Skip to main content

jamjet_worker/executors/
agent_tool.rs

1//! Executor for `AgentTool` workflow nodes (Phase B — sync, streaming, and conversational modes).
2//!
3//! Invokes a remote agent via A2A-style HTTP POST and returns the response as the node output.
4//! - Sync:           POST `{agent_uri}/tasks/send`           — single request/response
5//! - Streaming:      POST `{agent_uri}/tasks/sendSubscribe`  — NDJSON stream with budget guard
6//! - Conversational: POST `{agent_uri}/tasks/send` in a loop — multi-turn exchange
7
8#![allow(clippy::too_many_arguments)]
9
10use crate::executor::{ExecutionResult, NodeExecutor, StreamEventSender};
11use async_trait::async_trait;
12use bytes::BytesMut;
13use futures::StreamExt;
14use jamjet_state::backend::WorkItem;
15use reqwest::Client;
16use serde_json::{json, Value};
17use std::collections::hash_map::DefaultHasher;
18use std::hash::{Hash, Hasher};
19use std::time::Duration;
20use tracing::{debug, instrument, warn};
21
22pub struct AgentToolExecutor;
23
24impl AgentToolExecutor {
25    /// Build an HTTP client with the given timeout.
26    fn build_client(timeout_ms: u64) -> Result<Client, String> {
27        reqwest::Client::builder()
28            .timeout(std::time::Duration::from_millis(timeout_ms))
29            .build()
30            .map_err(|e| format!("HTTP client: {e}"))
31    }
32
33    /// Resolve the base agent URI to an `https://` URL, returning an error for unsupported schemes.
34    /// In test builds, `http://` is also accepted (for wiremock).
35    fn resolve_url(agent_uri: &str, endpoint: &str) -> Result<String, String> {
36        let is_http =
37            agent_uri.starts_with("https://") || (cfg!(test) && agent_uri.starts_with("http://"));
38        if is_http {
39            Ok(format!("{}/{}", agent_uri.trim_end_matches('/'), endpoint))
40        } else {
41            Err(format!(
42                "Cannot resolve '{}' to HTTP endpoint. \
43                 Only https:// agent URIs are supported for remote invocation.",
44                agent_uri
45            ))
46        }
47    }
48
49    /// Execute in sync mode: single POST to `/tasks/send`, returns one response.
50    async fn execute_sync(
51        &self,
52        item: &WorkItem,
53        agent_uri: &str,
54        protocol: &str,
55        output_key: &str,
56        timeout_ms: u64,
57        input: &Value,
58        input_hash: &str,
59        start: std::time::Instant,
60    ) -> Result<ExecutionResult, String> {
61        let client = Self::build_client(timeout_ms)?;
62        let task_url = Self::resolve_url(agent_uri, "tasks/send")?;
63
64        let resp = client
65            .post(&task_url)
66            .json(&json!({
67                "jsonrpc": "2.0",
68                "method": "tasks/send",
69                "params": { "message": { "parts": [{ "text": input.to_string() }] } }
70            }))
71            .send()
72            .await
73            .map_err(|e| format!("AgentTool invocation failed: {e}"))?;
74
75        if !resp.status().is_success() {
76            let status = resp.status();
77            let body = resp.text().await.unwrap_or_default();
78            return Err(format!("Agent returned error {status}: {body}"));
79        }
80
81        let output: Value = resp
82            .json()
83            .await
84            .map_err(|e| format!("Failed to parse agent response: {e}"))?;
85
86        let duration_ms = start.elapsed().as_millis() as u64;
87
88        Ok(ExecutionResult {
89            output: json!({ output_key: output }),
90            state_patch: json!({
91                "agent_tool_events": [
92                    {
93                        "type": "agent_tool_invoked",
94                        "node_id": &item.node_id,
95                        "agent_uri": agent_uri,
96                        "mode": "sync",
97                        "protocol": protocol,
98                        "input_hash": input_hash
99                    },
100                    {
101                        "type": "agent_tool_completed",
102                        "node_id": &item.node_id,
103                        "output": &output,
104                        "total_cost": 0.0,
105                        "latency_ms": duration_ms
106                    },
107                ]
108            }),
109            duration_ms,
110            gen_ai_system: None,
111            gen_ai_model: None,
112            input_tokens: None,
113            output_tokens: None,
114            finish_reason: None,
115        })
116    }
117
118    /// Legacy streaming mode: POST to `/tasks/sendSubscribe`, buffer full body, parse NDJSON.
119    ///
120    /// Emits `agent_tool_progress` events per chunk, with optional early termination
121    /// when `max_cost_usd` budget is exceeded. Final event is `agent_tool_completed`
122    /// or `agent_tool_terminated` on budget breach.
123    ///
124    /// Kept as fallback when invoked through `execute()` (no channel available).
125    /// The incremental path is `execute_streaming` on the trait.
126    async fn stream_ndjson(
127        &self,
128        item: &WorkItem,
129        agent_uri: &str,
130        protocol: &str,
131        output_key: &str,
132        timeout_ms: u64,
133        input: &Value,
134        input_hash: &str,
135        max_cost_usd: Option<f64>,
136        start: std::time::Instant,
137    ) -> Result<ExecutionResult, String> {
138        let client = Self::build_client(timeout_ms)?;
139        let task_url = Self::resolve_url(agent_uri, "tasks/sendSubscribe")?;
140
141        let resp = client
142            .post(&task_url)
143            .json(&json!({
144                "jsonrpc": "2.0",
145                "method": "tasks/sendSubscribe",
146                "params": { "message": { "parts": [{ "text": input.to_string() }] } }
147            }))
148            .send()
149            .await
150            .map_err(|e| format!("AgentTool streaming invocation failed: {e}"))?;
151
152        if !resp.status().is_success() {
153            let status = resp.status();
154            let body = resp.text().await.unwrap_or_default();
155            return Err(format!("Agent returned error {status}: {body}"));
156        }
157
158        // Read full body, then parse NDJSON line by line
159        let body = resp
160            .text()
161            .await
162            .map_err(|e| format!("Failed to read streaming response body: {e}"))?;
163
164        let mut events: Vec<Value> = vec![json!({
165            "type": "agent_tool_invoked",
166            "node_id": &item.node_id,
167            "agent_uri": agent_uri,
168            "mode": "streaming",
169            "protocol": protocol,
170            "input_hash": input_hash
171        })];
172
173        let mut accumulated_cost: f64 = 0.0;
174        let mut terminated_early = false;
175        let mut last_chunk: Value = json!(null);
176        let mut chunk_index: u64 = 0;
177
178        for line in body.lines() {
179            let trimmed = line.trim();
180            if trimmed.is_empty() {
181                continue;
182            }
183
184            let chunk: Value =
185                serde_json::from_str(trimmed).unwrap_or_else(|_| json!({ "raw": trimmed }));
186
187            // Extract cost from chunk if present
188            if let Some(cost) = chunk.get("cost_usd").and_then(|v| v.as_f64()) {
189                accumulated_cost += cost;
190            }
191
192            events.push(json!({
193                "type": "agent_tool_progress",
194                "node_id": &item.node_id,
195                "chunk_index": chunk_index,
196                "chunk": &chunk,
197                "accumulated_cost_usd": accumulated_cost
198            }));
199
200            last_chunk = chunk;
201            chunk_index += 1;
202
203            // Budget guard — terminate early if cost threshold exceeded
204            if let Some(budget) = max_cost_usd {
205                if accumulated_cost > budget {
206                    terminated_early = true;
207                    debug!(
208                        node_id = %item.node_id,
209                        accumulated_cost,
210                        budget,
211                        "AgentTool streaming: budget exceeded, terminating early"
212                    );
213                    break;
214                }
215            }
216        }
217
218        let duration_ms = start.elapsed().as_millis() as u64;
219
220        if terminated_early {
221            events.push(json!({
222                "type": "agent_tool_terminated",
223                "node_id": &item.node_id,
224                "reason": "budget_exceeded",
225                "accumulated_cost_usd": accumulated_cost,
226                "latency_ms": duration_ms
227            }));
228        } else {
229            events.push(json!({
230                "type": "agent_tool_completed",
231                "node_id": &item.node_id,
232                "output": &last_chunk,
233                "total_cost": accumulated_cost,
234                "latency_ms": duration_ms
235            }));
236        }
237
238        Ok(ExecutionResult {
239            output: json!({ output_key: &last_chunk }),
240            state_patch: json!({ "agent_tool_events": events }),
241            duration_ms,
242            gen_ai_system: None,
243            gen_ai_model: None,
244            input_tokens: None,
245            output_tokens: None,
246            finish_reason: None,
247        })
248    }
249
250    /// Best-effort A2A cancel. Fire-and-forget with 5s timeout.
251    async fn send_a2a_cancel(client: &Client, agent_uri: &str, task_id: &Option<String>) {
252        if let Some(ref id) = task_id {
253            if let Ok(cancel_url) = Self::resolve_url(agent_uri, "tasks/cancel") {
254                let _ = client
255                    .post(&cancel_url)
256                    .json(&serde_json::json!({ "id": id }))
257                    .timeout(Duration::from_secs(5))
258                    .send()
259                    .await;
260            }
261        }
262    }
263
264    /// Execute in conversational mode: multi-turn loop sending to `/tasks/send`.
265    ///
266    /// Reads `max_turns` from `payload.mode.conversational.max_turns` (default 5).
267    /// Each turn records outbound and inbound `agent_tool_turn` events. Stops early
268    /// when the agent response carries `status: "completed"`.
269    async fn execute_conversational(
270        &self,
271        item: &WorkItem,
272        agent_uri: &str,
273        protocol: &str,
274        output_key: &str,
275        timeout_ms: u64,
276        input: &Value,
277        input_hash: &str,
278        start: std::time::Instant,
279    ) -> Result<ExecutionResult, String> {
280        let p = &item.payload;
281
282        let max_turns = p
283            .get("mode")
284            .and_then(|m| m.get("conversational"))
285            .and_then(|c| c.get("max_turns"))
286            .and_then(|v| v.as_u64())
287            .unwrap_or(5) as usize;
288
289        let client = Self::build_client(timeout_ms)?;
290        let task_url = Self::resolve_url(agent_uri, "tasks/send")?;
291
292        let mut events: Vec<Value> = vec![json!({
293            "type": "agent_tool_invoked",
294            "node_id": &item.node_id,
295            "agent_uri": agent_uri,
296            "mode": "conversational",
297            "protocol": protocol,
298            "input_hash": input_hash
299        })];
300
301        let mut current_input = input.clone();
302        let mut final_output: Value = json!(null);
303
304        for turn in 0..max_turns {
305            // Record outbound turn event
306            events.push(json!({
307                "type": "agent_tool_turn",
308                "node_id": &item.node_id,
309                "turn": turn,
310                "direction": "outbound",
311                "input": &current_input
312            }));
313
314            debug!(
315                node_id = %item.node_id,
316                turn,
317                "AgentTool conversational: sending turn"
318            );
319
320            let resp = client
321                .post(&task_url)
322                .json(&json!({
323                    "jsonrpc": "2.0",
324                    "method": "tasks/send",
325                    "params": { "message": { "parts": [{ "text": current_input.to_string() }] } }
326                }))
327                .send()
328                .await
329                .map_err(|e| format!("AgentTool turn {turn} failed: {e}"))?;
330
331            if !resp.status().is_success() {
332                let status = resp.status();
333                let body = resp.text().await.unwrap_or_default();
334                return Err(format!(
335                    "Agent returned error {status} on turn {turn}: {body}"
336                ));
337            }
338
339            let response: Value = resp
340                .json()
341                .await
342                .map_err(|e| format!("Failed to parse agent response on turn {turn}: {e}"))?;
343
344            // Record inbound turn event
345            events.push(json!({
346                "type": "agent_tool_turn",
347                "node_id": &item.node_id,
348                "turn": turn,
349                "direction": "inbound",
350                "output": &response
351            }));
352
353            final_output = response.clone();
354
355            // Check if the agent signals completion
356            let status = response
357                .get("status")
358                .and_then(|v| v.as_str())
359                .unwrap_or("");
360            if status == "completed" {
361                debug!(
362                    node_id = %item.node_id,
363                    turn,
364                    "AgentTool conversational: agent signalled completion"
365                );
366                break;
367            }
368
369            // Use agent output as next input for the following turn
370            current_input = response
371                .get("output")
372                .cloned()
373                .unwrap_or_else(|| response.clone());
374        }
375
376        let duration_ms = start.elapsed().as_millis() as u64;
377
378        events.push(json!({
379            "type": "agent_tool_completed",
380            "node_id": &item.node_id,
381            "output": &final_output,
382            "total_cost": 0.0,
383            "latency_ms": duration_ms
384        }));
385
386        Ok(ExecutionResult {
387            output: json!({ output_key: &final_output }),
388            state_patch: json!({ "agent_tool_events": events }),
389            duration_ms,
390            gen_ai_system: None,
391            gen_ai_model: None,
392            input_tokens: None,
393            output_tokens: None,
394            finish_reason: None,
395        })
396    }
397}
398
399#[async_trait]
400impl NodeExecutor for AgentToolExecutor {
401    #[instrument(skip(self, item), fields(node_id = %item.node_id))]
402    async fn execute(&self, item: &WorkItem) -> Result<ExecutionResult, String> {
403        let start = std::time::Instant::now();
404        let p = &item.payload;
405
406        // Extract agent target — handle both { "explicit": "uri" } and plain string
407        let agent_uri = p
408            .get("agent")
409            .and_then(|a| {
410                a.get("explicit")
411                    .and_then(|v| v.as_str())
412                    .or_else(|| a.as_str())
413            })
414            .ok_or("AgentTool: missing 'agent' URI in payload")?;
415
416        // Extract mode — handle both string and object forms
417        // e.g. "sync" or {"conversational": {"max_turns": 5}}
418        let mode = if let Some(mode_val) = p.get("mode") {
419            if let Some(s) = mode_val.as_str() {
420                s.to_string()
421            } else if mode_val.get("conversational").is_some() {
422                "conversational".to_string()
423            } else if mode_val.get("streaming").is_some() {
424                "streaming".to_string()
425            } else {
426                "sync".to_string()
427            }
428        } else {
429            "sync".to_string()
430        };
431        let output_key = p
432            .get("output_key")
433            .and_then(|v| v.as_str())
434            .unwrap_or("result");
435        let timeout_ms = p
436            .get("timeout_ms")
437            .and_then(|v| v.as_u64())
438            .unwrap_or(30_000);
439        let input = p.get("input").cloned().unwrap_or(json!({}));
440        // Budget lookup: check nested {"budget": {"max_cost_usd": …}} first, then flat "max_cost_usd"
441        let max_cost_usd = p
442            .get("budget")
443            .and_then(|b| b.get("max_cost_usd"))
444            .and_then(|v| v.as_f64())
445            .or_else(|| p.get("max_cost_usd").and_then(|v| v.as_f64()));
446
447        // Check for unresolved auto target
448        if p.get("agent").and_then(|a| a.get("auto")).is_some() {
449            return Err(
450                "AgentTool with 'auto' target was not expanded at compile time. \
451                 Use the compiler to expand 'auto' into coordinator + agent_tool nodes."
452                    .into(),
453            );
454        }
455
456        // Resolve protocol based on URI scheme
457        let protocol = if agent_uri.starts_with("https://") {
458            "a2a"
459        } else if agent_uri.starts_with("jamjet://") {
460            "local"
461        } else {
462            "mcp"
463        };
464
465        // Compute input hash for tracing
466        let mut hasher = DefaultHasher::new();
467        input.to_string().hash(&mut hasher);
468        let input_hash = format!("{:016x}", hasher.finish());
469
470        debug!(agent_uri = %agent_uri, mode = %mode, protocol = %protocol, "AgentTool: invoking");
471
472        match mode.as_str() {
473            "sync" => {
474                self.execute_sync(
475                    item,
476                    agent_uri,
477                    protocol,
478                    output_key,
479                    timeout_ms,
480                    &input,
481                    &input_hash,
482                    start,
483                )
484                .await
485            }
486            "streaming" => {
487                self.stream_ndjson(
488                    item,
489                    agent_uri,
490                    protocol,
491                    output_key,
492                    timeout_ms,
493                    &input,
494                    &input_hash,
495                    max_cost_usd,
496                    start,
497                )
498                .await
499            }
500            "conversational" => {
501                self.execute_conversational(
502                    item,
503                    agent_uri,
504                    protocol,
505                    output_key,
506                    timeout_ms,
507                    &input,
508                    &input_hash,
509                    start,
510                )
511                .await
512            }
513            other => Err(format!("Unknown agent_tool mode: '{other}'")),
514        }
515    }
516
517    /// Incremental NDJSON streaming with per-chunk idle timeout, budget guard,
518    /// and A2A cancel on early termination. Events are sent via `tx` in real time.
519    #[instrument(skip(self, item, tx), fields(node_id = %item.node_id))]
520    async fn execute_streaming(
521        &self,
522        item: &WorkItem,
523        tx: StreamEventSender,
524    ) -> Result<ExecutionResult, String> {
525        let start = std::time::Instant::now();
526        let p = &item.payload;
527
528        // ── Extract params ──────────────────────────────────────────────
529        let agent_uri = p
530            .get("agent")
531            .and_then(|a| {
532                a.get("explicit")
533                    .and_then(|v| v.as_str())
534                    .or_else(|| a.as_str())
535            })
536            .ok_or("AgentTool: missing 'agent' URI in payload")?;
537
538        let mode = if let Some(mode_val) = p.get("mode") {
539            if let Some(s) = mode_val.as_str() {
540                s.to_string()
541            } else if mode_val.get("streaming").is_some() {
542                "streaming".to_string()
543            } else {
544                "sync".to_string()
545            }
546        } else {
547            "sync".to_string()
548        };
549
550        // Short-circuit non-streaming modes back to execute()
551        if mode != "streaming" {
552            return self.execute(item).await;
553        }
554
555        let input = p.get("input").cloned().unwrap_or(json!({}));
556
557        let max_cost_usd = p
558            .get("budget")
559            .and_then(|b| b.get("max_cost_usd"))
560            .and_then(|v| v.as_f64())
561            .or_else(|| p.get("max_cost_usd").and_then(|v| v.as_f64()));
562
563        let idle_timeout_secs = p
564            .get("idle_timeout_secs")
565            .and_then(|v| v.as_u64())
566            .unwrap_or(30);
567
568        // Resolve protocol
569        let protocol = if agent_uri.starts_with("https://")
570            || (cfg!(test) && agent_uri.starts_with("http://"))
571        {
572            "a2a"
573        } else if agent_uri.starts_with("jamjet://") {
574            "local"
575        } else {
576            "mcp"
577        };
578
579        // Compute input hash
580        let mut hasher = DefaultHasher::new();
581        input.to_string().hash(&mut hasher);
582        let input_hash = format!("{:016x}", hasher.finish());
583
584        // ── Build client WITHOUT overall timeout (streaming uses per-chunk idle) ──
585        let client = reqwest::Client::builder()
586            .build()
587            .map_err(|e| format!("HTTP client: {e}"))?;
588
589        // ── Emit invoked event ──────────────────────────────────────────
590        let now_ms = || -> u64 {
591            std::time::SystemTime::now()
592                .duration_since(std::time::UNIX_EPOCH)
593                .unwrap()
594                .as_millis() as u64
595        };
596
597        let invoked_event = json!({
598            "type": "agent_tool_invoked",
599            "node_id": &item.node_id,
600            "agent_uri": agent_uri,
601            "mode": &mode,
602            "protocol": protocol,
603            "input_hash": &input_hash,
604            "timestamp_ms": now_ms()
605        });
606        if tx.send(invoked_event).await.is_err() {
607            return Err("Streaming receiver dropped before invocation event".into());
608        }
609
610        // ── POST to /tasks/sendSubscribe ────────────────────────────────
611        let task_url = Self::resolve_url(agent_uri, "tasks/sendSubscribe")?;
612        let resp = client
613            .post(&task_url)
614            .json(&json!({
615                "jsonrpc": "2.0",
616                "method": "tasks/sendSubscribe",
617                "params": { "message": { "parts": [{ "text": input.to_string() }] } }
618            }))
619            .send()
620            .await
621            .map_err(|e| format!("AgentTool streaming invocation failed: {e}"))?;
622
623        if !resp.status().is_success() {
624            let status = resp.status();
625            let body = resp.text().await.unwrap_or_default();
626            return Err(format!("Agent returned error {status}: {body}"));
627        }
628
629        // ── Incremental NDJSON read loop ────────────────────────────────
630        let mut stream = resp.bytes_stream();
631        let mut line_buf = BytesMut::new();
632        let mut chunk_index: u64 = 0;
633        let mut accumulated_cost: f64 = 0.0;
634        let mut task_id: Option<String> = None;
635        let mut last_chunk: Value = json!(null);
636        let output_key = p
637            .get("output_key")
638            .and_then(|v| v.as_str())
639            .unwrap_or("result");
640        let mut terminated_early = false;
641        let mut terminal_error: Option<String> = None;
642        let idle_dur = Duration::from_secs(idle_timeout_secs);
643
644        loop {
645            match tokio::time::timeout(idle_dur, stream.next()).await {
646                // Timeout — no data within idle window
647                Err(_elapsed) => {
648                    warn!(
649                        node_id = %item.node_id,
650                        idle_timeout_secs,
651                        "AgentTool streaming: idle timeout, terminating"
652                    );
653                    let _ = tx
654                        .send(json!({
655                            "type": "agent_tool_terminated",
656                            "node_id": &item.node_id,
657                            "reason": "idle_timeout",
658                            "accumulated_cost_usd": accumulated_cost,
659                            "latency_ms": start.elapsed().as_millis() as u64,
660                            "timestamp_ms": now_ms()
661                        }))
662                        .await;
663                    Self::send_a2a_cancel(&client, agent_uri, &task_id).await;
664                    terminated_early = true;
665                    terminal_error =
666                        Some(format!("AgentTool idle timeout after {idle_timeout_secs}s"));
667                    break;
668                }
669                // Stream ended normally
670                Ok(None) => {
671                    break;
672                }
673                // Network error
674                Ok(Some(Err(e))) => {
675                    warn!(
676                        node_id = %item.node_id,
677                        error = %e,
678                        "AgentTool streaming: network error"
679                    );
680                    let _ = tx
681                        .send(json!({
682                            "type": "agent_tool_error",
683                            "node_id": &item.node_id,
684                            "error": e.to_string(),
685                            "timestamp_ms": now_ms()
686                        }))
687                        .await;
688                    terminated_early = true;
689                    terminal_error = Some(format!("AgentTool stream error: {e}"));
690                    break;
691                }
692                // Got a chunk of bytes
693                Ok(Some(Ok(bytes))) => {
694                    line_buf.extend_from_slice(&bytes);
695
696                    // Process all complete lines in the buffer
697                    while let Some(newline_pos) = line_buf.iter().position(|&b| b == b'\n') {
698                        let line_bytes = line_buf.split_to(newline_pos + 1);
699                        let line_str = match std::str::from_utf8(&line_bytes) {
700                            Ok(s) => s.trim().to_string(),
701                            Err(e) => {
702                                warn!(
703                                    node_id = %item.node_id,
704                                    error = %e,
705                                    "AgentTool streaming: non-UTF8 chunk, skipping"
706                                );
707                                continue;
708                            }
709                        };
710                        if line_str.is_empty() {
711                            continue;
712                        }
713
714                        let chunk: Value = serde_json::from_str(&line_str)
715                            .unwrap_or_else(|_| json!({ "raw": &line_str }));
716                        last_chunk = chunk.clone();
717
718                        // Extract task_id from first chunk
719                        if task_id.is_none() {
720                            if let Some(id) = chunk.get("id").and_then(|v| v.as_str()) {
721                                task_id = Some(id.to_string());
722                            }
723                        }
724
725                        // Accumulate cost
726                        if let Some(cost) = chunk.get("cost_usd").and_then(|v| v.as_f64()) {
727                            accumulated_cost += cost;
728                        }
729
730                        // Emit progress event
731                        let progress = json!({
732                            "type": "agent_tool_progress",
733                            "node_id": &item.node_id,
734                            "chunk_index": chunk_index,
735                            "chunk": &chunk,
736                            "accumulated_cost_usd": accumulated_cost,
737                            "timestamp_ms": now_ms()
738                        });
739                        chunk_index += 1;
740
741                        if tx.send(progress).await.is_err() {
742                            // Receiver dropped — treat as cancellation
743                            debug!(
744                                node_id = %item.node_id,
745                                "AgentTool streaming: receiver dropped, cancelling"
746                            );
747                            Self::send_a2a_cancel(&client, agent_uri, &task_id).await;
748                            terminated_early = true;
749                            terminal_error = Some("AgentTool stream receiver dropped".into());
750                            break;
751                        }
752
753                        // Budget guard
754                        if let Some(budget) = max_cost_usd {
755                            if accumulated_cost > budget {
756                                debug!(
757                                    node_id = %item.node_id,
758                                    accumulated_cost,
759                                    budget,
760                                    "AgentTool streaming: budget exceeded, terminating"
761                                );
762                                let _ = tx
763                                    .send(json!({
764                                        "type": "agent_tool_terminated",
765                                        "node_id": &item.node_id,
766                                        "reason": "budget_exceeded",
767                                        "accumulated_cost_usd": accumulated_cost,
768                                        "latency_ms": start.elapsed().as_millis() as u64,
769                                        "timestamp_ms": now_ms()
770                                    }))
771                                    .await;
772                                Self::send_a2a_cancel(&client, agent_uri, &task_id).await;
773                                terminated_early = true;
774                                break;
775                            }
776                        }
777                    }
778
779                    // If inner loop broke due to termination, break outer loop too
780                    if terminated_early {
781                        break;
782                    }
783                }
784            }
785        }
786
787        // ── Drain remaining bytes in line_buf ───────────────────────────
788        if !terminated_early && !line_buf.is_empty() {
789            if let Ok(remaining) = std::str::from_utf8(&line_buf) {
790                let trimmed = remaining.trim();
791                if !trimmed.is_empty() {
792                    let chunk: Value =
793                        serde_json::from_str(trimmed).unwrap_or_else(|_| json!({ "raw": trimmed }));
794                    last_chunk = chunk.clone();
795
796                    if let Some(cost) = chunk.get("cost_usd").and_then(|v| v.as_f64()) {
797                        accumulated_cost += cost;
798                    }
799
800                    let _ = tx
801                        .send(json!({
802                            "type": "agent_tool_progress",
803                            "node_id": &item.node_id,
804                            "chunk_index": chunk_index,
805                            "chunk": &chunk,
806                            "accumulated_cost_usd": accumulated_cost,
807                            "timestamp_ms": now_ms()
808                        }))
809                        .await;
810                }
811            }
812        }
813
814        // ── Return error for hard failures ───────────────────────────────
815        if let Some(error) = terminal_error {
816            return Err(error);
817        }
818
819        // ── Emit completed (if not terminated early) ────────────────────
820        let duration_ms = start.elapsed().as_millis() as u64;
821        if !terminated_early {
822            let _ = tx
823                .send(json!({
824                    "type": "agent_tool_completed",
825                    "node_id": &item.node_id,
826                    "output": &last_chunk,
827                    "total_cost": accumulated_cost,
828                    "latency_ms": duration_ms,
829                    "timestamp_ms": now_ms()
830                }))
831                .await;
832        }
833
834        Ok(ExecutionResult {
835            output: json!({ output_key: last_chunk }),
836            state_patch: json!({}),
837            duration_ms,
838            gen_ai_system: None,
839            gen_ai_model: None,
840            input_tokens: None,
841            output_tokens: None,
842            finish_reason: None,
843        })
844    }
845}
846
847// ── Tests ───────────────────────────────────────────────────────────────────
848
849#[cfg(test)]
850mod tests {
851    use super::*;
852    use crate::executor::NodeExecutor;
853    use wiremock::matchers::{method, path};
854    use wiremock::{Mock, MockServer, ResponseTemplate};
855
856    /// Build a WorkItem that targets the given agent URI with streaming mode.
857    fn make_test_work_item(
858        agent_uri: &str,
859        idle_timeout: Option<u64>,
860        max_cost: Option<f64>,
861    ) -> WorkItem {
862        let mut payload = serde_json::json!({
863            "agent": agent_uri,
864            "mode": "streaming",
865            "input": {"query": "test"},
866            "workflow_id": "wf1",
867            "workflow_version": "1.0.0",
868        });
869        if let Some(t) = idle_timeout {
870            payload["idle_timeout_secs"] = serde_json::json!(t);
871        }
872        if let Some(c) = max_cost {
873            payload["budget"] = serde_json::json!({"max_cost_usd": c});
874        }
875        WorkItem {
876            id: uuid::Uuid::new_v4(),
877            execution_id: jamjet_core::workflow::ExecutionId::new(),
878            node_id: "n1".into(),
879            queue_type: "agent_tool".into(),
880            payload,
881            attempt: 1,
882            max_attempts: 3,
883            created_at: chrono::Utc::now(),
884            lease_expires_at: None,
885            worker_id: None,
886            tenant_id: "default".into(),
887        }
888    }
889
890    /// Join NDJSON lines into a single body terminated by a newline.
891    fn ndjson_body(lines: &[&str]) -> String {
892        lines.join("\n") + "\n"
893    }
894
895    /// Drain all events from the receiver (non-blocking).
896    fn collect_events(
897        rx: &mut tokio::sync::mpsc::Receiver<serde_json::Value>,
898    ) -> Vec<serde_json::Value> {
899        let mut events = Vec::new();
900        while let Ok(ev) = rx.try_recv() {
901            events.push(ev);
902        }
903        events
904    }
905
906    // ── Test 1: streams NDJSON chunks in order ──────────────────────────
907
908    #[tokio::test]
909    async fn streams_ndjson_chunks_in_order() {
910        let server = MockServer::start().await;
911
912        let body = ndjson_body(&[r#"{"text":"hello"}"#, r#"{"text":"world"}"#]);
913
914        Mock::given(method("POST"))
915            .and(path("/tasks/sendSubscribe"))
916            .respond_with(ResponseTemplate::new(200).set_body_string(body))
917            .mount(&server)
918            .await;
919
920        let item = make_test_work_item(&server.uri(), Some(5), None);
921        let (tx, mut rx) = tokio::sync::mpsc::channel(32);
922
923        let executor = AgentToolExecutor;
924        let result = executor.execute_streaming(&item, tx).await;
925        assert!(
926            result.is_ok(),
927            "execute_streaming failed: {:?}",
928            result.err()
929        );
930
931        let events = collect_events(&mut rx);
932
933        // Expect: invoked, progress(0), progress(1), completed
934        assert!(
935            events.len() >= 4,
936            "Expected at least 4 events, got {}: {:#?}",
937            events.len(),
938            events
939        );
940
941        assert_eq!(events[0]["type"], "agent_tool_invoked");
942        assert_eq!(events[0]["mode"], "streaming");
943
944        assert_eq!(events[1]["type"], "agent_tool_progress");
945        assert_eq!(events[1]["chunk_index"], 0);
946        assert_eq!(events[1]["chunk"]["text"], "hello");
947
948        assert_eq!(events[2]["type"], "agent_tool_progress");
949        assert_eq!(events[2]["chunk_index"], 1);
950        assert_eq!(events[2]["chunk"]["text"], "world");
951
952        assert_eq!(events[3]["type"], "agent_tool_completed");
953    }
954
955    // ── Test 2: budget exceeded terminates stream ───────────────────────
956
957    #[tokio::test]
958    async fn budget_exceeded_terminates_stream() {
959        let server = MockServer::start().await;
960
961        // 3 chunks each costing 0.3; budget is 0.5 → should terminate after chunk 1
962        let body = ndjson_body(&[
963            r#"{"text":"a","cost_usd":0.3}"#,
964            r#"{"text":"b","cost_usd":0.3}"#,
965            r#"{"text":"c","cost_usd":0.3}"#,
966        ]);
967
968        Mock::given(method("POST"))
969            .and(path("/tasks/sendSubscribe"))
970            .respond_with(ResponseTemplate::new(200).set_body_string(body))
971            .mount(&server)
972            .await;
973
974        // Also mock /tasks/cancel so the A2A cancel doesn't fail
975        Mock::given(method("POST"))
976            .and(path("/tasks/cancel"))
977            .respond_with(ResponseTemplate::new(200))
978            .mount(&server)
979            .await;
980
981        let item = make_test_work_item(&server.uri(), Some(5), Some(0.5));
982        let (tx, mut rx) = tokio::sync::mpsc::channel(32);
983
984        let executor = AgentToolExecutor;
985        let result = executor.execute_streaming(&item, tx).await;
986        assert!(result.is_ok());
987
988        let events = collect_events(&mut rx);
989
990        // Find a terminated event with reason "budget_exceeded"
991        let terminated = events.iter().find(|e| e["type"] == "agent_tool_terminated");
992        assert!(
993            terminated.is_some(),
994            "Expected an agent_tool_terminated event, got: {:#?}",
995            events
996        );
997        assert_eq!(terminated.unwrap()["reason"], "budget_exceeded");
998
999        // Should NOT have a completed event
1000        let completed = events.iter().any(|e| e["type"] == "agent_tool_completed");
1001        assert!(
1002            !completed,
1003            "Should not have agent_tool_completed when budget exceeded"
1004        );
1005    }
1006
1007    // ── Test 3: malformed JSON becomes raw ──────────────────────────────
1008
1009    #[tokio::test]
1010    async fn malformed_json_becomes_raw() {
1011        let server = MockServer::start().await;
1012
1013        let body = ndjson_body(&[
1014            r#"{"text":"first"}"#,
1015            "not json at all",
1016            r#"{"text":"third"}"#,
1017        ]);
1018
1019        Mock::given(method("POST"))
1020            .and(path("/tasks/sendSubscribe"))
1021            .respond_with(ResponseTemplate::new(200).set_body_string(body))
1022            .mount(&server)
1023            .await;
1024
1025        let item = make_test_work_item(&server.uri(), Some(5), None);
1026        let (tx, mut rx) = tokio::sync::mpsc::channel(32);
1027
1028        let executor = AgentToolExecutor;
1029        let result = executor.execute_streaming(&item, tx).await;
1030        assert!(result.is_ok());
1031
1032        let events = collect_events(&mut rx);
1033
1034        // events[0] = invoked, events[1] = progress(0), events[2] = progress(1), events[3] = progress(2), events[4] = completed
1035        let progress_events: Vec<&serde_json::Value> = events
1036            .iter()
1037            .filter(|e| e["type"] == "agent_tool_progress")
1038            .collect();
1039
1040        assert_eq!(
1041            progress_events.len(),
1042            3,
1043            "Expected 3 progress events, got {}: {:#?}",
1044            progress_events.len(),
1045            progress_events
1046        );
1047
1048        // First chunk: valid JSON
1049        assert_eq!(progress_events[0]["chunk"]["text"], "first");
1050
1051        // Second chunk: malformed → wrapped in {"raw": "not json at all"}
1052        assert_eq!(progress_events[1]["chunk"]["raw"], "not json at all");
1053
1054        // Third chunk: valid JSON
1055        assert_eq!(progress_events[2]["chunk"]["text"], "third");
1056    }
1057}