1#![allow(missing_docs)]
9use serde::{Deserialize, Serialize};
10
11use super::events::AgentEvent;
12use crate::error::Error;
13
14pub trait RoutingStrategy: Send + Sync {
39 fn route(&self, task: &str, agents: &[AgentCapability])
41 -> (RoutingDecision, ComplexitySignals);
42}
43
44pub struct KeywordRoutingStrategy;
51
52impl RoutingStrategy for KeywordRoutingStrategy {
53 fn route(
54 &self,
55 task: &str,
56 agents: &[AgentCapability],
57 ) -> (RoutingDecision, ComplexitySignals) {
58 let analyzer = TaskComplexityAnalyzer::new(agents);
59 analyzer.analyze(task)
60 }
61}
62
63#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Deserialize, Serialize)]
65#[serde(rename_all = "snake_case")]
66pub enum RoutingMode {
67 #[default]
69 Auto,
70 AlwaysOrchestrate,
72 SingleAgent,
74}
75
76#[derive(Debug, Clone, PartialEq)]
78pub enum RoutingDecision {
79 SingleAgent {
80 agent_index: usize,
81 reason: &'static str,
82 },
83 Orchestrate {
84 reason: &'static str,
85 },
86}
87
88#[derive(Debug, Clone, Default)]
90pub struct ComplexitySignals {
91 pub word_count: usize,
92 pub step_markers: usize,
93 pub domain_signals: Vec<String>,
94 pub explicit_delegation: bool,
95 pub names_multiple_agents: bool,
96 pub covering_agents: Vec<usize>,
98 pub complexity_score: f32,
99}
100
101#[derive(Debug, Clone)]
103pub struct AgentCapability {
104 pub name: String,
105 pub description_lower: String,
106 pub tool_names: Vec<String>,
107 pub domains: Vec<String>,
108}
109
110const CODE_KEYWORDS: &[&str] = &[
113 "code",
114 "implement",
115 "refactor",
116 "debug",
117 "compile",
118 "function",
119 "class",
120 "module",
121 "rust",
122 "python",
123 "javascript",
124 "typescript",
125 "java",
126 "golang",
127 "programming",
128 "syntax",
129 "bug",
130 "fix",
131 "test",
132 "unit test",
133];
134
135const RESEARCH_KEYWORDS: &[&str] = &[
136 "research",
137 "investigate",
138 "analyze",
139 "study",
140 "survey",
141 "find",
142 "search",
143 "explore",
144 "review",
145 "literature",
146 "paper",
147 "arxiv",
148];
149
150const DATABASE_KEYWORDS: &[&str] = &[
151 "database",
152 "sql",
153 "query",
154 "table",
155 "schema",
156 "migration",
157 "postgres",
158 "mysql",
159 "sqlite",
160 "mongodb",
161 "redis",
162 "index",
163];
164
165const FRONTEND_KEYWORDS: &[&str] = &[
166 "frontend",
167 "ui",
168 "ux",
169 "component",
170 "react",
171 "vue",
172 "angular",
173 "css",
174 "html",
175 "layout",
176 "responsive",
177 "design",
178 "button",
179 "form",
180 "modal",
181];
182
183const BACKEND_KEYWORDS: &[&str] = &[
184 "backend",
185 "api",
186 "endpoint",
187 "server",
188 "rest",
189 "graphql",
190 "middleware",
191 "authentication",
192 "authorization",
193 "route",
194 "handler",
195];
196
197const DEVOPS_KEYWORDS: &[&str] = &[
198 "devops",
199 "deploy",
200 "docker",
201 "kubernetes",
202 "ci/cd",
203 "pipeline",
204 "infrastructure",
205 "terraform",
206 "ansible",
207 "nginx",
208 "monitoring",
209 "logging",
210];
211
212const WRITING_KEYWORDS: &[&str] = &[
213 "write",
214 "document",
215 "documentation",
216 "readme",
217 "blog",
218 "article",
219 "report",
220 "summary",
221 "copywriting",
222 "content",
223 "draft",
224 "edit text",
225];
226
227const SECURITY_KEYWORDS: &[&str] = &[
228 "security",
229 "vulnerability",
230 "audit",
231 "penetration",
232 "encryption",
233 "auth",
234 "cors",
235 "xss",
236 "injection",
237 "firewall",
238 "certificate",
239 "tls",
240];
241
242const DOMAIN_LISTS: &[(&str, &[&str])] = &[
243 ("code", CODE_KEYWORDS),
244 ("research", RESEARCH_KEYWORDS),
245 ("database", DATABASE_KEYWORDS),
246 ("frontend", FRONTEND_KEYWORDS),
247 ("backend", BACKEND_KEYWORDS),
248 ("devops", DEVOPS_KEYWORDS),
249 ("writing", WRITING_KEYWORDS),
250 ("security", SECURITY_KEYWORDS),
251];
252
253const STEP_MARKERS: &[&str] = &[
256 "first,",
257 "second,",
258 "third,",
259 "then,",
260 "finally,",
261 "next,",
262 "after that",
263 "step 1",
264 "step 2",
265 "step 3",
266 "step 4",
267 "step 5",
268];
269
270const DELEGATION_PHRASES: &[&str] = &[
273 "delegate",
274 "have them",
275 "coordinate between",
276 "coordinate with",
277 "team up",
278 "work together",
279 "collaborate",
280 "assign to",
281 "hand off",
282 "pass to",
283 "ask the ",
289 "have the ",
290 "use the ",
291 "tell the ",
292 "instruct the ",
293 "let the ",
294 "get the ",
295];
296
297const SINGLE_AGENT_THRESHOLD: f32 = 0.30;
300const ORCHESTRATE_THRESHOLD: f32 = 0.55;
301
302const WEIGHT_SIMPLE_QUESTION: f32 = -0.30;
305const WEIGHT_WORD_COUNT_HIGH: f32 = 0.10;
306const WEIGHT_STEP_MARKERS: f32 = 0.25;
307const WEIGHT_DELEGATION: f32 = 0.30;
308const WEIGHT_NAMES_AGENTS: f32 = 0.40;
309const WEIGHT_DOMAIN_DIVERSITY: f32 = 0.20;
310
311impl AgentCapability {
312 pub fn from_config(name: &str, description: &str, tool_names: &[String]) -> Self {
314 let description_lower = description.to_lowercase();
315 let tool_lower: Vec<String> = tool_names.iter().map(|t| t.to_lowercase()).collect();
316
317 let combined = format!("{} {}", description_lower, tool_lower.join(" "));
319 let mut domains = Vec::new();
320 for &(domain, keywords) in DOMAIN_LISTS {
321 if keywords.iter().any(|kw| contains_keyword(&combined, kw)) {
322 domains.push(domain.to_string());
323 }
324 }
325
326 Self {
327 name: name.to_lowercase(),
328 description_lower,
329 tool_names: tool_lower,
330 domains,
331 }
332 }
333}
334
335pub struct TaskComplexityAnalyzer<'a> {
337 agents: &'a [AgentCapability],
338}
339
340impl<'a> TaskComplexityAnalyzer<'a> {
341 pub fn new(agents: &'a [AgentCapability]) -> Self {
342 Self { agents }
343 }
344
345 pub fn analyze(&self, task: &str) -> (RoutingDecision, ComplexitySignals) {
347 let mut signals = self.heuristic_signals(task);
348
349 if signals.complexity_score < SINGLE_AGENT_THRESHOLD {
357 let agent_index = if signals.domain_signals.is_empty() {
358 0
359 } else {
360 best_covering_agent(&signals.domain_signals, self.agents).unwrap_or(0)
361 };
362 let reason = if agent_index == 0 && signals.domain_signals.is_empty() {
363 "heuristic score below single-agent threshold (no domain signals)"
364 } else if agent_index == 0 {
365 "heuristic score below single-agent threshold (no agent matched detected domains)"
366 } else {
367 "heuristic score below single-agent threshold (matched by domain coverage)"
368 };
369 return (
370 RoutingDecision::SingleAgent {
371 agent_index,
372 reason,
373 },
374 signals,
375 );
376 }
377 if signals.complexity_score > ORCHESTRATE_THRESHOLD {
378 return (
379 RoutingDecision::Orchestrate {
380 reason: "heuristic score above orchestrate threshold",
381 },
382 signals,
383 );
384 }
385
386 let decision = self.capability_match(&signals.domain_signals, &mut signals.covering_agents);
388 (decision, signals)
389 }
390
391 pub fn heuristic_signals(&self, task: &str) -> ComplexitySignals {
393 let task_lower = task.to_lowercase();
394 let words: Vec<&str> = task.split_whitespace().collect();
395 let word_count = words.len();
396
397 let simple_question = is_simple_question(&task_lower, &words);
399
400 let step_markers = count_step_markers(&task_lower);
402 let numbered_items = words.iter().filter(|w| is_numbered_step_marker(w)).count();
406 let total_step_markers = step_markers + numbered_items;
407
408 let explicit_delegation = DELEGATION_PHRASES.iter().any(|p| task_lower.contains(p));
410
411 let domain_signals = detect_domains(&task_lower);
413
414 let names_multiple_agents = if self.agents.len() >= 2 {
416 let matching = self
417 .agents
418 .iter()
419 .filter(|a| task_lower.contains(&a.name))
420 .count();
421 matching >= 2
422 } else {
423 false
424 };
425
426 let mut score: f32 = 0.0;
428 if simple_question {
429 score += WEIGHT_SIMPLE_QUESTION;
430 }
431 if word_count > 100 {
432 score += WEIGHT_WORD_COUNT_HIGH;
433 }
434 if total_step_markers >= 2 {
435 score += WEIGHT_STEP_MARKERS;
436 }
437 if explicit_delegation {
438 score += WEIGHT_DELEGATION;
439 }
440 if names_multiple_agents {
441 score += WEIGHT_NAMES_AGENTS;
442 }
443 if domain_signals.len() >= 3 {
444 score += WEIGHT_DOMAIN_DIVERSITY;
445 }
446
447 score = score.clamp(0.0, 1.0);
449
450 ComplexitySignals {
451 word_count,
452 step_markers: total_step_markers,
453 domain_signals,
454 explicit_delegation,
455 names_multiple_agents,
456 covering_agents: Vec::new(),
457 complexity_score: score,
458 }
459 }
460
461 fn capability_match(
463 &self,
464 task_domains: &[String],
465 covering_agents: &mut Vec<usize>,
466 ) -> RoutingDecision {
467 if task_domains.is_empty() {
468 return RoutingDecision::SingleAgent {
470 agent_index: 0,
471 reason: "no domains detected, defaulting to single agent",
472 };
473 }
474
475 for (i, agent) in self.agents.iter().enumerate() {
476 let covers_all = task_domains.iter().all(|d| agent.domains.contains(d));
477 if covers_all {
478 covering_agents.push(i);
479 }
480 }
481
482 if covering_agents.is_empty() {
483 RoutingDecision::Orchestrate {
484 reason: "no single agent covers all detected domains",
485 }
486 } else {
487 RoutingDecision::SingleAgent {
489 agent_index: covering_agents[0],
490 reason: "single agent covers all detected domains",
491 }
492 }
493 }
494}
495
496fn is_simple_question(task_lower: &str, words: &[&str]) -> bool {
498 let question_starters = [
499 "what", "how", "why", "explain", "describe", "who", "when", "where",
500 ];
501 let starts_with_question = words
505 .first()
506 .is_some_and(|w| question_starters.iter().any(|q| w.starts_with(q)));
507 let has_step_markers = count_step_markers(task_lower) >= 2;
508 starts_with_question && !has_step_markers
509}
510
511fn count_step_markers(task_lower: &str) -> usize {
513 STEP_MARKERS
514 .iter()
515 .filter(|marker| task_lower.contains(*marker))
516 .count()
517}
518
519fn is_numbered_step_marker(word: &str) -> bool {
524 let trimmed = word.trim_end_matches([';', ',', ':']);
525 if trimmed.len() < 2 {
526 return false;
527 }
528 if let Some(prefix) = trimmed.strip_suffix('.')
530 && !prefix.is_empty()
531 && prefix.chars().all(|c| c.is_ascii_digit())
532 {
533 return true;
534 }
535 if let Some(inner) = trimmed.strip_prefix('(').and_then(|s| s.strip_suffix(')'))
537 && !inner.is_empty()
538 && inner.chars().all(|c| c.is_ascii_digit())
539 {
540 return true;
541 }
542 if let Some(prefix) = trimmed.strip_suffix(')')
544 && !prefix.is_empty()
545 && prefix.chars().all(|c| c.is_ascii_digit())
546 {
547 return true;
548 }
549 false
550}
551
552fn contains_keyword(text: &str, keyword: &str) -> bool {
555 if keyword.contains(' ') {
556 return text.contains(keyword);
558 }
559 for (start, _) in text.match_indices(keyword) {
561 let end = start + keyword.len();
562 let before_ok = start == 0 || !text.as_bytes()[start - 1].is_ascii_alphanumeric();
563 let after_ok = end == text.len() || !text.as_bytes()[end].is_ascii_alphanumeric();
564 if before_ok && after_ok {
565 return true;
566 }
567 }
568 false
569}
570
571fn best_covering_agent(task_domains: &[String], agents: &[AgentCapability]) -> Option<usize> {
582 if task_domains.is_empty() || agents.is_empty() {
583 return None;
584 }
585 let mut best: Option<(usize, usize)> = None;
586 for (i, agent) in agents.iter().enumerate() {
587 let count = task_domains
588 .iter()
589 .filter(|d| agent.domains.contains(d))
590 .count();
591 if count == 0 {
592 continue;
593 }
594 match best {
595 Some((_, c)) if c >= count => {}
596 _ => best = Some((i, count)),
597 }
598 }
599 best.map(|(i, _)| i)
600}
601
602fn detect_domains(task_lower: &str) -> Vec<String> {
604 let mut domains = Vec::new();
605 for &(domain, keywords) in DOMAIN_LISTS {
606 if keywords.iter().any(|kw| contains_keyword(task_lower, kw)) {
607 domains.push(domain.to_string());
608 }
609 }
610 domains
611}
612
613pub fn should_escalate(error: &Error, events: &[AgentEvent]) -> bool {
615 let inner = match error {
617 Error::WithPartialUsage { source, .. } => source.as_ref(),
618 other => other,
619 };
620
621 if matches!(inner, Error::MaxTurnsExceeded(_) | Error::RunTimeout(_)) {
623 return true;
624 }
625
626 if events
628 .iter()
629 .any(|e| matches!(e, AgentEvent::DoomLoopDetected { .. }))
630 {
631 return true;
632 }
633
634 let compaction_count = events
636 .iter()
637 .filter(|e| matches!(e, AgentEvent::AutoCompactionTriggered { .. }))
638 .count();
639 if compaction_count >= 2 {
640 return true;
641 }
642
643 false
644}
645
646pub fn resolve_routing_mode(config_mode: RoutingMode) -> RoutingMode {
648 match std::env::var("HEARTBIT_ROUTING").ok() {
649 Some(val) => match val.to_lowercase().as_str() {
650 "auto" => RoutingMode::Auto,
651 "always_orchestrate" => RoutingMode::AlwaysOrchestrate,
652 "single_agent" => RoutingMode::SingleAgent,
653 _ => {
654 tracing::warn!(
655 value = %val,
656 "unknown HEARTBIT_ROUTING value, falling back to config"
657 );
658 config_mode
659 }
660 },
661 None => config_mode,
662 }
663}
664
665#[cfg(test)]
666mod tests {
667 use super::*;
668
669 fn make_agents() -> Vec<AgentCapability> {
670 vec![
671 AgentCapability::from_config(
672 "coder",
673 "A code implementation agent that writes and debugs software",
674 &["bash".into(), "read_file".into(), "write_file".into()],
675 ),
676 AgentCapability::from_config(
677 "researcher",
678 "A research agent that investigates and analyzes topics",
679 &["web_search".into(), "read_file".into()],
680 ),
681 ]
682 }
683
684 #[test]
687 fn agent_capability_extracts_domains_from_description() {
688 let cap = AgentCapability::from_config(
689 "fullstack",
690 "Handles frontend React components and backend API endpoints with database queries",
691 &[],
692 );
693 assert!(cap.domains.contains(&"frontend".to_string()));
694 assert!(cap.domains.contains(&"backend".to_string()));
695 assert!(cap.domains.contains(&"database".to_string()));
696 }
697
698 #[test]
699 fn agent_capability_extracts_domains_from_tools() {
700 let cap = AgentCapability::from_config(
701 "devops-agent",
702 "Manages infrastructure",
703 &["docker_build".into(), "deploy_k8s".into()],
704 );
705 assert!(cap.domains.contains(&"devops".to_string()));
706 }
707
708 #[test]
711 fn simple_question_scores_below_threshold() {
712 let agents = make_agents();
713 let analyzer = TaskComplexityAnalyzer::new(&agents);
714 let (decision, signals) = analyzer.analyze("What is the capital of France?");
715 assert!(
716 signals.complexity_score < SINGLE_AGENT_THRESHOLD,
717 "score {} should be < {}",
718 signals.complexity_score,
719 SINGLE_AGENT_THRESHOLD
720 );
721 assert!(matches!(decision, RoutingDecision::SingleAgent { .. }));
722 }
723
724 #[test]
725 fn multi_step_multi_domain_routes_to_orchestrate() {
726 let agents = make_agents();
727 let analyzer = TaskComplexityAnalyzer::new(&agents);
728 let task = "First, research the best database schema for user authentication. \
729 Then, implement the backend API endpoints in Rust. \
730 Finally, write frontend React components for the login form.";
731 let (decision, signals) = analyzer.analyze(task);
732 assert!(
733 signals.step_markers >= 2,
734 "step_markers: {}",
735 signals.step_markers
736 );
737 assert!(
738 signals.domain_signals.len() >= 3,
739 "domains: {:?}",
740 signals.domain_signals
741 );
742 assert!(
745 matches!(decision, RoutingDecision::Orchestrate { .. }),
746 "decision: {decision:?}, score: {}",
747 signals.complexity_score
748 );
749 }
750
751 #[test]
752 fn delegation_language_boosts_score() {
753 let agents = make_agents();
754 let analyzer = TaskComplexityAnalyzer::new(&agents);
755 let task = "Delegate the research task to the researcher and coordinate with the coder";
756 let signals = analyzer.heuristic_signals(task);
757 assert!(signals.explicit_delegation);
758 assert!(
760 signals.complexity_score > ORCHESTRATE_THRESHOLD,
761 "score: {}",
762 signals.complexity_score
763 );
764 }
765
766 #[test]
767 fn naming_multiple_agents_boosts_score() {
768 let agents = make_agents();
769 let analyzer = TaskComplexityAnalyzer::new(&agents);
770 let task = "Have coder write the code and researcher find the documentation";
771 let signals = analyzer.heuristic_signals(task);
772 assert!(signals.names_multiple_agents);
773 assert!(
774 signals.complexity_score >= WEIGHT_NAMES_AGENTS,
775 "score: {}",
776 signals.complexity_score
777 );
778 }
779
780 #[test]
781 fn word_count_above_100_adds_weight() {
782 let agents = make_agents();
783 let analyzer = TaskComplexityAnalyzer::new(&agents);
784 let task = (0..110)
786 .map(|i| format!("word{i}"))
787 .collect::<Vec<_>>()
788 .join(" ");
789 let signals = analyzer.heuristic_signals(&task);
790 assert!(signals.word_count > 100);
791 assert!(
792 signals.complexity_score >= WEIGHT_WORD_COUNT_HIGH,
793 "score: {}",
794 signals.complexity_score
795 );
796 }
797
798 #[test]
799 fn numbered_list_detected_as_step_markers() {
800 let agents = make_agents();
801 let analyzer = TaskComplexityAnalyzer::new(&agents);
802 let task = "1. Set up the database. 2. Write the API. 3. Test everything.";
803 let signals = analyzer.heuristic_signals(task);
804 assert!(
805 signals.step_markers >= 2,
806 "step_markers: {}",
807 signals.step_markers
808 );
809 }
810
811 #[test]
812 fn score_clamped_to_zero_one() {
813 let agents = make_agents();
814 let analyzer = TaskComplexityAnalyzer::new(&agents);
815
816 let signals = analyzer.heuristic_signals("What is 2+2?");
818 assert!(
819 signals.complexity_score >= 0.0,
820 "score: {}",
821 signals.complexity_score
822 );
823
824 let task = "First, delegate to coder and researcher. Then step 1 deploy the docker \
826 kubernetes infrastructure with database schema, frontend React components, \
827 backend API endpoints, security audit, research papers, and write documentation. \
828 Finally, coordinate the team. ".repeat(5);
829 let signals = analyzer.heuristic_signals(&task);
830 assert!(
831 signals.complexity_score <= 1.0,
832 "score: {}",
833 signals.complexity_score
834 );
835 }
836
837 #[test]
840 fn one_agent_covers_all_domains_routes_to_single() {
841 let agents = vec![AgentCapability::from_config(
842 "fullstack",
843 "Handles code implementation, database queries, and backend API endpoints",
844 &[],
845 )];
846 let analyzer = TaskComplexityAnalyzer::new(&agents);
847 let task = "Update the database query and fix the backend API endpoint bug";
849 let (decision, signals) = analyzer.analyze(task);
850 match &decision {
852 RoutingDecision::SingleAgent { agent_index, .. } => {
853 assert_eq!(*agent_index, 0);
854 }
855 RoutingDecision::Orchestrate { reason } => {
856 assert!(
858 signals.complexity_score > ORCHESTRATE_THRESHOLD,
859 "unexpected orchestrate: {reason}"
860 );
861 }
862 }
863 }
864
865 #[test]
866 fn split_coverage_routes_to_orchestrate() {
867 let agents = vec![
868 AgentCapability::from_config(
869 "frontend-dev",
870 "Builds frontend React components and CSS layouts",
871 &[],
872 ),
873 AgentCapability::from_config(
874 "backend-dev",
875 "Builds backend API endpoints and database schemas",
876 &[],
877 ),
878 ];
879 let analyzer = TaskComplexityAnalyzer::new(&agents);
880 let task = "Build a React form that submits to a new backend API endpoint and stores data in the database";
882 let mut signals = analyzer.heuristic_signals(task);
883 let mut covering = Vec::new();
884 let decision = analyzer.capability_match(&signals.domain_signals, &mut covering);
885 signals.covering_agents = covering;
886 assert!(
887 matches!(decision, RoutingDecision::Orchestrate { .. }),
888 "expected Orchestrate, got: {decision:?}"
889 );
890 assert!(signals.covering_agents.is_empty());
891 }
892
893 #[test]
894 fn no_domains_defaults_to_single_agent() {
895 let agents = make_agents();
896 let analyzer = TaskComplexityAnalyzer::new(&agents);
897 let mut covering = Vec::new();
898 let decision = analyzer.capability_match(&[], &mut covering);
899 assert!(matches!(
900 decision,
901 RoutingDecision::SingleAgent { agent_index: 0, .. }
902 ));
903 }
904
905 #[test]
908 fn escalate_on_max_turns_exceeded() {
909 let err = Error::MaxTurnsExceeded(10);
910 assert!(should_escalate(&err, &[]));
911 }
912
913 #[test]
914 fn escalate_on_max_turns_wrapped_in_partial_usage() {
915 use crate::llm::types::TokenUsage;
916 let err = Error::MaxTurnsExceeded(10).with_partial_usage(TokenUsage::default());
917 assert!(should_escalate(&err, &[]));
918 }
919
920 #[test]
921 fn escalate_on_run_timeout() {
922 let err = Error::RunTimeout(std::time::Duration::from_secs(60));
923 assert!(should_escalate(&err, &[]));
924 }
925
926 #[test]
927 fn escalate_on_doom_loop_event() {
928 let events = vec![AgentEvent::DoomLoopDetected {
929 agent: "a".into(),
930 turn: 5,
931 consecutive_count: 3,
932 tool_names: vec!["web_search".into()],
933 }];
934 let err = Error::Agent("generic error".into());
935 assert!(should_escalate(&err, &events));
936 }
937
938 #[test]
939 fn escalate_on_two_compactions() {
940 let events = vec![
941 AgentEvent::AutoCompactionTriggered {
942 agent: "a".into(),
943 turn: 2,
944 success: true,
945 usage: Default::default(),
946 },
947 AgentEvent::AutoCompactionTriggered {
948 agent: "a".into(),
949 turn: 5,
950 success: true,
951 usage: Default::default(),
952 },
953 ];
954 let err = Error::Agent("context overflow".into());
955 assert!(should_escalate(&err, &events));
956 }
957
958 #[test]
959 fn no_escalation_on_single_compaction() {
960 let events = vec![AgentEvent::AutoCompactionTriggered {
961 agent: "a".into(),
962 turn: 2,
963 success: true,
964 usage: Default::default(),
965 }];
966 let err = Error::Agent("some error".into());
967 assert!(!should_escalate(&err, &events));
968 }
969
970 #[test]
971 fn no_escalation_on_normal_error() {
972 let err = Error::Agent("tool failed".into());
973 assert!(!should_escalate(&err, &[]));
974 }
975
976 #[test]
979 fn routing_mode_default_is_auto() {
980 assert_eq!(RoutingMode::default(), RoutingMode::Auto);
981 }
982
983 #[test]
984 fn routing_mode_roundtrips_json() {
985 for mode in [
986 RoutingMode::Auto,
987 RoutingMode::AlwaysOrchestrate,
988 RoutingMode::SingleAgent,
989 ] {
990 let json = serde_json::to_string(&mode).unwrap();
991 let back: RoutingMode = serde_json::from_str(&json).unwrap();
992 assert_eq!(mode, back, "failed for {json}");
993 }
994 }
995
996 #[test]
997 fn routing_mode_deserializes_from_toml_strings() {
998 #[derive(Deserialize)]
999 struct W {
1000 mode: RoutingMode,
1001 }
1002 let w: W = toml::from_str(r#"mode = "auto""#).unwrap();
1003 assert_eq!(w.mode, RoutingMode::Auto);
1004 let w: W = toml::from_str(r#"mode = "always_orchestrate""#).unwrap();
1005 assert_eq!(w.mode, RoutingMode::AlwaysOrchestrate);
1006 let w: W = toml::from_str(r#"mode = "single_agent""#).unwrap();
1007 assert_eq!(w.mode, RoutingMode::SingleAgent);
1008 }
1009
1010 #[test]
1013 fn analyze_simple_task_two_agents_routes_single() {
1014 let agents = make_agents();
1015 let analyzer = TaskComplexityAnalyzer::new(&agents);
1016 let (decision, _) = analyzer.analyze("How do I parse JSON in Rust?");
1017 assert!(
1018 matches!(decision, RoutingDecision::SingleAgent { .. }),
1019 "got: {decision:?}"
1020 );
1021 }
1022
1023 #[test]
1024 fn analyze_complex_multi_domain_task_routes_orchestrate() {
1025 let agents = make_agents();
1026 let analyzer = TaskComplexityAnalyzer::new(&agents);
1027 let task = "First, research the latest security vulnerabilities. \
1028 Then, implement a fix in the backend API code. \
1029 Finally, deploy the fix using Docker and update the documentation.";
1030 let (decision, signals) = analyzer.analyze(task);
1031 assert!(
1032 signals.complexity_score > ORCHESTRATE_THRESHOLD
1033 || matches!(decision, RoutingDecision::Orchestrate { .. }),
1034 "decision: {decision:?}, score: {}",
1035 signals.complexity_score
1036 );
1037 }
1038
1039 #[test]
1040 fn analyze_delegation_with_agent_names_routes_orchestrate() {
1041 let agents = make_agents();
1042 let analyzer = TaskComplexityAnalyzer::new(&agents);
1043 let task =
1044 "Delegate to coder to implement the feature and have researcher find best practices";
1045 let (decision, signals) = analyzer.analyze(task);
1046 assert!(
1047 matches!(decision, RoutingDecision::Orchestrate { .. }),
1048 "decision: {decision:?}, score: {}",
1049 signals.complexity_score
1050 );
1051 }
1052
1053 #[test]
1056 fn resolve_routing_mode_uses_config_when_no_env() {
1057 unsafe {
1060 std::env::remove_var("HEARTBIT_ROUTING");
1061 }
1062 assert_eq!(
1063 resolve_routing_mode(RoutingMode::AlwaysOrchestrate),
1064 RoutingMode::AlwaysOrchestrate
1065 );
1066 }
1067
1068 #[test]
1073 fn keyword_routing_strategy_routes_simple_to_single() {
1074 let agents = make_agents();
1075 let strategy = KeywordRoutingStrategy;
1076 let (decision, _) = strategy.route("What is Rust?", &agents);
1077 assert!(
1078 matches!(decision, RoutingDecision::SingleAgent { .. }),
1079 "got: {decision:?}"
1080 );
1081 }
1082
1083 #[test]
1084 fn keyword_routing_strategy_routes_complex_to_orchestrate() {
1085 let agents = make_agents();
1086 let strategy = KeywordRoutingStrategy;
1087 let task =
1088 "Delegate to coder to implement the feature and have researcher find best practices";
1089 let (decision, _) = strategy.route(task, &agents);
1090 assert!(
1091 matches!(decision, RoutingDecision::Orchestrate { .. }),
1092 "got: {decision:?}"
1093 );
1094 }
1095
1096 #[test]
1097 fn custom_routing_strategy() {
1098 struct AlwaysOrchestrate;
1099 impl RoutingStrategy for AlwaysOrchestrate {
1100 fn route(
1101 &self,
1102 _task: &str,
1103 _agents: &[AgentCapability],
1104 ) -> (RoutingDecision, ComplexitySignals) {
1105 (
1106 RoutingDecision::Orchestrate {
1107 reason: "custom: always orchestrate",
1108 },
1109 ComplexitySignals::default(),
1110 )
1111 }
1112 }
1113
1114 let agents = make_agents();
1115 let strategy = AlwaysOrchestrate;
1116 let (decision, _) = strategy.route("What is 2+2?", &agents);
1117 assert!(
1118 matches!(decision, RoutingDecision::Orchestrate { .. }),
1119 "got: {decision:?}"
1120 );
1121 }
1122
1123 #[test]
1124 fn custom_routing_strategy_with_domain_matching() {
1125 struct PricingRouter;
1126 impl RoutingStrategy for PricingRouter {
1127 fn route(
1128 &self,
1129 task: &str,
1130 agents: &[AgentCapability],
1131 ) -> (RoutingDecision, ComplexitySignals) {
1132 let task_lower = task.to_lowercase();
1133 if task_lower.contains("pricing") || task_lower.contains("quote") {
1134 let idx = agents.iter().position(|a| a.name == "quoter").unwrap_or(0);
1135 return (
1136 RoutingDecision::SingleAgent {
1137 agent_index: idx,
1138 reason: "pricing domain detected",
1139 },
1140 ComplexitySignals::default(),
1141 );
1142 }
1143 KeywordRoutingStrategy.route(task, agents)
1145 }
1146 }
1147
1148 let agents = vec![
1149 AgentCapability::from_config("miner", "Finds sales leads", &[]),
1150 AgentCapability::from_config("quoter", "Generates pricing quotes", &[]),
1151 ];
1152 let strategy = PricingRouter;
1153
1154 let (decision, _) = strategy.route("Generate a pricing quote for the client", &agents);
1156 match decision {
1157 RoutingDecision::SingleAgent { agent_index, .. } => assert_eq!(agent_index, 1),
1158 other => panic!("expected SingleAgent, got: {other:?}"),
1159 }
1160
1161 let (decision, _) = strategy.route("What is 2+2?", &agents);
1163 assert!(matches!(decision, RoutingDecision::SingleAgent { .. }));
1164 }
1165
1166 #[test]
1167 fn routing_strategy_dyn_dispatch() {
1168 let strategy: std::sync::Arc<dyn RoutingStrategy> =
1170 std::sync::Arc::new(KeywordRoutingStrategy);
1171 let agents = make_agents();
1172 let (decision, _) = strategy.route("What is Rust?", &agents);
1173 assert!(matches!(decision, RoutingDecision::SingleAgent { .. }));
1174 }
1175
1176 #[test]
1177 fn missing_routing_field_defaults_to_auto() {
1178 #[derive(Deserialize)]
1179 struct TestConfig {
1180 #[serde(default)]
1181 routing: RoutingMode,
1182 }
1183 let config: TestConfig = toml::from_str("").unwrap();
1184 assert_eq!(config.routing, RoutingMode::Auto);
1185 }
1186
1187 #[test]
1190 fn contains_keyword_word_boundary() {
1191 assert!(!contains_keyword("builds backend api", "ui"));
1193 assert!(contains_keyword("the ui is broken", "ui"));
1195 assert!(contains_keyword("ui components", "ui"));
1197 assert!(contains_keyword("fix the ui", "ui"));
1199 assert!(contains_keyword("the api endpoint", "api"));
1201 assert!(!contains_keyword("the capital city", "api"));
1202 }
1203
1204 #[test]
1205 fn contains_keyword_multi_word() {
1206 assert!(contains_keyword("run the unit test suite", "unit test"));
1207 assert!(!contains_keyword("run the unittest suite", "unit test"));
1208 }
1209
1210 #[test]
1211 fn contains_keyword_adjacent_to_punctuation() {
1212 assert!(contains_keyword("fix the api.", "api"));
1214 assert!(contains_keyword("(api) endpoint", "api"));
1215 assert!(contains_keyword("api/rest", "api"));
1216 }
1217
1218 #[test]
1221 fn detect_domains_finds_multiple() {
1222 let domains = detect_domains("implement the api endpoint and write database migration");
1223 assert!(domains.contains(&"code".to_string())); assert!(domains.contains(&"backend".to_string())); assert!(domains.contains(&"database".to_string())); }
1227
1228 #[test]
1229 fn detect_domains_empty_for_generic_text() {
1230 let domains = detect_domains("hello world how are you");
1231 assert!(domains.is_empty());
1232 }
1233
1234 #[test]
1237 fn empty_task_routes_single_agent() {
1238 let agents = make_agents();
1239 let analyzer = TaskComplexityAnalyzer::new(&agents);
1240 let (decision, signals) = analyzer.analyze("");
1241 assert_eq!(signals.complexity_score, 0.0);
1242 assert!(matches!(decision, RoutingDecision::SingleAgent { .. }));
1243 }
1244
1245 #[test]
1246 fn single_agent_list_always_routes_single() {
1247 let agents = vec![AgentCapability::from_config("solo", "Does everything", &[])];
1248 let analyzer = TaskComplexityAnalyzer::new(&agents);
1249 let task = "Delegate the complex multi-step task that involves code, database, frontend, backend, security, devops, research, and writing";
1251 let (decision, signals) = analyzer.analyze(task);
1252 assert!(!signals.names_multiple_agents);
1254 match decision {
1256 RoutingDecision::SingleAgent { .. } | RoutingDecision::Orchestrate { .. } => {}
1257 }
1258 }
1259
1260 fn issue9_agents() -> Vec<AgentCapability> {
1275 vec![
1276 AgentCapability::from_config(
1277 "researcher",
1278 "Investigates a topic, gathers facts.",
1279 &["web_search".into()],
1280 ),
1281 AgentCapability::from_config(
1282 "coder",
1283 "Writes and refactors Rust code.",
1284 &["read".into(), "write".into(), "edit".into()],
1285 ),
1286 ]
1287 }
1288
1289 #[test]
1290 fn gh9_tier1_routes_to_best_covering_agent_when_domains_present() {
1291 let agents = issue9_agents();
1296 let analyzer = TaskComplexityAnalyzer::new(&agents);
1297 let task = "Ask the coder to refactor the parse_args function and write tests.";
1298 let (decision, signals) = analyzer.analyze(task);
1299 assert!(
1300 signals.domain_signals.contains(&"code".to_string()),
1301 "expected `code` in domain_signals, got: {:?}",
1302 signals.domain_signals
1303 );
1304 match decision {
1305 RoutingDecision::SingleAgent { agent_index, .. } => assert_eq!(
1306 agent_index, 1,
1307 "Tier 1 should route to coder (idx 1), got idx {agent_index}",
1308 ),
1309 RoutingDecision::Orchestrate { .. } => {}
1312 }
1313 }
1314
1315 #[test]
1316 fn gh9_parenthesised_numbers_count_as_step_markers() {
1317 let agents = make_agents();
1319 let analyzer = TaskComplexityAnalyzer::new(&agents);
1320 let task = "We need to: (1) investigate Rust async; (2) compare tokio vs smol; \
1321 (3) write a benchmark; (4) implement it; (5) review the code; \
1322 (6) summarize findings.";
1323 let signals = analyzer.heuristic_signals(task);
1324 assert!(
1325 signals.step_markers >= 4,
1326 "expected step_markers >= 4 for `(1) (2) ... (6)` pattern, got {}",
1327 signals.step_markers
1328 );
1329 }
1330
1331 #[test]
1332 fn gh9_right_paren_numbers_count_as_step_markers() {
1333 let agents = make_agents();
1335 let analyzer = TaskComplexityAnalyzer::new(&agents);
1336 let task = "Plan: 1) gather data; 2) draft outline; 3) review; 4) publish.";
1337 let signals = analyzer.heuristic_signals(task);
1338 assert!(
1339 signals.step_markers >= 3,
1340 "expected step_markers >= 3 for `1) 2) 3) 4)` pattern, got {}",
1341 signals.step_markers
1342 );
1343 }
1344
1345 #[test]
1346 fn gh9_ask_the_phrasing_triggers_delegation() {
1347 let agents = issue9_agents();
1348 let analyzer = TaskComplexityAnalyzer::new(&agents);
1349 let signals =
1350 analyzer.heuristic_signals("Ask the coder to refactor the parse_args function.");
1351 assert!(
1352 signals.explicit_delegation,
1353 "`ask the X to ...` should trigger explicit_delegation",
1354 );
1355 }
1356
1357 #[test]
1358 fn gh9_have_the_phrasing_triggers_delegation() {
1359 let agents = issue9_agents();
1360 let analyzer = TaskComplexityAnalyzer::new(&agents);
1361 let signals = analyzer.heuristic_signals("Have the researcher gather data on Rust async.");
1362 assert!(
1363 signals.explicit_delegation,
1364 "`have the X ...` should trigger explicit_delegation",
1365 );
1366 }
1367
1368 #[test]
1369 fn gh9_no_domains_still_falls_back_to_agent_zero() {
1370 let agents = issue9_agents();
1373 let analyzer = TaskComplexityAnalyzer::new(&agents);
1374 let (decision, signals) = analyzer.analyze("Hello, are you there?");
1375 assert!(
1376 signals.domain_signals.is_empty(),
1377 "domains should be empty for plain greeting, got: {:?}",
1378 signals.domain_signals
1379 );
1380 assert!(matches!(
1381 decision,
1382 RoutingDecision::SingleAgent { agent_index: 0, .. }
1383 ));
1384 }
1385}