use std::collections::BTreeSet;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::time::Duration;
use aion_core::{ActivityError, ActivityErrorKind, ActivityId, ContentType, Payload, WorkflowId};
use aion_proto::{ProtoActivityId, ProtoActivityTask, ProtoPayload, ProtoWorkflowId};
use async_trait::async_trait;
use futures::stream;
use serde_json::json;
use tokio::sync::{Mutex, mpsc};
use super::{ActivityDispatcher, DispatchOutcome, ServeEnd, serve_activity_tasks};
use crate::context::ActivityContext;
use crate::error::WorkerError;
use crate::protocol::{
ActivityTask, PendingActivityReport, UnackedResultTracker, WorkerSession, WorkerSessionEvent,
WorkerTaskStream, validate_activity_handlers,
};
use crate::{ReconnectConfig, WorkerConfig};
#[derive(Default)]
struct FakeSession {
tasks: Vec<Result<WorkerSessionEvent, WorkerError>>,
reports: Vec<RecordedReport>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
enum RecordedReport {
Completed(ActivityId, Payload),
Failed(ActivityId, ActivityError),
}
#[async_trait]
impl WorkerSession for FakeSession {
async fn handshake(&mut self, config: &WorkerConfig) -> Result<(), WorkerError> {
drop(config.clone());
Ok(())
}
async fn register(
&mut self,
activity_types: Vec<String>,
available_handlers: &BTreeSet<String>,
) -> Result<(), WorkerError> {
validate_activity_handlers(&activity_types, available_handlers)
}
fn receive_tasks(&mut self) -> WorkerTaskStream {
Box::pin(stream::iter(std::mem::take(&mut self.tasks)))
}
async fn report_result(
&mut self,
workflow_id: WorkflowId,
activity_id: ActivityId,
result: Payload,
) -> Result<(), WorkerError> {
let _ = workflow_id;
self.reports
.push(RecordedReport::Completed(activity_id, result));
Ok(())
}
async fn report_failure(
&mut self,
workflow_id: WorkflowId,
activity_id: ActivityId,
failure: ActivityError,
) -> Result<(), WorkerError> {
let _ = workflow_id;
self.reports
.push(RecordedReport::Failed(activity_id, failure));
Ok(())
}
async fn send_heartbeat(
&mut self,
workflow_id: WorkflowId,
activity_id: ActivityId,
progress: Option<Payload>,
) -> Result<(), WorkerError> {
drop((workflow_id, activity_id, progress));
Ok(())
}
}
struct RecordingDispatcher {
outcomes: Mutex<Vec<DispatchOutcome>>,
dispatched: Mutex<Vec<ActivityId>>,
}
#[async_trait]
impl ActivityDispatcher for RecordingDispatcher {
async fn dispatch(
&self,
task: ActivityTask,
context: ActivityContext,
) -> Result<DispatchOutcome, WorkerError> {
self.dispatched.lock().await.push(task.activity_id.clone());
drop(context);
let mut outcomes = self.outcomes.lock().await;
if outcomes.is_empty() {
return Err(WorkerError::decode(NoOutcome));
}
Ok(outcomes.remove(0))
}
fn activity_types(&self) -> BTreeSet<String> {
[String::from("charge-card")].into_iter().collect()
}
}
struct SlowDispatcher {
current: AtomicUsize,
peak: AtomicUsize,
started: AtomicUsize,
release: AtomicBool,
}
#[async_trait]
impl ActivityDispatcher for SlowDispatcher {
async fn dispatch(
&self,
task: ActivityTask,
context: ActivityContext,
) -> Result<DispatchOutcome, WorkerError> {
let now = self.current.fetch_add(1, Ordering::SeqCst) + 1;
update_peak(&self.peak, now);
self.started.fetch_add(1, Ordering::SeqCst);
while !self.release.load(Ordering::SeqCst) {
tokio::time::sleep(Duration::from_millis(1)).await;
}
self.current.fetch_sub(1, Ordering::SeqCst);
drop((task, context));
Ok(DispatchOutcome::Completed {
output: Payload::new(ContentType::Json, b"{}".to_vec()),
})
}
fn activity_types(&self) -> BTreeSet<String> {
[String::from("slow")].into_iter().collect()
}
}
struct CancellingDispatcher {
started: tokio::sync::Notify,
observed_cancelled: AtomicBool,
}
#[async_trait]
impl ActivityDispatcher for CancellingDispatcher {
async fn dispatch(
&self,
task: ActivityTask,
context: ActivityContext,
) -> Result<DispatchOutcome, WorkerError> {
drop(task);
self.started.notify_waiters();
context.cancelled().await;
self.observed_cancelled
.store(context.is_cancelled(), Ordering::SeqCst);
Ok(DispatchOutcome::Completed {
output: Payload::new(ContentType::Json, b"{}".to_vec()),
})
}
fn activity_types(&self) -> BTreeSet<String> {
[String::from("cancellable")].into_iter().collect()
}
}
#[tokio::test]
async fn dispatches_two_tasks_and_reports_corresponding_outcomes() -> Result<(), WorkerError> {
let workflow_id = WorkflowId::new_v4();
let first_activity = ActivityId::from_sequence_position(1);
let second_activity = ActivityId::from_sequence_position(2);
let first_output = Payload::from_json(&json!({"ok": true})).map_err(WorkerError::encode)?;
let failure = ActivityError {
kind: ActivityErrorKind::Terminal,
message: String::from("invalid card"),
details: None,
};
let mut session = FakeSession {
tasks: vec![
Ok(WorkerSessionEvent::Task(proto_task(
workflow_id.clone(),
first_activity.clone(),
"charge-card",
))),
Ok(WorkerSessionEvent::Task(proto_task(
workflow_id.clone(),
second_activity.clone(),
"charge-card",
))),
],
reports: Vec::new(),
};
let dispatcher = Arc::new(RecordingDispatcher {
outcomes: Mutex::new(vec![
DispatchOutcome::Completed {
output: first_output.clone(),
},
DispatchOutcome::Failed {
failure: failure.clone(),
},
]),
dispatched: Mutex::new(Vec::new()),
});
let config = test_config(2);
let mut tracker = UnackedResultTracker::new();
let end =
serve_activity_tasks(&config, &mut session, Arc::clone(&dispatcher), &mut tracker).await?;
assert_eq!(end, ServeEnd::StreamClosed);
assert_eq!(
*dispatcher.dispatched.lock().await,
vec![first_activity.clone(), second_activity.clone()]
);
assert_eq!(
session.reports,
vec![
RecordedReport::Completed(first_activity.clone(), first_output),
RecordedReport::Failed(second_activity.clone(), failure),
]
);
assert_eq!(tracker.len(), 2);
assert!(matches!(
tracker.get(&workflow_id, &first_activity),
Some(PendingActivityReport::Completed { .. })
));
assert!(matches!(
tracker.get(&workflow_id, &second_activity),
Some(PendingActivityReport::Failed { .. })
));
Ok(())
}
#[tokio::test]
async fn max_concurrency_caps_dispatches_at_two() -> Result<(), WorkerError> {
let workflow_id = WorkflowId::new_v4();
let (task_sender, task_receiver) = mpsc::channel(5);
let mut session = ChannelSession {
receiver: Some(task_receiver),
reports: Vec::new(),
};
for position in 1..=5 {
task_sender
.send(Ok(WorkerSessionEvent::Task(proto_task(
workflow_id.clone(),
ActivityId::from_sequence_position(position),
"slow",
))))
.await
.map_err(WorkerError::decode)?;
}
drop(task_sender);
let dispatcher = Arc::new(SlowDispatcher {
current: AtomicUsize::new(0),
peak: AtomicUsize::new(0),
started: AtomicUsize::new(0),
release: AtomicBool::new(false),
});
let config = test_config(2);
let worker = tokio::spawn({
let dispatcher = Arc::clone(&dispatcher);
async move {
let mut tracker = UnackedResultTracker::new();
let result =
serve_activity_tasks(&config, &mut session, dispatcher, &mut tracker).await;
(result, session, tracker)
}
});
wait_until_started(&dispatcher.started, 2).await;
tokio::time::sleep(Duration::from_millis(20)).await;
assert_eq!(dispatcher.started.load(Ordering::SeqCst), 2);
assert_eq!(dispatcher.peak.load(Ordering::SeqCst), 2);
dispatcher.release.store(true, Ordering::SeqCst);
let (result, session, tracker) = worker.await.map_err(WorkerError::decode)?;
assert_eq!(result?, ServeEnd::StreamClosed);
assert_eq!(session.reports.len(), 5);
assert_eq!(tracker.len(), 5);
assert_eq!(dispatcher.peak.load(Ordering::SeqCst), 2);
Ok(())
}
#[tokio::test]
async fn cancellation_event_flips_context_without_suppressing_result() -> Result<(), WorkerError> {
let workflow_id = WorkflowId::new_v4();
let activity_id = ActivityId::from_sequence_position(9);
let (event_sender, event_receiver) = mpsc::channel(2);
let mut session = ChannelSession {
receiver: Some(event_receiver),
reports: Vec::new(),
};
event_sender
.send(Ok(WorkerSessionEvent::Task(proto_task(
workflow_id.clone(),
activity_id.clone(),
"cancellable",
))))
.await
.map_err(WorkerError::decode)?;
let dispatcher = Arc::new(CancellingDispatcher {
started: tokio::sync::Notify::new(),
observed_cancelled: AtomicBool::new(false),
});
let config = test_config(1);
let worker = tokio::spawn({
let dispatcher = Arc::clone(&dispatcher);
async move {
let mut tracker = UnackedResultTracker::new();
let result =
serve_activity_tasks(&config, &mut session, dispatcher, &mut tracker).await;
(result, session)
}
});
dispatcher.started.notified().await;
event_sender
.send(Ok(WorkerSessionEvent::Cancel {
workflow_id,
activity_id: activity_id.clone(),
}))
.await
.map_err(WorkerError::decode)?;
drop(event_sender);
let (result, session) = worker.await.map_err(WorkerError::decode)?;
assert_eq!(result?, ServeEnd::StreamClosed);
assert!(dispatcher.observed_cancelled.load(Ordering::SeqCst));
assert_eq!(session.reports.len(), 1);
assert!(matches!(
&session.reports[0],
RecordedReport::Completed(reported_id, _) if reported_id == &activity_id
));
Ok(())
}
struct ChannelSession {
receiver: Option<mpsc::Receiver<Result<WorkerSessionEvent, WorkerError>>>,
reports: Vec<RecordedReport>,
}
#[async_trait]
impl WorkerSession for ChannelSession {
async fn handshake(&mut self, config: &WorkerConfig) -> Result<(), WorkerError> {
drop(config.clone());
Ok(())
}
async fn register(
&mut self,
activity_types: Vec<String>,
available_handlers: &BTreeSet<String>,
) -> Result<(), WorkerError> {
validate_activity_handlers(&activity_types, available_handlers)
}
fn receive_tasks(&mut self) -> WorkerTaskStream {
match self.receiver.take() {
Some(receiver) => Box::pin(tokio_stream::wrappers::ReceiverStream::new(receiver)),
None => Box::pin(stream::empty()),
}
}
async fn report_result(
&mut self,
workflow_id: WorkflowId,
activity_id: ActivityId,
result: Payload,
) -> Result<(), WorkerError> {
let _ = workflow_id;
self.reports
.push(RecordedReport::Completed(activity_id, result));
Ok(())
}
async fn report_failure(
&mut self,
workflow_id: WorkflowId,
activity_id: ActivityId,
failure: ActivityError,
) -> Result<(), WorkerError> {
let _ = workflow_id;
self.reports
.push(RecordedReport::Failed(activity_id, failure));
Ok(())
}
async fn send_heartbeat(
&mut self,
workflow_id: WorkflowId,
activity_id: ActivityId,
progress: Option<Payload>,
) -> Result<(), WorkerError> {
drop((workflow_id, activity_id, progress));
Ok(())
}
}
fn proto_task(
workflow_id: WorkflowId,
activity_id: ActivityId,
activity_type: &str,
) -> ProtoActivityTask {
ProtoActivityTask {
workflow_id: Some(ProtoWorkflowId::from(workflow_id)),
activity_id: Some(ProtoActivityId::from(activity_id)),
activity_type: String::from(activity_type),
input: Some(ProtoPayload::from(Payload::new(
ContentType::Json,
b"{}".to_vec(),
))),
}
}
async fn wait_until_started(started: &AtomicUsize, expected: usize) {
while started.load(Ordering::SeqCst) < expected {
tokio::time::sleep(Duration::from_millis(1)).await;
}
}
fn update_peak(peak: &AtomicUsize, observed: usize) {
let mut current_peak = peak.load(Ordering::SeqCst);
while observed > current_peak {
match peak.compare_exchange(current_peak, observed, Ordering::SeqCst, Ordering::SeqCst) {
Ok(_) => return,
Err(next_peak) => current_peak = next_peak,
}
}
}
fn test_config(max_concurrency: usize) -> WorkerConfig {
WorkerConfig::new(
"http://127.0.0.1:50051",
"payments",
"worker-a",
max_concurrency,
ReconnectConfig::new(Duration::from_millis(5), Duration::from_millis(20), 3),
None,
)
}
#[derive(Debug, thiserror::Error)]
#[error("fake dispatcher has no canned outcome")]
struct NoOutcome;