1use crate::a2a::bridge;
8use crate::a2a::proto;
9use crate::a2a::proto::a2a_service_server::{A2aService, A2aServiceServer};
10use crate::a2a::types::{self as local, TaskState};
11use crate::bus::AgentBus;
12
13use dashmap::DashMap;
14use std::pin::Pin;
15use std::sync::Arc;
16use tokio::sync::broadcast;
17use tokio_stream::Stream;
18use tonic::{Request, Response, Status};
19
20type StreamResult<T> = Pin<Box<dyn Stream<Item = Result<T, Status>> + Send>>;
21
22pub struct GrpcTaskStore {
24 tasks: DashMap<String, local::Task>,
25 push_configs: DashMap<String, Vec<local::TaskPushNotificationConfig>>,
26 card: local::AgentCard,
27 bus: Option<Arc<AgentBus>>,
28 update_tx: broadcast::Sender<(String, local::TaskStatus)>,
30}
31
32impl GrpcTaskStore {
33 pub fn new(card: local::AgentCard) -> Arc<Self> {
35 let (update_tx, _) = broadcast::channel(256);
36 Arc::new(Self {
37 tasks: DashMap::new(),
38 push_configs: DashMap::new(),
39 card,
40 bus: None,
41 update_tx,
42 })
43 }
44
45 pub fn with_bus(card: local::AgentCard, bus: Arc<AgentBus>) -> Arc<Self> {
47 let (update_tx, _) = broadcast::channel(256);
48 Arc::new(Self {
49 tasks: DashMap::new(),
50 push_configs: DashMap::new(),
51 card,
52 bus: Some(bus),
53 update_tx,
54 })
55 }
56
57 pub fn upsert_task(&self, task: local::Task) {
59 let _ = self.update_tx.send((task.id.clone(), task.status.clone()));
60 self.tasks.insert(task.id.clone(), task);
61 }
62
63 pub fn get_task(&self, id: &str) -> Option<local::Task> {
65 self.tasks.get(id).map(|r| r.value().clone())
66 }
67
68 pub fn subscribe_updates(&self) -> broadcast::Receiver<(String, local::TaskStatus)> {
70 self.update_tx.subscribe()
71 }
72
73 pub fn into_service(self: Arc<Self>) -> A2aServiceServer<A2aServiceImpl> {
75 A2aServiceServer::new(A2aServiceImpl { store: self })
76 }
77}
78
79pub struct A2aServiceImpl {
81 store: Arc<GrpcTaskStore>,
82}
83
84#[tonic::async_trait]
85impl A2aService for A2aServiceImpl {
86 async fn send_message(
89 &self,
90 request: Request<proto::SendMessageRequest>,
91 ) -> Result<Response<proto::SendMessageResponse>, Status> {
92 let req = request.into_inner();
93 let msg = req
94 .request
95 .ok_or_else(|| Status::invalid_argument("missing message"))?;
96 let local_msg = bridge::proto_message_to_local(&msg);
97
98 let task_id = local_msg
100 .task_id
101 .clone()
102 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
103
104 let task = local::Task {
105 id: task_id.clone(),
106 context_id: local_msg.context_id.clone(),
107 status: local::TaskStatus {
108 state: TaskState::Submitted,
109 message: Some(local_msg.clone()),
110 timestamp: Some(chrono::Utc::now().to_rfc3339()),
111 },
112 artifacts: vec![],
113 history: vec![local_msg],
114 metadata: Default::default(),
115 };
116 self.store.upsert_task(task.clone());
117
118 if let Some(ref bus) = self.store.bus {
120 let handle = bus.handle("grpc-server");
121 handle.send_task_update(&task_id, TaskState::Submitted, None);
122 }
123
124 let proto_task = bridge::local_task_to_proto(&task);
125 Ok(Response::new(proto::SendMessageResponse {
126 payload: Some(proto::send_message_response::Payload::Task(proto_task)),
127 }))
128 }
129
130 type SendStreamingMessageStream = StreamResult<proto::StreamResponse>;
133
134 async fn send_streaming_message(
135 &self,
136 request: Request<proto::SendMessageRequest>,
137 ) -> Result<Response<Self::SendStreamingMessageStream>, Status> {
138 let req = request.into_inner();
139 let msg = req
140 .request
141 .ok_or_else(|| Status::invalid_argument("missing message"))?;
142 let local_msg = bridge::proto_message_to_local(&msg);
143
144 let task_id = local_msg
145 .task_id
146 .clone()
147 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
148
149 let task = local::Task {
150 id: task_id.clone(),
151 context_id: local_msg.context_id.clone(),
152 status: local::TaskStatus {
153 state: TaskState::Submitted,
154 message: Some(local_msg.clone()),
155 timestamp: Some(chrono::Utc::now().to_rfc3339()),
156 },
157 artifacts: vec![],
158 history: vec![local_msg],
159 metadata: Default::default(),
160 };
161 self.store.upsert_task(task.clone());
162
163 let proto_task = bridge::local_task_to_proto(&task);
165 let mut rx = self.store.subscribe_updates();
166 let tid = task_id.clone();
167
168 let stream = async_stream::try_stream! {
169 yield proto::StreamResponse {
171 payload: Some(proto::stream_response::Payload::Task(proto_task)),
172 };
173
174 loop {
176 match rx.recv().await {
177 Ok((id, status)) if id == tid => {
178 let proto_status = bridge::local_task_status_to_proto(&status);
179 let is_terminal = status.state.is_terminal();
180 yield proto::StreamResponse {
181 payload: Some(proto::stream_response::Payload::StatusUpdate(
182 proto::TaskStatusUpdateEvent {
183 task_id: tid.clone(),
184 context_id: String::new(),
185 status: Some(proto_status),
186 r#final: is_terminal,
187 metadata: None,
188 },
189 )),
190 };
191 if is_terminal {
192 break;
193 }
194 }
195 Ok(_) => continue,
196 Err(broadcast::error::RecvError::Lagged(_)) => continue,
197 Err(broadcast::error::RecvError::Closed) => break,
198 }
199 }
200 };
201
202 Ok(Response::new(
203 Box::pin(stream) as Self::SendStreamingMessageStream
204 ))
205 }
206
207 async fn get_task(
210 &self,
211 request: Request<proto::GetTaskRequest>,
212 ) -> Result<Response<proto::Task>, Status> {
213 let req = request.into_inner();
214 let task_id = req.name.strip_prefix("tasks/").unwrap_or(&req.name);
215 let task = self
216 .store
217 .get_task(task_id)
218 .ok_or_else(|| Status::not_found(format!("task {task_id} not found")))?;
219 Ok(Response::new(bridge::local_task_to_proto(&task)))
220 }
221
222 async fn cancel_task(
225 &self,
226 request: Request<proto::CancelTaskRequest>,
227 ) -> Result<Response<proto::Task>, Status> {
228 let req = request.into_inner();
229 let task_id = req.name.strip_prefix("tasks/").unwrap_or(&req.name);
230
231 let mut task = self
232 .store
233 .tasks
234 .get_mut(task_id)
235 .ok_or_else(|| Status::not_found(format!("task {task_id} not found")))?;
236
237 if task.status.state.is_terminal() {
238 return Err(Status::failed_precondition(
239 "task already in terminal state",
240 ));
241 }
242
243 task.status = local::TaskStatus {
244 state: TaskState::Cancelled,
245 message: None,
246 timestamp: Some(chrono::Utc::now().to_rfc3339()),
247 };
248 let snapshot = task.clone();
249 drop(task);
250
251 let _ = self
252 .store
253 .update_tx
254 .send((task_id.to_string(), snapshot.status.clone()));
255
256 Ok(Response::new(bridge::local_task_to_proto(&snapshot)))
257 }
258
259 type TaskSubscriptionStream = StreamResult<proto::StreamResponse>;
262
263 async fn task_subscription(
264 &self,
265 request: Request<proto::TaskSubscriptionRequest>,
266 ) -> Result<Response<Self::TaskSubscriptionStream>, Status> {
267 let req = request.into_inner();
268 let task_id = req
269 .name
270 .strip_prefix("tasks/")
271 .unwrap_or(&req.name)
272 .to_string();
273
274 let task = self
275 .store
276 .get_task(&task_id)
277 .ok_or_else(|| Status::not_found(format!("task {task_id} not found")))?;
278
279 let proto_task = bridge::local_task_to_proto(&task);
280 let mut rx = self.store.subscribe_updates();
281 let tid = task_id.clone();
282
283 if task.status.state.is_terminal() {
285 let stream = async_stream::try_stream! {
286 yield proto::StreamResponse {
287 payload: Some(proto::stream_response::Payload::Task(proto_task)),
288 };
289 };
290 return Ok(Response::new(
291 Box::pin(stream) as Self::TaskSubscriptionStream
292 ));
293 }
294
295 let stream = async_stream::try_stream! {
296 yield proto::StreamResponse {
297 payload: Some(proto::stream_response::Payload::Task(proto_task)),
298 };
299
300 loop {
301 match rx.recv().await {
302 Ok((id, status)) if id == tid => {
303 let proto_status = bridge::local_task_status_to_proto(&status);
304 let is_terminal = status.state.is_terminal();
305 yield proto::StreamResponse {
306 payload: Some(proto::stream_response::Payload::StatusUpdate(
307 proto::TaskStatusUpdateEvent {
308 task_id: tid.clone(),
309 context_id: String::new(),
310 status: Some(proto_status),
311 r#final: is_terminal,
312 metadata: None,
313 },
314 )),
315 };
316 if is_terminal { break; }
317 }
318 Ok(_) => continue,
319 Err(broadcast::error::RecvError::Lagged(_)) => continue,
320 Err(broadcast::error::RecvError::Closed) => break,
321 }
322 }
323 };
324
325 Ok(Response::new(
326 Box::pin(stream) as Self::TaskSubscriptionStream
327 ))
328 }
329
330 async fn create_task_push_notification_config(
333 &self,
334 request: Request<proto::CreateTaskPushNotificationConfigRequest>,
335 ) -> Result<Response<proto::TaskPushNotificationConfig>, Status> {
336 let req = request.into_inner();
337 let task_id = req.parent.strip_prefix("tasks/").unwrap_or(&req.parent);
338
339 if self.store.get_task(task_id).is_none() {
340 return Err(Status::not_found(format!("task {task_id} not found")));
341 }
342
343 let config = req
344 .config
345 .ok_or_else(|| Status::invalid_argument("missing config"))?;
346 let pnc = config.push_notification_config.as_ref();
347
348 let local_config = local::TaskPushNotificationConfig {
349 id: task_id.to_string(),
350 push_notification_config: local::PushNotificationConfig {
351 url: pnc.map(|c| c.url.clone()).unwrap_or_default(),
352 token: pnc.and_then(|c| {
353 if c.token.is_empty() {
354 None
355 } else {
356 Some(c.token.clone())
357 }
358 }),
359 id: pnc.and_then(|c| {
360 if c.id.is_empty() {
361 None
362 } else {
363 Some(c.id.clone())
364 }
365 }),
366 },
367 };
368
369 self.store
370 .push_configs
371 .entry(task_id.to_string())
372 .or_default()
373 .push(local_config);
374
375 Ok(Response::new(config))
376 }
377
378 async fn get_task_push_notification_config(
379 &self,
380 request: Request<proto::GetTaskPushNotificationConfigRequest>,
381 ) -> Result<Response<proto::TaskPushNotificationConfig>, Status> {
382 let req = request.into_inner();
383 let parts: Vec<&str> = req.name.split('/').collect();
385 if parts.len() < 4 {
386 return Err(Status::invalid_argument("invalid name format"));
387 }
388 let task_id = parts[1];
389 let config_id = parts[3];
390
391 let configs = self
392 .store
393 .push_configs
394 .get(task_id)
395 .ok_or_else(|| Status::not_found("no configs for task"))?;
396
397 let _found = configs
398 .iter()
399 .find(|c| c.push_notification_config.id.as_deref() == Some(config_id))
400 .ok_or_else(|| Status::not_found("config not found"))?;
401
402 Ok(Response::new(proto::TaskPushNotificationConfig {
403 name: req.name,
404 push_notification_config: None, }))
406 }
407
408 async fn list_task_push_notification_config(
409 &self,
410 request: Request<proto::ListTaskPushNotificationConfigRequest>,
411 ) -> Result<Response<proto::ListTaskPushNotificationConfigResponse>, Status> {
412 let req = request.into_inner();
413 let task_id = req.parent.strip_prefix("tasks/").unwrap_or(&req.parent);
414
415 let configs: Vec<proto::TaskPushNotificationConfig> = self
416 .store
417 .push_configs
418 .get(task_id)
419 .map(|cs| {
420 cs.iter()
421 .map(|c| proto::TaskPushNotificationConfig {
422 name: format!(
423 "tasks/{}/pushNotificationConfigs/{}",
424 task_id,
425 c.push_notification_config
426 .id
427 .as_deref()
428 .unwrap_or("default")
429 ),
430 push_notification_config: None,
431 })
432 .collect()
433 })
434 .unwrap_or_default();
435
436 Ok(Response::new(
437 proto::ListTaskPushNotificationConfigResponse {
438 configs,
439 next_page_token: String::new(),
440 },
441 ))
442 }
443
444 async fn delete_task_push_notification_config(
445 &self,
446 request: Request<proto::DeleteTaskPushNotificationConfigRequest>,
447 ) -> Result<Response<()>, Status> {
448 let req = request.into_inner();
449 let parts: Vec<&str> = req.name.split('/').collect();
450 if parts.len() < 4 {
451 return Err(Status::invalid_argument("invalid name format"));
452 }
453 let task_id = parts[1];
454 let config_id = parts[3];
455
456 if let Some(mut configs) = self.store.push_configs.get_mut(task_id) {
457 configs.retain(|c| c.push_notification_config.id.as_deref() != Some(config_id));
458 }
459
460 Ok(Response::new(()))
461 }
462
463 async fn get_agent_card(
466 &self,
467 _request: Request<proto::GetAgentCardRequest>,
468 ) -> Result<Response<proto::AgentCard>, Status> {
469 Ok(Response::new(bridge::local_card_to_proto(&self.store.card)))
470 }
471}