1use crate::a2a::bridge;
8use crate::a2a::proto;
9use crate::a2a::proto::a2a_service_server::{A2aService, A2aServiceServer};
10use crate::a2a::types::{self as local, TaskState};
11use crate::bus::AgentBus;
12
13use dashmap::DashMap;
14use std::pin::Pin;
15use std::sync::Arc;
16use tokio::sync::broadcast;
17use tokio_stream::Stream;
18use tonic::{Request, Response, Status};
19
20type StreamResult<T> = Pin<Box<dyn Stream<Item = Result<T, Status>> + Send>>;
21pub struct GrpcTaskStore {
44 tasks: DashMap<String, local::Task>,
45 push_configs: DashMap<String, Vec<local::TaskPushNotificationConfig>>,
46 card: local::AgentCard,
47 bus: Option<Arc<AgentBus>>,
48 update_tx: broadcast::Sender<(String, local::TaskStatus)>,
55}
56impl GrpcTaskStore {
57 pub fn new(card: local::AgentCard) -> Arc<Self> {
81 let (update_tx, _) = broadcast::channel(256);
82 Arc::new(Self {
83 tasks: DashMap::new(),
84 push_configs: DashMap::new(),
85 card,
86 bus: None,
87 update_tx,
88 })
89 }
90
91 pub fn with_bus(card: local::AgentCard, bus: Arc<AgentBus>) -> Arc<Self> {
115 let mut store = Self::new(card);
116 Arc::get_mut(&mut store)
117 .expect("fresh Arc must be uniquely owned")
118 .bus = Some(bus);
119 store
120 }
121
122 pub fn upsert_task(&self, task: local::Task) {
139 let _ = self.update_tx.send((task.id.clone(), task.status.clone()));
140 self.tasks.insert(task.id.clone(), task);
141 }
142
143 pub fn get_task(&self, id: &str) -> Option<local::Task> {
161 self.tasks.get(id).map(|r| r.value().clone())
162 }
163
164 pub fn subscribe_updates(&self) -> broadcast::Receiver<(String, local::TaskStatus)> {
180 self.update_tx.subscribe()
181 }
182
183 pub fn into_service(self: Arc<Self>) -> A2aServiceServer<A2aServiceImpl> {
200 A2aServiceServer::new(A2aServiceImpl { store: self })
201 }
202}
203pub struct A2aServiceImpl {
205 store: Arc<GrpcTaskStore>,
206}
207
208#[tonic::async_trait]
209impl A2aService for A2aServiceImpl {
210 async fn send_message(
213 &self,
214 request: Request<proto::SendMessageRequest>,
215 ) -> Result<Response<proto::SendMessageResponse>, Status> {
216 let req = request.into_inner();
217 let msg = req
218 .request
219 .ok_or_else(|| Status::invalid_argument("missing message"))?;
220 let local_msg = bridge::proto_message_to_local(&msg);
221
222 let task_id = local_msg
224 .task_id
225 .clone()
226 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
227
228 let task = local::Task {
229 id: task_id.clone(),
230 context_id: local_msg.context_id.clone(),
231 status: local::TaskStatus {
232 state: TaskState::Submitted,
233 message: Some(local_msg.clone()),
234 timestamp: Some(chrono::Utc::now().to_rfc3339()),
235 },
236 artifacts: vec![],
237 history: vec![local_msg],
238 metadata: Default::default(),
239 };
240 self.store.upsert_task(task.clone());
241
242 if let Some(ref bus) = self.store.bus {
244 let handle = bus.handle("grpc-server");
245 handle.send_task_update(&task_id, TaskState::Submitted, None);
246 }
247
248 let proto_task = bridge::local_task_to_proto(&task);
249 Ok(Response::new(proto::SendMessageResponse {
250 payload: Some(proto::send_message_response::Payload::Task(proto_task)),
251 }))
252 }
253
254 type SendStreamingMessageStream = StreamResult<proto::StreamResponse>;
257
258 async fn send_streaming_message(
259 &self,
260 request: Request<proto::SendMessageRequest>,
261 ) -> Result<Response<Self::SendStreamingMessageStream>, Status> {
262 let req = request.into_inner();
263 let msg = req
264 .request
265 .ok_or_else(|| Status::invalid_argument("missing message"))?;
266 let local_msg = bridge::proto_message_to_local(&msg);
267
268 let task_id = local_msg
269 .task_id
270 .clone()
271 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
272
273 let task = local::Task {
274 id: task_id.clone(),
275 context_id: local_msg.context_id.clone(),
276 status: local::TaskStatus {
277 state: TaskState::Submitted,
278 message: Some(local_msg.clone()),
279 timestamp: Some(chrono::Utc::now().to_rfc3339()),
280 },
281 artifacts: vec![],
282 history: vec![local_msg],
283 metadata: Default::default(),
284 };
285 self.store.upsert_task(task.clone());
286
287 let proto_task = bridge::local_task_to_proto(&task);
289 let mut rx = self.store.subscribe_updates();
290 let tid = task_id.clone();
291
292 let stream = async_stream::try_stream! {
293 yield proto::StreamResponse {
295 payload: Some(proto::stream_response::Payload::Task(proto_task)),
296 };
297
298 loop {
300 match rx.recv().await {
301 Ok((id, status)) if id == tid => {
302 let proto_status = bridge::local_task_status_to_proto(&status);
303 let is_terminal = status.state.is_terminal();
304 yield proto::StreamResponse {
305 payload: Some(proto::stream_response::Payload::StatusUpdate(
306 proto::TaskStatusUpdateEvent {
307 task_id: tid.clone(),
308 context_id: String::new(),
309 status: Some(proto_status),
310 r#final: is_terminal,
311 metadata: None,
312 },
313 )),
314 };
315 if is_terminal {
316 break;
317 }
318 }
319 Ok(_) => continue,
320 Err(broadcast::error::RecvError::Lagged(_)) => continue,
321 Err(broadcast::error::RecvError::Closed) => break,
322 }
323 }
324 };
325
326 Ok(Response::new(
327 Box::pin(stream) as Self::SendStreamingMessageStream
328 ))
329 }
330
331 async fn get_task(
334 &self,
335 request: Request<proto::GetTaskRequest>,
336 ) -> Result<Response<proto::Task>, Status> {
337 let req = request.into_inner();
338 let task_id = req.name.strip_prefix("tasks/").unwrap_or(&req.name);
339 let task = self
340 .store
341 .get_task(task_id)
342 .ok_or_else(|| Status::not_found(format!("task {task_id} not found")))?;
343 Ok(Response::new(bridge::local_task_to_proto(&task)))
344 }
345
346 async fn cancel_task(
349 &self,
350 request: Request<proto::CancelTaskRequest>,
351 ) -> Result<Response<proto::Task>, Status> {
352 let req = request.into_inner();
353 let task_id = req.name.strip_prefix("tasks/").unwrap_or(&req.name);
354
355 let mut task = self
356 .store
357 .tasks
358 .get_mut(task_id)
359 .ok_or_else(|| Status::not_found(format!("task {task_id} not found")))?;
360
361 if task.status.state.is_terminal() {
362 return Err(Status::failed_precondition(
363 "task already in terminal state",
364 ));
365 }
366
367 task.status = local::TaskStatus {
368 state: TaskState::Cancelled,
369 message: None,
370 timestamp: Some(chrono::Utc::now().to_rfc3339()),
371 };
372 let snapshot = task.clone();
373 drop(task);
374
375 let _ = self
376 .store
377 .update_tx
378 .send((task_id.to_string(), snapshot.status.clone()));
379
380 Ok(Response::new(bridge::local_task_to_proto(&snapshot)))
381 }
382
383 type TaskSubscriptionStream = StreamResult<proto::StreamResponse>;
386
387 async fn task_subscription(
388 &self,
389 request: Request<proto::TaskSubscriptionRequest>,
390 ) -> Result<Response<Self::TaskSubscriptionStream>, Status> {
391 let req = request.into_inner();
392 let task_id = req
393 .name
394 .strip_prefix("tasks/")
395 .unwrap_or(&req.name)
396 .to_string();
397
398 let task = self
399 .store
400 .get_task(&task_id)
401 .ok_or_else(|| Status::not_found(format!("task {task_id} not found")))?;
402
403 let proto_task = bridge::local_task_to_proto(&task);
404 let mut rx = self.store.subscribe_updates();
405 let tid = task_id.clone();
406
407 if task.status.state.is_terminal() {
409 let stream = async_stream::try_stream! {
410 yield proto::StreamResponse {
411 payload: Some(proto::stream_response::Payload::Task(proto_task)),
412 };
413 };
414 return Ok(Response::new(
415 Box::pin(stream) as Self::TaskSubscriptionStream
416 ));
417 }
418
419 let stream = async_stream::try_stream! {
420 yield proto::StreamResponse {
421 payload: Some(proto::stream_response::Payload::Task(proto_task)),
422 };
423
424 loop {
425 match rx.recv().await {
426 Ok((id, status)) if id == tid => {
427 let proto_status = bridge::local_task_status_to_proto(&status);
428 let is_terminal = status.state.is_terminal();
429 yield proto::StreamResponse {
430 payload: Some(proto::stream_response::Payload::StatusUpdate(
431 proto::TaskStatusUpdateEvent {
432 task_id: tid.clone(),
433 context_id: String::new(),
434 status: Some(proto_status),
435 r#final: is_terminal,
436 metadata: None,
437 },
438 )),
439 };
440 if is_terminal { break; }
441 }
442 Ok(_) => continue,
443 Err(broadcast::error::RecvError::Lagged(_)) => continue,
444 Err(broadcast::error::RecvError::Closed) => break,
445 }
446 }
447 };
448
449 Ok(Response::new(
450 Box::pin(stream) as Self::TaskSubscriptionStream
451 ))
452 }
453
454 async fn create_task_push_notification_config(
457 &self,
458 request: Request<proto::CreateTaskPushNotificationConfigRequest>,
459 ) -> Result<Response<proto::TaskPushNotificationConfig>, Status> {
460 let req = request.into_inner();
461 let task_id = req.parent.strip_prefix("tasks/").unwrap_or(&req.parent);
462
463 if self.store.get_task(task_id).is_none() {
464 return Err(Status::not_found(format!("task {task_id} not found")));
465 }
466
467 let config = req
468 .config
469 .ok_or_else(|| Status::invalid_argument("missing config"))?;
470 let pnc = config.push_notification_config.as_ref();
471
472 let local_config = local::TaskPushNotificationConfig {
473 id: task_id.to_string(),
474 push_notification_config: local::PushNotificationConfig {
475 url: pnc.map(|c| c.url.clone()).unwrap_or_default(),
476 token: pnc.and_then(|c| {
477 if c.token.is_empty() {
478 None
479 } else {
480 Some(c.token.clone())
481 }
482 }),
483 id: pnc.and_then(|c| {
484 if c.id.is_empty() {
485 None
486 } else {
487 Some(c.id.clone())
488 }
489 }),
490 },
491 };
492
493 self.store
494 .push_configs
495 .entry(task_id.to_string())
496 .or_default()
497 .push(local_config);
498
499 Ok(Response::new(config))
500 }
501
502 async fn get_task_push_notification_config(
503 &self,
504 request: Request<proto::GetTaskPushNotificationConfigRequest>,
505 ) -> Result<Response<proto::TaskPushNotificationConfig>, Status> {
506 let req = request.into_inner();
507 let parts: Vec<&str> = req.name.split('/').collect();
509 if parts.len() < 4 {
510 return Err(Status::invalid_argument("invalid name format"));
511 }
512 let task_id = parts[1];
513 let config_id = parts[3];
514
515 let configs = self
516 .store
517 .push_configs
518 .get(task_id)
519 .ok_or_else(|| Status::not_found("no configs for task"))?;
520
521 let _found = configs
522 .iter()
523 .find(|c| c.push_notification_config.id.as_deref() == Some(config_id))
524 .ok_or_else(|| Status::not_found("config not found"))?;
525
526 Ok(Response::new(proto::TaskPushNotificationConfig {
527 name: req.name,
528 push_notification_config: None, }))
530 }
531
532 async fn list_task_push_notification_config(
533 &self,
534 request: Request<proto::ListTaskPushNotificationConfigRequest>,
535 ) -> Result<Response<proto::ListTaskPushNotificationConfigResponse>, Status> {
536 let req = request.into_inner();
537 let task_id = req.parent.strip_prefix("tasks/").unwrap_or(&req.parent);
538
539 let configs: Vec<proto::TaskPushNotificationConfig> = self
540 .store
541 .push_configs
542 .get(task_id)
543 .map(|cs| {
544 cs.iter()
545 .map(|c| proto::TaskPushNotificationConfig {
546 name: format!(
547 "tasks/{}/pushNotificationConfigs/{}",
548 task_id,
549 c.push_notification_config
550 .id
551 .as_deref()
552 .unwrap_or("default")
553 ),
554 push_notification_config: None,
555 })
556 .collect()
557 })
558 .unwrap_or_default();
559
560 Ok(Response::new(
561 proto::ListTaskPushNotificationConfigResponse {
562 configs,
563 next_page_token: String::new(),
564 },
565 ))
566 }
567
568 async fn delete_task_push_notification_config(
569 &self,
570 request: Request<proto::DeleteTaskPushNotificationConfigRequest>,
571 ) -> Result<Response<()>, Status> {
572 let req = request.into_inner();
573 let parts: Vec<&str> = req.name.split('/').collect();
574 if parts.len() < 4 {
575 return Err(Status::invalid_argument("invalid name format"));
576 }
577 let task_id = parts[1];
578 let config_id = parts[3];
579
580 if let Some(mut configs) = self.store.push_configs.get_mut(task_id) {
581 configs.retain(|c| c.push_notification_config.id.as_deref() != Some(config_id));
582 }
583
584 Ok(Response::new(()))
585 }
586
587 async fn get_agent_card(
590 &self,
591 _request: Request<proto::GetAgentCardRequest>,
592 ) -> Result<Response<proto::AgentCard>, Status> {
593 Ok(Response::new(bridge::local_card_to_proto(&self.store.card)))
594 }
595}