1use super::types::*;
4use anyhow::Result;
5use axum::{
6 extract::State,
7 http::StatusCode,
8 response::Json,
9 routing::{get, post},
10 Router,
11};
12use dashmap::DashMap;
13use std::sync::Arc;
14use uuid::Uuid;
15
16#[derive(Clone)]
18pub struct A2AServer {
19 tasks: Arc<DashMap<String, Task>>,
20 agent_card: AgentCard,
21}
22
23impl A2AServer {
24 pub fn new(agent_card: AgentCard) -> Self {
26 Self {
27 tasks: Arc::new(DashMap::new()),
28 agent_card,
29 }
30 }
31
32 pub fn router(self) -> Router {
34 Router::new()
35 .route("/.well-known/agent.json", get(get_agent_card))
36 .route("/", post(handle_rpc))
37 .with_state(self)
38 }
39
40 #[allow(dead_code)]
42 pub fn card(&self) -> &AgentCard {
43 &self.agent_card
44 }
45
46 pub fn default_card(url: &str) -> AgentCard {
48 AgentCard {
49 name: "CodeTether Agent".to_string(),
50 description: "A2A-native AI coding agent for the CodeTether ecosystem".to_string(),
51 url: url.to_string(),
52 version: env!("CARGO_PKG_VERSION").to_string(),
53 protocol_version: "0.3.0".to_string(),
54 capabilities: AgentCapabilities {
55 streaming: true,
56 push_notifications: false,
57 state_transition_history: true,
58 },
59 skills: vec![
60 AgentSkill {
61 id: "code".to_string(),
62 name: "Code Generation".to_string(),
63 description: "Write, edit, and refactor code".to_string(),
64 tags: vec!["code".to_string(), "programming".to_string()],
65 examples: vec![
66 "Write a function to parse JSON".to_string(),
67 "Refactor this code to use async/await".to_string(),
68 ],
69 input_modes: vec!["text/plain".to_string()],
70 output_modes: vec!["text/plain".to_string()],
71 },
72 AgentSkill {
73 id: "debug".to_string(),
74 name: "Debugging".to_string(),
75 description: "Debug and fix code issues".to_string(),
76 tags: vec!["debug".to_string(), "fix".to_string()],
77 examples: vec![
78 "Why is this function returning undefined?".to_string(),
79 "Fix the null pointer exception".to_string(),
80 ],
81 input_modes: vec!["text/plain".to_string()],
82 output_modes: vec!["text/plain".to_string()],
83 },
84 AgentSkill {
85 id: "explain".to_string(),
86 name: "Code Explanation".to_string(),
87 description: "Explain code and concepts".to_string(),
88 tags: vec!["explain".to_string(), "learn".to_string()],
89 examples: vec![
90 "Explain how this algorithm works".to_string(),
91 "What does this regex do?".to_string(),
92 ],
93 input_modes: vec!["text/plain".to_string()],
94 output_modes: vec!["text/plain".to_string()],
95 },
96 ],
97 default_input_modes: vec!["text/plain".to_string(), "application/json".to_string()],
98 default_output_modes: vec!["text/plain".to_string(), "application/json".to_string()],
99 provider: Some(AgentProvider {
100 organization: "CodeTether".to_string(),
101 url: "https://codetether.ai".to_string(),
102 }),
103 icon_url: None,
104 documentation_url: None,
105 }
106 }
107}
108
109async fn get_agent_card(State(server): State<A2AServer>) -> Json<AgentCard> {
111 Json(server.agent_card.clone())
112}
113
114async fn handle_rpc(
116 State(server): State<A2AServer>,
117 Json(request): Json<JsonRpcRequest>,
118) -> Result<Json<JsonRpcResponse>, (StatusCode, Json<JsonRpcResponse>)> {
119 let request_id = request.id.clone();
120 let response = match request.method.as_str() {
121 "message/send" => handle_message_send(&server, request).await,
122 "message/stream" => handle_message_stream(&server, request).await,
123 "tasks/get" => handle_tasks_get(&server, request).await,
124 "tasks/cancel" => handle_tasks_cancel(&server, request).await,
125 _ => Err(JsonRpcError::method_not_found(&request.method)),
126 };
127
128 match response {
129 Ok(result) => Ok(Json(JsonRpcResponse {
130 jsonrpc: "2.0".to_string(),
131 id: request_id.clone(),
132 result: Some(result),
133 error: None,
134 })),
135 Err(error) => Err((
136 StatusCode::OK,
137 Json(JsonRpcResponse {
138 jsonrpc: "2.0".to_string(),
139 id: request_id,
140 result: None,
141 error: Some(error),
142 }),
143 )),
144 }
145}
146
147async fn handle_message_send(
148 server: &A2AServer,
149 request: JsonRpcRequest,
150) -> Result<serde_json::Value, JsonRpcError> {
151 let params: MessageSendParams = serde_json::from_value(request.params)
152 .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
153
154 let task_id = params.message.task_id.clone().unwrap_or_else(|| Uuid::new_v4().to_string());
156
157 let task = Task {
158 id: task_id.clone(),
159 context_id: params.message.context_id.clone(),
160 status: TaskStatus {
161 state: TaskState::Submitted,
162 message: Some(params.message.clone()),
163 timestamp: Some(chrono::Utc::now().to_rfc3339()),
164 },
165 artifacts: vec![],
166 history: vec![params.message],
167 metadata: std::collections::HashMap::new(),
168 };
169
170 server.tasks.insert(task_id.clone(), task.clone());
171
172 serde_json::to_value(task)
175 .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
176}
177
178async fn handle_message_stream(
179 _server: &A2AServer,
180 _request: JsonRpcRequest,
181) -> Result<serde_json::Value, JsonRpcError> {
182 Err(JsonRpcError {
184 code: UNSUPPORTED_OPERATION,
185 message: "Streaming not yet implemented".to_string(),
186 data: None,
187 })
188}
189
190async fn handle_tasks_get(
191 server: &A2AServer,
192 request: JsonRpcRequest,
193) -> Result<serde_json::Value, JsonRpcError> {
194 let params: TaskQueryParams = serde_json::from_value(request.params)
195 .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
196
197 let task = server.tasks.get(¶ms.id).ok_or_else(|| JsonRpcError {
198 code: TASK_NOT_FOUND,
199 message: format!("Task not found: {}", params.id),
200 data: None,
201 })?;
202
203 serde_json::to_value(task.value().clone())
204 .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
205}
206
207async fn handle_tasks_cancel(
208 server: &A2AServer,
209 request: JsonRpcRequest,
210) -> Result<serde_json::Value, JsonRpcError> {
211 let params: TaskQueryParams = serde_json::from_value(request.params)
212 .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
213
214 let mut task = server.tasks.get_mut(¶ms.id).ok_or_else(|| JsonRpcError {
215 code: TASK_NOT_FOUND,
216 message: format!("Task not found: {}", params.id),
217 data: None,
218 })?;
219
220 if !task.status.state.is_active() {
221 return Err(JsonRpcError {
222 code: TASK_NOT_CANCELABLE,
223 message: "Task is already in a terminal state".to_string(),
224 data: None,
225 });
226 }
227
228 task.status.state = TaskState::Cancelled;
229 task.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
230
231 serde_json::to_value(task.value().clone())
232 .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
233}