1use super::types::*;
4use crate::session::Session;
5use crate::telemetry::{ToolExecution, record_persistent};
6use anyhow::Result;
7use axum::{
8 Router,
9 extract::State,
10 http::StatusCode,
11 response::Json,
12 routing::{get, post},
13};
14use dashmap::DashMap;
15use std::sync::Arc;
16use std::time::{Duration, Instant};
17use uuid::Uuid;
18
19#[derive(Clone)]
21pub struct A2AServer {
22 tasks: Arc<DashMap<String, Task>>,
23 agent_card: AgentCard,
24}
25
26impl A2AServer {
27 pub fn new(agent_card: AgentCard) -> Self {
29 Self {
30 tasks: Arc::new(DashMap::new()),
31 agent_card,
32 }
33 }
34
35 pub fn router(self) -> Router {
37 Router::new()
38 .route("/.well-known/agent.json", get(get_agent_card))
39 .route("/", post(handle_rpc))
40 .with_state(self)
41 }
42
43 #[allow(dead_code)]
45 pub fn card(&self) -> &AgentCard {
46 &self.agent_card
47 }
48
49 pub fn default_card(url: &str) -> AgentCard {
51 AgentCard {
52 name: "CodeTether Agent".to_string(),
53 description: "A2A-native AI coding agent for the CodeTether ecosystem".to_string(),
54 url: url.to_string(),
55 version: env!("CARGO_PKG_VERSION").to_string(),
56 protocol_version: "0.3.0".to_string(),
57 preferred_transport: None,
58 additional_interfaces: vec![],
59 capabilities: AgentCapabilities {
60 streaming: true,
61 push_notifications: false,
62 state_transition_history: true,
63 extensions: vec![],
64 },
65 skills: vec![
66 AgentSkill {
67 id: "code".to_string(),
68 name: "Code Generation".to_string(),
69 description: "Write, edit, and refactor code".to_string(),
70 tags: vec!["code".to_string(), "programming".to_string()],
71 examples: vec![
72 "Write a function to parse JSON".to_string(),
73 "Refactor this code to use async/await".to_string(),
74 ],
75 input_modes: vec!["text/plain".to_string()],
76 output_modes: vec!["text/plain".to_string()],
77 },
78 AgentSkill {
79 id: "debug".to_string(),
80 name: "Debugging".to_string(),
81 description: "Debug and fix code issues".to_string(),
82 tags: vec!["debug".to_string(), "fix".to_string()],
83 examples: vec![
84 "Why is this function returning undefined?".to_string(),
85 "Fix the null pointer exception".to_string(),
86 ],
87 input_modes: vec!["text/plain".to_string()],
88 output_modes: vec!["text/plain".to_string()],
89 },
90 AgentSkill {
91 id: "explain".to_string(),
92 name: "Code Explanation".to_string(),
93 description: "Explain code and concepts".to_string(),
94 tags: vec!["explain".to_string(), "learn".to_string()],
95 examples: vec![
96 "Explain how this algorithm works".to_string(),
97 "What does this regex do?".to_string(),
98 ],
99 input_modes: vec!["text/plain".to_string()],
100 output_modes: vec!["text/plain".to_string()],
101 },
102 ],
103 default_input_modes: vec!["text/plain".to_string(), "application/json".to_string()],
104 default_output_modes: vec!["text/plain".to_string(), "application/json".to_string()],
105 provider: Some(AgentProvider {
106 organization: "CodeTether".to_string(),
107 url: "https://codetether.run".to_string(),
108 }),
109 icon_url: None,
110 documentation_url: None,
111 security_schemes: Default::default(),
112 security: vec![],
113 supports_authenticated_extended_card: false,
114 signatures: vec![],
115 }
116 }
117}
118
119async fn get_agent_card(State(server): State<A2AServer>) -> Json<AgentCard> {
121 Json(server.agent_card.clone())
122}
123
124fn record_a2a_message_telemetry(
125 tool_name: &str,
126 task_id: &str,
127 blocking: bool,
128 prompt: &str,
129 duration: Duration,
130 success: bool,
131 output: Option<String>,
132 error: Option<String>,
133) {
134 let input = serde_json::json!({
135 "task_id": task_id,
136 "blocking": blocking,
137 "prompt": prompt,
138 });
139
140 let execution = ToolExecution::start(tool_name, input).with_session(task_id.to_string());
141 let execution = if success {
142 execution.complete_success(output.unwrap_or_default(), duration)
143 } else {
144 execution.complete_error(
145 error.unwrap_or_else(|| "A2A message execution failed".to_string()),
146 duration,
147 )
148 };
149
150 record_persistent(execution);
151}
152
153async fn handle_rpc(
155 State(server): State<A2AServer>,
156 Json(request): Json<JsonRpcRequest>,
157) -> Result<Json<JsonRpcResponse>, (StatusCode, Json<JsonRpcResponse>)> {
158 let request_id = request.id.clone();
159 let response = match request.method.as_str() {
160 "message/send" => handle_message_send(&server, request).await,
161 "message/stream" => handle_message_stream(&server, request).await,
162 "tasks/get" => handle_tasks_get(&server, request).await,
163 "tasks/cancel" => handle_tasks_cancel(&server, request).await,
164 _ => Err(JsonRpcError::method_not_found(&request.method)),
165 };
166
167 match response {
168 Ok(result) => Ok(Json(JsonRpcResponse {
169 jsonrpc: "2.0".to_string(),
170 id: request_id.clone(),
171 result: Some(result),
172 error: None,
173 })),
174 Err(error) => Err((
175 StatusCode::OK,
176 Json(JsonRpcResponse {
177 jsonrpc: "2.0".to_string(),
178 id: request_id,
179 result: None,
180 error: Some(error),
181 }),
182 )),
183 }
184}
185
186async fn handle_message_send(
187 server: &A2AServer,
188 request: JsonRpcRequest,
189) -> Result<serde_json::Value, JsonRpcError> {
190 let params: MessageSendParams = serde_json::from_value(request.params)
191 .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
192
193 let task_id = params
195 .message
196 .task_id
197 .clone()
198 .unwrap_or_else(|| Uuid::new_v4().to_string());
199
200 let task = Task {
201 id: task_id.clone(),
202 context_id: params.message.context_id.clone(),
203 status: TaskStatus {
204 state: TaskState::Working,
205 message: Some(params.message.clone()),
206 timestamp: Some(chrono::Utc::now().to_rfc3339()),
207 },
208 artifacts: vec![],
209 history: vec![params.message.clone()],
210 metadata: std::collections::HashMap::new(),
211 };
212
213 server.tasks.insert(task_id.clone(), task.clone());
214
215 let prompt: String = params
217 .message
218 .parts
219 .iter()
220 .filter_map(|p| match p {
221 Part::Text { text } => Some(text.as_str()),
222 _ => None,
223 })
224 .collect::<Vec<_>>()
225 .join("\n");
226
227 if prompt.is_empty() {
228 if let Some(mut t) = server.tasks.get_mut(&task_id) {
230 t.status.state = TaskState::Failed;
231 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
232 }
233 return Err(JsonRpcError::invalid_params("No text content in message"));
234 }
235
236 let blocking = params
238 .configuration
239 .as_ref()
240 .and_then(|c| c.blocking)
241 .unwrap_or(true);
242
243 if blocking {
244 let mut session = Session::new().await.map_err(|e| {
246 JsonRpcError::internal_error(format!("Failed to create session: {}", e))
247 })?;
248 let started_at = Instant::now();
249
250 match session.prompt(&prompt).await {
251 Ok(result) => {
252 let result_text = result.text;
253 let response_message = Message {
254 message_id: Uuid::new_v4().to_string(),
255 role: MessageRole::Agent,
256 parts: vec![Part::Text {
257 text: result_text.clone(),
258 }],
259 context_id: params.message.context_id.clone(),
260 task_id: Some(task_id.clone()),
261 metadata: std::collections::HashMap::new(),
262 extensions: vec![],
263 };
264
265 let artifact = Artifact {
266 artifact_id: Uuid::new_v4().to_string(),
267 parts: vec![Part::Text {
268 text: result_text.clone(),
269 }],
270 name: Some("response".to_string()),
271 description: None,
272 metadata: std::collections::HashMap::new(),
273 extensions: vec![],
274 };
275
276 if let Some(mut t) = server.tasks.get_mut(&task_id) {
277 t.status.state = TaskState::Completed;
278 t.status.message = Some(response_message.clone());
279 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
280 t.artifacts.push(artifact.clone());
281 t.history.push(response_message);
282
283 let status_event = TaskStatusUpdateEvent {
284 id: task_id.clone(),
285 status: t.status.clone(),
286 is_final: true,
287 metadata: std::collections::HashMap::new(),
288 };
289 let artifact_event = TaskArtifactUpdateEvent {
290 id: task_id.clone(),
291 artifact,
292 metadata: std::collections::HashMap::new(),
293 };
294 tracing::debug!(
295 task_id = %task_id,
296 event = ?StreamEvent::StatusUpdate(status_event),
297 "Task completed"
298 );
299 tracing::debug!(
300 task_id = %task_id,
301 event = ?StreamEvent::ArtifactUpdate(artifact_event),
302 "Artifact produced"
303 );
304 }
305
306 record_a2a_message_telemetry(
307 "a2a_message_send",
308 &task_id,
309 true,
310 &prompt,
311 started_at.elapsed(),
312 true,
313 Some(result_text),
314 None,
315 );
316 }
317 Err(e) => {
318 let error_message = Message {
319 message_id: Uuid::new_v4().to_string(),
320 role: MessageRole::Agent,
321 parts: vec![Part::Text {
322 text: format!("Error: {}", e),
323 }],
324 context_id: params.message.context_id.clone(),
325 task_id: Some(task_id.clone()),
326 metadata: std::collections::HashMap::new(),
327 extensions: vec![],
328 };
329
330 if let Some(mut t) = server.tasks.get_mut(&task_id) {
331 t.status.state = TaskState::Failed;
332 t.status.message = Some(error_message);
333 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
334 }
335
336 record_a2a_message_telemetry(
337 "a2a_message_send",
338 &task_id,
339 true,
340 &prompt,
341 started_at.elapsed(),
342 false,
343 None,
344 Some(e.to_string()),
345 );
346 }
347 }
348 } else {
349 let tasks = server.tasks.clone();
351 let context_id = params.message.context_id.clone();
352 let spawn_task_id = task_id.clone();
353
354 tokio::spawn(async move {
355 let task_id = spawn_task_id;
356 let started_at = Instant::now();
357 let mut session = match Session::new().await {
358 Ok(s) => s,
359 Err(e) => {
360 tracing::error!("Failed to create session for task {}: {}", task_id, e);
361 if let Some(mut t) = tasks.get_mut(&task_id) {
362 t.status.state = TaskState::Failed;
363 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
364 }
365 record_a2a_message_telemetry(
366 "a2a_message_send",
367 &task_id,
368 false,
369 &prompt,
370 started_at.elapsed(),
371 false,
372 None,
373 Some(e.to_string()),
374 );
375 return;
376 }
377 };
378
379 match session.prompt(&prompt).await {
380 Ok(result) => {
381 let result_text = result.text;
382 let response_message = Message {
383 message_id: Uuid::new_v4().to_string(),
384 role: MessageRole::Agent,
385 parts: vec![Part::Text {
386 text: result_text.clone(),
387 }],
388 context_id,
389 task_id: Some(task_id.clone()),
390 metadata: std::collections::HashMap::new(),
391 extensions: vec![],
392 };
393
394 let artifact = Artifact {
395 artifact_id: Uuid::new_v4().to_string(),
396 parts: vec![Part::Text {
397 text: result_text.clone(),
398 }],
399 name: Some("response".to_string()),
400 description: None,
401 metadata: std::collections::HashMap::new(),
402 extensions: vec![],
403 };
404
405 if let Some(mut t) = tasks.get_mut(&task_id) {
406 t.status.state = TaskState::Completed;
407 t.status.message = Some(response_message.clone());
408 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
409 t.artifacts.push(artifact);
410 t.history.push(response_message);
411 }
412
413 record_a2a_message_telemetry(
414 "a2a_message_send",
415 &task_id,
416 false,
417 &prompt,
418 started_at.elapsed(),
419 true,
420 Some(result_text),
421 None,
422 );
423 }
424 Err(e) => {
425 tracing::error!("Task {} failed: {}", task_id, e);
426 if let Some(mut t) = tasks.get_mut(&task_id) {
427 t.status.state = TaskState::Failed;
428 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
429 }
430 record_a2a_message_telemetry(
431 "a2a_message_send",
432 &task_id,
433 false,
434 &prompt,
435 started_at.elapsed(),
436 false,
437 None,
438 Some(e.to_string()),
439 );
440 }
441 }
442 });
443 }
444
445 let task = server.tasks.get(&task_id).unwrap();
447 let response = SendMessageResponse::Task(task.value().clone());
448 serde_json::to_value(response)
449 .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
450}
451
452async fn handle_message_stream(
453 server: &A2AServer,
454 request: JsonRpcRequest,
455) -> Result<serde_json::Value, JsonRpcError> {
456 let params: MessageSendParams = serde_json::from_value(request.params)
461 .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
462
463 let task_id = params
464 .message
465 .task_id
466 .clone()
467 .unwrap_or_else(|| Uuid::new_v4().to_string());
468
469 let task = Task {
470 id: task_id.clone(),
471 context_id: params.message.context_id.clone(),
472 status: TaskStatus {
473 state: TaskState::Working,
474 message: Some(params.message.clone()),
475 timestamp: Some(chrono::Utc::now().to_rfc3339()),
476 },
477 artifacts: vec![],
478 history: vec![params.message.clone()],
479 metadata: std::collections::HashMap::new(),
480 };
481
482 server.tasks.insert(task_id.clone(), task.clone());
483
484 let prompt: String = params
486 .message
487 .parts
488 .iter()
489 .filter_map(|p| match p {
490 Part::Text { text } => Some(text.as_str()),
491 _ => None,
492 })
493 .collect::<Vec<_>>()
494 .join("\n");
495
496 if prompt.is_empty() {
497 if let Some(mut t) = server.tasks.get_mut(&task_id) {
498 t.status.state = TaskState::Failed;
499 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
500 }
501 return Err(JsonRpcError::invalid_params("No text content in message"));
502 }
503
504 let tasks = server.tasks.clone();
506 let context_id = params.message.context_id.clone();
507 let spawn_task_id = task_id.clone();
508
509 tokio::spawn(async move {
510 let task_id = spawn_task_id;
511 let started_at = Instant::now();
512 let mut session = match Session::new().await {
513 Ok(s) => s,
514 Err(e) => {
515 tracing::error!(
516 "Failed to create session for stream task {}: {}",
517 task_id,
518 e
519 );
520 if let Some(mut t) = tasks.get_mut(&task_id) {
521 t.status.state = TaskState::Failed;
522 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
523 }
524 record_a2a_message_telemetry(
525 "a2a_message_stream",
526 &task_id,
527 false,
528 &prompt,
529 started_at.elapsed(),
530 false,
531 None,
532 Some(e.to_string()),
533 );
534 return;
535 }
536 };
537
538 match session.prompt(&prompt).await {
539 Ok(result) => {
540 let result_text = result.text;
541 let response_message = Message {
542 message_id: Uuid::new_v4().to_string(),
543 role: MessageRole::Agent,
544 parts: vec![Part::Text {
545 text: result_text.clone(),
546 }],
547 context_id,
548 task_id: Some(task_id.clone()),
549 metadata: std::collections::HashMap::new(),
550 extensions: vec![],
551 };
552
553 let artifact = Artifact {
554 artifact_id: Uuid::new_v4().to_string(),
555 parts: vec![Part::Text {
556 text: result_text.clone(),
557 }],
558 name: Some("response".to_string()),
559 description: None,
560 metadata: std::collections::HashMap::new(),
561 extensions: vec![],
562 };
563
564 if let Some(mut t) = tasks.get_mut(&task_id) {
565 t.status.state = TaskState::Completed;
566 t.status.message = Some(response_message.clone());
567 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
568 t.artifacts.push(artifact.clone());
569 t.history.push(response_message);
570
571 let status_event = TaskStatusUpdateEvent {
573 id: task_id.clone(),
574 status: t.status.clone(),
575 is_final: true,
576 metadata: std::collections::HashMap::new(),
577 };
578 let artifact_event = TaskArtifactUpdateEvent {
579 id: task_id.clone(),
580 artifact,
581 metadata: std::collections::HashMap::new(),
582 };
583 tracing::debug!(
584 task_id = %task_id,
585 event = ?StreamEvent::StatusUpdate(status_event),
586 "Task completed"
587 );
588 tracing::debug!(
589 task_id = %task_id,
590 event = ?StreamEvent::ArtifactUpdate(artifact_event),
591 "Artifact produced"
592 );
593 }
594
595 record_a2a_message_telemetry(
596 "a2a_message_stream",
597 &task_id,
598 false,
599 &prompt,
600 started_at.elapsed(),
601 true,
602 Some(result_text),
603 None,
604 );
605 }
606 Err(e) => {
607 tracing::error!("Stream task {} failed: {}", task_id, e);
608 if let Some(mut t) = tasks.get_mut(&task_id) {
609 t.status.state = TaskState::Failed;
610 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
611 }
612 record_a2a_message_telemetry(
613 "a2a_message_stream",
614 &task_id,
615 false,
616 &prompt,
617 started_at.elapsed(),
618 false,
619 None,
620 Some(e.to_string()),
621 );
622 }
623 }
624 });
625
626 let response = SendMessageResponse::Task(task);
628 serde_json::to_value(response)
629 .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
630}
631
632async fn handle_tasks_get(
633 server: &A2AServer,
634 request: JsonRpcRequest,
635) -> Result<serde_json::Value, JsonRpcError> {
636 let params: TaskQueryParams = serde_json::from_value(request.params)
637 .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
638
639 let task = server.tasks.get(¶ms.id).ok_or_else(|| JsonRpcError {
640 code: TASK_NOT_FOUND,
641 message: format!("Task not found: {}", params.id),
642 data: None,
643 })?;
644
645 serde_json::to_value(task.value().clone())
646 .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
647}
648
649async fn handle_tasks_cancel(
650 server: &A2AServer,
651 request: JsonRpcRequest,
652) -> Result<serde_json::Value, JsonRpcError> {
653 let params: TaskQueryParams = serde_json::from_value(request.params)
654 .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
655
656 let mut task = server
657 .tasks
658 .get_mut(¶ms.id)
659 .ok_or_else(|| JsonRpcError {
660 code: TASK_NOT_FOUND,
661 message: format!("Task not found: {}", params.id),
662 data: None,
663 })?;
664
665 if !task.status.state.is_active() {
666 return Err(JsonRpcError {
667 code: TASK_NOT_CANCELABLE,
668 message: "Task is already in a terminal state".to_string(),
669 data: None,
670 });
671 }
672
673 task.status.state = TaskState::Cancelled;
674 task.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
675
676 serde_json::to_value(task.value().clone())
677 .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
678}