1use super::types::*;
4use anyhow::Result;
5use axum::{
6 Router,
7 extract::State,
8 http::StatusCode,
9 response::Json,
10 routing::{get, post},
11};
12use dashmap::DashMap;
13use std::sync::Arc;
14use uuid::Uuid;
15
16#[derive(Clone)]
18pub struct A2AServer {
19 tasks: Arc<DashMap<String, Task>>,
20 agent_card: AgentCard,
21}
22
23impl A2AServer {
24 pub fn new(agent_card: AgentCard) -> Self {
26 Self {
27 tasks: Arc::new(DashMap::new()),
28 agent_card,
29 }
30 }
31
32 pub fn router(self) -> Router {
34 Router::new()
35 .route("/.well-known/agent.json", get(get_agent_card))
36 .route("/", post(handle_rpc))
37 .with_state(self)
38 }
39
40 #[allow(dead_code)]
42 pub fn card(&self) -> &AgentCard {
43 &self.agent_card
44 }
45
46 pub fn default_card(url: &str) -> AgentCard {
48 AgentCard {
49 name: "CodeTether Agent".to_string(),
50 description: "A2A-native AI coding agent for the CodeTether ecosystem".to_string(),
51 url: url.to_string(),
52 version: env!("CARGO_PKG_VERSION").to_string(),
53 protocol_version: "0.3.0".to_string(),
54 capabilities: AgentCapabilities {
55 streaming: true,
56 push_notifications: false,
57 state_transition_history: true,
58 },
59 skills: vec![
60 AgentSkill {
61 id: "code".to_string(),
62 name: "Code Generation".to_string(),
63 description: "Write, edit, and refactor code".to_string(),
64 tags: vec!["code".to_string(), "programming".to_string()],
65 examples: vec![
66 "Write a function to parse JSON".to_string(),
67 "Refactor this code to use async/await".to_string(),
68 ],
69 input_modes: vec!["text/plain".to_string()],
70 output_modes: vec!["text/plain".to_string()],
71 },
72 AgentSkill {
73 id: "debug".to_string(),
74 name: "Debugging".to_string(),
75 description: "Debug and fix code issues".to_string(),
76 tags: vec!["debug".to_string(), "fix".to_string()],
77 examples: vec![
78 "Why is this function returning undefined?".to_string(),
79 "Fix the null pointer exception".to_string(),
80 ],
81 input_modes: vec!["text/plain".to_string()],
82 output_modes: vec!["text/plain".to_string()],
83 },
84 AgentSkill {
85 id: "explain".to_string(),
86 name: "Code Explanation".to_string(),
87 description: "Explain code and concepts".to_string(),
88 tags: vec!["explain".to_string(), "learn".to_string()],
89 examples: vec![
90 "Explain how this algorithm works".to_string(),
91 "What does this regex do?".to_string(),
92 ],
93 input_modes: vec!["text/plain".to_string()],
94 output_modes: vec!["text/plain".to_string()],
95 },
96 ],
97 default_input_modes: vec!["text/plain".to_string(), "application/json".to_string()],
98 default_output_modes: vec!["text/plain".to_string(), "application/json".to_string()],
99 provider: Some(AgentProvider {
100 organization: "CodeTether".to_string(),
101 url: "https://codetether.ai".to_string(),
102 }),
103 icon_url: None,
104 documentation_url: None,
105 }
106 }
107}
108
109async fn get_agent_card(State(server): State<A2AServer>) -> Json<AgentCard> {
111 Json(server.agent_card.clone())
112}
113
114async fn handle_rpc(
116 State(server): State<A2AServer>,
117 Json(request): Json<JsonRpcRequest>,
118) -> Result<Json<JsonRpcResponse>, (StatusCode, Json<JsonRpcResponse>)> {
119 let request_id = request.id.clone();
120 let response = match request.method.as_str() {
121 "message/send" => handle_message_send(&server, request).await,
122 "message/stream" => handle_message_stream(&server, request).await,
123 "tasks/get" => handle_tasks_get(&server, request).await,
124 "tasks/cancel" => handle_tasks_cancel(&server, request).await,
125 _ => Err(JsonRpcError::method_not_found(&request.method)),
126 };
127
128 match response {
129 Ok(result) => Ok(Json(JsonRpcResponse {
130 jsonrpc: "2.0".to_string(),
131 id: request_id.clone(),
132 result: Some(result),
133 error: None,
134 })),
135 Err(error) => Err((
136 StatusCode::OK,
137 Json(JsonRpcResponse {
138 jsonrpc: "2.0".to_string(),
139 id: request_id,
140 result: None,
141 error: Some(error),
142 }),
143 )),
144 }
145}
146
147async fn handle_message_send(
148 server: &A2AServer,
149 request: JsonRpcRequest,
150) -> Result<serde_json::Value, JsonRpcError> {
151 let params: MessageSendParams = serde_json::from_value(request.params)
152 .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
153
154 let task_id = params
156 .message
157 .task_id
158 .clone()
159 .unwrap_or_else(|| Uuid::new_v4().to_string());
160
161 let task = Task {
162 id: task_id.clone(),
163 context_id: params.message.context_id.clone(),
164 status: TaskStatus {
165 state: TaskState::Submitted,
166 message: Some(params.message.clone()),
167 timestamp: Some(chrono::Utc::now().to_rfc3339()),
168 },
169 artifacts: vec![],
170 history: vec![params.message],
171 metadata: std::collections::HashMap::new(),
172 };
173
174 server.tasks.insert(task_id.clone(), task.clone());
175
176 serde_json::to_value(task)
179 .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
180}
181
182async fn handle_message_stream(
183 _server: &A2AServer,
184 _request: JsonRpcRequest,
185) -> Result<serde_json::Value, JsonRpcError> {
186 Err(JsonRpcError {
188 code: UNSUPPORTED_OPERATION,
189 message: "Streaming not yet implemented".to_string(),
190 data: None,
191 })
192}
193
194async fn handle_tasks_get(
195 server: &A2AServer,
196 request: JsonRpcRequest,
197) -> Result<serde_json::Value, JsonRpcError> {
198 let params: TaskQueryParams = serde_json::from_value(request.params)
199 .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
200
201 let task = server.tasks.get(¶ms.id).ok_or_else(|| JsonRpcError {
202 code: TASK_NOT_FOUND,
203 message: format!("Task not found: {}", params.id),
204 data: None,
205 })?;
206
207 serde_json::to_value(task.value().clone())
208 .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
209}
210
211async fn handle_tasks_cancel(
212 server: &A2AServer,
213 request: JsonRpcRequest,
214) -> Result<serde_json::Value, JsonRpcError> {
215 let params: TaskQueryParams = serde_json::from_value(request.params)
216 .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
217
218 let mut task = server
219 .tasks
220 .get_mut(¶ms.id)
221 .ok_or_else(|| JsonRpcError {
222 code: TASK_NOT_FOUND,
223 message: format!("Task not found: {}", params.id),
224 data: None,
225 })?;
226
227 if !task.status.state.is_active() {
228 return Err(JsonRpcError {
229 code: TASK_NOT_CANCELABLE,
230 message: "Task is already in a terminal state".to_string(),
231 data: None,
232 });
233 }
234
235 task.status.state = TaskState::Cancelled;
236 task.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
237
238 serde_json::to_value(task.value().clone())
239 .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
240}