use std::collections::BTreeSet;
use std::pin::Pin;
use aion_core::{ActivityError, ActivityId, Payload, WorkflowId};
use aion_proto::{
ProtoActivityId, ProtoActivityResult, ProtoActivityTask, ProtoHeartbeat, ProtoPayload,
ProtoWorkflowId, proto_activity_result,
};
use async_trait::async_trait;
use futures::{Stream, StreamExt};
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tonic::{Request, metadata::MetadataValue, transport::Channel};
use crate::config::WorkerConfig;
use crate::error::{MissingActivityHandler, WorkerError};
type GeneratedClient = aion_proto::generated::worker_protocol_client::WorkerProtocolClient<Channel>;
pub type WorkerTaskStream =
Pin<Box<dyn Stream<Item = Result<WorkerSessionEvent, WorkerError>> + Send>>;
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum WorkerSessionEvent {
Task(ProtoActivityTask),
Drain,
ResultAck {
workflow_id: WorkflowId,
activity_id: ActivityId,
},
Cancel {
workflow_id: WorkflowId,
activity_id: ActivityId,
},
}
#[async_trait]
pub trait WorkerSession: Send {
async fn handshake(&mut self, config: &WorkerConfig) -> Result<(), WorkerError>;
async fn register(
&mut self,
activity_types: Vec<String>,
available_handlers: &BTreeSet<String>,
) -> Result<(), WorkerError>;
fn receive_tasks(&mut self) -> WorkerTaskStream;
async fn report_result(
&mut self,
workflow_id: WorkflowId,
activity_id: ActivityId,
result: Payload,
) -> Result<(), WorkerError>;
async fn report_failure(
&mut self,
workflow_id: WorkflowId,
activity_id: ActivityId,
failure: ActivityError,
) -> Result<(), WorkerError>;
async fn send_heartbeat(
&mut self,
workflow_id: WorkflowId,
activity_id: ActivityId,
progress: Option<Payload>,
) -> Result<(), WorkerError>;
}
pub fn validate_activity_handlers(
activity_types: &[String],
available_handlers: &BTreeSet<String>,
) -> Result<(), WorkerError> {
if let Some(activity_type) = activity_types
.iter()
.find(|activity_type| !available_handlers.contains(*activity_type))
{
return Err(WorkerError::registration(MissingActivityHandler {
activity_type: activity_type.clone(),
}));
}
Ok(())
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct RegisteredSessionInfo {
pub worker_id: u64,
pub namespace: String,
pub heartbeat_window: std::time::Duration,
}
pub struct GrpcWorkerSession {
config: WorkerConfig,
activity_types: Vec<String>,
client: Option<GeneratedClient>,
sender: Option<mpsc::Sender<aion_proto::generated::WorkerToServer>>,
receiver: Option<tonic::codec::Streaming<aion_proto::generated::ServerToWorker>>,
registered_info: Option<RegisteredSessionInfo>,
}
impl GrpcWorkerSession {
pub async fn connect(config: WorkerConfig) -> Result<Self, WorkerError> {
let client = GeneratedClient::connect(config.endpoint.clone())
.await
.map_err(|source| WorkerError::Connect { source })?;
Ok(Self {
config,
activity_types: Vec::new(),
client: Some(client),
sender: None,
receiver: None,
registered_info: None,
})
}
#[must_use]
pub fn from_channel(config: WorkerConfig, channel: Channel) -> Self {
Self {
config,
activity_types: Vec::new(),
client: Some(GeneratedClient::new(channel)),
sender: None,
receiver: None,
registered_info: None,
}
}
#[must_use]
pub const fn registered_info(&self) -> Option<&RegisteredSessionInfo> {
self.registered_info.as_ref()
}
async fn open_registered_stream(
&mut self,
register: aion_proto::generated::RegisterWorker,
) -> Result<(), WorkerError> {
let client = self.client.as_mut().ok_or_else(|| {
WorkerError::registration(SessionStateError {
message: String::from("worker session has not completed its handshake"),
})
})?;
let (sender, outbound) = mpsc::channel(16);
sender
.try_send(aion_proto::generated::WorkerToServer {
message: Some(aion_proto::generated::worker_to_server::Message::Register(
register,
)),
})
.map_err(|_| {
WorkerError::registration(SessionStateError {
message: String::from(
"could not queue RegisterWorker as the first stream frame",
),
})
})?;
let mut request = Request::new(ReceiverStream::new(outbound));
apply_auth_metadata(request.metadata_mut(), &self.config)?;
let response = client
.stream_worker(request)
.await
.map_err(registration_denial_error)?;
let mut receiver = response.into_inner();
let first = tokio::time::timeout(self.config.reconnect.max_backoff, receiver.message())
.await
.map_err(|_| {
WorkerError::registration(SessionStateError {
message: format!(
"server did not acknowledge registration within {:?}",
self.config.reconnect.max_backoff
),
})
})?
.map_err(registration_denial_error)?;
let ack = match first.and_then(|frame| frame.message) {
Some(aion_proto::generated::server_to_worker::Message::RegisterAck(ack)) => ack,
Some(_) => {
return Err(WorkerError::decode(SessionStateError {
message: String::from(
"protocol violation: server sent a non-RegisterAck frame before \
acknowledging registration",
),
}));
}
None => {
return Err(WorkerError::registration(SessionStateError {
message: String::from(
"server ended the stream before acknowledging registration",
),
}));
}
};
self.registered_info = Some(RegisteredSessionInfo {
worker_id: ack.worker_id,
namespace: ack.namespace,
heartbeat_window: std::time::Duration::from_millis(ack.heartbeat_window_ms),
});
self.sender = Some(sender);
self.receiver = Some(receiver);
Ok(())
}
async fn send_to_server(
&self,
message: aion_proto::generated::worker_to_server::Message,
) -> Result<(), WorkerError> {
let sender = self.sender.as_ref().ok_or_else(|| {
WorkerError::registration(SessionStateError {
message: String::from("worker stream has not been opened"),
})
})?;
let send = sender.send(aion_proto::generated::WorkerToServer {
message: Some(message),
});
tokio::time::timeout(self.config.reconnect.max_backoff, send)
.await
.map_err(|_| WorkerError::Transport {
source: tonic::Status::unavailable(format!(
"worker stream send did not complete within {:?}",
self.config.reconnect.max_backoff
)),
})?
.map_err(|source| WorkerError::Transport {
source: tonic::Status::unavailable(format!("worker stream send failed: {source}")),
})
}
}
fn registration_denial_error(status: tonic::Status) -> WorkerError {
if status.code() == tonic::Code::Unauthenticated {
WorkerError::Handshake { source: status }
} else {
WorkerError::Registration {
source: Box::new(status),
}
}
}
fn apply_auth_metadata(
metadata: &mut tonic::metadata::MetadataMap,
config: &WorkerConfig,
) -> Result<(), WorkerError> {
let namespace =
MetadataValue::try_from(config.namespace.as_str()).map_err(|_| WorkerError::Handshake {
source: tonic::Status::invalid_argument("worker namespace is not valid gRPC metadata"),
})?;
let subject =
MetadataValue::try_from(config.subject.as_str()).map_err(|_| WorkerError::Handshake {
source: tonic::Status::invalid_argument("worker subject is not valid gRPC metadata"),
})?;
metadata.insert("x-aion-namespaces", namespace);
metadata.insert("x-aion-subject", subject);
Ok(())
}
#[async_trait]
impl WorkerSession for GrpcWorkerSession {
async fn handshake(&mut self, config: &WorkerConfig) -> Result<(), WorkerError> {
self.config = config.clone();
if self.client.is_none() {
self.client = Some(
GeneratedClient::connect(self.config.endpoint.clone())
.await
.map_err(|source| WorkerError::Connect { source })?,
);
}
Ok(())
}
async fn register(
&mut self,
activity_types: Vec<String>,
available_handlers: &BTreeSet<String>,
) -> Result<(), WorkerError> {
validate_activity_handlers(&activity_types, available_handlers)?;
self.activity_types.clone_from(&activity_types);
let register = aion_proto::generated::RegisterWorker {
namespace: self.config.task_queue.clone(),
activity_types,
};
self.open_registered_stream(register).await
}
fn receive_tasks(&mut self) -> WorkerTaskStream {
match self.receiver.take() {
Some(receiver) => Box::pin(receiver.filter_map(|message| async move {
Some(match message {
Ok(server_message) => decode_server_message(server_message),
Err(source) => Err(WorkerError::Transport { source }),
})
})),
None => Box::pin(futures::stream::iter([Err(WorkerError::Transport {
source: tonic::Status::failed_precondition(
"worker receive stream has not been opened",
),
})])),
}
}
async fn report_result(
&mut self,
workflow_id: WorkflowId,
activity_id: ActivityId,
result: Payload,
) -> Result<(), WorkerError> {
let result = ProtoActivityResult {
workflow_id: Some(ProtoWorkflowId::from(workflow_id)),
activity_id: Some(ProtoActivityId::from(activity_id)),
outcome: Some(proto_activity_result::Outcome::Result(ProtoPayload::from(
result,
))),
};
self.send_to_server(aion_proto::generated::worker_to_server::Message::Result(
generated_activity_result(result),
))
.await
}
async fn report_failure(
&mut self,
workflow_id: WorkflowId,
activity_id: ActivityId,
failure: ActivityError,
) -> Result<(), WorkerError> {
let result = ProtoActivityResult {
workflow_id: Some(ProtoWorkflowId::from(workflow_id)),
activity_id: Some(ProtoActivityId::from(activity_id)),
outcome: Some(proto_activity_result::Outcome::Error(failure.into())),
};
self.send_to_server(aion_proto::generated::worker_to_server::Message::Result(
generated_activity_result(result),
))
.await
}
async fn send_heartbeat(
&mut self,
workflow_id: WorkflowId,
activity_id: ActivityId,
progress: Option<Payload>,
) -> Result<(), WorkerError> {
let heartbeat = ProtoHeartbeat {
workflow_id: Some(ProtoWorkflowId::from(workflow_id)),
activity_id: Some(ProtoActivityId::from(activity_id)),
progress: progress.map(ProtoPayload::from),
};
self.send_to_server(aion_proto::generated::worker_to_server::Message::Heartbeat(
generated_heartbeat(heartbeat),
))
.await
}
}
fn decode_server_message(
message: aion_proto::generated::ServerToWorker,
) -> Result<WorkerSessionEvent, WorkerError> {
match message.message {
Some(aion_proto::generated::server_to_worker::Message::Task(task)) => {
Ok(WorkerSessionEvent::Task(proto_task(task)))
}
Some(aion_proto::generated::server_to_worker::Message::Drain(_)) => {
Ok(WorkerSessionEvent::Drain)
}
Some(aion_proto::generated::server_to_worker::Message::ResultAck(ack)) => {
decode_result_ack(ack)
}
Some(aion_proto::generated::server_to_worker::Message::RegisterAck(_)) => {
Err(WorkerError::decode(SessionStateError {
message: String::from(
"protocol violation: RegisterAck received after registration completed",
),
}))
}
None => Err(WorkerError::decode(SessionStateError {
message: String::from("server-to-worker message was empty"),
})),
}
}
fn decode_result_ack(
ack: aion_proto::generated::ResultAck,
) -> Result<WorkerSessionEvent, WorkerError> {
let workflow_id = ack
.workflow_id
.ok_or_else(|| {
WorkerError::decode(SessionStateError {
message: String::from("result ack workflow_id is missing"),
})
})
.and_then(|id| {
WorkflowId::try_from(ProtoWorkflowId { uuid: id.uuid }).map_err(|source| {
WorkerError::decode(SessionStateError {
message: format!("result ack workflow_id is invalid: {source}"),
})
})
})?;
let activity_id = ack
.activity_id
.map(|id| ActivityId::from_sequence_position(id.sequence_position))
.ok_or_else(|| {
WorkerError::decode(SessionStateError {
message: String::from("result ack activity_id is missing"),
})
})?;
Ok(WorkerSessionEvent::ResultAck {
workflow_id,
activity_id,
})
}
fn generated_activity_result(value: ProtoActivityResult) -> aion_proto::generated::ActivityResult {
aion_proto::generated::ActivityResult {
workflow_id: value.workflow_id.map(generated_workflow_id),
activity_id: value.activity_id.map(generated_activity_id),
outcome: value.outcome.map(|outcome| match outcome {
proto_activity_result::Outcome::Result(result) => {
aion_proto::generated::activity_result::Outcome::Result(generated_payload(result))
}
proto_activity_result::Outcome::Error(error) => {
aion_proto::generated::activity_result::Outcome::Error(generated_error(error))
}
}),
}
}
fn generated_heartbeat(value: ProtoHeartbeat) -> aion_proto::generated::Heartbeat {
aion_proto::generated::Heartbeat {
workflow_id: value.workflow_id.map(generated_workflow_id),
activity_id: value.activity_id.map(generated_activity_id),
progress: value.progress.map(generated_payload),
}
}
fn proto_task(value: aion_proto::generated::ActivityTask) -> ProtoActivityTask {
ProtoActivityTask {
workflow_id: value.workflow_id.map(proto_workflow_id),
activity_id: value.activity_id.map(proto_activity_id),
activity_type: value.activity_type,
input: value.input.map(proto_payload),
attempt: value.attempt,
}
}
fn generated_payload(value: ProtoPayload) -> aion_proto::generated::Payload {
aion_proto::generated::Payload {
content_type: value.content_type,
bytes: value.bytes,
}
}
fn proto_payload(value: aion_proto::generated::Payload) -> ProtoPayload {
ProtoPayload {
content_type: value.content_type,
bytes: value.bytes,
}
}
fn generated_workflow_id(value: ProtoWorkflowId) -> aion_proto::generated::WorkflowId {
aion_proto::generated::WorkflowId { uuid: value.uuid }
}
fn proto_workflow_id(value: aion_proto::generated::WorkflowId) -> ProtoWorkflowId {
ProtoWorkflowId { uuid: value.uuid }
}
fn generated_activity_id(value: ProtoActivityId) -> aion_proto::generated::ActivityId {
aion_proto::generated::ActivityId {
sequence_position: value.sequence_position,
}
}
fn proto_activity_id(value: aion_proto::generated::ActivityId) -> ProtoActivityId {
ProtoActivityId {
sequence_position: value.sequence_position,
}
}
fn generated_error(value: aion_proto::ProtoActivityError) -> aion_proto::generated::ActivityError {
aion_proto::generated::ActivityError {
kind: value.kind,
message: value.message,
details: value.details.map(generated_payload),
}
}
#[derive(thiserror::Error, Debug)]
#[error("{message}")]
struct SessionStateError {
message: String,
}
#[cfg(test)]
mod tests {
use std::collections::BTreeSet;
use aion_proto::ProtoActivityTask;
use async_trait::async_trait;
use futures::{StreamExt, stream};
use super::{
WorkerSession, WorkerSessionEvent, WorkerTaskStream, apply_auth_metadata,
validate_activity_handlers,
};
use crate::error::WorkerError;
use crate::{ReconnectConfig, WorkerConfig};
#[derive(Default)]
struct FakeSession {
handshakes: Vec<(String, String)>,
registrations: Vec<Vec<String>>,
}
#[async_trait]
impl WorkerSession for FakeSession {
async fn handshake(&mut self, config: &WorkerConfig) -> Result<(), WorkerError> {
self.handshakes
.push((config.task_queue.clone(), config.identity.clone()));
Ok(())
}
async fn register(
&mut self,
activity_types: Vec<String>,
available_handlers: &BTreeSet<String>,
) -> Result<(), WorkerError> {
validate_activity_handlers(&activity_types, available_handlers)?;
self.registrations.push(activity_types);
Ok(())
}
fn receive_tasks(&mut self) -> WorkerTaskStream {
Box::pin(stream::iter([Ok(WorkerSessionEvent::Task(
ProtoActivityTask {
workflow_id: None,
activity_id: None,
activity_type: String::from("charge-card"),
input: None,
attempt: 1,
},
))]))
}
async fn report_result(
&mut self,
workflow_id: aion_core::WorkflowId,
activity_id: aion_core::ActivityId,
result: aion_core::Payload,
) -> Result<(), WorkerError> {
drop((workflow_id, activity_id, result));
Ok(())
}
async fn report_failure(
&mut self,
workflow_id: aion_core::WorkflowId,
activity_id: aion_core::ActivityId,
failure: aion_core::ActivityError,
) -> Result<(), WorkerError> {
drop((workflow_id, activity_id, failure));
Ok(())
}
async fn send_heartbeat(
&mut self,
workflow_id: aion_core::WorkflowId,
activity_id: aion_core::ActivityId,
progress: Option<aion_core::Payload>,
) -> Result<(), WorkerError> {
drop((workflow_id, activity_id, progress));
Ok(())
}
}
#[test]
fn apply_auth_metadata_sets_worker_authorization_headers() -> Result<(), WorkerError> {
let config = WorkerConfig::builder()
.endpoint("http://127.0.0.1:50051")
.task_queue("payments")
.identity("worker-a")
.max_concurrency(4)
.reconnect_initial_backoff(std::time::Duration::from_millis(5))
.reconnect_max_backoff(std::time::Duration::from_millis(20))
.reconnect_max_attempts(3)
.namespace("payments")
.subject("worker-a")
.build()
.map_err(WorkerError::registration)?;
let mut metadata = tonic::metadata::MetadataMap::new();
apply_auth_metadata(&mut metadata, &config)?;
assert_eq!(
metadata
.get("x-aion-namespaces")
.and_then(|value| value.to_str().ok()),
Some("payments")
);
assert_eq!(
metadata
.get("x-aion-subject")
.and_then(|value| value.to_str().ok()),
Some("worker-a")
);
Ok(())
}
#[tokio::test]
async fn fake_session_records_handshake_and_registration() -> Result<(), WorkerError> {
let config = WorkerConfig::new(
"http://127.0.0.1:50051",
"payments",
"worker-a",
4,
ReconnectConfig::new(
std::time::Duration::from_millis(5),
std::time::Duration::from_millis(20),
3,
),
None,
);
let activity_types = vec![String::from("charge-card"), String::from("send-email")];
let handlers = activity_types.iter().cloned().collect::<BTreeSet<_>>();
let mut session = FakeSession::default();
session.handshake(&config).await?;
session.register(activity_types.clone(), &handlers).await?;
let received = session.receive_tasks().next().await;
assert_eq!(
session.handshakes,
vec![(String::from("payments"), String::from("worker-a"))]
);
assert_eq!(session.registrations, vec![activity_types]);
assert!(received.is_some());
Ok(())
}
#[tokio::test(start_paused = true)]
async fn report_send_times_out_retryably_at_max_backoff() -> Result<(), WorkerError> {
let config = WorkerConfig::new(
"http://127.0.0.1:50051",
"payments",
"worker-a",
1,
ReconnectConfig::new(
std::time::Duration::from_millis(5),
std::time::Duration::from_millis(20),
3,
),
None,
);
let (sender, receiver) = tokio::sync::mpsc::channel(1);
sender
.try_send(aion_proto::generated::WorkerToServer { message: None })
.map_err(WorkerError::decode)?;
let mut session = super::GrpcWorkerSession {
config,
activity_types: Vec::new(),
client: None,
sender: Some(sender),
receiver: None,
registered_info: None,
};
let result = session
.report_result(
aion_core::WorkflowId::new_v4(),
aion_core::ActivityId::from_sequence_position(1),
aion_core::Payload::new(aion_core::ContentType::Json, b"{}".to_vec()),
)
.await;
let Err(error) = result else {
return Err(WorkerError::Transport {
source: tonic::Status::internal("a hung send must time out, not hang"),
});
};
assert!(
matches!(error, WorkerError::Transport { .. }),
"send deadline elapse must be a retryable transport error: {error}"
);
assert!(error.is_retryable());
assert!(
error.to_string().contains("did not complete"),
"the error must name the deadline: {error}"
);
drop(receiver);
Ok(())
}
#[test]
fn registration_rejects_activity_without_handler() {
let activity_types = vec![String::from("charge-card"), String::from("send-email")];
let handlers = [String::from("charge-card")]
.into_iter()
.collect::<BTreeSet<_>>();
let result = validate_activity_handlers(&activity_types, &handlers);
assert!(result.is_err());
let error = match result {
Ok(()) => return,
Err(error) => error,
};
assert_eq!(
error.to_string(),
"worker registration failed: activity type `send-email` has no registered handler"
);
}
}