1use crate::ServerConfig;
2use crate::a2a::{
3 AgentCard, Executor, ExecutorConfig, JsonRpcError, JsonRpcRequest, JsonRpcResponse,
4 MessageSendParams, Task, TaskState, TaskStatus, TasksCancelParams, TasksGetParams, UpdateEvent,
5 build_agent_card, jsonrpc,
6};
7use adk_runner::RunnerConfig;
8use axum::{
9 extract::State,
10 http::StatusCode,
11 response::{
12 IntoResponse, Json,
13 sse::{Event, Sse},
14 },
15};
16use futures::stream::Stream;
17use serde_json::Value;
18use std::{collections::HashMap, convert::Infallible, sync::Arc, time::Duration};
19use tokio::sync::RwLock;
20
21#[derive(Default)]
23pub struct TaskStore {
24 tasks: RwLock<HashMap<String, Task>>,
25}
26
27impl TaskStore {
28 pub fn new() -> Self {
29 Self::default()
30 }
31
32 pub async fn store(&self, task: Task) {
33 self.tasks.write().await.insert(task.id.clone(), task);
34 }
35
36 pub async fn get(&self, task_id: &str) -> Option<Task> {
37 self.tasks.read().await.get(task_id).cloned()
38 }
39
40 pub async fn remove(&self, task_id: &str) -> Option<Task> {
41 self.tasks.write().await.remove(task_id)
42 }
43}
44
45#[derive(Clone)]
47pub struct A2aController {
48 config: ServerConfig,
49 agent_card: AgentCard,
50 task_store: Arc<TaskStore>,
51}
52
53impl A2aController {
54 pub fn new(config: ServerConfig, base_url: &str) -> Self {
55 let root_agent = config.agent_loader.root_agent();
56 let invoke_url = format!("{}/a2a", base_url.trim_end_matches('/'));
57 let agent_card = build_agent_card(root_agent.as_ref(), &invoke_url);
58
59 Self { config, agent_card, task_store: Arc::new(TaskStore::new()) }
60 }
61}
62
63pub async fn get_agent_card(State(controller): State<A2aController>) -> impl IntoResponse {
65 Json(controller.agent_card.clone())
66}
67
68pub async fn handle_jsonrpc(
70 State(controller): State<A2aController>,
71 Json(request): Json<JsonRpcRequest>,
72) -> impl IntoResponse {
73 if request.jsonrpc != "2.0" {
74 return Json(JsonRpcResponse::error(
75 request.id,
76 JsonRpcError::invalid_request("Invalid JSON-RPC version"),
77 ));
78 }
79
80 match request.method.as_str() {
81 jsonrpc::methods::MESSAGE_SEND => {
82 handle_message_send(&controller, request.params, request.id).await
83 }
84 jsonrpc::methods::TASKS_GET => {
85 handle_tasks_get(&controller, request.params, request.id).await
86 }
87 jsonrpc::methods::TASKS_CANCEL => {
88 handle_tasks_cancel(&controller, request.params, request.id).await
89 }
90 _ => Json(JsonRpcResponse::error(
91 request.id,
92 JsonRpcError::method_not_found(&request.method),
93 )),
94 }
95}
96
97pub async fn handle_jsonrpc_stream(
99 State(controller): State<A2aController>,
100 Json(request): Json<JsonRpcRequest>,
101) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<JsonRpcResponse>)>
102{
103 if request.jsonrpc != "2.0" {
104 return Err((
105 StatusCode::BAD_REQUEST,
106 Json(JsonRpcResponse::error(
107 request.id.clone(),
108 JsonRpcError::invalid_request("Invalid JSON-RPC version"),
109 )),
110 ));
111 }
112
113 if request.method != jsonrpc::methods::MESSAGE_SEND_STREAM
114 && request.method != jsonrpc::methods::MESSAGE_SEND
115 {
116 return Err((
117 StatusCode::BAD_REQUEST,
118 Json(JsonRpcResponse::error(
119 request.id.clone(),
120 JsonRpcError::method_not_found(&request.method),
121 )),
122 ));
123 }
124
125 let params: MessageSendParams = match request.params {
126 Some(p) => serde_json::from_value(p).map_err(|e| {
127 (
128 StatusCode::BAD_REQUEST,
129 Json(JsonRpcResponse::error(
130 request.id.clone(),
131 JsonRpcError::invalid_params(e.to_string()),
132 )),
133 )
134 })?,
135 None => {
136 return Err((
137 StatusCode::BAD_REQUEST,
138 Json(JsonRpcResponse::error(
139 request.id.clone(),
140 JsonRpcError::invalid_params("Missing params"),
141 )),
142 ));
143 }
144 };
145
146 let request_id = request.id.clone();
147 let stream = create_message_stream(controller, params, request_id);
148
149 Ok(Sse::new(stream).keep_alive(
150 axum::response::sse::KeepAlive::new().interval(Duration::from_secs(15)).text("ping"),
151 ))
152}
153
154fn create_message_stream(
155 controller: A2aController,
156 params: MessageSendParams,
157 request_id: Option<Value>,
158) -> impl Stream<Item = Result<Event, Infallible>> {
159 async_stream::stream! {
160 let context_id = params
161 .message
162 .context_id
163 .clone()
164 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
165 let task_id = params
166 .message
167 .task_id
168 .clone()
169 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
170
171 let root_agent = controller.config.agent_loader.root_agent();
172
173 let executor = Executor::new(ExecutorConfig {
174 app_name: root_agent.name().to_string(),
175 runner_config: Arc::new(RunnerConfig {
176 app_name: root_agent.name().to_string(),
177 agent: root_agent,
178 session_service: controller.config.session_service.clone(),
179 artifact_service: controller.config.artifact_service.clone(),
180 memory_service: None,
181 run_config: None,
182 }),
183 });
184
185 match executor.execute(&context_id, &task_id, ¶ms.message).await {
186 Ok(events) => {
187 for event in events {
188 let event_data = match &event {
189 UpdateEvent::TaskStatusUpdate(status) => {
190 serde_json::to_string(&JsonRpcResponse::success(
191 request_id.clone(),
192 serde_json::to_value(status).unwrap_or_default(),
193 ))
194 }
195 UpdateEvent::TaskArtifactUpdate(artifact) => {
196 serde_json::to_string(&JsonRpcResponse::success(
197 request_id.clone(),
198 serde_json::to_value(artifact).unwrap_or_default(),
199 ))
200 }
201 };
202
203 if let Ok(data) = event_data {
204 yield Ok(Event::default().data(data));
205 }
206 }
207 }
208 Err(e) => {
209 let error_response = JsonRpcResponse::error(
210 request_id.clone(),
211 JsonRpcError::internal_error_sanitized(
212 &e,
213 controller.config.security.expose_error_details,
214 ),
215 );
216 if let Ok(data) = serde_json::to_string(&error_response) {
217 yield Ok(Event::default().data(data));
218 }
219 }
220 }
221
222 yield Ok(Event::default().event("done").data(""));
224 }
225}
226
227async fn handle_message_send(
228 controller: &A2aController,
229 params: Option<Value>,
230 id: Option<Value>,
231) -> Json<JsonRpcResponse> {
232 let params: MessageSendParams = match params {
233 Some(p) => match serde_json::from_value(p) {
234 Ok(p) => p,
235 Err(e) => {
236 return Json(JsonRpcResponse::error(
237 id,
238 JsonRpcError::invalid_params(e.to_string()),
239 ));
240 }
241 },
242 None => {
243 return Json(JsonRpcResponse::error(
244 id,
245 JsonRpcError::invalid_params("Missing params"),
246 ));
247 }
248 };
249
250 let context_id =
251 params.message.context_id.clone().unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
252 let task_id =
253 params.message.task_id.clone().unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
254
255 let root_agent = controller.config.agent_loader.root_agent();
256
257 let executor = Executor::new(ExecutorConfig {
258 app_name: root_agent.name().to_string(),
259 runner_config: Arc::new(RunnerConfig {
260 app_name: root_agent.name().to_string(),
261 agent: root_agent,
262 session_service: controller.config.session_service.clone(),
263 artifact_service: controller.config.artifact_service.clone(),
264 memory_service: None,
265 run_config: None,
266 }),
267 });
268
269 match executor.execute(&context_id, &task_id, ¶ms.message).await {
270 Ok(events) => {
271 let mut task = Task {
273 id: task_id,
274 context_id: Some(context_id),
275 status: TaskStatus { state: TaskState::Completed, message: None },
276 artifacts: Some(vec![]),
277 history: None,
278 };
279
280 for event in events {
281 match event {
282 UpdateEvent::TaskStatusUpdate(status) => {
283 task.status = status.status;
284 }
285 UpdateEvent::TaskArtifactUpdate(artifact) => {
286 if let Some(ref mut artifacts) = task.artifacts {
287 artifacts.push(artifact.artifact);
288 }
289 }
290 }
291 }
292
293 controller.task_store.store(task.clone()).await;
295
296 Json(JsonRpcResponse::success(id, serde_json::to_value(task).unwrap_or_default()))
297 }
298 Err(e) => Json(JsonRpcResponse::error(
299 id,
300 JsonRpcError::internal_error_sanitized(
301 &e,
302 controller.config.security.expose_error_details,
303 ),
304 )),
305 }
306}
307
308async fn handle_tasks_get(
309 controller: &A2aController,
310 params: Option<Value>,
311 id: Option<Value>,
312) -> Json<JsonRpcResponse> {
313 let params: TasksGetParams = match params {
314 Some(p) => match serde_json::from_value(p) {
315 Ok(p) => p,
316 Err(e) => {
317 return Json(JsonRpcResponse::error(
318 id,
319 JsonRpcError::invalid_params(e.to_string()),
320 ));
321 }
322 },
323 None => {
324 return Json(JsonRpcResponse::error(
325 id,
326 JsonRpcError::invalid_params("Missing params"),
327 ));
328 }
329 };
330
331 match controller.task_store.get(¶ms.task_id).await {
332 Some(task) => {
333 Json(JsonRpcResponse::success(id, serde_json::to_value(task).unwrap_or_default()))
334 }
335 None => Json(JsonRpcResponse::error(
336 id,
337 JsonRpcError::internal_error(format!("Task not found: {}", params.task_id)),
338 )),
339 }
340}
341
342async fn handle_tasks_cancel(
343 controller: &A2aController,
344 params: Option<Value>,
345 id: Option<Value>,
346) -> Json<JsonRpcResponse> {
347 let params: TasksCancelParams = match params {
348 Some(p) => match serde_json::from_value(p) {
349 Ok(p) => p,
350 Err(e) => {
351 return Json(JsonRpcResponse::error(
352 id,
353 JsonRpcError::invalid_params(e.to_string()),
354 ));
355 }
356 },
357 None => {
358 return Json(JsonRpcResponse::error(
359 id,
360 JsonRpcError::invalid_params("Missing params"),
361 ));
362 }
363 };
364
365 let root_agent = controller.config.agent_loader.root_agent();
366
367 let executor = Executor::new(ExecutorConfig {
368 app_name: root_agent.name().to_string(),
369 runner_config: Arc::new(RunnerConfig {
370 app_name: root_agent.name().to_string(),
371 agent: root_agent,
372 session_service: controller.config.session_service.clone(),
373 artifact_service: controller.config.artifact_service.clone(),
374 memory_service: None,
375 run_config: None,
376 }),
377 });
378
379 let context_id = uuid::Uuid::new_v4().to_string();
381
382 match executor.cancel(&context_id, ¶ms.task_id).await {
383 Ok(status) => {
384 Json(JsonRpcResponse::success(id, serde_json::to_value(status).unwrap_or_default()))
385 }
386 Err(e) => Json(JsonRpcResponse::error(
387 id,
388 JsonRpcError::internal_error_sanitized(
389 &e,
390 controller.config.security.expose_error_details,
391 ),
392 )),
393 }
394}