Skip to main content

adk_server/rest/controllers/
a2a.rs

1use crate::ServerConfig;
2use crate::a2a::{
3    AgentCard, Executor, ExecutorConfig, JsonRpcError, JsonRpcRequest, JsonRpcResponse, Message,
4    MessageSendParams, Task, TaskState, TaskStatus, TaskStatusUpdateEvent, TasksCancelParams,
5    TasksGetParams, UpdateEvent, build_agent_card, jsonrpc,
6};
7use adk_runner::{Runner, RunnerConfig};
8use axum::{
9    extract::State,
10    http::StatusCode,
11    response::{
12        IntoResponse, Json,
13        sse::{Event, Sse},
14    },
15};
16use futures::stream::Stream;
17use serde_json::Value;
18use std::{collections::HashMap, convert::Infallible, sync::Arc, time::Duration};
19use tokio::sync::{Mutex, Notify, RwLock, mpsc, oneshot};
20use tokio_util::sync::CancellationToken;
21
22/// In-memory task storage
23#[derive(Default)]
24pub struct TaskStore {
25    tasks: RwLock<HashMap<String, Task>>,
26}
27
28impl TaskStore {
29    pub fn new() -> Self {
30        Self::default()
31    }
32
33    pub async fn store(&self, task: Task) {
34        self.tasks.write().await.insert(task.id.clone(), task);
35    }
36
37    pub async fn get(&self, task_id: &str) -> Option<Task> {
38        self.tasks.read().await.get(task_id).cloned()
39    }
40
41    pub async fn remove(&self, task_id: &str) -> Option<Task> {
42        self.tasks.write().await.remove(task_id)
43    }
44}
45
46#[derive(Clone)]
47struct ActiveTask {
48    token: CancellationToken,
49    abort_handle: tokio::task::AbortHandle,
50    completion: Arc<Notify>,
51    context_id: String,
52}
53
54enum StreamTaskMessage {
55    Update(Box<UpdateEvent>),
56    Error(String),
57}
58
59/// Controller for A2A protocol endpoints
60#[derive(Clone)]
61pub struct A2aController {
62    config: ServerConfig,
63    agent_card: AgentCard,
64    task_store: Arc<TaskStore>,
65    active_tasks: Arc<Mutex<HashMap<String, ActiveTask>>>,
66}
67
68impl A2aController {
69    pub fn new(config: ServerConfig, base_url: &str) -> Self {
70        let root_agent = config.agent_loader.root_agent();
71        let invoke_url = format!("{}/a2a", base_url.trim_end_matches('/'));
72        let agent_card = build_agent_card(root_agent.as_ref(), &invoke_url);
73
74        Self {
75            config,
76            agent_card,
77            task_store: Arc::new(TaskStore::new()),
78            active_tasks: Arc::new(Mutex::new(HashMap::new())),
79        }
80    }
81}
82
83fn build_runner_config(
84    controller: &A2aController,
85    root_agent: Arc<dyn adk_core::Agent>,
86    cancellation_token: Option<CancellationToken>,
87) -> Arc<RunnerConfig> {
88    let mut builder = Runner::builder()
89        .app_name(root_agent.name())
90        .agent(root_agent)
91        .session_service(controller.config.session_service.clone());
92    if let Some(ref artifact_service) = controller.config.artifact_service {
93        builder = builder.artifact_service(artifact_service.clone());
94    }
95    if let Some(ref memory_service) = controller.config.memory_service {
96        builder = builder.memory_service(memory_service.clone());
97    }
98    if let Some(ref compaction_config) = controller.config.compaction_config {
99        builder = builder.compaction_config(compaction_config.clone());
100    }
101    if let Some(ref context_cache_config) = controller.config.context_cache_config {
102        builder = builder.context_cache_config(context_cache_config.clone());
103    }
104    if let Some(ref cache_capable) = controller.config.cache_capable {
105        builder = builder.cache_capable(cache_capable.clone());
106    }
107    if let Some(cancellation_token) = cancellation_token {
108        builder = builder.cancellation_token(cancellation_token);
109    }
110    Arc::new(builder.build_config())
111}
112
113fn build_task_from_events(task_id: &str, context_id: &str, events: &[UpdateEvent]) -> Task {
114    let mut task = Task {
115        id: task_id.to_string(),
116        context_id: Some(context_id.to_string()),
117        status: TaskStatus { state: TaskState::Completed, message: None },
118        artifacts: Some(vec![]),
119        history: None,
120    };
121
122    for event in events {
123        match event {
124            UpdateEvent::TaskStatusUpdate(status) => {
125                task.status = status.status.clone();
126            }
127            UpdateEvent::TaskArtifactUpdate(artifact) => {
128                if let Some(ref mut artifacts) = task.artifacts {
129                    artifacts.push(artifact.artifact.clone());
130                }
131            }
132        }
133    }
134
135    task
136}
137
138fn build_failed_task(task_id: &str, context_id: &str, message: impl Into<String>) -> Task {
139    Task {
140        id: task_id.to_string(),
141        context_id: Some(context_id.to_string()),
142        status: TaskStatus { state: TaskState::Failed, message: Some(message.into()) },
143        artifacts: None,
144        history: None,
145    }
146}
147
148fn build_canceled_task(task_id: &str, context_id: &str) -> Task {
149    Task {
150        id: task_id.to_string(),
151        context_id: Some(context_id.to_string()),
152        status: TaskStatus { state: TaskState::Canceled, message: None },
153        artifacts: None,
154        history: None,
155    }
156}
157
158fn sanitize_internal_error(config: &ServerConfig, error: &adk_core::AdkError) -> String {
159    if config.security.expose_error_details {
160        error.to_string()
161    } else {
162        "Internal server error".to_string()
163    }
164}
165
166async fn start_task(
167    controller: &A2aController,
168    context_id: String,
169    task_id: String,
170    message: Message,
171    stream_updates: bool,
172) -> (oneshot::Receiver<adk_core::Result<Task>>, Option<mpsc::Receiver<StreamTaskMessage>>) {
173    let token = CancellationToken::new();
174    let completion = Arc::new(Notify::new());
175    let (task_tx, task_rx) = oneshot::channel();
176    let (stream_tx, stream_rx) = if stream_updates {
177        let (tx, rx) = mpsc::channel(32);
178        (Some(tx), Some(rx))
179    } else {
180        (None, None)
181    };
182
183    let root_agent = controller.config.agent_loader.root_agent();
184    let executor = Executor::new(ExecutorConfig {
185        app_name: root_agent.name().to_string(),
186        runner_config: build_runner_config(controller, root_agent, Some(token.clone())),
187        cancellation_token: Some(token.clone()),
188        #[cfg(feature = "a2a-interceptors")]
189        interceptor_chain: controller.config.interceptor_chain.clone(),
190    });
191
192    let controller_clone = controller.clone();
193    let completion_clone = completion.clone();
194    let task_id_for_task = task_id.clone();
195    let context_id_for_task = context_id.clone();
196    let stream_tx_for_task = stream_tx.clone();
197
198    let join_handle = tokio::spawn(async move {
199        let result = executor.execute(&context_id_for_task, &task_id_for_task, &message).await;
200
201        match result {
202            Ok(events) => {
203                if let Some(sender) = stream_tx_for_task {
204                    for event in &events {
205                        if sender
206                            .send(StreamTaskMessage::Update(Box::new(event.clone())))
207                            .await
208                            .is_err()
209                        {
210                            break;
211                        }
212                    }
213                }
214
215                let task = build_task_from_events(&task_id_for_task, &context_id_for_task, &events);
216                controller_clone.task_store.store(task.clone()).await;
217                let _ = task_tx.send(Ok(task));
218            }
219            Err(error) => {
220                if let Some(sender) = stream_tx_for_task {
221                    let _ = sender
222                        .send(StreamTaskMessage::Error(sanitize_internal_error(
223                            &controller_clone.config,
224                            &error,
225                        )))
226                        .await;
227                }
228                controller_clone
229                    .task_store
230                    .store(build_failed_task(
231                        &task_id_for_task,
232                        &context_id_for_task,
233                        error.to_string(),
234                    ))
235                    .await;
236                let _ = task_tx.send(Err(error));
237            }
238        }
239
240        controller_clone.active_tasks.lock().await.remove(&task_id_for_task);
241        completion_clone.notify_waiters();
242    });
243
244    controller.active_tasks.lock().await.insert(
245        task_id,
246        ActiveTask { token, abort_handle: join_handle.abort_handle(), completion, context_id },
247    );
248
249    (task_rx, stream_rx)
250}
251
252/// GET /.well-known/agent.json - Serve the agent card
253pub async fn get_agent_card(State(controller): State<A2aController>) -> impl IntoResponse {
254    Json(controller.agent_card.clone())
255}
256
257/// POST /a2a - JSON-RPC endpoint for A2A protocol
258pub async fn handle_jsonrpc(
259    State(controller): State<A2aController>,
260    Json(request): Json<JsonRpcRequest>,
261) -> impl IntoResponse {
262    if request.jsonrpc != "2.0" {
263        return Json(JsonRpcResponse::error(
264            request.id,
265            JsonRpcError::invalid_request("Invalid JSON-RPC version"),
266        ));
267    }
268
269    match request.method.as_str() {
270        jsonrpc::methods::MESSAGE_SEND => {
271            handle_message_send(&controller, request.params, request.id).await
272        }
273        jsonrpc::methods::TASKS_GET => {
274            handle_tasks_get(&controller, request.params, request.id).await
275        }
276        jsonrpc::methods::TASKS_CANCEL => {
277            handle_tasks_cancel(&controller, request.params, request.id).await
278        }
279        _ => Json(JsonRpcResponse::error(
280            request.id,
281            JsonRpcError::method_not_found(&request.method),
282        )),
283    }
284}
285
286/// POST /a2a/stream - SSE streaming endpoint for A2A protocol
287pub async fn handle_jsonrpc_stream(
288    State(controller): State<A2aController>,
289    Json(request): Json<JsonRpcRequest>,
290) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<JsonRpcResponse>)>
291{
292    if request.jsonrpc != "2.0" {
293        return Err((
294            StatusCode::BAD_REQUEST,
295            Json(JsonRpcResponse::error(
296                request.id.clone(),
297                JsonRpcError::invalid_request("Invalid JSON-RPC version"),
298            )),
299        ));
300    }
301
302    if request.method != jsonrpc::methods::MESSAGE_SEND_STREAM
303        && request.method != jsonrpc::methods::MESSAGE_SEND
304    {
305        return Err((
306            StatusCode::BAD_REQUEST,
307            Json(JsonRpcResponse::error(
308                request.id.clone(),
309                JsonRpcError::method_not_found(&request.method),
310            )),
311        ));
312    }
313
314    let params: MessageSendParams = match request.params {
315        Some(p) => serde_json::from_value(p).map_err(|e| {
316            (
317                StatusCode::BAD_REQUEST,
318                Json(JsonRpcResponse::error(
319                    request.id.clone(),
320                    JsonRpcError::invalid_params(e.to_string()),
321                )),
322            )
323        })?,
324        None => {
325            return Err((
326                StatusCode::BAD_REQUEST,
327                Json(JsonRpcResponse::error(
328                    request.id.clone(),
329                    JsonRpcError::invalid_params("Missing params"),
330                )),
331            ));
332        }
333    };
334
335    let request_id = request.id.clone();
336    let stream = create_message_stream(controller, params, request_id);
337
338    Ok(Sse::new(stream).keep_alive(
339        axum::response::sse::KeepAlive::new().interval(Duration::from_secs(15)).text("ping"),
340    ))
341}
342
343fn create_message_stream(
344    controller: A2aController,
345    params: MessageSendParams,
346    request_id: Option<Value>,
347) -> impl Stream<Item = Result<Event, Infallible>> {
348    async_stream::stream! {
349        let context_id = params
350            .message
351            .context_id
352            .clone()
353            .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
354        let task_id = params
355            .message
356            .task_id
357            .clone()
358            .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
359
360        let (_task_rx, maybe_stream_rx) = start_task(
361            &controller,
362            context_id.clone(),
363            task_id.clone(),
364            params.message.clone(),
365            true,
366        )
367        .await;
368
369        let Some(mut stream_rx) = maybe_stream_rx else {
370            yield Ok(Event::default().event("done").data(""));
371            return;
372        };
373
374        while let Some(message) = stream_rx.recv().await {
375            match message {
376                StreamTaskMessage::Update(event) => {
377                    let event_data = match event.as_ref() {
378                        UpdateEvent::TaskStatusUpdate(status) => {
379                            serde_json::to_string(&JsonRpcResponse::success(
380                                request_id.clone(),
381                                serde_json::to_value(status).unwrap_or_default(),
382                            ))
383                        }
384                        UpdateEvent::TaskArtifactUpdate(artifact) => {
385                            serde_json::to_string(&JsonRpcResponse::success(
386                                request_id.clone(),
387                                serde_json::to_value(artifact).unwrap_or_default(),
388                            ))
389                        }
390                    };
391
392                    if let Ok(data) = event_data {
393                        yield Ok(Event::default().data(data));
394                    }
395                }
396                StreamTaskMessage::Error(message) => {
397                    let error_response = JsonRpcResponse::error(
398                        request_id.clone(),
399                        JsonRpcError::internal_error(message),
400                    );
401                    if let Ok(data) = serde_json::to_string(&error_response) {
402                        yield Ok(Event::default().data(data));
403                    }
404                }
405            }
406        }
407
408        // Send done event
409        yield Ok(Event::default().event("done").data(""));
410    }
411}
412
413async fn handle_message_send(
414    controller: &A2aController,
415    params: Option<Value>,
416    id: Option<Value>,
417) -> Json<JsonRpcResponse> {
418    let params: MessageSendParams = match params {
419        Some(p) => match serde_json::from_value(p) {
420            Ok(p) => p,
421            Err(e) => {
422                return Json(JsonRpcResponse::error(
423                    id,
424                    JsonRpcError::invalid_params(e.to_string()),
425                ));
426            }
427        },
428        None => {
429            return Json(JsonRpcResponse::error(
430                id,
431                JsonRpcError::invalid_params("Missing params"),
432            ));
433        }
434    };
435
436    let context_id =
437        params.message.context_id.clone().unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
438    let task_id =
439        params.message.task_id.clone().unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
440
441    let (task_rx, _) =
442        start_task(controller, context_id.clone(), task_id.clone(), params.message, false).await;
443
444    match task_rx.await {
445        Ok(Ok(task)) => {
446            Json(JsonRpcResponse::success(id, serde_json::to_value(task).unwrap_or_default()))
447        }
448        Ok(Err(e)) => Json(JsonRpcResponse::error(
449            id,
450            JsonRpcError::internal_error_sanitized(
451                &e,
452                controller.config.security.expose_error_details,
453            ),
454        )),
455        Err(_) => {
456            Json(JsonRpcResponse::error(id, JsonRpcError::internal_error("Task execution aborted")))
457        }
458    }
459}
460
461async fn handle_tasks_get(
462    controller: &A2aController,
463    params: Option<Value>,
464    id: Option<Value>,
465) -> Json<JsonRpcResponse> {
466    let params: TasksGetParams = match params {
467        Some(p) => match serde_json::from_value(p) {
468            Ok(p) => p,
469            Err(e) => {
470                return Json(JsonRpcResponse::error(
471                    id,
472                    JsonRpcError::invalid_params(e.to_string()),
473                ));
474            }
475        },
476        None => {
477            return Json(JsonRpcResponse::error(
478                id,
479                JsonRpcError::invalid_params("Missing params"),
480            ));
481        }
482    };
483
484    if let Some(active_task) = controller.active_tasks.lock().await.get(&params.task_id).cloned() {
485        let task = Task {
486            id: params.task_id.clone(),
487            context_id: Some(active_task.context_id),
488            status: TaskStatus { state: TaskState::Working, message: None },
489            artifacts: None,
490            history: None,
491        };
492
493        return Json(JsonRpcResponse::success(id, serde_json::to_value(task).unwrap_or_default()));
494    }
495
496    match controller.task_store.get(&params.task_id).await {
497        Some(task) => {
498            Json(JsonRpcResponse::success(id, serde_json::to_value(task).unwrap_or_default()))
499        }
500        None => Json(JsonRpcResponse::error(
501            id,
502            JsonRpcError::internal_error(format!("Task not found: {}", params.task_id)),
503        )),
504    }
505}
506
507async fn handle_tasks_cancel(
508    controller: &A2aController,
509    params: Option<Value>,
510    id: Option<Value>,
511) -> Json<JsonRpcResponse> {
512    let params: TasksCancelParams = match params {
513        Some(p) => match serde_json::from_value(p) {
514            Ok(p) => p,
515            Err(e) => {
516                return Json(JsonRpcResponse::error(
517                    id,
518                    JsonRpcError::invalid_params(e.to_string()),
519                ));
520            }
521        },
522        None => {
523            return Json(JsonRpcResponse::error(
524                id,
525                JsonRpcError::invalid_params("Missing params"),
526            ));
527        }
528    };
529
530    let active_task = controller.active_tasks.lock().await.get(&params.task_id).cloned();
531
532    if let Some(active_task) = active_task {
533        active_task.token.cancel();
534
535        if tokio::time::timeout(Duration::from_secs(5), active_task.completion.notified())
536            .await
537            .is_err()
538        {
539            active_task.abort_handle.abort();
540            controller.active_tasks.lock().await.remove(&params.task_id);
541            controller
542                .task_store
543                .store(build_canceled_task(&params.task_id, &active_task.context_id))
544                .await;
545        }
546
547        let status = TaskStatusUpdateEvent {
548            task_id: params.task_id,
549            context_id: Some(active_task.context_id),
550            status: TaskStatus { state: TaskState::Canceled, message: None },
551            final_update: true,
552        };
553
554        return Json(JsonRpcResponse::success(
555            id,
556            serde_json::to_value(status).unwrap_or_default(),
557        ));
558    }
559
560    let status = TaskStatusUpdateEvent {
561        task_id: params.task_id,
562        context_id: Some(uuid::Uuid::new_v4().to_string()),
563        status: TaskStatus { state: TaskState::Canceled, message: None },
564        final_update: true,
565    };
566
567    Json(JsonRpcResponse::success(id, serde_json::to_value(status).unwrap_or_default()))
568}