Skip to main content

codetether_agent/a2a/
server.rs

1//! A2A Server - serve as an A2A agent
2
3use super::types::*;
4use crate::session::Session;
5use crate::telemetry::{ToolExecution, record_persistent};
6use anyhow::Result;
7use axum::{
8    Router,
9    extract::State,
10    http::StatusCode,
11    response::Json,
12    routing::{get, post},
13};
14use dashmap::DashMap;
15use std::sync::Arc;
16use std::time::{Duration, Instant};
17use uuid::Uuid;
18
19/// A2A Server state
20#[derive(Clone)]
21pub struct A2AServer {
22    tasks: Arc<DashMap<String, Task>>,
23    agent_card: AgentCard,
24}
25
26impl A2AServer {
27    /// Create a new A2A server
28    pub fn new(agent_card: AgentCard) -> Self {
29        Self {
30            tasks: Arc::new(DashMap::new()),
31            agent_card,
32        }
33    }
34
35    /// Create the router for A2A endpoints
36    pub fn router(self) -> Router {
37        Router::new()
38            .route("/.well-known/agent.json", get(get_agent_card))
39            .route("/", post(handle_rpc))
40            .with_state(self)
41    }
42
43    /// Get the agent card for this server
44    #[allow(dead_code)]
45    pub fn card(&self) -> &AgentCard {
46        &self.agent_card
47    }
48
49    /// Create a default agent card
50    pub fn default_card(url: &str) -> AgentCard {
51        AgentCard {
52            name: "CodeTether Agent".to_string(),
53            description: "A2A-native AI coding agent for the CodeTether ecosystem".to_string(),
54            url: url.to_string(),
55            version: env!("CARGO_PKG_VERSION").to_string(),
56            protocol_version: "0.3.0".to_string(),
57            preferred_transport: None,
58            additional_interfaces: vec![],
59            capabilities: AgentCapabilities {
60                streaming: true,
61                push_notifications: false,
62                state_transition_history: true,
63                extensions: vec![],
64            },
65            skills: vec![
66                AgentSkill {
67                    id: "code".to_string(),
68                    name: "Code Generation".to_string(),
69                    description: "Write, edit, and refactor code".to_string(),
70                    tags: vec!["code".to_string(), "programming".to_string()],
71                    examples: vec![
72                        "Write a function to parse JSON".to_string(),
73                        "Refactor this code to use async/await".to_string(),
74                    ],
75                    input_modes: vec!["text/plain".to_string()],
76                    output_modes: vec!["text/plain".to_string()],
77                },
78                AgentSkill {
79                    id: "debug".to_string(),
80                    name: "Debugging".to_string(),
81                    description: "Debug and fix code issues".to_string(),
82                    tags: vec!["debug".to_string(), "fix".to_string()],
83                    examples: vec![
84                        "Why is this function returning undefined?".to_string(),
85                        "Fix the null pointer exception".to_string(),
86                    ],
87                    input_modes: vec!["text/plain".to_string()],
88                    output_modes: vec!["text/plain".to_string()],
89                },
90                AgentSkill {
91                    id: "explain".to_string(),
92                    name: "Code Explanation".to_string(),
93                    description: "Explain code and concepts".to_string(),
94                    tags: vec!["explain".to_string(), "learn".to_string()],
95                    examples: vec![
96                        "Explain how this algorithm works".to_string(),
97                        "What does this regex do?".to_string(),
98                    ],
99                    input_modes: vec!["text/plain".to_string()],
100                    output_modes: vec!["text/plain".to_string()],
101                },
102            ],
103            default_input_modes: vec!["text/plain".to_string(), "application/json".to_string()],
104            default_output_modes: vec!["text/plain".to_string(), "application/json".to_string()],
105            provider: Some(AgentProvider {
106                organization: "CodeTether".to_string(),
107                url: "https://codetether.run".to_string(),
108            }),
109            icon_url: None,
110            documentation_url: None,
111            security_schemes: Default::default(),
112            security: vec![],
113            supports_authenticated_extended_card: false,
114            signatures: vec![],
115        }
116    }
117}
118
119/// Get agent card handler
120async fn get_agent_card(State(server): State<A2AServer>) -> Json<AgentCard> {
121    Json(server.agent_card.clone())
122}
123
124fn record_a2a_message_telemetry(
125    tool_name: &str,
126    task_id: &str,
127    blocking: bool,
128    prompt: &str,
129    duration: Duration,
130    success: bool,
131    output: Option<String>,
132    error: Option<String>,
133) {
134    let input = serde_json::json!({
135        "task_id": task_id,
136        "blocking": blocking,
137        "prompt": prompt,
138    });
139
140    let execution = ToolExecution::start(tool_name, input).with_session(task_id.to_string());
141    let execution = if success {
142        execution.complete_success(output.unwrap_or_default(), duration)
143    } else {
144        execution.complete_error(
145            error.unwrap_or_else(|| "A2A message execution failed".to_string()),
146            duration,
147        )
148    };
149
150    record_persistent(execution);
151}
152
153/// Handle JSON-RPC requests
154async fn handle_rpc(
155    State(server): State<A2AServer>,
156    Json(request): Json<JsonRpcRequest>,
157) -> Result<Json<JsonRpcResponse>, (StatusCode, Json<JsonRpcResponse>)> {
158    let request_id = request.id.clone();
159    let response = match request.method.as_str() {
160        "message/send" => handle_message_send(&server, request).await,
161        "message/stream" => handle_message_stream(&server, request).await,
162        "tasks/get" => handle_tasks_get(&server, request).await,
163        "tasks/cancel" => handle_tasks_cancel(&server, request).await,
164        _ => Err(JsonRpcError::method_not_found(&request.method)),
165    };
166
167    match response {
168        Ok(result) => Ok(Json(JsonRpcResponse {
169            jsonrpc: "2.0".to_string(),
170            id: request_id.clone(),
171            result: Some(result),
172            error: None,
173        })),
174        Err(error) => Err((
175            StatusCode::OK,
176            Json(JsonRpcResponse {
177                jsonrpc: "2.0".to_string(),
178                id: request_id,
179                result: None,
180                error: Some(error),
181            }),
182        )),
183    }
184}
185
186async fn handle_message_send(
187    server: &A2AServer,
188    request: JsonRpcRequest,
189) -> Result<serde_json::Value, JsonRpcError> {
190    let params: MessageSendParams = serde_json::from_value(request.params)
191        .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
192
193    // Create a new task
194    let task_id = params
195        .message
196        .task_id
197        .clone()
198        .unwrap_or_else(|| Uuid::new_v4().to_string());
199
200    let task = Task {
201        id: task_id.clone(),
202        context_id: params.message.context_id.clone(),
203        status: TaskStatus {
204            state: TaskState::Working,
205            message: Some(params.message.clone()),
206            timestamp: Some(chrono::Utc::now().to_rfc3339()),
207        },
208        artifacts: vec![],
209        history: vec![params.message.clone()],
210        metadata: std::collections::HashMap::new(),
211    };
212
213    server.tasks.insert(task_id.clone(), task.clone());
214
215    // Extract prompt text from message parts
216    let prompt: String = params
217        .message
218        .parts
219        .iter()
220        .filter_map(|p| match p {
221            Part::Text { text } => Some(text.as_str()),
222            _ => None,
223        })
224        .collect::<Vec<_>>()
225        .join("\n");
226
227    if prompt.is_empty() {
228        // Update task to failed
229        if let Some(mut t) = server.tasks.get_mut(&task_id) {
230            t.status.state = TaskState::Failed;
231            t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
232        }
233        return Err(JsonRpcError::invalid_params("No text content in message"));
234    }
235
236    // Determine if blocking (default true for message/send)
237    let blocking = params
238        .configuration
239        .as_ref()
240        .and_then(|c| c.blocking)
241        .unwrap_or(true);
242
243    if blocking {
244        // Synchronous execution: create session, run prompt, return completed task
245        let mut session = Session::new().await.map_err(|e| {
246            JsonRpcError::internal_error(format!("Failed to create session: {}", e))
247        })?;
248        let started_at = Instant::now();
249
250        match session.prompt(&prompt).await {
251            Ok(result) => {
252                let result_text = result.text;
253                let response_message = Message {
254                    message_id: Uuid::new_v4().to_string(),
255                    role: MessageRole::Agent,
256                    parts: vec![Part::Text {
257                        text: result_text.clone(),
258                    }],
259                    context_id: params.message.context_id.clone(),
260                    task_id: Some(task_id.clone()),
261                    metadata: std::collections::HashMap::new(),
262                    extensions: vec![],
263                };
264
265                let artifact = Artifact {
266                    artifact_id: Uuid::new_v4().to_string(),
267                    parts: vec![Part::Text {
268                        text: result_text.clone(),
269                    }],
270                    name: Some("response".to_string()),
271                    description: None,
272                    metadata: std::collections::HashMap::new(),
273                    extensions: vec![],
274                };
275
276                if let Some(mut t) = server.tasks.get_mut(&task_id) {
277                    t.status.state = TaskState::Completed;
278                    t.status.message = Some(response_message.clone());
279                    t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
280                    t.artifacts.push(artifact.clone());
281                    t.history.push(response_message);
282
283                    let status_event = TaskStatusUpdateEvent {
284                        id: task_id.clone(),
285                        status: t.status.clone(),
286                        is_final: true,
287                        metadata: std::collections::HashMap::new(),
288                    };
289                    let artifact_event = TaskArtifactUpdateEvent {
290                        id: task_id.clone(),
291                        artifact,
292                        metadata: std::collections::HashMap::new(),
293                    };
294                    tracing::debug!(
295                        task_id = %task_id,
296                        event = ?StreamEvent::StatusUpdate(status_event),
297                        "Task completed"
298                    );
299                    tracing::debug!(
300                        task_id = %task_id,
301                        event = ?StreamEvent::ArtifactUpdate(artifact_event),
302                        "Artifact produced"
303                    );
304                }
305
306                record_a2a_message_telemetry(
307                    "a2a_message_send",
308                    &task_id,
309                    true,
310                    &prompt,
311                    started_at.elapsed(),
312                    true,
313                    Some(result_text),
314                    None,
315                );
316            }
317            Err(e) => {
318                let error_message = Message {
319                    message_id: Uuid::new_v4().to_string(),
320                    role: MessageRole::Agent,
321                    parts: vec![Part::Text {
322                        text: format!("Error: {}", e),
323                    }],
324                    context_id: params.message.context_id.clone(),
325                    task_id: Some(task_id.clone()),
326                    metadata: std::collections::HashMap::new(),
327                    extensions: vec![],
328                };
329
330                if let Some(mut t) = server.tasks.get_mut(&task_id) {
331                    t.status.state = TaskState::Failed;
332                    t.status.message = Some(error_message);
333                    t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
334                }
335
336                record_a2a_message_telemetry(
337                    "a2a_message_send",
338                    &task_id,
339                    true,
340                    &prompt,
341                    started_at.elapsed(),
342                    false,
343                    None,
344                    Some(e.to_string()),
345                );
346            }
347        }
348    } else {
349        // Async execution: spawn background task, return immediately with Working state
350        let tasks = server.tasks.clone();
351        let context_id = params.message.context_id.clone();
352        let spawn_task_id = task_id.clone();
353
354        tokio::spawn(async move {
355            let task_id = spawn_task_id;
356            let started_at = Instant::now();
357            let mut session = match Session::new().await {
358                Ok(s) => s,
359                Err(e) => {
360                    tracing::error!("Failed to create session for task {}: {}", task_id, e);
361                    if let Some(mut t) = tasks.get_mut(&task_id) {
362                        t.status.state = TaskState::Failed;
363                        t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
364                    }
365                    record_a2a_message_telemetry(
366                        "a2a_message_send",
367                        &task_id,
368                        false,
369                        &prompt,
370                        started_at.elapsed(),
371                        false,
372                        None,
373                        Some(e.to_string()),
374                    );
375                    return;
376                }
377            };
378
379            match session.prompt(&prompt).await {
380                Ok(result) => {
381                    let result_text = result.text;
382                    let response_message = Message {
383                        message_id: Uuid::new_v4().to_string(),
384                        role: MessageRole::Agent,
385                        parts: vec![Part::Text {
386                            text: result_text.clone(),
387                        }],
388                        context_id,
389                        task_id: Some(task_id.clone()),
390                        metadata: std::collections::HashMap::new(),
391                        extensions: vec![],
392                    };
393
394                    let artifact = Artifact {
395                        artifact_id: Uuid::new_v4().to_string(),
396                        parts: vec![Part::Text {
397                            text: result_text.clone(),
398                        }],
399                        name: Some("response".to_string()),
400                        description: None,
401                        metadata: std::collections::HashMap::new(),
402                        extensions: vec![],
403                    };
404
405                    if let Some(mut t) = tasks.get_mut(&task_id) {
406                        t.status.state = TaskState::Completed;
407                        t.status.message = Some(response_message.clone());
408                        t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
409                        t.artifacts.push(artifact);
410                        t.history.push(response_message);
411                    }
412
413                    record_a2a_message_telemetry(
414                        "a2a_message_send",
415                        &task_id,
416                        false,
417                        &prompt,
418                        started_at.elapsed(),
419                        true,
420                        Some(result_text),
421                        None,
422                    );
423                }
424                Err(e) => {
425                    tracing::error!("Task {} failed: {}", task_id, e);
426                    if let Some(mut t) = tasks.get_mut(&task_id) {
427                        t.status.state = TaskState::Failed;
428                        t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
429                    }
430                    record_a2a_message_telemetry(
431                        "a2a_message_send",
432                        &task_id,
433                        false,
434                        &prompt,
435                        started_at.elapsed(),
436                        false,
437                        None,
438                        Some(e.to_string()),
439                    );
440                }
441            }
442        });
443    }
444
445    // Return current task state wrapped in SendMessageResponse
446    let task = server.tasks.get(&task_id).unwrap();
447    let response = SendMessageResponse::Task(task.value().clone());
448    serde_json::to_value(response)
449        .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
450}
451
452async fn handle_message_stream(
453    server: &A2AServer,
454    request: JsonRpcRequest,
455) -> Result<serde_json::Value, JsonRpcError> {
456    // message/stream submits the task for async processing.
457    // The client should poll tasks/get for status updates.
458    // True SSE streaming requires a dedicated endpoint outside JSON-RPC.
459
460    let params: MessageSendParams = serde_json::from_value(request.params)
461        .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
462
463    let task_id = params
464        .message
465        .task_id
466        .clone()
467        .unwrap_or_else(|| Uuid::new_v4().to_string());
468
469    let task = Task {
470        id: task_id.clone(),
471        context_id: params.message.context_id.clone(),
472        status: TaskStatus {
473            state: TaskState::Working,
474            message: Some(params.message.clone()),
475            timestamp: Some(chrono::Utc::now().to_rfc3339()),
476        },
477        artifacts: vec![],
478        history: vec![params.message.clone()],
479        metadata: std::collections::HashMap::new(),
480    };
481
482    server.tasks.insert(task_id.clone(), task.clone());
483
484    // Extract prompt
485    let prompt: String = params
486        .message
487        .parts
488        .iter()
489        .filter_map(|p| match p {
490            Part::Text { text } => Some(text.as_str()),
491            _ => None,
492        })
493        .collect::<Vec<_>>()
494        .join("\n");
495
496    if prompt.is_empty() {
497        if let Some(mut t) = server.tasks.get_mut(&task_id) {
498            t.status.state = TaskState::Failed;
499            t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
500        }
501        return Err(JsonRpcError::invalid_params("No text content in message"));
502    }
503
504    // Spawn async processing
505    let tasks = server.tasks.clone();
506    let context_id = params.message.context_id.clone();
507    let spawn_task_id = task_id.clone();
508
509    tokio::spawn(async move {
510        let task_id = spawn_task_id;
511        let started_at = Instant::now();
512        let mut session = match Session::new().await {
513            Ok(s) => s,
514            Err(e) => {
515                tracing::error!(
516                    "Failed to create session for stream task {}: {}",
517                    task_id,
518                    e
519                );
520                if let Some(mut t) = tasks.get_mut(&task_id) {
521                    t.status.state = TaskState::Failed;
522                    t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
523                }
524                record_a2a_message_telemetry(
525                    "a2a_message_stream",
526                    &task_id,
527                    false,
528                    &prompt,
529                    started_at.elapsed(),
530                    false,
531                    None,
532                    Some(e.to_string()),
533                );
534                return;
535            }
536        };
537
538        match session.prompt(&prompt).await {
539            Ok(result) => {
540                let result_text = result.text;
541                let response_message = Message {
542                    message_id: Uuid::new_v4().to_string(),
543                    role: MessageRole::Agent,
544                    parts: vec![Part::Text {
545                        text: result_text.clone(),
546                    }],
547                    context_id,
548                    task_id: Some(task_id.clone()),
549                    metadata: std::collections::HashMap::new(),
550                    extensions: vec![],
551                };
552
553                let artifact = Artifact {
554                    artifact_id: Uuid::new_v4().to_string(),
555                    parts: vec![Part::Text {
556                        text: result_text.clone(),
557                    }],
558                    name: Some("response".to_string()),
559                    description: None,
560                    metadata: std::collections::HashMap::new(),
561                    extensions: vec![],
562                };
563
564                if let Some(mut t) = tasks.get_mut(&task_id) {
565                    t.status.state = TaskState::Completed;
566                    t.status.message = Some(response_message.clone());
567                    t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
568                    t.artifacts.push(artifact.clone());
569                    t.history.push(response_message);
570
571                    // Emit streaming events for SSE consumers
572                    let status_event = TaskStatusUpdateEvent {
573                        id: task_id.clone(),
574                        status: t.status.clone(),
575                        is_final: true,
576                        metadata: std::collections::HashMap::new(),
577                    };
578                    let artifact_event = TaskArtifactUpdateEvent {
579                        id: task_id.clone(),
580                        artifact,
581                        metadata: std::collections::HashMap::new(),
582                    };
583                    tracing::debug!(
584                        task_id = %task_id,
585                        event = ?StreamEvent::StatusUpdate(status_event),
586                        "Task completed"
587                    );
588                    tracing::debug!(
589                        task_id = %task_id,
590                        event = ?StreamEvent::ArtifactUpdate(artifact_event),
591                        "Artifact produced"
592                    );
593                }
594
595                record_a2a_message_telemetry(
596                    "a2a_message_stream",
597                    &task_id,
598                    false,
599                    &prompt,
600                    started_at.elapsed(),
601                    true,
602                    Some(result_text),
603                    None,
604                );
605            }
606            Err(e) => {
607                tracing::error!("Stream task {} failed: {}", task_id, e);
608                if let Some(mut t) = tasks.get_mut(&task_id) {
609                    t.status.state = TaskState::Failed;
610                    t.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
611                }
612                record_a2a_message_telemetry(
613                    "a2a_message_stream",
614                    &task_id,
615                    false,
616                    &prompt,
617                    started_at.elapsed(),
618                    false,
619                    None,
620                    Some(e.to_string()),
621                );
622            }
623        }
624    });
625
626    // Return task in Working state — client polls tasks/get for completion
627    let response = SendMessageResponse::Task(task);
628    serde_json::to_value(response)
629        .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
630}
631
632async fn handle_tasks_get(
633    server: &A2AServer,
634    request: JsonRpcRequest,
635) -> Result<serde_json::Value, JsonRpcError> {
636    let params: TaskQueryParams = serde_json::from_value(request.params)
637        .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
638
639    let task = server.tasks.get(&params.id).ok_or_else(|| JsonRpcError {
640        code: TASK_NOT_FOUND,
641        message: format!("Task not found: {}", params.id),
642        data: None,
643    })?;
644
645    serde_json::to_value(task.value().clone())
646        .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
647}
648
649async fn handle_tasks_cancel(
650    server: &A2AServer,
651    request: JsonRpcRequest,
652) -> Result<serde_json::Value, JsonRpcError> {
653    let params: TaskQueryParams = serde_json::from_value(request.params)
654        .map_err(|e| JsonRpcError::invalid_params(format!("Invalid parameters: {}", e)))?;
655
656    let mut task = server
657        .tasks
658        .get_mut(&params.id)
659        .ok_or_else(|| JsonRpcError {
660            code: TASK_NOT_FOUND,
661            message: format!("Task not found: {}", params.id),
662            data: None,
663        })?;
664
665    if !task.status.state.is_active() {
666        return Err(JsonRpcError {
667            code: TASK_NOT_CANCELABLE,
668            message: "Task is already in a terminal state".to_string(),
669            data: None,
670        });
671    }
672
673    task.status.state = TaskState::Cancelled;
674    task.status.timestamp = Some(chrono::Utc::now().to_rfc3339());
675
676    serde_json::to_value(task.value().clone())
677        .map_err(|e| JsonRpcError::internal_error(format!("Serialization error: {}", e)))
678}