Skip to main content

codetether_agent/a2a/
grpc.rs

1//! gRPC transport for the A2A protocol.
2//!
3//! Implements the `A2AService` trait generated by tonic-build from
4//! `proto/a2a/v1/a2a.proto`.  Delegates to the same task store and bus
5//! that the JSON-RPC server uses, so both transports share state.
6
7use crate::a2a::bridge;
8use crate::a2a::proto;
9use crate::a2a::proto::a2a_service_server::{A2aService, A2aServiceServer};
10use crate::a2a::types::{self as local, TaskState};
11use crate::bus::AgentBus;
12
13use dashmap::DashMap;
14use std::pin::Pin;
15use std::sync::Arc;
16use tokio::sync::broadcast;
17use tokio_stream::Stream;
18use tonic::{Request, Response, Status};
19
20type StreamResult<T> = Pin<Box<dyn Stream<Item = Result<T, Status>> + Send>>;
21/// Shared state backing both JSON-RPC and gRPC transports.
22///
23/// Stores task records, push-notification configuration, agent metadata, and
24/// task-update fan-out channels for the gRPC service layer. The store is safe to
25/// share across concurrent request handlers: task and push-configuration maps
26/// use lock-free concurrent access through [`DashMap`], while task updates are
27/// distributed with a Tokio broadcast channel.
28///
29/// # Responsibilities
30///
31/// * Keep the authoritative in-process [`local::Task`] records by task ID.
32/// * Track per-task [`local::TaskPushNotificationConfig`] entries used by
33///   push-notification subscribers.
34/// * Expose the local [`local::AgentCard`] advertised by this server.
35/// * Optionally publish task lifecycle events through [`AgentBus`].
36/// * Broadcast task status changes to active streaming subscribers.
37///
38/// # Invariants
39///
40/// * Keys in `tasks` and `push_configs` are task IDs.
41/// * Values sent through `update_tx` use the same task ID namespace as `tasks`.
42/// * `card` describes the agent instance served by this store.
43pub struct GrpcTaskStore {
44    tasks: DashMap<String, local::Task>,
45    push_configs: DashMap<String, Vec<local::TaskPushNotificationConfig>>,
46    card: local::AgentCard,
47    bus: Option<Arc<AgentBus>>,
48    /// Broadcast for task updates (task_id, status)
49    ///
50    /// Sends each observed task status transition to subscribers as a
51    /// `(task_id, status)` tuple. Receivers may lag or miss messages according
52    /// to Tokio broadcast-channel semantics, so consumers should treat this as
53    /// a live update stream rather than durable task history.
54    update_tx: broadcast::Sender<(String, local::TaskStatus)>,
55}
56impl GrpcTaskStore {
57    /// Create a new task store with the given agent card.
58    ///
59    /// Initializes an empty, shareable [`GrpcTaskStore`] for a single local
60    /// agent. The returned [`Arc`] is intended to be cloned into gRPC handlers,
61    /// JSON-RPC adapters, streaming tasks, and other transport glue that needs
62    /// access to the same in-process task state.
63    ///
64    /// # Parameters
65    ///
66    /// * `card` - Agent metadata served by this store and returned to clients
67    ///   that request the local agent card.
68    ///
69    /// # Returns
70    ///
71    /// A reference-counted store with no tasks, no push-notification
72    /// configuration, no attached [`AgentBus`], and a task-update broadcast
73    /// channel sized for 256 queued status updates per receiver.
74    ///
75    /// # Side effects
76    ///
77    /// Allocates the concurrent maps and creates a Tokio broadcast channel. It
78    /// does not publish network services, spawn background tasks, or emit any
79    /// task updates.
80    pub fn new(card: local::AgentCard) -> Arc<Self> {
81        let (update_tx, _) = broadcast::channel(256);
82        Arc::new(Self {
83            tasks: DashMap::new(),
84            push_configs: DashMap::new(),
85            card,
86            bus: None,
87            update_tx,
88        })
89    }
90
91    /// Create with an agent bus attached.
92    ///
93    /// Constructs a new store using [`Self::new`] and attaches an [`AgentBus`]
94    /// before the store is shared. The bus allows gRPC task activity to be
95    /// mirrored into the broader agent-event system while preserving the same
96    /// task storage and update-stream behavior as [`Self::new`].
97    ///
98    /// # Parameters
99    ///
100    /// * `card` - Agent metadata served by this store.
101    /// * `bus` - Shared event bus used to publish or coordinate task activity
102    ///   outside the gRPC transport.
103    ///
104    /// # Returns
105    ///
106    /// A reference-counted store with the supplied bus installed.
107    ///
108    /// # Panics
109    ///
110    /// Panics only if the freshly created [`Arc`] is not uniquely owned before
111    /// installing the bus. Under normal control flow this cannot happen because
112    /// the [`Arc`] has just been returned by [`Self::new`] and has not yet been
113    /// cloned.
114    pub fn with_bus(card: local::AgentCard, bus: Arc<AgentBus>) -> Arc<Self> {
115        let mut store = Self::new(card);
116        Arc::get_mut(&mut store)
117            .expect("fresh Arc must be uniquely owned")
118            .bus = Some(bus);
119        store
120    }
121
122    /// Insert or update a task.
123    ///
124    /// Stores the provided task by its ID, replacing any existing task with the
125    /// same ID. Before writing the task into the map, this method broadcasts the
126    /// task's current status so live subscribers can react to status changes.
127    ///
128    /// # Parameters
129    ///
130    /// * `task` - Complete task record to insert or replace. Its `id` field is
131    ///   used as the storage key and as the task ID in the broadcast update.
132    ///
133    /// # Side effects
134    ///
135    /// Sends a best-effort status update on the broadcast channel and mutates
136    /// the in-memory task map. If there are no active receivers, or if receivers
137    /// have lagged, the broadcast result is ignored; the task is still stored.
138    pub fn upsert_task(&self, task: local::Task) {
139        let _ = self.update_tx.send((task.id.clone(), task.status.clone()));
140        self.tasks.insert(task.id.clone(), task);
141    }
142
143    /// Get a task by id.
144    ///
145    /// Looks up a task in the in-memory store and returns an owned clone so the
146    /// caller can inspect or serialize it without holding a [`DashMap`] guard.
147    ///
148    /// # Parameters
149    ///
150    /// * `id` - Task identifier to retrieve.
151    ///
152    /// # Returns
153    ///
154    /// `Some(local::Task)` when a task with the given ID exists, or `None` when
155    /// the store has no matching task.
156    ///
157    /// # Side effects
158    ///
159    /// This method does not mutate store state or emit task updates.
160    pub fn get_task(&self, id: &str) -> Option<local::Task> {
161        self.tasks.get(id).map(|r| r.value().clone())
162    }
163
164    /// Subscribe to task update notifications.
165    ///
166    /// Creates a new receiver for the live task-status broadcast stream. Each
167    /// message contains the task ID and the status observed by [`Self::upsert_task`].
168    ///
169    /// # Returns
170    ///
171    /// A Tokio broadcast receiver for `(task_id, status)` tuples. Broadcast
172    /// receivers are live streams, not durable logs; slow receivers may observe
173    /// lag errors according to Tokio broadcast-channel semantics.
174    ///
175    /// # Side effects
176    ///
177    /// Registers a new receiver with the broadcast channel. It does not read,
178    /// mutate, or replay existing tasks.
179    pub fn subscribe_updates(&self) -> broadcast::Receiver<(String, local::TaskStatus)> {
180        self.update_tx.subscribe()
181    }
182
183    /// Build the tonic service layer from this store.
184    ///
185    /// Wraps the shared task store in an [`A2aServiceImpl`] and returns the
186    /// generated tonic [`A2aServiceServer`] ready to be mounted into a gRPC
187    /// server.
188    ///
189    /// # Returns
190    ///
191    /// A tonic service server that handles A2A gRPC requests using this store as
192    /// its backing state.
193    ///
194    /// # Side effects
195    ///
196    /// Consumes one [`Arc`] handle to the store and moves it into the service
197    /// implementation. This method does not start listening on a socket; the
198    /// caller is responsible for adding the returned service to a tonic server.
199    pub fn into_service(self: Arc<Self>) -> A2aServiceServer<A2aServiceImpl> {
200        A2aServiceServer::new(A2aServiceImpl { store: self })
201    }
202}
203/// Tonic service implementation.
204pub struct A2aServiceImpl {
205    store: Arc<GrpcTaskStore>,
206}
207
208#[tonic::async_trait]
209impl A2aService for A2aServiceImpl {
210    // ── SendMessage ─────────────────────────────────────────────────────
211
212    async fn send_message(
213        &self,
214        request: Request<proto::SendMessageRequest>,
215    ) -> Result<Response<proto::SendMessageResponse>, Status> {
216        let req = request.into_inner();
217        let msg = req
218            .request
219            .ok_or_else(|| Status::invalid_argument("missing message"))?;
220        let local_msg = bridge::proto_message_to_local(&msg);
221
222        // Create / route task
223        let task_id = local_msg
224            .task_id
225            .clone()
226            .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
227
228        let task = local::Task {
229            id: task_id.clone(),
230            context_id: local_msg.context_id.clone(),
231            status: local::TaskStatus {
232                state: TaskState::Submitted,
233                message: Some(local_msg.clone()),
234                timestamp: Some(chrono::Utc::now().to_rfc3339()),
235            },
236            artifacts: vec![],
237            history: vec![local_msg],
238            metadata: Default::default(),
239        };
240        self.store.upsert_task(task.clone());
241
242        // Notify bus
243        if let Some(ref bus) = self.store.bus {
244            let handle = bus.handle("grpc-server");
245            handle.send_task_update(&task_id, TaskState::Submitted, None);
246        }
247
248        let proto_task = bridge::local_task_to_proto(&task);
249        Ok(Response::new(proto::SendMessageResponse {
250            payload: Some(proto::send_message_response::Payload::Task(proto_task)),
251        }))
252    }
253
254    // ── SendStreamingMessage ────────────────────────────────────────────
255
256    type SendStreamingMessageStream = StreamResult<proto::StreamResponse>;
257
258    async fn send_streaming_message(
259        &self,
260        request: Request<proto::SendMessageRequest>,
261    ) -> Result<Response<Self::SendStreamingMessageStream>, Status> {
262        let req = request.into_inner();
263        let msg = req
264            .request
265            .ok_or_else(|| Status::invalid_argument("missing message"))?;
266        let local_msg = bridge::proto_message_to_local(&msg);
267
268        let task_id = local_msg
269            .task_id
270            .clone()
271            .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
272
273        let task = local::Task {
274            id: task_id.clone(),
275            context_id: local_msg.context_id.clone(),
276            status: local::TaskStatus {
277                state: TaskState::Submitted,
278                message: Some(local_msg.clone()),
279                timestamp: Some(chrono::Utc::now().to_rfc3339()),
280            },
281            artifacts: vec![],
282            history: vec![local_msg],
283            metadata: Default::default(),
284        };
285        self.store.upsert_task(task.clone());
286
287        // Return the initial task then listen for updates
288        let proto_task = bridge::local_task_to_proto(&task);
289        let mut rx = self.store.subscribe_updates();
290        let tid = task_id.clone();
291
292        let stream = async_stream::try_stream! {
293            // First frame: the task itself
294            yield proto::StreamResponse {
295                payload: Some(proto::stream_response::Payload::Task(proto_task)),
296            };
297
298            // Subsequent frames: status updates for this task
299            loop {
300                match rx.recv().await {
301                    Ok((id, status)) if id == tid => {
302                        let proto_status = bridge::local_task_status_to_proto(&status);
303                        let is_terminal = status.state.is_terminal();
304                        yield proto::StreamResponse {
305                            payload: Some(proto::stream_response::Payload::StatusUpdate(
306                                proto::TaskStatusUpdateEvent {
307                                    task_id: tid.clone(),
308                                    context_id: String::new(),
309                                    status: Some(proto_status),
310                                    r#final: is_terminal,
311                                    metadata: None,
312                                },
313                            )),
314                        };
315                        if is_terminal {
316                            break;
317                        }
318                    }
319                    Ok(_) => continue,
320                    Err(broadcast::error::RecvError::Lagged(_)) => continue,
321                    Err(broadcast::error::RecvError::Closed) => break,
322                }
323            }
324        };
325
326        Ok(Response::new(
327            Box::pin(stream) as Self::SendStreamingMessageStream
328        ))
329    }
330
331    // ── GetTask ─────────────────────────────────────────────────────────
332
333    async fn get_task(
334        &self,
335        request: Request<proto::GetTaskRequest>,
336    ) -> Result<Response<proto::Task>, Status> {
337        let req = request.into_inner();
338        let task_id = req.name.strip_prefix("tasks/").unwrap_or(&req.name);
339        let task = self
340            .store
341            .get_task(task_id)
342            .ok_or_else(|| Status::not_found(format!("task {task_id} not found")))?;
343        Ok(Response::new(bridge::local_task_to_proto(&task)))
344    }
345
346    // ── CancelTask ──────────────────────────────────────────────────────
347
348    async fn cancel_task(
349        &self,
350        request: Request<proto::CancelTaskRequest>,
351    ) -> Result<Response<proto::Task>, Status> {
352        let req = request.into_inner();
353        let task_id = req.name.strip_prefix("tasks/").unwrap_or(&req.name);
354
355        let mut task = self
356            .store
357            .tasks
358            .get_mut(task_id)
359            .ok_or_else(|| Status::not_found(format!("task {task_id} not found")))?;
360
361        if task.status.state.is_terminal() {
362            return Err(Status::failed_precondition(
363                "task already in terminal state",
364            ));
365        }
366
367        task.status = local::TaskStatus {
368            state: TaskState::Cancelled,
369            message: None,
370            timestamp: Some(chrono::Utc::now().to_rfc3339()),
371        };
372        let snapshot = task.clone();
373        drop(task);
374
375        let _ = self
376            .store
377            .update_tx
378            .send((task_id.to_string(), snapshot.status.clone()));
379
380        Ok(Response::new(bridge::local_task_to_proto(&snapshot)))
381    }
382
383    // ── TaskSubscription ────────────────────────────────────────────────
384
385    type TaskSubscriptionStream = StreamResult<proto::StreamResponse>;
386
387    async fn task_subscription(
388        &self,
389        request: Request<proto::TaskSubscriptionRequest>,
390    ) -> Result<Response<Self::TaskSubscriptionStream>, Status> {
391        let req = request.into_inner();
392        let task_id = req
393            .name
394            .strip_prefix("tasks/")
395            .unwrap_or(&req.name)
396            .to_string();
397
398        let task = self
399            .store
400            .get_task(&task_id)
401            .ok_or_else(|| Status::not_found(format!("task {task_id} not found")))?;
402
403        let proto_task = bridge::local_task_to_proto(&task);
404        let mut rx = self.store.subscribe_updates();
405        let tid = task_id.clone();
406
407        // If already terminal, return just the task and close
408        if task.status.state.is_terminal() {
409            let stream = async_stream::try_stream! {
410                yield proto::StreamResponse {
411                    payload: Some(proto::stream_response::Payload::Task(proto_task)),
412                };
413            };
414            return Ok(Response::new(
415                Box::pin(stream) as Self::TaskSubscriptionStream
416            ));
417        }
418
419        let stream = async_stream::try_stream! {
420            yield proto::StreamResponse {
421                payload: Some(proto::stream_response::Payload::Task(proto_task)),
422            };
423
424            loop {
425                match rx.recv().await {
426                    Ok((id, status)) if id == tid => {
427                        let proto_status = bridge::local_task_status_to_proto(&status);
428                        let is_terminal = status.state.is_terminal();
429                        yield proto::StreamResponse {
430                            payload: Some(proto::stream_response::Payload::StatusUpdate(
431                                proto::TaskStatusUpdateEvent {
432                                    task_id: tid.clone(),
433                                    context_id: String::new(),
434                                    status: Some(proto_status),
435                                    r#final: is_terminal,
436                                    metadata: None,
437                                },
438                            )),
439                        };
440                        if is_terminal { break; }
441                    }
442                    Ok(_) => continue,
443                    Err(broadcast::error::RecvError::Lagged(_)) => continue,
444                    Err(broadcast::error::RecvError::Closed) => break,
445                }
446            }
447        };
448
449        Ok(Response::new(
450            Box::pin(stream) as Self::TaskSubscriptionStream
451        ))
452    }
453
454    // ── Push notification config CRUD ───────────────────────────────────
455
456    async fn create_task_push_notification_config(
457        &self,
458        request: Request<proto::CreateTaskPushNotificationConfigRequest>,
459    ) -> Result<Response<proto::TaskPushNotificationConfig>, Status> {
460        let req = request.into_inner();
461        let task_id = req.parent.strip_prefix("tasks/").unwrap_or(&req.parent);
462
463        if self.store.get_task(task_id).is_none() {
464            return Err(Status::not_found(format!("task {task_id} not found")));
465        }
466
467        let config = req
468            .config
469            .ok_or_else(|| Status::invalid_argument("missing config"))?;
470        let pnc = config.push_notification_config.as_ref();
471
472        let local_config = local::TaskPushNotificationConfig {
473            id: task_id.to_string(),
474            push_notification_config: local::PushNotificationConfig {
475                url: pnc.map(|c| c.url.clone()).unwrap_or_default(),
476                token: pnc.and_then(|c| {
477                    if c.token.is_empty() {
478                        None
479                    } else {
480                        Some(c.token.clone())
481                    }
482                }),
483                id: pnc.and_then(|c| {
484                    if c.id.is_empty() {
485                        None
486                    } else {
487                        Some(c.id.clone())
488                    }
489                }),
490            },
491        };
492
493        self.store
494            .push_configs
495            .entry(task_id.to_string())
496            .or_default()
497            .push(local_config);
498
499        Ok(Response::new(config))
500    }
501
502    async fn get_task_push_notification_config(
503        &self,
504        request: Request<proto::GetTaskPushNotificationConfigRequest>,
505    ) -> Result<Response<proto::TaskPushNotificationConfig>, Status> {
506        let req = request.into_inner();
507        // name format: tasks/{task_id}/pushNotificationConfigs/{config_id}
508        let parts: Vec<&str> = req.name.split('/').collect();
509        if parts.len() < 4 {
510            return Err(Status::invalid_argument("invalid name format"));
511        }
512        let task_id = parts[1];
513        let config_id = parts[3];
514
515        let configs = self
516            .store
517            .push_configs
518            .get(task_id)
519            .ok_or_else(|| Status::not_found("no configs for task"))?;
520
521        let _found = configs
522            .iter()
523            .find(|c| c.push_notification_config.id.as_deref() == Some(config_id))
524            .ok_or_else(|| Status::not_found("config not found"))?;
525
526        Ok(Response::new(proto::TaskPushNotificationConfig {
527            name: req.name,
528            push_notification_config: None, // simplified
529        }))
530    }
531
532    async fn list_task_push_notification_config(
533        &self,
534        request: Request<proto::ListTaskPushNotificationConfigRequest>,
535    ) -> Result<Response<proto::ListTaskPushNotificationConfigResponse>, Status> {
536        let req = request.into_inner();
537        let task_id = req.parent.strip_prefix("tasks/").unwrap_or(&req.parent);
538
539        let configs: Vec<proto::TaskPushNotificationConfig> = self
540            .store
541            .push_configs
542            .get(task_id)
543            .map(|cs| {
544                cs.iter()
545                    .map(|c| proto::TaskPushNotificationConfig {
546                        name: format!(
547                            "tasks/{}/pushNotificationConfigs/{}",
548                            task_id,
549                            c.push_notification_config
550                                .id
551                                .as_deref()
552                                .unwrap_or("default")
553                        ),
554                        push_notification_config: None,
555                    })
556                    .collect()
557            })
558            .unwrap_or_default();
559
560        Ok(Response::new(
561            proto::ListTaskPushNotificationConfigResponse {
562                configs,
563                next_page_token: String::new(),
564            },
565        ))
566    }
567
568    async fn delete_task_push_notification_config(
569        &self,
570        request: Request<proto::DeleteTaskPushNotificationConfigRequest>,
571    ) -> Result<Response<()>, Status> {
572        let req = request.into_inner();
573        let parts: Vec<&str> = req.name.split('/').collect();
574        if parts.len() < 4 {
575            return Err(Status::invalid_argument("invalid name format"));
576        }
577        let task_id = parts[1];
578        let config_id = parts[3];
579
580        if let Some(mut configs) = self.store.push_configs.get_mut(task_id) {
581            configs.retain(|c| c.push_notification_config.id.as_deref() != Some(config_id));
582        }
583
584        Ok(Response::new(()))
585    }
586
587    // ── GetAgentCard ────────────────────────────────────────────────────
588
589    async fn get_agent_card(
590        &self,
591        _request: Request<proto::GetAgentCardRequest>,
592    ) -> Result<Response<proto::AgentCard>, Status> {
593        Ok(Response::new(bridge::local_card_to_proto(&self.store.card)))
594    }
595}