1use crate::ServerConfig;
2use crate::a2a::{
3 AgentCard, Executor, ExecutorConfig, JsonRpcError, JsonRpcRequest, JsonRpcResponse, Message,
4 MessageSendParams, Task, TaskState, TaskStatus, TaskStatusUpdateEvent, TasksCancelParams,
5 TasksGetParams, UpdateEvent, build_agent_card, jsonrpc,
6};
7use adk_runner::RunnerConfig;
8use axum::{
9 extract::State,
10 http::StatusCode,
11 response::{
12 IntoResponse, Json,
13 sse::{Event, Sse},
14 },
15};
16use futures::stream::Stream;
17use serde_json::Value;
18use std::{collections::HashMap, convert::Infallible, sync::Arc, time::Duration};
19use tokio::sync::{Mutex, Notify, RwLock, mpsc, oneshot};
20use tokio_util::sync::CancellationToken;
21
22#[derive(Default)]
24pub struct TaskStore {
25 tasks: RwLock<HashMap<String, Task>>,
26}
27
28impl TaskStore {
29 pub fn new() -> Self {
30 Self::default()
31 }
32
33 pub async fn store(&self, task: Task) {
34 self.tasks.write().await.insert(task.id.clone(), task);
35 }
36
37 pub async fn get(&self, task_id: &str) -> Option<Task> {
38 self.tasks.read().await.get(task_id).cloned()
39 }
40
41 pub async fn remove(&self, task_id: &str) -> Option<Task> {
42 self.tasks.write().await.remove(task_id)
43 }
44}
45
46#[derive(Clone)]
47struct ActiveTask {
48 token: CancellationToken,
49 abort_handle: tokio::task::AbortHandle,
50 completion: Arc<Notify>,
51 context_id: String,
52}
53
54enum StreamTaskMessage {
55 Update(Box<UpdateEvent>),
56 Error(String),
57}
58
59#[derive(Clone)]
61pub struct A2aController {
62 config: ServerConfig,
63 agent_card: AgentCard,
64 task_store: Arc<TaskStore>,
65 active_tasks: Arc<Mutex<HashMap<String, ActiveTask>>>,
66}
67
68impl A2aController {
69 pub fn new(config: ServerConfig, base_url: &str) -> Self {
70 let root_agent = config.agent_loader.root_agent();
71 let invoke_url = format!("{}/a2a", base_url.trim_end_matches('/'));
72 let agent_card = build_agent_card(root_agent.as_ref(), &invoke_url);
73
74 Self {
75 config,
76 agent_card,
77 task_store: Arc::new(TaskStore::new()),
78 active_tasks: Arc::new(Mutex::new(HashMap::new())),
79 }
80 }
81}
82
83fn build_runner_config(
84 controller: &A2aController,
85 root_agent: Arc<dyn adk_core::Agent>,
86 cancellation_token: Option<CancellationToken>,
87) -> Arc<RunnerConfig> {
88 Arc::new(RunnerConfig {
89 app_name: root_agent.name().to_string(),
90 agent: root_agent,
91 session_service: controller.config.session_service.clone(),
92 artifact_service: controller.config.artifact_service.clone(),
93 memory_service: controller.config.memory_service.clone(),
94 plugin_manager: None,
95 run_config: None,
96 compaction_config: controller.config.compaction_config.clone(),
97 context_cache_config: controller.config.context_cache_config.clone(),
98 cache_capable: controller.config.cache_capable.clone(),
99 request_context: None,
100 cancellation_token,
101 })
102}
103
104fn build_task_from_events(task_id: &str, context_id: &str, events: &[UpdateEvent]) -> Task {
105 let mut task = Task {
106 id: task_id.to_string(),
107 context_id: Some(context_id.to_string()),
108 status: TaskStatus { state: TaskState::Completed, message: None },
109 artifacts: Some(vec![]),
110 history: None,
111 };
112
113 for event in events {
114 match event {
115 UpdateEvent::TaskStatusUpdate(status) => {
116 task.status = status.status.clone();
117 }
118 UpdateEvent::TaskArtifactUpdate(artifact) => {
119 if let Some(ref mut artifacts) = task.artifacts {
120 artifacts.push(artifact.artifact.clone());
121 }
122 }
123 }
124 }
125
126 task
127}
128
129fn build_failed_task(task_id: &str, context_id: &str, message: impl Into<String>) -> Task {
130 Task {
131 id: task_id.to_string(),
132 context_id: Some(context_id.to_string()),
133 status: TaskStatus { state: TaskState::Failed, message: Some(message.into()) },
134 artifacts: None,
135 history: None,
136 }
137}
138
139fn build_canceled_task(task_id: &str, context_id: &str) -> Task {
140 Task {
141 id: task_id.to_string(),
142 context_id: Some(context_id.to_string()),
143 status: TaskStatus { state: TaskState::Canceled, message: None },
144 artifacts: None,
145 history: None,
146 }
147}
148
149fn sanitize_internal_error(config: &ServerConfig, error: &adk_core::AdkError) -> String {
150 if config.security.expose_error_details {
151 error.to_string()
152 } else {
153 "Internal server error".to_string()
154 }
155}
156
157async fn start_task(
158 controller: &A2aController,
159 context_id: String,
160 task_id: String,
161 message: Message,
162 stream_updates: bool,
163) -> (oneshot::Receiver<adk_core::Result<Task>>, Option<mpsc::Receiver<StreamTaskMessage>>) {
164 let token = CancellationToken::new();
165 let completion = Arc::new(Notify::new());
166 let (task_tx, task_rx) = oneshot::channel();
167 let (stream_tx, stream_rx) = if stream_updates {
168 let (tx, rx) = mpsc::channel(32);
169 (Some(tx), Some(rx))
170 } else {
171 (None, None)
172 };
173
174 let root_agent = controller.config.agent_loader.root_agent();
175 let executor = Executor::new(ExecutorConfig {
176 app_name: root_agent.name().to_string(),
177 runner_config: build_runner_config(controller, root_agent, Some(token.clone())),
178 cancellation_token: Some(token.clone()),
179 });
180
181 let controller_clone = controller.clone();
182 let completion_clone = completion.clone();
183 let task_id_for_task = task_id.clone();
184 let context_id_for_task = context_id.clone();
185 let stream_tx_for_task = stream_tx.clone();
186
187 let join_handle = tokio::spawn(async move {
188 let result = executor.execute(&context_id_for_task, &task_id_for_task, &message).await;
189
190 match result {
191 Ok(events) => {
192 if let Some(sender) = stream_tx_for_task {
193 for event in &events {
194 if sender
195 .send(StreamTaskMessage::Update(Box::new(event.clone())))
196 .await
197 .is_err()
198 {
199 break;
200 }
201 }
202 }
203
204 let task = build_task_from_events(&task_id_for_task, &context_id_for_task, &events);
205 controller_clone.task_store.store(task.clone()).await;
206 let _ = task_tx.send(Ok(task));
207 }
208 Err(error) => {
209 if let Some(sender) = stream_tx_for_task {
210 let _ = sender
211 .send(StreamTaskMessage::Error(sanitize_internal_error(
212 &controller_clone.config,
213 &error,
214 )))
215 .await;
216 }
217 controller_clone
218 .task_store
219 .store(build_failed_task(
220 &task_id_for_task,
221 &context_id_for_task,
222 error.to_string(),
223 ))
224 .await;
225 let _ = task_tx.send(Err(error));
226 }
227 }
228
229 controller_clone.active_tasks.lock().await.remove(&task_id_for_task);
230 completion_clone.notify_waiters();
231 });
232
233 controller.active_tasks.lock().await.insert(
234 task_id,
235 ActiveTask { token, abort_handle: join_handle.abort_handle(), completion, context_id },
236 );
237
238 (task_rx, stream_rx)
239}
240
241pub async fn get_agent_card(State(controller): State<A2aController>) -> impl IntoResponse {
243 Json(controller.agent_card.clone())
244}
245
246pub async fn handle_jsonrpc(
248 State(controller): State<A2aController>,
249 Json(request): Json<JsonRpcRequest>,
250) -> impl IntoResponse {
251 if request.jsonrpc != "2.0" {
252 return Json(JsonRpcResponse::error(
253 request.id,
254 JsonRpcError::invalid_request("Invalid JSON-RPC version"),
255 ));
256 }
257
258 match request.method.as_str() {
259 jsonrpc::methods::MESSAGE_SEND => {
260 handle_message_send(&controller, request.params, request.id).await
261 }
262 jsonrpc::methods::TASKS_GET => {
263 handle_tasks_get(&controller, request.params, request.id).await
264 }
265 jsonrpc::methods::TASKS_CANCEL => {
266 handle_tasks_cancel(&controller, request.params, request.id).await
267 }
268 _ => Json(JsonRpcResponse::error(
269 request.id,
270 JsonRpcError::method_not_found(&request.method),
271 )),
272 }
273}
274
275pub async fn handle_jsonrpc_stream(
277 State(controller): State<A2aController>,
278 Json(request): Json<JsonRpcRequest>,
279) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<JsonRpcResponse>)>
280{
281 if request.jsonrpc != "2.0" {
282 return Err((
283 StatusCode::BAD_REQUEST,
284 Json(JsonRpcResponse::error(
285 request.id.clone(),
286 JsonRpcError::invalid_request("Invalid JSON-RPC version"),
287 )),
288 ));
289 }
290
291 if request.method != jsonrpc::methods::MESSAGE_SEND_STREAM
292 && request.method != jsonrpc::methods::MESSAGE_SEND
293 {
294 return Err((
295 StatusCode::BAD_REQUEST,
296 Json(JsonRpcResponse::error(
297 request.id.clone(),
298 JsonRpcError::method_not_found(&request.method),
299 )),
300 ));
301 }
302
303 let params: MessageSendParams = match request.params {
304 Some(p) => serde_json::from_value(p).map_err(|e| {
305 (
306 StatusCode::BAD_REQUEST,
307 Json(JsonRpcResponse::error(
308 request.id.clone(),
309 JsonRpcError::invalid_params(e.to_string()),
310 )),
311 )
312 })?,
313 None => {
314 return Err((
315 StatusCode::BAD_REQUEST,
316 Json(JsonRpcResponse::error(
317 request.id.clone(),
318 JsonRpcError::invalid_params("Missing params"),
319 )),
320 ));
321 }
322 };
323
324 let request_id = request.id.clone();
325 let stream = create_message_stream(controller, params, request_id);
326
327 Ok(Sse::new(stream).keep_alive(
328 axum::response::sse::KeepAlive::new().interval(Duration::from_secs(15)).text("ping"),
329 ))
330}
331
332fn create_message_stream(
333 controller: A2aController,
334 params: MessageSendParams,
335 request_id: Option<Value>,
336) -> impl Stream<Item = Result<Event, Infallible>> {
337 async_stream::stream! {
338 let context_id = params
339 .message
340 .context_id
341 .clone()
342 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
343 let task_id = params
344 .message
345 .task_id
346 .clone()
347 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
348
349 let (_task_rx, maybe_stream_rx) = start_task(
350 &controller,
351 context_id.clone(),
352 task_id.clone(),
353 params.message.clone(),
354 true,
355 )
356 .await;
357
358 let Some(mut stream_rx) = maybe_stream_rx else {
359 yield Ok(Event::default().event("done").data(""));
360 return;
361 };
362
363 while let Some(message) = stream_rx.recv().await {
364 match message {
365 StreamTaskMessage::Update(event) => {
366 let event_data = match event.as_ref() {
367 UpdateEvent::TaskStatusUpdate(status) => {
368 serde_json::to_string(&JsonRpcResponse::success(
369 request_id.clone(),
370 serde_json::to_value(status).unwrap_or_default(),
371 ))
372 }
373 UpdateEvent::TaskArtifactUpdate(artifact) => {
374 serde_json::to_string(&JsonRpcResponse::success(
375 request_id.clone(),
376 serde_json::to_value(artifact).unwrap_or_default(),
377 ))
378 }
379 };
380
381 if let Ok(data) = event_data {
382 yield Ok(Event::default().data(data));
383 }
384 }
385 StreamTaskMessage::Error(message) => {
386 let error_response = JsonRpcResponse::error(
387 request_id.clone(),
388 JsonRpcError::internal_error(message),
389 );
390 if let Ok(data) = serde_json::to_string(&error_response) {
391 yield Ok(Event::default().data(data));
392 }
393 }
394 }
395 }
396
397 yield Ok(Event::default().event("done").data(""));
399 }
400}
401
402async fn handle_message_send(
403 controller: &A2aController,
404 params: Option<Value>,
405 id: Option<Value>,
406) -> Json<JsonRpcResponse> {
407 let params: MessageSendParams = match params {
408 Some(p) => match serde_json::from_value(p) {
409 Ok(p) => p,
410 Err(e) => {
411 return Json(JsonRpcResponse::error(
412 id,
413 JsonRpcError::invalid_params(e.to_string()),
414 ));
415 }
416 },
417 None => {
418 return Json(JsonRpcResponse::error(
419 id,
420 JsonRpcError::invalid_params("Missing params"),
421 ));
422 }
423 };
424
425 let context_id =
426 params.message.context_id.clone().unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
427 let task_id =
428 params.message.task_id.clone().unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
429
430 let (task_rx, _) =
431 start_task(controller, context_id.clone(), task_id.clone(), params.message, false).await;
432
433 match task_rx.await {
434 Ok(Ok(task)) => {
435 Json(JsonRpcResponse::success(id, serde_json::to_value(task).unwrap_or_default()))
436 }
437 Ok(Err(e)) => Json(JsonRpcResponse::error(
438 id,
439 JsonRpcError::internal_error_sanitized(
440 &e,
441 controller.config.security.expose_error_details,
442 ),
443 )),
444 Err(_) => {
445 Json(JsonRpcResponse::error(id, JsonRpcError::internal_error("Task execution aborted")))
446 }
447 }
448}
449
450async fn handle_tasks_get(
451 controller: &A2aController,
452 params: Option<Value>,
453 id: Option<Value>,
454) -> Json<JsonRpcResponse> {
455 let params: TasksGetParams = match params {
456 Some(p) => match serde_json::from_value(p) {
457 Ok(p) => p,
458 Err(e) => {
459 return Json(JsonRpcResponse::error(
460 id,
461 JsonRpcError::invalid_params(e.to_string()),
462 ));
463 }
464 },
465 None => {
466 return Json(JsonRpcResponse::error(
467 id,
468 JsonRpcError::invalid_params("Missing params"),
469 ));
470 }
471 };
472
473 if let Some(active_task) = controller.active_tasks.lock().await.get(¶ms.task_id).cloned() {
474 let task = Task {
475 id: params.task_id.clone(),
476 context_id: Some(active_task.context_id),
477 status: TaskStatus { state: TaskState::Working, message: None },
478 artifacts: None,
479 history: None,
480 };
481
482 return Json(JsonRpcResponse::success(id, serde_json::to_value(task).unwrap_or_default()));
483 }
484
485 match controller.task_store.get(¶ms.task_id).await {
486 Some(task) => {
487 Json(JsonRpcResponse::success(id, serde_json::to_value(task).unwrap_or_default()))
488 }
489 None => Json(JsonRpcResponse::error(
490 id,
491 JsonRpcError::internal_error(format!("Task not found: {}", params.task_id)),
492 )),
493 }
494}
495
496async fn handle_tasks_cancel(
497 controller: &A2aController,
498 params: Option<Value>,
499 id: Option<Value>,
500) -> Json<JsonRpcResponse> {
501 let params: TasksCancelParams = match params {
502 Some(p) => match serde_json::from_value(p) {
503 Ok(p) => p,
504 Err(e) => {
505 return Json(JsonRpcResponse::error(
506 id,
507 JsonRpcError::invalid_params(e.to_string()),
508 ));
509 }
510 },
511 None => {
512 return Json(JsonRpcResponse::error(
513 id,
514 JsonRpcError::invalid_params("Missing params"),
515 ));
516 }
517 };
518
519 let active_task = controller.active_tasks.lock().await.get(¶ms.task_id).cloned();
520
521 if let Some(active_task) = active_task {
522 active_task.token.cancel();
523
524 if tokio::time::timeout(Duration::from_secs(5), active_task.completion.notified())
525 .await
526 .is_err()
527 {
528 active_task.abort_handle.abort();
529 controller.active_tasks.lock().await.remove(¶ms.task_id);
530 controller
531 .task_store
532 .store(build_canceled_task(¶ms.task_id, &active_task.context_id))
533 .await;
534 }
535
536 let status = TaskStatusUpdateEvent {
537 task_id: params.task_id,
538 context_id: Some(active_task.context_id),
539 status: TaskStatus { state: TaskState::Canceled, message: None },
540 final_update: true,
541 };
542
543 return Json(JsonRpcResponse::success(
544 id,
545 serde_json::to_value(status).unwrap_or_default(),
546 ));
547 }
548
549 let status = TaskStatusUpdateEvent {
550 task_id: params.task_id,
551 context_id: Some(uuid::Uuid::new_v4().to_string()),
552 status: TaskStatus { state: TaskState::Canceled, message: None },
553 final_update: true,
554 };
555
556 Json(JsonRpcResponse::success(id, serde_json::to_value(status).unwrap_or_default()))
557}