1use crate::a2a::bridge;
8use crate::a2a::proto;
9use crate::a2a::proto::a2a_service_server::{A2aService, A2aServiceServer};
10use crate::a2a::types::{self as local, TaskState};
11use crate::bus::AgentBus;
12
13use dashmap::DashMap;
14use std::pin::Pin;
15use std::sync::Arc;
16use tokio::sync::broadcast;
17use tokio_stream::Stream;
18use tonic::{Request, Response, Status};
19
20type StreamResult<T> = Pin<Box<dyn Stream<Item = Result<T, Status>> + Send>>;
21
22pub struct GrpcTaskStore {
24 tasks: DashMap<String, local::Task>,
25 push_configs: DashMap<String, Vec<local::TaskPushNotificationConfig>>,
26 card: local::AgentCard,
27 bus: Option<Arc<AgentBus>>,
28 update_tx: broadcast::Sender<(String, local::TaskStatus)>,
30}
31
32impl GrpcTaskStore {
33 pub fn new(card: local::AgentCard) -> Arc<Self> {
35 let (update_tx, _) = broadcast::channel(256);
36 Arc::new(Self {
37 tasks: DashMap::new(),
38 push_configs: DashMap::new(),
39 card,
40 bus: None,
41 update_tx,
42 })
43 }
44
45 pub fn with_bus(card: local::AgentCard, bus: Arc<AgentBus>) -> Arc<Self> {
47 let mut store = Self::new(card);
48 Arc::get_mut(&mut store)
49 .expect("fresh Arc must be uniquely owned")
50 .bus = Some(bus);
51 store
52 }
53
54 pub fn upsert_task(&self, task: local::Task) {
56 let _ = self.update_tx.send((task.id.clone(), task.status.clone()));
57 self.tasks.insert(task.id.clone(), task);
58 }
59
60 pub fn get_task(&self, id: &str) -> Option<local::Task> {
62 self.tasks.get(id).map(|r| r.value().clone())
63 }
64
65 pub fn subscribe_updates(&self) -> broadcast::Receiver<(String, local::TaskStatus)> {
67 self.update_tx.subscribe()
68 }
69
70 pub fn into_service(self: Arc<Self>) -> A2aServiceServer<A2aServiceImpl> {
72 A2aServiceServer::new(A2aServiceImpl { store: self })
73 }
74}
75
76pub struct A2aServiceImpl {
78 store: Arc<GrpcTaskStore>,
79}
80
81#[tonic::async_trait]
82impl A2aService for A2aServiceImpl {
83 async fn send_message(
86 &self,
87 request: Request<proto::SendMessageRequest>,
88 ) -> Result<Response<proto::SendMessageResponse>, Status> {
89 let req = request.into_inner();
90 let msg = req
91 .request
92 .ok_or_else(|| Status::invalid_argument("missing message"))?;
93 let local_msg = bridge::proto_message_to_local(&msg);
94
95 let task_id = local_msg
97 .task_id
98 .clone()
99 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
100
101 let task = local::Task {
102 id: task_id.clone(),
103 context_id: local_msg.context_id.clone(),
104 status: local::TaskStatus {
105 state: TaskState::Submitted,
106 message: Some(local_msg.clone()),
107 timestamp: Some(chrono::Utc::now().to_rfc3339()),
108 },
109 artifacts: vec![],
110 history: vec![local_msg],
111 metadata: Default::default(),
112 };
113 self.store.upsert_task(task.clone());
114
115 if let Some(ref bus) = self.store.bus {
117 let handle = bus.handle("grpc-server");
118 handle.send_task_update(&task_id, TaskState::Submitted, None);
119 }
120
121 let proto_task = bridge::local_task_to_proto(&task);
122 Ok(Response::new(proto::SendMessageResponse {
123 payload: Some(proto::send_message_response::Payload::Task(proto_task)),
124 }))
125 }
126
127 type SendStreamingMessageStream = StreamResult<proto::StreamResponse>;
130
131 async fn send_streaming_message(
132 &self,
133 request: Request<proto::SendMessageRequest>,
134 ) -> Result<Response<Self::SendStreamingMessageStream>, Status> {
135 let req = request.into_inner();
136 let msg = req
137 .request
138 .ok_or_else(|| Status::invalid_argument("missing message"))?;
139 let local_msg = bridge::proto_message_to_local(&msg);
140
141 let task_id = local_msg
142 .task_id
143 .clone()
144 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
145
146 let task = local::Task {
147 id: task_id.clone(),
148 context_id: local_msg.context_id.clone(),
149 status: local::TaskStatus {
150 state: TaskState::Submitted,
151 message: Some(local_msg.clone()),
152 timestamp: Some(chrono::Utc::now().to_rfc3339()),
153 },
154 artifacts: vec![],
155 history: vec![local_msg],
156 metadata: Default::default(),
157 };
158 self.store.upsert_task(task.clone());
159
160 let proto_task = bridge::local_task_to_proto(&task);
162 let mut rx = self.store.subscribe_updates();
163 let tid = task_id.clone();
164
165 let stream = async_stream::try_stream! {
166 yield proto::StreamResponse {
168 payload: Some(proto::stream_response::Payload::Task(proto_task)),
169 };
170
171 loop {
173 match rx.recv().await {
174 Ok((id, status)) if id == tid => {
175 let proto_status = bridge::local_task_status_to_proto(&status);
176 let is_terminal = status.state.is_terminal();
177 yield proto::StreamResponse {
178 payload: Some(proto::stream_response::Payload::StatusUpdate(
179 proto::TaskStatusUpdateEvent {
180 task_id: tid.clone(),
181 context_id: String::new(),
182 status: Some(proto_status),
183 r#final: is_terminal,
184 metadata: None,
185 },
186 )),
187 };
188 if is_terminal {
189 break;
190 }
191 }
192 Ok(_) => continue,
193 Err(broadcast::error::RecvError::Lagged(_)) => continue,
194 Err(broadcast::error::RecvError::Closed) => break,
195 }
196 }
197 };
198
199 Ok(Response::new(
200 Box::pin(stream) as Self::SendStreamingMessageStream
201 ))
202 }
203
204 async fn get_task(
207 &self,
208 request: Request<proto::GetTaskRequest>,
209 ) -> Result<Response<proto::Task>, Status> {
210 let req = request.into_inner();
211 let task_id = req.name.strip_prefix("tasks/").unwrap_or(&req.name);
212 let task = self
213 .store
214 .get_task(task_id)
215 .ok_or_else(|| Status::not_found(format!("task {task_id} not found")))?;
216 Ok(Response::new(bridge::local_task_to_proto(&task)))
217 }
218
219 async fn cancel_task(
222 &self,
223 request: Request<proto::CancelTaskRequest>,
224 ) -> Result<Response<proto::Task>, Status> {
225 let req = request.into_inner();
226 let task_id = req.name.strip_prefix("tasks/").unwrap_or(&req.name);
227
228 let mut task = self
229 .store
230 .tasks
231 .get_mut(task_id)
232 .ok_or_else(|| Status::not_found(format!("task {task_id} not found")))?;
233
234 if task.status.state.is_terminal() {
235 return Err(Status::failed_precondition(
236 "task already in terminal state",
237 ));
238 }
239
240 task.status = local::TaskStatus {
241 state: TaskState::Cancelled,
242 message: None,
243 timestamp: Some(chrono::Utc::now().to_rfc3339()),
244 };
245 let snapshot = task.clone();
246 drop(task);
247
248 let _ = self
249 .store
250 .update_tx
251 .send((task_id.to_string(), snapshot.status.clone()));
252
253 Ok(Response::new(bridge::local_task_to_proto(&snapshot)))
254 }
255
256 type TaskSubscriptionStream = StreamResult<proto::StreamResponse>;
259
260 async fn task_subscription(
261 &self,
262 request: Request<proto::TaskSubscriptionRequest>,
263 ) -> Result<Response<Self::TaskSubscriptionStream>, Status> {
264 let req = request.into_inner();
265 let task_id = req
266 .name
267 .strip_prefix("tasks/")
268 .unwrap_or(&req.name)
269 .to_string();
270
271 let task = self
272 .store
273 .get_task(&task_id)
274 .ok_or_else(|| Status::not_found(format!("task {task_id} not found")))?;
275
276 let proto_task = bridge::local_task_to_proto(&task);
277 let mut rx = self.store.subscribe_updates();
278 let tid = task_id.clone();
279
280 if task.status.state.is_terminal() {
282 let stream = async_stream::try_stream! {
283 yield proto::StreamResponse {
284 payload: Some(proto::stream_response::Payload::Task(proto_task)),
285 };
286 };
287 return Ok(Response::new(
288 Box::pin(stream) as Self::TaskSubscriptionStream
289 ));
290 }
291
292 let stream = async_stream::try_stream! {
293 yield proto::StreamResponse {
294 payload: Some(proto::stream_response::Payload::Task(proto_task)),
295 };
296
297 loop {
298 match rx.recv().await {
299 Ok((id, status)) if id == tid => {
300 let proto_status = bridge::local_task_status_to_proto(&status);
301 let is_terminal = status.state.is_terminal();
302 yield proto::StreamResponse {
303 payload: Some(proto::stream_response::Payload::StatusUpdate(
304 proto::TaskStatusUpdateEvent {
305 task_id: tid.clone(),
306 context_id: String::new(),
307 status: Some(proto_status),
308 r#final: is_terminal,
309 metadata: None,
310 },
311 )),
312 };
313 if is_terminal { break; }
314 }
315 Ok(_) => continue,
316 Err(broadcast::error::RecvError::Lagged(_)) => continue,
317 Err(broadcast::error::RecvError::Closed) => break,
318 }
319 }
320 };
321
322 Ok(Response::new(
323 Box::pin(stream) as Self::TaskSubscriptionStream
324 ))
325 }
326
327 async fn create_task_push_notification_config(
330 &self,
331 request: Request<proto::CreateTaskPushNotificationConfigRequest>,
332 ) -> Result<Response<proto::TaskPushNotificationConfig>, Status> {
333 let req = request.into_inner();
334 let task_id = req.parent.strip_prefix("tasks/").unwrap_or(&req.parent);
335
336 if self.store.get_task(task_id).is_none() {
337 return Err(Status::not_found(format!("task {task_id} not found")));
338 }
339
340 let config = req
341 .config
342 .ok_or_else(|| Status::invalid_argument("missing config"))?;
343 let pnc = config.push_notification_config.as_ref();
344
345 let local_config = local::TaskPushNotificationConfig {
346 id: task_id.to_string(),
347 push_notification_config: local::PushNotificationConfig {
348 url: pnc.map(|c| c.url.clone()).unwrap_or_default(),
349 token: pnc.and_then(|c| {
350 if c.token.is_empty() {
351 None
352 } else {
353 Some(c.token.clone())
354 }
355 }),
356 id: pnc.and_then(|c| {
357 if c.id.is_empty() {
358 None
359 } else {
360 Some(c.id.clone())
361 }
362 }),
363 },
364 };
365
366 self.store
367 .push_configs
368 .entry(task_id.to_string())
369 .or_default()
370 .push(local_config);
371
372 Ok(Response::new(config))
373 }
374
375 async fn get_task_push_notification_config(
376 &self,
377 request: Request<proto::GetTaskPushNotificationConfigRequest>,
378 ) -> Result<Response<proto::TaskPushNotificationConfig>, Status> {
379 let req = request.into_inner();
380 let parts: Vec<&str> = req.name.split('/').collect();
382 if parts.len() < 4 {
383 return Err(Status::invalid_argument("invalid name format"));
384 }
385 let task_id = parts[1];
386 let config_id = parts[3];
387
388 let configs = self
389 .store
390 .push_configs
391 .get(task_id)
392 .ok_or_else(|| Status::not_found("no configs for task"))?;
393
394 let _found = configs
395 .iter()
396 .find(|c| c.push_notification_config.id.as_deref() == Some(config_id))
397 .ok_or_else(|| Status::not_found("config not found"))?;
398
399 Ok(Response::new(proto::TaskPushNotificationConfig {
400 name: req.name,
401 push_notification_config: None, }))
403 }
404
405 async fn list_task_push_notification_config(
406 &self,
407 request: Request<proto::ListTaskPushNotificationConfigRequest>,
408 ) -> Result<Response<proto::ListTaskPushNotificationConfigResponse>, Status> {
409 let req = request.into_inner();
410 let task_id = req.parent.strip_prefix("tasks/").unwrap_or(&req.parent);
411
412 let configs: Vec<proto::TaskPushNotificationConfig> = self
413 .store
414 .push_configs
415 .get(task_id)
416 .map(|cs| {
417 cs.iter()
418 .map(|c| proto::TaskPushNotificationConfig {
419 name: format!(
420 "tasks/{}/pushNotificationConfigs/{}",
421 task_id,
422 c.push_notification_config
423 .id
424 .as_deref()
425 .unwrap_or("default")
426 ),
427 push_notification_config: None,
428 })
429 .collect()
430 })
431 .unwrap_or_default();
432
433 Ok(Response::new(
434 proto::ListTaskPushNotificationConfigResponse {
435 configs,
436 next_page_token: String::new(),
437 },
438 ))
439 }
440
441 async fn delete_task_push_notification_config(
442 &self,
443 request: Request<proto::DeleteTaskPushNotificationConfigRequest>,
444 ) -> Result<Response<()>, Status> {
445 let req = request.into_inner();
446 let parts: Vec<&str> = req.name.split('/').collect();
447 if parts.len() < 4 {
448 return Err(Status::invalid_argument("invalid name format"));
449 }
450 let task_id = parts[1];
451 let config_id = parts[3];
452
453 if let Some(mut configs) = self.store.push_configs.get_mut(task_id) {
454 configs.retain(|c| c.push_notification_config.id.as_deref() != Some(config_id));
455 }
456
457 Ok(Response::new(()))
458 }
459
460 async fn get_agent_card(
463 &self,
464 _request: Request<proto::GetAgentCardRequest>,
465 ) -> Result<Response<proto::AgentCard>, Status> {
466 Ok(Response::new(bridge::local_card_to_proto(&self.store.card)))
467 }
468}