1use crate::ServerConfig;
2use crate::a2a::{
3 AgentCard, Executor, ExecutorConfig, JsonRpcError, JsonRpcRequest, JsonRpcResponse, Message,
4 MessageSendParams, Task, TaskState, TaskStatus, TaskStatusUpdateEvent, TasksCancelParams,
5 TasksGetParams, UpdateEvent, build_agent_card, jsonrpc,
6};
7use adk_runner::RunnerConfig;
8use axum::{
9 extract::State,
10 http::StatusCode,
11 response::{
12 IntoResponse, Json,
13 sse::{Event, Sse},
14 },
15};
16use futures::stream::Stream;
17use serde_json::Value;
18use std::{collections::HashMap, convert::Infallible, sync::Arc, time::Duration};
19use tokio::sync::{Mutex, Notify, RwLock, mpsc, oneshot};
20use tokio_util::sync::CancellationToken;
21
22#[derive(Default)]
24pub struct TaskStore {
25 tasks: RwLock<HashMap<String, Task>>,
26}
27
28impl TaskStore {
29 pub fn new() -> Self {
30 Self::default()
31 }
32
33 pub async fn store(&self, task: Task) {
34 self.tasks.write().await.insert(task.id.clone(), task);
35 }
36
37 pub async fn get(&self, task_id: &str) -> Option<Task> {
38 self.tasks.read().await.get(task_id).cloned()
39 }
40
41 pub async fn remove(&self, task_id: &str) -> Option<Task> {
42 self.tasks.write().await.remove(task_id)
43 }
44}
45
46#[derive(Clone)]
47struct ActiveTask {
48 token: CancellationToken,
49 abort_handle: tokio::task::AbortHandle,
50 completion: Arc<Notify>,
51 context_id: String,
52}
53
54enum StreamTaskMessage {
55 Update(Box<UpdateEvent>),
56 Error(String),
57}
58
59#[derive(Clone)]
61pub struct A2aController {
62 config: ServerConfig,
63 agent_card: AgentCard,
64 task_store: Arc<TaskStore>,
65 active_tasks: Arc<Mutex<HashMap<String, ActiveTask>>>,
66}
67
68impl A2aController {
69 pub fn new(config: ServerConfig, base_url: &str) -> Self {
70 let root_agent = config.agent_loader.root_agent();
71 let invoke_url = format!("{}/a2a", base_url.trim_end_matches('/'));
72 let agent_card = build_agent_card(root_agent.as_ref(), &invoke_url);
73
74 Self {
75 config,
76 agent_card,
77 task_store: Arc::new(TaskStore::new()),
78 active_tasks: Arc::new(Mutex::new(HashMap::new())),
79 }
80 }
81}
82
83fn build_runner_config(
84 controller: &A2aController,
85 root_agent: Arc<dyn adk_core::Agent>,
86 cancellation_token: Option<CancellationToken>,
87) -> Arc<RunnerConfig> {
88 Arc::new(RunnerConfig {
89 app_name: root_agent.name().to_string(),
90 agent: root_agent,
91 session_service: controller.config.session_service.clone(),
92 artifact_service: controller.config.artifact_service.clone(),
93 memory_service: controller.config.memory_service.clone(),
94 plugin_manager: None,
95 run_config: None,
96 compaction_config: controller.config.compaction_config.clone(),
97 context_cache_config: controller.config.context_cache_config.clone(),
98 cache_capable: controller.config.cache_capable.clone(),
99 request_context: None,
100 cancellation_token,
101 intra_compaction_config: None,
102 intra_compaction_summarizer: None,
103 })
104}
105
106fn build_task_from_events(task_id: &str, context_id: &str, events: &[UpdateEvent]) -> Task {
107 let mut task = Task {
108 id: task_id.to_string(),
109 context_id: Some(context_id.to_string()),
110 status: TaskStatus { state: TaskState::Completed, message: None },
111 artifacts: Some(vec![]),
112 history: None,
113 };
114
115 for event in events {
116 match event {
117 UpdateEvent::TaskStatusUpdate(status) => {
118 task.status = status.status.clone();
119 }
120 UpdateEvent::TaskArtifactUpdate(artifact) => {
121 if let Some(ref mut artifacts) = task.artifacts {
122 artifacts.push(artifact.artifact.clone());
123 }
124 }
125 }
126 }
127
128 task
129}
130
131fn build_failed_task(task_id: &str, context_id: &str, message: impl Into<String>) -> Task {
132 Task {
133 id: task_id.to_string(),
134 context_id: Some(context_id.to_string()),
135 status: TaskStatus { state: TaskState::Failed, message: Some(message.into()) },
136 artifacts: None,
137 history: None,
138 }
139}
140
141fn build_canceled_task(task_id: &str, context_id: &str) -> Task {
142 Task {
143 id: task_id.to_string(),
144 context_id: Some(context_id.to_string()),
145 status: TaskStatus { state: TaskState::Canceled, message: None },
146 artifacts: None,
147 history: None,
148 }
149}
150
151fn sanitize_internal_error(config: &ServerConfig, error: &adk_core::AdkError) -> String {
152 if config.security.expose_error_details {
153 error.to_string()
154 } else {
155 "Internal server error".to_string()
156 }
157}
158
159async fn start_task(
160 controller: &A2aController,
161 context_id: String,
162 task_id: String,
163 message: Message,
164 stream_updates: bool,
165) -> (oneshot::Receiver<adk_core::Result<Task>>, Option<mpsc::Receiver<StreamTaskMessage>>) {
166 let token = CancellationToken::new();
167 let completion = Arc::new(Notify::new());
168 let (task_tx, task_rx) = oneshot::channel();
169 let (stream_tx, stream_rx) = if stream_updates {
170 let (tx, rx) = mpsc::channel(32);
171 (Some(tx), Some(rx))
172 } else {
173 (None, None)
174 };
175
176 let root_agent = controller.config.agent_loader.root_agent();
177 let executor = Executor::new(ExecutorConfig {
178 app_name: root_agent.name().to_string(),
179 runner_config: build_runner_config(controller, root_agent, Some(token.clone())),
180 cancellation_token: Some(token.clone()),
181 });
182
183 let controller_clone = controller.clone();
184 let completion_clone = completion.clone();
185 let task_id_for_task = task_id.clone();
186 let context_id_for_task = context_id.clone();
187 let stream_tx_for_task = stream_tx.clone();
188
189 let join_handle = tokio::spawn(async move {
190 let result = executor.execute(&context_id_for_task, &task_id_for_task, &message).await;
191
192 match result {
193 Ok(events) => {
194 if let Some(sender) = stream_tx_for_task {
195 for event in &events {
196 if sender
197 .send(StreamTaskMessage::Update(Box::new(event.clone())))
198 .await
199 .is_err()
200 {
201 break;
202 }
203 }
204 }
205
206 let task = build_task_from_events(&task_id_for_task, &context_id_for_task, &events);
207 controller_clone.task_store.store(task.clone()).await;
208 let _ = task_tx.send(Ok(task));
209 }
210 Err(error) => {
211 if let Some(sender) = stream_tx_for_task {
212 let _ = sender
213 .send(StreamTaskMessage::Error(sanitize_internal_error(
214 &controller_clone.config,
215 &error,
216 )))
217 .await;
218 }
219 controller_clone
220 .task_store
221 .store(build_failed_task(
222 &task_id_for_task,
223 &context_id_for_task,
224 error.to_string(),
225 ))
226 .await;
227 let _ = task_tx.send(Err(error));
228 }
229 }
230
231 controller_clone.active_tasks.lock().await.remove(&task_id_for_task);
232 completion_clone.notify_waiters();
233 });
234
235 controller.active_tasks.lock().await.insert(
236 task_id,
237 ActiveTask { token, abort_handle: join_handle.abort_handle(), completion, context_id },
238 );
239
240 (task_rx, stream_rx)
241}
242
243pub async fn get_agent_card(State(controller): State<A2aController>) -> impl IntoResponse {
245 Json(controller.agent_card.clone())
246}
247
248pub async fn handle_jsonrpc(
250 State(controller): State<A2aController>,
251 Json(request): Json<JsonRpcRequest>,
252) -> impl IntoResponse {
253 if request.jsonrpc != "2.0" {
254 return Json(JsonRpcResponse::error(
255 request.id,
256 JsonRpcError::invalid_request("Invalid JSON-RPC version"),
257 ));
258 }
259
260 match request.method.as_str() {
261 jsonrpc::methods::MESSAGE_SEND => {
262 handle_message_send(&controller, request.params, request.id).await
263 }
264 jsonrpc::methods::TASKS_GET => {
265 handle_tasks_get(&controller, request.params, request.id).await
266 }
267 jsonrpc::methods::TASKS_CANCEL => {
268 handle_tasks_cancel(&controller, request.params, request.id).await
269 }
270 _ => Json(JsonRpcResponse::error(
271 request.id,
272 JsonRpcError::method_not_found(&request.method),
273 )),
274 }
275}
276
277pub async fn handle_jsonrpc_stream(
279 State(controller): State<A2aController>,
280 Json(request): Json<JsonRpcRequest>,
281) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<JsonRpcResponse>)>
282{
283 if request.jsonrpc != "2.0" {
284 return Err((
285 StatusCode::BAD_REQUEST,
286 Json(JsonRpcResponse::error(
287 request.id.clone(),
288 JsonRpcError::invalid_request("Invalid JSON-RPC version"),
289 )),
290 ));
291 }
292
293 if request.method != jsonrpc::methods::MESSAGE_SEND_STREAM
294 && request.method != jsonrpc::methods::MESSAGE_SEND
295 {
296 return Err((
297 StatusCode::BAD_REQUEST,
298 Json(JsonRpcResponse::error(
299 request.id.clone(),
300 JsonRpcError::method_not_found(&request.method),
301 )),
302 ));
303 }
304
305 let params: MessageSendParams = match request.params {
306 Some(p) => serde_json::from_value(p).map_err(|e| {
307 (
308 StatusCode::BAD_REQUEST,
309 Json(JsonRpcResponse::error(
310 request.id.clone(),
311 JsonRpcError::invalid_params(e.to_string()),
312 )),
313 )
314 })?,
315 None => {
316 return Err((
317 StatusCode::BAD_REQUEST,
318 Json(JsonRpcResponse::error(
319 request.id.clone(),
320 JsonRpcError::invalid_params("Missing params"),
321 )),
322 ));
323 }
324 };
325
326 let request_id = request.id.clone();
327 let stream = create_message_stream(controller, params, request_id);
328
329 Ok(Sse::new(stream).keep_alive(
330 axum::response::sse::KeepAlive::new().interval(Duration::from_secs(15)).text("ping"),
331 ))
332}
333
334fn create_message_stream(
335 controller: A2aController,
336 params: MessageSendParams,
337 request_id: Option<Value>,
338) -> impl Stream<Item = Result<Event, Infallible>> {
339 async_stream::stream! {
340 let context_id = params
341 .message
342 .context_id
343 .clone()
344 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
345 let task_id = params
346 .message
347 .task_id
348 .clone()
349 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
350
351 let (_task_rx, maybe_stream_rx) = start_task(
352 &controller,
353 context_id.clone(),
354 task_id.clone(),
355 params.message.clone(),
356 true,
357 )
358 .await;
359
360 let Some(mut stream_rx) = maybe_stream_rx else {
361 yield Ok(Event::default().event("done").data(""));
362 return;
363 };
364
365 while let Some(message) = stream_rx.recv().await {
366 match message {
367 StreamTaskMessage::Update(event) => {
368 let event_data = match event.as_ref() {
369 UpdateEvent::TaskStatusUpdate(status) => {
370 serde_json::to_string(&JsonRpcResponse::success(
371 request_id.clone(),
372 serde_json::to_value(status).unwrap_or_default(),
373 ))
374 }
375 UpdateEvent::TaskArtifactUpdate(artifact) => {
376 serde_json::to_string(&JsonRpcResponse::success(
377 request_id.clone(),
378 serde_json::to_value(artifact).unwrap_or_default(),
379 ))
380 }
381 };
382
383 if let Ok(data) = event_data {
384 yield Ok(Event::default().data(data));
385 }
386 }
387 StreamTaskMessage::Error(message) => {
388 let error_response = JsonRpcResponse::error(
389 request_id.clone(),
390 JsonRpcError::internal_error(message),
391 );
392 if let Ok(data) = serde_json::to_string(&error_response) {
393 yield Ok(Event::default().data(data));
394 }
395 }
396 }
397 }
398
399 yield Ok(Event::default().event("done").data(""));
401 }
402}
403
404async fn handle_message_send(
405 controller: &A2aController,
406 params: Option<Value>,
407 id: Option<Value>,
408) -> Json<JsonRpcResponse> {
409 let params: MessageSendParams = match params {
410 Some(p) => match serde_json::from_value(p) {
411 Ok(p) => p,
412 Err(e) => {
413 return Json(JsonRpcResponse::error(
414 id,
415 JsonRpcError::invalid_params(e.to_string()),
416 ));
417 }
418 },
419 None => {
420 return Json(JsonRpcResponse::error(
421 id,
422 JsonRpcError::invalid_params("Missing params"),
423 ));
424 }
425 };
426
427 let context_id =
428 params.message.context_id.clone().unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
429 let task_id =
430 params.message.task_id.clone().unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
431
432 let (task_rx, _) =
433 start_task(controller, context_id.clone(), task_id.clone(), params.message, false).await;
434
435 match task_rx.await {
436 Ok(Ok(task)) => {
437 Json(JsonRpcResponse::success(id, serde_json::to_value(task).unwrap_or_default()))
438 }
439 Ok(Err(e)) => Json(JsonRpcResponse::error(
440 id,
441 JsonRpcError::internal_error_sanitized(
442 &e,
443 controller.config.security.expose_error_details,
444 ),
445 )),
446 Err(_) => {
447 Json(JsonRpcResponse::error(id, JsonRpcError::internal_error("Task execution aborted")))
448 }
449 }
450}
451
452async fn handle_tasks_get(
453 controller: &A2aController,
454 params: Option<Value>,
455 id: Option<Value>,
456) -> Json<JsonRpcResponse> {
457 let params: TasksGetParams = match params {
458 Some(p) => match serde_json::from_value(p) {
459 Ok(p) => p,
460 Err(e) => {
461 return Json(JsonRpcResponse::error(
462 id,
463 JsonRpcError::invalid_params(e.to_string()),
464 ));
465 }
466 },
467 None => {
468 return Json(JsonRpcResponse::error(
469 id,
470 JsonRpcError::invalid_params("Missing params"),
471 ));
472 }
473 };
474
475 if let Some(active_task) = controller.active_tasks.lock().await.get(¶ms.task_id).cloned() {
476 let task = Task {
477 id: params.task_id.clone(),
478 context_id: Some(active_task.context_id),
479 status: TaskStatus { state: TaskState::Working, message: None },
480 artifacts: None,
481 history: None,
482 };
483
484 return Json(JsonRpcResponse::success(id, serde_json::to_value(task).unwrap_or_default()));
485 }
486
487 match controller.task_store.get(¶ms.task_id).await {
488 Some(task) => {
489 Json(JsonRpcResponse::success(id, serde_json::to_value(task).unwrap_or_default()))
490 }
491 None => Json(JsonRpcResponse::error(
492 id,
493 JsonRpcError::internal_error(format!("Task not found: {}", params.task_id)),
494 )),
495 }
496}
497
498async fn handle_tasks_cancel(
499 controller: &A2aController,
500 params: Option<Value>,
501 id: Option<Value>,
502) -> Json<JsonRpcResponse> {
503 let params: TasksCancelParams = match params {
504 Some(p) => match serde_json::from_value(p) {
505 Ok(p) => p,
506 Err(e) => {
507 return Json(JsonRpcResponse::error(
508 id,
509 JsonRpcError::invalid_params(e.to_string()),
510 ));
511 }
512 },
513 None => {
514 return Json(JsonRpcResponse::error(
515 id,
516 JsonRpcError::invalid_params("Missing params"),
517 ));
518 }
519 };
520
521 let active_task = controller.active_tasks.lock().await.get(¶ms.task_id).cloned();
522
523 if let Some(active_task) = active_task {
524 active_task.token.cancel();
525
526 if tokio::time::timeout(Duration::from_secs(5), active_task.completion.notified())
527 .await
528 .is_err()
529 {
530 active_task.abort_handle.abort();
531 controller.active_tasks.lock().await.remove(¶ms.task_id);
532 controller
533 .task_store
534 .store(build_canceled_task(¶ms.task_id, &active_task.context_id))
535 .await;
536 }
537
538 let status = TaskStatusUpdateEvent {
539 task_id: params.task_id,
540 context_id: Some(active_task.context_id),
541 status: TaskStatus { state: TaskState::Canceled, message: None },
542 final_update: true,
543 };
544
545 return Json(JsonRpcResponse::success(
546 id,
547 serde_json::to_value(status).unwrap_or_default(),
548 ));
549 }
550
551 let status = TaskStatusUpdateEvent {
552 task_id: params.task_id,
553 context_id: Some(uuid::Uuid::new_v4().to_string()),
554 status: TaskStatus { state: TaskState::Canceled, message: None },
555 final_update: true,
556 };
557
558 Json(JsonRpcResponse::success(id, serde_json::to_value(status).unwrap_or_default()))
559}