adk_server/rest/controllers/
a2a.rs

1use crate::a2a::{
2    build_agent_card, jsonrpc, AgentCard, Executor, ExecutorConfig, JsonRpcError, JsonRpcRequest,
3    JsonRpcResponse, MessageSendParams, Task, TaskState, TaskStatus, TasksCancelParams,
4    TasksGetParams, UpdateEvent,
5};
6use crate::ServerConfig;
7use adk_runner::RunnerConfig;
8use axum::{
9    extract::State,
10    http::StatusCode,
11    response::{
12        sse::{Event, Sse},
13        IntoResponse, Json,
14    },
15};
16use futures::stream::Stream;
17use serde_json::Value;
18use std::{convert::Infallible, sync::Arc, time::Duration};
19
20/// Controller for A2A protocol endpoints
21#[derive(Clone)]
22pub struct A2aController {
23    config: ServerConfig,
24    agent_card: AgentCard,
25}
26
27impl A2aController {
28    pub fn new(config: ServerConfig, base_url: &str) -> Self {
29        let root_agent = config.agent_loader.root_agent();
30        let invoke_url = format!("{}/a2a", base_url.trim_end_matches('/'));
31        let agent_card = build_agent_card(root_agent.as_ref(), &invoke_url);
32
33        Self { config, agent_card }
34    }
35}
36
37/// GET /.well-known/agent.json - Serve the agent card
38pub async fn get_agent_card(State(controller): State<A2aController>) -> impl IntoResponse {
39    Json(controller.agent_card.clone())
40}
41
42/// POST /a2a - JSON-RPC endpoint for A2A protocol
43pub async fn handle_jsonrpc(
44    State(controller): State<A2aController>,
45    Json(request): Json<JsonRpcRequest>,
46) -> impl IntoResponse {
47    if request.jsonrpc != "2.0" {
48        return Json(JsonRpcResponse::error(
49            request.id,
50            JsonRpcError::invalid_request("Invalid JSON-RPC version"),
51        ));
52    }
53
54    match request.method.as_str() {
55        jsonrpc::methods::MESSAGE_SEND => {
56            handle_message_send(&controller, request.params, request.id).await
57        }
58        jsonrpc::methods::TASKS_GET => {
59            handle_tasks_get(&controller, request.params, request.id).await
60        }
61        jsonrpc::methods::TASKS_CANCEL => {
62            handle_tasks_cancel(&controller, request.params, request.id).await
63        }
64        _ => Json(JsonRpcResponse::error(
65            request.id,
66            JsonRpcError::method_not_found(&request.method),
67        )),
68    }
69}
70
71/// POST /a2a/stream - SSE streaming endpoint for A2A protocol
72pub async fn handle_jsonrpc_stream(
73    State(controller): State<A2aController>,
74    Json(request): Json<JsonRpcRequest>,
75) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<JsonRpcResponse>)>
76{
77    if request.jsonrpc != "2.0" {
78        return Err((
79            StatusCode::BAD_REQUEST,
80            Json(JsonRpcResponse::error(
81                request.id.clone(),
82                JsonRpcError::invalid_request("Invalid JSON-RPC version"),
83            )),
84        ));
85    }
86
87    if request.method != jsonrpc::methods::MESSAGE_SEND_STREAM
88        && request.method != jsonrpc::methods::MESSAGE_SEND
89    {
90        return Err((
91            StatusCode::BAD_REQUEST,
92            Json(JsonRpcResponse::error(
93                request.id.clone(),
94                JsonRpcError::method_not_found(&request.method),
95            )),
96        ));
97    }
98
99    let params: MessageSendParams = match request.params {
100        Some(p) => serde_json::from_value(p).map_err(|e| {
101            (
102                StatusCode::BAD_REQUEST,
103                Json(JsonRpcResponse::error(
104                    request.id.clone(),
105                    JsonRpcError::invalid_params(e.to_string()),
106                )),
107            )
108        })?,
109        None => {
110            return Err((
111                StatusCode::BAD_REQUEST,
112                Json(JsonRpcResponse::error(
113                    request.id.clone(),
114                    JsonRpcError::invalid_params("Missing params"),
115                )),
116            ))
117        }
118    };
119
120    let request_id = request.id.clone();
121    let stream = create_message_stream(controller, params, request_id);
122
123    Ok(Sse::new(stream).keep_alive(
124        axum::response::sse::KeepAlive::new().interval(Duration::from_secs(15)).text("ping"),
125    ))
126}
127
128fn create_message_stream(
129    controller: A2aController,
130    params: MessageSendParams,
131    request_id: Option<Value>,
132) -> impl Stream<Item = Result<Event, Infallible>> {
133    async_stream::stream! {
134        let context_id = params
135            .message
136            .context_id
137            .clone()
138            .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
139        let task_id = params
140            .message
141            .task_id
142            .clone()
143            .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
144
145        let root_agent = controller.config.agent_loader.root_agent();
146
147        let executor = Executor::new(ExecutorConfig {
148            app_name: root_agent.name().to_string(),
149            runner_config: Arc::new(RunnerConfig {
150                app_name: root_agent.name().to_string(),
151                agent: root_agent,
152                session_service: controller.config.session_service.clone(),
153                artifact_service: controller.config.artifact_service.clone(),
154                memory_service: None,
155            }),
156        });
157
158        match executor.execute(&context_id, &task_id, &params.message).await {
159            Ok(events) => {
160                for event in events {
161                    let event_data = match &event {
162                        UpdateEvent::TaskStatusUpdate(status) => {
163                            serde_json::to_string(&JsonRpcResponse::success(
164                                request_id.clone(),
165                                serde_json::to_value(status).unwrap_or_default(),
166                            ))
167                        }
168                        UpdateEvent::TaskArtifactUpdate(artifact) => {
169                            serde_json::to_string(&JsonRpcResponse::success(
170                                request_id.clone(),
171                                serde_json::to_value(artifact).unwrap_or_default(),
172                            ))
173                        }
174                    };
175
176                    if let Ok(data) = event_data {
177                        yield Ok(Event::default().data(data));
178                    }
179                }
180            }
181            Err(e) => {
182                let error_response = JsonRpcResponse::error(
183                    request_id.clone(),
184                    JsonRpcError::internal_error_sanitized(
185                        &e,
186                        controller.config.security.expose_error_details,
187                    ),
188                );
189                if let Ok(data) = serde_json::to_string(&error_response) {
190                    yield Ok(Event::default().data(data));
191                }
192            }
193        }
194
195        // Send done event
196        yield Ok(Event::default().event("done").data(""));
197    }
198}
199
200async fn handle_message_send(
201    controller: &A2aController,
202    params: Option<Value>,
203    id: Option<Value>,
204) -> Json<JsonRpcResponse> {
205    let params: MessageSendParams = match params {
206        Some(p) => match serde_json::from_value(p) {
207            Ok(p) => p,
208            Err(e) => {
209                return Json(JsonRpcResponse::error(
210                    id,
211                    JsonRpcError::invalid_params(e.to_string()),
212                ))
213            }
214        },
215        None => {
216            return Json(JsonRpcResponse::error(id, JsonRpcError::invalid_params("Missing params")))
217        }
218    };
219
220    let context_id =
221        params.message.context_id.clone().unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
222    let task_id =
223        params.message.task_id.clone().unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
224
225    let root_agent = controller.config.agent_loader.root_agent();
226
227    let executor = Executor::new(ExecutorConfig {
228        app_name: root_agent.name().to_string(),
229        runner_config: Arc::new(RunnerConfig {
230            app_name: root_agent.name().to_string(),
231            agent: root_agent,
232            session_service: controller.config.session_service.clone(),
233            artifact_service: controller.config.artifact_service.clone(),
234            memory_service: None,
235        }),
236    });
237
238    match executor.execute(&context_id, &task_id, &params.message).await {
239        Ok(events) => {
240            // Build task from events
241            let mut task = Task {
242                id: task_id,
243                context_id: Some(context_id),
244                status: TaskStatus { state: TaskState::Completed, message: None },
245                artifacts: Some(vec![]),
246                history: None,
247            };
248
249            for event in events {
250                match event {
251                    UpdateEvent::TaskStatusUpdate(status) => {
252                        task.status = status.status;
253                    }
254                    UpdateEvent::TaskArtifactUpdate(artifact) => {
255                        if let Some(ref mut artifacts) = task.artifacts {
256                            artifacts.push(artifact.artifact);
257                        }
258                    }
259                }
260            }
261
262            Json(JsonRpcResponse::success(id, serde_json::to_value(task).unwrap_or_default()))
263        }
264        Err(e) => Json(JsonRpcResponse::error(
265            id,
266            JsonRpcError::internal_error_sanitized(
267                &e,
268                controller.config.security.expose_error_details,
269            ),
270        )),
271    }
272}
273
274async fn handle_tasks_get(
275    _controller: &A2aController,
276    params: Option<Value>,
277    id: Option<Value>,
278) -> Json<JsonRpcResponse> {
279    let _params: TasksGetParams = match params {
280        Some(p) => match serde_json::from_value(p) {
281            Ok(p) => p,
282            Err(e) => {
283                return Json(JsonRpcResponse::error(
284                    id,
285                    JsonRpcError::invalid_params(e.to_string()),
286                ))
287            }
288        },
289        None => {
290            return Json(JsonRpcResponse::error(id, JsonRpcError::invalid_params("Missing params")))
291        }
292    };
293
294    // TODO: Implement task storage and retrieval
295    Json(JsonRpcResponse::error(
296        id,
297        JsonRpcError::internal_error("Task retrieval not yet implemented"),
298    ))
299}
300
301async fn handle_tasks_cancel(
302    controller: &A2aController,
303    params: Option<Value>,
304    id: Option<Value>,
305) -> Json<JsonRpcResponse> {
306    let params: TasksCancelParams = match params {
307        Some(p) => match serde_json::from_value(p) {
308            Ok(p) => p,
309            Err(e) => {
310                return Json(JsonRpcResponse::error(
311                    id,
312                    JsonRpcError::invalid_params(e.to_string()),
313                ))
314            }
315        },
316        None => {
317            return Json(JsonRpcResponse::error(id, JsonRpcError::invalid_params("Missing params")))
318        }
319    };
320
321    let root_agent = controller.config.agent_loader.root_agent();
322
323    let executor = Executor::new(ExecutorConfig {
324        app_name: root_agent.name().to_string(),
325        runner_config: Arc::new(RunnerConfig {
326            app_name: root_agent.name().to_string(),
327            agent: root_agent,
328            session_service: controller.config.session_service.clone(),
329            artifact_service: controller.config.artifact_service.clone(),
330            memory_service: None,
331        }),
332    });
333
334    // Use a default context_id for cancel
335    let context_id = uuid::Uuid::new_v4().to_string();
336
337    match executor.cancel(&context_id, &params.task_id).await {
338        Ok(status) => {
339            Json(JsonRpcResponse::success(id, serde_json::to_value(status).unwrap_or_default()))
340        }
341        Err(e) => Json(JsonRpcResponse::error(
342            id,
343            JsonRpcError::internal_error_sanitized(
344                &e,
345                controller.config.security.expose_error_details,
346            ),
347        )),
348    }
349}