use std::net::SocketAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use aion::ActivityDispatcher as _;
use aion_proto::generated::worker_protocol_client::WorkerProtocolClient;
use aion_proto::generated::{self, server_to_worker, worker_to_server};
use aion_server::ServerState;
use aion_server::api::worker_grpc::worker_service;
use aion_server::config::{
AuthConfig, DashboardAssetSource, DashboardConfig, ListenConfig, MetricsConfig,
NamespaceConfig, NamespaceMode, RuntimeConfig, WebSocketConfig, WorkerConfig,
};
use aion_server::worker::{ConnectedWorkerRegistry, WorkerActivityDispatcher};
use aion_server::{NamespaceResolver, StaticScheduleNamespaces, StaticWorkflowNamespaces};
use tokio::net::TcpListener;
use tokio_stream::wrappers::{ReceiverStream, TcpListenerStream};
type TestError = Box<dyn std::error::Error>;
const NAMESPACE: &str = "default";
const ACTIVITY_TYPE: &str = "greet";
fn runtime_config() -> RuntimeConfig {
RuntimeConfig {
listen: ListenConfig {
grpc: SocketAddr::from(([127, 0, 0, 1], 0)),
http: SocketAddr::from(([127, 0, 0, 1], 0)),
},
tls: None,
auth: AuthConfig {
enabled: false,
jwks_url: None,
jwks_refresh_seconds: 300,
},
dashboard: DashboardConfig {
source: DashboardAssetSource::Embedded,
},
namespace: NamespaceConfig {
mode: NamespaceMode::SharedEngine,
},
worker: WorkerConfig {
heartbeat_window: Duration::from_millis(30_000),
},
websocket: WebSocketConfig {
outbound_buffer_bound: 32,
event_broadcast_capacity: Some(64),
},
workflow_packages: Vec::new(),
scheduler_threads: 1,
query_timeout: Some(Duration::from_millis(10_000)),
default_namespace: NAMESPACE.to_owned(),
drain_timeout: Duration::from_secs(30),
metrics: MetricsConfig { enabled: false },
}
}
struct Harness {
state: ServerState,
registry: ConnectedWorkerRegistry,
worker_tx: tokio::sync::mpsc::Sender<generated::WorkerToServer>,
inbound: tonic::Streaming<generated::ServerToWorker>,
server: tokio::task::JoinHandle<Result<(), tonic::transport::Error>>,
}
impl Harness {
async fn start() -> Result<Self, TestError> {
let listener = TcpListener::bind("127.0.0.1:0").await?;
let address = listener.local_addr()?;
let registry = ConnectedWorkerRegistry::default();
let resolver = NamespaceResolver::authorization_only(
NamespaceMode::SharedEngine,
StaticWorkflowNamespaces::default(),
StaticScheduleNamespaces::default(),
);
let state =
ServerState::from_parts_with_registry(resolver, runtime_config(), registry.clone());
let server = tokio::spawn(
tonic::transport::Server::builder()
.add_service(worker_service(state.clone()))
.serve_with_incoming(TcpListenerStream::new(listener)),
);
let mut client = WorkerProtocolClient::connect(format!("http://{address}")).await?;
let (worker_tx, worker_rx) = tokio::sync::mpsc::channel::<generated::WorkerToServer>(8);
worker_tx
.send(generated::WorkerToServer {
message: Some(worker_to_server::Message::Register(
generated::RegisterWorker {
namespace: NAMESPACE.to_owned(),
activity_types: vec![ACTIVITY_TYPE.to_owned()],
},
)),
})
.await?;
let mut request = tonic::Request::new(ReceiverStream::new(worker_rx));
request
.metadata_mut()
.insert("x-aion-namespaces", NAMESPACE.parse()?);
let inbound = client.stream_worker(request).await?.into_inner();
let deadline = Instant::now() + Duration::from_secs(10);
while registry.workers_for(NAMESPACE, ACTIVITY_TYPE)?.is_empty() {
if Instant::now() >= deadline {
return Err("worker registration did not reach the registry".into());
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
Ok(Self {
state,
registry,
worker_tx,
inbound,
server,
})
}
fn dispatcher(&self) -> WorkerActivityDispatcher {
WorkerActivityDispatcher::new(self.registry.clone(), NAMESPACE)
.with_pending(self.state.pending_activities().clone())
.with_heartbeat_tracker(self.state.heartbeat_tracker().clone())
.with_drain_state(self.state.drain_state().clone())
}
async fn next_task(&mut self) -> Result<generated::ActivityTask, TestError> {
while let Some(message) = self.inbound.message().await? {
if let Some(server_to_worker::Message::Task(task)) = message.message {
return Ok(task);
}
}
Err("worker stream closed before a task was delivered".into())
}
async fn complete(
&self,
task: generated::ActivityTask,
result_json: &[u8],
) -> Result<(), TestError> {
self.worker_tx
.send(generated::WorkerToServer {
message: Some(worker_to_server::Message::Result(
generated::ActivityResult {
workflow_id: task.workflow_id,
activity_id: task.activity_id,
outcome: Some(generated::activity_result::Outcome::Result(
generated::Payload {
content_type: "application/json".to_owned(),
bytes: result_json.to_vec(),
},
)),
},
)),
})
.await?;
Ok(())
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn remote_dispatch_delivers_task_promptly_and_round_trips() -> Result<(), TestError> {
let mut harness = Harness::start().await?;
let dispatcher = Arc::new(harness.dispatcher());
let started = Instant::now();
let dispatch_task = tokio::spawn(futures::future::lazy(move |_| {
dispatcher.dispatch(ACTIVITY_TYPE, r#"{"name":"world"}"#, "{}")
}));
let task = harness.next_task().await?;
let delivery_elapsed = started.elapsed();
assert_eq!(task.activity_type, ACTIVITY_TYPE);
assert!(
delivery_elapsed < Duration::from_secs(5),
"task took {delivery_elapsed:?} to reach the worker stream; delivery \
must not be coupled to the dispatch timeout"
);
harness
.complete(task, br#"{"greeting":"hello world"}"#)
.await?;
let result = dispatch_task.await.map_err(|error| error.to_string())?;
let round_trip_elapsed = started.elapsed();
assert_eq!(result, Ok(r#"{"greeting":"hello world"}"#.to_owned()));
assert!(
round_trip_elapsed < Duration::from_secs(5),
"dispatch round trip took {round_trip_elapsed:?}"
);
harness.server.abort();
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn dispatch_times_out_only_when_worker_stays_silent() -> Result<(), TestError> {
let mut harness = Harness::start().await?;
let dispatcher = Arc::new(harness.dispatcher().with_timeout(Duration::from_secs(2)));
let started = Instant::now();
let dispatch_task = tokio::spawn(futures::future::lazy(move |_| {
dispatcher.dispatch(ACTIVITY_TYPE, "{}", "{}")
}));
let task = harness.next_task().await?;
let delivery_elapsed = started.elapsed();
assert_eq!(task.activity_type, ACTIVITY_TYPE);
assert!(
delivery_elapsed < Duration::from_secs(1),
"task took {delivery_elapsed:?} to reach the worker stream; with the \
dispatch stall defect it would only arrive when the 2s timeout fired"
);
let result = dispatch_task.await.map_err(|error| error.to_string())?;
let error = result.err().ok_or("expected dispatch to time out")?;
assert!(
error.contains("timed out after 2s"),
"unexpected error: {error}"
);
harness.server.abort();
Ok(())
}