1use std::fmt;
24
25use claude_wrapper::ClaudeCommand;
26use serde::{Deserialize, Serialize};
27
28use crate::chain::{ChainOptions, ChainResult, ChainStep, StepAction, StepFailurePolicy};
29use crate::pool::Pool;
30use crate::store::PoolStore;
31use crate::types::TaskResult;
32
33const DEFAULT_ROUTING_PROMPT: &str = include_str!("prompts/auto_route.md");
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
38#[serde(rename_all = "snake_case")]
39pub enum RoutePreference {
40 PreferSingle,
42 PreferParallel,
44 PreferChain,
46}
47
48impl fmt::Display for RoutePreference {
49 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50 match self {
51 Self::PreferSingle => write!(f, "single"),
52 Self::PreferParallel => write!(f, "parallel"),
53 Self::PreferChain => write!(f, "chain"),
54 }
55 }
56}
57
58#[derive(Debug, Clone, Default, Serialize, Deserialize)]
64pub struct AutoHint {
65 #[serde(skip_serializing_if = "Option::is_none")]
67 pub max_parallel: Option<usize>,
68 #[serde(skip_serializing_if = "Option::is_none")]
70 pub max_chain_steps: Option<usize>,
71 #[serde(skip_serializing_if = "Option::is_none")]
73 pub prefer: Option<RoutePreference>,
74 #[serde(skip_serializing_if = "Option::is_none")]
77 pub domain: Option<String>,
78 #[serde(skip_serializing_if = "Option::is_none")]
82 pub decomposition_hints: Option<Vec<String>>,
83}
84
85#[derive(Debug, Clone, Default)]
90pub struct AutoConfig {
91 pub custom_prompt: Option<String>,
99 pub hints: Option<AutoHint>,
101}
102
103fn render_hints(hints: &AutoHint) -> String {
105 let mut parts = Vec::new();
106
107 if let Some(n) = hints.max_parallel {
108 parts.push(format!("- Maximum parallel tasks: {n}"));
109 }
110 if let Some(n) = hints.max_chain_steps {
111 parts.push(format!("- Maximum chain steps: {n}"));
112 }
113 if let Some(pref) = &hints.prefer {
114 parts.push(format!(
115 "- Preferred route: {pref} (but choose differently if the task clearly warrants it)"
116 ));
117 }
118 if let Some(domain) = &hints.domain {
119 parts.push(format!("- Domain: {domain}"));
120 }
121 if let Some(decomp) = &hints.decomposition_hints
122 && !decomp.is_empty()
123 {
124 parts.push(format!(
125 "- Suggested decomposition boundaries: {}",
126 decomp.join(", ")
127 ));
128 }
129
130 if parts.is_empty() {
131 return String::new();
132 }
133
134 let mut section = String::from("\n\n## Constraints\n\n");
135 section.push_str(&parts.join("\n"));
136 section
137}
138
139pub(crate) fn assemble_routing_system_prompt(config: Option<&AutoConfig>) -> String {
143 let base = config
144 .and_then(|c| c.custom_prompt.as_deref())
145 .unwrap_or(DEFAULT_ROUTING_PROMPT);
146
147 let mut prompt = base.to_string();
148
149 if let Some(hints) = config.and_then(|c| c.hints.as_ref()) {
150 prompt.push_str(&render_hints(hints));
151 }
152
153 prompt
154}
155
156pub(crate) fn wrap_task(task: &str) -> String {
158 format!("<task>{task}</task>")
159}
160
161#[derive(Debug, Clone, Serialize, Deserialize)]
163#[serde(tag = "route", rename_all = "snake_case")]
164pub enum AutoRoute {
165 Single { prompt: String },
167 Parallel { prompts: Vec<String> },
169 Chain { steps: Vec<AutoStep> },
171}
172
173#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct AutoStep {
176 pub name: String,
178 pub prompt: String,
180}
181
182#[derive(Debug, Clone)]
184pub enum AutoResult {
185 Single(TaskResult),
187 Parallel(Vec<TaskResult>),
189 Chain(ChainResult),
191}
192
193impl AutoResult {
194 pub fn output(&self) -> String {
196 match self {
197 Self::Single(r) => r.output.clone(),
198 Self::Parallel(results) => results
199 .iter()
200 .enumerate()
201 .map(|(i, r)| format!("[{}] {}", i, r.output.trim()))
202 .collect::<Vec<_>>()
203 .join("\n"),
204 Self::Chain(r) => r.final_output.clone(),
205 }
206 }
207
208 pub fn route_name(&self) -> &'static str {
210 match self {
211 Self::Single(_) => "single",
212 Self::Parallel(_) => "parallel",
213 Self::Chain(_) => "chain",
214 }
215 }
216
217 pub fn cost_microdollars(&self) -> u64 {
219 match self {
220 Self::Single(r) => r.cost_microdollars,
221 Self::Parallel(results) => results.iter().map(|r| r.cost_microdollars).sum(),
222 Self::Chain(r) => r.total_cost_microdollars,
223 }
224 }
225}
226
227impl<S: PoolStore + 'static> Pool<S> {
228 pub async fn auto(&self, prompt: &str) -> crate::Result<AutoResult> {
247 self.auto_with_config(prompt, None).await
248 }
249
250 pub async fn auto_with_hints(
255 &self,
256 prompt: &str,
257 hints: &AutoHint,
258 ) -> crate::Result<AutoResult> {
259 let config = AutoConfig {
260 custom_prompt: None,
261 hints: Some(hints.clone()),
262 };
263 self.auto_with_config(prompt, Some(&config)).await
264 }
265
266 pub async fn auto_with_config(
271 &self,
272 prompt: &str,
273 config: Option<&AutoConfig>,
274 ) -> crate::Result<AutoResult> {
275 let route = match self.route_with_config(prompt, config).await {
276 Ok(route) => route,
277 Err(e) => {
278 tracing::warn!(error = %e, "auto-route parse failed, falling back to single");
279 AutoRoute::Single {
280 prompt: prompt.to_string(),
281 }
282 }
283 };
284
285 tracing::info!(route = route.route_name(), "auto-route decided");
286
287 self.execute_route(route).await
288 }
289
290 pub async fn route(&self, prompt: &str) -> crate::Result<AutoRoute> {
295 self.route_with_config(prompt, None).await
296 }
297
298 pub async fn route_with_hints(
300 &self,
301 prompt: &str,
302 hints: &AutoHint,
303 ) -> crate::Result<AutoRoute> {
304 let config = AutoConfig {
305 custom_prompt: None,
306 hints: Some(hints.clone()),
307 };
308 self.route_with_config(prompt, Some(&config)).await
309 }
310
311 pub async fn route_with_config(
313 &self,
314 prompt: &str,
315 config: Option<&AutoConfig>,
316 ) -> crate::Result<AutoRoute> {
317 let system = assemble_routing_system_prompt(config);
318 let user_message = wrap_task(prompt);
319
320 let cmd = claude_wrapper::QueryCommand::new(&user_message)
325 .system_prompt(system)
326 .output_format(claude_wrapper::OutputFormat::Json)
327 .permission_mode(claude_wrapper::PermissionMode::Plan)
328 .disallowed_tools(["Bash", "Read", "Write", "Edit", "Glob", "Grep", "Agent"])
329 .no_session_persistence()
330 .max_turns(2);
331
332 let output = cmd
333 .execute(self.claude())
334 .await
335 .map_err(crate::Error::Wrapper)?;
336
337 parse_route_from_output(&output.stdout)
338 }
339
340 pub async fn execute_route(&self, route: AutoRoute) -> crate::Result<AutoResult> {
347 let route = normalize_route(route)?;
348
349 match route {
350 AutoRoute::Single { prompt } => {
351 let result = self.run(&prompt).await?;
352 Ok(AutoResult::Single(result))
353 }
354 AutoRoute::Parallel { prompts } => {
355 let refs: Vec<&str> = prompts.iter().map(|s| s.as_str()).collect();
356 let results = self.fan_out(&refs).await?;
357 Ok(AutoResult::Parallel(results))
358 }
359 AutoRoute::Chain { steps } => {
360 let chain_steps: Vec<ChainStep> = steps
361 .into_iter()
362 .map(|s| ChainStep {
363 name: s.name,
364 action: StepAction::Prompt { prompt: s.prompt },
365 config: None,
366 failure_policy: StepFailurePolicy::default(),
367 output_vars: Default::default(),
368 })
369 .collect();
370
371 let task_id = self
372 .submit_chain(chain_steps, ChainOptions::default())
373 .await?;
374
375 let deadline = tokio::time::Instant::now()
377 + std::time::Duration::from_secs(CHAIN_POLL_TIMEOUT_SECS);
378 loop {
379 if let Some(result) = self.result(&task_id).await? {
380 if let Ok(chain_result) =
382 serde_json::from_str::<ChainResult>(&result.output)
383 {
384 return Ok(AutoResult::Chain(chain_result));
385 }
386 return Ok(AutoResult::Single(result));
388 }
389 if tokio::time::Instant::now() >= deadline {
390 return Err(crate::Error::Store(format!(
391 "auto-route chain timed out after {CHAIN_POLL_TIMEOUT_SECS}s"
392 )));
393 }
394 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
395 }
396 }
397 }
398 }
399}
400
401impl AutoRoute {
402 fn route_name(&self) -> &'static str {
404 match self {
405 Self::Single { .. } => "single",
406 Self::Parallel { .. } => "parallel",
407 Self::Chain { .. } => "chain",
408 }
409 }
410}
411
412const CHAIN_POLL_TIMEOUT_SECS: u64 = 600;
416
417fn normalize_route(route: AutoRoute) -> crate::Result<AutoRoute> {
424 match route {
425 AutoRoute::Single { prompt } => {
426 let prompt = prompt.trim().to_string();
427 if prompt.is_empty() {
428 return Err(crate::Error::Store(
429 "auto-route produced an empty prompt".into(),
430 ));
431 }
432 Ok(AutoRoute::Single { prompt })
433 }
434 AutoRoute::Parallel { prompts } => {
435 let prompts: Vec<String> = prompts
436 .into_iter()
437 .map(|p| p.trim().to_string())
438 .filter(|p| !p.is_empty())
439 .collect();
440 match prompts.len() {
441 0 => Err(crate::Error::Store(
442 "auto-route produced parallel with no prompts".into(),
443 )),
444 1 => {
445 tracing::info!("normalizing parallel(1) to single");
446 Ok(AutoRoute::Single {
447 prompt: prompts.into_iter().next().unwrap(),
448 })
449 }
450 _ => Ok(AutoRoute::Parallel { prompts }),
451 }
452 }
453 AutoRoute::Chain { steps } => {
454 let steps: Vec<AutoStep> = steps
455 .into_iter()
456 .filter(|s| !s.prompt.trim().is_empty())
457 .map(|s| AutoStep {
458 name: s.name,
459 prompt: s.prompt.trim().to_string(),
460 })
461 .collect();
462 match steps.len() {
463 0 => Err(crate::Error::Store(
464 "auto-route produced chain with no steps".into(),
465 )),
466 1 => {
467 tracing::info!("normalizing chain(1) to single");
468 Ok(AutoRoute::Single {
469 prompt: steps.into_iter().next().unwrap().prompt,
470 })
471 }
472 _ => Ok(AutoRoute::Chain { steps }),
473 }
474 }
475 }
476}
477
478pub(crate) fn parse_route_from_output(output: &str) -> crate::Result<AutoRoute> {
486 if let Ok(query_result) = serde_json::from_str::<serde_json::Value>(output) {
488 if let Some(subtype) = query_result.get("subtype").and_then(|v| v.as_str())
491 && subtype != "success"
492 {
493 tracing::warn!(
494 subtype,
495 "routing LLM returned non-success result (likely used tools instead of classifying)"
496 );
497 return Err(crate::Error::Store(format!(
498 "routing LLM returned '{subtype}' instead of a routing decision"
499 )));
500 }
501
502 if let Some(result_text) = query_result.get("result").and_then(|v| v.as_str())
503 && let Ok(route) = extract_json_route(result_text)
504 {
505 return Ok(route);
506 }
507 }
508
509 if let Ok(route) = extract_json_route(output) {
511 return Ok(route);
512 }
513
514 tracing::debug!(
515 output = %output.chars().take(500).collect::<String>(),
516 "could not parse routing decision from LLM output"
517 );
518 Err(crate::Error::Store(
519 "could not parse routing decision from LLM output".into(),
520 ))
521}
522
523pub(crate) fn extract_json_route(text: &str) -> crate::Result<AutoRoute> {
525 if let Ok(route) = serde_json::from_str::<AutoRoute>(text) {
527 return Ok(route);
528 }
529
530 if let Some(start) = text.find("```json") {
532 let json_start = start + 7;
533 if let Some(end) = text[json_start..].find("```") {
534 let json_str = text[json_start..json_start + end].trim();
535 if let Ok(route) = serde_json::from_str::<AutoRoute>(json_str) {
536 return Ok(route);
537 }
538 }
539 }
540
541 if let Some(start) = text.find("```\n") {
543 let json_start = start + 4;
544 if let Some(end) = text[json_start..].find("```") {
545 let json_str = text[json_start..json_start + end].trim();
546 if let Ok(route) = serde_json::from_str::<AutoRoute>(json_str) {
547 return Ok(route);
548 }
549 }
550 }
551
552 if let Some(start) = text.find(r#""route""#) {
554 let before = &text[..start];
555 if let Some(brace) = before.rfind('{') {
556 let candidate = &text[brace..];
557 let mut depth = 0;
558 let mut end = 0;
559 for (i, ch) in candidate.char_indices() {
560 match ch {
561 '{' => depth += 1,
562 '}' => {
563 depth -= 1;
564 if depth == 0 {
565 end = i + 1;
566 break;
567 }
568 }
569 _ => {}
570 }
571 }
572 if end > 0 {
573 let json_str = &candidate[..end];
574 if let Ok(route) = serde_json::from_str::<AutoRoute>(json_str) {
575 return Ok(route);
576 }
577 }
578 }
579 }
580
581 Err(crate::Error::Store(
582 "no valid JSON routing decision found in text".into(),
583 ))
584}
585
586#[cfg(test)]
587mod tests {
588 use super::*;
589
590 #[test]
591 fn parse_single_route() {
592 let json = r#"{"route": "single", "prompt": "fix the bug"}"#;
593 let route = extract_json_route(json).unwrap();
594 match route {
595 AutoRoute::Single { prompt } => assert_eq!(prompt, "fix the bug"),
596 _ => panic!("expected Single"),
597 }
598 }
599
600 #[test]
601 fn parse_parallel_route() {
602 let json =
603 r#"{"route": "parallel", "prompts": ["review a.rs", "review b.rs", "review c.rs"]}"#;
604 let route = extract_json_route(json).unwrap();
605 match route {
606 AutoRoute::Parallel { prompts } => {
607 assert_eq!(prompts.len(), 3);
608 assert_eq!(prompts[0], "review a.rs");
609 }
610 _ => panic!("expected Parallel"),
611 }
612 }
613
614 #[test]
615 fn parse_chain_route() {
616 let json = r#"{"route": "chain", "steps": [{"name": "analyze", "prompt": "analyze the code"}, {"name": "fix", "prompt": "fix based on {previous_output}"}]}"#;
617 let route = extract_json_route(json).unwrap();
618 match route {
619 AutoRoute::Chain { steps } => {
620 assert_eq!(steps.len(), 2);
621 assert_eq!(steps[0].name, "analyze");
622 assert!(steps[1].prompt.contains("{previous_output}"));
623 }
624 _ => panic!("expected Chain"),
625 }
626 }
627
628 #[test]
629 fn parse_from_markdown_fence() {
630 let text = r#"Here is my decision:
631```json
632{"route": "single", "prompt": "just do it"}
633```
634"#;
635 let route = extract_json_route(text).unwrap();
636 assert!(matches!(route, AutoRoute::Single { .. }));
637 }
638
639 #[test]
640 fn parse_from_bare_fence() {
641 let text = "```\n{\"route\": \"single\", \"prompt\": \"do it\"}\n```\n";
642 let route = extract_json_route(text).unwrap();
643 assert!(matches!(route, AutoRoute::Single { .. }));
644 }
645
646 #[test]
647 fn parse_from_embedded_json() {
648 let text = r#"I think this should be {"route": "single", "prompt": "just do it"} and that's my answer."#;
649 let route = extract_json_route(text).unwrap();
650 assert!(matches!(route, AutoRoute::Single { .. }));
651 }
652
653 #[test]
654 fn parse_from_query_result_wrapper() {
655 let output = r#"{"result": "{\"route\": \"parallel\", \"prompts\": [\"a\", \"b\"]}", "session_id": "abc", "cost_usd": 0.01}"#;
656 let route = parse_route_from_output(output).unwrap();
657 match route {
658 AutoRoute::Parallel { prompts } => assert_eq!(prompts.len(), 2),
659 _ => panic!("expected Parallel"),
660 }
661 }
662
663 #[test]
664 fn parse_fails_on_garbage() {
665 assert!(extract_json_route("this is not json at all").is_err());
666 assert!(parse_route_from_output("garbage").is_err());
667 }
668
669 #[test]
670 fn fallback_to_single_on_parse_failure() {
671 let original_prompt = "do the thing";
673 let route =
674 parse_route_from_output("unparseable garbage").unwrap_or_else(|_| AutoRoute::Single {
675 prompt: original_prompt.to_string(),
676 });
677 match route {
678 AutoRoute::Single { prompt } => assert_eq!(prompt, "do the thing"),
679 _ => panic!("expected fallback to Single"),
680 }
681 }
682
683 #[test]
684 fn auto_result_output_single() {
685 let result = AutoResult::Single(TaskResult::success(String::from("hello world"), 100, 50));
686 assert_eq!(result.output(), "hello world");
687 assert_eq!(result.route_name(), "single");
688 assert_eq!(result.cost_microdollars(), 100);
689 }
690
691 #[test]
692 fn auto_result_output_parallel() {
693 let results = vec![
694 TaskResult::success(String::from("one"), 100, 50),
695 TaskResult::success(String::from("two"), 200, 50),
696 ];
697 let result = AutoResult::Parallel(results);
698 assert_eq!(result.route_name(), "parallel");
699 assert_eq!(result.cost_microdollars(), 300);
700 assert!(result.output().contains("[0] one"));
701 assert!(result.output().contains("[1] two"));
702 }
703
704 #[test]
705 fn auto_result_output_chain() {
706 let chain = ChainResult {
707 steps: vec![],
708 final_output: "chain done".into(),
709 total_cost_microdollars: 500,
710 success: true,
711 };
712 let result = AutoResult::Chain(chain);
713 assert_eq!(result.output(), "chain done");
714 assert_eq!(result.route_name(), "chain");
715 assert_eq!(result.cost_microdollars(), 500);
716 }
717
718 #[test]
719 fn serde_roundtrip_single() {
720 let route = AutoRoute::Single {
721 prompt: "test".into(),
722 };
723 let json = serde_json::to_string(&route).unwrap();
724 let parsed: AutoRoute = serde_json::from_str(&json).unwrap();
725 assert!(matches!(parsed, AutoRoute::Single { .. }));
726 }
727
728 #[test]
729 fn serde_roundtrip_parallel() {
730 let route = AutoRoute::Parallel {
731 prompts: vec!["a".into(), "b".into()],
732 };
733 let json = serde_json::to_string(&route).unwrap();
734 let parsed: AutoRoute = serde_json::from_str(&json).unwrap();
735 assert!(matches!(parsed, AutoRoute::Parallel { .. }));
736 }
737
738 #[test]
739 fn serde_roundtrip_chain() {
740 let route = AutoRoute::Chain {
741 steps: vec![AutoStep {
742 name: "s1".into(),
743 prompt: "do it".into(),
744 }],
745 };
746 let json = serde_json::to_string(&route).unwrap();
747 let parsed: AutoRoute = serde_json::from_str(&json).unwrap();
748 assert!(matches!(parsed, AutoRoute::Chain { .. }));
749 }
750
751 #[test]
754 fn render_empty_hints_produces_nothing() {
755 let hints = AutoHint::default();
756 assert_eq!(render_hints(&hints), "");
757 }
758
759 #[test]
760 fn render_hints_max_parallel() {
761 let hints = AutoHint {
762 max_parallel: Some(3),
763 ..Default::default()
764 };
765 let rendered = render_hints(&hints);
766 assert!(rendered.contains("Maximum parallel tasks: 3"));
767 assert!(rendered.contains("## Constraints"));
768 }
769
770 #[test]
771 fn render_hints_max_chain_steps() {
772 let hints = AutoHint {
773 max_chain_steps: Some(4),
774 ..Default::default()
775 };
776 let rendered = render_hints(&hints);
777 assert!(rendered.contains("Maximum chain steps: 4"));
778 }
779
780 #[test]
781 fn render_hints_preference() {
782 let hints = AutoHint {
783 prefer: Some(RoutePreference::PreferParallel),
784 ..Default::default()
785 };
786 let rendered = render_hints(&hints);
787 assert!(rendered.contains("Preferred route: parallel"));
788 assert!(rendered.contains("choose differently if the task clearly warrants it"));
789 }
790
791 #[test]
792 fn render_hints_domain() {
793 let hints = AutoHint {
794 domain: Some("monorepo with independent crates".into()),
795 ..Default::default()
796 };
797 let rendered = render_hints(&hints);
798 assert!(rendered.contains("Domain: monorepo with independent crates"));
799 }
800
801 #[test]
802 fn render_hints_decomposition() {
803 let hints = AutoHint {
804 decomposition_hints: Some(vec![
805 "auth module".into(),
806 "api module".into(),
807 "db module".into(),
808 ]),
809 ..Default::default()
810 };
811 let rendered = render_hints(&hints);
812 assert!(
813 rendered
814 .contains("Suggested decomposition boundaries: auth module, api module, db module")
815 );
816 }
817
818 #[test]
819 fn render_hints_empty_decomposition_skipped() {
820 let hints = AutoHint {
821 decomposition_hints: Some(vec![]),
822 ..Default::default()
823 };
824 assert_eq!(render_hints(&hints), "");
825 }
826
827 #[test]
828 fn render_hints_all_fields() {
829 let hints = AutoHint {
830 max_parallel: Some(2),
831 max_chain_steps: Some(3),
832 prefer: Some(RoutePreference::PreferChain),
833 domain: Some("microservices".into()),
834 decomposition_hints: Some(vec!["svc-a".into(), "svc-b".into()]),
835 };
836 let rendered = render_hints(&hints);
837 assert!(rendered.contains("Maximum parallel tasks: 2"));
838 assert!(rendered.contains("Maximum chain steps: 3"));
839 assert!(rendered.contains("Preferred route: chain"));
840 assert!(rendered.contains("Domain: microservices"));
841 assert!(rendered.contains("svc-a, svc-b"));
842 }
843
844 #[test]
845 fn assemble_system_prompt_no_config() {
846 let prompt = assemble_routing_system_prompt(None);
847 assert!(prompt.starts_with("You are a work router."));
848 assert!(!prompt.contains("## Task"));
849 assert!(!prompt.contains("## Constraints"));
850 }
851
852 #[test]
853 fn assemble_system_prompt_with_hints() {
854 let config = AutoConfig {
855 custom_prompt: None,
856 hints: Some(AutoHint {
857 max_parallel: Some(2),
858 ..Default::default()
859 }),
860 };
861 let prompt = assemble_routing_system_prompt(Some(&config));
862 assert!(prompt.starts_with("You are a work router."));
863 assert!(prompt.contains("## Constraints"));
864 assert!(prompt.contains("Maximum parallel tasks: 2"));
865 assert!(!prompt.contains("## Task"));
866 }
867
868 #[test]
869 fn assemble_system_prompt_with_custom_prompt() {
870 let config = AutoConfig {
871 custom_prompt: Some("You are a custom router.".into()),
872 hints: None,
873 };
874 let prompt = assemble_routing_system_prompt(Some(&config));
875 assert!(prompt.starts_with("You are a custom router."));
876 assert!(!prompt.contains("You are a work router."));
877 assert!(!prompt.contains("## Task"));
878 }
879
880 #[test]
881 fn assemble_system_prompt_custom_prompt_with_hints() {
882 let config = AutoConfig {
883 custom_prompt: Some("Custom instructions.".into()),
884 hints: Some(AutoHint {
885 domain: Some("testing".into()),
886 ..Default::default()
887 }),
888 };
889 let prompt = assemble_routing_system_prompt(Some(&config));
890 assert!(prompt.starts_with("Custom instructions."));
891 assert!(prompt.contains("## Constraints"));
892 assert!(prompt.contains("Domain: testing"));
893 assert!(!prompt.contains("## Task"));
894 }
895
896 #[test]
897 fn wrap_task_adds_xml_tags() {
898 let wrapped = wrap_task("do the thing");
899 assert_eq!(wrapped, "<task>do the thing</task>");
900 }
901
902 #[test]
903 fn default_prompt_loaded_from_file() {
904 assert!(DEFAULT_ROUTING_PROMPT.contains("You are a work router."));
905 assert!(DEFAULT_ROUTING_PROMPT.contains("THREE options"));
906 assert!(DEFAULT_ROUTING_PROMPT.contains("SINGLE"));
907 assert!(DEFAULT_ROUTING_PROMPT.contains("PARALLEL"));
908 assert!(DEFAULT_ROUTING_PROMPT.contains("CHAIN"));
909 assert!(DEFAULT_ROUTING_PROMPT.contains("Decision test"));
911 assert!(DEFAULT_ROUTING_PROMPT.contains("<examples>"));
913 assert!(DEFAULT_ROUTING_PROMPT.contains("<example>"));
914 assert!(DEFAULT_ROUTING_PROMPT.contains("</examples>"));
915 assert!(DEFAULT_ROUTING_PROMPT.contains("Common mistakes to avoid"));
917 assert!(DEFAULT_ROUTING_PROMPT.contains("Splitting incorrectly is worse"));
919 assert!(
921 DEFAULT_ROUTING_PROMPT.contains("task to classify is provided in the user message")
922 );
923 }
924
925 #[test]
926 fn route_preference_display() {
927 assert_eq!(RoutePreference::PreferSingle.to_string(), "single");
928 assert_eq!(RoutePreference::PreferParallel.to_string(), "parallel");
929 assert_eq!(RoutePreference::PreferChain.to_string(), "chain");
930 }
931
932 #[test]
933 fn route_preference_serde_roundtrip() {
934 let pref = RoutePreference::PreferParallel;
935 let json = serde_json::to_string(&pref).unwrap();
936 let parsed: RoutePreference = serde_json::from_str(&json).unwrap();
937 assert_eq!(parsed, RoutePreference::PreferParallel);
938 }
939
940 #[test]
941 fn auto_hint_serde_skips_none_fields() {
942 let hints = AutoHint {
943 max_parallel: Some(3),
944 ..Default::default()
945 };
946 let json = serde_json::to_string(&hints).unwrap();
947 assert!(json.contains("max_parallel"));
948 assert!(!json.contains("max_chain_steps"));
949 assert!(!json.contains("prefer"));
950 assert!(!json.contains("domain"));
951 assert!(!json.contains("decomposition_hints"));
952 }
953
954 #[test]
955 fn auto_hint_default_is_empty() {
956 let hints = AutoHint::default();
957 assert!(hints.max_parallel.is_none());
958 assert!(hints.max_chain_steps.is_none());
959 assert!(hints.prefer.is_none());
960 assert!(hints.domain.is_none());
961 assert!(hints.decomposition_hints.is_none());
962 }
963
964 #[test]
967 fn normalize_single_trims_whitespace() {
968 let route = AutoRoute::Single {
969 prompt: " hello ".into(),
970 };
971 let normalized = normalize_route(route).unwrap();
972 match normalized {
973 AutoRoute::Single { prompt } => assert_eq!(prompt, "hello"),
974 _ => panic!("expected Single"),
975 }
976 }
977
978 #[test]
979 fn normalize_single_rejects_empty() {
980 let route = AutoRoute::Single {
981 prompt: " ".into(),
982 };
983 assert!(normalize_route(route).is_err());
984 }
985
986 #[test]
987 fn normalize_parallel_one_becomes_single() {
988 let route = AutoRoute::Parallel {
989 prompts: vec!["only one".into()],
990 };
991 let normalized = normalize_route(route).unwrap();
992 match normalized {
993 AutoRoute::Single { prompt } => assert_eq!(prompt, "only one"),
994 _ => panic!("expected Single, got {:?}", normalized),
995 }
996 }
997
998 #[test]
999 fn normalize_parallel_empty_is_error() {
1000 let route = AutoRoute::Parallel { prompts: vec![] };
1001 assert!(normalize_route(route).is_err());
1002 }
1003
1004 #[test]
1005 fn normalize_parallel_filters_empty_prompts() {
1006 let route = AutoRoute::Parallel {
1007 prompts: vec!["good".into(), " ".into(), "also good".into()],
1008 };
1009 let normalized = normalize_route(route).unwrap();
1010 match normalized {
1011 AutoRoute::Parallel { prompts } => {
1012 assert_eq!(prompts.len(), 2);
1013 assert_eq!(prompts[0], "good");
1014 assert_eq!(prompts[1], "also good");
1015 }
1016 _ => panic!("expected Parallel"),
1017 }
1018 }
1019
1020 #[test]
1021 fn normalize_parallel_all_empty_is_error() {
1022 let route = AutoRoute::Parallel {
1023 prompts: vec![" ".into(), "".into()],
1024 };
1025 assert!(normalize_route(route).is_err());
1026 }
1027
1028 #[test]
1029 fn normalize_chain_one_becomes_single() {
1030 let route = AutoRoute::Chain {
1031 steps: vec![AutoStep {
1032 name: "only".into(),
1033 prompt: "do it".into(),
1034 }],
1035 };
1036 let normalized = normalize_route(route).unwrap();
1037 match normalized {
1038 AutoRoute::Single { prompt } => assert_eq!(prompt, "do it"),
1039 _ => panic!("expected Single"),
1040 }
1041 }
1042
1043 #[test]
1044 fn normalize_chain_empty_is_error() {
1045 let route = AutoRoute::Chain { steps: vec![] };
1046 assert!(normalize_route(route).is_err());
1047 }
1048
1049 #[test]
1050 fn normalize_chain_filters_empty_prompts() {
1051 let route = AutoRoute::Chain {
1052 steps: vec![
1053 AutoStep {
1054 name: "a".into(),
1055 prompt: "step one".into(),
1056 },
1057 AutoStep {
1058 name: "b".into(),
1059 prompt: " ".into(),
1060 },
1061 AutoStep {
1062 name: "c".into(),
1063 prompt: "step three".into(),
1064 },
1065 ],
1066 };
1067 let normalized = normalize_route(route).unwrap();
1068 match normalized {
1069 AutoRoute::Chain { steps } => {
1070 assert_eq!(steps.len(), 2);
1071 assert_eq!(steps[0].name, "a");
1072 assert_eq!(steps[1].name, "c");
1073 }
1074 _ => panic!("expected Chain"),
1075 }
1076 }
1077
1078 #[test]
1079 fn normalize_valid_parallel_unchanged() {
1080 let route = AutoRoute::Parallel {
1081 prompts: vec!["a".into(), "b".into(), "c".into()],
1082 };
1083 let normalized = normalize_route(route).unwrap();
1084 assert!(matches!(normalized, AutoRoute::Parallel { prompts } if prompts.len() == 3));
1085 }
1086
1087 #[test]
1088 fn normalize_valid_chain_unchanged() {
1089 let route = AutoRoute::Chain {
1090 steps: vec![
1091 AutoStep {
1092 name: "s1".into(),
1093 prompt: "first".into(),
1094 },
1095 AutoStep {
1096 name: "s2".into(),
1097 prompt: "second".into(),
1098 },
1099 ],
1100 };
1101 let normalized = normalize_route(route).unwrap();
1102 assert!(matches!(normalized, AutoRoute::Chain { steps } if steps.len() == 2));
1103 }
1104}