adk_server/rest/controllers/
a2a.rs1use crate::a2a::{
2 build_agent_card, jsonrpc, AgentCard, Executor, ExecutorConfig, JsonRpcError, JsonRpcRequest,
3 JsonRpcResponse, MessageSendParams, Task, TaskState, TaskStatus, TasksCancelParams,
4 TasksGetParams, UpdateEvent,
5};
6use crate::ServerConfig;
7use adk_runner::RunnerConfig;
8use axum::{
9 extract::State,
10 http::StatusCode,
11 response::{
12 sse::{Event, Sse},
13 IntoResponse, Json,
14 },
15};
16use futures::stream::Stream;
17use serde_json::Value;
18use std::{convert::Infallible, sync::Arc, time::Duration};
19
20#[derive(Clone)]
22pub struct A2aController {
23 config: ServerConfig,
24 agent_card: AgentCard,
25}
26
27impl A2aController {
28 pub fn new(config: ServerConfig, base_url: &str) -> Self {
29 let root_agent = config.agent_loader.root_agent();
30 let invoke_url = format!("{}/a2a", base_url.trim_end_matches('/'));
31 let agent_card = build_agent_card(root_agent.as_ref(), &invoke_url);
32
33 Self { config, agent_card }
34 }
35}
36
37pub async fn get_agent_card(State(controller): State<A2aController>) -> impl IntoResponse {
39 Json(controller.agent_card.clone())
40}
41
42pub async fn handle_jsonrpc(
44 State(controller): State<A2aController>,
45 Json(request): Json<JsonRpcRequest>,
46) -> impl IntoResponse {
47 if request.jsonrpc != "2.0" {
48 return Json(JsonRpcResponse::error(
49 request.id,
50 JsonRpcError::invalid_request("Invalid JSON-RPC version"),
51 ));
52 }
53
54 match request.method.as_str() {
55 jsonrpc::methods::MESSAGE_SEND => {
56 handle_message_send(&controller, request.params, request.id).await
57 }
58 jsonrpc::methods::TASKS_GET => {
59 handle_tasks_get(&controller, request.params, request.id).await
60 }
61 jsonrpc::methods::TASKS_CANCEL => {
62 handle_tasks_cancel(&controller, request.params, request.id).await
63 }
64 _ => Json(JsonRpcResponse::error(
65 request.id,
66 JsonRpcError::method_not_found(&request.method),
67 )),
68 }
69}
70
71pub async fn handle_jsonrpc_stream(
73 State(controller): State<A2aController>,
74 Json(request): Json<JsonRpcRequest>,
75) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<JsonRpcResponse>)>
76{
77 if request.jsonrpc != "2.0" {
78 return Err((
79 StatusCode::BAD_REQUEST,
80 Json(JsonRpcResponse::error(
81 request.id.clone(),
82 JsonRpcError::invalid_request("Invalid JSON-RPC version"),
83 )),
84 ));
85 }
86
87 if request.method != jsonrpc::methods::MESSAGE_SEND_STREAM
88 && request.method != jsonrpc::methods::MESSAGE_SEND
89 {
90 return Err((
91 StatusCode::BAD_REQUEST,
92 Json(JsonRpcResponse::error(
93 request.id.clone(),
94 JsonRpcError::method_not_found(&request.method),
95 )),
96 ));
97 }
98
99 let params: MessageSendParams = match request.params {
100 Some(p) => serde_json::from_value(p).map_err(|e| {
101 (
102 StatusCode::BAD_REQUEST,
103 Json(JsonRpcResponse::error(
104 request.id.clone(),
105 JsonRpcError::invalid_params(e.to_string()),
106 )),
107 )
108 })?,
109 None => {
110 return Err((
111 StatusCode::BAD_REQUEST,
112 Json(JsonRpcResponse::error(
113 request.id.clone(),
114 JsonRpcError::invalid_params("Missing params"),
115 )),
116 ))
117 }
118 };
119
120 let request_id = request.id.clone();
121 let stream = create_message_stream(controller, params, request_id);
122
123 Ok(Sse::new(stream).keep_alive(
124 axum::response::sse::KeepAlive::new().interval(Duration::from_secs(15)).text("ping"),
125 ))
126}
127
128fn create_message_stream(
129 controller: A2aController,
130 params: MessageSendParams,
131 request_id: Option<Value>,
132) -> impl Stream<Item = Result<Event, Infallible>> {
133 async_stream::stream! {
134 let context_id = params
135 .message
136 .context_id
137 .clone()
138 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
139 let task_id = params
140 .message
141 .task_id
142 .clone()
143 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
144
145 let root_agent = controller.config.agent_loader.root_agent();
146
147 let executor = Executor::new(ExecutorConfig {
148 app_name: root_agent.name().to_string(),
149 runner_config: Arc::new(RunnerConfig {
150 app_name: root_agent.name().to_string(),
151 agent: root_agent,
152 session_service: controller.config.session_service.clone(),
153 artifact_service: controller.config.artifact_service.clone(),
154 memory_service: None,
155 }),
156 });
157
158 match executor.execute(&context_id, &task_id, ¶ms.message).await {
159 Ok(events) => {
160 for event in events {
161 let event_data = match &event {
162 UpdateEvent::TaskStatusUpdate(status) => {
163 serde_json::to_string(&JsonRpcResponse::success(
164 request_id.clone(),
165 serde_json::to_value(status).unwrap_or_default(),
166 ))
167 }
168 UpdateEvent::TaskArtifactUpdate(artifact) => {
169 serde_json::to_string(&JsonRpcResponse::success(
170 request_id.clone(),
171 serde_json::to_value(artifact).unwrap_or_default(),
172 ))
173 }
174 };
175
176 if let Ok(data) = event_data {
177 yield Ok(Event::default().data(data));
178 }
179 }
180 }
181 Err(e) => {
182 let error_response = JsonRpcResponse::error(
183 request_id.clone(),
184 JsonRpcError::internal_error_sanitized(
185 &e,
186 controller.config.security.expose_error_details,
187 ),
188 );
189 if let Ok(data) = serde_json::to_string(&error_response) {
190 yield Ok(Event::default().data(data));
191 }
192 }
193 }
194
195 yield Ok(Event::default().event("done").data(""));
197 }
198}
199
200async fn handle_message_send(
201 controller: &A2aController,
202 params: Option<Value>,
203 id: Option<Value>,
204) -> Json<JsonRpcResponse> {
205 let params: MessageSendParams = match params {
206 Some(p) => match serde_json::from_value(p) {
207 Ok(p) => p,
208 Err(e) => {
209 return Json(JsonRpcResponse::error(
210 id,
211 JsonRpcError::invalid_params(e.to_string()),
212 ))
213 }
214 },
215 None => {
216 return Json(JsonRpcResponse::error(id, JsonRpcError::invalid_params("Missing params")))
217 }
218 };
219
220 let context_id =
221 params.message.context_id.clone().unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
222 let task_id =
223 params.message.task_id.clone().unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
224
225 let root_agent = controller.config.agent_loader.root_agent();
226
227 let executor = Executor::new(ExecutorConfig {
228 app_name: root_agent.name().to_string(),
229 runner_config: Arc::new(RunnerConfig {
230 app_name: root_agent.name().to_string(),
231 agent: root_agent,
232 session_service: controller.config.session_service.clone(),
233 artifact_service: controller.config.artifact_service.clone(),
234 memory_service: None,
235 }),
236 });
237
238 match executor.execute(&context_id, &task_id, ¶ms.message).await {
239 Ok(events) => {
240 let mut task = Task {
242 id: task_id,
243 context_id: Some(context_id),
244 status: TaskStatus { state: TaskState::Completed, message: None },
245 artifacts: Some(vec![]),
246 history: None,
247 };
248
249 for event in events {
250 match event {
251 UpdateEvent::TaskStatusUpdate(status) => {
252 task.status = status.status;
253 }
254 UpdateEvent::TaskArtifactUpdate(artifact) => {
255 if let Some(ref mut artifacts) = task.artifacts {
256 artifacts.push(artifact.artifact);
257 }
258 }
259 }
260 }
261
262 Json(JsonRpcResponse::success(id, serde_json::to_value(task).unwrap_or_default()))
263 }
264 Err(e) => Json(JsonRpcResponse::error(
265 id,
266 JsonRpcError::internal_error_sanitized(
267 &e,
268 controller.config.security.expose_error_details,
269 ),
270 )),
271 }
272}
273
274async fn handle_tasks_get(
275 _controller: &A2aController,
276 params: Option<Value>,
277 id: Option<Value>,
278) -> Json<JsonRpcResponse> {
279 let _params: TasksGetParams = match params {
280 Some(p) => match serde_json::from_value(p) {
281 Ok(p) => p,
282 Err(e) => {
283 return Json(JsonRpcResponse::error(
284 id,
285 JsonRpcError::invalid_params(e.to_string()),
286 ))
287 }
288 },
289 None => {
290 return Json(JsonRpcResponse::error(id, JsonRpcError::invalid_params("Missing params")))
291 }
292 };
293
294 Json(JsonRpcResponse::error(
296 id,
297 JsonRpcError::internal_error("Task retrieval not yet implemented"),
298 ))
299}
300
301async fn handle_tasks_cancel(
302 controller: &A2aController,
303 params: Option<Value>,
304 id: Option<Value>,
305) -> Json<JsonRpcResponse> {
306 let params: TasksCancelParams = match params {
307 Some(p) => match serde_json::from_value(p) {
308 Ok(p) => p,
309 Err(e) => {
310 return Json(JsonRpcResponse::error(
311 id,
312 JsonRpcError::invalid_params(e.to_string()),
313 ))
314 }
315 },
316 None => {
317 return Json(JsonRpcResponse::error(id, JsonRpcError::invalid_params("Missing params")))
318 }
319 };
320
321 let root_agent = controller.config.agent_loader.root_agent();
322
323 let executor = Executor::new(ExecutorConfig {
324 app_name: root_agent.name().to_string(),
325 runner_config: Arc::new(RunnerConfig {
326 app_name: root_agent.name().to_string(),
327 agent: root_agent,
328 session_service: controller.config.session_service.clone(),
329 artifact_service: controller.config.artifact_service.clone(),
330 memory_service: None,
331 }),
332 });
333
334 let context_id = uuid::Uuid::new_v4().to_string();
336
337 match executor.cancel(&context_id, ¶ms.task_id).await {
338 Ok(status) => {
339 Json(JsonRpcResponse::success(id, serde_json::to_value(status).unwrap_or_default()))
340 }
341 Err(e) => Json(JsonRpcResponse::error(
342 id,
343 JsonRpcError::internal_error_sanitized(
344 &e,
345 controller.config.security.expose_error_details,
346 ),
347 )),
348 }
349}