1use super::types::*;
4use crate::session::Session;
5use anyhow::Result;
6use axum::{
7 Router,
8 extract::State,
9 http::StatusCode,
10 response::Json,
11 routing::{get, post},
12};
13use dashmap::DashMap;
14use std::sync::Arc;
15use uuid::Uuid;
16
17#[derive(Clone)]
19pub struct A2AServer {
20 tasks: Arc<DashMap<String, Task>>,
21 agent_card: AgentCard,
22}
23
24impl A2AServer {
25 pub fn new(agent_card: AgentCard) -> Self {
27 Self {
28 tasks: Arc::new(DashMap::new()),
29 agent_card,
30 }
31 }
32
33 pub fn router(self) -> Router {
35 Router::new()
36 .route("/.well-known/agent.json", get(get_agent_card))
37 .route("/", post(handle_rpc))
38 .with_state(self)
39 }
40
41 #[allow(dead_code)]
43 pub fn card(&self) -> &AgentCard {
44 &self.agent_card
45 }
46
47 pub fn default_card(url: &str) -> AgentCard {
49 AgentCard {
50 name: "CodeTether Agent".to_string(),
51 description: "A2A-native AI coding agent for the CodeTether ecosystem".to_string(),
52 url: url.to_string(),
53 version: env!("CARGO_PKG_VERSION").to_string(),
54 protocol_version: "0.3.0".to_string(),
55 capabilities: AgentCapabilities {
56 streaming: true,
57 push_notifications: false,
58 state_transition_history: true,
59 },
60 skills: vec![
61 AgentSkill {
62 id: "code".to_string(),
63 name: "Code Generation".to_string(),
64 description: "Write, edit, and refactor code".to_string(),
65 tags: vec!["code".to_string(), "programming".to_string()],
66 examples: vec![
67 "Write a function to parse JSON".to_string(),
68 "Refactor this code to use async/await".to_string(),
69 ],
70 input_modes: vec!["text/plain".to_string()],
71 output_modes: vec!["text/plain".to_string()],
72 },
73 AgentSkill {
74 id: "debug".to_string(),
75 name: "Debugging".to_string(),
76 description: "Debug and fix code issues".to_string(),
77 tags: vec!["debug".to_string(), "fix".to_string()],
78 examples: vec![
79 "Why is this function returning undefined?".to_string(),
80 "Fix the null pointer exception".to_string(),
81 ],
82 input_modes: vec!["text/plain".to_string()],
83 output_modes: vec!["text/plain".to_string()],
84 },
85 AgentSkill {
86 id: "explain".to_string(),
87 name: "Code Explanation".to_string(),
88 description: "Explain code and concepts".to_string(),
89 tags: vec!["explain".to_string(), "learn".to_string()],
90 examples: vec![
91 "Explain how this algorithm works".to_string(),
92 "What does this regex do?".to_string(),
93 ],
94 input_modes: vec!["text/plain".to_string()],
95 output_modes: vec!["text/plain".to_string()],
96 },
97 ],
98 default_input_modes: vec!["text/plain".to_string(), "application/json".to_string()],
99 default_output_modes: vec!["text/plain".to_string(), "application/json".to_string()],
100 provider: Some(AgentProvider {
101 organization: "CodeTether".to_string(),
102 url: "https://codetether.ai".to_string(),
103 }),
104 icon_url: None,
105 documentation_url: None,
106 }
107 }
108}
109
110async fn get_agent_card(State(server): State<A2AServer>) -> Json<AgentCard> {
112 Json(server.agent_card.clone())
113}
114
115async fn handle_rpc(
117 State(server): State<A2AServer>,
118 Json(request): Json<JsonRpcRequest>,
119) -> Result<Json<JsonRpcResponse>, (StatusCode, Json<JsonRpcResponse>)> {
120 let request_id = request.id.clone();
121 let response = match request.method.as_str() {
122 "message/send" => handle_message_send(&server, request).await,
123 "message/stream" => handle_message_stream(&server, request).await,
124 "tasks/get" => handle_tasks_get(&server, request).await,
125 "tasks/cancel" => handle_tasks_cancel(&server, request).await,
126 _ => Err(JsonRpcError::method_not_found(&request.method)),
127 };
128
129 match response {
130 Ok(result) => Ok(Json(JsonRpcResponse {
131 jsonrpc: "2.0".to_string(),
132 id: request_id.clone(),
133 result: Some(result),
134 error: None,
135 })),
136 Err(error) => Err((
137 StatusCode::OK,
138 Json(JsonRpcResponse {
139 jsonrpc: "2.0".to_string(),
140 id: request_id,
141 result: None,
142 error: Some(error),
143 }),
144 )),
145 }
146}
147
148async fn handle_message_send(
149 server: &A2AServer,
150 request: JsonRpcRequest,
151) -> Result<serde_json::Value, JsonRpcError> {
152 let params: MessageSendParams = serde_json::from_value(request.params)
153 .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
154
155 let task_id = params
157 .message
158 .task_id
159 .clone()
160 .unwrap_or_else(|| Uuid::new_v4().to_string());
161
162 let task = Task {
163 id: task_id.clone(),
164 context_id: params.message.context_id.clone(),
165 status: TaskStatus {
166 state: TaskState::Working,
167 message: Some(params.message.clone()),
168 timestamp: Some(chrono::Utc::now().to_rfc3339()),
169 },
170 artifacts: vec![],
171 history: vec![params.message.clone()],
172 metadata: std::collections::HashMap::new(),
173 };
174
175 server.tasks.insert(task_id.clone(), task.clone());
176
177 let prompt: String = params
179 .message
180 .parts
181 .iter()
182 .filter_map(|p| match p {
183 Part::Text { text } => Some(text.as_str()),
184 _ => None,
185 })
186 .collect::<Vec<_>>()
187 .join("\n");
188
189 if prompt.is_empty() {
190 if let Some(mut t) = server.tasks.get_mut(&task_id) {
192 t.status.state = TaskState::Failed;
193 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
194 }
195 return Err(JsonRpcError::invalid_params("No text content in message"));
196 }
197
198 let blocking = params
200 .configuration
201 .as_ref()
202 .and_then(|c| c.blocking)
203 .unwrap_or(true);
204
205 if blocking {
206 let mut session = Session::new().await.map_err(|e| {
208 JsonRpcError::internal_error(format!("Failed to create session: {}", e))
209 })?;
210
211 match session.prompt(&prompt).await {
212 Ok(result) => {
213 let response_message = Message {
214 message_id: Uuid::new_v4().to_string(),
215 role: MessageRole::Agent,
216 parts: vec![Part::Text {
217 text: result.text.clone(),
218 }],
219 context_id: params.message.context_id.clone(),
220 task_id: Some(task_id.clone()),
221 metadata: std::collections::HashMap::new(),
222 };
223
224 let artifact = Artifact {
225 artifact_id: Uuid::new_v4().to_string(),
226 parts: vec![Part::Text { text: result.text }],
227 name: Some("response".to_string()),
228 description: None,
229 metadata: std::collections::HashMap::new(),
230 };
231
232 if let Some(mut t) = server.tasks.get_mut(&task_id) {
233 t.status.state = TaskState::Completed;
234 t.status.message = Some(response_message.clone());
235 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
236 t.artifacts.push(artifact);
237 t.history.push(response_message);
238 }
239 }
240 Err(e) => {
241 let error_message = Message {
242 message_id: Uuid::new_v4().to_string(),
243 role: MessageRole::Agent,
244 parts: vec![Part::Text {
245 text: format!("Error: {}", e),
246 }],
247 context_id: params.message.context_id.clone(),
248 task_id: Some(task_id.clone()),
249 metadata: std::collections::HashMap::new(),
250 };
251
252 if let Some(mut t) = server.tasks.get_mut(&task_id) {
253 t.status.state = TaskState::Failed;
254 t.status.message = Some(error_message);
255 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
256 }
257 }
258 }
259 } else {
260 let tasks = server.tasks.clone();
262 let context_id = params.message.context_id.clone();
263 let spawn_task_id = task_id.clone();
264
265 tokio::spawn(async move {
266 let task_id = spawn_task_id;
267 let mut session = match Session::new().await {
268 Ok(s) => s,
269 Err(e) => {
270 tracing::error!("Failed to create session for task {}: {}", task_id, e);
271 if let Some(mut t) = tasks.get_mut(&task_id) {
272 t.status.state = TaskState::Failed;
273 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
274 }
275 return;
276 }
277 };
278
279 match session.prompt(&prompt).await {
280 Ok(result) => {
281 let response_message = Message {
282 message_id: Uuid::new_v4().to_string(),
283 role: MessageRole::Agent,
284 parts: vec![Part::Text {
285 text: result.text.clone(),
286 }],
287 context_id,
288 task_id: Some(task_id.clone()),
289 metadata: std::collections::HashMap::new(),
290 };
291
292 let artifact = Artifact {
293 artifact_id: Uuid::new_v4().to_string(),
294 parts: vec![Part::Text { text: result.text }],
295 name: Some("response".to_string()),
296 description: None,
297 metadata: std::collections::HashMap::new(),
298 };
299
300 if let Some(mut t) = tasks.get_mut(&task_id) {
301 t.status.state = TaskState::Completed;
302 t.status.message = Some(response_message.clone());
303 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
304 t.artifacts.push(artifact);
305 t.history.push(response_message);
306 }
307 }
308 Err(e) => {
309 tracing::error!("Task {} failed: {}", task_id, e);
310 if let Some(mut t) = tasks.get_mut(&task_id) {
311 t.status.state = TaskState::Failed;
312 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
313 }
314 }
315 }
316 });
317 }
318
319 let task = server.tasks.get(&task_id).unwrap();
321 serde_json::to_value(task.value().clone())
322 .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
323}
324
325async fn handle_message_stream(
326 server: &A2AServer,
327 request: JsonRpcRequest,
328) -> Result<serde_json::Value, JsonRpcError> {
329 let params: MessageSendParams = serde_json::from_value(request.params)
334 .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
335
336 let task_id = params
337 .message
338 .task_id
339 .clone()
340 .unwrap_or_else(|| Uuid::new_v4().to_string());
341
342 let task = Task {
343 id: task_id.clone(),
344 context_id: params.message.context_id.clone(),
345 status: TaskStatus {
346 state: TaskState::Working,
347 message: Some(params.message.clone()),
348 timestamp: Some(chrono::Utc::now().to_rfc3339()),
349 },
350 artifacts: vec![],
351 history: vec![params.message.clone()],
352 metadata: std::collections::HashMap::new(),
353 };
354
355 server.tasks.insert(task_id.clone(), task.clone());
356
357 let prompt: String = params
359 .message
360 .parts
361 .iter()
362 .filter_map(|p| match p {
363 Part::Text { text } => Some(text.as_str()),
364 _ => None,
365 })
366 .collect::<Vec<_>>()
367 .join("\n");
368
369 if prompt.is_empty() {
370 if let Some(mut t) = server.tasks.get_mut(&task_id) {
371 t.status.state = TaskState::Failed;
372 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
373 }
374 return Err(JsonRpcError::invalid_params("No text content in message"));
375 }
376
377 let tasks = server.tasks.clone();
379 let context_id = params.message.context_id.clone();
380 let spawn_task_id = task_id.clone();
381
382 tokio::spawn(async move {
383 let task_id = spawn_task_id;
384 let mut session = match Session::new().await {
385 Ok(s) => s,
386 Err(e) => {
387 tracing::error!(
388 "Failed to create session for stream task {}: {}",
389 task_id,
390 e
391 );
392 if let Some(mut t) = tasks.get_mut(&task_id) {
393 t.status.state = TaskState::Failed;
394 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
395 }
396 return;
397 }
398 };
399
400 match session.prompt(&prompt).await {
401 Ok(result) => {
402 let response_message = Message {
403 message_id: Uuid::new_v4().to_string(),
404 role: MessageRole::Agent,
405 parts: vec![Part::Text {
406 text: result.text.clone(),
407 }],
408 context_id,
409 task_id: Some(task_id.clone()),
410 metadata: std::collections::HashMap::new(),
411 };
412
413 let artifact = Artifact {
414 artifact_id: Uuid::new_v4().to_string(),
415 parts: vec![Part::Text { text: result.text }],
416 name: Some("response".to_string()),
417 description: None,
418 metadata: std::collections::HashMap::new(),
419 };
420
421 if let Some(mut t) = tasks.get_mut(&task_id) {
422 t.status.state = TaskState::Completed;
423 t.status.message = Some(response_message.clone());
424 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
425 t.artifacts.push(artifact);
426 t.history.push(response_message);
427 }
428 }
429 Err(e) => {
430 tracing::error!("Stream task {} failed: {}", task_id, e);
431 if let Some(mut t) = tasks.get_mut(&task_id) {
432 t.status.state = TaskState::Failed;
433 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
434 }
435 }
436 }
437 });
438
439 serde_json::to_value(task)
441 .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
442}
443
444async fn handle_tasks_get(
445 server: &A2AServer,
446 request: JsonRpcRequest,
447) -> Result<serde_json::Value, JsonRpcError> {
448 let params: TaskQueryParams = serde_json::from_value(request.params)
449 .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
450
451 let task = server.tasks.get(¶ms.id).ok_or_else(|| JsonRpcError {
452 code: TASK_NOT_FOUND,
453 message: format!("Task not found: {}", params.id),
454 data: None,
455 })?;
456
457 serde_json::to_value(task.value().clone())
458 .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
459}
460
461async fn handle_tasks_cancel(
462 server: &A2AServer,
463 request: JsonRpcRequest,
464) -> Result<serde_json::Value, JsonRpcError> {
465 let params: TaskQueryParams = serde_json::from_value(request.params)
466 .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
467
468 let mut task = server
469 .tasks
470 .get_mut(¶ms.id)
471 .ok_or_else(|| JsonRpcError {
472 code: TASK_NOT_FOUND,
473 message: format!("Task not found: {}", params.id),
474 data: None,
475 })?;
476
477 if !task.status.state.is_active() {
478 return Err(JsonRpcError {
479 code: TASK_NOT_CANCELABLE,
480 message: "Task is already in a terminal state".to_string(),
481 data: None,
482 });
483 }
484
485 task.status.state = TaskState::Cancelled;
486 task.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
487
488 serde_json::to_value(task.value().clone())
489 .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
490}