Skip to main content

codetether_agent/a2a/
grpc.rs

1//! gRPC transport for the A2A protocol.
2//!
3//! Implements the `A2AService` trait generated by tonic-build from
4//! `proto/a2a/v1/a2a.proto`.  Delegates to the same task store and bus
5//! that the JSON-RPC server uses, so both transports share state.
6
7use crate::a2a::bridge;
8use crate::a2a::proto;
9use crate::a2a::proto::a2a_service_server::{A2aService, A2aServiceServer};
10use crate::a2a::types::{self as local, TaskState};
11use crate::bus::AgentBus;
12
13use dashmap::DashMap;
14use std::pin::Pin;
15use std::sync::Arc;
16use tokio::sync::broadcast;
17use tokio_stream::Stream;
18use tonic::{Request, Response, Status};
19
20type StreamResult<T> = Pin<Box<dyn Stream<Item = Result<T, Status>> + Send>>;
21
22/// Shared state backing both JSON-RPC and gRPC transports.
23pub struct GrpcTaskStore {
24    tasks: DashMap<String, local::Task>,
25    push_configs: DashMap<String, Vec<local::TaskPushNotificationConfig>>,
26    card: local::AgentCard,
27    bus: Option<Arc<AgentBus>>,
28    /// Broadcast for task updates (task_id, status)
29    update_tx: broadcast::Sender<(String, local::TaskStatus)>,
30}
31
32impl GrpcTaskStore {
33    /// Create a new task store with the given agent card.
34    pub fn new(card: local::AgentCard) -> Arc<Self> {
35        let (update_tx, _) = broadcast::channel(256);
36        Arc::new(Self {
37            tasks: DashMap::new(),
38            push_configs: DashMap::new(),
39            card,
40            bus: None,
41            update_tx,
42        })
43    }
44
45    /// Create with an agent bus attached.
46    pub fn with_bus(card: local::AgentCard, bus: Arc<AgentBus>) -> Arc<Self> {
47        let mut store = Self::new(card);
48        Arc::get_mut(&mut store)
49            .expect("fresh Arc must be uniquely owned")
50            .bus = Some(bus);
51        store
52    }
53
54    /// Insert or update a task.
55    pub fn upsert_task(&self, task: local::Task) {
56        let _ = self.update_tx.send((task.id.clone(), task.status.clone()));
57        self.tasks.insert(task.id.clone(), task);
58    }
59
60    /// Get a task by id.
61    pub fn get_task(&self, id: &str) -> Option<local::Task> {
62        self.tasks.get(id).map(|r| r.value().clone())
63    }
64
65    /// Subscribe to task update notifications.
66    pub fn subscribe_updates(&self) -> broadcast::Receiver<(String, local::TaskStatus)> {
67        self.update_tx.subscribe()
68    }
69
70    /// Build the tonic service layer from this store.
71    pub fn into_service(self: Arc<Self>) -> A2aServiceServer<A2aServiceImpl> {
72        A2aServiceServer::new(A2aServiceImpl { store: self })
73    }
74}
75
76/// Tonic service implementation.
77pub struct A2aServiceImpl {
78    store: Arc<GrpcTaskStore>,
79}
80
81#[tonic::async_trait]
82impl A2aService for A2aServiceImpl {
83    // ── SendMessage ─────────────────────────────────────────────────────
84
85    async fn send_message(
86        &self,
87        request: Request<proto::SendMessageRequest>,
88    ) -> Result<Response<proto::SendMessageResponse>, Status> {
89        let req = request.into_inner();
90        let msg = req
91            .request
92            .ok_or_else(|| Status::invalid_argument("missing message"))?;
93        let local_msg = bridge::proto_message_to_local(&msg);
94
95        // Create / route task
96        let task_id = local_msg
97            .task_id
98            .clone()
99            .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
100
101        let task = local::Task {
102            id: task_id.clone(),
103            context_id: local_msg.context_id.clone(),
104            status: local::TaskStatus {
105                state: TaskState::Submitted,
106                message: Some(local_msg.clone()),
107                timestamp: Some(chrono::Utc::now().to_rfc3339()),
108            },
109            artifacts: vec![],
110            history: vec![local_msg],
111            metadata: Default::default(),
112        };
113        self.store.upsert_task(task.clone());
114
115        // Notify bus
116        if let Some(ref bus) = self.store.bus {
117            let handle = bus.handle("grpc-server");
118            handle.send_task_update(&task_id, TaskState::Submitted, None);
119        }
120
121        let proto_task = bridge::local_task_to_proto(&task);
122        Ok(Response::new(proto::SendMessageResponse {
123            payload: Some(proto::send_message_response::Payload::Task(proto_task)),
124        }))
125    }
126
127    // ── SendStreamingMessage ────────────────────────────────────────────
128
129    type SendStreamingMessageStream = StreamResult<proto::StreamResponse>;
130
131    async fn send_streaming_message(
132        &self,
133        request: Request<proto::SendMessageRequest>,
134    ) -> Result<Response<Self::SendStreamingMessageStream>, Status> {
135        let req = request.into_inner();
136        let msg = req
137            .request
138            .ok_or_else(|| Status::invalid_argument("missing message"))?;
139        let local_msg = bridge::proto_message_to_local(&msg);
140
141        let task_id = local_msg
142            .task_id
143            .clone()
144            .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
145
146        let task = local::Task {
147            id: task_id.clone(),
148            context_id: local_msg.context_id.clone(),
149            status: local::TaskStatus {
150                state: TaskState::Submitted,
151                message: Some(local_msg.clone()),
152                timestamp: Some(chrono::Utc::now().to_rfc3339()),
153            },
154            artifacts: vec![],
155            history: vec![local_msg],
156            metadata: Default::default(),
157        };
158        self.store.upsert_task(task.clone());
159
160        // Return the initial task then listen for updates
161        let proto_task = bridge::local_task_to_proto(&task);
162        let mut rx = self.store.subscribe_updates();
163        let tid = task_id.clone();
164
165        let stream = async_stream::try_stream! {
166            // First frame: the task itself
167            yield proto::StreamResponse {
168                payload: Some(proto::stream_response::Payload::Task(proto_task)),
169            };
170
171            // Subsequent frames: status updates for this task
172            loop {
173                match rx.recv().await {
174                    Ok((id, status)) if id == tid => {
175                        let proto_status = bridge::local_task_status_to_proto(&status);
176                        let is_terminal = status.state.is_terminal();
177                        yield proto::StreamResponse {
178                            payload: Some(proto::stream_response::Payload::StatusUpdate(
179                                proto::TaskStatusUpdateEvent {
180                                    task_id: tid.clone(),
181                                    context_id: String::new(),
182                                    status: Some(proto_status),
183                                    r#final: is_terminal,
184                                    metadata: None,
185                                },
186                            )),
187                        };
188                        if is_terminal {
189                            break;
190                        }
191                    }
192                    Ok(_) => continue,
193                    Err(broadcast::error::RecvError::Lagged(_)) => continue,
194                    Err(broadcast::error::RecvError::Closed) => break,
195                }
196            }
197        };
198
199        Ok(Response::new(
200            Box::pin(stream) as Self::SendStreamingMessageStream
201        ))
202    }
203
204    // ── GetTask ─────────────────────────────────────────────────────────
205
206    async fn get_task(
207        &self,
208        request: Request<proto::GetTaskRequest>,
209    ) -> Result<Response<proto::Task>, Status> {
210        let req = request.into_inner();
211        let task_id = req.name.strip_prefix("tasks/").unwrap_or(&req.name);
212        let task = self
213            .store
214            .get_task(task_id)
215            .ok_or_else(|| Status::not_found(format!("task {task_id} not found")))?;
216        Ok(Response::new(bridge::local_task_to_proto(&task)))
217    }
218
219    // ── CancelTask ──────────────────────────────────────────────────────
220
221    async fn cancel_task(
222        &self,
223        request: Request<proto::CancelTaskRequest>,
224    ) -> Result<Response<proto::Task>, Status> {
225        let req = request.into_inner();
226        let task_id = req.name.strip_prefix("tasks/").unwrap_or(&req.name);
227
228        let mut task = self
229            .store
230            .tasks
231            .get_mut(task_id)
232            .ok_or_else(|| Status::not_found(format!("task {task_id} not found")))?;
233
234        if task.status.state.is_terminal() {
235            return Err(Status::failed_precondition(
236                "task already in terminal state",
237            ));
238        }
239
240        task.status = local::TaskStatus {
241            state: TaskState::Cancelled,
242            message: None,
243            timestamp: Some(chrono::Utc::now().to_rfc3339()),
244        };
245        let snapshot = task.clone();
246        drop(task);
247
248        let _ = self
249            .store
250            .update_tx
251            .send((task_id.to_string(), snapshot.status.clone()));
252
253        Ok(Response::new(bridge::local_task_to_proto(&snapshot)))
254    }
255
256    // ── TaskSubscription ────────────────────────────────────────────────
257
258    type TaskSubscriptionStream = StreamResult<proto::StreamResponse>;
259
260    async fn task_subscription(
261        &self,
262        request: Request<proto::TaskSubscriptionRequest>,
263    ) -> Result<Response<Self::TaskSubscriptionStream>, Status> {
264        let req = request.into_inner();
265        let task_id = req
266            .name
267            .strip_prefix("tasks/")
268            .unwrap_or(&req.name)
269            .to_string();
270
271        let task = self
272            .store
273            .get_task(&task_id)
274            .ok_or_else(|| Status::not_found(format!("task {task_id} not found")))?;
275
276        let proto_task = bridge::local_task_to_proto(&task);
277        let mut rx = self.store.subscribe_updates();
278        let tid = task_id.clone();
279
280        // If already terminal, return just the task and close
281        if task.status.state.is_terminal() {
282            let stream = async_stream::try_stream! {
283                yield proto::StreamResponse {
284                    payload: Some(proto::stream_response::Payload::Task(proto_task)),
285                };
286            };
287            return Ok(Response::new(
288                Box::pin(stream) as Self::TaskSubscriptionStream
289            ));
290        }
291
292        let stream = async_stream::try_stream! {
293            yield proto::StreamResponse {
294                payload: Some(proto::stream_response::Payload::Task(proto_task)),
295            };
296
297            loop {
298                match rx.recv().await {
299                    Ok((id, status)) if id == tid => {
300                        let proto_status = bridge::local_task_status_to_proto(&status);
301                        let is_terminal = status.state.is_terminal();
302                        yield proto::StreamResponse {
303                            payload: Some(proto::stream_response::Payload::StatusUpdate(
304                                proto::TaskStatusUpdateEvent {
305                                    task_id: tid.clone(),
306                                    context_id: String::new(),
307                                    status: Some(proto_status),
308                                    r#final: is_terminal,
309                                    metadata: None,
310                                },
311                            )),
312                        };
313                        if is_terminal { break; }
314                    }
315                    Ok(_) => continue,
316                    Err(broadcast::error::RecvError::Lagged(_)) => continue,
317                    Err(broadcast::error::RecvError::Closed) => break,
318                }
319            }
320        };
321
322        Ok(Response::new(
323            Box::pin(stream) as Self::TaskSubscriptionStream
324        ))
325    }
326
327    // ── Push notification config CRUD ───────────────────────────────────
328
329    async fn create_task_push_notification_config(
330        &self,
331        request: Request<proto::CreateTaskPushNotificationConfigRequest>,
332    ) -> Result<Response<proto::TaskPushNotificationConfig>, Status> {
333        let req = request.into_inner();
334        let task_id = req.parent.strip_prefix("tasks/").unwrap_or(&req.parent);
335
336        if self.store.get_task(task_id).is_none() {
337            return Err(Status::not_found(format!("task {task_id} not found")));
338        }
339
340        let config = req
341            .config
342            .ok_or_else(|| Status::invalid_argument("missing config"))?;
343        let pnc = config.push_notification_config.as_ref();
344
345        let local_config = local::TaskPushNotificationConfig {
346            id: task_id.to_string(),
347            push_notification_config: local::PushNotificationConfig {
348                url: pnc.map(|c| c.url.clone()).unwrap_or_default(),
349                token: pnc.and_then(|c| {
350                    if c.token.is_empty() {
351                        None
352                    } else {
353                        Some(c.token.clone())
354                    }
355                }),
356                id: pnc.and_then(|c| {
357                    if c.id.is_empty() {
358                        None
359                    } else {
360                        Some(c.id.clone())
361                    }
362                }),
363            },
364        };
365
366        self.store
367            .push_configs
368            .entry(task_id.to_string())
369            .or_default()
370            .push(local_config);
371
372        Ok(Response::new(config))
373    }
374
375    async fn get_task_push_notification_config(
376        &self,
377        request: Request<proto::GetTaskPushNotificationConfigRequest>,
378    ) -> Result<Response<proto::TaskPushNotificationConfig>, Status> {
379        let req = request.into_inner();
380        // name format: tasks/{task_id}/pushNotificationConfigs/{config_id}
381        let parts: Vec<&str> = req.name.split('/').collect();
382        if parts.len() < 4 {
383            return Err(Status::invalid_argument("invalid name format"));
384        }
385        let task_id = parts[1];
386        let config_id = parts[3];
387
388        let configs = self
389            .store
390            .push_configs
391            .get(task_id)
392            .ok_or_else(|| Status::not_found("no configs for task"))?;
393
394        let _found = configs
395            .iter()
396            .find(|c| c.push_notification_config.id.as_deref() == Some(config_id))
397            .ok_or_else(|| Status::not_found("config not found"))?;
398
399        Ok(Response::new(proto::TaskPushNotificationConfig {
400            name: req.name,
401            push_notification_config: None, // simplified
402        }))
403    }
404
405    async fn list_task_push_notification_config(
406        &self,
407        request: Request<proto::ListTaskPushNotificationConfigRequest>,
408    ) -> Result<Response<proto::ListTaskPushNotificationConfigResponse>, Status> {
409        let req = request.into_inner();
410        let task_id = req.parent.strip_prefix("tasks/").unwrap_or(&req.parent);
411
412        let configs: Vec<proto::TaskPushNotificationConfig> = self
413            .store
414            .push_configs
415            .get(task_id)
416            .map(|cs| {
417                cs.iter()
418                    .map(|c| proto::TaskPushNotificationConfig {
419                        name: format!(
420                            "tasks/{}/pushNotificationConfigs/{}",
421                            task_id,
422                            c.push_notification_config
423                                .id
424                                .as_deref()
425                                .unwrap_or("default")
426                        ),
427                        push_notification_config: None,
428                    })
429                    .collect()
430            })
431            .unwrap_or_default();
432
433        Ok(Response::new(
434            proto::ListTaskPushNotificationConfigResponse {
435                configs,
436                next_page_token: String::new(),
437            },
438        ))
439    }
440
441    async fn delete_task_push_notification_config(
442        &self,
443        request: Request<proto::DeleteTaskPushNotificationConfigRequest>,
444    ) -> Result<Response<()>, Status> {
445        let req = request.into_inner();
446        let parts: Vec<&str> = req.name.split('/').collect();
447        if parts.len() < 4 {
448            return Err(Status::invalid_argument("invalid name format"));
449        }
450        let task_id = parts[1];
451        let config_id = parts[3];
452
453        if let Some(mut configs) = self.store.push_configs.get_mut(task_id) {
454            configs.retain(|c| c.push_notification_config.id.as_deref() != Some(config_id));
455        }
456
457        Ok(Response::new(()))
458    }
459
460    // ── GetAgentCard ────────────────────────────────────────────────────
461
462    async fn get_agent_card(
463        &self,
464        _request: Request<proto::GetAgentCardRequest>,
465    ) -> Result<Response<proto::AgentCard>, Status> {
466        Ok(Response::new(bridge::local_card_to_proto(&self.store.card)))
467    }
468}