1use super::types::*;
4use crate::session::{Session, SessionEvent};
5use crate::telemetry::record_persistent;
6use anyhow::Result;
7use axum::{
8 Router,
9 extract::State,
10 http::StatusCode,
11 response::Json,
12 routing::{get, post},
13};
14use dashmap::DashMap;
15use std::sync::Arc;
16use std::time::{Duration, Instant};
17use tokio::sync::mpsc;
18use uuid::Uuid;
19
20#[derive(Clone)]
22pub struct A2AServer {
23 tasks: Arc<DashMap<String, Task>>,
24 agent_card: AgentCard,
25 bus: Option<Arc<crate::bus::AgentBus>>,
27}
28
29impl A2AServer {
30 pub fn new(agent_card: AgentCard) -> Self {
32 Self {
33 tasks: Arc::new(DashMap::new()),
34 agent_card,
35 bus: None,
36 }
37 }
38
39 pub fn with_bus(agent_card: AgentCard, bus: Arc<crate::bus::AgentBus>) -> Self {
41 Self {
42 tasks: Arc::new(DashMap::new()),
43 agent_card,
44 bus: Some(bus),
45 }
46 }
47
48 pub fn router(self) -> Router {
50 Router::new()
51 .route("/.well-known/agent.json", get(get_agent_card))
52 .route("/.well-known/agent-card.json", get(get_agent_card))
53 .route("/", post(handle_rpc))
54 .with_state(self)
55 }
56
57 #[allow(dead_code)]
59 pub fn card(&self) -> &AgentCard {
60 &self.agent_card
61 }
62
63 pub fn default_card(url: &str) -> AgentCard {
65 AgentCard {
66 name: "CodeTether Agent".to_string(),
67 description: "A2A-native AI coding agent for the CodeTether ecosystem".to_string(),
68 url: url.to_string(),
69 version: env!("CARGO_PKG_VERSION").to_string(),
70 protocol_version: "0.3.0".to_string(),
71 preferred_transport: None,
72 additional_interfaces: vec![],
73 capabilities: AgentCapabilities {
74 streaming: true,
75 push_notifications: false,
76 state_transition_history: true,
77 extensions: vec![],
78 },
79 skills: vec![
80 AgentSkill {
81 id: "code".to_string(),
82 name: "Code Generation".to_string(),
83 description: "Write, edit, and refactor code".to_string(),
84 tags: vec!["code".to_string(), "programming".to_string()],
85 examples: vec![
86 "Write a function to parse JSON".to_string(),
87 "Refactor this code to use async/await".to_string(),
88 ],
89 input_modes: vec!["text/plain".to_string()],
90 output_modes: vec!["text/plain".to_string()],
91 },
92 AgentSkill {
93 id: "debug".to_string(),
94 name: "Debugging".to_string(),
95 description: "Debug and fix code issues".to_string(),
96 tags: vec!["debug".to_string(), "fix".to_string()],
97 examples: vec![
98 "Why is this function returning undefined?".to_string(),
99 "Fix the null pointer exception".to_string(),
100 ],
101 input_modes: vec!["text/plain".to_string()],
102 output_modes: vec!["text/plain".to_string()],
103 },
104 AgentSkill {
105 id: "explain".to_string(),
106 name: "Code Explanation".to_string(),
107 description: "Explain code and concepts".to_string(),
108 tags: vec!["explain".to_string(), "learn".to_string()],
109 examples: vec![
110 "Explain how this algorithm works".to_string(),
111 "What does this regex do?".to_string(),
112 ],
113 input_modes: vec!["text/plain".to_string()],
114 output_modes: vec!["text/plain".to_string()],
115 },
116 ],
117 default_input_modes: vec!["text/plain".to_string(), "application/json".to_string()],
118 default_output_modes: vec!["text/plain".to_string(), "application/json".to_string()],
119 provider: Some(AgentProvider {
120 organization: "CodeTether".to_string(),
121 url: "https://codetether.run".to_string(),
122 }),
123 icon_url: None,
124 documentation_url: None,
125 security_schemes: Default::default(),
126 security: vec![],
127 supports_authenticated_extended_card: false,
128 signatures: vec![],
129 }
130 }
131}
132
133async fn get_agent_card(State(server): State<A2AServer>) -> Json<AgentCard> {
135 Json(server.agent_card.clone())
136}
137
138async fn configure_a2a_session(session: &mut Session) {
139 let configured_model = std::env::var("CODETETHER_DEFAULT_MODEL")
140 .ok()
141 .map(|value| value.trim().to_string())
142 .filter(|value| !value.is_empty());
143
144 let configured_model = match configured_model {
145 Some(model) => Some(model),
146 None => match crate::config::Config::load().await {
147 Ok(config) => config
148 .default_model
149 .filter(|value| !value.trim().is_empty()),
150 Err(e) => {
151 tracing::debug!(error = %e, "Failed to load config for A2A session model");
152 None
153 }
154 },
155 };
156
157 if let Some(model) = configured_model {
158 session.metadata.model = Some(model);
159 }
160}
161
162fn record_a2a_message_telemetry(
163 tool_name: &str,
164 task_id: &str,
165 blocking: bool,
166 prompt: &str,
167 duration: Duration,
168 success: bool,
169 output: Option<String>,
170 error: Option<String>,
171) {
172 let record = crate::telemetry::A2AMessageRecord {
173 tool_name: tool_name.to_string(),
174 task_id: task_id.to_string(),
175 blocking,
176 prompt: prompt.to_string(),
177 duration_ms: duration.as_millis() as u64,
178 success,
179 output,
180 error,
181 timestamp: chrono::Utc::now(),
182 };
183 let _ = record_persistent(
184 "a2a_message",
185 &serde_json::to_value(&record).unwrap_or_default(),
186 );
187}
188
189async fn handle_rpc(
191 State(server): State<A2AServer>,
192 Json(request): Json<JsonRpcRequest>,
193) -> Result<Json<JsonRpcResponse>, (StatusCode, Json<JsonRpcResponse>)> {
194 let request_id = request.id.clone();
195 let response = match request.method.as_str() {
196 "message/send" => handle_message_send(&server, request).await,
197 "message/stream" => handle_message_stream(&server, request).await,
198 "tasks/get" => handle_tasks_get(&server, request).await,
199 "tasks/cancel" => handle_tasks_cancel(&server, request).await,
200 _ => Err(JsonRpcError::method_not_found(&request.method)),
201 };
202
203 match response {
204 Ok(result) => Ok(Json(JsonRpcResponse {
205 jsonrpc: "2.0".to_string(),
206 id: request_id.clone(),
207 result: Some(result),
208 error: None,
209 })),
210 Err(error) => Err((
211 StatusCode::OK,
212 Json(JsonRpcResponse {
213 jsonrpc: "2.0".to_string(),
214 id: request_id,
215 result: None,
216 error: Some(error),
217 }),
218 )),
219 }
220}
221
222async fn handle_message_send(
223 server: &A2AServer,
224 request: JsonRpcRequest,
225) -> Result<serde_json::Value, JsonRpcError> {
226 let params: MessageSendParams = serde_json::from_value(request.params)
227 .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
228
229 let task_id = params
231 .message
232 .task_id
233 .clone()
234 .unwrap_or_else(|| Uuid::new_v4().to_string());
235
236 let task = Task {
237 id: task_id.clone(),
238 context_id: params.message.context_id.clone(),
239 status: TaskStatus {
240 state: TaskState::Working,
241 message: Some(params.message.clone()),
242 timestamp: Some(chrono::Utc::now().to_rfc3339()),
243 },
244 artifacts: vec![],
245 history: vec![params.message.clone()],
246 metadata: std::collections::HashMap::new(),
247 };
248
249 server.tasks.insert(task_id.clone(), task.clone());
250
251 let prompt: String = params
253 .message
254 .parts
255 .iter()
256 .filter_map(|p| match p {
257 Part::Text { text } => Some(text.as_str()),
258 _ => None,
259 })
260 .collect::<Vec<_>>()
261 .join("\n");
262
263 if prompt.is_empty() {
264 if let Some(mut t) = server.tasks.get_mut(&task_id) {
266 t.status.state = TaskState::Failed;
267 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
268 }
269 return Err(JsonRpcError::invalid_params("No text content in message"));
270 }
271
272 let blocking = params
274 .configuration
275 .as_ref()
276 .and_then(|c| c.blocking)
277 .unwrap_or(true);
278
279 if blocking {
280 let mut session = Session::new().await.map_err(|e| {
282 JsonRpcError::internal_error(format!("Failed to create session: {}", e))
283 })?;
284 configure_a2a_session(&mut session).await;
285 let started_at = Instant::now();
286
287 match session.prompt(&prompt).await {
288 Ok(result) => {
289 let result_text = result.text;
290 let response_message = Message {
291 message_id: Uuid::new_v4().to_string(),
292 role: MessageRole::Agent,
293 parts: vec![Part::Text {
294 text: result_text.clone(),
295 }],
296 context_id: params.message.context_id.clone(),
297 task_id: Some(task_id.clone()),
298 metadata: std::collections::HashMap::new(),
299 extensions: vec![],
300 };
301
302 let artifact = Artifact {
303 artifact_id: Uuid::new_v4().to_string(),
304 parts: vec![Part::Text {
305 text: result_text.clone(),
306 }],
307 name: Some("response".to_string()),
308 description: None,
309 metadata: std::collections::HashMap::new(),
310 extensions: vec![],
311 };
312
313 if let Some(mut t) = server.tasks.get_mut(&task_id) {
314 t.status.state = TaskState::Completed;
315 t.status.message = Some(response_message.clone());
316 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
317 t.artifacts.push(artifact.clone());
318 t.history.push(response_message);
319
320 let status_event = TaskStatusUpdateEvent {
321 id: task_id.clone(),
322 status: t.status.clone(),
323 is_final: true,
324 metadata: std::collections::HashMap::new(),
325 };
326 let artifact_event = TaskArtifactUpdateEvent {
327 id: task_id.clone(),
328 artifact,
329 metadata: std::collections::HashMap::new(),
330 };
331 tracing::debug!(
332 task_id = %task_id,
333 event = ?StreamEvent::StatusUpdate(status_event),
334 "Task completed"
335 );
336 tracing::debug!(
337 task_id = %task_id,
338 event = ?StreamEvent::ArtifactUpdate(artifact_event),
339 "Artifact produced"
340 );
341 }
342
343 record_a2a_message_telemetry(
344 "a2a_message_send",
345 &task_id,
346 true,
347 &prompt,
348 started_at.elapsed(),
349 true,
350 Some(result_text),
351 None,
352 );
353 }
354 Err(e) => {
355 let error_message = Message {
356 message_id: Uuid::new_v4().to_string(),
357 role: MessageRole::Agent,
358 parts: vec![Part::Text {
359 text: format!("Error: {}", e),
360 }],
361 context_id: params.message.context_id.clone(),
362 task_id: Some(task_id.clone()),
363 metadata: std::collections::HashMap::new(),
364 extensions: vec![],
365 };
366
367 if let Some(mut t) = server.tasks.get_mut(&task_id) {
368 t.status.state = TaskState::Failed;
369 t.status.message = Some(error_message);
370 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
371 }
372
373 record_a2a_message_telemetry(
374 "a2a_message_send",
375 &task_id,
376 true,
377 &prompt,
378 started_at.elapsed(),
379 false,
380 None,
381 Some(e.to_string()),
382 );
383 }
384 }
385 } else {
386 let tasks = server.tasks.clone();
388 let context_id = params.message.context_id.clone();
389 let spawn_task_id = task_id.clone();
390
391 tokio::spawn(async move {
392 let task_id = spawn_task_id;
393 let started_at = Instant::now();
394 let mut session = match Session::new().await {
395 Ok(s) => s,
396 Err(e) => {
397 tracing::error!("Failed to create session for task {}: {}", task_id, e);
398 if let Some(mut t) = tasks.get_mut(&task_id) {
399 t.status.state = TaskState::Failed;
400 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
401 }
402 record_a2a_message_telemetry(
403 "a2a_message_send",
404 &task_id,
405 false,
406 &prompt,
407 started_at.elapsed(),
408 false,
409 None,
410 Some(e.to_string()),
411 );
412 return;
413 }
414 };
415 configure_a2a_session(&mut session).await;
416
417 match session.prompt(&prompt).await {
418 Ok(result) => {
419 let result_text = result.text;
420 let response_message = Message {
421 message_id: Uuid::new_v4().to_string(),
422 role: MessageRole::Agent,
423 parts: vec![Part::Text {
424 text: result_text.clone(),
425 }],
426 context_id,
427 task_id: Some(task_id.clone()),
428 metadata: std::collections::HashMap::new(),
429 extensions: vec![],
430 };
431
432 let artifact = Artifact {
433 artifact_id: Uuid::new_v4().to_string(),
434 parts: vec![Part::Text {
435 text: result_text.clone(),
436 }],
437 name: Some("response".to_string()),
438 description: None,
439 metadata: std::collections::HashMap::new(),
440 extensions: vec![],
441 };
442
443 if let Some(mut t) = tasks.get_mut(&task_id) {
444 t.status.state = TaskState::Completed;
445 t.status.message = Some(response_message.clone());
446 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
447 t.artifacts.push(artifact);
448 t.history.push(response_message);
449 }
450
451 record_a2a_message_telemetry(
452 "a2a_message_send",
453 &task_id,
454 false,
455 &prompt,
456 started_at.elapsed(),
457 true,
458 Some(result_text),
459 None,
460 );
461 }
462 Err(e) => {
463 tracing::error!("Task {} failed: {}", task_id, e);
464 if let Some(mut t) = tasks.get_mut(&task_id) {
465 t.status.state = TaskState::Failed;
466 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
467 }
468 record_a2a_message_telemetry(
469 "a2a_message_send",
470 &task_id,
471 false,
472 &prompt,
473 started_at.elapsed(),
474 false,
475 None,
476 Some(e.to_string()),
477 );
478 }
479 }
480 });
481 }
482
483 let task = server
485 .tasks
486 .get(&task_id)
487 .ok_or_else(|| JsonRpcError::internal_error(format!("Task disappeared: {}", task_id)))?;
488 let response = SendMessageResponse::Task(task.value().clone());
489 serde_json::to_value(response)
490 .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
491}
492
493async fn handle_message_stream(
494 server: &A2AServer,
495 request: JsonRpcRequest,
496) -> Result<serde_json::Value, JsonRpcError> {
497 let params: MessageSendParams = serde_json::from_value(request.params)
502 .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
503
504 let task_id = params
505 .message
506 .task_id
507 .clone()
508 .unwrap_or_else(|| Uuid::new_v4().to_string());
509
510 let task = Task {
511 id: task_id.clone(),
512 context_id: params.message.context_id.clone(),
513 status: TaskStatus {
514 state: TaskState::Working,
515 message: Some(params.message.clone()),
516 timestamp: Some(chrono::Utc::now().to_rfc3339()),
517 },
518 artifacts: vec![],
519 history: vec![params.message.clone()],
520 metadata: std::collections::HashMap::new(),
521 };
522
523 server.tasks.insert(task_id.clone(), task.clone());
524
525 let prompt: String = params
527 .message
528 .parts
529 .iter()
530 .filter_map(|p| match p {
531 Part::Text { text } => Some(text.as_str()),
532 _ => None,
533 })
534 .collect::<Vec<_>>()
535 .join("\n");
536
537 if prompt.is_empty() {
538 if let Some(mut t) = server.tasks.get_mut(&task_id) {
539 t.status.state = TaskState::Failed;
540 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
541 }
542 return Err(JsonRpcError::invalid_params("No text content in message"));
543 }
544
545 let tasks = server.tasks.clone();
547 let context_id = params.message.context_id.clone();
548 let spawn_task_id = task_id.clone();
549 let bus = server.bus.clone();
550
551 tokio::spawn(async move {
552 let task_id = spawn_task_id;
553 let started_at = Instant::now();
554
555 let (event_tx, mut event_rx) = mpsc::channel::<SessionEvent>(256);
557
558 let mut session = match Session::new().await {
559 Ok(s) => s,
560 Err(e) => {
561 tracing::error!(
562 "Failed to create session for stream task {}: {}",
563 task_id,
564 e
565 );
566 if let Some(mut t) = tasks.get_mut(&task_id) {
567 t.status.state = TaskState::Failed;
568 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
569 }
570 record_a2a_message_telemetry(
571 "a2a_message_stream",
572 &task_id,
573 false,
574 &prompt,
575 started_at.elapsed(),
576 false,
577 None,
578 Some(e.to_string()),
579 );
580 return;
581 }
582 };
583 configure_a2a_session(&mut session).await;
584
585 let bus_clone = bus.clone();
587 let task_id_clone = task_id.clone();
588 tokio::spawn(async move {
589 while let Some(event) = event_rx.recv().await {
590 let event_data = match &event {
591 SessionEvent::Thinking => {
592 serde_json::json!({ "type": "thinking" })
593 }
594 SessionEvent::ToolCallStart { name, arguments } => {
595 serde_json::json!({
596 "type": "tool_call_start",
597 "name": name,
598 "arguments": arguments
599 })
600 }
601 SessionEvent::ToolCallComplete {
602 name,
603 output,
604 success,
605 duration_ms,
606 } => {
607 serde_json::json!({
608 "type": "tool_call_complete",
609 "name": name,
610 "output": output.chars().take(500).collect::<String>(),
611 "success": success,
612 "duration_ms": duration_ms
613 })
614 }
615 SessionEvent::TextChunk(text) => {
616 serde_json::json!({ "type": "text_chunk", "text": text })
617 }
618 SessionEvent::TextComplete(text) => {
619 serde_json::json!({ "type": "text_complete", "text": text })
620 }
621 SessionEvent::ThinkingComplete(thought) => {
622 serde_json::json!({ "type": "thinking_complete", "thought": thought })
623 }
624 SessionEvent::UsageReport {
625 prompt_tokens,
626 completion_tokens,
627 duration_ms,
628 model,
629 } => {
630 serde_json::json!({
631 "type": "usage_report",
632 "prompt_tokens": prompt_tokens,
633 "completion_tokens": completion_tokens,
634 "duration_ms": duration_ms,
635 "model": model
636 })
637 }
638 SessionEvent::Done => {
639 serde_json::json!({ "type": "done" })
640 }
641 SessionEvent::Error(err) => {
642 serde_json::json!({ "type": "error", "error": err })
643 }
644 SessionEvent::SessionSync(_) => {
645 continue; }
647 _ => continue,
651 };
652
653 if let Some(ref bus) = bus_clone {
655 let handle = bus.handle("a2a-stream");
656 handle.send(
657 format!("task.{}", task_id_clone),
658 crate::bus::BusMessage::TaskUpdate {
659 task_id: task_id_clone.clone(),
660 state: crate::a2a::types::TaskState::Working,
661 message: Some(serde_json::to_string(&event_data).unwrap_or_default()),
662 },
663 );
664 }
665 }
666 });
667
668 let registry = match crate::provider::ProviderRegistry::from_vault().await {
670 Ok(r) => Arc::new(r),
671 Err(e) => {
672 tracing::error!("Failed to load provider registry: {}", e);
673 if let Some(mut t) = tasks.get_mut(&task_id) {
674 t.status.state = TaskState::Failed;
675 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
676 }
677 return;
678 }
679 };
680
681 match session
682 .prompt_with_events(&prompt, event_tx, registry)
683 .await
684 {
685 Ok(result) => {
686 let result_text = result.text;
687 let response_message = Message {
688 message_id: Uuid::new_v4().to_string(),
689 role: MessageRole::Agent,
690 parts: vec![Part::Text {
691 text: result_text.clone(),
692 }],
693 context_id,
694 task_id: Some(task_id.clone()),
695 metadata: std::collections::HashMap::new(),
696 extensions: vec![],
697 };
698
699 let artifact = Artifact {
700 artifact_id: Uuid::new_v4().to_string(),
701 parts: vec![Part::Text {
702 text: result_text.clone(),
703 }],
704 name: Some("response".to_string()),
705 description: None,
706 metadata: std::collections::HashMap::new(),
707 extensions: vec![],
708 };
709
710 if let Some(mut t) = tasks.get_mut(&task_id) {
711 t.status.state = TaskState::Completed;
712 t.status.message = Some(response_message.clone());
713 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
714 t.artifacts.push(artifact.clone());
715 t.history.push(response_message);
716
717 let status_event = TaskStatusUpdateEvent {
719 id: task_id.clone(),
720 status: t.status.clone(),
721 is_final: true,
722 metadata: std::collections::HashMap::new(),
723 };
724 let artifact_event = TaskArtifactUpdateEvent {
725 id: task_id.clone(),
726 artifact,
727 metadata: std::collections::HashMap::new(),
728 };
729 tracing::debug!(
730 task_id = %task_id,
731 event = ?StreamEvent::StatusUpdate(status_event),
732 "Task completed"
733 );
734 tracing::debug!(
735 task_id = %task_id,
736 event = ?StreamEvent::ArtifactUpdate(artifact_event),
737 "Artifact produced"
738 );
739 }
740
741 record_a2a_message_telemetry(
742 "a2a_message_stream",
743 &task_id,
744 false,
745 &prompt,
746 started_at.elapsed(),
747 true,
748 Some(result_text),
749 None,
750 );
751 }
752 Err(e) => {
753 tracing::error!("Stream task {} failed: {}", task_id, e);
754 if let Some(mut t) = tasks.get_mut(&task_id) {
755 t.status.state = TaskState::Failed;
756 t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
757 }
758 record_a2a_message_telemetry(
759 "a2a_message_stream",
760 &task_id,
761 false,
762 &prompt,
763 started_at.elapsed(),
764 false,
765 None,
766 Some(e.to_string()),
767 );
768 }
769 }
770 });
771
772 let response = SendMessageResponse::Task(task);
774 serde_json::to_value(response)
775 .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
776}
777
778async fn handle_tasks_get(
779 server: &A2AServer,
780 request: JsonRpcRequest,
781) -> Result<serde_json::Value, JsonRpcError> {
782 let params: TaskQueryParams = serde_json::from_value(request.params)
783 .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
784
785 let task = server.tasks.get(¶ms.id).ok_or_else(|| JsonRpcError {
786 code: TASK_NOT_FOUND,
787 message: format!("Task not found: {}", params.id),
788 data: None,
789 })?;
790
791 serde_json::to_value(task.value().clone())
792 .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
793}
794
795async fn handle_tasks_cancel(
796 server: &A2AServer,
797 request: JsonRpcRequest,
798) -> Result<serde_json::Value, JsonRpcError> {
799 let params: TaskQueryParams = serde_json::from_value(request.params)
800 .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
801
802 let mut task = server
803 .tasks
804 .get_mut(¶ms.id)
805 .ok_or_else(|| JsonRpcError {
806 code: TASK_NOT_FOUND,
807 message: format!("Task not found: {}", params.id),
808 data: None,
809 })?;
810
811 if !task.status.state.is_active() {
812 return Err(JsonRpcError {
813 code: TASK_NOT_CANCELABLE,
814 message: "Task is already in a terminal state".to_string(),
815 data: None,
816 });
817 }
818
819 task.status.state = TaskState::Cancelled;
820 task.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
821
822 serde_json::to_value(task.value().clone())
823 .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
824}