use std::net::SocketAddr;
use std::sync::Arc;
use axum::Router;
use async_trait::async_trait;
use object_store::ObjectStore;
use reqwest::Url;
use tempfile::TempDir;
use alien_bindings::{
providers::{kv::LocalKv, storage::LocalStorage},
traits::{Kv, Storage},
};
use crate::{
server::{create_axum_router, CommandDispatcher, CommandServer, InMemoryCommandRegistry},
test_utils::{MockDispatcher, MockDispatcherMode},
types::*,
Result,
};
use alien_core::DeploymentModel;
pub struct TestCommandServer {
pub command_server: Arc<CommandServer>,
pub server_addr: SocketAddr,
pub shutdown_tx: Option<tokio::sync::oneshot::Sender<()>>,
pub kv: Arc<LocalKv>,
pub storage: Arc<LocalStorage>,
pub dispatcher: Arc<dyn CommandDispatcher>,
_temp_dir: TempDir,
}
impl TestCommandServer {
pub async fn new() -> Self {
Self::builder().build().await
}
pub fn builder() -> TestCommandServerBuilder {
TestCommandServerBuilder::new()
}
pub fn base_url(&self) -> String {
format!("http://{}", self.server_addr)
}
pub fn command_base_url(&self) -> String {
let base = Url::parse(&self.base_url()).expect("Valid base URL");
base.join("v1/").expect("Valid URL join").to_string()
}
pub async fn create_command(
&self,
request: CreateCommandRequest,
) -> Result<CreateCommandResponse> {
self.command_server.create_command(request).await
}
pub async fn upload_complete(
&self,
command_id: &str,
upload_request: UploadCompleteRequest,
) -> Result<UploadCompleteResponse> {
self.command_server
.upload_complete(command_id, upload_request)
.await
}
pub async fn get_command_status(&self, command_id: &str) -> Result<CommandStatusResponse> {
self.command_server.get_command_status(command_id).await
}
pub async fn submit_command_response(
&self,
command_id: &str,
response: CommandResponse,
) -> Result<()> {
self.command_server
.submit_command_response(command_id, response)
.await
}
pub async fn acquire_lease(
&self,
deployment_id: &str,
mut lease_request: LeaseRequest,
) -> Result<LeaseResponse> {
lease_request.deployment_id = deployment_id.to_string();
self.command_server
.acquire_lease(deployment_id, &lease_request)
.await
}
pub async fn acquire_single_lease(&self, deployment_id: &str) -> Result<Option<LeaseInfo>> {
let mut lease_request = LeaseRequest::default();
lease_request.deployment_id = deployment_id.to_string();
let response = self.acquire_lease(deployment_id, lease_request).await?;
Ok(response.leases.into_iter().next())
}
pub async fn release_lease(&self, command_id: &str, lease_id: &str) -> Result<()> {
self.command_server
.release_lease(command_id, lease_id)
.await
}
pub async fn shutdown(&mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
}
pub async fn wait_for_state(
&self,
command_id: &str,
expected_state: CommandState,
timeout: std::time::Duration,
) -> bool {
let start = std::time::Instant::now();
while start.elapsed() < timeout {
if let Ok(status) = self.get_command_status(command_id).await {
if status.state == expected_state {
return true;
}
}
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
}
false
}
pub async fn wait_for_completion(
&self,
command_id: &str,
timeout: std::time::Duration,
) -> Result<CommandStatusResponse> {
let start = std::time::Instant::now();
while start.elapsed() < timeout {
let status = self.get_command_status(command_id).await?;
if status.state.is_terminal() {
return Ok(status);
}
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
}
Err(alien_error::AlienError::new(crate::ErrorData::Other {
message: format!("Command {} did not complete within timeout", command_id),
}))
}
pub async fn reset(&self) {
let _ = self.kv.clear().await;
if let Some(mock_dispatcher) = self.mock_dispatcher() {
mock_dispatcher.clear().await;
}
}
pub async fn command_count(&self) -> usize {
let keys = self.kv.keys().await.unwrap_or_default();
keys.iter()
.filter(|k| k.starts_with("cmd:") && !k.contains(":lease"))
.count()
}
pub async fn storage_object_count(&self) -> usize {
let mut count = 0;
let mut stream = self.storage.list(None);
while let Some(_) = futures::stream::StreamExt::next(&mut stream).await {
count += 1;
}
count
}
pub fn mock_dispatcher(&self) -> Option<&MockDispatcher> {
self.dispatcher.as_any().downcast_ref::<MockDispatcher>()
}
pub async fn is_clean(&self) -> bool {
self.command_count().await == 0 && self.storage_object_count().await == 0
}
}
pub struct TestCommandServerBuilder {
kv: Option<Arc<LocalKv>>,
storage: Option<Arc<LocalStorage>>,
dispatcher: Option<Arc<dyn CommandDispatcher>>,
}
impl TestCommandServerBuilder {
fn new() -> Self {
Self {
kv: None,
storage: None,
dispatcher: None,
}
}
pub fn with_kv(mut self, kv: Arc<LocalKv>) -> Self {
self.kv = Some(kv);
self
}
pub fn with_storage(mut self, storage: Arc<LocalStorage>) -> Self {
self.storage = Some(storage);
self
}
pub fn with_dispatcher(mut self, dispatcher: Arc<dyn CommandDispatcher>) -> Self {
self.dispatcher = Some(dispatcher);
self
}
pub fn with_pull_mode(mut self) -> Self {
self.dispatcher = Some(Arc::new(MockDispatcher::new_pull()) as Arc<dyn CommandDispatcher>);
self
}
pub async fn build(self) -> TestCommandServer {
let temp_dir = TempDir::new().expect("Failed to create temp directory");
let kv = if let Some(kv) = self.kv {
kv
} else {
let kv_path = temp_dir.path().join("kv.db");
Arc::new(
LocalKv::new(kv_path)
.await
.expect("Failed to create LocalKv for testing"),
)
};
let storage = self.storage.unwrap_or_else(|| {
Arc::new(
LocalStorage::new_from_path(temp_dir.path().to_str().unwrap())
.expect("Failed to create LocalStorage for testing"),
)
});
let dispatcher = self
.dispatcher
.unwrap_or_else(|| Arc::new(MockDispatcher::new()) as Arc<dyn CommandDispatcher>);
let deployment_model = dispatcher
.as_any()
.downcast_ref::<MockDispatcher>()
.map(|d| match d.mode() {
MockDispatcherMode::Push => DeploymentModel::Push,
MockDispatcherMode::Pull => DeploymentModel::Pull,
})
.unwrap_or(DeploymentModel::Pull);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
.await
.expect("Failed to bind to port");
let server_addr = listener.local_addr().expect("Failed to get local address");
let base_url = format!("http://{}", server_addr);
let command_base_url = {
let base = Url::parse(&base_url).expect("Valid base URL");
base.join("v1/").expect("Valid URL join").to_string()
};
let command_server = Arc::new(CommandServer::new(
kv.clone() as Arc<dyn Kv>,
storage.clone() as Arc<dyn Storage>,
dispatcher.clone(),
Arc::new(InMemoryCommandRegistry::with_deployment_model(
deployment_model,
)),
command_base_url,
));
let commands_router: Router<Arc<CommandServer>> = create_axum_router();
let router = Router::new()
.nest("/v1", commands_router)
.with_state(command_server.clone());
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
tokio::spawn(async move {
axum::serve(listener, router)
.with_graceful_shutdown(async {
shutdown_rx.await.ok();
})
.await
.expect("Server failed");
});
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
TestCommandServer {
command_server,
server_addr,
shutdown_tx: Some(shutdown_tx),
kv,
storage,
dispatcher,
_temp_dir: temp_dir,
}
}
}
impl Drop for TestCommandServer {
fn drop(&mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
}
}
#[async_trait]
pub trait TestCommandServerAssertions {
async fn assert_command_state(&self, command_id: &str, expected_state: CommandState);
async fn assert_command_succeeded(&self, command_id: &str);
async fn assert_command_failed(&self, command_id: &str);
async fn assert_clean(&self);
async fn assert_command_count(&self, expected: usize);
async fn assert_storage_count(&self, expected: usize);
}
#[async_trait]
impl TestCommandServerAssertions for TestCommandServer {
async fn assert_command_state(&self, command_id: &str, expected_state: CommandState) {
let status = self
.get_command_status(command_id)
.await
.unwrap_or_else(|_| panic!("Failed to get status for command {}", command_id));
assert_eq!(
status.state, expected_state,
"Command {} expected to be in state {:?}, but was {:?}",
command_id, expected_state, status.state
);
}
async fn assert_command_succeeded(&self, command_id: &str) {
self.assert_command_state(command_id, CommandState::Succeeded)
.await;
}
async fn assert_command_failed(&self, command_id: &str) {
self.assert_command_state(command_id, CommandState::Failed)
.await;
}
async fn assert_clean(&self) {
assert!(
self.is_clean().await,
"Expected server state to be clean, but found {} commands and {} storage objects",
self.command_count().await,
self.storage_object_count().await
);
}
async fn assert_command_count(&self, expected: usize) {
let actual = self.command_count().await;
assert_eq!(
actual, expected,
"Expected {} commands in KV store, but found {}",
expected, actual
);
}
async fn assert_storage_count(&self, expected: usize) {
let actual = self.storage_object_count().await;
assert_eq!(
actual, expected,
"Expected {} objects in storage, but found {}",
expected, actual
);
}
}