car-inference 0.15.0

Local model inference for CAR — Candle backend with Qwen3 models
Documentation
//! Intelligent model routing — select the best model based on prompt characteristics.
//!
//! Routes generation requests based on prompt complexity analysis:
//! - Simple questions → smallest model (0.6B)
//! - Medium tasks → Qwen3-1.7B
//! - Code tasks → Qwen3-4B
//! - Complex reasoning → best available (8B or 30B-A3B)
//!
//! Classify/embed tasks always route to the smallest model.

use crate::hardware::HardwareInfo;
use crate::models::ModelRegistry;
use serde::{Deserialize, Serialize};
use tracing::debug;

// Re-export TaskComplexity from adaptive_router (canonical definition).
pub use crate::adaptive_router::TaskComplexity;

/// The result of routing a request to a model.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RoutingDecision {
    pub model: String,
    pub complexity: TaskComplexity,
    pub reason: String,
}

/// Routes inference requests to the best available model based on prompt analysis and hardware.
pub struct ModelRouter {
    hw: HardwareInfo,
}

/// Model tier ordering, smallest to largest.
const TIERS: &[&str] = &[
    "Qwen3-0.6B",
    "Qwen3-1.7B",
    "Qwen3-4B",
    "Qwen3-8B",
    "Qwen3-30B-A3B",
];

impl ModelRouter {
    pub fn new(hw: HardwareInfo) -> Self {
        Self { hw }
    }

    /// Route a generation request to the best model.
    pub fn route_generate(&self, prompt: &str, registry: &ModelRegistry) -> RoutingDecision {
        let complexity = Self::assess_complexity(prompt);
        let target = self.target_for_complexity(complexity);
        let model = self.best_available(&target, registry);
        let reason = format!("{:?} task -> {} (target: {})", complexity, model, target);
        debug!(
            complexity = ?complexity,
            target = %target,
            selected = %model,
            "routed generation request"
        );
        RoutingDecision {
            model,
            complexity,
            reason,
        }
    }

    /// Route classify to the smallest available generative model.
    pub fn route_small(&self, registry: &ModelRegistry) -> String {
        let model = self.best_available("Qwen3-0.6B", registry);
        debug!(selected = %model, "routed small task (classify)");
        model
    }

    /// Route embedding to the dedicated embedding model.
    pub fn route_embedding(&self, _registry: &ModelRegistry) -> String {
        "Qwen3-Embedding-0.6B".to_string()
    }

    /// Assess prompt complexity based on content analysis.
    pub fn assess_complexity(prompt: &str) -> TaskComplexity {
        let lower = prompt.to_lowercase();
        let word_count = prompt.split_whitespace().count();
        let estimated_tokens = (word_count as f64 * 1.3) as usize;

        // Code markers
        let code_markers = [
            "```",
            "fn ",
            "def ",
            "class ",
            "import ",
            "require(",
            "async fn",
            "pub fn",
            "function ",
            "const ",
            "let ",
            "var ",
            "#include",
            "package ",
            "impl ",
        ];
        let has_code = code_markers.iter().any(|m| prompt.contains(m));

        // Repair/debug markers (checked case-insensitive)
        let repair_markers = [
            "fix", "repair", "debug", "refactor", "broken", "failing", "error", "bug",
        ];
        let has_repair = repair_markers.iter().any(|m| lower.contains(m));

        // Complex reasoning markers (checked case-insensitive)
        let reasoning_markers = [
            "analyze",
            "compare",
            "explain why",
            "step by step",
            "think through",
            "evaluate",
            "trade-off",
            "tradeoff",
            "pros and cons",
            "architecture",
            "design",
            "strategy",
            "optimize",
            "comprehensive",
        ];
        let has_reasoning = reasoning_markers.iter().any(|m| lower.contains(m));

        // Simple question patterns (checked case-insensitive)
        let simple_patterns = [
            "what is",
            "who is",
            "when did",
            "where is",
            "how many",
            "yes or no",
            "true or false",
            "name the",
            "list the",
            "define ",
        ];
        let is_simple_question = simple_patterns.iter().any(|p| lower.contains(p));

        // Priority: code > reasoning/long > simple > medium
        if has_code || has_repair {
            TaskComplexity::Code
        } else if has_reasoning || estimated_tokens > 500 {
            TaskComplexity::Complex
        } else if is_simple_question || estimated_tokens < 30 {
            TaskComplexity::Simple
        } else {
            TaskComplexity::Medium
        }
    }

    /// Get the target model name for a complexity level.
    fn target_for_complexity(&self, complexity: TaskComplexity) -> String {
        match complexity {
            TaskComplexity::Simple => "Qwen3-0.6B".into(),
            TaskComplexity::Medium => "Qwen3-1.7B".into(),
            TaskComplexity::Code => "Qwen3-4B".into(),
            TaskComplexity::Complex => self.hw.recommended_model.clone(),
        }
    }

    /// Find the best available model at or near the target.
    ///
    /// Prefers the target if downloaded and fits in memory. Otherwise falls back
    /// upward to larger models first (prefer better quality), then downward to
    /// smaller models. If nothing is downloaded, returns the target name (which
    /// will trigger a download on first use).
    fn best_available(&self, target: &str, registry: &ModelRegistry) -> String {
        let models = registry.list_models();
        let target_idx = TIERS.iter().position(|&t| t == target).unwrap_or(0);

        // Check if target is downloaded and fits
        if let Some(m) = models.iter().find(|m| m.name == target) {
            if m.downloaded && m.quantized_size_mb <= self.hw.max_model_mb {
                return target.to_string();
            }
        }

        // Try models from target upward (prefer better quality)
        for &name in &TIERS[target_idx..] {
            if let Some(m) = models.iter().find(|m| m.name == name) {
                if m.downloaded && m.quantized_size_mb <= self.hw.max_model_mb {
                    return name.to_string();
                }
            }
        }

        // Try models below target (fallback to smaller)
        for &name in TIERS[..target_idx].iter().rev() {
            if let Some(m) = models.iter().find(|m| m.name == name) {
                if m.downloaded && m.quantized_size_mb <= self.hw.max_model_mb {
                    return name.to_string();
                }
            }
        }

        // Nothing downloaded — return target anyway (will trigger download)
        target.to_string()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn simple_question() {
        assert_eq!(
            TaskComplexity::assess("What is the capital of France?"),
            TaskComplexity::Simple,
        );
    }

    #[test]
    fn code_with_function_keyword() {
        assert_eq!(
            TaskComplexity::assess("Write a function to sort a list"),
            TaskComplexity::Code,
        );
    }

    #[test]
    fn complex_reasoning() {
        assert_eq!(
            TaskComplexity::assess(
                "Analyze the trade-offs between microservices and monolithic architecture for our use case"
            ),
            TaskComplexity::Complex,
        );
    }

    #[test]
    fn medium_default() {
        // Must be >30 estimated tokens (~24 words) to avoid Simple, and no code/reasoning markers
        assert_eq!(
            TaskComplexity::assess(
                "Tell me about the history of computing and how it has evolved over the decades from early mechanical calculation devices through transistors to the modern digital systems we use today in everyday life"
            ),
            TaskComplexity::Medium,
        );
    }

    #[test]
    fn code_with_backticks() {
        assert_eq!(
            TaskComplexity::assess("Here is my code:\n```rust\nfn main() {}\n```\nFix it"),
            TaskComplexity::Code,
        );
    }

    #[test]
    fn repair_marker() {
        assert_eq!(
            TaskComplexity::assess("Debug this failing test case"),
            TaskComplexity::Code,
        );
    }
}