1#![allow(clippy::too_many_arguments)]
9
10use crate::executor::{ExecutionResult, NodeExecutor, StreamEventSender};
11use async_trait::async_trait;
12use bytes::BytesMut;
13use futures::StreamExt;
14use jamjet_state::backend::WorkItem;
15use reqwest::Client;
16use serde_json::{json, Value};
17use std::collections::hash_map::DefaultHasher;
18use std::hash::{Hash, Hasher};
19use std::time::Duration;
20use tracing::{debug, instrument, warn};
21
22pub struct AgentToolExecutor;
23
24impl AgentToolExecutor {
25 fn build_client(timeout_ms: u64) -> Result<Client, String> {
27 reqwest::Client::builder()
28 .timeout(std::time::Duration::from_millis(timeout_ms))
29 .build()
30 .map_err(|e| format!("HTTP client: {e}"))
31 }
32
33 fn resolve_url(agent_uri: &str, endpoint: &str) -> Result<String, String> {
36 let is_http =
37 agent_uri.starts_with("https://") || (cfg!(test) && agent_uri.starts_with("http://"));
38 if is_http {
39 Ok(format!("{}/{}", agent_uri.trim_end_matches('/'), endpoint))
40 } else {
41 Err(format!(
42 "Cannot resolve '{}' to HTTP endpoint. \
43 Only https:// agent URIs are supported for remote invocation.",
44 agent_uri
45 ))
46 }
47 }
48
49 async fn execute_sync(
51 &self,
52 item: &WorkItem,
53 agent_uri: &str,
54 protocol: &str,
55 output_key: &str,
56 timeout_ms: u64,
57 input: &Value,
58 input_hash: &str,
59 start: std::time::Instant,
60 ) -> Result<ExecutionResult, String> {
61 let client = Self::build_client(timeout_ms)?;
62 let task_url = Self::resolve_url(agent_uri, "tasks/send")?;
63
64 let resp = client
65 .post(&task_url)
66 .json(&json!({
67 "jsonrpc": "2.0",
68 "method": "tasks/send",
69 "params": { "message": { "parts": [{ "text": input.to_string() }] } }
70 }))
71 .send()
72 .await
73 .map_err(|e| format!("AgentTool invocation failed: {e}"))?;
74
75 if !resp.status().is_success() {
76 let status = resp.status();
77 let body = resp.text().await.unwrap_or_default();
78 return Err(format!("Agent returned error {status}: {body}"));
79 }
80
81 let output: Value = resp
82 .json()
83 .await
84 .map_err(|e| format!("Failed to parse agent response: {e}"))?;
85
86 let duration_ms = start.elapsed().as_millis() as u64;
87
88 Ok(ExecutionResult {
89 output: json!({ output_key: output }),
90 state_patch: json!({
91 "agent_tool_events": [
92 {
93 "type": "agent_tool_invoked",
94 "node_id": &item.node_id,
95 "agent_uri": agent_uri,
96 "mode": "sync",
97 "protocol": protocol,
98 "input_hash": input_hash
99 },
100 {
101 "type": "agent_tool_completed",
102 "node_id": &item.node_id,
103 "output": &output,
104 "total_cost": 0.0,
105 "latency_ms": duration_ms
106 },
107 ]
108 }),
109 duration_ms,
110 gen_ai_system: None,
111 gen_ai_model: None,
112 input_tokens: None,
113 output_tokens: None,
114 finish_reason: None,
115 })
116 }
117
118 async fn stream_ndjson(
127 &self,
128 item: &WorkItem,
129 agent_uri: &str,
130 protocol: &str,
131 output_key: &str,
132 timeout_ms: u64,
133 input: &Value,
134 input_hash: &str,
135 max_cost_usd: Option<f64>,
136 start: std::time::Instant,
137 ) -> Result<ExecutionResult, String> {
138 let client = Self::build_client(timeout_ms)?;
139 let task_url = Self::resolve_url(agent_uri, "tasks/sendSubscribe")?;
140
141 let resp = client
142 .post(&task_url)
143 .json(&json!({
144 "jsonrpc": "2.0",
145 "method": "tasks/sendSubscribe",
146 "params": { "message": { "parts": [{ "text": input.to_string() }] } }
147 }))
148 .send()
149 .await
150 .map_err(|e| format!("AgentTool streaming invocation failed: {e}"))?;
151
152 if !resp.status().is_success() {
153 let status = resp.status();
154 let body = resp.text().await.unwrap_or_default();
155 return Err(format!("Agent returned error {status}: {body}"));
156 }
157
158 let body = resp
160 .text()
161 .await
162 .map_err(|e| format!("Failed to read streaming response body: {e}"))?;
163
164 let mut events: Vec<Value> = vec![json!({
165 "type": "agent_tool_invoked",
166 "node_id": &item.node_id,
167 "agent_uri": agent_uri,
168 "mode": "streaming",
169 "protocol": protocol,
170 "input_hash": input_hash
171 })];
172
173 let mut accumulated_cost: f64 = 0.0;
174 let mut terminated_early = false;
175 let mut last_chunk: Value = json!(null);
176 let mut chunk_index: u64 = 0;
177
178 for line in body.lines() {
179 let trimmed = line.trim();
180 if trimmed.is_empty() {
181 continue;
182 }
183
184 let chunk: Value =
185 serde_json::from_str(trimmed).unwrap_or_else(|_| json!({ "raw": trimmed }));
186
187 if let Some(cost) = chunk.get("cost_usd").and_then(|v| v.as_f64()) {
189 accumulated_cost += cost;
190 }
191
192 events.push(json!({
193 "type": "agent_tool_progress",
194 "node_id": &item.node_id,
195 "chunk_index": chunk_index,
196 "chunk": &chunk,
197 "accumulated_cost_usd": accumulated_cost
198 }));
199
200 last_chunk = chunk;
201 chunk_index += 1;
202
203 if let Some(budget) = max_cost_usd {
205 if accumulated_cost > budget {
206 terminated_early = true;
207 debug!(
208 node_id = %item.node_id,
209 accumulated_cost,
210 budget,
211 "AgentTool streaming: budget exceeded, terminating early"
212 );
213 break;
214 }
215 }
216 }
217
218 let duration_ms = start.elapsed().as_millis() as u64;
219
220 if terminated_early {
221 events.push(json!({
222 "type": "agent_tool_terminated",
223 "node_id": &item.node_id,
224 "reason": "budget_exceeded",
225 "accumulated_cost_usd": accumulated_cost,
226 "latency_ms": duration_ms
227 }));
228 } else {
229 events.push(json!({
230 "type": "agent_tool_completed",
231 "node_id": &item.node_id,
232 "output": &last_chunk,
233 "total_cost": accumulated_cost,
234 "latency_ms": duration_ms
235 }));
236 }
237
238 Ok(ExecutionResult {
239 output: json!({ output_key: &last_chunk }),
240 state_patch: json!({ "agent_tool_events": events }),
241 duration_ms,
242 gen_ai_system: None,
243 gen_ai_model: None,
244 input_tokens: None,
245 output_tokens: None,
246 finish_reason: None,
247 })
248 }
249
250 async fn send_a2a_cancel(client: &Client, agent_uri: &str, task_id: &Option<String>) {
252 if let Some(ref id) = task_id {
253 if let Ok(cancel_url) = Self::resolve_url(agent_uri, "tasks/cancel") {
254 let _ = client
255 .post(&cancel_url)
256 .json(&serde_json::json!({ "id": id }))
257 .timeout(Duration::from_secs(5))
258 .send()
259 .await;
260 }
261 }
262 }
263
264 async fn execute_conversational(
270 &self,
271 item: &WorkItem,
272 agent_uri: &str,
273 protocol: &str,
274 output_key: &str,
275 timeout_ms: u64,
276 input: &Value,
277 input_hash: &str,
278 start: std::time::Instant,
279 ) -> Result<ExecutionResult, String> {
280 let p = &item.payload;
281
282 let max_turns = p
283 .get("mode")
284 .and_then(|m| m.get("conversational"))
285 .and_then(|c| c.get("max_turns"))
286 .and_then(|v| v.as_u64())
287 .unwrap_or(5) as usize;
288
289 let client = Self::build_client(timeout_ms)?;
290 let task_url = Self::resolve_url(agent_uri, "tasks/send")?;
291
292 let mut events: Vec<Value> = vec![json!({
293 "type": "agent_tool_invoked",
294 "node_id": &item.node_id,
295 "agent_uri": agent_uri,
296 "mode": "conversational",
297 "protocol": protocol,
298 "input_hash": input_hash
299 })];
300
301 let mut current_input = input.clone();
302 let mut final_output: Value = json!(null);
303
304 for turn in 0..max_turns {
305 events.push(json!({
307 "type": "agent_tool_turn",
308 "node_id": &item.node_id,
309 "turn": turn,
310 "direction": "outbound",
311 "input": ¤t_input
312 }));
313
314 debug!(
315 node_id = %item.node_id,
316 turn,
317 "AgentTool conversational: sending turn"
318 );
319
320 let resp = client
321 .post(&task_url)
322 .json(&json!({
323 "jsonrpc": "2.0",
324 "method": "tasks/send",
325 "params": { "message": { "parts": [{ "text": current_input.to_string() }] } }
326 }))
327 .send()
328 .await
329 .map_err(|e| format!("AgentTool turn {turn} failed: {e}"))?;
330
331 if !resp.status().is_success() {
332 let status = resp.status();
333 let body = resp.text().await.unwrap_or_default();
334 return Err(format!(
335 "Agent returned error {status} on turn {turn}: {body}"
336 ));
337 }
338
339 let response: Value = resp
340 .json()
341 .await
342 .map_err(|e| format!("Failed to parse agent response on turn {turn}: {e}"))?;
343
344 events.push(json!({
346 "type": "agent_tool_turn",
347 "node_id": &item.node_id,
348 "turn": turn,
349 "direction": "inbound",
350 "output": &response
351 }));
352
353 final_output = response.clone();
354
355 let status = response
357 .get("status")
358 .and_then(|v| v.as_str())
359 .unwrap_or("");
360 if status == "completed" {
361 debug!(
362 node_id = %item.node_id,
363 turn,
364 "AgentTool conversational: agent signalled completion"
365 );
366 break;
367 }
368
369 current_input = response
371 .get("output")
372 .cloned()
373 .unwrap_or_else(|| response.clone());
374 }
375
376 let duration_ms = start.elapsed().as_millis() as u64;
377
378 events.push(json!({
379 "type": "agent_tool_completed",
380 "node_id": &item.node_id,
381 "output": &final_output,
382 "total_cost": 0.0,
383 "latency_ms": duration_ms
384 }));
385
386 Ok(ExecutionResult {
387 output: json!({ output_key: &final_output }),
388 state_patch: json!({ "agent_tool_events": events }),
389 duration_ms,
390 gen_ai_system: None,
391 gen_ai_model: None,
392 input_tokens: None,
393 output_tokens: None,
394 finish_reason: None,
395 })
396 }
397}
398
399#[async_trait]
400impl NodeExecutor for AgentToolExecutor {
401 #[instrument(skip(self, item), fields(node_id = %item.node_id))]
402 async fn execute(&self, item: &WorkItem) -> Result<ExecutionResult, String> {
403 let start = std::time::Instant::now();
404 let p = &item.payload;
405
406 let agent_uri = p
408 .get("agent")
409 .and_then(|a| {
410 a.get("explicit")
411 .and_then(|v| v.as_str())
412 .or_else(|| a.as_str())
413 })
414 .ok_or("AgentTool: missing 'agent' URI in payload")?;
415
416 let mode = if let Some(mode_val) = p.get("mode") {
419 if let Some(s) = mode_val.as_str() {
420 s.to_string()
421 } else if mode_val.get("conversational").is_some() {
422 "conversational".to_string()
423 } else if mode_val.get("streaming").is_some() {
424 "streaming".to_string()
425 } else {
426 "sync".to_string()
427 }
428 } else {
429 "sync".to_string()
430 };
431 let output_key = p
432 .get("output_key")
433 .and_then(|v| v.as_str())
434 .unwrap_or("result");
435 let timeout_ms = p
436 .get("timeout_ms")
437 .and_then(|v| v.as_u64())
438 .unwrap_or(30_000);
439 let input = p.get("input").cloned().unwrap_or(json!({}));
440 let max_cost_usd = p
442 .get("budget")
443 .and_then(|b| b.get("max_cost_usd"))
444 .and_then(|v| v.as_f64())
445 .or_else(|| p.get("max_cost_usd").and_then(|v| v.as_f64()));
446
447 if p.get("agent").and_then(|a| a.get("auto")).is_some() {
449 return Err(
450 "AgentTool with 'auto' target was not expanded at compile time. \
451 Use the compiler to expand 'auto' into coordinator + agent_tool nodes."
452 .into(),
453 );
454 }
455
456 let protocol = if agent_uri.starts_with("https://") {
458 "a2a"
459 } else if agent_uri.starts_with("jamjet://") {
460 "local"
461 } else {
462 "mcp"
463 };
464
465 let mut hasher = DefaultHasher::new();
467 input.to_string().hash(&mut hasher);
468 let input_hash = format!("{:016x}", hasher.finish());
469
470 debug!(agent_uri = %agent_uri, mode = %mode, protocol = %protocol, "AgentTool: invoking");
471
472 match mode.as_str() {
473 "sync" => {
474 self.execute_sync(
475 item,
476 agent_uri,
477 protocol,
478 output_key,
479 timeout_ms,
480 &input,
481 &input_hash,
482 start,
483 )
484 .await
485 }
486 "streaming" => {
487 self.stream_ndjson(
488 item,
489 agent_uri,
490 protocol,
491 output_key,
492 timeout_ms,
493 &input,
494 &input_hash,
495 max_cost_usd,
496 start,
497 )
498 .await
499 }
500 "conversational" => {
501 self.execute_conversational(
502 item,
503 agent_uri,
504 protocol,
505 output_key,
506 timeout_ms,
507 &input,
508 &input_hash,
509 start,
510 )
511 .await
512 }
513 other => Err(format!("Unknown agent_tool mode: '{other}'")),
514 }
515 }
516
517 #[instrument(skip(self, item, tx), fields(node_id = %item.node_id))]
520 async fn execute_streaming(
521 &self,
522 item: &WorkItem,
523 tx: StreamEventSender,
524 ) -> Result<ExecutionResult, String> {
525 let start = std::time::Instant::now();
526 let p = &item.payload;
527
528 let agent_uri = p
530 .get("agent")
531 .and_then(|a| {
532 a.get("explicit")
533 .and_then(|v| v.as_str())
534 .or_else(|| a.as_str())
535 })
536 .ok_or("AgentTool: missing 'agent' URI in payload")?;
537
538 let mode = if let Some(mode_val) = p.get("mode") {
539 if let Some(s) = mode_val.as_str() {
540 s.to_string()
541 } else if mode_val.get("streaming").is_some() {
542 "streaming".to_string()
543 } else {
544 "sync".to_string()
545 }
546 } else {
547 "sync".to_string()
548 };
549
550 if mode != "streaming" {
552 return self.execute(item).await;
553 }
554
555 let input = p.get("input").cloned().unwrap_or(json!({}));
556
557 let max_cost_usd = p
558 .get("budget")
559 .and_then(|b| b.get("max_cost_usd"))
560 .and_then(|v| v.as_f64())
561 .or_else(|| p.get("max_cost_usd").and_then(|v| v.as_f64()));
562
563 let idle_timeout_secs = p
564 .get("idle_timeout_secs")
565 .and_then(|v| v.as_u64())
566 .unwrap_or(30);
567
568 let protocol = if agent_uri.starts_with("https://")
570 || (cfg!(test) && agent_uri.starts_with("http://"))
571 {
572 "a2a"
573 } else if agent_uri.starts_with("jamjet://") {
574 "local"
575 } else {
576 "mcp"
577 };
578
579 let mut hasher = DefaultHasher::new();
581 input.to_string().hash(&mut hasher);
582 let input_hash = format!("{:016x}", hasher.finish());
583
584 let client = reqwest::Client::builder()
586 .build()
587 .map_err(|e| format!("HTTP client: {e}"))?;
588
589 let now_ms = || -> u64 {
591 std::time::SystemTime::now()
592 .duration_since(std::time::UNIX_EPOCH)
593 .unwrap()
594 .as_millis() as u64
595 };
596
597 let invoked_event = json!({
598 "type": "agent_tool_invoked",
599 "node_id": &item.node_id,
600 "agent_uri": agent_uri,
601 "mode": &mode,
602 "protocol": protocol,
603 "input_hash": &input_hash,
604 "timestamp_ms": now_ms()
605 });
606 if tx.send(invoked_event).await.is_err() {
607 return Err("Streaming receiver dropped before invocation event".into());
608 }
609
610 let task_url = Self::resolve_url(agent_uri, "tasks/sendSubscribe")?;
612 let resp = client
613 .post(&task_url)
614 .json(&json!({
615 "jsonrpc": "2.0",
616 "method": "tasks/sendSubscribe",
617 "params": { "message": { "parts": [{ "text": input.to_string() }] } }
618 }))
619 .send()
620 .await
621 .map_err(|e| format!("AgentTool streaming invocation failed: {e}"))?;
622
623 if !resp.status().is_success() {
624 let status = resp.status();
625 let body = resp.text().await.unwrap_or_default();
626 return Err(format!("Agent returned error {status}: {body}"));
627 }
628
629 let mut stream = resp.bytes_stream();
631 let mut line_buf = BytesMut::new();
632 let mut chunk_index: u64 = 0;
633 let mut accumulated_cost: f64 = 0.0;
634 let mut task_id: Option<String> = None;
635 let mut last_chunk: Value = json!(null);
636 let output_key = p
637 .get("output_key")
638 .and_then(|v| v.as_str())
639 .unwrap_or("result");
640 let mut terminated_early = false;
641 let mut terminal_error: Option<String> = None;
642 let idle_dur = Duration::from_secs(idle_timeout_secs);
643
644 loop {
645 match tokio::time::timeout(idle_dur, stream.next()).await {
646 Err(_elapsed) => {
648 warn!(
649 node_id = %item.node_id,
650 idle_timeout_secs,
651 "AgentTool streaming: idle timeout, terminating"
652 );
653 let _ = tx
654 .send(json!({
655 "type": "agent_tool_terminated",
656 "node_id": &item.node_id,
657 "reason": "idle_timeout",
658 "accumulated_cost_usd": accumulated_cost,
659 "latency_ms": start.elapsed().as_millis() as u64,
660 "timestamp_ms": now_ms()
661 }))
662 .await;
663 Self::send_a2a_cancel(&client, agent_uri, &task_id).await;
664 terminated_early = true;
665 terminal_error =
666 Some(format!("AgentTool idle timeout after {idle_timeout_secs}s"));
667 break;
668 }
669 Ok(None) => {
671 break;
672 }
673 Ok(Some(Err(e))) => {
675 warn!(
676 node_id = %item.node_id,
677 error = %e,
678 "AgentTool streaming: network error"
679 );
680 let _ = tx
681 .send(json!({
682 "type": "agent_tool_error",
683 "node_id": &item.node_id,
684 "error": e.to_string(),
685 "timestamp_ms": now_ms()
686 }))
687 .await;
688 terminated_early = true;
689 terminal_error = Some(format!("AgentTool stream error: {e}"));
690 break;
691 }
692 Ok(Some(Ok(bytes))) => {
694 line_buf.extend_from_slice(&bytes);
695
696 while let Some(newline_pos) = line_buf.iter().position(|&b| b == b'\n') {
698 let line_bytes = line_buf.split_to(newline_pos + 1);
699 let line_str = match std::str::from_utf8(&line_bytes) {
700 Ok(s) => s.trim().to_string(),
701 Err(e) => {
702 warn!(
703 node_id = %item.node_id,
704 error = %e,
705 "AgentTool streaming: non-UTF8 chunk, skipping"
706 );
707 continue;
708 }
709 };
710 if line_str.is_empty() {
711 continue;
712 }
713
714 let chunk: Value = serde_json::from_str(&line_str)
715 .unwrap_or_else(|_| json!({ "raw": &line_str }));
716 last_chunk = chunk.clone();
717
718 if task_id.is_none() {
720 if let Some(id) = chunk.get("id").and_then(|v| v.as_str()) {
721 task_id = Some(id.to_string());
722 }
723 }
724
725 if let Some(cost) = chunk.get("cost_usd").and_then(|v| v.as_f64()) {
727 accumulated_cost += cost;
728 }
729
730 let progress = json!({
732 "type": "agent_tool_progress",
733 "node_id": &item.node_id,
734 "chunk_index": chunk_index,
735 "chunk": &chunk,
736 "accumulated_cost_usd": accumulated_cost,
737 "timestamp_ms": now_ms()
738 });
739 chunk_index += 1;
740
741 if tx.send(progress).await.is_err() {
742 debug!(
744 node_id = %item.node_id,
745 "AgentTool streaming: receiver dropped, cancelling"
746 );
747 Self::send_a2a_cancel(&client, agent_uri, &task_id).await;
748 terminated_early = true;
749 terminal_error = Some("AgentTool stream receiver dropped".into());
750 break;
751 }
752
753 if let Some(budget) = max_cost_usd {
755 if accumulated_cost > budget {
756 debug!(
757 node_id = %item.node_id,
758 accumulated_cost,
759 budget,
760 "AgentTool streaming: budget exceeded, terminating"
761 );
762 let _ = tx
763 .send(json!({
764 "type": "agent_tool_terminated",
765 "node_id": &item.node_id,
766 "reason": "budget_exceeded",
767 "accumulated_cost_usd": accumulated_cost,
768 "latency_ms": start.elapsed().as_millis() as u64,
769 "timestamp_ms": now_ms()
770 }))
771 .await;
772 Self::send_a2a_cancel(&client, agent_uri, &task_id).await;
773 terminated_early = true;
774 break;
775 }
776 }
777 }
778
779 if terminated_early {
781 break;
782 }
783 }
784 }
785 }
786
787 if !terminated_early && !line_buf.is_empty() {
789 if let Ok(remaining) = std::str::from_utf8(&line_buf) {
790 let trimmed = remaining.trim();
791 if !trimmed.is_empty() {
792 let chunk: Value =
793 serde_json::from_str(trimmed).unwrap_or_else(|_| json!({ "raw": trimmed }));
794 last_chunk = chunk.clone();
795
796 if let Some(cost) = chunk.get("cost_usd").and_then(|v| v.as_f64()) {
797 accumulated_cost += cost;
798 }
799
800 let _ = tx
801 .send(json!({
802 "type": "agent_tool_progress",
803 "node_id": &item.node_id,
804 "chunk_index": chunk_index,
805 "chunk": &chunk,
806 "accumulated_cost_usd": accumulated_cost,
807 "timestamp_ms": now_ms()
808 }))
809 .await;
810 }
811 }
812 }
813
814 if let Some(error) = terminal_error {
816 return Err(error);
817 }
818
819 let duration_ms = start.elapsed().as_millis() as u64;
821 if !terminated_early {
822 let _ = tx
823 .send(json!({
824 "type": "agent_tool_completed",
825 "node_id": &item.node_id,
826 "output": &last_chunk,
827 "total_cost": accumulated_cost,
828 "latency_ms": duration_ms,
829 "timestamp_ms": now_ms()
830 }))
831 .await;
832 }
833
834 Ok(ExecutionResult {
835 output: json!({ output_key: last_chunk }),
836 state_patch: json!({}),
837 duration_ms,
838 gen_ai_system: None,
839 gen_ai_model: None,
840 input_tokens: None,
841 output_tokens: None,
842 finish_reason: None,
843 })
844 }
845}
846
847#[cfg(test)]
850mod tests {
851 use super::*;
852 use crate::executor::NodeExecutor;
853 use wiremock::matchers::{method, path};
854 use wiremock::{Mock, MockServer, ResponseTemplate};
855
856 fn make_test_work_item(
858 agent_uri: &str,
859 idle_timeout: Option<u64>,
860 max_cost: Option<f64>,
861 ) -> WorkItem {
862 let mut payload = serde_json::json!({
863 "agent": agent_uri,
864 "mode": "streaming",
865 "input": {"query": "test"},
866 "workflow_id": "wf1",
867 "workflow_version": "1.0.0",
868 });
869 if let Some(t) = idle_timeout {
870 payload["idle_timeout_secs"] = serde_json::json!(t);
871 }
872 if let Some(c) = max_cost {
873 payload["budget"] = serde_json::json!({"max_cost_usd": c});
874 }
875 WorkItem {
876 id: uuid::Uuid::new_v4(),
877 execution_id: jamjet_core::workflow::ExecutionId::new(),
878 node_id: "n1".into(),
879 queue_type: "agent_tool".into(),
880 payload,
881 attempt: 1,
882 max_attempts: 3,
883 created_at: chrono::Utc::now(),
884 lease_expires_at: None,
885 worker_id: None,
886 tenant_id: "default".into(),
887 }
888 }
889
890 fn ndjson_body(lines: &[&str]) -> String {
892 lines.join("\n") + "\n"
893 }
894
895 fn collect_events(
897 rx: &mut tokio::sync::mpsc::Receiver<serde_json::Value>,
898 ) -> Vec<serde_json::Value> {
899 let mut events = Vec::new();
900 while let Ok(ev) = rx.try_recv() {
901 events.push(ev);
902 }
903 events
904 }
905
906 #[tokio::test]
909 async fn streams_ndjson_chunks_in_order() {
910 let server = MockServer::start().await;
911
912 let body = ndjson_body(&[r#"{"text":"hello"}"#, r#"{"text":"world"}"#]);
913
914 Mock::given(method("POST"))
915 .and(path("/tasks/sendSubscribe"))
916 .respond_with(ResponseTemplate::new(200).set_body_string(body))
917 .mount(&server)
918 .await;
919
920 let item = make_test_work_item(&server.uri(), Some(5), None);
921 let (tx, mut rx) = tokio::sync::mpsc::channel(32);
922
923 let executor = AgentToolExecutor;
924 let result = executor.execute_streaming(&item, tx).await;
925 assert!(
926 result.is_ok(),
927 "execute_streaming failed: {:?}",
928 result.err()
929 );
930
931 let events = collect_events(&mut rx);
932
933 assert!(
935 events.len() >= 4,
936 "Expected at least 4 events, got {}: {:#?}",
937 events.len(),
938 events
939 );
940
941 assert_eq!(events[0]["type"], "agent_tool_invoked");
942 assert_eq!(events[0]["mode"], "streaming");
943
944 assert_eq!(events[1]["type"], "agent_tool_progress");
945 assert_eq!(events[1]["chunk_index"], 0);
946 assert_eq!(events[1]["chunk"]["text"], "hello");
947
948 assert_eq!(events[2]["type"], "agent_tool_progress");
949 assert_eq!(events[2]["chunk_index"], 1);
950 assert_eq!(events[2]["chunk"]["text"], "world");
951
952 assert_eq!(events[3]["type"], "agent_tool_completed");
953 }
954
955 #[tokio::test]
958 async fn budget_exceeded_terminates_stream() {
959 let server = MockServer::start().await;
960
961 let body = ndjson_body(&[
963 r#"{"text":"a","cost_usd":0.3}"#,
964 r#"{"text":"b","cost_usd":0.3}"#,
965 r#"{"text":"c","cost_usd":0.3}"#,
966 ]);
967
968 Mock::given(method("POST"))
969 .and(path("/tasks/sendSubscribe"))
970 .respond_with(ResponseTemplate::new(200).set_body_string(body))
971 .mount(&server)
972 .await;
973
974 Mock::given(method("POST"))
976 .and(path("/tasks/cancel"))
977 .respond_with(ResponseTemplate::new(200))
978 .mount(&server)
979 .await;
980
981 let item = make_test_work_item(&server.uri(), Some(5), Some(0.5));
982 let (tx, mut rx) = tokio::sync::mpsc::channel(32);
983
984 let executor = AgentToolExecutor;
985 let result = executor.execute_streaming(&item, tx).await;
986 assert!(result.is_ok());
987
988 let events = collect_events(&mut rx);
989
990 let terminated = events.iter().find(|e| e["type"] == "agent_tool_terminated");
992 assert!(
993 terminated.is_some(),
994 "Expected an agent_tool_terminated event, got: {:#?}",
995 events
996 );
997 assert_eq!(terminated.unwrap()["reason"], "budget_exceeded");
998
999 let completed = events.iter().any(|e| e["type"] == "agent_tool_completed");
1001 assert!(
1002 !completed,
1003 "Should not have agent_tool_completed when budget exceeded"
1004 );
1005 }
1006
1007 #[tokio::test]
1010 async fn malformed_json_becomes_raw() {
1011 let server = MockServer::start().await;
1012
1013 let body = ndjson_body(&[
1014 r#"{"text":"first"}"#,
1015 "not json at all",
1016 r#"{"text":"third"}"#,
1017 ]);
1018
1019 Mock::given(method("POST"))
1020 .and(path("/tasks/sendSubscribe"))
1021 .respond_with(ResponseTemplate::new(200).set_body_string(body))
1022 .mount(&server)
1023 .await;
1024
1025 let item = make_test_work_item(&server.uri(), Some(5), None);
1026 let (tx, mut rx) = tokio::sync::mpsc::channel(32);
1027
1028 let executor = AgentToolExecutor;
1029 let result = executor.execute_streaming(&item, tx).await;
1030 assert!(result.is_ok());
1031
1032 let events = collect_events(&mut rx);
1033
1034 let progress_events: Vec<&serde_json::Value> = events
1036 .iter()
1037 .filter(|e| e["type"] == "agent_tool_progress")
1038 .collect();
1039
1040 assert_eq!(
1041 progress_events.len(),
1042 3,
1043 "Expected 3 progress events, got {}: {:#?}",
1044 progress_events.len(),
1045 progress_events
1046 );
1047
1048 assert_eq!(progress_events[0]["chunk"]["text"], "first");
1050
1051 assert_eq!(progress_events[1]["chunk"]["raw"], "not json at all");
1053
1054 assert_eq!(progress_events[2]["chunk"]["text"], "third");
1056 }
1057}