Skip to main content

codetether_agent/a2a/
grpc.rs

1//! gRPC transport for the A2A protocol.
2//!
3//! Implements the `A2AService` trait generated by tonic-build from
4//! `proto/a2a/v1/a2a.proto`.  Delegates to the same task store and bus
5//! that the JSON-RPC server uses, so both transports share state.
6
7use crate::a2a::bridge;
8use crate::a2a::proto;
9use crate::a2a::proto::a2a_service_server::{A2aService, A2aServiceServer};
10use crate::a2a::types::{self as local, TaskState};
11use crate::bus::AgentBus;
12
13use dashmap::DashMap;
14use std::pin::Pin;
15use std::sync::Arc;
16use tokio::sync::broadcast;
17use tokio_stream::Stream;
18use tonic::{Request, Response, Status};
19
20type StreamResult<T> = Pin<Box<dyn Stream<Item = Result<T, Status>> + Send>>;
21
22/// Shared state backing both JSON-RPC and gRPC transports.
23pub struct GrpcTaskStore {
24    tasks: DashMap<String, local::Task>,
25    push_configs: DashMap<String, Vec<local::TaskPushNotificationConfig>>,
26    card: local::AgentCard,
27    bus: Option<Arc<AgentBus>>,
28    /// Broadcast for task updates (task_id, status)
29    update_tx: broadcast::Sender<(String, local::TaskStatus)>,
30}
31
32impl GrpcTaskStore {
33    /// Create a new task store with the given agent card.
34    pub fn new(card: local::AgentCard) -> Arc<Self> {
35        let (update_tx, _) = broadcast::channel(256);
36        Arc::new(Self {
37            tasks: DashMap::new(),
38            push_configs: DashMap::new(),
39            card,
40            bus: None,
41            update_tx,
42        })
43    }
44
45    /// Create with an agent bus attached.
46    pub fn with_bus(card: local::AgentCard, bus: Arc<AgentBus>) -> Arc<Self> {
47        let (update_tx, _) = broadcast::channel(256);
48        Arc::new(Self {
49            tasks: DashMap::new(),
50            push_configs: DashMap::new(),
51            card,
52            bus: Some(bus),
53            update_tx,
54        })
55    }
56
57    /// Insert or update a task.
58    pub fn upsert_task(&self, task: local::Task) {
59        let _ = self.update_tx.send((task.id.clone(), task.status.clone()));
60        self.tasks.insert(task.id.clone(), task);
61    }
62
63    /// Get a task by id.
64    pub fn get_task(&self, id: &str) -> Option<local::Task> {
65        self.tasks.get(id).map(|r| r.value().clone())
66    }
67
68    /// Subscribe to task update notifications.
69    pub fn subscribe_updates(&self) -> broadcast::Receiver<(String, local::TaskStatus)> {
70        self.update_tx.subscribe()
71    }
72
73    /// Build the tonic service layer from this store.
74    pub fn into_service(self: Arc<Self>) -> A2aServiceServer<A2aServiceImpl> {
75        A2aServiceServer::new(A2aServiceImpl { store: self })
76    }
77}
78
79/// Tonic service implementation.
80pub struct A2aServiceImpl {
81    store: Arc<GrpcTaskStore>,
82}
83
84#[tonic::async_trait]
85impl A2aService for A2aServiceImpl {
86    // ── SendMessage ─────────────────────────────────────────────────────
87
88    async fn send_message(
89        &self,
90        request: Request<proto::SendMessageRequest>,
91    ) -> Result<Response<proto::SendMessageResponse>, Status> {
92        let req = request.into_inner();
93        let msg = req
94            .request
95            .ok_or_else(|| Status::invalid_argument("missing message"))?;
96        let local_msg = bridge::proto_message_to_local(&msg);
97
98        // Create / route task
99        let task_id = local_msg
100            .task_id
101            .clone()
102            .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
103
104        let task = local::Task {
105            id: task_id.clone(),
106            context_id: local_msg.context_id.clone(),
107            status: local::TaskStatus {
108                state: TaskState::Submitted,
109                message: Some(local_msg.clone()),
110                timestamp: Some(chrono::Utc::now().to_rfc3339()),
111            },
112            artifacts: vec![],
113            history: vec![local_msg],
114            metadata: Default::default(),
115        };
116        self.store.upsert_task(task.clone());
117
118        // Notify bus
119        if let Some(ref bus) = self.store.bus {
120            let handle = bus.handle("grpc-server");
121            handle.send_task_update(&task_id, TaskState::Submitted, None);
122        }
123
124        let proto_task = bridge::local_task_to_proto(&task);
125        Ok(Response::new(proto::SendMessageResponse {
126            payload: Some(proto::send_message_response::Payload::Task(proto_task)),
127        }))
128    }
129
130    // ── SendStreamingMessage ────────────────────────────────────────────
131
132    type SendStreamingMessageStream = StreamResult<proto::StreamResponse>;
133
134    async fn send_streaming_message(
135        &self,
136        request: Request<proto::SendMessageRequest>,
137    ) -> Result<Response<Self::SendStreamingMessageStream>, Status> {
138        let req = request.into_inner();
139        let msg = req
140            .request
141            .ok_or_else(|| Status::invalid_argument("missing message"))?;
142        let local_msg = bridge::proto_message_to_local(&msg);
143
144        let task_id = local_msg
145            .task_id
146            .clone()
147            .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
148
149        let task = local::Task {
150            id: task_id.clone(),
151            context_id: local_msg.context_id.clone(),
152            status: local::TaskStatus {
153                state: TaskState::Submitted,
154                message: Some(local_msg.clone()),
155                timestamp: Some(chrono::Utc::now().to_rfc3339()),
156            },
157            artifacts: vec![],
158            history: vec![local_msg],
159            metadata: Default::default(),
160        };
161        self.store.upsert_task(task.clone());
162
163        // Return the initial task then listen for updates
164        let proto_task = bridge::local_task_to_proto(&task);
165        let mut rx = self.store.subscribe_updates();
166        let tid = task_id.clone();
167
168        let stream = async_stream::try_stream! {
169            // First frame: the task itself
170            yield proto::StreamResponse {
171                payload: Some(proto::stream_response::Payload::Task(proto_task)),
172            };
173
174            // Subsequent frames: status updates for this task
175            loop {
176                match rx.recv().await {
177                    Ok((id, status)) if id == tid => {
178                        let proto_status = bridge::local_task_status_to_proto(&status);
179                        let is_terminal = status.state.is_terminal();
180                        yield proto::StreamResponse {
181                            payload: Some(proto::stream_response::Payload::StatusUpdate(
182                                proto::TaskStatusUpdateEvent {
183                                    task_id: tid.clone(),
184                                    context_id: String::new(),
185                                    status: Some(proto_status),
186                                    r#final: is_terminal,
187                                    metadata: None,
188                                },
189                            )),
190                        };
191                        if is_terminal {
192                            break;
193                        }
194                    }
195                    Ok(_) => continue,
196                    Err(broadcast::error::RecvError::Lagged(_)) => continue,
197                    Err(broadcast::error::RecvError::Closed) => break,
198                }
199            }
200        };
201
202        Ok(Response::new(
203            Box::pin(stream) as Self::SendStreamingMessageStream
204        ))
205    }
206
207    // ── GetTask ─────────────────────────────────────────────────────────
208
209    async fn get_task(
210        &self,
211        request: Request<proto::GetTaskRequest>,
212    ) -> Result<Response<proto::Task>, Status> {
213        let req = request.into_inner();
214        let task_id = req.name.strip_prefix("tasks/").unwrap_or(&req.name);
215        let task = self
216            .store
217            .get_task(task_id)
218            .ok_or_else(|| Status::not_found(format!("task {task_id} not found")))?;
219        Ok(Response::new(bridge::local_task_to_proto(&task)))
220    }
221
222    // ── CancelTask ──────────────────────────────────────────────────────
223
224    async fn cancel_task(
225        &self,
226        request: Request<proto::CancelTaskRequest>,
227    ) -> Result<Response<proto::Task>, Status> {
228        let req = request.into_inner();
229        let task_id = req.name.strip_prefix("tasks/").unwrap_or(&req.name);
230
231        let mut task = self
232            .store
233            .tasks
234            .get_mut(task_id)
235            .ok_or_else(|| Status::not_found(format!("task {task_id} not found")))?;
236
237        if task.status.state.is_terminal() {
238            return Err(Status::failed_precondition(
239                "task already in terminal state",
240            ));
241        }
242
243        task.status = local::TaskStatus {
244            state: TaskState::Cancelled,
245            message: None,
246            timestamp: Some(chrono::Utc::now().to_rfc3339()),
247        };
248        let snapshot = task.clone();
249        drop(task);
250
251        let _ = self
252            .store
253            .update_tx
254            .send((task_id.to_string(), snapshot.status.clone()));
255
256        Ok(Response::new(bridge::local_task_to_proto(&snapshot)))
257    }
258
259    // ── TaskSubscription ────────────────────────────────────────────────
260
261    type TaskSubscriptionStream = StreamResult<proto::StreamResponse>;
262
263    async fn task_subscription(
264        &self,
265        request: Request<proto::TaskSubscriptionRequest>,
266    ) -> Result<Response<Self::TaskSubscriptionStream>, Status> {
267        let req = request.into_inner();
268        let task_id = req
269            .name
270            .strip_prefix("tasks/")
271            .unwrap_or(&req.name)
272            .to_string();
273
274        let task = self
275            .store
276            .get_task(&task_id)
277            .ok_or_else(|| Status::not_found(format!("task {task_id} not found")))?;
278
279        let proto_task = bridge::local_task_to_proto(&task);
280        let mut rx = self.store.subscribe_updates();
281        let tid = task_id.clone();
282
283        // If already terminal, return just the task and close
284        if task.status.state.is_terminal() {
285            let stream = async_stream::try_stream! {
286                yield proto::StreamResponse {
287                    payload: Some(proto::stream_response::Payload::Task(proto_task)),
288                };
289            };
290            return Ok(Response::new(
291                Box::pin(stream) as Self::TaskSubscriptionStream
292            ));
293        }
294
295        let stream = async_stream::try_stream! {
296            yield proto::StreamResponse {
297                payload: Some(proto::stream_response::Payload::Task(proto_task)),
298            };
299
300            loop {
301                match rx.recv().await {
302                    Ok((id, status)) if id == tid => {
303                        let proto_status = bridge::local_task_status_to_proto(&status);
304                        let is_terminal = status.state.is_terminal();
305                        yield proto::StreamResponse {
306                            payload: Some(proto::stream_response::Payload::StatusUpdate(
307                                proto::TaskStatusUpdateEvent {
308                                    task_id: tid.clone(),
309                                    context_id: String::new(),
310                                    status: Some(proto_status),
311                                    r#final: is_terminal,
312                                    metadata: None,
313                                },
314                            )),
315                        };
316                        if is_terminal { break; }
317                    }
318                    Ok(_) => continue,
319                    Err(broadcast::error::RecvError::Lagged(_)) => continue,
320                    Err(broadcast::error::RecvError::Closed) => break,
321                }
322            }
323        };
324
325        Ok(Response::new(
326            Box::pin(stream) as Self::TaskSubscriptionStream
327        ))
328    }
329
330    // ── Push notification config CRUD ───────────────────────────────────
331
332    async fn create_task_push_notification_config(
333        &self,
334        request: Request<proto::CreateTaskPushNotificationConfigRequest>,
335    ) -> Result<Response<proto::TaskPushNotificationConfig>, Status> {
336        let req = request.into_inner();
337        let task_id = req.parent.strip_prefix("tasks/").unwrap_or(&req.parent);
338
339        if self.store.get_task(task_id).is_none() {
340            return Err(Status::not_found(format!("task {task_id} not found")));
341        }
342
343        let config = req
344            .config
345            .ok_or_else(|| Status::invalid_argument("missing config"))?;
346        let pnc = config.push_notification_config.as_ref();
347
348        let local_config = local::TaskPushNotificationConfig {
349            id: task_id.to_string(),
350            push_notification_config: local::PushNotificationConfig {
351                url: pnc.map(|c| c.url.clone()).unwrap_or_default(),
352                token: pnc.and_then(|c| {
353                    if c.token.is_empty() {
354                        None
355                    } else {
356                        Some(c.token.clone())
357                    }
358                }),
359                id: pnc.and_then(|c| {
360                    if c.id.is_empty() {
361                        None
362                    } else {
363                        Some(c.id.clone())
364                    }
365                }),
366            },
367        };
368
369        self.store
370            .push_configs
371            .entry(task_id.to_string())
372            .or_default()
373            .push(local_config);
374
375        Ok(Response::new(config))
376    }
377
378    async fn get_task_push_notification_config(
379        &self,
380        request: Request<proto::GetTaskPushNotificationConfigRequest>,
381    ) -> Result<Response<proto::TaskPushNotificationConfig>, Status> {
382        let req = request.into_inner();
383        // name format: tasks/{task_id}/pushNotificationConfigs/{config_id}
384        let parts: Vec<&str> = req.name.split('/').collect();
385        if parts.len() < 4 {
386            return Err(Status::invalid_argument("invalid name format"));
387        }
388        let task_id = parts[1];
389        let config_id = parts[3];
390
391        let configs = self
392            .store
393            .push_configs
394            .get(task_id)
395            .ok_or_else(|| Status::not_found("no configs for task"))?;
396
397        let _found = configs
398            .iter()
399            .find(|c| c.push_notification_config.id.as_deref() == Some(config_id))
400            .ok_or_else(|| Status::not_found("config not found"))?;
401
402        Ok(Response::new(proto::TaskPushNotificationConfig {
403            name: req.name,
404            push_notification_config: None, // simplified
405        }))
406    }
407
408    async fn list_task_push_notification_config(
409        &self,
410        request: Request<proto::ListTaskPushNotificationConfigRequest>,
411    ) -> Result<Response<proto::ListTaskPushNotificationConfigResponse>, Status> {
412        let req = request.into_inner();
413        let task_id = req.parent.strip_prefix("tasks/").unwrap_or(&req.parent);
414
415        let configs: Vec<proto::TaskPushNotificationConfig> = self
416            .store
417            .push_configs
418            .get(task_id)
419            .map(|cs| {
420                cs.iter()
421                    .map(|c| proto::TaskPushNotificationConfig {
422                        name: format!(
423                            "tasks/{}/pushNotificationConfigs/{}",
424                            task_id,
425                            c.push_notification_config
426                                .id
427                                .as_deref()
428                                .unwrap_or("default")
429                        ),
430                        push_notification_config: None,
431                    })
432                    .collect()
433            })
434            .unwrap_or_default();
435
436        Ok(Response::new(
437            proto::ListTaskPushNotificationConfigResponse {
438                configs,
439                next_page_token: String::new(),
440            },
441        ))
442    }
443
444    async fn delete_task_push_notification_config(
445        &self,
446        request: Request<proto::DeleteTaskPushNotificationConfigRequest>,
447    ) -> Result<Response<()>, Status> {
448        let req = request.into_inner();
449        let parts: Vec<&str> = req.name.split('/').collect();
450        if parts.len() < 4 {
451            return Err(Status::invalid_argument("invalid name format"));
452        }
453        let task_id = parts[1];
454        let config_id = parts[3];
455
456        if let Some(mut configs) = self.store.push_configs.get_mut(task_id) {
457            configs.retain(|c| c.push_notification_config.id.as_deref() != Some(config_id));
458        }
459
460        Ok(Response::new(()))
461    }
462
463    // ── GetAgentCard ────────────────────────────────────────────────────
464
465    async fn get_agent_card(
466        &self,
467        _request: Request<proto::GetAgentCardRequest>,
468    ) -> Result<Response<proto::AgentCard>, Status> {
469        Ok(Response::new(bridge::local_card_to_proto(&self.store.card)))
470    }
471}