Skip to main content

routa_server/api/
a2a.rs

1//! A2A Protocol API
2//!
3//! /api/a2a/sessions - List active sessions
4//! /api/a2a/rpc     - JSON-RPC endpoint + SSE stream
5//! /api/a2a/card    - Agent card discovery
6
7use axum::{
8    extract::{Path, Query, State},
9    response::sse::{Event, KeepAlive, Sse},
10    routing::get,
11    Json, Router,
12};
13use chrono::Utc;
14use routa_core::models::task::{Task, TaskStatus};
15use serde::Deserialize;
16use std::convert::Infallible;
17use std::time::Duration;
18use tokio_stream::StreamExt as _;
19
20use crate::error::ServerError;
21use crate::state::AppState;
22
23pub fn router() -> Router<AppState> {
24    Router::new()
25        .route("/sessions", get(list_sessions))
26        .route("/rpc", get(rpc_sse).post(rpc_handler))
27        .route("/card", get(agent_card))
28        .route("/message", axum::routing::post(send_message))
29        .route("/tasks", get(list_tasks))
30        .route("/tasks/{id}", get(get_task).post(update_task))
31}
32
33// ─── /api/a2a/sessions ────────────────────────────────────────────────
34
35async fn list_sessions(
36    State(state): State<AppState>,
37) -> Result<Json<serde_json::Value>, ServerError> {
38    let sessions = state.acp_manager.list_sessions().await;
39
40    let a2a_sessions: Vec<serde_json::Value> = sessions
41        .iter()
42        .map(|s| {
43            serde_json::json!({
44                "id": s.session_id,
45                "agentName": format!("routa-{}-{}", s.provider.as_deref().unwrap_or("agent"), &s.session_id[..8.min(s.session_id.len())]),
46                "provider": s.provider.as_deref().unwrap_or("unknown"),
47                "status": "connected",
48                "capabilities": [
49                    "initialize", "method_list",
50                    "session/new", "session/prompt", "session/cancel", "session/load",
51                    "list_agents", "create_agent", "delegate_task", "message_agent"
52                ],
53                "rpcUrl": format!("/api/a2a/rpc?sessionId={}", s.session_id),
54                "eventStreamUrl": format!("/api/a2a/rpc?sessionId={}", s.session_id),
55                "createdAt": s.created_at,
56            })
57        })
58        .collect();
59
60    Ok(Json(serde_json::json!({
61        "sessions": a2a_sessions,
62        "count": a2a_sessions.len(),
63    })))
64}
65
66// ─── /api/a2a/card ────────────────────────────────────────────────────
67
68async fn agent_card() -> Json<serde_json::Value> {
69    Json(serde_json::json!({
70        "name": "Routa Multi-Agent Coordinator",
71        "description": "Multi-agent coordination platform with ACP and MCP support",
72        "protocolVersion": "0.3.0",
73        "version": "0.1.0",
74        "url": "/api/a2a/rpc",
75        "skills": [
76            {
77                "id": "coordination",
78                "name": "Agent Coordination",
79                "description": "Create, delegate tasks to, and coordinate multiple AI agents",
80                "tags": ["coordination", "multi-agent", "orchestration"],
81            },
82            {
83                "id": "acp-proxy",
84                "name": "ACP Session Proxy",
85                "description": "Proxy access to backend ACP agent sessions",
86                "tags": ["acp", "session", "proxy"],
87            }
88        ],
89        "capabilities": { "pushNotifications": true },
90        "defaultInputModes": ["text"],
91        "defaultOutputModes": ["text"],
92        "additionalInterfaces": [{
93            "url": "/api/a2a/rpc",
94            "transport": "JSONRPC",
95        }],
96    }))
97}
98
99// ─── /api/a2a/rpc POST ───────────────────────────────────────────────
100
101#[derive(Debug, Deserialize)]
102#[serde(rename_all = "camelCase")]
103struct RpcQuery {
104    session_id: Option<String>,
105}
106
107async fn rpc_handler(
108    State(state): State<AppState>,
109    Query(query): Query<RpcQuery>,
110    Json(body): Json<serde_json::Value>,
111) -> Result<Json<serde_json::Value>, ServerError> {
112    let method = body.get("method").and_then(|m| m.as_str()).unwrap_or("");
113    let id = body.get("id").cloned().unwrap_or(serde_json::json!(null));
114    let params = body.get("params").cloned().unwrap_or_default();
115
116    let result =
117        match method {
118            "method_list" => serde_json::json!({
119                "methods": [
120                    "SendMessage", "GetTask", "ListTasks", "CancelTask",
121                    "method_list", "initialize",
122                    "session/new", "session/prompt", "session/cancel", "session/load",
123                    "list_agents", "create_agent", "delegate_task", "message_agent",
124                ]
125            }),
126
127            "initialize" => serde_json::json!({
128                "protocolVersion": "0.3.0",
129                "agentInfo": { "name": "routa-a2a-bridge", "version": "0.1.0" },
130                "capabilities": { "sessions": true, "coordination": true, "tasks": true },
131            }),
132
133            "SendMessage" => {
134                let workspace_id = params
135                    .get("metadata")
136                    .and_then(|value| value.get("workspaceId"))
137                    .and_then(|value| value.as_str())
138                    .unwrap_or("default")
139                    .to_string();
140                let prompt = extract_a2a_prompt(&params)?;
141                let task_id = uuid::Uuid::new_v4().to_string();
142                let context_id = params
143                    .get("message")
144                    .and_then(|value| value.get("contextId"))
145                    .and_then(|value| value.as_str())
146                    .map(ToOwned::to_owned)
147                    .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
148                let title = prompt
149                    .lines()
150                    .find(|line| !line.trim().is_empty())
151                    .map(|line| truncate_text(line.trim(), 80))
152                    .filter(|line| !line.is_empty())
153                    .unwrap_or_else(|| "A2A task".to_string());
154
155                let task = Task::new(
156                    task_id.clone(),
157                    title,
158                    prompt,
159                    workspace_id,
160                    Some(context_id.clone()),
161                    None,
162                    None,
163                    None,
164                    None,
165                    None,
166                    None,
167                );
168                state.task_store.save(&task).await?;
169
170                let state_clone = state.clone();
171                let task_id_clone = task_id.clone();
172                tokio::spawn(async move {
173                    tokio::time::sleep(Duration::from_millis(200)).await;
174                    let _ = state_clone
175                        .task_store
176                        .update_status(&task_id_clone, &TaskStatus::Completed)
177                        .await;
178                });
179
180                build_a2a_task_payload(&task, "submitted", Some(Utc::now().to_rfc3339()))
181            }
182
183            "GetTask" => {
184                let task_id = params
185                    .get("id")
186                    .and_then(|value| value.as_str())
187                    .ok_or_else(|| ServerError::BadRequest("Missing task id".into()))?;
188                let task =
189                    state.task_store.get(task_id).await?.ok_or_else(|| {
190                        ServerError::NotFound(format!("Task {task_id} not found"))
191                    })?;
192                build_a2a_task_payload(
193                    &task,
194                    map_task_status_to_a2a_state(&task.status),
195                    Some(task.updated_at.to_rfc3339()),
196                )
197            }
198
199            "ListTasks" => {
200                let workspace_id = params
201                    .get("workspaceId")
202                    .and_then(|value| value.as_str())
203                    .unwrap_or("default");
204                let tasks = state.task_store.list_by_workspace(workspace_id).await?;
205                serde_json::json!({
206                    "tasks": tasks
207                        .iter()
208                        .map(|task| {
209                            build_a2a_task_payload(
210                                task,
211                                map_task_status_to_a2a_state(&task.status),
212                                Some(task.updated_at.to_rfc3339()),
213                            )["task"].clone()
214                        })
215                        .collect::<Vec<_>>()
216                })
217            }
218
219            "CancelTask" => {
220                let task_id = params
221                    .get("id")
222                    .and_then(|value| value.as_str())
223                    .ok_or_else(|| ServerError::BadRequest("Missing task id".into()))?;
224                state
225                    .task_store
226                    .update_status(task_id, &TaskStatus::Cancelled)
227                    .await?;
228                let task =
229                    state.task_store.get(task_id).await?.ok_or_else(|| {
230                        ServerError::NotFound(format!("Task {task_id} not found"))
231                    })?;
232                build_a2a_task_payload(&task, "canceled", Some(task.updated_at.to_rfc3339()))
233            }
234
235            "list_agents" => {
236                let workspace_id = params
237                    .get("workspaceId")
238                    .and_then(|v| v.as_str())
239                    .unwrap_or("default");
240                let agents = state.agent_store.list_by_workspace(workspace_id).await?;
241                serde_json::json!({ "agents": agents })
242            }
243
244            "create_agent" => {
245                let name = params
246                    .get("name")
247                    .and_then(|v| v.as_str())
248                    .ok_or_else(|| ServerError::BadRequest("Missing name".into()))?;
249                let role = params
250                    .get("role")
251                    .and_then(|v| v.as_str())
252                    .ok_or_else(|| ServerError::BadRequest("Missing role".into()))?;
253                let workspace_id = params
254                    .get("workspaceId")
255                    .and_then(|v| v.as_str())
256                    .unwrap_or("default");
257
258                let agent_role = crate::models::agent::AgentRole::from_str(role)
259                    .ok_or_else(|| ServerError::BadRequest(format!("Invalid role: {role}")))?;
260
261                let agent = crate::models::agent::Agent::new(
262                    uuid::Uuid::new_v4().to_string(),
263                    name.to_string(),
264                    agent_role,
265                    workspace_id.to_string(),
266                    None,
267                    None,
268                    None,
269                );
270                state.agent_store.save(&agent).await?;
271                serde_json::json!({ "success": true, "agentId": agent.id })
272            }
273
274            "delegate_task" | "message_agent" => {
275                // Acknowledge and return stub
276                serde_json::json!({
277                    "status": "forwarded",
278                    "sessionId": query.session_id,
279                    "method": method,
280                    "message": "Request forwarded to backend session",
281                })
282            }
283
284            _ => {
285                return Ok(Json(serde_json::json!({
286                    "jsonrpc": "2.0",
287                    "id": id,
288                    "error": { "code": -32601, "message": format!("Unknown method: {}", method) }
289                })));
290            }
291        };
292
293    Ok(Json(serde_json::json!({
294        "jsonrpc": "2.0",
295        "id": id,
296        "result": result,
297    })))
298}
299
300// ─── /api/a2a/rpc GET (SSE) ──────────────────────────────────────────
301
302async fn rpc_sse(
303    Query(query): Query<RpcQuery>,
304) -> Result<Sse<impl tokio_stream::Stream<Item = Result<Event, Infallible>>>, axum::http::StatusCode>
305{
306    let session_id = match query.session_id {
307        Some(id) => id,
308        None => return Err(axum::http::StatusCode::BAD_REQUEST),
309    };
310
311    let connected_event = serde_json::json!({
312        "jsonrpc": "2.0",
313        "method": "notification",
314        "params": {
315            "type": "connected",
316            "sessionId": session_id,
317            "message": "A2A event stream connected",
318        }
319    });
320
321    let initial = tokio_stream::once(Ok::<_, Infallible>(
322        Event::default().data(connected_event.to_string()),
323    ));
324
325    let heartbeat = tokio_stream::wrappers::IntervalStream::new(tokio::time::interval(
326        std::time::Duration::from_secs(30),
327    ))
328    .map(|_| Ok(Event::default().comment("keep-alive")));
329
330    let stream = initial.chain(heartbeat);
331
332    Ok(Sse::new(stream).keep_alive(KeepAlive::default()))
333}
334
335// ─── /api/a2a/message ────────────────────────────────────────────────
336
337/// POST /api/a2a/message — Send a message via the A2A protocol
338async fn send_message(Json(body): Json<serde_json::Value>) -> Json<serde_json::Value> {
339    let method = body
340        .get("method")
341        .and_then(|v| v.as_str())
342        .unwrap_or("sendMessage");
343
344    let session_id = body
345        .get("params")
346        .and_then(|p| p.get("sessionId"))
347        .and_then(|v| v.as_str())
348        .unwrap_or("default");
349
350    Json(serde_json::json!({
351        "jsonrpc": "2.0",
352        "id": body.get("id"),
353        "result": {
354            "status": "accepted",
355            "method": method,
356            "sessionId": session_id,
357        }
358    }))
359}
360
361// ─── /api/a2a/tasks ──────────────────────────────────────────────────
362
363#[derive(Debug, Deserialize)]
364#[serde(rename_all = "camelCase")]
365struct TasksQuery {
366    session_id: Option<String>,
367    workspace_id: Option<String>,
368}
369
370/// GET /api/a2a/tasks — List A2A tasks (mapped from Routa tasks)
371async fn list_tasks(
372    State(state): State<AppState>,
373    Query(q): Query<TasksQuery>,
374) -> Result<Json<serde_json::Value>, ServerError> {
375    let tasks = if let Some(session_id) = &q.session_id {
376        state.task_store.list_by_session(session_id).await?
377    } else {
378        let ws = q.workspace_id.as_deref().unwrap_or("default");
379        state.task_store.list_by_workspace(ws).await?
380    };
381    Ok(Json(serde_json::json!({ "tasks": tasks })))
382}
383
384/// GET /api/a2a/tasks/{id} — Get an A2A task by ID
385async fn get_task(
386    State(state): State<AppState>,
387    Path(id): Path<String>,
388) -> Result<Json<serde_json::Value>, ServerError> {
389    state
390        .task_store
391        .get(&id)
392        .await?
393        .map(|t| Json(serde_json::json!(t)))
394        .ok_or_else(|| ServerError::NotFound(format!("Task {id} not found")))
395}
396
397/// POST /api/a2a/tasks/{id} — Update / respond to an A2A task
398async fn update_task(
399    State(state): State<AppState>,
400    Path(id): Path<String>,
401    Json(body): Json<serde_json::Value>,
402) -> Result<Json<serde_json::Value>, ServerError> {
403    if let Some(status) = body.get("status").and_then(|v| v.as_str()) {
404        let task_status = crate::models::task::TaskStatus::from_str(status)
405            .ok_or_else(|| ServerError::BadRequest(format!("Invalid status: {status}")))?;
406        state.task_store.update_status(&id, &task_status).await?;
407        Ok(Json(
408            serde_json::json!({ "updated": true, "id": id, "status": status }),
409        ))
410    } else {
411        Ok(Json(
412            serde_json::json!({ "updated": false, "id": id, "message": "No status change requested" }),
413        ))
414    }
415}
416
417fn extract_a2a_prompt(params: &serde_json::Value) -> Result<String, ServerError> {
418    let parts = params
419        .get("message")
420        .and_then(|value| value.get("parts"))
421        .and_then(|value| value.as_array())
422        .ok_or_else(|| ServerError::BadRequest("Missing message parts".into()))?;
423    let prompt = parts
424        .iter()
425        .filter_map(|part| part.get("text").and_then(|value| value.as_str()))
426        .map(str::trim)
427        .filter(|part| !part.is_empty())
428        .collect::<Vec<_>>()
429        .join("\n");
430    if prompt.is_empty() {
431        return Err(ServerError::BadRequest(
432            "A2A message must contain at least one text part".into(),
433        ));
434    }
435    Ok(prompt)
436}
437
438fn truncate_text(text: &str, max_len: usize) -> String {
439    if text.chars().count() <= max_len {
440        return text.to_string();
441    }
442    text.chars().take(max_len).collect()
443}
444
445fn map_task_status_to_a2a_state(status: &TaskStatus) -> &'static str {
446    match status {
447        TaskStatus::Completed => "completed",
448        TaskStatus::Cancelled => "canceled",
449        TaskStatus::Blocked | TaskStatus::NeedsFix => "failed",
450        TaskStatus::Pending => "submitted",
451        TaskStatus::InProgress | TaskStatus::ReviewRequired => "working",
452    }
453}
454
455fn build_a2a_task_payload(
456    task: &Task,
457    state: &str,
458    timestamp: Option<String>,
459) -> serde_json::Value {
460    let timestamp = timestamp.unwrap_or_else(|| Utc::now().to_rfc3339());
461    serde_json::json!({
462        "task": {
463            "id": task.id,
464            "contextId": task.session_id,
465            "status": {
466                "state": state,
467                "timestamp": timestamp,
468            },
469            "history": [{
470                "messageId": format!("msg-{}", task.id),
471                "role": "user",
472                "parts": [{ "text": task.objective }],
473                "contextId": task.session_id,
474                "taskId": task.id,
475            }],
476            "artifacts": [],
477            "metadata": {
478                "workspaceId": task.workspace_id,
479                "columnId": task.column_id,
480            }
481        }
482    })
483}