Skip to main content

jamjet_worker/executors/
a2a_task.rs

1//! Executor for `A2aTask` workflow nodes.
2//!
3//! When a workflow node has kind `A2aTask`, this executor:
4//! 1. Resolves the remote agent URI from the IR.
5//! 2. Submits a task via the A2A client.
6//! 3. Polls (or SSE-streams) for completion.
7//! 4. Maps artifacts into the node output and workflow state patch.
8//!
9//! The executor is crash-resumable: if the worker dies mid-poll, the scheduler
10//! will reclaim the lease and re-submit the task on the next attempt.
11
12use crate::executor::{ExecutionResult, NodeExecutor};
13use async_trait::async_trait;
14use jamjet_a2a_proto::A2aMessage as Message;
15use jamjet_a2a_proto::A2aPart as Part;
16use jamjet_a2a_proto::{
17    A2aArtifact, A2aClient, A2aTaskState, PartContent, Role, SendMessageRequest,
18    SendMessageResponse,
19};
20use jamjet_state::backend::WorkItem;
21use serde_json::{json, Value};
22use std::collections::HashMap;
23use std::time::Duration;
24use tracing::{debug, instrument, warn};
25use uuid::Uuid;
26
27/// Executor for `a2a_task` workflow nodes.
28pub struct A2aTaskExecutor {
29    client: A2aClient,
30    /// Default poll interval when SSE is not available.
31    poll_interval: Duration,
32}
33
34impl A2aTaskExecutor {
35    pub fn new() -> Self {
36        Self {
37            client: A2aClient::new(),
38            poll_interval: Duration::from_secs(2),
39        }
40    }
41
42    pub fn with_poll_interval(mut self, interval: Duration) -> Self {
43        self.poll_interval = interval;
44        self
45    }
46}
47
48impl Default for A2aTaskExecutor {
49    fn default() -> Self {
50        Self::new()
51    }
52}
53
54#[async_trait]
55impl NodeExecutor for A2aTaskExecutor {
56    #[instrument(skip(self, item), fields(node_id = %item.node_id))]
57    async fn execute(&self, item: &WorkItem) -> Result<ExecutionResult, String> {
58        let start = std::time::Instant::now();
59
60        // Extract agent URI and skill from the work item payload.
61        let agent_uri = item
62            .payload
63            .get("agent_uri")
64            .and_then(|v| v.as_str())
65            .ok_or("A2aTaskExecutor: missing 'agent_uri' in payload")?;
66
67        let skill = item
68            .payload
69            .get("skill")
70            .and_then(|v| v.as_str())
71            .unwrap_or("default");
72
73        let input = item.payload.get("input").cloned().unwrap_or(json!({}));
74
75        let task_id = item
76            .payload
77            .get("task_id")
78            .and_then(|v| v.as_str())
79            .map(|s| s.to_string())
80            .unwrap_or_else(|| Uuid::new_v4().to_string());
81
82        debug!(
83            agent_uri = %agent_uri,
84            skill = %skill,
85            task_id = %task_id,
86            "Submitting A2A task"
87        );
88
89        // Open a protocol-level span for A2A round-trip tracking (H2.4, H2.5, H2.9).
90        let a2a_span = tracing::info_span!(
91            "jamjet.a2a_task",
92            "jamjet.tool.protocol" = "a2a",
93            "jamjet.a2a.agent_uri" = %agent_uri,
94            "jamjet.a2a.skill" = %skill,
95            "jamjet.a2a.task_id" = %task_id,
96        );
97        let _a2a_guard = a2a_span.enter();
98
99        // Build the message using the published crate's types.
100        let mut metadata_map = HashMap::new();
101        metadata_map.insert("skill".to_string(), json!(skill));
102
103        let message = Message {
104            message_id: Uuid::new_v4().to_string(),
105            context_id: None,
106            task_id: Some(task_id.clone()),
107            role: Role::User,
108            parts: vec![Part {
109                content: PartContent::Data(input),
110                metadata: None,
111                filename: None,
112                media_type: None,
113            }],
114            metadata: Some(metadata_map),
115            extensions: vec![],
116            reference_task_ids: vec![],
117        };
118
119        let request = SendMessageRequest {
120            tenant: None,
121            message,
122            configuration: None,
123            metadata: None,
124        };
125
126        let submitted = self
127            .client
128            .send_message(agent_uri, request)
129            .await
130            .map_err(|e| format!("A2A task submission failed: {e}"))?;
131
132        // Extract the task ID from the response (may differ from our generated one).
133        let response_task_id = match &submitted {
134            SendMessageResponse::Task(t) => t.id.clone(),
135            SendMessageResponse::WrappedTask(w) => w.task.id.clone(),
136            SendMessageResponse::Message(m) => m.task_id.clone().unwrap_or(task_id.clone()),
137            SendMessageResponse::WrappedMessage(w) => {
138                w.message.task_id.clone().unwrap_or(task_id.clone())
139            }
140        };
141
142        debug!(task_id = %response_task_id, "A2A task submitted");
143
144        // Poll until completion.
145        let final_task = self
146            .client
147            .wait_for_completion(agent_uri, &response_task_id, self.poll_interval, None)
148            .await
149            .map_err(|e| format!("A2A task polling failed: {e}"))?;
150
151        let duration_ms = start.elapsed().as_millis() as u64;
152
153        match final_task.status.state {
154            A2aTaskState::Completed => {
155                // Extract output from the first artifact.
156                let output = extract_output(&final_task.artifacts);
157                Ok(ExecutionResult {
158                    output: output.clone(),
159                    state_patch: json!({ "last_a2a_output": output }),
160                    duration_ms,
161                    gen_ai_system: None,
162                    gen_ai_model: None,
163                    input_tokens: None,
164                    output_tokens: None,
165                    finish_reason: Some("completed".into()),
166                })
167            }
168            A2aTaskState::Failed => {
169                let error = final_task
170                    .status
171                    .message
172                    .as_ref()
173                    .and_then(|m| {
174                        m.parts.iter().find_map(|p| {
175                            if let PartContent::Text(ref text) = p.content {
176                                Some(text.clone())
177                            } else {
178                                None
179                            }
180                        })
181                    })
182                    .unwrap_or_else(|| "A2A task failed".into());
183                Err(error)
184            }
185            A2aTaskState::InputRequired => {
186                // The workflow should be paused for user input — return error
187                // so the retry mechanism handles re-submission.
188                warn!(task_id = %response_task_id, "A2A task requires input — not yet handled");
189                Err("A2A task requires input — multi-turn not yet supported in this node".into())
190            }
191            other => Err(format!("A2A task ended in unexpected state: {other:?}")),
192        }
193    }
194}
195
196fn extract_output(artifacts: &[A2aArtifact]) -> Value {
197    artifacts
198        .first()
199        .map(|a| {
200            a.parts
201                .iter()
202                .map(|p| match &p.content {
203                    PartContent::Data(data) => data.clone(),
204                    PartContent::Text(text) => json!({ "text": text }),
205                    PartContent::Url(url) => json!({ "uri": url }),
206                    PartContent::Raw(_) => json!({ "type": "raw_binary" }),
207                    _ => json!({}),
208                })
209                .next()
210                .unwrap_or(json!({}))
211        })
212        .unwrap_or(json!({}))
213}