use std::sync::Arc;
use std::time::Duration;
use tokio::sync::broadcast;
use tokio_stream::wrappers::ReceiverStream;
use tonic::Status;
use turul_a2a_proto as pb;
use turul_a2a_types::Task;
use crate::error::A2aError;
use crate::grpc::error::a2a_to_status;
use crate::grpc::service::BoxedStreamResponseStream;
use crate::router::{self, AppState};
use crate::storage::A2aEventStore;
use crate::streaming::replay;
const STORE_POLL_INTERVAL: Duration = Duration::from_secs(2);
const STREAM_CHANNEL_DEPTH: usize = 64;
pub const LAST_EVENT_ID_METADATA: &str = "a2a-last-event-id";
pub async fn handle_send_streaming_message(
state: AppState,
tenant: String,
owner: String,
body: String,
) -> Result<BoxedStreamResponseStream, Status> {
let (task_id, wake_rx) =
router::setup_streaming_send(state.clone(), &tenant, &owner, None, body)
.await
.map_err(a2a_to_status)?;
Ok(make_store_grpc_stream(
state.event_store,
tenant,
task_id,
0,
wake_rx,
None,
))
}
pub async fn handle_subscribe_to_task(
state: AppState,
tenant: String,
owner: String,
task_id: String,
last_event_id_meta: Option<String>,
) -> Result<BoxedStreamResponseStream, Status> {
let task = state
.task_storage
.get_task(&tenant, &task_id, &owner, None)
.await
.map_err(|e| a2a_to_status(A2aError::from(e)))?
.ok_or_else(|| {
a2a_to_status(A2aError::TaskNotFound {
task_id: task_id.clone(),
})
})?;
if let Some(status) = task.status() {
if let Ok(s) = status.state() {
if s.is_terminal() {
return Err(a2a_to_status(A2aError::UnsupportedOperation {
message: format!("Task {task_id} is already in terminal state {s:?}"),
}));
}
}
}
let after_sequence = last_event_id_meta
.as_deref()
.and_then(replay::parse_last_event_id)
.filter(|parsed| parsed.task_id == task_id)
.map(|parsed| parsed.sequence)
.unwrap_or(0);
let initial_task = if after_sequence == 0 {
Some(task)
} else {
None
};
let wake_rx = state.event_broker.subscribe(&task_id).await;
Ok(make_store_grpc_stream(
state.event_store,
tenant,
task_id,
after_sequence,
wake_rx,
initial_task,
))
}
fn make_store_grpc_stream(
event_store: Arc<dyn A2aEventStore>,
tenant: String,
task_id: String,
after_sequence: u64,
mut wake_rx: broadcast::Receiver<()>,
initial_task: Option<Task>,
) -> BoxedStreamResponseStream {
let (tx, rx) =
tokio::sync::mpsc::channel::<Result<pb::StreamResponse, Status>>(STREAM_CHANNEL_DEPTH);
tokio::spawn(async move {
if let Some(task) = initial_task {
let response = pb::StreamResponse {
payload: Some(pb::stream_response::Payload::Task(task.as_proto().clone())),
};
if tx.send(Ok(response)).await.is_err() {
return;
}
}
let mut last_seq = after_sequence;
loop {
let events = match event_store
.get_events_after(&tenant, &task_id, last_seq)
.await
{
Ok(events) => events,
Err(err) => {
let _ = tx.send(Err(a2a_to_status(A2aError::from(err)))).await;
return;
}
};
let mut saw_terminal = false;
for (seq, event) in events {
last_seq = seq;
let value = serde_json::to_value(&event).unwrap_or_default();
let response = match serde_json::from_value::<pb::StreamResponse>(value) {
Ok(r) => r,
Err(err) => {
let _ = tx
.send(Err(Status::internal(format!(
"grpc adapter: failed to encode event seq {seq}: {err}"
))))
.await;
return;
}
};
if tx.send(Ok(response)).await.is_err() {
return; }
if event.is_terminal() {
saw_terminal = true;
}
}
if saw_terminal {
return; }
tokio::select! {
result = wake_rx.recv() => {
match result {
Ok(()) => {}
Err(broadcast::error::RecvError::Closed) => return,
Err(broadcast::error::RecvError::Lagged(_)) => {}
}
}
_ = tokio::time::sleep(STORE_POLL_INTERVAL) => {}
}
}
});
Box::pin(ReceiverStream::new(rx))
}