Skip to main content

codetether_agent/a2a/
server.rs

1//! A2A Server - serve as an A2A agent
2
3use super::types::*;
4use crate::session::Session;
5use crate::telemetry::{ToolExecution, record_persistent};
6use anyhow::Result;
7use axum::{
8    Router,
9    extract::State,
10    http::StatusCode,
11    response::Json,
12    routing::{get, post},
13};
14use dashmap::DashMap;
15use std::sync::Arc;
16use std::time::{Duration, Instant};
17use uuid::Uuid;
18
19/// A2A Server state
20#[derive(Clone)]
21pub struct A2AServer {
22    tasks: Arc<DashMap<String, Task>>,
23    agent_card: AgentCard,
24}
25
26impl A2AServer {
27    /// Create a new A2A server
28    pub fn new(agent_card: AgentCard) -> Self {
29        Self {
30            tasks: Arc::new(DashMap::new()),
31            agent_card,
32        }
33    }
34
35    /// Create the router for A2A endpoints
36    pub fn router(self) -> Router {
37        Router::new()
38            .route("/.well-known/agent.json", get(get_agent_card))
39            .route("/.well-known/agent-card.json", get(get_agent_card))
40            .route("/", post(handle_rpc))
41            .with_state(self)
42    }
43
44    /// Get the agent card for this server
45    #[allow(dead_code)]
46    pub fn card(&self) -> &AgentCard {
47        &self.agent_card
48    }
49
50    /// Create a default agent card
51    pub fn default_card(url: &str) -> AgentCard {
52        AgentCard {
53            name: "CodeTether Agent".to_string(),
54            description: "A2A-native AI coding agent for the CodeTether ecosystem".to_string(),
55            url: url.to_string(),
56            version: env!("CARGO_PKG_VERSION").to_string(),
57            protocol_version: "0.3.0".to_string(),
58            preferred_transport: None,
59            additional_interfaces: vec![],
60            capabilities: AgentCapabilities {
61                streaming: true,
62                push_notifications: false,
63                state_transition_history: true,
64                extensions: vec![],
65            },
66            skills: vec![
67                AgentSkill {
68                    id: "code".to_string(),
69                    name: "Code Generation".to_string(),
70                    description: "Write, edit, and refactor code".to_string(),
71                    tags: vec!["code".to_string(), "programming".to_string()],
72                    examples: vec![
73                        "Write a function to parse JSON".to_string(),
74                        "Refactor this code to use async/await".to_string(),
75                    ],
76                    input_modes: vec!["text/plain".to_string()],
77                    output_modes: vec!["text/plain".to_string()],
78                },
79                AgentSkill {
80                    id: "debug".to_string(),
81                    name: "Debugging".to_string(),
82                    description: "Debug and fix code issues".to_string(),
83                    tags: vec!["debug".to_string(), "fix".to_string()],
84                    examples: vec![
85                        "Why is this function returning undefined?".to_string(),
86                        "Fix the null pointer exception".to_string(),
87                    ],
88                    input_modes: vec!["text/plain".to_string()],
89                    output_modes: vec!["text/plain".to_string()],
90                },
91                AgentSkill {
92                    id: "explain".to_string(),
93                    name: "Code Explanation".to_string(),
94                    description: "Explain code and concepts".to_string(),
95                    tags: vec!["explain".to_string(), "learn".to_string()],
96                    examples: vec![
97                        "Explain how this algorithm works".to_string(),
98                        "What does this regex do?".to_string(),
99                    ],
100                    input_modes: vec!["text/plain".to_string()],
101                    output_modes: vec!["text/plain".to_string()],
102                },
103            ],
104            default_input_modes: vec!["text/plain".to_string(), "application/json".to_string()],
105            default_output_modes: vec!["text/plain".to_string(), "application/json".to_string()],
106            provider: Some(AgentProvider {
107                organization: "CodeTether".to_string(),
108                url: "https://codetether.run".to_string(),
109            }),
110            icon_url: None,
111            documentation_url: None,
112            security_schemes: Default::default(),
113            security: vec![],
114            supports_authenticated_extended_card: false,
115            signatures: vec![],
116        }
117    }
118}
119
120/// Get agent card handler
121async fn get_agent_card(State(server): State<A2AServer>) -> Json<AgentCard> {
122    Json(server.agent_card.clone())
123}
124
125fn record_a2a_message_telemetry(
126    tool_name: &str,
127    task_id: &str,
128    blocking: bool,
129    prompt: &str,
130    duration: Duration,
131    success: bool,
132    output: Option<String>,
133    error: Option<String>,
134) {
135    let record = crate::telemetry::A2AMessageRecord {
136        tool_name: tool_name.to_string(),
137        task_id: task_id.to_string(),
138        blocking,
139        prompt: prompt.to_string(),
140        duration_ms: duration.as_millis() as u64,
141        success,
142        output,
143        error,
144        timestamp: chrono::Utc::now(),
145    };
146    let _ = record_persistent("a2a_message", &serde_json::to_value(&record).unwrap_or_default());
147}
148
149/// Handle JSON-RPC requests
150async fn handle_rpc(
151    State(server): State<A2AServer>,
152    Json(request): Json<JsonRpcRequest>,
153) -> Result<Json<JsonRpcResponse>, (StatusCode, Json<JsonRpcResponse>)> {
154    let request_id = request.id.clone();
155    let response = match request.method.as_str() {
156        "message/send" => handle_message_send(&server, request).await,
157        "message/stream" => handle_message_stream(&server, request).await,
158        "tasks/get" => handle_tasks_get(&server, request).await,
159        "tasks/cancel" => handle_tasks_cancel(&server, request).await,
160        _ => Err(JsonRpcError::method_not_found(&request.method)),
161    };
162
163    match response {
164        Ok(result) => Ok(Json(JsonRpcResponse {
165            jsonrpc: "2.0".to_string(),
166            id: request_id.clone(),
167            result: Some(result),
168            error: None,
169        })),
170        Err(error) => Err((
171            StatusCode::OK,
172            Json(JsonRpcResponse {
173                jsonrpc: "2.0".to_string(),
174                id: request_id,
175                result: None,
176                error: Some(error),
177            }),
178        )),
179    }
180}
181
182async fn handle_message_send(
183    server: &A2AServer,
184    request: JsonRpcRequest,
185) -> Result<serde_json::Value, JsonRpcError> {
186    let params: MessageSendParams = serde_json::from_value(request.params)
187        .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
188
189    // Create a new task
190    let task_id = params
191        .message
192        .task_id
193        .clone()
194        .unwrap_or_else(|| Uuid::new_v4().to_string());
195
196    let task = Task {
197        id: task_id.clone(),
198        context_id: params.message.context_id.clone(),
199        status: TaskStatus {
200            state: TaskState::Working,
201            message: Some(params.message.clone()),
202            timestamp: Some(chrono::Utc::now().to_rfc3339()),
203        },
204        artifacts: vec![],
205        history: vec![params.message.clone()],
206        metadata: std::collections::HashMap::new(),
207    };
208
209    server.tasks.insert(task_id.clone(), task.clone());
210
211    // Extract prompt text from message parts
212    let prompt: String = params
213        .message
214        .parts
215        .iter()
216        .filter_map(|p| match p {
217            Part::Text { text } => Some(text.as_str()),
218            _ => None,
219        })
220        .collect::<Vec<_>>()
221        .join("\n");
222
223    if prompt.is_empty() {
224        // Update task to failed
225        if let Some(mut t) = server.tasks.get_mut(&task_id) {
226            t.status.state = TaskState::Failed;
227            t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
228        }
229        return Err(JsonRpcError::invalid_params("No text content in message"));
230    }
231
232    // Determine if blocking (default true for message/send)
233    let blocking = params
234        .configuration
235        .as_ref()
236        .and_then(|c| c.blocking)
237        .unwrap_or(true);
238
239    if blocking {
240        // Synchronous execution: create session, run prompt, return completed task
241        let mut session = Session::new().await.map_err(|e| {
242            JsonRpcError::internal_error(format!("Failed to create session: {}", e))
243        })?;
244        let started_at = Instant::now();
245
246        match session.prompt(&prompt).await {
247            Ok(result) => {
248                let result_text = result.text;
249                let response_message = Message {
250                    message_id: Uuid::new_v4().to_string(),
251                    role: MessageRole::Agent,
252                    parts: vec![Part::Text {
253                        text: result_text.clone(),
254                    }],
255                    context_id: params.message.context_id.clone(),
256                    task_id: Some(task_id.clone()),
257                    metadata: std::collections::HashMap::new(),
258                    extensions: vec![],
259                };
260
261                let artifact = Artifact {
262                    artifact_id: Uuid::new_v4().to_string(),
263                    parts: vec![Part::Text {
264                        text: result_text.clone(),
265                    }],
266                    name: Some("response".to_string()),
267                    description: None,
268                    metadata: std::collections::HashMap::new(),
269                    extensions: vec![],
270                };
271
272                if let Some(mut t) = server.tasks.get_mut(&task_id) {
273                    t.status.state = TaskState::Completed;
274                    t.status.message = Some(response_message.clone());
275                    t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
276                    t.artifacts.push(artifact.clone());
277                    t.history.push(response_message);
278
279                    let status_event = TaskStatusUpdateEvent {
280                        id: task_id.clone(),
281                        status: t.status.clone(),
282                        is_final: true,
283                        metadata: std::collections::HashMap::new(),
284                    };
285                    let artifact_event = TaskArtifactUpdateEvent {
286                        id: task_id.clone(),
287                        artifact,
288                        metadata: std::collections::HashMap::new(),
289                    };
290                    tracing::debug!(
291                        task_id = %task_id,
292                        event = ?StreamEvent::StatusUpdate(status_event),
293                        "Task completed"
294                    );
295                    tracing::debug!(
296                        task_id = %task_id,
297                        event = ?StreamEvent::ArtifactUpdate(artifact_event),
298                        "Artifact produced"
299                    );
300                }
301
302                record_a2a_message_telemetry(
303                    "a2a_message_send",
304                    &task_id,
305                    true,
306                    &prompt,
307                    started_at.elapsed(),
308                    true,
309                    Some(result_text),
310                    None,
311                );
312            }
313            Err(e) => {
314                let error_message = Message {
315                    message_id: Uuid::new_v4().to_string(),
316                    role: MessageRole::Agent,
317                    parts: vec![Part::Text {
318                        text: format!("Error: {}", e),
319                    }],
320                    context_id: params.message.context_id.clone(),
321                    task_id: Some(task_id.clone()),
322                    metadata: std::collections::HashMap::new(),
323                    extensions: vec![],
324                };
325
326                if let Some(mut t) = server.tasks.get_mut(&task_id) {
327                    t.status.state = TaskState::Failed;
328                    t.status.message = Some(error_message);
329                    t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
330                }
331
332                record_a2a_message_telemetry(
333                    "a2a_message_send",
334                    &task_id,
335                    true,
336                    &prompt,
337                    started_at.elapsed(),
338                    false,
339                    None,
340                    Some(e.to_string()),
341                );
342            }
343        }
344    } else {
345        // Async execution: spawn background task, return immediately with Working state
346        let tasks = server.tasks.clone();
347        let context_id = params.message.context_id.clone();
348        let spawn_task_id = task_id.clone();
349
350        tokio::spawn(async move {
351            let task_id = spawn_task_id;
352            let started_at = Instant::now();
353            let mut session = match Session::new().await {
354                Ok(s) => s,
355                Err(e) => {
356                    tracing::error!("Failed to create session for task {}: {}", task_id, e);
357                    if let Some(mut t) = tasks.get_mut(&task_id) {
358                        t.status.state = TaskState::Failed;
359                        t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
360                    }
361                    record_a2a_message_telemetry(
362                        "a2a_message_send",
363                        &task_id,
364                        false,
365                        &prompt,
366                        started_at.elapsed(),
367                        false,
368                        None,
369                        Some(e.to_string()),
370                    );
371                    return;
372                }
373            };
374
375            match session.prompt(&prompt).await {
376                Ok(result) => {
377                    let result_text = result.text;
378                    let response_message = Message {
379                        message_id: Uuid::new_v4().to_string(),
380                        role: MessageRole::Agent,
381                        parts: vec![Part::Text {
382                            text: result_text.clone(),
383                        }],
384                        context_id,
385                        task_id: Some(task_id.clone()),
386                        metadata: std::collections::HashMap::new(),
387                        extensions: vec![],
388                    };
389
390                    let artifact = Artifact {
391                        artifact_id: Uuid::new_v4().to_string(),
392                        parts: vec![Part::Text {
393                            text: result_text.clone(),
394                        }],
395                        name: Some("response".to_string()),
396                        description: None,
397                        metadata: std::collections::HashMap::new(),
398                        extensions: vec![],
399                    };
400
401                    if let Some(mut t) = tasks.get_mut(&task_id) {
402                        t.status.state = TaskState::Completed;
403                        t.status.message = Some(response_message.clone());
404                        t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
405                        t.artifacts.push(artifact);
406                        t.history.push(response_message);
407                    }
408
409                    record_a2a_message_telemetry(
410                        "a2a_message_send",
411                        &task_id,
412                        false,
413                        &prompt,
414                        started_at.elapsed(),
415                        true,
416                        Some(result_text),
417                        None,
418                    );
419                }
420                Err(e) => {
421                    tracing::error!("Task {} failed: {}", task_id, e);
422                    if let Some(mut t) = tasks.get_mut(&task_id) {
423                        t.status.state = TaskState::Failed;
424                        t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
425                    }
426                    record_a2a_message_telemetry(
427                        "a2a_message_send",
428                        &task_id,
429                        false,
430                        &prompt,
431                        started_at.elapsed(),
432                        false,
433                        None,
434                        Some(e.to_string()),
435                    );
436                }
437            }
438        });
439    }
440
441    // Return current task state wrapped in SendMessageResponse
442    let task = server.tasks.get(&task_id).unwrap();
443    let response = SendMessageResponse::Task(task.value().clone());
444    serde_json::to_value(response)
445        .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
446}
447
448async fn handle_message_stream(
449    server: &A2AServer,
450    request: JsonRpcRequest,
451) -> Result<serde_json::Value, JsonRpcError> {
452    // message/stream submits the task for async processing.
453    // The client should poll tasks/get for status updates.
454    // True SSE streaming requires a dedicated endpoint outside JSON-RPC.
455
456    let params: MessageSendParams = serde_json::from_value(request.params)
457        .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
458
459    let task_id = params
460        .message
461        .task_id
462        .clone()
463        .unwrap_or_else(|| Uuid::new_v4().to_string());
464
465    let task = Task {
466        id: task_id.clone(),
467        context_id: params.message.context_id.clone(),
468        status: TaskStatus {
469            state: TaskState::Working,
470            message: Some(params.message.clone()),
471            timestamp: Some(chrono::Utc::now().to_rfc3339()),
472        },
473        artifacts: vec![],
474        history: vec![params.message.clone()],
475        metadata: std::collections::HashMap::new(),
476    };
477
478    server.tasks.insert(task_id.clone(), task.clone());
479
480    // Extract prompt
481    let prompt: String = params
482        .message
483        .parts
484        .iter()
485        .filter_map(|p| match p {
486            Part::Text { text } => Some(text.as_str()),
487            _ => None,
488        })
489        .collect::<Vec<_>>()
490        .join("\n");
491
492    if prompt.is_empty() {
493        if let Some(mut t) = server.tasks.get_mut(&task_id) {
494            t.status.state = TaskState::Failed;
495            t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
496        }
497        return Err(JsonRpcError::invalid_params("No text content in message"));
498    }
499
500    // Spawn async processing
501    let tasks = server.tasks.clone();
502    let context_id = params.message.context_id.clone();
503    let spawn_task_id = task_id.clone();
504
505    tokio::spawn(async move {
506        let task_id = spawn_task_id;
507        let started_at = Instant::now();
508        let mut session = match Session::new().await {
509            Ok(s) => s,
510            Err(e) => {
511                tracing::error!(
512                    "Failed to create session for stream task {}: {}",
513                    task_id,
514                    e
515                );
516                if let Some(mut t) = tasks.get_mut(&task_id) {
517                    t.status.state = TaskState::Failed;
518                    t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
519                }
520                record_a2a_message_telemetry(
521                    "a2a_message_stream",
522                    &task_id,
523                    false,
524                    &prompt,
525                    started_at.elapsed(),
526                    false,
527                    None,
528                    Some(e.to_string()),
529                );
530                return;
531            }
532        };
533
534        match session.prompt(&prompt).await {
535            Ok(result) => {
536                let result_text = result.text;
537                let response_message = Message {
538                    message_id: Uuid::new_v4().to_string(),
539                    role: MessageRole::Agent,
540                    parts: vec![Part::Text {
541                        text: result_text.clone(),
542                    }],
543                    context_id,
544                    task_id: Some(task_id.clone()),
545                    metadata: std::collections::HashMap::new(),
546                    extensions: vec![],
547                };
548
549                let artifact = Artifact {
550                    artifact_id: Uuid::new_v4().to_string(),
551                    parts: vec![Part::Text {
552                        text: result_text.clone(),
553                    }],
554                    name: Some("response".to_string()),
555                    description: None,
556                    metadata: std::collections::HashMap::new(),
557                    extensions: vec![],
558                };
559
560                if let Some(mut t) = tasks.get_mut(&task_id) {
561                    t.status.state = TaskState::Completed;
562                    t.status.message = Some(response_message.clone());
563                    t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
564                    t.artifacts.push(artifact.clone());
565                    t.history.push(response_message);
566
567                    // Emit streaming events for SSE consumers
568                    let status_event = TaskStatusUpdateEvent {
569                        id: task_id.clone(),
570                        status: t.status.clone(),
571                        is_final: true,
572                        metadata: std::collections::HashMap::new(),
573                    };
574                    let artifact_event = TaskArtifactUpdateEvent {
575                        id: task_id.clone(),
576                        artifact,
577                        metadata: std::collections::HashMap::new(),
578                    };
579                    tracing::debug!(
580                        task_id = %task_id,
581                        event = ?StreamEvent::StatusUpdate(status_event),
582                        "Task completed"
583                    );
584                    tracing::debug!(
585                        task_id = %task_id,
586                        event = ?StreamEvent::ArtifactUpdate(artifact_event),
587                        "Artifact produced"
588                    );
589                }
590
591                record_a2a_message_telemetry(
592                    "a2a_message_stream",
593                    &task_id,
594                    false,
595                    &prompt,
596                    started_at.elapsed(),
597                    true,
598                    Some(result_text),
599                    None,
600                );
601            }
602            Err(e) => {
603                tracing::error!("Stream task {} failed: {}", task_id, e);
604                if let Some(mut t) = tasks.get_mut(&task_id) {
605                    t.status.state = TaskState::Failed;
606                    t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
607                }
608                record_a2a_message_telemetry(
609                    "a2a_message_stream",
610                    &task_id,
611                    false,
612                    &prompt,
613                    started_at.elapsed(),
614                    false,
615                    None,
616                    Some(e.to_string()),
617                );
618            }
619        }
620    });
621
622    // Return task in Working state — client polls tasks/get for completion
623    let response = SendMessageResponse::Task(task);
624    serde_json::to_value(response)
625        .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
626}
627
628async fn handle_tasks_get(
629    server: &A2AServer,
630    request: JsonRpcRequest,
631) -> Result<serde_json::Value, JsonRpcError> {
632    let params: TaskQueryParams = serde_json::from_value(request.params)
633        .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
634
635    let task = server.tasks.get(&params.id).ok_or_else(|| JsonRpcError {
636        code: TASK_NOT_FOUND,
637        message: format!("Task not found: {}", params.id),
638        data: None,
639    })?;
640
641    serde_json::to_value(task.value().clone())
642        .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
643}
644
645async fn handle_tasks_cancel(
646    server: &A2AServer,
647    request: JsonRpcRequest,
648) -> Result<serde_json::Value, JsonRpcError> {
649    let params: TaskQueryParams = serde_json::from_value(request.params)
650        .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
651
652    let mut task = server
653        .tasks
654        .get_mut(&params.id)
655        .ok_or_else(|| JsonRpcError {
656            code: TASK_NOT_FOUND,
657            message: format!("Task not found: {}", params.id),
658            data: None,
659        })?;
660
661    if !task.status.state.is_active() {
662        return Err(JsonRpcError {
663            code: TASK_NOT_CANCELABLE,
664            message: "Task is already in a terminal state".to_string(),
665            data: None,
666        });
667    }
668
669    task.status.state = TaskState::Cancelled;
670    task.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
671
672    serde_json::to_value(task.value().clone())
673        .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
674}