1use async_trait::async_trait;
26
27use crate::error::AgentError;
28use crate::state::State;
29
30mod dispatch;
31mod fallback;
32mod fn_agent;
33mod llm;
34mod loop_agent;
35mod map_over;
36mod parallel;
37mod race;
38mod route;
39mod sequential;
40mod tap;
41mod timeout;
42
43pub use dispatch::{DispatchTextAgent, JoinTextAgent, TaskRegistry};
44pub use fallback::FallbackTextAgent;
45pub use fn_agent::FnTextAgent;
46pub use llm::LlmTextAgent;
47pub use loop_agent::LoopTextAgent;
48pub use map_over::MapOverTextAgent;
49pub use parallel::ParallelTextAgent;
50pub use race::RaceTextAgent;
51pub use route::{RouteRule, RouteTextAgent};
52pub use sequential::SequentialTextAgent;
53pub use tap::TapTextAgent;
54pub use timeout::TimeoutTextAgent;
55
56#[async_trait]
63pub trait TextAgent: Send + Sync {
64 fn name(&self) -> &str;
66
67 async fn run(&self, state: &State) -> Result<String, AgentError>;
69}
70
71const _: () = {
73 fn _assert_object_safe(_: &dyn TextAgent) {}
74};
75
76#[cfg(test)]
79mod tests {
80 use super::*;
81 use crate::llm::{BaseLlm, LlmError, LlmRequest, LlmResponse};
82 use rs_genai::prelude::{Content, FunctionCall, Part, Role};
83 use std::sync::Arc;
84 use std::time::Duration;
85
86 struct FixedLlm {
88 response: String,
89 }
90
91 #[async_trait]
92 impl BaseLlm for FixedLlm {
93 fn model_id(&self) -> &str {
94 "fixed-mock"
95 }
96
97 async fn generate(&self, _req: LlmRequest) -> Result<LlmResponse, LlmError> {
98 Ok(LlmResponse {
99 content: Content {
100 role: Some(Role::Model),
101 parts: vec![Part::Text {
102 text: self.response.clone(),
103 }],
104 },
105 finish_reason: Some("STOP".into()),
106 usage: None,
107 })
108 }
109 }
110
111 struct EchoLlm {
113 prefix: String,
114 }
115
116 #[async_trait]
117 impl BaseLlm for EchoLlm {
118 fn model_id(&self) -> &str {
119 "echo-mock"
120 }
121
122 async fn generate(&self, req: LlmRequest) -> Result<LlmResponse, LlmError> {
123 let input_text: String = req
124 .contents
125 .iter()
126 .flat_map(|c| &c.parts)
127 .filter_map(|p| match p {
128 Part::Text { text } => Some(text.as_str()),
129 _ => None,
130 })
131 .collect::<Vec<_>>()
132 .join(" ");
133
134 Ok(LlmResponse {
135 content: Content {
136 role: Some(Role::Model),
137 parts: vec![Part::Text {
138 text: format!("{}{}", self.prefix, input_text),
139 }],
140 },
141 finish_reason: Some("STOP".into()),
142 usage: None,
143 })
144 }
145 }
146
147 struct ToolCallingLlm {
149 tool_name: String,
150 tool_args: serde_json::Value,
151 final_response: String,
152 }
153
154 #[async_trait]
155 impl BaseLlm for ToolCallingLlm {
156 fn model_id(&self) -> &str {
157 "tool-mock"
158 }
159
160 async fn generate(&self, req: LlmRequest) -> Result<LlmResponse, LlmError> {
161 let has_tool_response = req.contents.iter().any(|c| {
163 c.parts
164 .iter()
165 .any(|p| matches!(p, Part::FunctionResponse { .. }))
166 });
167
168 if has_tool_response {
169 Ok(LlmResponse {
171 content: Content {
172 role: Some(Role::Model),
173 parts: vec![Part::Text {
174 text: self.final_response.clone(),
175 }],
176 },
177 finish_reason: Some("STOP".into()),
178 usage: None,
179 })
180 } else {
181 Ok(LlmResponse {
183 content: Content {
184 role: Some(Role::Model),
185 parts: vec![Part::FunctionCall {
186 function_call: FunctionCall {
187 name: self.tool_name.clone(),
188 args: self.tool_args.clone(),
189 id: Some("call-1".into()),
190 },
191 }],
192 },
193 finish_reason: None,
194 usage: None,
195 })
196 }
197 }
198 }
199
200 struct FailLlm;
202
203 #[async_trait]
204 impl BaseLlm for FailLlm {
205 fn model_id(&self) -> &str {
206 "fail-mock"
207 }
208
209 async fn generate(&self, _req: LlmRequest) -> Result<LlmResponse, LlmError> {
210 Err(LlmError::RequestFailed("intentional failure".into()))
211 }
212 }
213
214 #[test]
217 fn text_agent_is_object_safe() {
218 fn _assert(_: &dyn TextAgent) {}
219 }
220
221 #[tokio::test]
224 async fn llm_text_agent_returns_text() {
225 let llm = Arc::new(FixedLlm {
226 response: "Hello world".into(),
227 });
228 let agent = LlmTextAgent::new("greeter", llm).instruction("Say hello");
229 let state = State::new();
230 let result = agent.run(&state).await.unwrap();
231 assert_eq!(result, "Hello world");
232 assert_eq!(state.get::<String>("output"), Some("Hello world".into()));
233 }
234
235 #[tokio::test]
236 async fn llm_text_agent_reads_input_from_state() {
237 let llm = Arc::new(EchoLlm {
238 prefix: "Echo: ".into(),
239 });
240 let agent = LlmTextAgent::new("echoer", llm);
241 let state = State::new();
242 state.set("input", "test message");
243 let result = agent.run(&state).await.unwrap();
244 assert!(result.contains("test message"));
245 }
246
247 #[tokio::test]
248 async fn llm_text_agent_dispatches_tools() {
249 let llm = Arc::new(ToolCallingLlm {
250 tool_name: "get_weather".into(),
251 tool_args: serde_json::json!({"city": "London"}),
252 final_response: "The weather is sunny".into(),
253 });
254
255 let mut dispatcher = crate::tool::ToolDispatcher::new();
256 dispatcher.register_function(Arc::new(crate::tool::SimpleTool::new(
257 "get_weather",
258 "Get weather",
259 None,
260 |_args| async { Ok(serde_json::json!({"temp": 22})) },
261 )));
262
263 let agent = LlmTextAgent::new("weather", llm).tools(Arc::new(dispatcher));
264 let state = State::new();
265 let result = agent.run(&state).await.unwrap();
266 assert_eq!(result, "The weather is sunny");
267 }
268
269 #[tokio::test]
270 async fn llm_text_agent_propagates_llm_error() {
271 let llm = Arc::new(FailLlm);
272 let agent = LlmTextAgent::new("failer", llm);
273 let state = State::new();
274 let result = agent.run(&state).await;
275 assert!(result.is_err());
276 }
277
278 #[tokio::test]
281 async fn fn_agent_transforms_state() {
282 let agent = FnTextAgent::new("upper", |state: &State| {
283 let input = state.get::<String>("input").unwrap_or_default();
284 let upper = input.to_uppercase();
285 state.set("output", &upper);
286 Ok(upper)
287 });
288
289 let state = State::new();
290 state.set("input", "hello");
291 let result = agent.run(&state).await.unwrap();
292 assert_eq!(result, "HELLO");
293 assert_eq!(state.get::<String>("output"), Some("HELLO".into()));
294 }
295
296 #[tokio::test]
297 async fn fn_agent_can_fail() {
298 let agent = FnTextAgent::new("failer", |_state: &State| {
299 Err(AgentError::Other("nope".into()))
300 });
301 let state = State::new();
302 assert!(agent.run(&state).await.is_err());
303 }
304
305 #[tokio::test]
308 async fn sequential_chains_agents() {
309 let llm1: Arc<dyn BaseLlm> = Arc::new(FixedLlm {
310 response: "step1 done".into(),
311 });
312 let llm2: Arc<dyn BaseLlm> = Arc::new(EchoLlm {
313 prefix: "step2: ".into(),
314 });
315
316 let children: Vec<Arc<dyn TextAgent>> = vec![
317 Arc::new(LlmTextAgent::new("step1", llm1)),
318 Arc::new(LlmTextAgent::new("step2", llm2)),
319 ];
320
321 let pipeline = SequentialTextAgent::new("pipeline", children);
322 let state = State::new();
323 let result = pipeline.run(&state).await.unwrap();
324 assert!(result.contains("step2:"));
326 assert!(result.contains("step1 done"));
327 }
328
329 #[tokio::test]
330 async fn sequential_stops_on_error() {
331 let children: Vec<Arc<dyn TextAgent>> = vec![
332 Arc::new(LlmTextAgent::new(
333 "ok",
334 Arc::new(FixedLlm {
335 response: "fine".into(),
336 }),
337 )),
338 Arc::new(LlmTextAgent::new("fail", Arc::new(FailLlm))),
339 Arc::new(LlmTextAgent::new(
340 "never",
341 Arc::new(FixedLlm {
342 response: "unreachable".into(),
343 }),
344 )),
345 ];
346
347 let pipeline = SequentialTextAgent::new("pipeline", children);
348 let state = State::new();
349 assert!(pipeline.run(&state).await.is_err());
350 }
351
352 #[tokio::test]
353 async fn sequential_empty_returns_empty() {
354 let pipeline = SequentialTextAgent::new("empty", vec![]);
355 let state = State::new();
356 let result = pipeline.run(&state).await.unwrap();
357 assert_eq!(result, "");
358 }
359
360 #[tokio::test]
363 async fn parallel_runs_concurrently() {
364 let branches: Vec<Arc<dyn TextAgent>> = vec![
365 Arc::new(FnTextAgent::new("a", |state: &State| {
366 state.set("key_a", "val_a");
367 Ok("result_a".into())
368 })),
369 Arc::new(FnTextAgent::new("b", |state: &State| {
370 state.set("key_b", "val_b");
371 Ok("result_b".into())
372 })),
373 ];
374
375 let par = ParallelTextAgent::new("parallel", branches);
376 let state = State::new();
377 let result = par.run(&state).await.unwrap();
378 assert!(result.contains("result_a"));
379 assert!(result.contains("result_b"));
380 assert_eq!(state.get::<String>("key_a"), Some("val_a".into()));
381 assert_eq!(state.get::<String>("key_b"), Some("val_b".into()));
382 }
383
384 #[tokio::test]
385 async fn parallel_fails_if_any_fails() {
386 let branches: Vec<Arc<dyn TextAgent>> = vec![
387 Arc::new(FnTextAgent::new("ok", |_| Ok("fine".into()))),
388 Arc::new(FnTextAgent::new("fail", |_| {
389 Err(AgentError::Other("boom".into()))
390 })),
391 ];
392
393 let par = ParallelTextAgent::new("parallel", branches);
394 let state = State::new();
395 assert!(par.run(&state).await.is_err());
396 }
397
398 #[tokio::test]
401 async fn loop_runs_max_iterations() {
402 let counter = Arc::new(std::sync::atomic::AtomicU32::new(0));
403 let counter_clone = counter.clone();
404
405 let body = Arc::new(FnTextAgent::new("counter", move |_state: &State| {
406 counter_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
407 Ok("tick".into())
408 }));
409
410 let loop_agent = LoopTextAgent::new("loop", body, 5);
411 let state = State::new();
412 loop_agent.run(&state).await.unwrap();
413 assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 5);
414 }
415
416 #[tokio::test]
417 async fn loop_breaks_on_predicate() {
418 let body = Arc::new(FnTextAgent::new("incrementer", |state: &State| {
419 let n = state.get::<i32>("n").unwrap_or(0);
420 state.set("n", n + 1);
421 Ok(format!("n={}", n + 1))
422 }));
423
424 let loop_agent = LoopTextAgent::new("loop", body, 100)
425 .until(|state: &State| state.get::<i32>("n").unwrap_or(0) >= 3);
426
427 let state = State::new();
428 loop_agent.run(&state).await.unwrap();
429 assert_eq!(state.get::<i32>("n"), Some(3));
430 }
431
432 #[tokio::test]
435 async fn fallback_returns_first_success() {
436 let candidates: Vec<Arc<dyn TextAgent>> = vec![
437 Arc::new(FnTextAgent::new("fail1", |_| {
438 Err(AgentError::Other("fail1".into()))
439 })),
440 Arc::new(FnTextAgent::new("ok", |_| Ok("success".into()))),
441 Arc::new(FnTextAgent::new("never", |_| Ok("unreachable".into()))),
442 ];
443
444 let fallback = FallbackTextAgent::new("fallback", candidates);
445 let state = State::new();
446 let result = fallback.run(&state).await.unwrap();
447 assert_eq!(result, "success");
448 }
449
450 #[tokio::test]
451 async fn fallback_returns_last_error() {
452 let candidates: Vec<Arc<dyn TextAgent>> = vec![
453 Arc::new(FnTextAgent::new("fail1", |_| {
454 Err(AgentError::Other("fail1".into()))
455 })),
456 Arc::new(FnTextAgent::new("fail2", |_| {
457 Err(AgentError::Other("fail2".into()))
458 })),
459 ];
460
461 let fallback = FallbackTextAgent::new("fallback", candidates);
462 let state = State::new();
463 let err = fallback.run(&state).await.unwrap_err();
464 assert!(err.to_string().contains("fail2"));
465 }
466
467 #[tokio::test]
468 async fn fallback_empty_returns_error() {
469 let fallback = FallbackTextAgent::new("fallback", vec![]);
470 let state = State::new();
471 assert!(fallback.run(&state).await.is_err());
472 }
473
474 #[tokio::test]
477 async fn route_dispatches_matching_rule() {
478 let agent_a: Arc<dyn TextAgent> = Arc::new(FnTextAgent::new("a", |_| Ok("route_a".into())));
479 let agent_b: Arc<dyn TextAgent> = Arc::new(FnTextAgent::new("b", |_| Ok("route_b".into())));
480 let default: Arc<dyn TextAgent> =
481 Arc::new(FnTextAgent::new("default", |_| Ok("default".into())));
482
483 let router = RouteTextAgent::new(
484 "router",
485 vec![
486 RouteRule::new(
487 |s: &State| s.get::<String>("mode") == Some("a".into()),
488 agent_a,
489 ),
490 RouteRule::new(
491 |s: &State| s.get::<String>("mode") == Some("b".into()),
492 agent_b,
493 ),
494 ],
495 default,
496 );
497
498 let state = State::new();
499 state.set("mode", "b");
500 let result = router.run(&state).await.unwrap();
501 assert_eq!(result, "route_b");
502 }
503
504 #[tokio::test]
505 async fn route_uses_default_when_no_match() {
506 let default: Arc<dyn TextAgent> =
507 Arc::new(FnTextAgent::new("default", |_| Ok("fallback".into())));
508
509 let router = RouteTextAgent::new(
510 "router",
511 vec![RouteRule::new(|_: &State| false, default.clone())],
512 default,
513 );
514
515 let state = State::new();
516 let result = router.run(&state).await.unwrap();
517 assert_eq!(result, "fallback");
518 }
519
520 struct AsyncSleepAgent {
524 delay: Duration,
525 }
526
527 #[async_trait]
528 impl TextAgent for AsyncSleepAgent {
529 fn name(&self) -> &str {
530 "async-sleeper"
531 }
532 async fn run(&self, _state: &State) -> Result<String, AgentError> {
533 tokio::time::sleep(self.delay).await;
534 Ok("too late".into())
535 }
536 }
537
538 #[tokio::test]
541 async fn race_returns_first_to_complete() {
542 let fast: Arc<dyn TextAgent> = Arc::new(FnTextAgent::new("fast", |_| Ok("winner".into())));
544 let slow: Arc<dyn TextAgent> = Arc::new(AsyncSleepAgent {
545 delay: Duration::from_millis(500),
546 });
547
548 let race = RaceTextAgent::new("race", vec![fast, slow]);
549 let state = State::new();
550 let result = race.run(&state).await.unwrap();
551 assert_eq!(result, "winner");
552 }
553
554 #[tokio::test]
555 async fn race_empty_returns_error() {
556 let race = RaceTextAgent::new("race", vec![]);
557 let state = State::new();
558 assert!(race.run(&state).await.is_err());
559 }
560
561 #[tokio::test]
564 async fn timeout_returns_result_within_limit() {
565 let fast: Arc<dyn TextAgent> = Arc::new(FnTextAgent::new("fast", |_| Ok("done".into())));
566 let timeout = TimeoutTextAgent::new("timeout", fast, Duration::from_secs(5));
567 let state = State::new();
568 let result = timeout.run(&state).await.unwrap();
569 assert_eq!(result, "done");
570 }
571
572 #[tokio::test]
573 async fn timeout_returns_error_when_exceeded() {
574 let slow: Arc<dyn TextAgent> = Arc::new(AsyncSleepAgent {
575 delay: Duration::from_secs(2),
576 });
577 let timeout = TimeoutTextAgent::new("timeout", slow, Duration::from_millis(50));
578 let state = State::new();
579 let err = timeout.run(&state).await.unwrap_err();
580 assert!(matches!(err, AgentError::Timeout));
581 }
582
583 #[tokio::test]
586 async fn map_over_iterates_items() {
587 let agent: Arc<dyn TextAgent> = Arc::new(FnTextAgent::new("upper", |state: &State| {
588 let item: String = state
589 .get::<serde_json::Value>("_item")
590 .map(|v| v.as_str().unwrap_or("").to_string())
591 .unwrap_or_default();
592 Ok(item.to_uppercase())
593 }));
594
595 let map = MapOverTextAgent::new("mapper", agent, "items");
596 let state = State::new();
597 state.set(
598 "items",
599 vec![
600 serde_json::Value::String("hello".into()),
601 serde_json::Value::String("world".into()),
602 ],
603 );
604
605 let result = map.run(&state).await.unwrap();
606 assert!(result.contains("HELLO"));
607 assert!(result.contains("WORLD"));
608
609 let results: Vec<String> = state.get("_results").unwrap();
610 assert_eq!(results.len(), 2);
611 assert_eq!(results[0], "HELLO");
612 assert_eq!(results[1], "WORLD");
613 }
614
615 #[tokio::test]
616 async fn map_over_empty_list() {
617 let agent: Arc<dyn TextAgent> = Arc::new(FnTextAgent::new("noop", |_| Ok("x".into())));
618 let map = MapOverTextAgent::new("mapper", agent, "items");
619 let state = State::new();
620 let result = map.run(&state).await.unwrap();
622 assert_eq!(result, "");
623 }
624
625 #[tokio::test]
628 async fn tap_observes_state() {
629 let observed = Arc::new(std::sync::Mutex::new(String::new()));
630 let observed_clone = observed.clone();
631
632 let tap = TapTextAgent::new("observer", move |state: &State| {
633 let val = state.get::<String>("input").unwrap_or_default();
634 *observed_clone.lock().unwrap() = val;
635 });
636
637 let state = State::new();
638 state.set("input", "hello");
639 let result = tap.run(&state).await.unwrap();
640 assert_eq!(result, ""); assert_eq!(*observed.lock().unwrap(), "hello");
642 }
643
644 #[tokio::test]
647 async fn dispatch_and_join_round_trip() {
648 let registry = TaskRegistry::new();
649 let budget = Arc::new(tokio::sync::Semaphore::new(10));
650
651 let agent_a: Arc<dyn TextAgent> =
652 Arc::new(FnTextAgent::new("task_a", |_| Ok("result_a".into())));
653 let agent_b: Arc<dyn TextAgent> =
654 Arc::new(FnTextAgent::new("task_b", |_| Ok("result_b".into())));
655
656 let dispatch = DispatchTextAgent::new(
657 "dispatch",
658 vec![("task_a".into(), agent_a), ("task_b".into(), agent_b)],
659 registry.clone(),
660 budget,
661 );
662
663 let state = State::new();
664 let dispatch_result = dispatch.run(&state).await.unwrap();
665 assert_eq!(dispatch_result, ""); let join = JoinTextAgent::new("joiner", registry);
668 let join_result = join.run(&state).await.unwrap();
669 assert!(join_result.contains("result_a"));
670 assert!(join_result.contains("result_b"));
671 }
672
673 #[tokio::test]
674 async fn join_with_target_names() {
675 let registry = TaskRegistry::new();
676 let budget = Arc::new(tokio::sync::Semaphore::new(10));
677
678 let children: Vec<(String, Arc<dyn TextAgent>)> = vec![
679 (
680 "x".into(),
681 Arc::new(FnTextAgent::new("x", |_| Ok("rx".into()))),
682 ),
683 (
684 "y".into(),
685 Arc::new(FnTextAgent::new("y", |_| Ok("ry".into()))),
686 ),
687 (
688 "z".into(),
689 Arc::new(FnTextAgent::new("z", |_| Ok("rz".into()))),
690 ),
691 ];
692
693 let dispatch = DispatchTextAgent::new("dispatch", children, registry.clone(), budget);
694 let state = State::new();
695 dispatch.run(&state).await.unwrap();
696
697 let join =
699 JoinTextAgent::new("joiner", registry.clone()).targets(vec!["x".into(), "z".into()]);
700 let result = join.run(&state).await.unwrap();
701 assert!(result.contains("rx"));
702 assert!(result.contains("rz"));
703
704 let remaining = registry.inner.lock().await;
706 assert!(remaining.contains_key("y"));
707 }
708
709 #[tokio::test]
710 async fn join_with_timeout() {
711 let registry = TaskRegistry::new();
712 let budget = Arc::new(tokio::sync::Semaphore::new(10));
713
714 let slow: Arc<dyn TextAgent> = Arc::new(AsyncSleepAgent {
715 delay: Duration::from_secs(2),
716 });
717
718 let dispatch = DispatchTextAgent::new(
719 "dispatch",
720 vec![("slow".into(), slow)],
721 registry.clone(),
722 budget,
723 );
724 let state = State::new();
725 dispatch.run(&state).await.unwrap();
726
727 let join = JoinTextAgent::new("joiner", registry).timeout(Duration::from_millis(50));
728 let err = join.run(&state).await.unwrap_err();
729 assert!(matches!(err, AgentError::Timeout));
730 }
731}