Skip to main content

codetether_agent/a2a/
server.rs

1//! A2A Server - serve as an A2A agent
2
3use super::types::*;
4use crate::session::Session;
5use anyhow::Result;
6use axum::{
7    Router,
8    extract::State,
9    http::StatusCode,
10    response::Json,
11    routing::{get, post},
12};
13use dashmap::DashMap;
14use std::sync::Arc;
15use uuid::Uuid;
16
17/// A2A Server state
18#[derive(Clone)]
19pub struct A2AServer {
20    tasks: Arc<DashMap<String, Task>>,
21    agent_card: AgentCard,
22}
23
24impl A2AServer {
25    /// Create a new A2A server
26    pub fn new(agent_card: AgentCard) -> Self {
27        Self {
28            tasks: Arc::new(DashMap::new()),
29            agent_card,
30        }
31    }
32
33    /// Create the router for A2A endpoints
34    pub fn router(self) -> Router {
35        Router::new()
36            .route("/.well-known/agent.json", get(get_agent_card))
37            .route("/", post(handle_rpc))
38            .with_state(self)
39    }
40
41    /// Get the agent card for this server
42    #[allow(dead_code)]
43    pub fn card(&self) -> &AgentCard {
44        &self.agent_card
45    }
46
47    /// Create a default agent card
48    pub fn default_card(url: &str) -> AgentCard {
49        AgentCard {
50            name: "CodeTether Agent".to_string(),
51            description: "A2A-native AI coding agent for the CodeTether ecosystem".to_string(),
52            url: url.to_string(),
53            version: env!("CARGO_PKG_VERSION").to_string(),
54            protocol_version: "0.3.0".to_string(),
55            capabilities: AgentCapabilities {
56                streaming: true,
57                push_notifications: false,
58                state_transition_history: true,
59            },
60            skills: vec![
61                AgentSkill {
62                    id: "code".to_string(),
63                    name: "Code Generation".to_string(),
64                    description: "Write, edit, and refactor code".to_string(),
65                    tags: vec!["code".to_string(), "programming".to_string()],
66                    examples: vec![
67                        "Write a function to parse JSON".to_string(),
68                        "Refactor this code to use async/await".to_string(),
69                    ],
70                    input_modes: vec!["text/plain".to_string()],
71                    output_modes: vec!["text/plain".to_string()],
72                },
73                AgentSkill {
74                    id: "debug".to_string(),
75                    name: "Debugging".to_string(),
76                    description: "Debug and fix code issues".to_string(),
77                    tags: vec!["debug".to_string(), "fix".to_string()],
78                    examples: vec![
79                        "Why is this function returning undefined?".to_string(),
80                        "Fix the null pointer exception".to_string(),
81                    ],
82                    input_modes: vec!["text/plain".to_string()],
83                    output_modes: vec!["text/plain".to_string()],
84                },
85                AgentSkill {
86                    id: "explain".to_string(),
87                    name: "Code Explanation".to_string(),
88                    description: "Explain code and concepts".to_string(),
89                    tags: vec!["explain".to_string(), "learn".to_string()],
90                    examples: vec![
91                        "Explain how this algorithm works".to_string(),
92                        "What does this regex do?".to_string(),
93                    ],
94                    input_modes: vec!["text/plain".to_string()],
95                    output_modes: vec!["text/plain".to_string()],
96                },
97            ],
98            default_input_modes: vec!["text/plain".to_string(), "application/json".to_string()],
99            default_output_modes: vec!["text/plain".to_string(), "application/json".to_string()],
100            provider: Some(AgentProvider {
101                organization: "CodeTether".to_string(),
102                url: "https://codetether.ai".to_string(),
103            }),
104            icon_url: None,
105            documentation_url: None,
106        }
107    }
108}
109
110/// Get agent card handler
111async fn get_agent_card(State(server): State<A2AServer>) -> Json<AgentCard> {
112    Json(server.agent_card.clone())
113}
114
115/// Handle JSON-RPC requests
116async fn handle_rpc(
117    State(server): State<A2AServer>,
118    Json(request): Json<JsonRpcRequest>,
119) -> Result<Json<JsonRpcResponse>, (StatusCode, Json<JsonRpcResponse>)> {
120    let request_id = request.id.clone();
121    let response = match request.method.as_str() {
122        "message/send" => handle_message_send(&server, request).await,
123        "message/stream" => handle_message_stream(&server, request).await,
124        "tasks/get" => handle_tasks_get(&server, request).await,
125        "tasks/cancel" => handle_tasks_cancel(&server, request).await,
126        _ => Err(JsonRpcError::method_not_found(&request.method)),
127    };
128
129    match response {
130        Ok(result) => Ok(Json(JsonRpcResponse {
131            jsonrpc: "2.0".to_string(),
132            id: request_id.clone(),
133            result: Some(result),
134            error: None,
135        })),
136        Err(error) => Err((
137            StatusCode::OK,
138            Json(JsonRpcResponse {
139                jsonrpc: "2.0".to_string(),
140                id: request_id,
141                result: None,
142                error: Some(error),
143            }),
144        )),
145    }
146}
147
148async fn handle_message_send(
149    server: &A2AServer,
150    request: JsonRpcRequest,
151) -> Result<serde_json::Value, JsonRpcError> {
152    let params: MessageSendParams = serde_json::from_value(request.params)
153        .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
154
155    // Create a new task
156    let task_id = params
157        .message
158        .task_id
159        .clone()
160        .unwrap_or_else(|| Uuid::new_v4().to_string());
161
162    let task = Task {
163        id: task_id.clone(),
164        context_id: params.message.context_id.clone(),
165        status: TaskStatus {
166            state: TaskState::Working,
167            message: Some(params.message.clone()),
168            timestamp: Some(chrono::Utc::now().to_rfc3339()),
169        },
170        artifacts: vec![],
171        history: vec![params.message.clone()],
172        metadata: std::collections::HashMap::new(),
173    };
174
175    server.tasks.insert(task_id.clone(), task.clone());
176
177    // Extract prompt text from message parts
178    let prompt: String = params
179        .message
180        .parts
181        .iter()
182        .filter_map(|p| match p {
183            Part::Text { text } => Some(text.as_str()),
184            _ => None,
185        })
186        .collect::<Vec<_>>()
187        .join("\n");
188
189    if prompt.is_empty() {
190        // Update task to failed
191        if let Some(mut t) = server.tasks.get_mut(&task_id) {
192            t.status.state = TaskState::Failed;
193            t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
194        }
195        return Err(JsonRpcError::invalid_params("No text content in message"));
196    }
197
198    // Determine if blocking (default true for message/send)
199    let blocking = params
200        .configuration
201        .as_ref()
202        .and_then(|c| c.blocking)
203        .unwrap_or(true);
204
205    if blocking {
206        // Synchronous execution: create session, run prompt, return completed task
207        let mut session = Session::new().await.map_err(|e| {
208            JsonRpcError::internal_error(format!("Failed to create session: {}", e))
209        })?;
210
211        match session.prompt(&prompt).await {
212            Ok(result) => {
213                let response_message = Message {
214                    message_id: Uuid::new_v4().to_string(),
215                    role: MessageRole::Agent,
216                    parts: vec![Part::Text {
217                        text: result.text.clone(),
218                    }],
219                    context_id: params.message.context_id.clone(),
220                    task_id: Some(task_id.clone()),
221                    metadata: std::collections::HashMap::new(),
222                };
223
224                let artifact = Artifact {
225                    artifact_id: Uuid::new_v4().to_string(),
226                    parts: vec![Part::Text { text: result.text }],
227                    name: Some("response".to_string()),
228                    description: None,
229                    metadata: std::collections::HashMap::new(),
230                };
231
232                if let Some(mut t) = server.tasks.get_mut(&task_id) {
233                    t.status.state = TaskState::Completed;
234                    t.status.message = Some(response_message.clone());
235                    t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
236                    t.artifacts.push(artifact);
237                    t.history.push(response_message);
238                }
239            }
240            Err(e) => {
241                let error_message = Message {
242                    message_id: Uuid::new_v4().to_string(),
243                    role: MessageRole::Agent,
244                    parts: vec![Part::Text {
245                        text: format!("Error: {}", e),
246                    }],
247                    context_id: params.message.context_id.clone(),
248                    task_id: Some(task_id.clone()),
249                    metadata: std::collections::HashMap::new(),
250                };
251
252                if let Some(mut t) = server.tasks.get_mut(&task_id) {
253                    t.status.state = TaskState::Failed;
254                    t.status.message = Some(error_message);
255                    t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
256                }
257            }
258        }
259    } else {
260        // Async execution: spawn background task, return immediately with Working state
261        let tasks = server.tasks.clone();
262        let context_id = params.message.context_id.clone();
263        let spawn_task_id = task_id.clone();
264
265        tokio::spawn(async move {
266            let task_id = spawn_task_id;
267            let mut session = match Session::new().await {
268                Ok(s) => s,
269                Err(e) => {
270                    tracing::error!("Failed to create session for task {}: {}", task_id, e);
271                    if let Some(mut t) = tasks.get_mut(&task_id) {
272                        t.status.state = TaskState::Failed;
273                        t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
274                    }
275                    return;
276                }
277            };
278
279            match session.prompt(&prompt).await {
280                Ok(result) => {
281                    let response_message = Message {
282                        message_id: Uuid::new_v4().to_string(),
283                        role: MessageRole::Agent,
284                        parts: vec![Part::Text {
285                            text: result.text.clone(),
286                        }],
287                        context_id,
288                        task_id: Some(task_id.clone()),
289                        metadata: std::collections::HashMap::new(),
290                    };
291
292                    let artifact = Artifact {
293                        artifact_id: Uuid::new_v4().to_string(),
294                        parts: vec![Part::Text { text: result.text }],
295                        name: Some("response".to_string()),
296                        description: None,
297                        metadata: std::collections::HashMap::new(),
298                    };
299
300                    if let Some(mut t) = tasks.get_mut(&task_id) {
301                        t.status.state = TaskState::Completed;
302                        t.status.message = Some(response_message.clone());
303                        t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
304                        t.artifacts.push(artifact);
305                        t.history.push(response_message);
306                    }
307                }
308                Err(e) => {
309                    tracing::error!("Task {} failed: {}", task_id, e);
310                    if let Some(mut t) = tasks.get_mut(&task_id) {
311                        t.status.state = TaskState::Failed;
312                        t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
313                    }
314                }
315            }
316        });
317    }
318
319    // Return current task state
320    let task = server.tasks.get(&task_id).unwrap();
321    serde_json::to_value(task.value().clone())
322        .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
323}
324
325async fn handle_message_stream(
326    server: &A2AServer,
327    request: JsonRpcRequest,
328) -> Result<serde_json::Value, JsonRpcError> {
329    // message/stream submits the task for async processing.
330    // The client should poll tasks/get for status updates.
331    // True SSE streaming requires a dedicated endpoint outside JSON-RPC.
332
333    let params: MessageSendParams = serde_json::from_value(request.params)
334        .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
335
336    let task_id = params
337        .message
338        .task_id
339        .clone()
340        .unwrap_or_else(|| Uuid::new_v4().to_string());
341
342    let task = Task {
343        id: task_id.clone(),
344        context_id: params.message.context_id.clone(),
345        status: TaskStatus {
346            state: TaskState::Working,
347            message: Some(params.message.clone()),
348            timestamp: Some(chrono::Utc::now().to_rfc3339()),
349        },
350        artifacts: vec![],
351        history: vec![params.message.clone()],
352        metadata: std::collections::HashMap::new(),
353    };
354
355    server.tasks.insert(task_id.clone(), task.clone());
356
357    // Extract prompt
358    let prompt: String = params
359        .message
360        .parts
361        .iter()
362        .filter_map(|p| match p {
363            Part::Text { text } => Some(text.as_str()),
364            _ => None,
365        })
366        .collect::<Vec<_>>()
367        .join("\n");
368
369    if prompt.is_empty() {
370        if let Some(mut t) = server.tasks.get_mut(&task_id) {
371            t.status.state = TaskState::Failed;
372            t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
373        }
374        return Err(JsonRpcError::invalid_params("No text content in message"));
375    }
376
377    // Spawn async processing
378    let tasks = server.tasks.clone();
379    let context_id = params.message.context_id.clone();
380    let spawn_task_id = task_id.clone();
381
382    tokio::spawn(async move {
383        let task_id = spawn_task_id;
384        let mut session = match Session::new().await {
385            Ok(s) => s,
386            Err(e) => {
387                tracing::error!("Failed to create session for stream task {}: {}", task_id, e);
388                if let Some(mut t) = tasks.get_mut(&task_id) {
389                    t.status.state = TaskState::Failed;
390                    t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
391                }
392                return;
393            }
394        };
395
396        match session.prompt(&prompt).await {
397            Ok(result) => {
398                let response_message = Message {
399                    message_id: Uuid::new_v4().to_string(),
400                    role: MessageRole::Agent,
401                    parts: vec![Part::Text {
402                        text: result.text.clone(),
403                    }],
404                    context_id,
405                    task_id: Some(task_id.clone()),
406                    metadata: std::collections::HashMap::new(),
407                };
408
409                let artifact = Artifact {
410                    artifact_id: Uuid::new_v4().to_string(),
411                    parts: vec![Part::Text { text: result.text }],
412                    name: Some("response".to_string()),
413                    description: None,
414                    metadata: std::collections::HashMap::new(),
415                };
416
417                if let Some(mut t) = tasks.get_mut(&task_id) {
418                    t.status.state = TaskState::Completed;
419                    t.status.message = Some(response_message.clone());
420                    t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
421                    t.artifacts.push(artifact);
422                    t.history.push(response_message);
423                }
424            }
425            Err(e) => {
426                tracing::error!("Stream task {} failed: {}", task_id, e);
427                if let Some(mut t) = tasks.get_mut(&task_id) {
428                    t.status.state = TaskState::Failed;
429                    t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
430                }
431            }
432        }
433    });
434
435    // Return task in Working state — client polls tasks/get for completion
436    serde_json::to_value(task)
437        .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
438}
439
440async fn handle_tasks_get(
441    server: &A2AServer,
442    request: JsonRpcRequest,
443) -> Result<serde_json::Value, JsonRpcError> {
444    let params: TaskQueryParams = serde_json::from_value(request.params)
445        .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
446
447    let task = server.tasks.get(&params.id).ok_or_else(|| JsonRpcError {
448        code: TASK_NOT_FOUND,
449        message: format!("Task not found: {}", params.id),
450        data: None,
451    })?;
452
453    serde_json::to_value(task.value().clone())
454        .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
455}
456
457async fn handle_tasks_cancel(
458    server: &A2AServer,
459    request: JsonRpcRequest,
460) -> Result<serde_json::Value, JsonRpcError> {
461    let params: TaskQueryParams = serde_json::from_value(request.params)
462        .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
463
464    let mut task = server
465        .tasks
466        .get_mut(&params.id)
467        .ok_or_else(|| JsonRpcError {
468            code: TASK_NOT_FOUND,
469            message: format!("Task not found: {}", params.id),
470            data: None,
471        })?;
472
473    if !task.status.state.is_active() {
474        return Err(JsonRpcError {
475            code: TASK_NOT_CANCELABLE,
476            message: "Task is already in a terminal state".to_string(),
477            data: None,
478        });
479    }
480
481    task.status.state = TaskState::Cancelled;
482    task.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
483
484    serde_json::to_value(task.value().clone())
485        .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
486}