1use super::types::*;
4use crate::session::Session;
5use crate::telemetry::{ToolExecution, record_persistent};
6use anyhow::Result;
7use axum::{
8 Router,
9 extract::State,
10 http::StatusCode,
11 response::Json,
12 routing::{get, post},
13};
14use dashmap::DashMap;
15use std::sync::Arc;
16use std::time::{Duration, Instant};
17use uuid::Uuid;
18
19#[derive(Clone)]
21pub struct A2AServer {
22 tasks: Arc<DashMap<String, Task>>,
23 agent_card: AgentCard,
24}
25
26impl A2AServer {
27 pub fn new(agent_card: AgentCard) -> Self {
29 Self {
30 tasks: Arc::new(DashMap::new()),
31 agent_card,
32 }
33 }
34
35 pub fn router(self) -> Router {
37 Router::new()
38 .route("/.well-known/agent.json", get(get_agent_card))
39 .route("/.well-known/agent-card.json", get(get_agent_card))
40 .route("/", post(handle_rpc))
41 .with_state(self)
42 }
43
44 #[allow(dead_code)]
46 pub fn card(&self) -> &AgentCard {
47 &self.agent_card
48 }
49
50 pub fn default_card(url: &str) -> AgentCard {
52 AgentCard {
53 name: "CodeTether Agent".to_string(),
54 description: "A2A-native AI coding agent for the CodeTether ecosystem".to_string(),
55 url: url.to_string(),
56 version: env!("CARGO_PKG_VERSION").to_string(),
57 protocol_version: "0.3.0".to_string(),
58 preferred_transport: None,
59 additional_interfaces: vec![],
60 capabilities: AgentCapabilities {
61 streaming: true,
62 push_notifications: false,
63 state_transition_history: true,
64 extensions: vec![],
65 },
66 skills: vec![
67 AgentSkill {
68 id: "code".to_string(),
69 name: "Code Generation".to_string(),
70 description: "Write, edit, and refactor code".to_string(),
71 tags: vec!["code".to_string(), "programming".to_string()],
72 examples: vec![
73 "Write a function to parse JSON".to_string(),
74 "Refactor this code to use async/await".to_string(),
75 ],
76 input_modes: vec!["text/plain".to_string()],
77 output_modes: vec!["text/plain".to_string()],
78 },
79 AgentSkill {
80 id: "debug".to_string(),
81 name: "Debugging".to_string(),
82 description: "Debug and fix code issues".to_string(),
83 tags: vec!["debug".to_string(), "fix".to_string()],
84 examples: vec![
85 "Why is this function returning undefined?".to_string(),
86 "Fix the null pointer exception".to_string(),
87 ],
88 input_modes: vec!["text/plain".to_string()],
89 output_modes: vec!["text/plain".to_string()],
90 },
91 AgentSkill {
92 id: "explain".to_string(),
93 name: "Code Explanation".to_string(),
94 description: "Explain code and concepts".to_string(),
95 tags: vec!["explain".to_string(), "learn".to_string()],
96 examples: vec![
97 "Explain how this algorithm works".to_string(),
98 "What does this regex do?".to_string(),
99 ],
100 input_modes: vec!["text/plain".to_string()],
101 output_modes: vec!["text/plain".to_string()],
102 },
103 ],
104 default_input_modes: vec!["text/plain".to_string(), "application/json".to_string()],
105 default_output_modes: vec!["text/plain".to_string(), "application/json".to_string()],
106 provider: Some(AgentProvider {
107 organization: "CodeTether".to_string(),
108 url: "https://codetether.run".to_string(),
109 }),
110 icon_url: None,
111 documentation_url: None,
112 security_schemes: Default::default(),
113 security: vec![],
114 supports_authenticated_extended_card: false,
115 signatures: vec![],
116 }
117 }
118}
119
120async fn get_agent_card(State(server): State<A2AServer>) -> Json<AgentCard> {
122 Json(server.agent_card.clone())
123}
124
125fn record_a2a_message_telemetry(
126 tool_name: &str,
127 task_id: &str,
128 blocking: bool,
129 prompt: &str,
130 duration: Duration,
131 success: bool,
132 output: Option<String>,
133 error: Option<String>,
134) {
135 let record = crate::telemetry::A2AMessageRecord {
136 tool_name: tool_name.to_string(),
137 task_id: task_id.to_string(),
138 blocking,
139 prompt: prompt.to_string(),
140 duration_ms: duration.as_millis() as u64,
141 success,
142 output,
143 error,
144 timestamp: chrono::Utc::now(),
145 };
146 let _ = record_persistent("a2a_message", &serde_json::to_value(&record).unwrap_or_default());
147}
148
149async fn handle_rpc(
151 State(server): State<A2AServer>,
152 Json(request): Json<JsonRpcRequest>,
153) -> Result<Json<JsonRpcResponse>, (StatusCode, Json<JsonRpcResponse>)> {
154 let request_id = request.id.clone();
155 let response = match request.method.as_str() {
156 "message/send" => handle_message_send(&server, request).await,
157 "message/stream" => handle_message_stream(&server, request).await,
158 "tasks/get" => handle_tasks_get(&server, request).await,
159 "tasks/cancel" => handle_tasks_cancel(&server, request).await,
160 _ => Err(JsonRpcError::method_not_found(&request.method)),
161 };
162
163 match response {
164 Ok(result) => Ok(Json(JsonRpcResponse {
165 jsonrpc: "2.0".to_string(),
166 id: request_id.clone(),
167 result: Some(result),
168 error: None,
169 })),
170 Err(error) => Err((
171 StatusCode::OK,
172 Json(JsonRpcResponse {
173 jsonrpc: "2.0".to_string(),
174 id: request_id,
175 result: None,
176 error: Some(error),
177 }),
178 )),
179 }
180}
181
182async fn handle_message_send(
183 server: &A2AServer,
184 request: JsonRpcRequest,
185) -> Result<serde_json::Value, JsonRpcError> {
186 let params: MessageSendParams = serde_json::from_value(request.params)
187 .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
188
189 let task_id = params
191 .message
192 .task_id
193 .clone()
194 .unwrap_or_else(|| Uuid::new_v4().to_string());
195
196 let task = Task {
197 id: task_id.clone(),
198 context_id: params.message.context_id.clone(),
199 status: TaskStatus {
200 state: TaskState::Working,
201 message: Some(params.message.clone()),
202 timestamp: Some(chrono::Utc::now().to_rfc3339()),
203 },
204 artifacts: vec![],
205 history: vec![params.message.clone()],
206 metadata: std::collections::HashMap::new(),
207 };
208
209 server.tasks.insert(task_id.clone(), task.clone());
210
211 let prompt: String = params
213 .message
214 .parts
215 .iter()
216 .filter_map(|p| match p {
217 Part::Text { text } => Some(text.as_str()),
218 _ => None,
219 })
220 .collect::<Vec<_>>()
221 .join("\n");
222
223 if prompt.is_empty() {
224 if let Some(mut t) = server.tasks.get_mut(&task_id) {
226 t.status.state = TaskState::Failed;
227 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
228 }
229 return Err(JsonRpcError::invalid_params("No text content in message"));
230 }
231
232 let blocking = params
234 .configuration
235 .as_ref()
236 .and_then(|c| c.blocking)
237 .unwrap_or(true);
238
239 if blocking {
240 let mut session = Session::new().await.map_err(|e| {
242 JsonRpcError::internal_error(format!("Failed to create session: {}", e))
243 })?;
244 let started_at = Instant::now();
245
246 match session.prompt(&prompt).await {
247 Ok(result) => {
248 let result_text = result.text;
249 let response_message = Message {
250 message_id: Uuid::new_v4().to_string(),
251 role: MessageRole::Agent,
252 parts: vec![Part::Text {
253 text: result_text.clone(),
254 }],
255 context_id: params.message.context_id.clone(),
256 task_id: Some(task_id.clone()),
257 metadata: std::collections::HashMap::new(),
258 extensions: vec![],
259 };
260
261 let artifact = Artifact {
262 artifact_id: Uuid::new_v4().to_string(),
263 parts: vec![Part::Text {
264 text: result_text.clone(),
265 }],
266 name: Some("response".to_string()),
267 description: None,
268 metadata: std::collections::HashMap::new(),
269 extensions: vec![],
270 };
271
272 if let Some(mut t) = server.tasks.get_mut(&task_id) {
273 t.status.state = TaskState::Completed;
274 t.status.message = Some(response_message.clone());
275 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
276 t.artifacts.push(artifact.clone());
277 t.history.push(response_message);
278
279 let status_event = TaskStatusUpdateEvent {
280 id: task_id.clone(),
281 status: t.status.clone(),
282 is_final: true,
283 metadata: std::collections::HashMap::new(),
284 };
285 let artifact_event = TaskArtifactUpdateEvent {
286 id: task_id.clone(),
287 artifact,
288 metadata: std::collections::HashMap::new(),
289 };
290 tracing::debug!(
291 task_id = %task_id,
292 event = ?StreamEvent::StatusUpdate(status_event),
293 "Task completed"
294 );
295 tracing::debug!(
296 task_id = %task_id,
297 event = ?StreamEvent::ArtifactUpdate(artifact_event),
298 "Artifact produced"
299 );
300 }
301
302 record_a2a_message_telemetry(
303 "a2a_message_send",
304 &task_id,
305 true,
306 &prompt,
307 started_at.elapsed(),
308 true,
309 Some(result_text),
310 None,
311 );
312 }
313 Err(e) => {
314 let error_message = Message {
315 message_id: Uuid::new_v4().to_string(),
316 role: MessageRole::Agent,
317 parts: vec![Part::Text {
318 text: format!("Error: {}", e),
319 }],
320 context_id: params.message.context_id.clone(),
321 task_id: Some(task_id.clone()),
322 metadata: std::collections::HashMap::new(),
323 extensions: vec![],
324 };
325
326 if let Some(mut t) = server.tasks.get_mut(&task_id) {
327 t.status.state = TaskState::Failed;
328 t.status.message = Some(error_message);
329 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
330 }
331
332 record_a2a_message_telemetry(
333 "a2a_message_send",
334 &task_id,
335 true,
336 &prompt,
337 started_at.elapsed(),
338 false,
339 None,
340 Some(e.to_string()),
341 );
342 }
343 }
344 } else {
345 let tasks = server.tasks.clone();
347 let context_id = params.message.context_id.clone();
348 let spawn_task_id = task_id.clone();
349
350 tokio::spawn(async move {
351 let task_id = spawn_task_id;
352 let started_at = Instant::now();
353 let mut session = match Session::new().await {
354 Ok(s) => s,
355 Err(e) => {
356 tracing::error!("Failed to create session for task {}: {}", task_id, e);
357 if let Some(mut t) = tasks.get_mut(&task_id) {
358 t.status.state = TaskState::Failed;
359 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
360 }
361 record_a2a_message_telemetry(
362 "a2a_message_send",
363 &task_id,
364 false,
365 &prompt,
366 started_at.elapsed(),
367 false,
368 None,
369 Some(e.to_string()),
370 );
371 return;
372 }
373 };
374
375 match session.prompt(&prompt).await {
376 Ok(result) => {
377 let result_text = result.text;
378 let response_message = Message {
379 message_id: Uuid::new_v4().to_string(),
380 role: MessageRole::Agent,
381 parts: vec![Part::Text {
382 text: result_text.clone(),
383 }],
384 context_id,
385 task_id: Some(task_id.clone()),
386 metadata: std::collections::HashMap::new(),
387 extensions: vec![],
388 };
389
390 let artifact = Artifact {
391 artifact_id: Uuid::new_v4().to_string(),
392 parts: vec![Part::Text {
393 text: result_text.clone(),
394 }],
395 name: Some("response".to_string()),
396 description: None,
397 metadata: std::collections::HashMap::new(),
398 extensions: vec![],
399 };
400
401 if let Some(mut t) = tasks.get_mut(&task_id) {
402 t.status.state = TaskState::Completed;
403 t.status.message = Some(response_message.clone());
404 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
405 t.artifacts.push(artifact);
406 t.history.push(response_message);
407 }
408
409 record_a2a_message_telemetry(
410 "a2a_message_send",
411 &task_id,
412 false,
413 &prompt,
414 started_at.elapsed(),
415 true,
416 Some(result_text),
417 None,
418 );
419 }
420 Err(e) => {
421 tracing::error!("Task {} failed: {}", task_id, e);
422 if let Some(mut t) = tasks.get_mut(&task_id) {
423 t.status.state = TaskState::Failed;
424 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
425 }
426 record_a2a_message_telemetry(
427 "a2a_message_send",
428 &task_id,
429 false,
430 &prompt,
431 started_at.elapsed(),
432 false,
433 None,
434 Some(e.to_string()),
435 );
436 }
437 }
438 });
439 }
440
441 let task = server.tasks.get(&task_id).unwrap();
443 let response = SendMessageResponse::Task(task.value().clone());
444 serde_json::to_value(response)
445 .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
446}
447
448async fn handle_message_stream(
449 server: &A2AServer,
450 request: JsonRpcRequest,
451) -> Result<serde_json::Value, JsonRpcError> {
452 let params: MessageSendParams = serde_json::from_value(request.params)
457 .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
458
459 let task_id = params
460 .message
461 .task_id
462 .clone()
463 .unwrap_or_else(|| Uuid::new_v4().to_string());
464
465 let task = Task {
466 id: task_id.clone(),
467 context_id: params.message.context_id.clone(),
468 status: TaskStatus {
469 state: TaskState::Working,
470 message: Some(params.message.clone()),
471 timestamp: Some(chrono::Utc::now().to_rfc3339()),
472 },
473 artifacts: vec![],
474 history: vec![params.message.clone()],
475 metadata: std::collections::HashMap::new(),
476 };
477
478 server.tasks.insert(task_id.clone(), task.clone());
479
480 let prompt: String = params
482 .message
483 .parts
484 .iter()
485 .filter_map(|p| match p {
486 Part::Text { text } => Some(text.as_str()),
487 _ => None,
488 })
489 .collect::<Vec<_>>()
490 .join("\n");
491
492 if prompt.is_empty() {
493 if let Some(mut t) = server.tasks.get_mut(&task_id) {
494 t.status.state = TaskState::Failed;
495 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
496 }
497 return Err(JsonRpcError::invalid_params("No text content in message"));
498 }
499
500 let tasks = server.tasks.clone();
502 let context_id = params.message.context_id.clone();
503 let spawn_task_id = task_id.clone();
504
505 tokio::spawn(async move {
506 let task_id = spawn_task_id;
507 let started_at = Instant::now();
508 let mut session = match Session::new().await {
509 Ok(s) => s,
510 Err(e) => {
511 tracing::error!(
512 "Failed to create session for stream task {}: {}",
513 task_id,
514 e
515 );
516 if let Some(mut t) = tasks.get_mut(&task_id) {
517 t.status.state = TaskState::Failed;
518 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
519 }
520 record_a2a_message_telemetry(
521 "a2a_message_stream",
522 &task_id,
523 false,
524 &prompt,
525 started_at.elapsed(),
526 false,
527 None,
528 Some(e.to_string()),
529 );
530 return;
531 }
532 };
533
534 match session.prompt(&prompt).await {
535 Ok(result) => {
536 let result_text = result.text;
537 let response_message = Message {
538 message_id: Uuid::new_v4().to_string(),
539 role: MessageRole::Agent,
540 parts: vec![Part::Text {
541 text: result_text.clone(),
542 }],
543 context_id,
544 task_id: Some(task_id.clone()),
545 metadata: std::collections::HashMap::new(),
546 extensions: vec![],
547 };
548
549 let artifact = Artifact {
550 artifact_id: Uuid::new_v4().to_string(),
551 parts: vec![Part::Text {
552 text: result_text.clone(),
553 }],
554 name: Some("response".to_string()),
555 description: None,
556 metadata: std::collections::HashMap::new(),
557 extensions: vec![],
558 };
559
560 if let Some(mut t) = tasks.get_mut(&task_id) {
561 t.status.state = TaskState::Completed;
562 t.status.message = Some(response_message.clone());
563 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
564 t.artifacts.push(artifact.clone());
565 t.history.push(response_message);
566
567 let status_event = TaskStatusUpdateEvent {
569 id: task_id.clone(),
570 status: t.status.clone(),
571 is_final: true,
572 metadata: std::collections::HashMap::new(),
573 };
574 let artifact_event = TaskArtifactUpdateEvent {
575 id: task_id.clone(),
576 artifact,
577 metadata: std::collections::HashMap::new(),
578 };
579 tracing::debug!(
580 task_id = %task_id,
581 event = ?StreamEvent::StatusUpdate(status_event),
582 "Task completed"
583 );
584 tracing::debug!(
585 task_id = %task_id,
586 event = ?StreamEvent::ArtifactUpdate(artifact_event),
587 "Artifact produced"
588 );
589 }
590
591 record_a2a_message_telemetry(
592 "a2a_message_stream",
593 &task_id,
594 false,
595 &prompt,
596 started_at.elapsed(),
597 true,
598 Some(result_text),
599 None,
600 );
601 }
602 Err(e) => {
603 tracing::error!("Stream task {} failed: {}", task_id, e);
604 if let Some(mut t) = tasks.get_mut(&task_id) {
605 t.status.state = TaskState::Failed;
606 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
607 }
608 record_a2a_message_telemetry(
609 "a2a_message_stream",
610 &task_id,
611 false,
612 &prompt,
613 started_at.elapsed(),
614 false,
615 None,
616 Some(e.to_string()),
617 );
618 }
619 }
620 });
621
622 let response = SendMessageResponse::Task(task);
624 serde_json::to_value(response)
625 .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
626}
627
628async fn handle_tasks_get(
629 server: &A2AServer,
630 request: JsonRpcRequest,
631) -> Result<serde_json::Value, JsonRpcError> {
632 let params: TaskQueryParams = serde_json::from_value(request.params)
633 .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
634
635 let task = server.tasks.get(¶ms.id).ok_or_else(|| JsonRpcError {
636 code: TASK_NOT_FOUND,
637 message: format!("Task not found: {}", params.id),
638 data: None,
639 })?;
640
641 serde_json::to_value(task.value().clone())
642 .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
643}
644
645async fn handle_tasks_cancel(
646 server: &A2AServer,
647 request: JsonRpcRequest,
648) -> Result<serde_json::Value, JsonRpcError> {
649 let params: TaskQueryParams = serde_json::from_value(request.params)
650 .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
651
652 let mut task = server
653 .tasks
654 .get_mut(¶ms.id)
655 .ok_or_else(|| JsonRpcError {
656 code: TASK_NOT_FOUND,
657 message: format!("Task not found: {}", params.id),
658 data: None,
659 })?;
660
661 if !task.status.state.is_active() {
662 return Err(JsonRpcError {
663 code: TASK_NOT_CANCELABLE,
664 message: "Task is already in a terminal state".to_string(),
665 data: None,
666 });
667 }
668
669 task.status.state = TaskState::Cancelled;
670 task.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
671
672 serde_json::to_value(task.value().clone())
673 .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
674}