adk_server/rest/controllers/
a2a.rs1use crate::ServerConfig;
2use crate::a2a::{
3 AgentCard, Executor, ExecutorConfig, JsonRpcError, JsonRpcRequest, JsonRpcResponse,
4 MessageSendParams, Task, TaskState, TaskStatus, TasksCancelParams, TasksGetParams, UpdateEvent,
5 build_agent_card, jsonrpc,
6};
7use adk_runner::RunnerConfig;
8use axum::{
9 extract::State,
10 http::StatusCode,
11 response::{
12 IntoResponse, Json,
13 sse::{Event, Sse},
14 },
15};
16use futures::stream::Stream;
17use serde_json::Value;
18use std::{convert::Infallible, sync::Arc, time::Duration};
19
20#[derive(Clone)]
22pub struct A2aController {
23 config: ServerConfig,
24 agent_card: AgentCard,
25}
26
27impl A2aController {
28 pub fn new(config: ServerConfig, base_url: &str) -> Self {
29 let root_agent = config.agent_loader.root_agent();
30 let invoke_url = format!("{}/a2a", base_url.trim_end_matches('/'));
31 let agent_card = build_agent_card(root_agent.as_ref(), &invoke_url);
32
33 Self { config, agent_card }
34 }
35}
36
37pub async fn get_agent_card(State(controller): State<A2aController>) -> impl IntoResponse {
39 Json(controller.agent_card.clone())
40}
41
42pub async fn handle_jsonrpc(
44 State(controller): State<A2aController>,
45 Json(request): Json<JsonRpcRequest>,
46) -> impl IntoResponse {
47 if request.jsonrpc != "2.0" {
48 return Json(JsonRpcResponse::error(
49 request.id,
50 JsonRpcError::invalid_request("Invalid JSON-RPC version"),
51 ));
52 }
53
54 match request.method.as_str() {
55 jsonrpc::methods::MESSAGE_SEND => {
56 handle_message_send(&controller, request.params, request.id).await
57 }
58 jsonrpc::methods::TASKS_GET => {
59 handle_tasks_get(&controller, request.params, request.id).await
60 }
61 jsonrpc::methods::TASKS_CANCEL => {
62 handle_tasks_cancel(&controller, request.params, request.id).await
63 }
64 _ => Json(JsonRpcResponse::error(
65 request.id,
66 JsonRpcError::method_not_found(&request.method),
67 )),
68 }
69}
70
71pub async fn handle_jsonrpc_stream(
73 State(controller): State<A2aController>,
74 Json(request): Json<JsonRpcRequest>,
75) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<JsonRpcResponse>)>
76{
77 if request.jsonrpc != "2.0" {
78 return Err((
79 StatusCode::BAD_REQUEST,
80 Json(JsonRpcResponse::error(
81 request.id.clone(),
82 JsonRpcError::invalid_request("Invalid JSON-RPC version"),
83 )),
84 ));
85 }
86
87 if request.method != jsonrpc::methods::MESSAGE_SEND_STREAM
88 && request.method != jsonrpc::methods::MESSAGE_SEND
89 {
90 return Err((
91 StatusCode::BAD_REQUEST,
92 Json(JsonRpcResponse::error(
93 request.id.clone(),
94 JsonRpcError::method_not_found(&request.method),
95 )),
96 ));
97 }
98
99 let params: MessageSendParams = match request.params {
100 Some(p) => serde_json::from_value(p).map_err(|e| {
101 (
102 StatusCode::BAD_REQUEST,
103 Json(JsonRpcResponse::error(
104 request.id.clone(),
105 JsonRpcError::invalid_params(e.to_string()),
106 )),
107 )
108 })?,
109 None => {
110 return Err((
111 StatusCode::BAD_REQUEST,
112 Json(JsonRpcResponse::error(
113 request.id.clone(),
114 JsonRpcError::invalid_params("Missing params"),
115 )),
116 ));
117 }
118 };
119
120 let request_id = request.id.clone();
121 let stream = create_message_stream(controller, params, request_id);
122
123 Ok(Sse::new(stream).keep_alive(
124 axum::response::sse::KeepAlive::new().interval(Duration::from_secs(15)).text("ping"),
125 ))
126}
127
128fn create_message_stream(
129 controller: A2aController,
130 params: MessageSendParams,
131 request_id: Option<Value>,
132) -> impl Stream<Item = Result<Event, Infallible>> {
133 async_stream::stream! {
134 let context_id = params
135 .message
136 .context_id
137 .clone()
138 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
139 let task_id = params
140 .message
141 .task_id
142 .clone()
143 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
144
145 let root_agent = controller.config.agent_loader.root_agent();
146
147 let executor = Executor::new(ExecutorConfig {
148 app_name: root_agent.name().to_string(),
149 runner_config: Arc::new(RunnerConfig {
150 app_name: root_agent.name().to_string(),
151 agent: root_agent,
152 session_service: controller.config.session_service.clone(),
153 artifact_service: controller.config.artifact_service.clone(),
154 memory_service: None,
155 run_config: None,
156 }),
157 });
158
159 match executor.execute(&context_id, &task_id, ¶ms.message).await {
160 Ok(events) => {
161 for event in events {
162 let event_data = match &event {
163 UpdateEvent::TaskStatusUpdate(status) => {
164 serde_json::to_string(&JsonRpcResponse::success(
165 request_id.clone(),
166 serde_json::to_value(status).unwrap_or_default(),
167 ))
168 }
169 UpdateEvent::TaskArtifactUpdate(artifact) => {
170 serde_json::to_string(&JsonRpcResponse::success(
171 request_id.clone(),
172 serde_json::to_value(artifact).unwrap_or_default(),
173 ))
174 }
175 };
176
177 if let Ok(data) = event_data {
178 yield Ok(Event::default().data(data));
179 }
180 }
181 }
182 Err(e) => {
183 let error_response = JsonRpcResponse::error(
184 request_id.clone(),
185 JsonRpcError::internal_error_sanitized(
186 &e,
187 controller.config.security.expose_error_details,
188 ),
189 );
190 if let Ok(data) = serde_json::to_string(&error_response) {
191 yield Ok(Event::default().data(data));
192 }
193 }
194 }
195
196 yield Ok(Event::default().event("done").data(""));
198 }
199}
200
201async fn handle_message_send(
202 controller: &A2aController,
203 params: Option<Value>,
204 id: Option<Value>,
205) -> Json<JsonRpcResponse> {
206 let params: MessageSendParams = match params {
207 Some(p) => match serde_json::from_value(p) {
208 Ok(p) => p,
209 Err(e) => {
210 return Json(JsonRpcResponse::error(
211 id,
212 JsonRpcError::invalid_params(e.to_string()),
213 ));
214 }
215 },
216 None => {
217 return Json(JsonRpcResponse::error(
218 id,
219 JsonRpcError::invalid_params("Missing params"),
220 ));
221 }
222 };
223
224 let context_id =
225 params.message.context_id.clone().unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
226 let task_id =
227 params.message.task_id.clone().unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
228
229 let root_agent = controller.config.agent_loader.root_agent();
230
231 let executor = Executor::new(ExecutorConfig {
232 app_name: root_agent.name().to_string(),
233 runner_config: Arc::new(RunnerConfig {
234 app_name: root_agent.name().to_string(),
235 agent: root_agent,
236 session_service: controller.config.session_service.clone(),
237 artifact_service: controller.config.artifact_service.clone(),
238 memory_service: None,
239 run_config: None,
240 }),
241 });
242
243 match executor.execute(&context_id, &task_id, ¶ms.message).await {
244 Ok(events) => {
245 let mut task = Task {
247 id: task_id,
248 context_id: Some(context_id),
249 status: TaskStatus { state: TaskState::Completed, message: None },
250 artifacts: Some(vec![]),
251 history: None,
252 };
253
254 for event in events {
255 match event {
256 UpdateEvent::TaskStatusUpdate(status) => {
257 task.status = status.status;
258 }
259 UpdateEvent::TaskArtifactUpdate(artifact) => {
260 if let Some(ref mut artifacts) = task.artifacts {
261 artifacts.push(artifact.artifact);
262 }
263 }
264 }
265 }
266
267 Json(JsonRpcResponse::success(id, serde_json::to_value(task).unwrap_or_default()))
268 }
269 Err(e) => Json(JsonRpcResponse::error(
270 id,
271 JsonRpcError::internal_error_sanitized(
272 &e,
273 controller.config.security.expose_error_details,
274 ),
275 )),
276 }
277}
278
279async fn handle_tasks_get(
280 _controller: &A2aController,
281 params: Option<Value>,
282 id: Option<Value>,
283) -> Json<JsonRpcResponse> {
284 let _params: TasksGetParams = match params {
285 Some(p) => match serde_json::from_value(p) {
286 Ok(p) => p,
287 Err(e) => {
288 return Json(JsonRpcResponse::error(
289 id,
290 JsonRpcError::invalid_params(e.to_string()),
291 ));
292 }
293 },
294 None => {
295 return Json(JsonRpcResponse::error(
296 id,
297 JsonRpcError::invalid_params("Missing params"),
298 ));
299 }
300 };
301
302 Json(JsonRpcResponse::error(
304 id,
305 JsonRpcError::internal_error("Task retrieval not yet implemented"),
306 ))
307}
308
309async fn handle_tasks_cancel(
310 controller: &A2aController,
311 params: Option<Value>,
312 id: Option<Value>,
313) -> Json<JsonRpcResponse> {
314 let params: TasksCancelParams = match params {
315 Some(p) => match serde_json::from_value(p) {
316 Ok(p) => p,
317 Err(e) => {
318 return Json(JsonRpcResponse::error(
319 id,
320 JsonRpcError::invalid_params(e.to_string()),
321 ));
322 }
323 },
324 None => {
325 return Json(JsonRpcResponse::error(
326 id,
327 JsonRpcError::invalid_params("Missing params"),
328 ));
329 }
330 };
331
332 let root_agent = controller.config.agent_loader.root_agent();
333
334 let executor = Executor::new(ExecutorConfig {
335 app_name: root_agent.name().to_string(),
336 runner_config: Arc::new(RunnerConfig {
337 app_name: root_agent.name().to_string(),
338 agent: root_agent,
339 session_service: controller.config.session_service.clone(),
340 artifact_service: controller.config.artifact_service.clone(),
341 memory_service: None,
342 run_config: None,
343 }),
344 });
345
346 let context_id = uuid::Uuid::new_v4().to_string();
348
349 match executor.cancel(&context_id, ¶ms.task_id).await {
350 Ok(status) => {
351 Json(JsonRpcResponse::success(id, serde_json::to_value(status).unwrap_or_default()))
352 }
353 Err(e) => Json(JsonRpcResponse::error(
354 id,
355 JsonRpcError::internal_error_sanitized(
356 &e,
357 controller.config.security.expose_error_details,
358 ),
359 )),
360 }
361}