use std::net::SocketAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use aion::ActivityDispatcher as _;
use aion_proto::generated::worker_protocol_client::WorkerProtocolClient;
use aion_proto::generated::{self, server_to_worker, worker_to_server};
use aion_server::ServerState;
use aion_server::api::worker_grpc::worker_service;
use aion_server::config::{
AuthConfig, DashboardAssetSource, DashboardConfig, DeployConfig, ListenConfig, MetricsConfig,
NamespaceConfig, NamespaceMode, RuntimeConfig, WebSocketConfig, WorkerConfig,
};
use aion_server::worker::{ConnectedWorkerRegistry, WorkerActivityDispatcher};
use aion_server::{NamespaceResolver, StaticScheduleNamespaces, StaticWorkflowNamespaces};
use tokio::net::TcpListener;
use tokio_stream::wrappers::{ReceiverStream, TcpListenerStream};
type TestError = Box<dyn std::error::Error>;
const NAMESPACE: &str = "default";
const ACTIVITY_TYPE: &str = "greet";
fn runtime_config() -> RuntimeConfig {
RuntimeConfig {
listen: ListenConfig {
grpc: SocketAddr::from(([127, 0, 0, 1], 0)),
http: SocketAddr::from(([127, 0, 0, 1], 0)),
},
tls: None,
auth: AuthConfig {
enabled: false,
jwks_url: None,
jwks_refresh_seconds: 300,
},
dashboard: DashboardConfig {
source: DashboardAssetSource::Embedded,
},
namespace: NamespaceConfig {
mode: NamespaceMode::SharedEngine,
},
worker: WorkerConfig {
heartbeat_window: Duration::from_millis(30_000),
},
websocket: WebSocketConfig {
outbound_buffer_bound: 32,
event_broadcast_capacity: Some(64),
},
workflow_packages: Vec::new(),
deploy: DeployConfig::default(),
scheduler_threads: 1,
query_timeout: Some(Duration::from_millis(10_000)),
default_namespace: NAMESPACE.to_owned(),
drain_timeout: Duration::from_secs(30),
metrics: MetricsConfig { enabled: false },
}
}
struct Harness {
state: ServerState,
registry: ConnectedWorkerRegistry,
worker_tx: tokio::sync::mpsc::Sender<generated::WorkerToServer>,
inbound: tonic::Streaming<generated::ServerToWorker>,
register_ack: generated::RegisterAck,
server: tokio::task::JoinHandle<Result<(), tonic::transport::Error>>,
}
impl Harness {
async fn start() -> Result<Self, TestError> {
let listener = TcpListener::bind("127.0.0.1:0").await?;
let address = listener.local_addr()?;
let registry = ConnectedWorkerRegistry::default();
let resolver = NamespaceResolver::authorization_only(
NamespaceMode::SharedEngine,
StaticWorkflowNamespaces::default(),
StaticScheduleNamespaces::default(),
);
let state =
ServerState::from_parts_with_registry(resolver, runtime_config(), registry.clone());
let server = tokio::spawn(
tonic::transport::Server::builder()
.add_service(worker_service(state.clone()))
.serve_with_incoming(TcpListenerStream::new(listener)),
);
let mut client = WorkerProtocolClient::connect(format!("http://{address}")).await?;
let (worker_tx, worker_rx) = tokio::sync::mpsc::channel::<generated::WorkerToServer>(8);
worker_tx
.send(generated::WorkerToServer {
message: Some(worker_to_server::Message::Register(
generated::RegisterWorker {
namespace: NAMESPACE.to_owned(),
activity_types: vec![ACTIVITY_TYPE.to_owned()],
},
)),
})
.await?;
let mut request = tonic::Request::new(ReceiverStream::new(worker_rx));
request
.metadata_mut()
.insert("x-aion-namespaces", NAMESPACE.parse()?);
let mut inbound = client.stream_worker(request).await?.into_inner();
let first = inbound
.message()
.await?
.and_then(|frame| frame.message)
.ok_or("response stream ended before the RegisterAck")?;
let server_to_worker::Message::RegisterAck(register_ack) = first else {
return Err(format!("first response frame must be RegisterAck, got {first:?}").into());
};
if registry.workers_for(NAMESPACE, ACTIVITY_TYPE)?.is_empty() {
return Err("RegisterAck arrived before the registry registration".into());
}
Ok(Self {
state,
registry,
worker_tx,
inbound,
register_ack,
server,
})
}
fn dispatcher(&self) -> WorkerActivityDispatcher {
WorkerActivityDispatcher::new(self.registry.clone(), NAMESPACE)
.with_pending(self.state.pending_activities().clone())
.with_heartbeat_tracker(self.state.heartbeat_tracker().clone())
.with_drain_state(self.state.drain_state().clone())
}
async fn next_task(&mut self) -> Result<generated::ActivityTask, TestError> {
while let Some(message) = self.inbound.message().await? {
if let Some(server_to_worker::Message::Task(task)) = message.message {
return Ok(task);
}
}
Err("worker stream closed before a task was delivered".into())
}
async fn next_result_ack(&mut self) -> Result<generated::ResultAck, TestError> {
while let Some(message) = self.inbound.message().await? {
if let Some(server_to_worker::Message::ResultAck(ack)) = message.message {
return Ok(ack);
}
}
Err("worker stream closed before a result ack was delivered".into())
}
async fn next_drain(&mut self) -> Result<(), TestError> {
while let Some(message) = self.inbound.message().await? {
if let Some(server_to_worker::Message::Drain(_)) = message.message {
return Ok(());
}
}
Err("worker stream closed before a drain request was delivered".into())
}
async fn complete(
&self,
task: generated::ActivityTask,
result_json: &[u8],
) -> Result<(), TestError> {
self.worker_tx
.send(generated::WorkerToServer {
message: Some(worker_to_server::Message::Result(
generated::ActivityResult {
workflow_id: task.workflow_id,
activity_id: task.activity_id,
outcome: Some(generated::activity_result::Outcome::Result(
generated::Payload {
content_type: "application/json".to_owned(),
bytes: result_json.to_vec(),
},
)),
},
)),
})
.await?;
Ok(())
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn remote_dispatch_delivers_task_promptly_and_round_trips() -> Result<(), TestError> {
let mut harness = Harness::start().await?;
let dispatcher = Arc::new(harness.dispatcher());
let started = Instant::now();
let dispatch_task = tokio::spawn(futures::future::lazy(move |_| {
dispatcher.dispatch(ACTIVITY_TYPE, r#"{"name":"world"}"#, "{}", 3)
}));
let task = harness.next_task().await?;
let delivery_elapsed = started.elapsed();
assert_eq!(task.activity_type, ACTIVITY_TYPE);
assert_eq!(
task.attempt, 3,
"the engine-seam attempt must be stamped onto the wire task"
);
assert!(
delivery_elapsed < Duration::from_secs(5),
"task took {delivery_elapsed:?} to reach the worker stream; delivery \
must not be coupled to the dispatch timeout"
);
harness
.complete(task, br#"{"greeting":"hello world"}"#)
.await?;
let result = dispatch_task.await.map_err(|error| error.to_string())?;
let round_trip_elapsed = started.elapsed();
assert_eq!(result, Ok(r#"{"greeting":"hello world"}"#.to_owned()));
assert!(
round_trip_elapsed < Duration::from_secs(5),
"dispatch round trip took {round_trip_elapsed:?}"
);
harness.server.abort();
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn dispatch_times_out_only_when_worker_stays_silent() -> Result<(), TestError> {
let mut harness = Harness::start().await?;
let dispatcher = Arc::new(harness.dispatcher().with_timeout(Duration::from_secs(2)));
let started = Instant::now();
let dispatch_task = tokio::spawn(futures::future::lazy(move |_| {
dispatcher.dispatch(ACTIVITY_TYPE, "{}", "{}", 1)
}));
let task = harness.next_task().await?;
let delivery_elapsed = started.elapsed();
assert_eq!(task.activity_type, ACTIVITY_TYPE);
assert!(
delivery_elapsed < Duration::from_secs(1),
"task took {delivery_elapsed:?} to reach the worker stream; with the \
dispatch stall defect it would only arrive when the 2s timeout fired"
);
let result = dispatch_task.await.map_err(|error| error.to_string())?;
let error = result.err().ok_or("expected dispatch to time out")?;
assert!(
error.contains("timed out after 2s"),
"unexpected error: {error}"
);
harness.server.abort();
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn register_ack_is_first_frame_then_task() -> Result<(), TestError> {
let mut harness = Harness::start().await?;
assert_eq!(harness.register_ack.namespace, NAMESPACE);
assert_eq!(
harness.register_ack.heartbeat_window_ms, 30_000,
"the ack must carry the operator-configured heartbeat window"
);
assert!(
harness.register_ack.worker_id > 0,
"the ack must carry the server-assigned worker id"
);
let dispatcher = Arc::new(harness.dispatcher());
let dispatch_task = tokio::spawn(futures::future::lazy(move |_| {
dispatcher.dispatch(ACTIVITY_TYPE, "{}", "{}", 1)
}));
let task = harness.next_task().await?;
assert_eq!(task.activity_type, ACTIVITY_TYPE);
harness.complete(task, br#"{"greeting":"hi"}"#).await?;
let result = dispatch_task.await.map_err(|error| error.to_string())?;
assert_eq!(result, Ok(r#"{"greeting":"hi"}"#.to_owned()));
harness.server.abort();
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn denied_registration_fails_rpc_without_frames() -> Result<(), TestError> {
let listener = TcpListener::bind("127.0.0.1:0").await?;
let address = listener.local_addr()?;
let registry = ConnectedWorkerRegistry::default();
let resolver = NamespaceResolver::authorization_only(
NamespaceMode::SharedEngine,
StaticWorkflowNamespaces::default(),
StaticScheduleNamespaces::default(),
);
let state = ServerState::from_parts_with_registry(resolver, runtime_config(), registry.clone());
let server = tokio::spawn(
tonic::transport::Server::builder()
.add_service(worker_service(state))
.serve_with_incoming(TcpListenerStream::new(listener)),
);
let mut client = WorkerProtocolClient::connect(format!("http://{address}")).await?;
let (worker_tx, worker_rx) = tokio::sync::mpsc::channel::<generated::WorkerToServer>(8);
worker_tx
.send(generated::WorkerToServer {
message: Some(worker_to_server::Message::Register(
generated::RegisterWorker {
namespace: "ungranted".to_owned(),
activity_types: vec![ACTIVITY_TYPE.to_owned()],
},
)),
})
.await?;
let mut request = tonic::Request::new(ReceiverStream::new(worker_rx));
request
.metadata_mut()
.insert("x-aion-namespaces", NAMESPACE.parse()?);
let denial = match client.stream_worker(request).await {
Ok(mut response) => {
match response.get_mut().message().await {
Ok(Some(frame)) => {
return Err(format!("denied registration delivered a frame: {frame:?}").into());
}
Ok(None) => {
return Err("denied registration ended the stream without a status".into());
}
Err(status) => status,
}
}
Err(status) => status,
};
assert_eq!(denial.code(), tonic::Code::PermissionDenied);
assert!(registry.workers_for("ungranted", ACTIVITY_TYPE)?.is_empty());
server.abort();
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn result_frames_are_acked_including_duplicates() -> Result<(), TestError> {
let mut harness = Harness::start().await?;
let dispatcher = Arc::new(harness.dispatcher());
let dispatch_task = tokio::spawn(futures::future::lazy(move |_| {
dispatcher.dispatch(ACTIVITY_TYPE, "{}", "{}", 1)
}));
let task = harness.next_task().await?;
let workflow_id = task.workflow_id.clone();
let activity_id = task.activity_id;
harness.complete(task.clone(), br#"{"ok":true}"#).await?;
let ack = harness.next_result_ack().await?;
assert_eq!(ack.workflow_id, workflow_id);
assert_eq!(ack.activity_id, activity_id);
let result = dispatch_task.await.map_err(|error| error.to_string())?;
assert_eq!(result, Ok(r#"{"ok":true}"#.to_owned()));
harness.complete(task, br#"{"ok":true}"#).await?;
let duplicate_ack = harness.next_result_ack().await?;
assert_eq!(duplicate_ack.workflow_id, workflow_id);
assert_eq!(duplicate_ack.activity_id, activity_id);
harness.server.abort();
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn malformed_result_gets_no_ack_and_stream_stays_healthy() -> Result<(), TestError> {
let mut harness = Harness::start().await?;
harness
.worker_tx
.send(generated::WorkerToServer {
message: Some(worker_to_server::Message::Result(
generated::ActivityResult {
workflow_id: Some(generated::WorkflowId {
uuid: "00000000-0000-0000-0000-000000000000".to_owned(),
}),
activity_id: None,
outcome: Some(generated::activity_result::Outcome::Result(
generated::Payload {
content_type: "application/json".to_owned(),
bytes: b"{}".to_vec(),
},
)),
},
)),
})
.await?;
let dispatcher = Arc::new(harness.dispatcher());
let dispatch_task = tokio::spawn(futures::future::lazy(move |_| {
dispatcher.dispatch(ACTIVITY_TYPE, "{}", "{}", 1)
}));
let task = harness.next_task().await?;
let workflow_id = task.workflow_id.clone();
harness.complete(task, br#"{"ok":true}"#).await?;
let ack = harness.next_result_ack().await?;
assert_eq!(
ack.workflow_id, workflow_id,
"the only ack on the stream must belong to the well-formed result"
);
let result = dispatch_task.await.map_err(|error| error.to_string())?;
assert_eq!(result, Ok(r#"{"ok":true}"#.to_owned()));
harness.server.abort();
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn drain_broadcast_reaches_worker_and_gates_dispatch() -> Result<(), TestError> {
let mut harness = Harness::start().await?;
assert!(harness.state.drain_state().begin());
let delivered = harness.state.worker_registry().broadcast_drain()?;
assert_eq!(delivered, 1);
harness.next_drain().await?;
let dispatcher = harness.dispatcher();
let dispatch_task =
tokio::task::spawn_blocking(move || dispatcher.dispatch(ACTIVITY_TYPE, "{}", "{}", 1));
let result = dispatch_task.await.map_err(|error| error.to_string())?;
let error = result.err().ok_or("post-drain dispatch must be rejected")?;
assert!(
error.contains("draining"),
"rejection must name the drain gate: {error}"
);
harness.server.abort();
Ok(())
}