use std::collections::{BTreeSet, HashMap};
use std::sync::Arc;
use aion_core::{ActivityError, ActivityId, Payload, WorkflowId};
use async_trait::async_trait;
use futures::StreamExt;
use futures::future;
use tokio::sync::{Semaphore, mpsc};
use tracing::{debug, info};
use crate::config::WorkerConfig;
use crate::context::{ActivityContext, HeartbeatRequest};
use crate::error::WorkerError;
use crate::protocol::reconnect::UnackedResultTracker;
use crate::protocol::{
ActivityExecutionKey, ActivityTask, HeartbeatBookkeeper, WorkerSession, WorkerSessionEvent,
};
use crate::runtime::report::{
DispatchFinished, InFlightActivity, RuntimeChannels, drain_remaining, record_first_error,
report_finished,
};
#[async_trait]
pub trait ActivityDispatcher: Send + Sync + 'static {
async fn dispatch(
&self,
task: ActivityTask,
context: ActivityContext,
) -> Result<DispatchOutcome, WorkerError>;
fn activity_types(&self) -> BTreeSet<String>;
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum DispatchOutcome {
Completed {
output: Payload,
},
Failed {
failure: ActivityError,
},
}
pub type NoShutdown = future::Pending<()>;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum ServeEnd {
Shutdown,
StreamClosed,
Drained,
}
#[derive(Debug, Default)]
pub struct SessionHealth {
pub tasks_reported: usize,
pub stream_ended_at: Option<tokio::time::Instant>,
pub drain_received: bool,
}
pub async fn serve_activity_tasks<S, D>(
config: &WorkerConfig,
session: &mut S,
dispatcher: Arc<D>,
tracker: &mut UnackedResultTracker,
) -> Result<ServeEnd, WorkerError>
where
S: WorkerSession,
D: ActivityDispatcher,
{
let mut health = SessionHealth::default();
serve_activity_tasks_until(
config,
session,
dispatcher,
tracker,
&mut health,
future::pending(),
)
.await
}
pub async fn serve_activity_tasks_until<S, D, Shutdown>(
config: &WorkerConfig,
session: &mut S,
dispatcher: Arc<D>,
tracker: &mut UnackedResultTracker,
health: &mut SessionHealth,
shutdown: Shutdown,
) -> Result<ServeEnd, WorkerError>
where
S: WorkerSession,
D: ActivityDispatcher,
Shutdown: Future<Output = ()> + Send,
{
ensure_max_concurrency(config)?;
let semaphore = Arc::new(Semaphore::new(config.max_concurrency));
let (result_sender, heartbeat_sender, mut channels) = runtime_channels();
let heartbeat_bookkeeper = HeartbeatBookkeeper::default();
let mut stream = session.receive_tasks();
let mut in_flight = HashMap::<ActivityExecutionKey, InFlightActivity>::new();
let mut pending_error = None;
let mut end = ServeEnd::StreamClosed;
tokio::pin!(shutdown);
while pending_error.is_none() {
tokio::select! {
biased;
() = &mut shutdown => {
cancel_all_in_flight(&in_flight);
end = ServeEnd::Shutdown;
break;
}
finished = channels.results.recv() => {
if let Some(finished) = finished {
report_finished(
session,
&heartbeat_bookkeeper,
finished,
&mut in_flight,
tracker,
&mut health.tasks_reported,
&mut pending_error,
)
.await;
}
}
request = channels.heartbeats.recv() => {
if let Some(request) = request {
forward_heartbeat(session, &heartbeat_bookkeeper, request, &mut pending_error)
.await;
}
}
event = stream.next() => {
let Some(event) = event else { break; };
match event {
Ok(WorkerSessionEvent::Cancel { workflow_id, activity_id }) => {
deliver_cancellation(workflow_id, &activity_id, &in_flight);
}
Ok(WorkerSessionEvent::ResultAck { workflow_id, activity_id }) => {
acknowledge_result(&workflow_id, &activity_id, tracker);
}
Ok(WorkerSessionEvent::Drain) => {
info!("server drain received; finishing in-flight work before reconnect");
health.drain_received = true;
end = ServeEnd::Drained;
break;
}
Err(error) => {
pending_error = Some(error);
break;
}
Ok(WorkerSessionEvent::Task(proto_task)) => {
let Some(permit) =
acquire_permit_or_shutdown(shutdown.as_mut(), &semaphore).await?
else {
cancel_all_in_flight(&in_flight);
end = ServeEnd::Shutdown;
break;
};
if !handle_task(
proto_task,
SessionEventContext {
permit,
dispatcher: Arc::clone(&dispatcher),
result_sender: &result_sender,
heartbeat_sender: &heartbeat_sender,
heartbeat_bookkeeper: &heartbeat_bookkeeper,
in_flight: &mut in_flight,
pending_error: &mut pending_error,
},
)? {
break;
}
}
}
}
}
}
health.stream_ended_at = Some(tokio::time::Instant::now());
drop((result_sender, heartbeat_sender));
drain_remaining(
session,
&heartbeat_bookkeeper,
&mut channels,
&mut in_flight,
tracker,
&mut health.tasks_reported,
&mut pending_error,
)
.await;
pending_error.map_or(Ok(end), Err)
}
fn runtime_channels() -> (
mpsc::UnboundedSender<DispatchFinished>,
mpsc::UnboundedSender<HeartbeatRequest>,
RuntimeChannels,
) {
let (result_sender, result_receiver) = mpsc::unbounded_channel();
let (heartbeat_sender, heartbeat_receiver) = mpsc::unbounded_channel();
let channels = RuntimeChannels {
heartbeats: heartbeat_receiver,
results: result_receiver,
};
(result_sender, heartbeat_sender, channels)
}
struct SessionEventContext<'a, D> {
permit: tokio::sync::OwnedSemaphorePermit,
dispatcher: Arc<D>,
result_sender: &'a mpsc::UnboundedSender<DispatchFinished>,
heartbeat_sender: &'a mpsc::UnboundedSender<HeartbeatRequest>,
heartbeat_bookkeeper: &'a HeartbeatBookkeeper,
in_flight: &'a mut HashMap<ActivityExecutionKey, InFlightActivity>,
pending_error: &'a mut Option<WorkerError>,
}
fn handle_task<D>(
proto_task: aion_proto::ProtoActivityTask,
ctx: SessionEventContext<'_, D>,
) -> Result<bool, WorkerError>
where
D: ActivityDispatcher,
{
let task = match ActivityTask::try_from(proto_task) {
Ok(task) => task,
Err(error) => {
drop(ctx.permit);
*ctx.pending_error = Some(error);
return Ok(false);
}
};
spawn_activity(
task,
ctx.permit,
ctx.dispatcher,
ctx.result_sender.clone(),
ctx.heartbeat_sender.clone(),
ctx.heartbeat_bookkeeper,
ctx.in_flight,
)?;
Ok(true)
}
fn ensure_max_concurrency(config: &WorkerConfig) -> Result<(), WorkerError> {
if config.max_concurrency == 0 {
return Err(WorkerError::registration(InvalidMaxConcurrency));
}
Ok(())
}
async fn acquire_permit_or_shutdown<F>(
shutdown: std::pin::Pin<&mut F>,
semaphore: &Arc<Semaphore>,
) -> Result<Option<tokio::sync::OwnedSemaphorePermit>, WorkerError>
where
F: Future<Output = ()> + Send,
{
tokio::select! {
biased;
() = shutdown => Ok(None),
permit = Arc::clone(semaphore).acquire_owned() => {
permit.map(Some).map_err(WorkerError::registration)
}
}
}
async fn forward_heartbeat<S>(
session: &mut S,
heartbeat_bookkeeper: &HeartbeatBookkeeper,
request: HeartbeatRequest,
pending_error: &mut Option<WorkerError>,
) where
S: WorkerSession,
{
record_first_error(
pending_error,
crate::protocol::send_heartbeat(session, heartbeat_bookkeeper, request).await,
);
}
fn acknowledge_result(
workflow_id: &WorkflowId,
activity_id: &ActivityId,
tracker: &mut UnackedResultTracker,
) {
if tracker.acknowledge(workflow_id, activity_id).is_some() {
debug!(
workflow_id = %workflow_id,
activity_id = activity_id.sequence_position(),
"server acknowledged activity result; tracker entry cleared"
);
} else {
debug!(
workflow_id = %workflow_id,
activity_id = activity_id.sequence_position(),
"result ack for unknown tracker entry ignored"
);
}
}
fn spawn_activity<D>(
task: ActivityTask,
permit: tokio::sync::OwnedSemaphorePermit,
dispatcher: Arc<D>,
result_sender: mpsc::UnboundedSender<DispatchFinished>,
heartbeat_sender: mpsc::UnboundedSender<HeartbeatRequest>,
heartbeat_bookkeeper: &HeartbeatBookkeeper,
in_flight: &mut HashMap<ActivityExecutionKey, InFlightActivity>,
) -> Result<(), WorkerError>
where
D: ActivityDispatcher,
{
info!(
activity_type = %task.activity_type,
activity_id = task.activity_id.sequence_position(),
workflow_id = %task.workflow_id,
attempt = task.attempt,
"received activity task"
);
let key = ActivityExecutionKey::new(task.workflow_id.clone(), task.activity_id.clone());
heartbeat_bookkeeper.register(key.clone())?;
let (context, cancellation_handle) = ActivityContext::for_workflow(
Some(task.workflow_id.clone()),
task.activity_id.clone(),
task.attempt,
Some(heartbeat_sender),
);
let finished_key = key.clone();
let join_handle = tokio::spawn(async move {
let outcome = dispatcher.dispatch(task, context).await;
if result_sender
.send(DispatchFinished {
key: finished_key,
outcome,
})
.is_err()
{
debug!("worker loop stopped before dispatch outcome could be delivered");
}
drop(permit);
});
in_flight.insert(
key,
InFlightActivity {
cancellation_handle,
join_handle,
},
);
Ok(())
}
fn deliver_cancellation(
workflow_id: WorkflowId,
activity_id: &ActivityId,
in_flight: &HashMap<ActivityExecutionKey, InFlightActivity>,
) {
let key = ActivityExecutionKey::new(workflow_id, activity_id.clone());
if let Some(in_flight_activity) = in_flight.get(&key) {
in_flight_activity.cancellation_handle.cancel();
info!(
activity_id = activity_id.sequence_position(),
"delivered cooperative activity cancellation"
);
}
}
fn cancel_all_in_flight(in_flight: &HashMap<ActivityExecutionKey, InFlightActivity>) {
for (key, in_flight_activity) in in_flight {
in_flight_activity.cancellation_handle.cancel();
info!(
activity_id = key.activity_id.sequence_position(),
workflow_id = %key.workflow_id,
"delivered cooperative activity cancellation during worker shutdown"
);
}
}
#[derive(Debug, thiserror::Error)]
#[error("worker max_concurrency must be greater than zero")]
struct InvalidMaxConcurrency;
#[cfg(test)]
#[path = "loop_tests.rs"]
mod tests;