Skip to main content

adk_server/rest/controllers/
a2a.rs

1use crate::ServerConfig;
2use crate::a2a::{
3    AgentCard, Executor, ExecutorConfig, JsonRpcError, JsonRpcRequest, JsonRpcResponse, Message,
4    MessageSendParams, Task, TaskState, TaskStatus, TaskStatusUpdateEvent, TasksCancelParams,
5    TasksGetParams, UpdateEvent, build_agent_card, jsonrpc,
6};
7use adk_runner::RunnerConfig;
8use axum::{
9    extract::State,
10    http::StatusCode,
11    response::{
12        IntoResponse, Json,
13        sse::{Event, Sse},
14    },
15};
16use futures::stream::Stream;
17use serde_json::Value;
18use std::{collections::HashMap, convert::Infallible, sync::Arc, time::Duration};
19use tokio::sync::{Mutex, Notify, RwLock, mpsc, oneshot};
20use tokio_util::sync::CancellationToken;
21
22/// In-memory task storage
23#[derive(Default)]
24pub struct TaskStore {
25    tasks: RwLock<HashMap<String, Task>>,
26}
27
28impl TaskStore {
29    pub fn new() -> Self {
30        Self::default()
31    }
32
33    pub async fn store(&self, task: Task) {
34        self.tasks.write().await.insert(task.id.clone(), task);
35    }
36
37    pub async fn get(&self, task_id: &str) -> Option<Task> {
38        self.tasks.read().await.get(task_id).cloned()
39    }
40
41    pub async fn remove(&self, task_id: &str) -> Option<Task> {
42        self.tasks.write().await.remove(task_id)
43    }
44}
45
46#[derive(Clone)]
47struct ActiveTask {
48    token: CancellationToken,
49    abort_handle: tokio::task::AbortHandle,
50    completion: Arc<Notify>,
51    context_id: String,
52}
53
54enum StreamTaskMessage {
55    Update(Box<UpdateEvent>),
56    Error(String),
57}
58
59/// Controller for A2A protocol endpoints
60#[derive(Clone)]
61pub struct A2aController {
62    config: ServerConfig,
63    agent_card: AgentCard,
64    task_store: Arc<TaskStore>,
65    active_tasks: Arc<Mutex<HashMap<String, ActiveTask>>>,
66}
67
68impl A2aController {
69    pub fn new(config: ServerConfig, base_url: &str) -> Self {
70        let root_agent = config.agent_loader.root_agent();
71        let invoke_url = format!("{}/a2a", base_url.trim_end_matches('/'));
72        let agent_card = build_agent_card(root_agent.as_ref(), &invoke_url);
73
74        Self {
75            config,
76            agent_card,
77            task_store: Arc::new(TaskStore::new()),
78            active_tasks: Arc::new(Mutex::new(HashMap::new())),
79        }
80    }
81}
82
83fn build_runner_config(
84    controller: &A2aController,
85    root_agent: Arc<dyn adk_core::Agent>,
86    cancellation_token: Option<CancellationToken>,
87) -> Arc<RunnerConfig> {
88    Arc::new(RunnerConfig {
89        app_name: root_agent.name().to_string(),
90        agent: root_agent,
91        session_service: controller.config.session_service.clone(),
92        artifact_service: controller.config.artifact_service.clone(),
93        memory_service: controller.config.memory_service.clone(),
94        plugin_manager: None,
95        run_config: None,
96        compaction_config: controller.config.compaction_config.clone(),
97        context_cache_config: controller.config.context_cache_config.clone(),
98        cache_capable: controller.config.cache_capable.clone(),
99        request_context: None,
100        cancellation_token,
101        intra_compaction_config: None,
102        intra_compaction_summarizer: None,
103    })
104}
105
106fn build_task_from_events(task_id: &str, context_id: &str, events: &[UpdateEvent]) -> Task {
107    let mut task = Task {
108        id: task_id.to_string(),
109        context_id: Some(context_id.to_string()),
110        status: TaskStatus { state: TaskState::Completed, message: None },
111        artifacts: Some(vec![]),
112        history: None,
113    };
114
115    for event in events {
116        match event {
117            UpdateEvent::TaskStatusUpdate(status) => {
118                task.status = status.status.clone();
119            }
120            UpdateEvent::TaskArtifactUpdate(artifact) => {
121                if let Some(ref mut artifacts) = task.artifacts {
122                    artifacts.push(artifact.artifact.clone());
123                }
124            }
125        }
126    }
127
128    task
129}
130
131fn build_failed_task(task_id: &str, context_id: &str, message: impl Into<String>) -> Task {
132    Task {
133        id: task_id.to_string(),
134        context_id: Some(context_id.to_string()),
135        status: TaskStatus { state: TaskState::Failed, message: Some(message.into()) },
136        artifacts: None,
137        history: None,
138    }
139}
140
141fn build_canceled_task(task_id: &str, context_id: &str) -> Task {
142    Task {
143        id: task_id.to_string(),
144        context_id: Some(context_id.to_string()),
145        status: TaskStatus { state: TaskState::Canceled, message: None },
146        artifacts: None,
147        history: None,
148    }
149}
150
151fn sanitize_internal_error(config: &ServerConfig, error: &adk_core::AdkError) -> String {
152    if config.security.expose_error_details {
153        error.to_string()
154    } else {
155        "Internal server error".to_string()
156    }
157}
158
159async fn start_task(
160    controller: &A2aController,
161    context_id: String,
162    task_id: String,
163    message: Message,
164    stream_updates: bool,
165) -> (oneshot::Receiver<adk_core::Result<Task>>, Option<mpsc::Receiver<StreamTaskMessage>>) {
166    let token = CancellationToken::new();
167    let completion = Arc::new(Notify::new());
168    let (task_tx, task_rx) = oneshot::channel();
169    let (stream_tx, stream_rx) = if stream_updates {
170        let (tx, rx) = mpsc::channel(32);
171        (Some(tx), Some(rx))
172    } else {
173        (None, None)
174    };
175
176    let root_agent = controller.config.agent_loader.root_agent();
177    let executor = Executor::new(ExecutorConfig {
178        app_name: root_agent.name().to_string(),
179        runner_config: build_runner_config(controller, root_agent, Some(token.clone())),
180        cancellation_token: Some(token.clone()),
181    });
182
183    let controller_clone = controller.clone();
184    let completion_clone = completion.clone();
185    let task_id_for_task = task_id.clone();
186    let context_id_for_task = context_id.clone();
187    let stream_tx_for_task = stream_tx.clone();
188
189    let join_handle = tokio::spawn(async move {
190        let result = executor.execute(&context_id_for_task, &task_id_for_task, &message).await;
191
192        match result {
193            Ok(events) => {
194                if let Some(sender) = stream_tx_for_task {
195                    for event in &events {
196                        if sender
197                            .send(StreamTaskMessage::Update(Box::new(event.clone())))
198                            .await
199                            .is_err()
200                        {
201                            break;
202                        }
203                    }
204                }
205
206                let task = build_task_from_events(&task_id_for_task, &context_id_for_task, &events);
207                controller_clone.task_store.store(task.clone()).await;
208                let _ = task_tx.send(Ok(task));
209            }
210            Err(error) => {
211                if let Some(sender) = stream_tx_for_task {
212                    let _ = sender
213                        .send(StreamTaskMessage::Error(sanitize_internal_error(
214                            &controller_clone.config,
215                            &error,
216                        )))
217                        .await;
218                }
219                controller_clone
220                    .task_store
221                    .store(build_failed_task(
222                        &task_id_for_task,
223                        &context_id_for_task,
224                        error.to_string(),
225                    ))
226                    .await;
227                let _ = task_tx.send(Err(error));
228            }
229        }
230
231        controller_clone.active_tasks.lock().await.remove(&task_id_for_task);
232        completion_clone.notify_waiters();
233    });
234
235    controller.active_tasks.lock().await.insert(
236        task_id,
237        ActiveTask { token, abort_handle: join_handle.abort_handle(), completion, context_id },
238    );
239
240    (task_rx, stream_rx)
241}
242
243/// GET /.well-known/agent.json - Serve the agent card
244pub async fn get_agent_card(State(controller): State<A2aController>) -> impl IntoResponse {
245    Json(controller.agent_card.clone())
246}
247
248/// POST /a2a - JSON-RPC endpoint for A2A protocol
249pub async fn handle_jsonrpc(
250    State(controller): State<A2aController>,
251    Json(request): Json<JsonRpcRequest>,
252) -> impl IntoResponse {
253    if request.jsonrpc != "2.0" {
254        return Json(JsonRpcResponse::error(
255            request.id,
256            JsonRpcError::invalid_request("Invalid JSON-RPC version"),
257        ));
258    }
259
260    match request.method.as_str() {
261        jsonrpc::methods::MESSAGE_SEND => {
262            handle_message_send(&controller, request.params, request.id).await
263        }
264        jsonrpc::methods::TASKS_GET => {
265            handle_tasks_get(&controller, request.params, request.id).await
266        }
267        jsonrpc::methods::TASKS_CANCEL => {
268            handle_tasks_cancel(&controller, request.params, request.id).await
269        }
270        _ => Json(JsonRpcResponse::error(
271            request.id,
272            JsonRpcError::method_not_found(&request.method),
273        )),
274    }
275}
276
277/// POST /a2a/stream - SSE streaming endpoint for A2A protocol
278pub async fn handle_jsonrpc_stream(
279    State(controller): State<A2aController>,
280    Json(request): Json<JsonRpcRequest>,
281) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<JsonRpcResponse>)>
282{
283    if request.jsonrpc != "2.0" {
284        return Err((
285            StatusCode::BAD_REQUEST,
286            Json(JsonRpcResponse::error(
287                request.id.clone(),
288                JsonRpcError::invalid_request("Invalid JSON-RPC version"),
289            )),
290        ));
291    }
292
293    if request.method != jsonrpc::methods::MESSAGE_SEND_STREAM
294        && request.method != jsonrpc::methods::MESSAGE_SEND
295    {
296        return Err((
297            StatusCode::BAD_REQUEST,
298            Json(JsonRpcResponse::error(
299                request.id.clone(),
300                JsonRpcError::method_not_found(&request.method),
301            )),
302        ));
303    }
304
305    let params: MessageSendParams = match request.params {
306        Some(p) => serde_json::from_value(p).map_err(|e| {
307            (
308                StatusCode::BAD_REQUEST,
309                Json(JsonRpcResponse::error(
310                    request.id.clone(),
311                    JsonRpcError::invalid_params(e.to_string()),
312                )),
313            )
314        })?,
315        None => {
316            return Err((
317                StatusCode::BAD_REQUEST,
318                Json(JsonRpcResponse::error(
319                    request.id.clone(),
320                    JsonRpcError::invalid_params("Missing params"),
321                )),
322            ));
323        }
324    };
325
326    let request_id = request.id.clone();
327    let stream = create_message_stream(controller, params, request_id);
328
329    Ok(Sse::new(stream).keep_alive(
330        axum::response::sse::KeepAlive::new().interval(Duration::from_secs(15)).text("ping"),
331    ))
332}
333
334fn create_message_stream(
335    controller: A2aController,
336    params: MessageSendParams,
337    request_id: Option<Value>,
338) -> impl Stream<Item = Result<Event, Infallible>> {
339    async_stream::stream! {
340        let context_id = params
341            .message
342            .context_id
343            .clone()
344            .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
345        let task_id = params
346            .message
347            .task_id
348            .clone()
349            .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
350
351        let (_task_rx, maybe_stream_rx) = start_task(
352            &controller,
353            context_id.clone(),
354            task_id.clone(),
355            params.message.clone(),
356            true,
357        )
358        .await;
359
360        let Some(mut stream_rx) = maybe_stream_rx else {
361            yield Ok(Event::default().event("done").data(""));
362            return;
363        };
364
365        while let Some(message) = stream_rx.recv().await {
366            match message {
367                StreamTaskMessage::Update(event) => {
368                    let event_data = match event.as_ref() {
369                        UpdateEvent::TaskStatusUpdate(status) => {
370                            serde_json::to_string(&JsonRpcResponse::success(
371                                request_id.clone(),
372                                serde_json::to_value(status).unwrap_or_default(),
373                            ))
374                        }
375                        UpdateEvent::TaskArtifactUpdate(artifact) => {
376                            serde_json::to_string(&JsonRpcResponse::success(
377                                request_id.clone(),
378                                serde_json::to_value(artifact).unwrap_or_default(),
379                            ))
380                        }
381                    };
382
383                    if let Ok(data) = event_data {
384                        yield Ok(Event::default().data(data));
385                    }
386                }
387                StreamTaskMessage::Error(message) => {
388                    let error_response = JsonRpcResponse::error(
389                        request_id.clone(),
390                        JsonRpcError::internal_error(message),
391                    );
392                    if let Ok(data) = serde_json::to_string(&error_response) {
393                        yield Ok(Event::default().data(data));
394                    }
395                }
396            }
397        }
398
399        // Send done event
400        yield Ok(Event::default().event("done").data(""));
401    }
402}
403
404async fn handle_message_send(
405    controller: &A2aController,
406    params: Option<Value>,
407    id: Option<Value>,
408) -> Json<JsonRpcResponse> {
409    let params: MessageSendParams = match params {
410        Some(p) => match serde_json::from_value(p) {
411            Ok(p) => p,
412            Err(e) => {
413                return Json(JsonRpcResponse::error(
414                    id,
415                    JsonRpcError::invalid_params(e.to_string()),
416                ));
417            }
418        },
419        None => {
420            return Json(JsonRpcResponse::error(
421                id,
422                JsonRpcError::invalid_params("Missing params"),
423            ));
424        }
425    };
426
427    let context_id =
428        params.message.context_id.clone().unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
429    let task_id =
430        params.message.task_id.clone().unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
431
432    let (task_rx, _) =
433        start_task(controller, context_id.clone(), task_id.clone(), params.message, false).await;
434
435    match task_rx.await {
436        Ok(Ok(task)) => {
437            Json(JsonRpcResponse::success(id, serde_json::to_value(task).unwrap_or_default()))
438        }
439        Ok(Err(e)) => Json(JsonRpcResponse::error(
440            id,
441            JsonRpcError::internal_error_sanitized(
442                &e,
443                controller.config.security.expose_error_details,
444            ),
445        )),
446        Err(_) => {
447            Json(JsonRpcResponse::error(id, JsonRpcError::internal_error("Task execution aborted")))
448        }
449    }
450}
451
452async fn handle_tasks_get(
453    controller: &A2aController,
454    params: Option<Value>,
455    id: Option<Value>,
456) -> Json<JsonRpcResponse> {
457    let params: TasksGetParams = match params {
458        Some(p) => match serde_json::from_value(p) {
459            Ok(p) => p,
460            Err(e) => {
461                return Json(JsonRpcResponse::error(
462                    id,
463                    JsonRpcError::invalid_params(e.to_string()),
464                ));
465            }
466        },
467        None => {
468            return Json(JsonRpcResponse::error(
469                id,
470                JsonRpcError::invalid_params("Missing params"),
471            ));
472        }
473    };
474
475    if let Some(active_task) = controller.active_tasks.lock().await.get(&params.task_id).cloned() {
476        let task = Task {
477            id: params.task_id.clone(),
478            context_id: Some(active_task.context_id),
479            status: TaskStatus { state: TaskState::Working, message: None },
480            artifacts: None,
481            history: None,
482        };
483
484        return Json(JsonRpcResponse::success(id, serde_json::to_value(task).unwrap_or_default()));
485    }
486
487    match controller.task_store.get(&params.task_id).await {
488        Some(task) => {
489            Json(JsonRpcResponse::success(id, serde_json::to_value(task).unwrap_or_default()))
490        }
491        None => Json(JsonRpcResponse::error(
492            id,
493            JsonRpcError::internal_error(format!("Task not found: {}", params.task_id)),
494        )),
495    }
496}
497
498async fn handle_tasks_cancel(
499    controller: &A2aController,
500    params: Option<Value>,
501    id: Option<Value>,
502) -> Json<JsonRpcResponse> {
503    let params: TasksCancelParams = match params {
504        Some(p) => match serde_json::from_value(p) {
505            Ok(p) => p,
506            Err(e) => {
507                return Json(JsonRpcResponse::error(
508                    id,
509                    JsonRpcError::invalid_params(e.to_string()),
510                ));
511            }
512        },
513        None => {
514            return Json(JsonRpcResponse::error(
515                id,
516                JsonRpcError::invalid_params("Missing params"),
517            ));
518        }
519    };
520
521    let active_task = controller.active_tasks.lock().await.get(&params.task_id).cloned();
522
523    if let Some(active_task) = active_task {
524        active_task.token.cancel();
525
526        if tokio::time::timeout(Duration::from_secs(5), active_task.completion.notified())
527            .await
528            .is_err()
529        {
530            active_task.abort_handle.abort();
531            controller.active_tasks.lock().await.remove(&params.task_id);
532            controller
533                .task_store
534                .store(build_canceled_task(&params.task_id, &active_task.context_id))
535                .await;
536        }
537
538        let status = TaskStatusUpdateEvent {
539            task_id: params.task_id,
540            context_id: Some(active_task.context_id),
541            status: TaskStatus { state: TaskState::Canceled, message: None },
542            final_update: true,
543        };
544
545        return Json(JsonRpcResponse::success(
546            id,
547            serde_json::to_value(status).unwrap_or_default(),
548        ));
549    }
550
551    let status = TaskStatusUpdateEvent {
552        task_id: params.task_id,
553        context_id: Some(uuid::Uuid::new_v4().to_string()),
554        status: TaskStatus { state: TaskState::Canceled, message: None },
555        final_update: true,
556    };
557
558    Json(JsonRpcResponse::success(id, serde_json::to_value(status).unwrap_or_default()))
559}