use std::collections::HashMap;
use async_trait::async_trait;
use super::{BackendType, SpawnConfig};
#[async_trait]
pub trait BackendRouter: Send + Sync {
async fn route(
&self,
config: &SpawnConfig,
available: &[BackendType],
) -> Option<BackendType>;
}
#[derive(Debug)]
pub struct KeywordRouter {
rules: Vec<(String, BackendType)>,
default: BackendType,
use_word_boundary: bool,
}
impl KeywordRouter {
pub fn new(default: BackendType) -> Self {
Self {
rules: Vec::new(),
default,
use_word_boundary: false,
}
}
pub fn word_boundary(mut self, enable: bool) -> Self {
self.use_word_boundary = enable;
self
}
pub fn rule(mut self, keyword: impl Into<String>, backend: BackendType) -> Self {
self.rules.push((keyword.into().to_lowercase(), backend));
self
}
pub fn rules(mut self, rules: impl IntoIterator<Item = (String, BackendType)>) -> Self {
for (kw, bt) in rules {
self.rules.push((kw.to_lowercase(), bt));
}
self
}
}
fn contains_word(text: &str, word: &str) -> bool {
if word.is_empty() || text.len() < word.len() {
return false;
}
let text_bytes = text.as_bytes();
let mut start = 0;
while start + word.len() <= text.len() {
match text[start..].find(word) {
Some(pos) => {
let abs_pos = start + pos;
let before_ok =
abs_pos == 0 || !text_bytes[abs_pos - 1].is_ascii_alphanumeric();
let after_pos = abs_pos + word.len();
let after_ok =
after_pos >= text.len() || !text_bytes[after_pos].is_ascii_alphanumeric();
if before_ok && after_ok {
return true;
}
start = abs_pos + word.len();
}
None => break,
}
}
false
}
#[async_trait]
impl BackendRouter for KeywordRouter {
async fn route(
&self,
config: &SpawnConfig,
available: &[BackendType],
) -> Option<BackendType> {
if available.is_empty() {
return None;
}
let prompt_lower = config.prompt.to_lowercase();
for (keyword, backend) in &self.rules {
let matched = if self.use_word_boundary {
contains_word(&prompt_lower, keyword.as_str())
} else {
prompt_lower.contains(keyword.as_str())
};
if matched && available.contains(backend) {
return Some(*backend);
}
}
if available.contains(&self.default) {
Some(self.default)
} else {
available.first().copied()
}
}
}
#[derive(Debug, Clone)]
pub struct BackendCapability {
pub multi_turn: bool,
pub streaming: bool,
pub cost_tier: u8,
pub latency_tier: u8,
}
impl Default for BackendCapability {
fn default() -> Self {
Self {
multi_turn: true,
streaming: true,
cost_tier: 2,
latency_tier: 2,
}
}
}
impl BackendCapability {
pub fn defaults() -> HashMap<BackendType, BackendCapability> {
let mut map = HashMap::new();
map.insert(BackendType::ClaudeCode, BackendCapability {
multi_turn: true,
streaming: true,
cost_tier: 3, latency_tier: 2,
});
map.insert(BackendType::Codex, BackendCapability {
multi_turn: true,
streaming: true,
cost_tier: 2,
latency_tier: 2,
});
map.insert(BackendType::GeminiCli, BackendCapability {
multi_turn: false, streaming: true,
cost_tier: 1, latency_tier: 3, });
map
}
}
#[derive(Debug)]
pub struct CapabilityRouter {
capabilities: HashMap<BackendType, BackendCapability>,
require_multi_turn: bool,
cost_weight: f32,
}
impl Default for CapabilityRouter {
fn default() -> Self {
Self::new()
}
}
impl CapabilityRouter {
pub fn new() -> Self {
Self {
capabilities: BackendCapability::defaults(),
require_multi_turn: false,
cost_weight: 0.5,
}
}
pub fn require_multi_turn(mut self, require: bool) -> Self {
self.require_multi_turn = require;
self
}
pub fn cost_weight(mut self, weight: f32) -> Self {
self.cost_weight = weight.clamp(0.0, 1.0);
self
}
pub fn with_capability(mut self, backend: BackendType, cap: BackendCapability) -> Self {
self.capabilities.insert(backend, cap);
self
}
}
#[async_trait]
impl BackendRouter for CapabilityRouter {
async fn route(
&self,
_config: &SpawnConfig,
available: &[BackendType],
) -> Option<BackendType> {
let latency_weight = 1.0 - self.cost_weight;
available
.iter()
.filter(|bt| {
if let Some(cap) = self.capabilities.get(bt) {
!self.require_multi_turn || cap.multi_turn
} else {
!self.require_multi_turn
}
})
.min_by(|a, b| {
let score = |bt: &BackendType| -> f32 {
if let Some(cap) = self.capabilities.get(bt) {
self.cost_weight * cap.cost_tier as f32
+ latency_weight * cap.latency_tier as f32
} else {
self.cost_weight * 2.0 + latency_weight * 2.0
}
};
score(a)
.partial_cmp(&score(b))
.unwrap_or(std::cmp::Ordering::Equal)
})
.copied()
}
}
pub struct ChainRouter {
routers: Vec<Box<dyn BackendRouter>>,
}
impl std::fmt::Debug for ChainRouter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ChainRouter")
.field("routers_count", &self.routers.len())
.finish()
}
}
impl ChainRouter {
pub fn new() -> Self {
Self {
routers: Vec::new(),
}
}
pub fn push(mut self, router: impl BackendRouter + 'static) -> Self {
self.routers.push(Box::new(router));
self
}
}
impl Default for ChainRouter {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl BackendRouter for ChainRouter {
async fn route(
&self,
config: &SpawnConfig,
available: &[BackendType],
) -> Option<BackendType> {
for router in &self.routers {
if let Some(bt) = router.route(config, available).await {
return Some(bt);
}
}
None
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PromptComplexity {
Simple,
Medium,
Complex,
}
#[derive(Debug)]
pub struct SmartRouter {
priorities: Vec<(String, BackendType)>,
simple: Option<BackendType>,
medium: Option<BackendType>,
complex: Option<BackendType>,
simple_threshold: usize,
complex_threshold: usize,
capability: CapabilityRouter,
default: BackendType,
}
impl SmartRouter {
pub fn new(default: BackendType) -> Self {
Self {
priorities: Vec::new(),
simple: None,
medium: None,
complex: None,
simple_threshold: 200,
complex_threshold: 800,
capability: CapabilityRouter::new(),
default,
}
}
pub fn priority(mut self, keyword: impl Into<String>, backend: BackendType) -> Self {
self.priorities.push((keyword.into().to_lowercase(), backend));
self
}
pub fn simple_backend(mut self, backend: BackendType) -> Self {
self.simple = Some(backend);
self
}
pub fn medium_backend(mut self, backend: BackendType) -> Self {
self.medium = Some(backend);
self
}
pub fn complex_backend(mut self, backend: BackendType) -> Self {
self.complex = Some(backend);
self
}
pub fn complexity_threshold(mut self, simple: usize, complex: usize) -> Self {
debug_assert!(simple < complex, "simple threshold must be less than complex threshold");
self.simple_threshold = simple;
self.complex_threshold = complex;
self
}
pub fn cost_weight(mut self, weight: f32) -> Self {
self.capability = self.capability.cost_weight(weight);
self
}
pub fn require_multi_turn(mut self, require: bool) -> Self {
self.capability = self.capability.require_multi_turn(require);
self
}
pub fn with_capability(mut self, backend: BackendType, cap: BackendCapability) -> Self {
self.capability = self.capability.with_capability(backend, cap);
self
}
pub fn analyze_complexity(&self, prompt: &str) -> PromptComplexity {
let len = prompt.len();
let has_code = prompt.contains("```")
|| prompt.contains("fn ")
|| prompt.contains("def ")
|| prompt.contains("class ")
|| prompt.contains("impl ")
|| prompt.contains("struct ")
|| prompt.contains("import ")
|| prompt.contains("#include")
|| prompt.contains("function ")
|| prompt.contains("async fn");
let paragraph_count = prompt.split("\n\n").count();
if len >= self.complex_threshold || (has_code && len >= self.simple_threshold) || paragraph_count >= 4 {
PromptComplexity::Complex
} else if len >= self.simple_threshold || has_code || paragraph_count >= 2 {
PromptComplexity::Medium
} else {
PromptComplexity::Simple
}
}
}
#[async_trait]
impl BackendRouter for SmartRouter {
async fn route(
&self,
config: &SpawnConfig,
available: &[BackendType],
) -> Option<BackendType> {
if available.is_empty() {
return None;
}
let prompt_lower = config.prompt.to_lowercase();
for (keyword, backend) in &self.priorities {
if contains_word(&prompt_lower, keyword.as_str()) && available.contains(backend) {
return Some(*backend);
}
}
let complexity = self.analyze_complexity(&config.prompt);
let preferred = match complexity {
PromptComplexity::Simple => self.simple,
PromptComplexity::Medium => self.medium,
PromptComplexity::Complex => self.complex,
};
if let Some(backend) = preferred {
if available.contains(&backend) {
return Some(backend);
}
}
if let Some(bt) = self.capability.route(config, available).await {
return Some(bt);
}
if available.contains(&self.default) {
Some(self.default)
} else {
available.first().copied()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn keyword_router_matches_first_rule() {
let router = KeywordRouter::new(BackendType::ClaudeCode)
.rule("review", BackendType::GeminiCli)
.rule("implement", BackendType::ClaudeCode)
.rule("test", BackendType::Codex);
let config = SpawnConfig::new("reviewer", "Please review this code carefully.");
let available = vec![BackendType::ClaudeCode, BackendType::Codex, BackendType::GeminiCli];
let result = router.route(&config, &available).await;
assert_eq!(result, Some(BackendType::GeminiCli));
}
#[tokio::test]
async fn keyword_router_falls_back_to_default() {
let router = KeywordRouter::new(BackendType::Codex)
.rule("review", BackendType::GeminiCli);
let config = SpawnConfig::new("worker", "Do something unrelated.");
let available = vec![BackendType::ClaudeCode, BackendType::Codex];
let result = router.route(&config, &available).await;
assert_eq!(result, Some(BackendType::Codex));
}
#[tokio::test]
async fn keyword_router_skips_unavailable_backend() {
let router = KeywordRouter::new(BackendType::ClaudeCode)
.rule("review", BackendType::GeminiCli);
let config = SpawnConfig::new("reviewer", "Please review this code.");
let available = vec![BackendType::ClaudeCode, BackendType::Codex];
let result = router.route(&config, &available).await;
assert_eq!(result, Some(BackendType::ClaudeCode));
}
#[tokio::test]
async fn keyword_router_case_insensitive() {
let router = KeywordRouter::new(BackendType::ClaudeCode)
.rule("REVIEW", BackendType::GeminiCli);
let config = SpawnConfig::new("reviewer", "review this code");
let available = vec![BackendType::ClaudeCode, BackendType::GeminiCli];
let result = router.route(&config, &available).await;
assert_eq!(result, Some(BackendType::GeminiCli));
}
#[tokio::test]
async fn keyword_router_returns_none_for_empty_available() {
let router = KeywordRouter::new(BackendType::ClaudeCode);
let config = SpawnConfig::new("worker", "Do something.");
let result = router.route(&config, &[]).await;
assert_eq!(result, None);
}
#[tokio::test]
async fn capability_router_prefers_cheapest() {
let router = CapabilityRouter::new()
.cost_weight(1.0);
let config = SpawnConfig::new("worker", "Do a simple task.");
let available = vec![BackendType::ClaudeCode, BackendType::Codex, BackendType::GeminiCli];
let result = router.route(&config, &available).await;
assert_eq!(result, Some(BackendType::GeminiCli)); }
#[tokio::test]
async fn capability_router_multi_turn_requirement() {
let router = CapabilityRouter::new()
.require_multi_turn(true)
.cost_weight(1.0);
let config = SpawnConfig::new("worker", "Multi-turn session.");
let available = vec![BackendType::ClaudeCode, BackendType::Codex, BackendType::GeminiCli];
let result = router.route(&config, &available).await;
assert_eq!(result, Some(BackendType::Codex));
}
#[tokio::test]
async fn capability_router_prefers_fastest() {
let router = CapabilityRouter::new()
.cost_weight(0.0);
let config = SpawnConfig::new("worker", "Quick task.");
let available = vec![BackendType::ClaudeCode, BackendType::Codex, BackendType::GeminiCli];
let result = router.route(&config, &available).await;
assert_eq!(result, Some(BackendType::ClaudeCode));
}
#[tokio::test]
async fn capability_router_returns_none_for_empty_available() {
let router = CapabilityRouter::new();
let config = SpawnConfig::new("worker", "Do something.");
let result = router.route(&config, &[]).await;
assert_eq!(result, None);
}
#[test]
fn contains_word_exact_match() {
assert!(super::contains_word("review this code", "review"));
assert!(super::contains_word("please review", "review"));
assert!(super::contains_word("review", "review"));
}
#[test]
fn contains_word_rejects_substring() {
assert!(!super::contains_word("reviewing this code", "review"));
assert!(!super::contains_word("code reviewer", "review"));
assert!(!super::contains_word("prereview step", "review"));
}
#[test]
fn contains_word_with_punctuation() {
assert!(super::contains_word("please review, thanks", "review"));
assert!(super::contains_word("review.", "review"));
assert!(super::contains_word("(review)", "review"));
}
#[test]
fn contains_word_test_vs_testing() {
assert!(super::contains_word("test this function", "test"));
assert!(!super::contains_word("testing this function", "test"));
assert!(!super::contains_word("run the contest", "test"));
}
#[test]
fn contains_word_edge_cases() {
assert!(!super::contains_word("", "review"));
assert!(!super::contains_word("some text", ""));
assert!(!super::contains_word("", ""));
assert!(super::contains_word("a quick task", "a"));
assert!(!super::contains_word("analyze this", "a"));
assert!(super::contains_word("do a test", "test"));
assert!(super::contains_word("testing the test", "test"));
assert!(super::contains_word("run unit-test now", "test"));
assert!(super::contains_word("café test", "test"));
assert!(super::contains_word("café", "café"));
assert!(!super::contains_word("acafé", "café"));
assert!(super::contains_word("aéb test éb", "éb"));
}
#[tokio::test]
async fn keyword_router_word_boundary_mode() {
let router = KeywordRouter::new(BackendType::ClaudeCode)
.word_boundary(true)
.rule("test", BackendType::Codex);
let available = vec![BackendType::ClaudeCode, BackendType::Codex];
let config = SpawnConfig::new("w", "test this function");
assert_eq!(router.route(&config, &available).await, Some(BackendType::Codex));
let config = SpawnConfig::new("w", "testing this function");
assert_eq!(router.route(&config, &available).await, Some(BackendType::ClaudeCode));
}
#[tokio::test]
async fn chain_router_first_match_wins() {
let keyword = KeywordRouter::new(BackendType::ClaudeCode)
.rule("review", BackendType::GeminiCli);
let capability = CapabilityRouter::new().cost_weight(1.0);
let router = ChainRouter::new().push(keyword).push(capability);
let config = SpawnConfig::new("w", "review this code");
let available = vec![BackendType::ClaudeCode, BackendType::Codex, BackendType::GeminiCli];
assert_eq!(router.route(&config, &available).await, Some(BackendType::GeminiCli));
}
#[tokio::test]
async fn chain_router_falls_through_to_second() {
let keyword = KeywordRouter::new(BackendType::ClaudeCode)
.word_boundary(true)
.rule("review", BackendType::GeminiCli);
let capability = CapabilityRouter::new().cost_weight(1.0);
let router = ChainRouter::new().push(keyword).push(capability);
let config = SpawnConfig::new("w", "do a simple task");
let available = vec![BackendType::ClaudeCode, BackendType::Codex, BackendType::GeminiCli];
let result = router.route(&config, &available).await;
assert_eq!(result, Some(BackendType::ClaudeCode));
}
#[tokio::test]
async fn chain_router_empty_returns_none() {
let router = ChainRouter::new();
let config = SpawnConfig::new("w", "anything");
let result = router.route(&config, &[BackendType::ClaudeCode]).await;
assert_eq!(result, None);
}
#[test]
fn backend_capability_default() {
let cap = BackendCapability::default();
assert!(cap.multi_turn);
assert!(cap.streaming);
assert_eq!(cap.cost_tier, 2);
assert_eq!(cap.latency_tier, 2);
}
#[test]
fn capability_router_with_custom_capability() {
let cap = BackendCapability {
cost_tier: 1,
..Default::default()
};
assert!(cap.multi_turn); assert_eq!(cap.cost_tier, 1); }
#[test]
fn routers_are_debug() {
let kw = KeywordRouter::new(BackendType::ClaudeCode).rule("test", BackendType::Codex);
assert!(format!("{kw:?}").contains("KeywordRouter"));
let cap = CapabilityRouter::new();
assert!(format!("{cap:?}").contains("CapabilityRouter"));
let chain = ChainRouter::new().push(KeywordRouter::new(BackendType::ClaudeCode));
let debug = format!("{chain:?}");
assert!(debug.contains("ChainRouter"));
assert!(debug.contains("routers_count"));
let smart = SmartRouter::new(BackendType::ClaudeCode);
assert!(format!("{smart:?}").contains("SmartRouter"));
}
#[test]
fn smart_router_complexity_simple() {
let router = SmartRouter::new(BackendType::ClaudeCode);
let complexity = router.analyze_complexity("Fix the bug.");
assert_eq!(complexity, PromptComplexity::Simple);
}
#[test]
fn smart_router_complexity_medium_by_length() {
let router = SmartRouter::new(BackendType::ClaudeCode);
let prompt = "a".repeat(250); let complexity = router.analyze_complexity(&prompt);
assert_eq!(complexity, PromptComplexity::Medium);
}
#[test]
fn smart_router_complexity_medium_by_code() {
let router = SmartRouter::new(BackendType::ClaudeCode);
let complexity = router.analyze_complexity("Please add fn main() here");
assert_eq!(complexity, PromptComplexity::Medium);
}
#[test]
fn smart_router_complexity_complex_by_length() {
let router = SmartRouter::new(BackendType::ClaudeCode);
let prompt = "a".repeat(900); let complexity = router.analyze_complexity(&prompt);
assert_eq!(complexity, PromptComplexity::Complex);
}
#[test]
fn smart_router_complexity_complex_by_code_and_length() {
let router = SmartRouter::new(BackendType::ClaudeCode);
let prompt = format!("Please implement this:\n```rust\nfn main() {{}}\n```\n{}", "x".repeat(200));
let complexity = router.analyze_complexity(&prompt);
assert_eq!(complexity, PromptComplexity::Complex);
}
#[test]
fn smart_router_complexity_complex_by_paragraphs() {
let router = SmartRouter::new(BackendType::ClaudeCode);
let prompt = "First paragraph.\n\nSecond paragraph.\n\nThird paragraph.\n\nFourth paragraph.";
let complexity = router.analyze_complexity(prompt);
assert_eq!(complexity, PromptComplexity::Complex);
}
#[tokio::test]
async fn smart_router_priority_wins() {
let router = SmartRouter::new(BackendType::Codex)
.priority("security audit", BackendType::ClaudeCode)
.simple_backend(BackendType::GeminiCli);
let config = SpawnConfig::new("w", "Run a security audit on this module");
let available = vec![BackendType::ClaudeCode, BackendType::Codex, BackendType::GeminiCli];
assert_eq!(router.route(&config, &available).await, Some(BackendType::ClaudeCode));
}
#[tokio::test]
async fn smart_router_complexity_routing_simple() {
let router = SmartRouter::new(BackendType::Codex)
.simple_backend(BackendType::GeminiCli);
let config = SpawnConfig::new("w", "Fix the typo.");
let available = vec![BackendType::ClaudeCode, BackendType::Codex, BackendType::GeminiCli];
assert_eq!(router.route(&config, &available).await, Some(BackendType::GeminiCli));
}
#[tokio::test]
async fn smart_router_complexity_routing_complex() {
let router = SmartRouter::new(BackendType::Codex)
.complex_backend(BackendType::ClaudeCode);
let long_prompt = "a".repeat(900);
let config = SpawnConfig::new("w", &long_prompt);
let available = vec![BackendType::ClaudeCode, BackendType::Codex, BackendType::GeminiCli];
assert_eq!(router.route(&config, &available).await, Some(BackendType::ClaudeCode));
}
#[tokio::test]
async fn smart_router_falls_back_to_capability() {
let router = SmartRouter::new(BackendType::Codex)
.cost_weight(1.0);
let config = SpawnConfig::new("w", "Do a task.");
let available = vec![BackendType::ClaudeCode, BackendType::Codex, BackendType::GeminiCli];
assert_eq!(router.route(&config, &available).await, Some(BackendType::GeminiCli));
}
#[tokio::test]
async fn smart_router_priority_skips_unavailable() {
let router = SmartRouter::new(BackendType::Codex)
.priority("audit", BackendType::ClaudeCode)
.simple_backend(BackendType::GeminiCli);
let config = SpawnConfig::new("w", "Run an audit");
let available = vec![BackendType::Codex, BackendType::GeminiCli];
assert_eq!(router.route(&config, &available).await, Some(BackendType::GeminiCli));
}
#[tokio::test]
async fn smart_router_returns_none_for_empty() {
let router = SmartRouter::new(BackendType::ClaudeCode);
let config = SpawnConfig::new("w", "anything");
assert_eq!(router.route(&config, &[]).await, None);
}
#[tokio::test]
async fn smart_router_default_fallback() {
let router = SmartRouter::new(BackendType::Codex);
let config = SpawnConfig::new("w", "Short task.");
let available = vec![BackendType::Codex];
assert_eq!(router.route(&config, &available).await, Some(BackendType::Codex));
}
#[test]
fn smart_router_custom_thresholds() {
let router = SmartRouter::new(BackendType::ClaudeCode)
.complexity_threshold(50, 300);
assert_eq!(router.analyze_complexity(&"a".repeat(60)), PromptComplexity::Medium);
assert_eq!(router.analyze_complexity(&"a".repeat(30)), PromptComplexity::Simple);
assert_eq!(router.analyze_complexity(&"a".repeat(400)), PromptComplexity::Complex);
}
}