Skip to main content

codetether_agent/a2a/
server.rs

1//! A2A Server - serve as an A2A agent
2
3use super::types::*;
4use anyhow::Result;
5use axum::{
6    extract::State,
7    http::StatusCode,
8    response::Json,
9    routing::{get, post},
10    Router,
11};
12use dashmap::DashMap;
13use std::sync::Arc;
14use uuid::Uuid;
15
16/// A2A Server state
17#[derive(Clone)]
18pub struct A2AServer {
19    tasks: Arc<DashMap<String, Task>>,
20    agent_card: AgentCard,
21}
22
23impl A2AServer {
24    /// Create a new A2A server
25    pub fn new(agent_card: AgentCard) -> Self {
26        Self {
27            tasks: Arc::new(DashMap::new()),
28            agent_card,
29        }
30    }
31
32    /// Create the router for A2A endpoints
33    pub fn router(self) -> Router {
34        Router::new()
35            .route("/.well-known/agent.json", get(get_agent_card))
36            .route("/", post(handle_rpc))
37            .with_state(self)
38    }
39
40    /// Get the agent card for this server
41    #[allow(dead_code)]
42    pub fn card(&self) -> &AgentCard {
43        &self.agent_card
44    }
45
46    /// Create a default agent card
47    pub fn default_card(url: &str) -> AgentCard {
48        AgentCard {
49            name: "CodeTether Agent".to_string(),
50            description: "A2A-native AI coding agent for the CodeTether ecosystem".to_string(),
51            url: url.to_string(),
52            version: env!("CARGO_PKG_VERSION").to_string(),
53            protocol_version: "0.3.0".to_string(),
54            capabilities: AgentCapabilities {
55                streaming: true,
56                push_notifications: false,
57                state_transition_history: true,
58            },
59            skills: vec![
60                AgentSkill {
61                    id: "code".to_string(),
62                    name: "Code Generation".to_string(),
63                    description: "Write, edit, and refactor code".to_string(),
64                    tags: vec!["code".to_string(), "programming".to_string()],
65                    examples: vec![
66                        "Write a function to parse JSON".to_string(),
67                        "Refactor this code to use async/await".to_string(),
68                    ],
69                    input_modes: vec!["text/plain".to_string()],
70                    output_modes: vec!["text/plain".to_string()],
71                },
72                AgentSkill {
73                    id: "debug".to_string(),
74                    name: "Debugging".to_string(),
75                    description: "Debug and fix code issues".to_string(),
76                    tags: vec!["debug".to_string(), "fix".to_string()],
77                    examples: vec![
78                        "Why is this function returning undefined?".to_string(),
79                        "Fix the null pointer exception".to_string(),
80                    ],
81                    input_modes: vec!["text/plain".to_string()],
82                    output_modes: vec!["text/plain".to_string()],
83                },
84                AgentSkill {
85                    id: "explain".to_string(),
86                    name: "Code Explanation".to_string(),
87                    description: "Explain code and concepts".to_string(),
88                    tags: vec!["explain".to_string(), "learn".to_string()],
89                    examples: vec![
90                        "Explain how this algorithm works".to_string(),
91                        "What does this regex do?".to_string(),
92                    ],
93                    input_modes: vec!["text/plain".to_string()],
94                    output_modes: vec!["text/plain".to_string()],
95                },
96            ],
97            default_input_modes: vec!["text/plain".to_string(), "application/json".to_string()],
98            default_output_modes: vec!["text/plain".to_string(), "application/json".to_string()],
99            provider: Some(AgentProvider {
100                organization: "CodeTether".to_string(),
101                url: "https://codetether.ai".to_string(),
102            }),
103            icon_url: None,
104            documentation_url: None,
105        }
106    }
107}
108
109/// Get agent card handler
110async fn get_agent_card(State(server): State<A2AServer>) -> Json<AgentCard> {
111    Json(server.agent_card.clone())
112}
113
114/// Handle JSON-RPC requests
115async fn handle_rpc(
116    State(server): State<A2AServer>,
117    Json(request): Json<JsonRpcRequest>,
118) -> Result<Json<JsonRpcResponse>, (StatusCode, Json<JsonRpcResponse>)> {
119    let request_id = request.id.clone();
120    let response = match request.method.as_str() {
121        "message/send" => handle_message_send(&server, request).await,
122        "message/stream" => handle_message_stream(&server, request).await,
123        "tasks/get" => handle_tasks_get(&server, request).await,
124        "tasks/cancel" => handle_tasks_cancel(&server, request).await,
125        _ => Err(JsonRpcError::method_not_found(&request.method)),
126    };
127
128    match response {
129        Ok(result) => Ok(Json(JsonRpcResponse {
130            jsonrpc: "2.0".to_string(),
131            id: request_id.clone(),
132            result: Some(result),
133            error: None,
134        })),
135        Err(error) => Err((
136            StatusCode::OK,
137            Json(JsonRpcResponse {
138                jsonrpc: "2.0".to_string(),
139                id: request_id,
140                result: None,
141                error: Some(error),
142            }),
143        )),
144    }
145}
146
147async fn handle_message_send(
148    server: &A2AServer,
149    request: JsonRpcRequest,
150) -> Result<serde_json::Value, JsonRpcError> {
151    let params: MessageSendParams = serde_json::from_value(request.params)
152        .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
153
154    // Create a new task
155    let task_id = params.message.task_id.clone().unwrap_or_else(|| Uuid::new_v4().to_string());
156    
157    let task = Task {
158        id: task_id.clone(),
159        context_id: params.message.context_id.clone(),
160        status: TaskStatus {
161            state: TaskState::Submitted,
162            message: Some(params.message.clone()),
163            timestamp: Some(chrono::Utc::now().to_rfc3339()),
164        },
165        artifacts: vec![],
166        history: vec![params.message],
167        metadata: std::collections::HashMap::new(),
168    };
169
170    server.tasks.insert(task_id.clone(), task.clone());
171
172    // TODO: Process the task asynchronously
173
174    serde_json::to_value(task)
175        .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
176}
177
178async fn handle_message_stream(
179    _server: &A2AServer,
180    _request: JsonRpcRequest,
181) -> Result<serde_json::Value, JsonRpcError> {
182    // TODO: Implement streaming
183    Err(JsonRpcError {
184        code: UNSUPPORTED_OPERATION,
185        message: "Streaming not yet implemented".to_string(),
186        data: None,
187    })
188}
189
190async fn handle_tasks_get(
191    server: &A2AServer,
192    request: JsonRpcRequest,
193) -> Result<serde_json::Value, JsonRpcError> {
194    let params: TaskQueryParams = serde_json::from_value(request.params)
195        .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
196
197    let task = server.tasks.get(&params.id).ok_or_else(|| JsonRpcError {
198        code: TASK_NOT_FOUND,
199        message: format!("Task not found: {}", params.id),
200        data: None,
201    })?;
202
203    serde_json::to_value(task.value().clone())
204        .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
205}
206
207async fn handle_tasks_cancel(
208    server: &A2AServer,
209    request: JsonRpcRequest,
210) -> Result<serde_json::Value, JsonRpcError> {
211    let params: TaskQueryParams = serde_json::from_value(request.params)
212        .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
213
214    let mut task = server.tasks.get_mut(&params.id).ok_or_else(|| JsonRpcError {
215        code: TASK_NOT_FOUND,
216        message: format!("Task not found: {}", params.id),
217        data: None,
218    })?;
219
220    if !task.status.state.is_active() {
221        return Err(JsonRpcError {
222            code: TASK_NOT_CANCELABLE,
223            message: "Task is already in a terminal state".to_string(),
224            data: None,
225        });
226    }
227
228    task.status.state = TaskState::Cancelled;
229    task.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
230
231    serde_json::to_value(task.value().clone())
232        .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
233}