1use crate::ServerConfig;
2use crate::a2a::{
3 AgentCard, Executor, ExecutorConfig, JsonRpcError, JsonRpcRequest, JsonRpcResponse, Message,
4 MessageSendParams, Task, TaskState, TaskStatus, TaskStatusUpdateEvent, TasksCancelParams,
5 TasksGetParams, UpdateEvent, build_agent_card, jsonrpc,
6};
7use adk_runner::{Runner, RunnerConfig};
8use axum::{
9 extract::State,
10 http::StatusCode,
11 response::{
12 IntoResponse, Json,
13 sse::{Event, Sse},
14 },
15};
16use futures::stream::Stream;
17use serde_json::Value;
18use std::{collections::HashMap, convert::Infallible, sync::Arc, time::Duration};
19use tokio::sync::{Mutex, Notify, RwLock, mpsc, oneshot};
20use tokio_util::sync::CancellationToken;
21
22#[derive(Default)]
24pub struct TaskStore {
25 tasks: RwLock<HashMap<String, Task>>,
26}
27
28impl TaskStore {
29 pub fn new() -> Self {
30 Self::default()
31 }
32
33 pub async fn store(&self, task: Task) {
34 self.tasks.write().await.insert(task.id.clone(), task);
35 }
36
37 pub async fn get(&self, task_id: &str) -> Option<Task> {
38 self.tasks.read().await.get(task_id).cloned()
39 }
40
41 pub async fn remove(&self, task_id: &str) -> Option<Task> {
42 self.tasks.write().await.remove(task_id)
43 }
44}
45
46#[derive(Clone)]
47struct ActiveTask {
48 token: CancellationToken,
49 abort_handle: tokio::task::AbortHandle,
50 completion: Arc<Notify>,
51 context_id: String,
52}
53
54enum StreamTaskMessage {
55 Update(Box<UpdateEvent>),
56 Error(String),
57}
58
59#[derive(Clone)]
61pub struct A2aController {
62 config: ServerConfig,
63 agent_card: AgentCard,
64 task_store: Arc<TaskStore>,
65 active_tasks: Arc<Mutex<HashMap<String, ActiveTask>>>,
66}
67
68impl A2aController {
69 pub fn new(config: ServerConfig, base_url: &str) -> Self {
70 let root_agent = config.agent_loader.root_agent();
71 let invoke_url = format!("{}/a2a", base_url.trim_end_matches('/'));
72 let agent_card = build_agent_card(root_agent.as_ref(), &invoke_url);
73
74 Self {
75 config,
76 agent_card,
77 task_store: Arc::new(TaskStore::new()),
78 active_tasks: Arc::new(Mutex::new(HashMap::new())),
79 }
80 }
81}
82
83fn build_runner_config(
84 controller: &A2aController,
85 root_agent: Arc<dyn adk_core::Agent>,
86 cancellation_token: Option<CancellationToken>,
87) -> Arc<RunnerConfig> {
88 let mut builder = Runner::builder()
89 .app_name(root_agent.name())
90 .agent(root_agent)
91 .session_service(controller.config.session_service.clone());
92 if let Some(ref artifact_service) = controller.config.artifact_service {
93 builder = builder.artifact_service(artifact_service.clone());
94 }
95 if let Some(ref memory_service) = controller.config.memory_service {
96 builder = builder.memory_service(memory_service.clone());
97 }
98 if let Some(ref compaction_config) = controller.config.compaction_config {
99 builder = builder.compaction_config(compaction_config.clone());
100 }
101 if let Some(ref context_cache_config) = controller.config.context_cache_config {
102 builder = builder.context_cache_config(context_cache_config.clone());
103 }
104 if let Some(ref cache_capable) = controller.config.cache_capable {
105 builder = builder.cache_capable(cache_capable.clone());
106 }
107 if let Some(cancellation_token) = cancellation_token {
108 builder = builder.cancellation_token(cancellation_token);
109 }
110 Arc::new(builder.build_config())
111}
112
113fn build_task_from_events(task_id: &str, context_id: &str, events: &[UpdateEvent]) -> Task {
114 let mut task = Task {
115 id: task_id.to_string(),
116 context_id: Some(context_id.to_string()),
117 status: TaskStatus { state: TaskState::Completed, message: None },
118 artifacts: Some(vec![]),
119 history: None,
120 };
121
122 for event in events {
123 match event {
124 UpdateEvent::TaskStatusUpdate(status) => {
125 task.status = status.status.clone();
126 }
127 UpdateEvent::TaskArtifactUpdate(artifact) => {
128 if let Some(ref mut artifacts) = task.artifacts {
129 artifacts.push(artifact.artifact.clone());
130 }
131 }
132 }
133 }
134
135 task
136}
137
138fn build_failed_task(task_id: &str, context_id: &str, message: impl Into<String>) -> Task {
139 Task {
140 id: task_id.to_string(),
141 context_id: Some(context_id.to_string()),
142 status: TaskStatus { state: TaskState::Failed, message: Some(message.into()) },
143 artifacts: None,
144 history: None,
145 }
146}
147
148fn build_canceled_task(task_id: &str, context_id: &str) -> Task {
149 Task {
150 id: task_id.to_string(),
151 context_id: Some(context_id.to_string()),
152 status: TaskStatus { state: TaskState::Canceled, message: None },
153 artifacts: None,
154 history: None,
155 }
156}
157
158fn sanitize_internal_error(config: &ServerConfig, error: &adk_core::AdkError) -> String {
159 if config.security.expose_error_details {
160 error.to_string()
161 } else {
162 "Internal server error".to_string()
163 }
164}
165
166async fn start_task(
167 controller: &A2aController,
168 context_id: String,
169 task_id: String,
170 message: Message,
171 stream_updates: bool,
172) -> (oneshot::Receiver<adk_core::Result<Task>>, Option<mpsc::Receiver<StreamTaskMessage>>) {
173 let token = CancellationToken::new();
174 let completion = Arc::new(Notify::new());
175 let (task_tx, task_rx) = oneshot::channel();
176 let (stream_tx, stream_rx) = if stream_updates {
177 let (tx, rx) = mpsc::channel(32);
178 (Some(tx), Some(rx))
179 } else {
180 (None, None)
181 };
182
183 let root_agent = controller.config.agent_loader.root_agent();
184 let executor = Executor::new(ExecutorConfig {
185 app_name: root_agent.name().to_string(),
186 runner_config: build_runner_config(controller, root_agent, Some(token.clone())),
187 cancellation_token: Some(token.clone()),
188 #[cfg(feature = "a2a-interceptors")]
189 interceptor_chain: controller.config.interceptor_chain.clone(),
190 });
191
192 let controller_clone = controller.clone();
193 let completion_clone = completion.clone();
194 let task_id_for_task = task_id.clone();
195 let context_id_for_task = context_id.clone();
196 let stream_tx_for_task = stream_tx.clone();
197
198 let join_handle = tokio::spawn(async move {
199 let result = executor.execute(&context_id_for_task, &task_id_for_task, &message).await;
200
201 match result {
202 Ok(events) => {
203 if let Some(sender) = stream_tx_for_task {
204 for event in &events {
205 if sender
206 .send(StreamTaskMessage::Update(Box::new(event.clone())))
207 .await
208 .is_err()
209 {
210 break;
211 }
212 }
213 }
214
215 let task = build_task_from_events(&task_id_for_task, &context_id_for_task, &events);
216 controller_clone.task_store.store(task.clone()).await;
217 let _ = task_tx.send(Ok(task));
218 }
219 Err(error) => {
220 if let Some(sender) = stream_tx_for_task {
221 let _ = sender
222 .send(StreamTaskMessage::Error(sanitize_internal_error(
223 &controller_clone.config,
224 &error,
225 )))
226 .await;
227 }
228 controller_clone
229 .task_store
230 .store(build_failed_task(
231 &task_id_for_task,
232 &context_id_for_task,
233 error.to_string(),
234 ))
235 .await;
236 let _ = task_tx.send(Err(error));
237 }
238 }
239
240 controller_clone.active_tasks.lock().await.remove(&task_id_for_task);
241 completion_clone.notify_waiters();
242 });
243
244 controller.active_tasks.lock().await.insert(
245 task_id,
246 ActiveTask { token, abort_handle: join_handle.abort_handle(), completion, context_id },
247 );
248
249 (task_rx, stream_rx)
250}
251
252pub async fn get_agent_card(State(controller): State<A2aController>) -> impl IntoResponse {
254 Json(controller.agent_card.clone())
255}
256
257pub async fn handle_jsonrpc(
259 State(controller): State<A2aController>,
260 Json(request): Json<JsonRpcRequest>,
261) -> impl IntoResponse {
262 if request.jsonrpc != "2.0" {
263 return Json(JsonRpcResponse::error(
264 request.id,
265 JsonRpcError::invalid_request("Invalid JSON-RPC version"),
266 ));
267 }
268
269 match request.method.as_str() {
270 jsonrpc::methods::MESSAGE_SEND => {
271 handle_message_send(&controller, request.params, request.id).await
272 }
273 jsonrpc::methods::TASKS_GET => {
274 handle_tasks_get(&controller, request.params, request.id).await
275 }
276 jsonrpc::methods::TASKS_CANCEL => {
277 handle_tasks_cancel(&controller, request.params, request.id).await
278 }
279 _ => Json(JsonRpcResponse::error(
280 request.id,
281 JsonRpcError::method_not_found(&request.method),
282 )),
283 }
284}
285
286pub async fn handle_jsonrpc_stream(
288 State(controller): State<A2aController>,
289 Json(request): Json<JsonRpcRequest>,
290) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<JsonRpcResponse>)>
291{
292 if request.jsonrpc != "2.0" {
293 return Err((
294 StatusCode::BAD_REQUEST,
295 Json(JsonRpcResponse::error(
296 request.id.clone(),
297 JsonRpcError::invalid_request("Invalid JSON-RPC version"),
298 )),
299 ));
300 }
301
302 if request.method != jsonrpc::methods::MESSAGE_SEND_STREAM
303 && request.method != jsonrpc::methods::MESSAGE_SEND
304 {
305 return Err((
306 StatusCode::BAD_REQUEST,
307 Json(JsonRpcResponse::error(
308 request.id.clone(),
309 JsonRpcError::method_not_found(&request.method),
310 )),
311 ));
312 }
313
314 let params: MessageSendParams = match request.params {
315 Some(p) => serde_json::from_value(p).map_err(|e| {
316 (
317 StatusCode::BAD_REQUEST,
318 Json(JsonRpcResponse::error(
319 request.id.clone(),
320 JsonRpcError::invalid_params(e.to_string()),
321 )),
322 )
323 })?,
324 None => {
325 return Err((
326 StatusCode::BAD_REQUEST,
327 Json(JsonRpcResponse::error(
328 request.id.clone(),
329 JsonRpcError::invalid_params("Missing params"),
330 )),
331 ));
332 }
333 };
334
335 let request_id = request.id.clone();
336 let stream = create_message_stream(controller, params, request_id);
337
338 Ok(Sse::new(stream).keep_alive(
339 axum::response::sse::KeepAlive::new().interval(Duration::from_secs(15)).text("ping"),
340 ))
341}
342
343fn create_message_stream(
344 controller: A2aController,
345 params: MessageSendParams,
346 request_id: Option<Value>,
347) -> impl Stream<Item = Result<Event, Infallible>> {
348 async_stream::stream! {
349 let context_id = params
350 .message
351 .context_id
352 .clone()
353 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
354 let task_id = params
355 .message
356 .task_id
357 .clone()
358 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
359
360 let (_task_rx, maybe_stream_rx) = start_task(
361 &controller,
362 context_id.clone(),
363 task_id.clone(),
364 params.message.clone(),
365 true,
366 )
367 .await;
368
369 let Some(mut stream_rx) = maybe_stream_rx else {
370 yield Ok(Event::default().event("done").data(""));
371 return;
372 };
373
374 while let Some(message) = stream_rx.recv().await {
375 match message {
376 StreamTaskMessage::Update(event) => {
377 let event_data = match event.as_ref() {
378 UpdateEvent::TaskStatusUpdate(status) => {
379 serde_json::to_string(&JsonRpcResponse::success(
380 request_id.clone(),
381 serde_json::to_value(status).unwrap_or_default(),
382 ))
383 }
384 UpdateEvent::TaskArtifactUpdate(artifact) => {
385 serde_json::to_string(&JsonRpcResponse::success(
386 request_id.clone(),
387 serde_json::to_value(artifact).unwrap_or_default(),
388 ))
389 }
390 };
391
392 if let Ok(data) = event_data {
393 yield Ok(Event::default().data(data));
394 }
395 }
396 StreamTaskMessage::Error(message) => {
397 let error_response = JsonRpcResponse::error(
398 request_id.clone(),
399 JsonRpcError::internal_error(message),
400 );
401 if let Ok(data) = serde_json::to_string(&error_response) {
402 yield Ok(Event::default().data(data));
403 }
404 }
405 }
406 }
407
408 yield Ok(Event::default().event("done").data(""));
410 }
411}
412
413async fn handle_message_send(
414 controller: &A2aController,
415 params: Option<Value>,
416 id: Option<Value>,
417) -> Json<JsonRpcResponse> {
418 let params: MessageSendParams = match params {
419 Some(p) => match serde_json::from_value(p) {
420 Ok(p) => p,
421 Err(e) => {
422 return Json(JsonRpcResponse::error(
423 id,
424 JsonRpcError::invalid_params(e.to_string()),
425 ));
426 }
427 },
428 None => {
429 return Json(JsonRpcResponse::error(
430 id,
431 JsonRpcError::invalid_params("Missing params"),
432 ));
433 }
434 };
435
436 let context_id =
437 params.message.context_id.clone().unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
438 let task_id =
439 params.message.task_id.clone().unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
440
441 let (task_rx, _) =
442 start_task(controller, context_id.clone(), task_id.clone(), params.message, false).await;
443
444 match task_rx.await {
445 Ok(Ok(task)) => {
446 Json(JsonRpcResponse::success(id, serde_json::to_value(task).unwrap_or_default()))
447 }
448 Ok(Err(e)) => Json(JsonRpcResponse::error(
449 id,
450 JsonRpcError::internal_error_sanitized(
451 &e,
452 controller.config.security.expose_error_details,
453 ),
454 )),
455 Err(_) => {
456 Json(JsonRpcResponse::error(id, JsonRpcError::internal_error("Task execution aborted")))
457 }
458 }
459}
460
461async fn handle_tasks_get(
462 controller: &A2aController,
463 params: Option<Value>,
464 id: Option<Value>,
465) -> Json<JsonRpcResponse> {
466 let params: TasksGetParams = match params {
467 Some(p) => match serde_json::from_value(p) {
468 Ok(p) => p,
469 Err(e) => {
470 return Json(JsonRpcResponse::error(
471 id,
472 JsonRpcError::invalid_params(e.to_string()),
473 ));
474 }
475 },
476 None => {
477 return Json(JsonRpcResponse::error(
478 id,
479 JsonRpcError::invalid_params("Missing params"),
480 ));
481 }
482 };
483
484 if let Some(active_task) = controller.active_tasks.lock().await.get(¶ms.task_id).cloned() {
485 let task = Task {
486 id: params.task_id.clone(),
487 context_id: Some(active_task.context_id),
488 status: TaskStatus { state: TaskState::Working, message: None },
489 artifacts: None,
490 history: None,
491 };
492
493 return Json(JsonRpcResponse::success(id, serde_json::to_value(task).unwrap_or_default()));
494 }
495
496 match controller.task_store.get(¶ms.task_id).await {
497 Some(task) => {
498 Json(JsonRpcResponse::success(id, serde_json::to_value(task).unwrap_or_default()))
499 }
500 None => Json(JsonRpcResponse::error(
501 id,
502 JsonRpcError::internal_error(format!("Task not found: {}", params.task_id)),
503 )),
504 }
505}
506
507async fn handle_tasks_cancel(
508 controller: &A2aController,
509 params: Option<Value>,
510 id: Option<Value>,
511) -> Json<JsonRpcResponse> {
512 let params: TasksCancelParams = match params {
513 Some(p) => match serde_json::from_value(p) {
514 Ok(p) => p,
515 Err(e) => {
516 return Json(JsonRpcResponse::error(
517 id,
518 JsonRpcError::invalid_params(e.to_string()),
519 ));
520 }
521 },
522 None => {
523 return Json(JsonRpcResponse::error(
524 id,
525 JsonRpcError::invalid_params("Missing params"),
526 ));
527 }
528 };
529
530 let active_task = controller.active_tasks.lock().await.get(¶ms.task_id).cloned();
531
532 if let Some(active_task) = active_task {
533 active_task.token.cancel();
534
535 if tokio::time::timeout(Duration::from_secs(5), active_task.completion.notified())
536 .await
537 .is_err()
538 {
539 active_task.abort_handle.abort();
540 controller.active_tasks.lock().await.remove(¶ms.task_id);
541 controller
542 .task_store
543 .store(build_canceled_task(¶ms.task_id, &active_task.context_id))
544 .await;
545 }
546
547 let status = TaskStatusUpdateEvent {
548 task_id: params.task_id,
549 context_id: Some(active_task.context_id),
550 status: TaskStatus { state: TaskState::Canceled, message: None },
551 final_update: true,
552 };
553
554 return Json(JsonRpcResponse::success(
555 id,
556 serde_json::to_value(status).unwrap_or_default(),
557 ));
558 }
559
560 let status = TaskStatusUpdateEvent {
561 task_id: params.task_id,
562 context_id: Some(uuid::Uuid::new_v4().to_string()),
563 status: TaskStatus { state: TaskState::Canceled, message: None },
564 final_update: true,
565 };
566
567 Json(JsonRpcResponse::success(id, serde_json::to_value(status).unwrap_or_default()))
568}