use crate::a2a::bridge;
use crate::a2a::proto;
use crate::a2a::proto::a2a_service_server::{A2aService, A2aServiceServer};
use crate::a2a::types::{self as local, TaskState};
use crate::bus::AgentBus;
use dashmap::DashMap;
use std::pin::Pin;
use std::sync::Arc;
use tokio::sync::broadcast;
use tokio_stream::Stream;
use tonic::{Request, Response, Status};
type StreamResult<T> = Pin<Box<dyn Stream<Item = Result<T, Status>> + Send>>;
pub struct GrpcTaskStore {
tasks: DashMap<String, local::Task>,
push_configs: DashMap<String, Vec<local::TaskPushNotificationConfig>>,
card: local::AgentCard,
bus: Option<Arc<AgentBus>>,
update_tx: broadcast::Sender<(String, local::TaskStatus)>,
}
impl GrpcTaskStore {
pub fn new(card: local::AgentCard) -> Arc<Self> {
let (update_tx, _) = broadcast::channel(256);
Arc::new(Self {
tasks: DashMap::new(),
push_configs: DashMap::new(),
card,
bus: None,
update_tx,
})
}
pub fn with_bus(card: local::AgentCard, bus: Arc<AgentBus>) -> Arc<Self> {
let mut store = Self::new(card);
Arc::get_mut(&mut store)
.expect("fresh Arc must be uniquely owned")
.bus = Some(bus);
store
}
pub fn upsert_task(&self, task: local::Task) {
let _ = self.update_tx.send((task.id.clone(), task.status.clone()));
self.tasks.insert(task.id.clone(), task);
}
pub fn get_task(&self, id: &str) -> Option<local::Task> {
self.tasks.get(id).map(|r| r.value().clone())
}
pub fn subscribe_updates(&self) -> broadcast::Receiver<(String, local::TaskStatus)> {
self.update_tx.subscribe()
}
pub fn into_service(self: Arc<Self>) -> A2aServiceServer<A2aServiceImpl> {
A2aServiceServer::new(A2aServiceImpl { store: self })
}
}
pub struct A2aServiceImpl {
store: Arc<GrpcTaskStore>,
}
#[tonic::async_trait]
impl A2aService for A2aServiceImpl {
async fn send_message(
&self,
request: Request<proto::SendMessageRequest>,
) -> Result<Response<proto::SendMessageResponse>, Status> {
let req = request.into_inner();
let msg = req
.request
.ok_or_else(|| Status::invalid_argument("missing message"))?;
let local_msg = bridge::proto_message_to_local(&msg);
let task_id = local_msg
.task_id
.clone()
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
let task = local::Task {
id: task_id.clone(),
context_id: local_msg.context_id.clone(),
status: local::TaskStatus {
state: TaskState::Submitted,
message: Some(local_msg.clone()),
timestamp: Some(chrono::Utc::now().to_rfc3339()),
},
artifacts: vec![],
history: vec![local_msg],
metadata: Default::default(),
};
self.store.upsert_task(task.clone());
if let Some(ref bus) = self.store.bus {
let handle = bus.handle("grpc-server");
handle.send_task_update(&task_id, TaskState::Submitted, None);
}
let proto_task = bridge::local_task_to_proto(&task);
Ok(Response::new(proto::SendMessageResponse {
payload: Some(proto::send_message_response::Payload::Task(proto_task)),
}))
}
type SendStreamingMessageStream = StreamResult<proto::StreamResponse>;
async fn send_streaming_message(
&self,
request: Request<proto::SendMessageRequest>,
) -> Result<Response<Self::SendStreamingMessageStream>, Status> {
let req = request.into_inner();
let msg = req
.request
.ok_or_else(|| Status::invalid_argument("missing message"))?;
let local_msg = bridge::proto_message_to_local(&msg);
let task_id = local_msg
.task_id
.clone()
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
let task = local::Task {
id: task_id.clone(),
context_id: local_msg.context_id.clone(),
status: local::TaskStatus {
state: TaskState::Submitted,
message: Some(local_msg.clone()),
timestamp: Some(chrono::Utc::now().to_rfc3339()),
},
artifacts: vec![],
history: vec![local_msg],
metadata: Default::default(),
};
self.store.upsert_task(task.clone());
let proto_task = bridge::local_task_to_proto(&task);
let mut rx = self.store.subscribe_updates();
let tid = task_id.clone();
let stream = async_stream::try_stream! {
yield proto::StreamResponse {
payload: Some(proto::stream_response::Payload::Task(proto_task)),
};
loop {
match rx.recv().await {
Ok((id, status)) if id == tid => {
let proto_status = bridge::local_task_status_to_proto(&status);
let is_terminal = status.state.is_terminal();
yield proto::StreamResponse {
payload: Some(proto::stream_response::Payload::StatusUpdate(
proto::TaskStatusUpdateEvent {
task_id: tid.clone(),
context_id: String::new(),
status: Some(proto_status),
r#final: is_terminal,
metadata: None,
},
)),
};
if is_terminal {
break;
}
}
Ok(_) => continue,
Err(broadcast::error::RecvError::Lagged(_)) => continue,
Err(broadcast::error::RecvError::Closed) => break,
}
}
};
Ok(Response::new(
Box::pin(stream) as Self::SendStreamingMessageStream
))
}
async fn get_task(
&self,
request: Request<proto::GetTaskRequest>,
) -> Result<Response<proto::Task>, Status> {
let req = request.into_inner();
let task_id = req.name.strip_prefix("tasks/").unwrap_or(&req.name);
let task = self
.store
.get_task(task_id)
.ok_or_else(|| Status::not_found(format!("task {task_id} not found")))?;
Ok(Response::new(bridge::local_task_to_proto(&task)))
}
async fn cancel_task(
&self,
request: Request<proto::CancelTaskRequest>,
) -> Result<Response<proto::Task>, Status> {
let req = request.into_inner();
let task_id = req.name.strip_prefix("tasks/").unwrap_or(&req.name);
let mut task = self
.store
.tasks
.get_mut(task_id)
.ok_or_else(|| Status::not_found(format!("task {task_id} not found")))?;
if task.status.state.is_terminal() {
return Err(Status::failed_precondition(
"task already in terminal state",
));
}
task.status = local::TaskStatus {
state: TaskState::Cancelled,
message: None,
timestamp: Some(chrono::Utc::now().to_rfc3339()),
};
let snapshot = task.clone();
drop(task);
let _ = self
.store
.update_tx
.send((task_id.to_string(), snapshot.status.clone()));
Ok(Response::new(bridge::local_task_to_proto(&snapshot)))
}
type TaskSubscriptionStream = StreamResult<proto::StreamResponse>;
async fn task_subscription(
&self,
request: Request<proto::TaskSubscriptionRequest>,
) -> Result<Response<Self::TaskSubscriptionStream>, Status> {
let req = request.into_inner();
let task_id = req
.name
.strip_prefix("tasks/")
.unwrap_or(&req.name)
.to_string();
let task = self
.store
.get_task(&task_id)
.ok_or_else(|| Status::not_found(format!("task {task_id} not found")))?;
let proto_task = bridge::local_task_to_proto(&task);
let mut rx = self.store.subscribe_updates();
let tid = task_id.clone();
if task.status.state.is_terminal() {
let stream = async_stream::try_stream! {
yield proto::StreamResponse {
payload: Some(proto::stream_response::Payload::Task(proto_task)),
};
};
return Ok(Response::new(
Box::pin(stream) as Self::TaskSubscriptionStream
));
}
let stream = async_stream::try_stream! {
yield proto::StreamResponse {
payload: Some(proto::stream_response::Payload::Task(proto_task)),
};
loop {
match rx.recv().await {
Ok((id, status)) if id == tid => {
let proto_status = bridge::local_task_status_to_proto(&status);
let is_terminal = status.state.is_terminal();
yield proto::StreamResponse {
payload: Some(proto::stream_response::Payload::StatusUpdate(
proto::TaskStatusUpdateEvent {
task_id: tid.clone(),
context_id: String::new(),
status: Some(proto_status),
r#final: is_terminal,
metadata: None,
},
)),
};
if is_terminal { break; }
}
Ok(_) => continue,
Err(broadcast::error::RecvError::Lagged(_)) => continue,
Err(broadcast::error::RecvError::Closed) => break,
}
}
};
Ok(Response::new(
Box::pin(stream) as Self::TaskSubscriptionStream
))
}
async fn create_task_push_notification_config(
&self,
request: Request<proto::CreateTaskPushNotificationConfigRequest>,
) -> Result<Response<proto::TaskPushNotificationConfig>, Status> {
let req = request.into_inner();
let task_id = req.parent.strip_prefix("tasks/").unwrap_or(&req.parent);
if self.store.get_task(task_id).is_none() {
return Err(Status::not_found(format!("task {task_id} not found")));
}
let config = req
.config
.ok_or_else(|| Status::invalid_argument("missing config"))?;
let pnc = config.push_notification_config.as_ref();
let local_config = local::TaskPushNotificationConfig {
id: task_id.to_string(),
push_notification_config: local::PushNotificationConfig {
url: pnc.map(|c| c.url.clone()).unwrap_or_default(),
token: pnc.and_then(|c| {
if c.token.is_empty() {
None
} else {
Some(c.token.clone())
}
}),
id: pnc.and_then(|c| {
if c.id.is_empty() {
None
} else {
Some(c.id.clone())
}
}),
},
};
self.store
.push_configs
.entry(task_id.to_string())
.or_default()
.push(local_config);
Ok(Response::new(config))
}
async fn get_task_push_notification_config(
&self,
request: Request<proto::GetTaskPushNotificationConfigRequest>,
) -> Result<Response<proto::TaskPushNotificationConfig>, Status> {
let req = request.into_inner();
let parts: Vec<&str> = req.name.split('/').collect();
if parts.len() < 4 {
return Err(Status::invalid_argument("invalid name format"));
}
let task_id = parts[1];
let config_id = parts[3];
let configs = self
.store
.push_configs
.get(task_id)
.ok_or_else(|| Status::not_found("no configs for task"))?;
let _found = configs
.iter()
.find(|c| c.push_notification_config.id.as_deref() == Some(config_id))
.ok_or_else(|| Status::not_found("config not found"))?;
Ok(Response::new(proto::TaskPushNotificationConfig {
name: req.name,
push_notification_config: None, }))
}
async fn list_task_push_notification_config(
&self,
request: Request<proto::ListTaskPushNotificationConfigRequest>,
) -> Result<Response<proto::ListTaskPushNotificationConfigResponse>, Status> {
let req = request.into_inner();
let task_id = req.parent.strip_prefix("tasks/").unwrap_or(&req.parent);
let configs: Vec<proto::TaskPushNotificationConfig> = self
.store
.push_configs
.get(task_id)
.map(|cs| {
cs.iter()
.map(|c| proto::TaskPushNotificationConfig {
name: format!(
"tasks/{}/pushNotificationConfigs/{}",
task_id,
c.push_notification_config
.id
.as_deref()
.unwrap_or("default")
),
push_notification_config: None,
})
.collect()
})
.unwrap_or_default();
Ok(Response::new(
proto::ListTaskPushNotificationConfigResponse {
configs,
next_page_token: String::new(),
},
))
}
async fn delete_task_push_notification_config(
&self,
request: Request<proto::DeleteTaskPushNotificationConfigRequest>,
) -> Result<Response<()>, Status> {
let req = request.into_inner();
let parts: Vec<&str> = req.name.split('/').collect();
if parts.len() < 4 {
return Err(Status::invalid_argument("invalid name format"));
}
let task_id = parts[1];
let config_id = parts[3];
if let Some(mut configs) = self.store.push_configs.get_mut(task_id) {
configs.retain(|c| c.push_notification_config.id.as_deref() != Some(config_id));
}
Ok(Response::new(()))
}
async fn get_agent_card(
&self,
_request: Request<proto::GetAgentCardRequest>,
) -> Result<Response<proto::AgentCard>, Status> {
Ok(Response::new(bridge::local_card_to_proto(&self.store.card)))
}
}