adk_server/rest/controllers/
a2a.rs

1use crate::ServerConfig;
2use crate::a2a::{
3    AgentCard, Executor, ExecutorConfig, JsonRpcError, JsonRpcRequest, JsonRpcResponse,
4    MessageSendParams, Task, TaskState, TaskStatus, TasksCancelParams, TasksGetParams, UpdateEvent,
5    build_agent_card, jsonrpc,
6};
7use adk_runner::RunnerConfig;
8use axum::{
9    extract::State,
10    http::StatusCode,
11    response::{
12        IntoResponse, Json,
13        sse::{Event, Sse},
14    },
15};
16use futures::stream::Stream;
17use serde_json::Value;
18use std::{convert::Infallible, sync::Arc, time::Duration};
19
20/// Controller for A2A protocol endpoints
21#[derive(Clone)]
22pub struct A2aController {
23    config: ServerConfig,
24    agent_card: AgentCard,
25}
26
27impl A2aController {
28    pub fn new(config: ServerConfig, base_url: &str) -> Self {
29        let root_agent = config.agent_loader.root_agent();
30        let invoke_url = format!("{}/a2a", base_url.trim_end_matches('/'));
31        let agent_card = build_agent_card(root_agent.as_ref(), &invoke_url);
32
33        Self { config, agent_card }
34    }
35}
36
37/// GET /.well-known/agent.json - Serve the agent card
38pub async fn get_agent_card(State(controller): State<A2aController>) -> impl IntoResponse {
39    Json(controller.agent_card.clone())
40}
41
42/// POST /a2a - JSON-RPC endpoint for A2A protocol
43pub async fn handle_jsonrpc(
44    State(controller): State<A2aController>,
45    Json(request): Json<JsonRpcRequest>,
46) -> impl IntoResponse {
47    if request.jsonrpc != "2.0" {
48        return Json(JsonRpcResponse::error(
49            request.id,
50            JsonRpcError::invalid_request("Invalid JSON-RPC version"),
51        ));
52    }
53
54    match request.method.as_str() {
55        jsonrpc::methods::MESSAGE_SEND => {
56            handle_message_send(&controller, request.params, request.id).await
57        }
58        jsonrpc::methods::TASKS_GET => {
59            handle_tasks_get(&controller, request.params, request.id).await
60        }
61        jsonrpc::methods::TASKS_CANCEL => {
62            handle_tasks_cancel(&controller, request.params, request.id).await
63        }
64        _ => Json(JsonRpcResponse::error(
65            request.id,
66            JsonRpcError::method_not_found(&request.method),
67        )),
68    }
69}
70
71/// POST /a2a/stream - SSE streaming endpoint for A2A protocol
72pub async fn handle_jsonrpc_stream(
73    State(controller): State<A2aController>,
74    Json(request): Json<JsonRpcRequest>,
75) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<JsonRpcResponse>)>
76{
77    if request.jsonrpc != "2.0" {
78        return Err((
79            StatusCode::BAD_REQUEST,
80            Json(JsonRpcResponse::error(
81                request.id.clone(),
82                JsonRpcError::invalid_request("Invalid JSON-RPC version"),
83            )),
84        ));
85    }
86
87    if request.method != jsonrpc::methods::MESSAGE_SEND_STREAM
88        && request.method != jsonrpc::methods::MESSAGE_SEND
89    {
90        return Err((
91            StatusCode::BAD_REQUEST,
92            Json(JsonRpcResponse::error(
93                request.id.clone(),
94                JsonRpcError::method_not_found(&request.method),
95            )),
96        ));
97    }
98
99    let params: MessageSendParams = match request.params {
100        Some(p) => serde_json::from_value(p).map_err(|e| {
101            (
102                StatusCode::BAD_REQUEST,
103                Json(JsonRpcResponse::error(
104                    request.id.clone(),
105                    JsonRpcError::invalid_params(e.to_string()),
106                )),
107            )
108        })?,
109        None => {
110            return Err((
111                StatusCode::BAD_REQUEST,
112                Json(JsonRpcResponse::error(
113                    request.id.clone(),
114                    JsonRpcError::invalid_params("Missing params"),
115                )),
116            ));
117        }
118    };
119
120    let request_id = request.id.clone();
121    let stream = create_message_stream(controller, params, request_id);
122
123    Ok(Sse::new(stream).keep_alive(
124        axum::response::sse::KeepAlive::new().interval(Duration::from_secs(15)).text("ping"),
125    ))
126}
127
128fn create_message_stream(
129    controller: A2aController,
130    params: MessageSendParams,
131    request_id: Option<Value>,
132) -> impl Stream<Item = Result<Event, Infallible>> {
133    async_stream::stream! {
134        let context_id = params
135            .message
136            .context_id
137            .clone()
138            .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
139        let task_id = params
140            .message
141            .task_id
142            .clone()
143            .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
144
145        let root_agent = controller.config.agent_loader.root_agent();
146
147        let executor = Executor::new(ExecutorConfig {
148            app_name: root_agent.name().to_string(),
149            runner_config: Arc::new(RunnerConfig {
150                app_name: root_agent.name().to_string(),
151                agent: root_agent,
152                session_service: controller.config.session_service.clone(),
153                artifact_service: controller.config.artifact_service.clone(),
154                memory_service: None,
155                run_config: None,
156            }),
157        });
158
159        match executor.execute(&context_id, &task_id, &params.message).await {
160            Ok(events) => {
161                for event in events {
162                    let event_data = match &event {
163                        UpdateEvent::TaskStatusUpdate(status) => {
164                            serde_json::to_string(&JsonRpcResponse::success(
165                                request_id.clone(),
166                                serde_json::to_value(status).unwrap_or_default(),
167                            ))
168                        }
169                        UpdateEvent::TaskArtifactUpdate(artifact) => {
170                            serde_json::to_string(&JsonRpcResponse::success(
171                                request_id.clone(),
172                                serde_json::to_value(artifact).unwrap_or_default(),
173                            ))
174                        }
175                    };
176
177                    if let Ok(data) = event_data {
178                        yield Ok(Event::default().data(data));
179                    }
180                }
181            }
182            Err(e) => {
183                let error_response = JsonRpcResponse::error(
184                    request_id.clone(),
185                    JsonRpcError::internal_error_sanitized(
186                        &e,
187                        controller.config.security.expose_error_details,
188                    ),
189                );
190                if let Ok(data) = serde_json::to_string(&error_response) {
191                    yield Ok(Event::default().data(data));
192                }
193            }
194        }
195
196        // Send done event
197        yield Ok(Event::default().event("done").data(""));
198    }
199}
200
201async fn handle_message_send(
202    controller: &A2aController,
203    params: Option<Value>,
204    id: Option<Value>,
205) -> Json<JsonRpcResponse> {
206    let params: MessageSendParams = match params {
207        Some(p) => match serde_json::from_value(p) {
208            Ok(p) => p,
209            Err(e) => {
210                return Json(JsonRpcResponse::error(
211                    id,
212                    JsonRpcError::invalid_params(e.to_string()),
213                ));
214            }
215        },
216        None => {
217            return Json(JsonRpcResponse::error(
218                id,
219                JsonRpcError::invalid_params("Missing params"),
220            ));
221        }
222    };
223
224    let context_id =
225        params.message.context_id.clone().unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
226    let task_id =
227        params.message.task_id.clone().unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
228
229    let root_agent = controller.config.agent_loader.root_agent();
230
231    let executor = Executor::new(ExecutorConfig {
232        app_name: root_agent.name().to_string(),
233        runner_config: Arc::new(RunnerConfig {
234            app_name: root_agent.name().to_string(),
235            agent: root_agent,
236            session_service: controller.config.session_service.clone(),
237            artifact_service: controller.config.artifact_service.clone(),
238            memory_service: None,
239            run_config: None,
240        }),
241    });
242
243    match executor.execute(&context_id, &task_id, &params.message).await {
244        Ok(events) => {
245            // Build task from events
246            let mut task = Task {
247                id: task_id,
248                context_id: Some(context_id),
249                status: TaskStatus { state: TaskState::Completed, message: None },
250                artifacts: Some(vec![]),
251                history: None,
252            };
253
254            for event in events {
255                match event {
256                    UpdateEvent::TaskStatusUpdate(status) => {
257                        task.status = status.status;
258                    }
259                    UpdateEvent::TaskArtifactUpdate(artifact) => {
260                        if let Some(ref mut artifacts) = task.artifacts {
261                            artifacts.push(artifact.artifact);
262                        }
263                    }
264                }
265            }
266
267            Json(JsonRpcResponse::success(id, serde_json::to_value(task).unwrap_or_default()))
268        }
269        Err(e) => Json(JsonRpcResponse::error(
270            id,
271            JsonRpcError::internal_error_sanitized(
272                &e,
273                controller.config.security.expose_error_details,
274            ),
275        )),
276    }
277}
278
279async fn handle_tasks_get(
280    _controller: &A2aController,
281    params: Option<Value>,
282    id: Option<Value>,
283) -> Json<JsonRpcResponse> {
284    let _params: TasksGetParams = match params {
285        Some(p) => match serde_json::from_value(p) {
286            Ok(p) => p,
287            Err(e) => {
288                return Json(JsonRpcResponse::error(
289                    id,
290                    JsonRpcError::invalid_params(e.to_string()),
291                ));
292            }
293        },
294        None => {
295            return Json(JsonRpcResponse::error(
296                id,
297                JsonRpcError::invalid_params("Missing params"),
298            ));
299        }
300    };
301
302    // TODO: Implement task storage and retrieval
303    Json(JsonRpcResponse::error(
304        id,
305        JsonRpcError::internal_error("Task retrieval not yet implemented"),
306    ))
307}
308
309async fn handle_tasks_cancel(
310    controller: &A2aController,
311    params: Option<Value>,
312    id: Option<Value>,
313) -> Json<JsonRpcResponse> {
314    let params: TasksCancelParams = match params {
315        Some(p) => match serde_json::from_value(p) {
316            Ok(p) => p,
317            Err(e) => {
318                return Json(JsonRpcResponse::error(
319                    id,
320                    JsonRpcError::invalid_params(e.to_string()),
321                ));
322            }
323        },
324        None => {
325            return Json(JsonRpcResponse::error(
326                id,
327                JsonRpcError::invalid_params("Missing params"),
328            ));
329        }
330    };
331
332    let root_agent = controller.config.agent_loader.root_agent();
333
334    let executor = Executor::new(ExecutorConfig {
335        app_name: root_agent.name().to_string(),
336        runner_config: Arc::new(RunnerConfig {
337            app_name: root_agent.name().to_string(),
338            agent: root_agent,
339            session_service: controller.config.session_service.clone(),
340            artifact_service: controller.config.artifact_service.clone(),
341            memory_service: None,
342            run_config: None,
343        }),
344    });
345
346    // Use a default context_id for cancel
347    let context_id = uuid::Uuid::new_v4().to_string();
348
349    match executor.cancel(&context_id, &params.task_id).await {
350        Ok(status) => {
351            Json(JsonRpcResponse::success(id, serde_json::to_value(status).unwrap_or_default()))
352        }
353        Err(e) => Json(JsonRpcResponse::error(
354            id,
355            JsonRpcError::internal_error_sanitized(
356                &e,
357                controller.config.security.expose_error_details,
358            ),
359        )),
360    }
361}