adk_server/rest/controllers/
a2a.rs

1use crate::ServerConfig;
2use crate::a2a::{
3    AgentCard, Executor, ExecutorConfig, JsonRpcError, JsonRpcRequest, JsonRpcResponse,
4    MessageSendParams, Task, TaskState, TaskStatus, TasksCancelParams, TasksGetParams, UpdateEvent,
5    build_agent_card, jsonrpc,
6};
7use adk_runner::RunnerConfig;
8use axum::{
9    extract::State,
10    http::StatusCode,
11    response::{
12        IntoResponse, Json,
13        sse::{Event, Sse},
14    },
15};
16use futures::stream::Stream;
17use serde_json::Value;
18use std::{collections::HashMap, convert::Infallible, sync::Arc, time::Duration};
19use tokio::sync::RwLock;
20
21/// In-memory task storage
22#[derive(Default)]
23pub struct TaskStore {
24    tasks: RwLock<HashMap<String, Task>>,
25}
26
27impl TaskStore {
28    pub fn new() -> Self {
29        Self::default()
30    }
31
32    pub async fn store(&self, task: Task) {
33        self.tasks.write().await.insert(task.id.clone(), task);
34    }
35
36    pub async fn get(&self, task_id: &str) -> Option<Task> {
37        self.tasks.read().await.get(task_id).cloned()
38    }
39
40    pub async fn remove(&self, task_id: &str) -> Option<Task> {
41        self.tasks.write().await.remove(task_id)
42    }
43}
44
45/// Controller for A2A protocol endpoints
46#[derive(Clone)]
47pub struct A2aController {
48    config: ServerConfig,
49    agent_card: AgentCard,
50    task_store: Arc<TaskStore>,
51}
52
53impl A2aController {
54    pub fn new(config: ServerConfig, base_url: &str) -> Self {
55        let root_agent = config.agent_loader.root_agent();
56        let invoke_url = format!("{}/a2a", base_url.trim_end_matches('/'));
57        let agent_card = build_agent_card(root_agent.as_ref(), &invoke_url);
58
59        Self { config, agent_card, task_store: Arc::new(TaskStore::new()) }
60    }
61}
62
63/// GET /.well-known/agent.json - Serve the agent card
64pub async fn get_agent_card(State(controller): State<A2aController>) -> impl IntoResponse {
65    Json(controller.agent_card.clone())
66}
67
68/// POST /a2a - JSON-RPC endpoint for A2A protocol
69pub async fn handle_jsonrpc(
70    State(controller): State<A2aController>,
71    Json(request): Json<JsonRpcRequest>,
72) -> impl IntoResponse {
73    if request.jsonrpc != "2.0" {
74        return Json(JsonRpcResponse::error(
75            request.id,
76            JsonRpcError::invalid_request("Invalid JSON-RPC version"),
77        ));
78    }
79
80    match request.method.as_str() {
81        jsonrpc::methods::MESSAGE_SEND => {
82            handle_message_send(&controller, request.params, request.id).await
83        }
84        jsonrpc::methods::TASKS_GET => {
85            handle_tasks_get(&controller, request.params, request.id).await
86        }
87        jsonrpc::methods::TASKS_CANCEL => {
88            handle_tasks_cancel(&controller, request.params, request.id).await
89        }
90        _ => Json(JsonRpcResponse::error(
91            request.id,
92            JsonRpcError::method_not_found(&request.method),
93        )),
94    }
95}
96
97/// POST /a2a/stream - SSE streaming endpoint for A2A protocol
98pub async fn handle_jsonrpc_stream(
99    State(controller): State<A2aController>,
100    Json(request): Json<JsonRpcRequest>,
101) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<JsonRpcResponse>)>
102{
103    if request.jsonrpc != "2.0" {
104        return Err((
105            StatusCode::BAD_REQUEST,
106            Json(JsonRpcResponse::error(
107                request.id.clone(),
108                JsonRpcError::invalid_request("Invalid JSON-RPC version"),
109            )),
110        ));
111    }
112
113    if request.method != jsonrpc::methods::MESSAGE_SEND_STREAM
114        && request.method != jsonrpc::methods::MESSAGE_SEND
115    {
116        return Err((
117            StatusCode::BAD_REQUEST,
118            Json(JsonRpcResponse::error(
119                request.id.clone(),
120                JsonRpcError::method_not_found(&request.method),
121            )),
122        ));
123    }
124
125    let params: MessageSendParams = match request.params {
126        Some(p) => serde_json::from_value(p).map_err(|e| {
127            (
128                StatusCode::BAD_REQUEST,
129                Json(JsonRpcResponse::error(
130                    request.id.clone(),
131                    JsonRpcError::invalid_params(e.to_string()),
132                )),
133            )
134        })?,
135        None => {
136            return Err((
137                StatusCode::BAD_REQUEST,
138                Json(JsonRpcResponse::error(
139                    request.id.clone(),
140                    JsonRpcError::invalid_params("Missing params"),
141                )),
142            ));
143        }
144    };
145
146    let request_id = request.id.clone();
147    let stream = create_message_stream(controller, params, request_id);
148
149    Ok(Sse::new(stream).keep_alive(
150        axum::response::sse::KeepAlive::new().interval(Duration::from_secs(15)).text("ping"),
151    ))
152}
153
154fn create_message_stream(
155    controller: A2aController,
156    params: MessageSendParams,
157    request_id: Option<Value>,
158) -> impl Stream<Item = Result<Event, Infallible>> {
159    async_stream::stream! {
160        let context_id = params
161            .message
162            .context_id
163            .clone()
164            .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
165        let task_id = params
166            .message
167            .task_id
168            .clone()
169            .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
170
171        let root_agent = controller.config.agent_loader.root_agent();
172
173        let executor = Executor::new(ExecutorConfig {
174            app_name: root_agent.name().to_string(),
175            runner_config: Arc::new(RunnerConfig {
176                app_name: root_agent.name().to_string(),
177                agent: root_agent,
178                session_service: controller.config.session_service.clone(),
179                artifact_service: controller.config.artifact_service.clone(),
180                memory_service: None,
181                run_config: None,
182            }),
183        });
184
185        match executor.execute(&context_id, &task_id, &params.message).await {
186            Ok(events) => {
187                for event in events {
188                    let event_data = match &event {
189                        UpdateEvent::TaskStatusUpdate(status) => {
190                            serde_json::to_string(&JsonRpcResponse::success(
191                                request_id.clone(),
192                                serde_json::to_value(status).unwrap_or_default(),
193                            ))
194                        }
195                        UpdateEvent::TaskArtifactUpdate(artifact) => {
196                            serde_json::to_string(&JsonRpcResponse::success(
197                                request_id.clone(),
198                                serde_json::to_value(artifact).unwrap_or_default(),
199                            ))
200                        }
201                    };
202
203                    if let Ok(data) = event_data {
204                        yield Ok(Event::default().data(data));
205                    }
206                }
207            }
208            Err(e) => {
209                let error_response = JsonRpcResponse::error(
210                    request_id.clone(),
211                    JsonRpcError::internal_error_sanitized(
212                        &e,
213                        controller.config.security.expose_error_details,
214                    ),
215                );
216                if let Ok(data) = serde_json::to_string(&error_response) {
217                    yield Ok(Event::default().data(data));
218                }
219            }
220        }
221
222        // Send done event
223        yield Ok(Event::default().event("done").data(""));
224    }
225}
226
227async fn handle_message_send(
228    controller: &A2aController,
229    params: Option<Value>,
230    id: Option<Value>,
231) -> Json<JsonRpcResponse> {
232    let params: MessageSendParams = match params {
233        Some(p) => match serde_json::from_value(p) {
234            Ok(p) => p,
235            Err(e) => {
236                return Json(JsonRpcResponse::error(
237                    id,
238                    JsonRpcError::invalid_params(e.to_string()),
239                ));
240            }
241        },
242        None => {
243            return Json(JsonRpcResponse::error(
244                id,
245                JsonRpcError::invalid_params("Missing params"),
246            ));
247        }
248    };
249
250    let context_id =
251        params.message.context_id.clone().unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
252    let task_id =
253        params.message.task_id.clone().unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
254
255    let root_agent = controller.config.agent_loader.root_agent();
256
257    let executor = Executor::new(ExecutorConfig {
258        app_name: root_agent.name().to_string(),
259        runner_config: Arc::new(RunnerConfig {
260            app_name: root_agent.name().to_string(),
261            agent: root_agent,
262            session_service: controller.config.session_service.clone(),
263            artifact_service: controller.config.artifact_service.clone(),
264            memory_service: None,
265            run_config: None,
266        }),
267    });
268
269    match executor.execute(&context_id, &task_id, &params.message).await {
270        Ok(events) => {
271            // Build task from events
272            let mut task = Task {
273                id: task_id,
274                context_id: Some(context_id),
275                status: TaskStatus { state: TaskState::Completed, message: None },
276                artifacts: Some(vec![]),
277                history: None,
278            };
279
280            for event in events {
281                match event {
282                    UpdateEvent::TaskStatusUpdate(status) => {
283                        task.status = status.status;
284                    }
285                    UpdateEvent::TaskArtifactUpdate(artifact) => {
286                        if let Some(ref mut artifacts) = task.artifacts {
287                            artifacts.push(artifact.artifact);
288                        }
289                    }
290                }
291            }
292
293            // Store task for later retrieval
294            controller.task_store.store(task.clone()).await;
295
296            Json(JsonRpcResponse::success(id, serde_json::to_value(task).unwrap_or_default()))
297        }
298        Err(e) => Json(JsonRpcResponse::error(
299            id,
300            JsonRpcError::internal_error_sanitized(
301                &e,
302                controller.config.security.expose_error_details,
303            ),
304        )),
305    }
306}
307
308async fn handle_tasks_get(
309    controller: &A2aController,
310    params: Option<Value>,
311    id: Option<Value>,
312) -> Json<JsonRpcResponse> {
313    let params: TasksGetParams = match params {
314        Some(p) => match serde_json::from_value(p) {
315            Ok(p) => p,
316            Err(e) => {
317                return Json(JsonRpcResponse::error(
318                    id,
319                    JsonRpcError::invalid_params(e.to_string()),
320                ));
321            }
322        },
323        None => {
324            return Json(JsonRpcResponse::error(
325                id,
326                JsonRpcError::invalid_params("Missing params"),
327            ));
328        }
329    };
330
331    match controller.task_store.get(&params.task_id).await {
332        Some(task) => {
333            Json(JsonRpcResponse::success(id, serde_json::to_value(task).unwrap_or_default()))
334        }
335        None => Json(JsonRpcResponse::error(
336            id,
337            JsonRpcError::internal_error(format!("Task not found: {}", params.task_id)),
338        )),
339    }
340}
341
342async fn handle_tasks_cancel(
343    controller: &A2aController,
344    params: Option<Value>,
345    id: Option<Value>,
346) -> Json<JsonRpcResponse> {
347    let params: TasksCancelParams = match params {
348        Some(p) => match serde_json::from_value(p) {
349            Ok(p) => p,
350            Err(e) => {
351                return Json(JsonRpcResponse::error(
352                    id,
353                    JsonRpcError::invalid_params(e.to_string()),
354                ));
355            }
356        },
357        None => {
358            return Json(JsonRpcResponse::error(
359                id,
360                JsonRpcError::invalid_params("Missing params"),
361            ));
362        }
363    };
364
365    let root_agent = controller.config.agent_loader.root_agent();
366
367    let executor = Executor::new(ExecutorConfig {
368        app_name: root_agent.name().to_string(),
369        runner_config: Arc::new(RunnerConfig {
370            app_name: root_agent.name().to_string(),
371            agent: root_agent,
372            session_service: controller.config.session_service.clone(),
373            artifact_service: controller.config.artifact_service.clone(),
374            memory_service: None,
375            run_config: None,
376        }),
377    });
378
379    // Use a default context_id for cancel
380    let context_id = uuid::Uuid::new_v4().to_string();
381
382    match executor.cancel(&context_id, &params.task_id).await {
383        Ok(status) => {
384            Json(JsonRpcResponse::success(id, serde_json::to_value(status).unwrap_or_default()))
385        }
386        Err(e) => Json(JsonRpcResponse::error(
387            id,
388            JsonRpcError::internal_error_sanitized(
389                &e,
390                controller.config.security.expose_error_details,
391            ),
392        )),
393    }
394}