jamjet_worker/executors/
a2a_task.rs1use crate::executor::{ExecutionResult, NodeExecutor};
13use async_trait::async_trait;
14use jamjet_a2a_proto::A2aMessage as Message;
15use jamjet_a2a_proto::A2aPart as Part;
16use jamjet_a2a_proto::{
17 A2aArtifact, A2aClient, A2aTaskState, PartContent, Role, SendMessageRequest,
18 SendMessageResponse,
19};
20use jamjet_state::backend::WorkItem;
21use serde_json::{json, Value};
22use std::collections::HashMap;
23use std::time::Duration;
24use tracing::{debug, instrument, warn};
25use uuid::Uuid;
26
27pub struct A2aTaskExecutor {
29 client: A2aClient,
30 poll_interval: Duration,
32}
33
34impl A2aTaskExecutor {
35 pub fn new() -> Self {
36 Self {
37 client: A2aClient::new(),
38 poll_interval: Duration::from_secs(2),
39 }
40 }
41
42 pub fn with_poll_interval(mut self, interval: Duration) -> Self {
43 self.poll_interval = interval;
44 self
45 }
46}
47
48impl Default for A2aTaskExecutor {
49 fn default() -> Self {
50 Self::new()
51 }
52}
53
54#[async_trait]
55impl NodeExecutor for A2aTaskExecutor {
56 #[instrument(skip(self, item), fields(node_id = %item.node_id))]
57 async fn execute(&self, item: &WorkItem) -> Result<ExecutionResult, String> {
58 let start = std::time::Instant::now();
59
60 let agent_uri = item
62 .payload
63 .get("agent_uri")
64 .and_then(|v| v.as_str())
65 .ok_or("A2aTaskExecutor: missing 'agent_uri' in payload")?;
66
67 let skill = item
68 .payload
69 .get("skill")
70 .and_then(|v| v.as_str())
71 .unwrap_or("default");
72
73 let input = item.payload.get("input").cloned().unwrap_or(json!({}));
74
75 let task_id = item
76 .payload
77 .get("task_id")
78 .and_then(|v| v.as_str())
79 .map(|s| s.to_string())
80 .unwrap_or_else(|| Uuid::new_v4().to_string());
81
82 debug!(
83 agent_uri = %agent_uri,
84 skill = %skill,
85 task_id = %task_id,
86 "Submitting A2A task"
87 );
88
89 let a2a_span = tracing::info_span!(
91 "jamjet.a2a_task",
92 "jamjet.tool.protocol" = "a2a",
93 "jamjet.a2a.agent_uri" = %agent_uri,
94 "jamjet.a2a.skill" = %skill,
95 "jamjet.a2a.task_id" = %task_id,
96 );
97 let _a2a_guard = a2a_span.enter();
98
99 let mut metadata_map = HashMap::new();
101 metadata_map.insert("skill".to_string(), json!(skill));
102
103 let message = Message {
104 message_id: Uuid::new_v4().to_string(),
105 context_id: None,
106 task_id: Some(task_id.clone()),
107 role: Role::User,
108 parts: vec![Part {
109 content: PartContent::Data(input),
110 metadata: None,
111 filename: None,
112 media_type: None,
113 }],
114 metadata: Some(metadata_map),
115 extensions: vec![],
116 reference_task_ids: vec![],
117 };
118
119 let request = SendMessageRequest {
120 tenant: None,
121 message,
122 configuration: None,
123 metadata: None,
124 };
125
126 let submitted = self
127 .client
128 .send_message(agent_uri, request)
129 .await
130 .map_err(|e| format!("A2A task submission failed: {e}"))?;
131
132 let response_task_id = match &submitted {
134 SendMessageResponse::Task(t) => t.id.clone(),
135 SendMessageResponse::WrappedTask(w) => w.task.id.clone(),
136 SendMessageResponse::Message(m) => m.task_id.clone().unwrap_or(task_id.clone()),
137 SendMessageResponse::WrappedMessage(w) => {
138 w.message.task_id.clone().unwrap_or(task_id.clone())
139 }
140 };
141
142 debug!(task_id = %response_task_id, "A2A task submitted");
143
144 let final_task = self
146 .client
147 .wait_for_completion(agent_uri, &response_task_id, self.poll_interval, None)
148 .await
149 .map_err(|e| format!("A2A task polling failed: {e}"))?;
150
151 let duration_ms = start.elapsed().as_millis() as u64;
152
153 match final_task.status.state {
154 A2aTaskState::Completed => {
155 let output = extract_output(&final_task.artifacts);
157 Ok(ExecutionResult {
158 output: output.clone(),
159 state_patch: json!({ "last_a2a_output": output }),
160 duration_ms,
161 gen_ai_system: None,
162 gen_ai_model: None,
163 input_tokens: None,
164 output_tokens: None,
165 finish_reason: Some("completed".into()),
166 })
167 }
168 A2aTaskState::Failed => {
169 let error = final_task
170 .status
171 .message
172 .as_ref()
173 .and_then(|m| {
174 m.parts.iter().find_map(|p| {
175 if let PartContent::Text(ref text) = p.content {
176 Some(text.clone())
177 } else {
178 None
179 }
180 })
181 })
182 .unwrap_or_else(|| "A2A task failed".into());
183 Err(error)
184 }
185 A2aTaskState::InputRequired => {
186 warn!(task_id = %response_task_id, "A2A task requires input — not yet handled");
189 Err("A2A task requires input — multi-turn not yet supported in this node".into())
190 }
191 other => Err(format!("A2A task ended in unexpected state: {other:?}")),
192 }
193 }
194}
195
196fn extract_output(artifacts: &[A2aArtifact]) -> Value {
197 artifacts
198 .first()
199 .map(|a| {
200 a.parts
201 .iter()
202 .map(|p| match &p.content {
203 PartContent::Data(data) => data.clone(),
204 PartContent::Text(text) => json!({ "text": text }),
205 PartContent::Url(url) => json!({ "uri": url }),
206 PartContent::Raw(_) => json!({ "type": "raw_binary" }),
207 _ => json!({}),
208 })
209 .next()
210 .unwrap_or(json!({}))
211 })
212 .unwrap_or(json!({}))
213}