Skip to main content

adk_server/rest/controllers/
a2a.rs

1use crate::ServerConfig;
2use crate::a2a::{
3    AgentCard, Executor, ExecutorConfig, JsonRpcError, JsonRpcRequest, JsonRpcResponse, Message,
4    MessageSendParams, Task, TaskState, TaskStatus, TaskStatusUpdateEvent, TasksCancelParams,
5    TasksGetParams, UpdateEvent, build_agent_card, jsonrpc,
6};
7use adk_runner::RunnerConfig;
8use axum::{
9    extract::State,
10    http::StatusCode,
11    response::{
12        IntoResponse, Json,
13        sse::{Event, Sse},
14    },
15};
16use futures::stream::Stream;
17use serde_json::Value;
18use std::{collections::HashMap, convert::Infallible, sync::Arc, time::Duration};
19use tokio::sync::{Mutex, Notify, RwLock, mpsc, oneshot};
20use tokio_util::sync::CancellationToken;
21
22/// In-memory task storage
23#[derive(Default)]
24pub struct TaskStore {
25    tasks: RwLock<HashMap<String, Task>>,
26}
27
28impl TaskStore {
29    pub fn new() -> Self {
30        Self::default()
31    }
32
33    pub async fn store(&self, task: Task) {
34        self.tasks.write().await.insert(task.id.clone(), task);
35    }
36
37    pub async fn get(&self, task_id: &str) -> Option<Task> {
38        self.tasks.read().await.get(task_id).cloned()
39    }
40
41    pub async fn remove(&self, task_id: &str) -> Option<Task> {
42        self.tasks.write().await.remove(task_id)
43    }
44}
45
46#[derive(Clone)]
47struct ActiveTask {
48    token: CancellationToken,
49    abort_handle: tokio::task::AbortHandle,
50    completion: Arc<Notify>,
51    context_id: String,
52}
53
54enum StreamTaskMessage {
55    Update(Box<UpdateEvent>),
56    Error(String),
57}
58
59/// Controller for A2A protocol endpoints
60#[derive(Clone)]
61pub struct A2aController {
62    config: ServerConfig,
63    agent_card: AgentCard,
64    task_store: Arc<TaskStore>,
65    active_tasks: Arc<Mutex<HashMap<String, ActiveTask>>>,
66}
67
68impl A2aController {
69    pub fn new(config: ServerConfig, base_url: &str) -> Self {
70        let root_agent = config.agent_loader.root_agent();
71        let invoke_url = format!("{}/a2a", base_url.trim_end_matches('/'));
72        let agent_card = build_agent_card(root_agent.as_ref(), &invoke_url);
73
74        Self {
75            config,
76            agent_card,
77            task_store: Arc::new(TaskStore::new()),
78            active_tasks: Arc::new(Mutex::new(HashMap::new())),
79        }
80    }
81}
82
83fn build_runner_config(
84    controller: &A2aController,
85    root_agent: Arc<dyn adk_core::Agent>,
86    cancellation_token: Option<CancellationToken>,
87) -> Arc<RunnerConfig> {
88    Arc::new(RunnerConfig {
89        app_name: root_agent.name().to_string(),
90        agent: root_agent,
91        session_service: controller.config.session_service.clone(),
92        artifact_service: controller.config.artifact_service.clone(),
93        memory_service: controller.config.memory_service.clone(),
94        plugin_manager: None,
95        run_config: None,
96        compaction_config: controller.config.compaction_config.clone(),
97        context_cache_config: controller.config.context_cache_config.clone(),
98        cache_capable: controller.config.cache_capable.clone(),
99        request_context: None,
100        cancellation_token,
101    })
102}
103
104fn build_task_from_events(task_id: &str, context_id: &str, events: &[UpdateEvent]) -> Task {
105    let mut task = Task {
106        id: task_id.to_string(),
107        context_id: Some(context_id.to_string()),
108        status: TaskStatus { state: TaskState::Completed, message: None },
109        artifacts: Some(vec![]),
110        history: None,
111    };
112
113    for event in events {
114        match event {
115            UpdateEvent::TaskStatusUpdate(status) => {
116                task.status = status.status.clone();
117            }
118            UpdateEvent::TaskArtifactUpdate(artifact) => {
119                if let Some(ref mut artifacts) = task.artifacts {
120                    artifacts.push(artifact.artifact.clone());
121                }
122            }
123        }
124    }
125
126    task
127}
128
129fn build_failed_task(task_id: &str, context_id: &str, message: impl Into<String>) -> Task {
130    Task {
131        id: task_id.to_string(),
132        context_id: Some(context_id.to_string()),
133        status: TaskStatus { state: TaskState::Failed, message: Some(message.into()) },
134        artifacts: None,
135        history: None,
136    }
137}
138
139fn build_canceled_task(task_id: &str, context_id: &str) -> Task {
140    Task {
141        id: task_id.to_string(),
142        context_id: Some(context_id.to_string()),
143        status: TaskStatus { state: TaskState::Canceled, message: None },
144        artifacts: None,
145        history: None,
146    }
147}
148
149fn sanitize_internal_error(config: &ServerConfig, error: &adk_core::AdkError) -> String {
150    if config.security.expose_error_details {
151        error.to_string()
152    } else {
153        "Internal server error".to_string()
154    }
155}
156
157async fn start_task(
158    controller: &A2aController,
159    context_id: String,
160    task_id: String,
161    message: Message,
162    stream_updates: bool,
163) -> (oneshot::Receiver<adk_core::Result<Task>>, Option<mpsc::Receiver<StreamTaskMessage>>) {
164    let token = CancellationToken::new();
165    let completion = Arc::new(Notify::new());
166    let (task_tx, task_rx) = oneshot::channel();
167    let (stream_tx, stream_rx) = if stream_updates {
168        let (tx, rx) = mpsc::channel(32);
169        (Some(tx), Some(rx))
170    } else {
171        (None, None)
172    };
173
174    let root_agent = controller.config.agent_loader.root_agent();
175    let executor = Executor::new(ExecutorConfig {
176        app_name: root_agent.name().to_string(),
177        runner_config: build_runner_config(controller, root_agent, Some(token.clone())),
178        cancellation_token: Some(token.clone()),
179    });
180
181    let controller_clone = controller.clone();
182    let completion_clone = completion.clone();
183    let task_id_for_task = task_id.clone();
184    let context_id_for_task = context_id.clone();
185    let stream_tx_for_task = stream_tx.clone();
186
187    let join_handle = tokio::spawn(async move {
188        let result = executor.execute(&context_id_for_task, &task_id_for_task, &message).await;
189
190        match result {
191            Ok(events) => {
192                if let Some(sender) = stream_tx_for_task {
193                    for event in &events {
194                        if sender
195                            .send(StreamTaskMessage::Update(Box::new(event.clone())))
196                            .await
197                            .is_err()
198                        {
199                            break;
200                        }
201                    }
202                }
203
204                let task = build_task_from_events(&task_id_for_task, &context_id_for_task, &events);
205                controller_clone.task_store.store(task.clone()).await;
206                let _ = task_tx.send(Ok(task));
207            }
208            Err(error) => {
209                if let Some(sender) = stream_tx_for_task {
210                    let _ = sender
211                        .send(StreamTaskMessage::Error(sanitize_internal_error(
212                            &controller_clone.config,
213                            &error,
214                        )))
215                        .await;
216                }
217                controller_clone
218                    .task_store
219                    .store(build_failed_task(
220                        &task_id_for_task,
221                        &context_id_for_task,
222                        error.to_string(),
223                    ))
224                    .await;
225                let _ = task_tx.send(Err(error));
226            }
227        }
228
229        controller_clone.active_tasks.lock().await.remove(&task_id_for_task);
230        completion_clone.notify_waiters();
231    });
232
233    controller.active_tasks.lock().await.insert(
234        task_id,
235        ActiveTask { token, abort_handle: join_handle.abort_handle(), completion, context_id },
236    );
237
238    (task_rx, stream_rx)
239}
240
241/// GET /.well-known/agent.json - Serve the agent card
242pub async fn get_agent_card(State(controller): State<A2aController>) -> impl IntoResponse {
243    Json(controller.agent_card.clone())
244}
245
246/// POST /a2a - JSON-RPC endpoint for A2A protocol
247pub async fn handle_jsonrpc(
248    State(controller): State<A2aController>,
249    Json(request): Json<JsonRpcRequest>,
250) -> impl IntoResponse {
251    if request.jsonrpc != "2.0" {
252        return Json(JsonRpcResponse::error(
253            request.id,
254            JsonRpcError::invalid_request("Invalid JSON-RPC version"),
255        ));
256    }
257
258    match request.method.as_str() {
259        jsonrpc::methods::MESSAGE_SEND => {
260            handle_message_send(&controller, request.params, request.id).await
261        }
262        jsonrpc::methods::TASKS_GET => {
263            handle_tasks_get(&controller, request.params, request.id).await
264        }
265        jsonrpc::methods::TASKS_CANCEL => {
266            handle_tasks_cancel(&controller, request.params, request.id).await
267        }
268        _ => Json(JsonRpcResponse::error(
269            request.id,
270            JsonRpcError::method_not_found(&request.method),
271        )),
272    }
273}
274
275/// POST /a2a/stream - SSE streaming endpoint for A2A protocol
276pub async fn handle_jsonrpc_stream(
277    State(controller): State<A2aController>,
278    Json(request): Json<JsonRpcRequest>,
279) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<JsonRpcResponse>)>
280{
281    if request.jsonrpc != "2.0" {
282        return Err((
283            StatusCode::BAD_REQUEST,
284            Json(JsonRpcResponse::error(
285                request.id.clone(),
286                JsonRpcError::invalid_request("Invalid JSON-RPC version"),
287            )),
288        ));
289    }
290
291    if request.method != jsonrpc::methods::MESSAGE_SEND_STREAM
292        && request.method != jsonrpc::methods::MESSAGE_SEND
293    {
294        return Err((
295            StatusCode::BAD_REQUEST,
296            Json(JsonRpcResponse::error(
297                request.id.clone(),
298                JsonRpcError::method_not_found(&request.method),
299            )),
300        ));
301    }
302
303    let params: MessageSendParams = match request.params {
304        Some(p) => serde_json::from_value(p).map_err(|e| {
305            (
306                StatusCode::BAD_REQUEST,
307                Json(JsonRpcResponse::error(
308                    request.id.clone(),
309                    JsonRpcError::invalid_params(e.to_string()),
310                )),
311            )
312        })?,
313        None => {
314            return Err((
315                StatusCode::BAD_REQUEST,
316                Json(JsonRpcResponse::error(
317                    request.id.clone(),
318                    JsonRpcError::invalid_params("Missing params"),
319                )),
320            ));
321        }
322    };
323
324    let request_id = request.id.clone();
325    let stream = create_message_stream(controller, params, request_id);
326
327    Ok(Sse::new(stream).keep_alive(
328        axum::response::sse::KeepAlive::new().interval(Duration::from_secs(15)).text("ping"),
329    ))
330}
331
332fn create_message_stream(
333    controller: A2aController,
334    params: MessageSendParams,
335    request_id: Option<Value>,
336) -> impl Stream<Item = Result<Event, Infallible>> {
337    async_stream::stream! {
338        let context_id = params
339            .message
340            .context_id
341            .clone()
342            .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
343        let task_id = params
344            .message
345            .task_id
346            .clone()
347            .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
348
349        let (_task_rx, maybe_stream_rx) = start_task(
350            &controller,
351            context_id.clone(),
352            task_id.clone(),
353            params.message.clone(),
354            true,
355        )
356        .await;
357
358        let Some(mut stream_rx) = maybe_stream_rx else {
359            yield Ok(Event::default().event("done").data(""));
360            return;
361        };
362
363        while let Some(message) = stream_rx.recv().await {
364            match message {
365                StreamTaskMessage::Update(event) => {
366                    let event_data = match event.as_ref() {
367                        UpdateEvent::TaskStatusUpdate(status) => {
368                            serde_json::to_string(&JsonRpcResponse::success(
369                                request_id.clone(),
370                                serde_json::to_value(status).unwrap_or_default(),
371                            ))
372                        }
373                        UpdateEvent::TaskArtifactUpdate(artifact) => {
374                            serde_json::to_string(&JsonRpcResponse::success(
375                                request_id.clone(),
376                                serde_json::to_value(artifact).unwrap_or_default(),
377                            ))
378                        }
379                    };
380
381                    if let Ok(data) = event_data {
382                        yield Ok(Event::default().data(data));
383                    }
384                }
385                StreamTaskMessage::Error(message) => {
386                    let error_response = JsonRpcResponse::error(
387                        request_id.clone(),
388                        JsonRpcError::internal_error(message),
389                    );
390                    if let Ok(data) = serde_json::to_string(&error_response) {
391                        yield Ok(Event::default().data(data));
392                    }
393                }
394            }
395        }
396
397        // Send done event
398        yield Ok(Event::default().event("done").data(""));
399    }
400}
401
402async fn handle_message_send(
403    controller: &A2aController,
404    params: Option<Value>,
405    id: Option<Value>,
406) -> Json<JsonRpcResponse> {
407    let params: MessageSendParams = match params {
408        Some(p) => match serde_json::from_value(p) {
409            Ok(p) => p,
410            Err(e) => {
411                return Json(JsonRpcResponse::error(
412                    id,
413                    JsonRpcError::invalid_params(e.to_string()),
414                ));
415            }
416        },
417        None => {
418            return Json(JsonRpcResponse::error(
419                id,
420                JsonRpcError::invalid_params("Missing params"),
421            ));
422        }
423    };
424
425    let context_id =
426        params.message.context_id.clone().unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
427    let task_id =
428        params.message.task_id.clone().unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
429
430    let (task_rx, _) =
431        start_task(controller, context_id.clone(), task_id.clone(), params.message, false).await;
432
433    match task_rx.await {
434        Ok(Ok(task)) => {
435            Json(JsonRpcResponse::success(id, serde_json::to_value(task).unwrap_or_default()))
436        }
437        Ok(Err(e)) => Json(JsonRpcResponse::error(
438            id,
439            JsonRpcError::internal_error_sanitized(
440                &e,
441                controller.config.security.expose_error_details,
442            ),
443        )),
444        Err(_) => {
445            Json(JsonRpcResponse::error(id, JsonRpcError::internal_error("Task execution aborted")))
446        }
447    }
448}
449
450async fn handle_tasks_get(
451    controller: &A2aController,
452    params: Option<Value>,
453    id: Option<Value>,
454) -> Json<JsonRpcResponse> {
455    let params: TasksGetParams = match params {
456        Some(p) => match serde_json::from_value(p) {
457            Ok(p) => p,
458            Err(e) => {
459                return Json(JsonRpcResponse::error(
460                    id,
461                    JsonRpcError::invalid_params(e.to_string()),
462                ));
463            }
464        },
465        None => {
466            return Json(JsonRpcResponse::error(
467                id,
468                JsonRpcError::invalid_params("Missing params"),
469            ));
470        }
471    };
472
473    if let Some(active_task) = controller.active_tasks.lock().await.get(&params.task_id).cloned() {
474        let task = Task {
475            id: params.task_id.clone(),
476            context_id: Some(active_task.context_id),
477            status: TaskStatus { state: TaskState::Working, message: None },
478            artifacts: None,
479            history: None,
480        };
481
482        return Json(JsonRpcResponse::success(id, serde_json::to_value(task).unwrap_or_default()));
483    }
484
485    match controller.task_store.get(&params.task_id).await {
486        Some(task) => {
487            Json(JsonRpcResponse::success(id, serde_json::to_value(task).unwrap_or_default()))
488        }
489        None => Json(JsonRpcResponse::error(
490            id,
491            JsonRpcError::internal_error(format!("Task not found: {}", params.task_id)),
492        )),
493    }
494}
495
496async fn handle_tasks_cancel(
497    controller: &A2aController,
498    params: Option<Value>,
499    id: Option<Value>,
500) -> Json<JsonRpcResponse> {
501    let params: TasksCancelParams = match params {
502        Some(p) => match serde_json::from_value(p) {
503            Ok(p) => p,
504            Err(e) => {
505                return Json(JsonRpcResponse::error(
506                    id,
507                    JsonRpcError::invalid_params(e.to_string()),
508                ));
509            }
510        },
511        None => {
512            return Json(JsonRpcResponse::error(
513                id,
514                JsonRpcError::invalid_params("Missing params"),
515            ));
516        }
517    };
518
519    let active_task = controller.active_tasks.lock().await.get(&params.task_id).cloned();
520
521    if let Some(active_task) = active_task {
522        active_task.token.cancel();
523
524        if tokio::time::timeout(Duration::from_secs(5), active_task.completion.notified())
525            .await
526            .is_err()
527        {
528            active_task.abort_handle.abort();
529            controller.active_tasks.lock().await.remove(&params.task_id);
530            controller
531                .task_store
532                .store(build_canceled_task(&params.task_id, &active_task.context_id))
533                .await;
534        }
535
536        let status = TaskStatusUpdateEvent {
537            task_id: params.task_id,
538            context_id: Some(active_task.context_id),
539            status: TaskStatus { state: TaskState::Canceled, message: None },
540            final_update: true,
541        };
542
543        return Json(JsonRpcResponse::success(
544            id,
545            serde_json::to_value(status).unwrap_or_default(),
546        ));
547    }
548
549    let status = TaskStatusUpdateEvent {
550        task_id: params.task_id,
551        context_id: Some(uuid::Uuid::new_v4().to_string()),
552        status: TaskStatus { state: TaskState::Canceled, message: None },
553        final_update: true,
554    };
555
556    Json(JsonRpcResponse::success(id, serde_json::to_value(status).unwrap_or_default()))
557}