1use super::types::*;
4use crate::session::Session;
5use anyhow::Result;
6use axum::{
7 Router,
8 extract::State,
9 http::StatusCode,
10 response::Json,
11 routing::{get, post},
12};
13use dashmap::DashMap;
14use std::sync::Arc;
15use uuid::Uuid;
16
17#[derive(Clone)]
19pub struct A2AServer {
20 tasks: Arc<DashMap<String, Task>>,
21 agent_card: AgentCard,
22}
23
24impl A2AServer {
25 pub fn new(agent_card: AgentCard) -> Self {
27 Self {
28 tasks: Arc::new(DashMap::new()),
29 agent_card,
30 }
31 }
32
33 pub fn router(self) -> Router {
35 Router::new()
36 .route("/.well-known/agent.json", get(get_agent_card))
37 .route("/", post(handle_rpc))
38 .with_state(self)
39 }
40
41 #[allow(dead_code)]
43 pub fn card(&self) -> &AgentCard {
44 &self.agent_card
45 }
46
47 pub fn default_card(url: &str) -> AgentCard {
49 AgentCard {
50 name: "CodeTether Agent".to_string(),
51 description: "A2A-native AI coding agent for the CodeTether ecosystem".to_string(),
52 url: url.to_string(),
53 version: env!("CARGO_PKG_VERSION").to_string(),
54 protocol_version: "0.3.0".to_string(),
55 capabilities: AgentCapabilities {
56 streaming: true,
57 push_notifications: false,
58 state_transition_history: true,
59 },
60 skills: vec![
61 AgentSkill {
62 id: "code".to_string(),
63 name: "Code Generation".to_string(),
64 description: "Write, edit, and refactor code".to_string(),
65 tags: vec!["code".to_string(), "programming".to_string()],
66 examples: vec![
67 "Write a function to parse JSON".to_string(),
68 "Refactor this code to use async/await".to_string(),
69 ],
70 input_modes: vec!["text/plain".to_string()],
71 output_modes: vec!["text/plain".to_string()],
72 },
73 AgentSkill {
74 id: "debug".to_string(),
75 name: "Debugging".to_string(),
76 description: "Debug and fix code issues".to_string(),
77 tags: vec!["debug".to_string(), "fix".to_string()],
78 examples: vec![
79 "Why is this function returning undefined?".to_string(),
80 "Fix the null pointer exception".to_string(),
81 ],
82 input_modes: vec!["text/plain".to_string()],
83 output_modes: vec!["text/plain".to_string()],
84 },
85 AgentSkill {
86 id: "explain".to_string(),
87 name: "Code Explanation".to_string(),
88 description: "Explain code and concepts".to_string(),
89 tags: vec!["explain".to_string(), "learn".to_string()],
90 examples: vec![
91 "Explain how this algorithm works".to_string(),
92 "What does this regex do?".to_string(),
93 ],
94 input_modes: vec!["text/plain".to_string()],
95 output_modes: vec!["text/plain".to_string()],
96 },
97 ],
98 default_input_modes: vec!["text/plain".to_string(), "application/json".to_string()],
99 default_output_modes: vec!["text/plain".to_string(), "application/json".to_string()],
100 provider: Some(AgentProvider {
101 organization: "CodeTether".to_string(),
102 url: "https://codetether.ai".to_string(),
103 }),
104 icon_url: None,
105 documentation_url: None,
106 }
107 }
108}
109
110async fn get_agent_card(State(server): State<A2AServer>) -> Json<AgentCard> {
112 Json(server.agent_card.clone())
113}
114
115async fn handle_rpc(
117 State(server): State<A2AServer>,
118 Json(request): Json<JsonRpcRequest>,
119) -> Result<Json<JsonRpcResponse>, (StatusCode, Json<JsonRpcResponse>)> {
120 let request_id = request.id.clone();
121 let response = match request.method.as_str() {
122 "message/send" => handle_message_send(&server, request).await,
123 "message/stream" => handle_message_stream(&server, request).await,
124 "tasks/get" => handle_tasks_get(&server, request).await,
125 "tasks/cancel" => handle_tasks_cancel(&server, request).await,
126 _ => Err(JsonRpcError::method_not_found(&request.method)),
127 };
128
129 match response {
130 Ok(result) => Ok(Json(JsonRpcResponse {
131 jsonrpc: "2.0".to_string(),
132 id: request_id.clone(),
133 result: Some(result),
134 error: None,
135 })),
136 Err(error) => Err((
137 StatusCode::OK,
138 Json(JsonRpcResponse {
139 jsonrpc: "2.0".to_string(),
140 id: request_id,
141 result: None,
142 error: Some(error),
143 }),
144 )),
145 }
146}
147
148async fn handle_message_send(
149 server: &A2AServer,
150 request: JsonRpcRequest,
151) -> Result<serde_json::Value, JsonRpcError> {
152 let params: MessageSendParams = serde_json::from_value(request.params)
153 .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
154
155 let task_id = params
157 .message
158 .task_id
159 .clone()
160 .unwrap_or_else(|| Uuid::new_v4().to_string());
161
162 let task = Task {
163 id: task_id.clone(),
164 context_id: params.message.context_id.clone(),
165 status: TaskStatus {
166 state: TaskState::Working,
167 message: Some(params.message.clone()),
168 timestamp: Some(chrono::Utc::now().to_rfc3339()),
169 },
170 artifacts: vec![],
171 history: vec![params.message.clone()],
172 metadata: std::collections::HashMap::new(),
173 };
174
175 server.tasks.insert(task_id.clone(), task.clone());
176
177 let prompt: String = params
179 .message
180 .parts
181 .iter()
182 .filter_map(|p| match p {
183 Part::Text { text } => Some(text.as_str()),
184 _ => None,
185 })
186 .collect::<Vec<_>>()
187 .join("\n");
188
189 if prompt.is_empty() {
190 if let Some(mut t) = server.tasks.get_mut(&task_id) {
192 t.status.state = TaskState::Failed;
193 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
194 }
195 return Err(JsonRpcError::invalid_params("No text content in message"));
196 }
197
198 let blocking = params
200 .configuration
201 .as_ref()
202 .and_then(|c| c.blocking)
203 .unwrap_or(true);
204
205 if blocking {
206 let mut session = Session::new().await.map_err(|e| {
208 JsonRpcError::internal_error(format!("Failed to create session: {}", e))
209 })?;
210
211 match session.prompt(&prompt).await {
212 Ok(result) => {
213 let response_message = Message {
214 message_id: Uuid::new_v4().to_string(),
215 role: MessageRole::Agent,
216 parts: vec![Part::Text {
217 text: result.text.clone(),
218 }],
219 context_id: params.message.context_id.clone(),
220 task_id: Some(task_id.clone()),
221 metadata: std::collections::HashMap::new(),
222 };
223
224 let artifact = Artifact {
225 artifact_id: Uuid::new_v4().to_string(),
226 parts: vec![Part::Text { text: result.text }],
227 name: Some("response".to_string()),
228 description: None,
229 metadata: std::collections::HashMap::new(),
230 };
231
232 if let Some(mut t) = server.tasks.get_mut(&task_id) {
233 t.status.state = TaskState::Completed;
234 t.status.message = Some(response_message.clone());
235 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
236 t.artifacts.push(artifact);
237 t.history.push(response_message);
238 }
239 }
240 Err(e) => {
241 let error_message = Message {
242 message_id: Uuid::new_v4().to_string(),
243 role: MessageRole::Agent,
244 parts: vec![Part::Text {
245 text: format!("Error: {}", e),
246 }],
247 context_id: params.message.context_id.clone(),
248 task_id: Some(task_id.clone()),
249 metadata: std::collections::HashMap::new(),
250 };
251
252 if let Some(mut t) = server.tasks.get_mut(&task_id) {
253 t.status.state = TaskState::Failed;
254 t.status.message = Some(error_message);
255 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
256 }
257 }
258 }
259 } else {
260 let tasks = server.tasks.clone();
262 let context_id = params.message.context_id.clone();
263 let spawn_task_id = task_id.clone();
264
265 tokio::spawn(async move {
266 let task_id = spawn_task_id;
267 let mut session = match Session::new().await {
268 Ok(s) => s,
269 Err(e) => {
270 tracing::error!("Failed to create session for task {}: {}", task_id, e);
271 if let Some(mut t) = tasks.get_mut(&task_id) {
272 t.status.state = TaskState::Failed;
273 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
274 }
275 return;
276 }
277 };
278
279 match session.prompt(&prompt).await {
280 Ok(result) => {
281 let response_message = Message {
282 message_id: Uuid::new_v4().to_string(),
283 role: MessageRole::Agent,
284 parts: vec![Part::Text {
285 text: result.text.clone(),
286 }],
287 context_id,
288 task_id: Some(task_id.clone()),
289 metadata: std::collections::HashMap::new(),
290 };
291
292 let artifact = Artifact {
293 artifact_id: Uuid::new_v4().to_string(),
294 parts: vec![Part::Text { text: result.text }],
295 name: Some("response".to_string()),
296 description: None,
297 metadata: std::collections::HashMap::new(),
298 };
299
300 if let Some(mut t) = tasks.get_mut(&task_id) {
301 t.status.state = TaskState::Completed;
302 t.status.message = Some(response_message.clone());
303 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
304 t.artifacts.push(artifact);
305 t.history.push(response_message);
306 }
307 }
308 Err(e) => {
309 tracing::error!("Task {} failed: {}", task_id, e);
310 if let Some(mut t) = tasks.get_mut(&task_id) {
311 t.status.state = TaskState::Failed;
312 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
313 }
314 }
315 }
316 });
317 }
318
319 let task = server.tasks.get(&task_id).unwrap();
321 serde_json::to_value(task.value().clone())
322 .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
323}
324
325async fn handle_message_stream(
326 server: &A2AServer,
327 request: JsonRpcRequest,
328) -> Result<serde_json::Value, JsonRpcError> {
329 let params: MessageSendParams = serde_json::from_value(request.params)
334 .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
335
336 let task_id = params
337 .message
338 .task_id
339 .clone()
340 .unwrap_or_else(|| Uuid::new_v4().to_string());
341
342 let task = Task {
343 id: task_id.clone(),
344 context_id: params.message.context_id.clone(),
345 status: TaskStatus {
346 state: TaskState::Working,
347 message: Some(params.message.clone()),
348 timestamp: Some(chrono::Utc::now().to_rfc3339()),
349 },
350 artifacts: vec![],
351 history: vec![params.message.clone()],
352 metadata: std::collections::HashMap::new(),
353 };
354
355 server.tasks.insert(task_id.clone(), task.clone());
356
357 let prompt: String = params
359 .message
360 .parts
361 .iter()
362 .filter_map(|p| match p {
363 Part::Text { text } => Some(text.as_str()),
364 _ => None,
365 })
366 .collect::<Vec<_>>()
367 .join("\n");
368
369 if prompt.is_empty() {
370 if let Some(mut t) = server.tasks.get_mut(&task_id) {
371 t.status.state = TaskState::Failed;
372 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
373 }
374 return Err(JsonRpcError::invalid_params("No text content in message"));
375 }
376
377 let tasks = server.tasks.clone();
379 let context_id = params.message.context_id.clone();
380 let spawn_task_id = task_id.clone();
381
382 tokio::spawn(async move {
383 let task_id = spawn_task_id;
384 let mut session = match Session::new().await {
385 Ok(s) => s,
386 Err(e) => {
387 tracing::error!("Failed to create session for stream task {}: {}", task_id, e);
388 if let Some(mut t) = tasks.get_mut(&task_id) {
389 t.status.state = TaskState::Failed;
390 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
391 }
392 return;
393 }
394 };
395
396 match session.prompt(&prompt).await {
397 Ok(result) => {
398 let response_message = Message {
399 message_id: Uuid::new_v4().to_string(),
400 role: MessageRole::Agent,
401 parts: vec![Part::Text {
402 text: result.text.clone(),
403 }],
404 context_id,
405 task_id: Some(task_id.clone()),
406 metadata: std::collections::HashMap::new(),
407 };
408
409 let artifact = Artifact {
410 artifact_id: Uuid::new_v4().to_string(),
411 parts: vec![Part::Text { text: result.text }],
412 name: Some("response".to_string()),
413 description: None,
414 metadata: std::collections::HashMap::new(),
415 };
416
417 if let Some(mut t) = tasks.get_mut(&task_id) {
418 t.status.state = TaskState::Completed;
419 t.status.message = Some(response_message.clone());
420 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
421 t.artifacts.push(artifact);
422 t.history.push(response_message);
423 }
424 }
425 Err(e) => {
426 tracing::error!("Stream task {} failed: {}", task_id, e);
427 if let Some(mut t) = tasks.get_mut(&task_id) {
428 t.status.state = TaskState::Failed;
429 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
430 }
431 }
432 }
433 });
434
435 serde_json::to_value(task)
437 .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
438}
439
440async fn handle_tasks_get(
441 server: &A2AServer,
442 request: JsonRpcRequest,
443) -> Result<serde_json::Value, JsonRpcError> {
444 let params: TaskQueryParams = serde_json::from_value(request.params)
445 .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
446
447 let task = server.tasks.get(¶ms.id).ok_or_else(|| JsonRpcError {
448 code: TASK_NOT_FOUND,
449 message: format!("Task not found: {}", params.id),
450 data: None,
451 })?;
452
453 serde_json::to_value(task.value().clone())
454 .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
455}
456
457async fn handle_tasks_cancel(
458 server: &A2AServer,
459 request: JsonRpcRequest,
460) -> Result<serde_json::Value, JsonRpcError> {
461 let params: TaskQueryParams = serde_json::from_value(request.params)
462 .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
463
464 let mut task = server
465 .tasks
466 .get_mut(¶ms.id)
467 .ok_or_else(|| JsonRpcError {
468 code: TASK_NOT_FOUND,
469 message: format!("Task not found: {}", params.id),
470 data: None,
471 })?;
472
473 if !task.status.state.is_active() {
474 return Err(JsonRpcError {
475 code: TASK_NOT_CANCELABLE,
476 message: "Task is already in a terminal state".to_string(),
477 data: None,
478 });
479 }
480
481 task.status.state = TaskState::Cancelled;
482 task.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
483
484 serde_json::to_value(task.value().clone())
485 .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
486}