#![cfg(feature = "sqs")]
use std::sync::Arc;
use async_trait::async_trait;
use aws_lambda_events::event::sqs::{BatchItemFailure, SqsBatchResponse, SqsEvent, SqsMessage};
use aws_sdk_sqs::Client as SqsClient;
use turul_a2a::durable_executor::{DurableExecutorQueue, QueueError, QueuedExecutorJob};
use turul_a2a::router::AppState;
use turul_a2a::server::spawn::{SpawnDeps, SpawnScope, run_queued_executor_job};
use turul_a2a_types::{Message, Part, Role, TaskState, TaskStatus};
#[derive(Clone)]
pub struct SqsDurableExecutorQueue {
client: Arc<SqsClient>,
queue_url: String,
}
impl SqsDurableExecutorQueue {
pub fn new(queue_url: impl Into<String>, client: Arc<SqsClient>) -> Self {
Self {
client,
queue_url: queue_url.into(),
}
}
}
#[async_trait]
impl DurableExecutorQueue for SqsDurableExecutorQueue {
fn max_payload_bytes(&self) -> usize {
256 * 1024
}
async fn enqueue(&self, job: QueuedExecutorJob) -> Result<(), QueueError> {
let encoded = serde_json::to_string(&job)?;
let max = self.max_payload_bytes();
if encoded.len() > max {
return Err(QueueError::PayloadTooLarge {
actual: encoded.len(),
max,
});
}
self.client
.send_message()
.queue_url(&self.queue_url)
.message_body(encoded)
.send()
.await
.map_err(|e| QueueError::Transport(format!("{e}")))?;
Ok(())
}
fn kind(&self) -> &'static str {
"sqs"
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LambdaEvent {
Http,
Sqs,
Unknown,
}
pub fn classify_event(event: &serde_json::Value) -> LambdaEvent {
if is_sqs_event_shape(event) {
return LambdaEvent::Sqs;
}
if is_http_event_shape(event) {
return LambdaEvent::Http;
}
LambdaEvent::Unknown
}
fn is_sqs_event_shape(event: &serde_json::Value) -> bool {
event
.get("Records")
.and_then(|r| r.as_array())
.and_then(|arr| arr.first())
.and_then(|rec| rec.get("eventSource"))
.and_then(|s| s.as_str())
.map(|s| s == "aws:sqs")
.unwrap_or(false)
}
fn is_http_event_shape(event: &serde_json::Value) -> bool {
if event.get("httpMethod").is_some() {
return true;
}
if let Some(req_ctx) = event.get("requestContext") {
if req_ctx.pointer("/http/method").is_some() {
return true;
}
}
if event.get("routeKey").is_some() {
return true;
}
false
}
pub async fn drive_sqs_batch(state: &AppState, event: SqsEvent) -> SqsBatchResponse {
let mut failures = Vec::new();
for record in event.records {
if let Err(id) = drive_sqs_record(state, record).await {
let mut f = BatchItemFailure::default();
f.item_identifier = id;
failures.push(f);
}
}
let mut resp = SqsBatchResponse::default();
resp.batch_item_failures = failures;
resp
}
async fn drive_sqs_record(state: &AppState, record: SqsMessage) -> Result<(), String> {
let identifier = record
.message_id
.clone()
.unwrap_or_else(|| "<no-message-id>".to_string());
let body = match record.body.as_deref() {
Some(b) => b,
None => {
tracing::error!(item = %identifier, "SQS record has no body");
return Err(identifier);
}
};
let job: QueuedExecutorJob = match serde_json::from_str(body) {
Ok(j) => j,
Err(e) => {
tracing::error!(item = %identifier, error = %e, "failed to deserialise QueuedExecutorJob");
return Err(identifier);
}
};
if job.version != QueuedExecutorJob::VERSION {
tracing::error!(
item = %identifier,
version = job.version,
expected = QueuedExecutorJob::VERSION,
"unknown envelope version"
);
return Err(identifier);
}
let task = match state
.task_storage
.get_task(&job.tenant, &job.task_id, &job.owner, None)
.await
{
Ok(Some(t)) => t,
Ok(None) => {
tracing::error!(
item = %identifier,
tenant = %job.tenant,
task_id = %job.task_id,
"task not found on SQS dequeue"
);
return Err(identifier);
}
Err(e) => {
tracing::error!(item = %identifier, error = %e, "get_task failed on SQS dequeue");
return Err(identifier);
}
};
if let Some(status) = task.status() {
if let Ok(s) = status.state() {
use turul_a2a_types::state_machine::is_terminal;
if is_terminal(s) {
tracing::debug!(
item = %identifier,
state = ?s,
"task already terminal; skipping executor invocation"
);
return Ok(());
}
}
}
let cancel_requested = state
.cancellation_supervisor
.supervisor_get_cancel_requested(&job.tenant, &job.task_id)
.await
.unwrap_or(false);
if cancel_requested {
let reason = Message::new(
uuid::Uuid::now_v7().to_string(),
Role::Agent,
vec![Part::text("canceled before durable executor dispatch")],
);
let canceled_status = TaskStatus::new(TaskState::Canceled).with_message(reason);
let canceled_event = turul_a2a::streaming::StreamEvent::StatusUpdate {
status_update: turul_a2a::streaming::StatusUpdatePayload {
task_id: job.task_id.clone(),
context_id: job.context_id.clone(),
status: serde_json::to_value(&canceled_status).unwrap_or_default(),
},
};
match state
.atomic_store
.update_task_status_with_events(
&job.tenant,
&job.task_id,
&job.owner,
canceled_status,
vec![canceled_event],
)
.await
{
Ok(_) => {
tracing::info!(
item = %identifier,
tenant = %job.tenant,
task_id = %job.task_id,
"ADR-018: canceled before dispatch — CANCELED committed, executor never invoked"
);
state.event_broker.notify(&job.task_id).await;
return Ok(());
}
Err(turul_a2a::storage::A2aStorageError::TerminalStateAlreadySet { .. }) => {
tracing::debug!(
item = %identifier,
"task reached terminal concurrently with cancel — success"
);
return Ok(());
}
Err(e) => {
tracing::error!(
item = %identifier,
error = %e,
"ADR-018 canceled compensation failed; batch-item retry"
);
return Err(identifier);
}
}
}
let deps = SpawnDeps {
executor: state.executor.clone(),
task_storage: state.task_storage.clone(),
atomic_store: state.atomic_store.clone(),
event_broker: state.event_broker.clone(),
in_flight: state.in_flight.clone(),
push_dispatcher: state.push_dispatcher.clone(),
};
let scope = SpawnScope {
tenant: job.tenant.clone(),
owner: job.owner.clone(),
task_id: job.task_id.clone(),
context_id: job.context_id.clone(),
message: match Message::try_from(job.message.clone()) {
Ok(m) => m,
Err(e) => {
tracing::error!(item = %identifier, error = %e, "invalid message in SQS job");
return Err(identifier);
}
},
claims: job.claims.clone(),
};
run_queued_executor_job(deps, scope).await;
Ok(())
}