octoroute 1.0.0

Intelligent multi-model router for self-hosted LLMs
Documentation
//! Runtime-validated wrapper for tier-specific endpoint selection
//!
//! The TierSelector validates that at least one endpoint exists for the specified tier
//! and provides selection from that tier only.

use crate::config::ModelEndpoint;
use crate::error::{AppError, AppResult};
use crate::models::endpoint_name::ExclusionSet;
use crate::models::selector::ModelSelector;
use crate::router::TargetModel;
use std::sync::Arc;

/// Type-safe selector for a specific model tier
///
/// This newtype wrapper validates that at least one endpoint exists for the specified
/// tier at construction time, then provides type-safe selection from that tier.
///
/// # Tier Selection for LLM Routing
///
/// When used by the LLM-based router, the tier choice impacts routing performance:
///
/// - **FAST (8B)**: Lowest latency (~50-200ms) but may misroute complex requests.
///   Risk of bad routing decisions outweighs latency savings.
///
/// - **BALANCED (30B)**: Recommended default. Good reasoning for classification
///   with acceptable latency (~100-500ms). Best balance of accuracy and speed.
///
/// - **DEEP (120B)**: Highest accuracy but very slow (~2-5s). Router latency may
///   exceed the time to just run the user query on BALANCED. Rarely worth it.
///
/// # Construction-Time Validation
///
/// The tier is validated at construction (checks that at least one endpoint exists)
/// and stored immutably. The selector cannot switch tiers after construction, but
/// tier selection itself is a runtime parameter, not a compile-time guarantee.
#[derive(Debug)]
pub struct TierSelector {
    inner: Arc<ModelSelector>,
    tier: TargetModel,
}

impl TierSelector {
    /// Create a new TierSelector for the specified tier
    ///
    /// Returns an error if the ModelSelector has no endpoints for the specified tier.
    /// This validation ensures the router can never be in an invalid state.
    ///
    /// # Arguments
    /// * `selector` - The underlying ModelSelector
    /// * `tier` - Which tier (Fast, Balanced, Deep) to select from
    ///
    /// # Errors
    ///
    /// Returns `AppError::Config` if no endpoints are configured for the specified tier.
    pub fn new(selector: Arc<ModelSelector>, tier: TargetModel) -> AppResult<Self> {
        if selector.endpoint_count(tier) == 0 {
            return Err(AppError::Config(format!(
                "TierSelector requires at least one {:?} tier endpoint",
                tier
            )));
        }
        Ok(Self {
            inner: selector,
            tier,
        })
    }

    /// Select an endpoint from this selector's tier with health filtering and exclusion
    ///
    /// # Arguments
    /// * `exclude` - Set of endpoint names to exclude (for retry logic)
    ///
    /// # Returns
    /// - `Some(&ModelEndpoint)` if a healthy, non-excluded endpoint exists for this tier
    /// - `None` if all endpoints for this tier are unhealthy or excluded
    pub async fn select(&self, exclude: &ExclusionSet) -> Option<&ModelEndpoint> {
        self.inner.select(self.tier, exclude).await
    }

    /// Get the tier this selector operates on
    pub fn tier(&self) -> TargetModel {
        self.tier
    }

    /// Get the number of configured endpoints for this selector's tier
    pub fn endpoint_count(&self) -> usize {
        self.inner.endpoint_count(self.tier)
    }

    /// Get a reference to the health checker for external use (e.g., marking success/failure)
    pub fn health_checker(&self) -> &Arc<crate::models::health::HealthChecker> {
        self.inner.health_checker()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    /// Helper to create test metrics
    fn test_metrics() -> Arc<crate::metrics::Metrics> {
        Arc::new(crate::metrics::Metrics::new().expect("should create metrics"))
    }
    use crate::config::Config;
    use crate::models::EndpointName;

    fn create_test_config_with_balanced() -> Arc<Config> {
        let toml = r#"
[server]
host = "127.0.0.1"
port = 3000
request_timeout_seconds = 30

[[models.fast]]
name = "fast-1"
base_url = "http://localhost:11434/v1"
max_tokens = 2048
weight = 1.0
priority = 1

[[models.balanced]]
name = "balanced-1"
base_url = "http://localhost:1234/v1"
max_tokens = 4096
weight = 1.0
priority = 1

[[models.balanced]]
name = "balanced-2"
base_url = "http://localhost:1235/v1"
max_tokens = 4096
weight = 1.0
priority = 1

[[models.deep]]
name = "deep-1"
base_url = "http://localhost:8080/v1"
max_tokens = 8192
weight = 1.0
priority = 1

[routing]
strategy = "hybrid"
default_importance = "normal"
router_tier = "balanced"
"#;
        let config: Config = toml::from_str(toml).expect("should parse config");
        Arc::new(config)
    }

    fn create_test_config_without_balanced() -> Arc<Config> {
        let toml = r#"
[server]
host = "127.0.0.1"
port = 3000
request_timeout_seconds = 30

[[models.fast]]
name = "fast-1"
base_url = "http://localhost:11434/v1"
max_tokens = 2048
weight = 1.0
priority = 1

[models]
balanced = []

[[models.deep]]
name = "deep-1"
base_url = "http://localhost:8080/v1"
max_tokens = 8192
weight = 1.0
priority = 1

[routing]
strategy = "rule"
default_importance = "normal"
router_tier = "balanced"
"#;
        let config: Config = toml::from_str(toml).expect("should parse config");
        Arc::new(config)
    }

    #[tokio::test]
    async fn test_tier_selector_new_with_balanced_endpoints() {
        let config = create_test_config_with_balanced();
        let selector = Arc::new(ModelSelector::new(config, test_metrics()));

        let result = TierSelector::new(selector, TargetModel::Balanced);
        assert!(
            result.is_ok(),
            "should create TierSelector with balanced endpoints"
        );
    }

    #[tokio::test]
    async fn test_tier_selector_new_without_balanced_endpoints() {
        let config = create_test_config_without_balanced();
        let selector = Arc::new(ModelSelector::new(config, test_metrics()));

        let result = TierSelector::new(selector, TargetModel::Balanced);
        assert!(result.is_err(), "should fail without balanced endpoints");

        let err = result.unwrap_err();
        match err {
            AppError::Config(msg) => {
                assert!(
                    msg.contains("Balanced") && msg.contains("tier endpoint"),
                    "error should mention Balanced tier requirement, got: {}",
                    msg
                );
            }
            _ => panic!("expected Config error, got: {:?}", err),
        }
    }

    #[tokio::test]
    async fn test_tier_selector_selects_balanced_endpoint() {
        let config = create_test_config_with_balanced();
        let selector = Arc::new(ModelSelector::new(config, test_metrics()));
        let tier_selector =
            TierSelector::new(selector, TargetModel::Balanced).expect("should create TierSelector");

        let exclude = ExclusionSet::new();
        let endpoint = tier_selector.select(&exclude).await;

        assert!(endpoint.is_some(), "should select a balanced endpoint");
        let endpoint = endpoint.unwrap();
        assert!(
            endpoint.name().starts_with("balanced-"),
            "selected endpoint should be from balanced tier, got: {}",
            endpoint.name()
        );
    }

    #[tokio::test]
    async fn test_tier_selector_respects_exclusion() {
        let config = create_test_config_with_balanced();
        let selector = Arc::new(ModelSelector::new(config, test_metrics()));
        let tier_selector =
            TierSelector::new(selector, TargetModel::Balanced).expect("should create TierSelector");

        // First selection should succeed
        let exclude = ExclusionSet::new();
        let first = tier_selector.select(&exclude).await;
        assert!(first.is_some());

        // Build exclusion set with both balanced endpoints
        let mut exclude = ExclusionSet::new();
        exclude.insert(EndpointName::from("balanced-1"));
        exclude.insert(EndpointName::from("balanced-2"));

        // Should return None when all balanced endpoints excluded
        let excluded = tier_selector.select(&exclude).await;
        assert!(
            excluded.is_none(),
            "should return None when all balanced endpoints excluded"
        );
    }

    #[tokio::test]
    async fn test_tier_selector_endpoint_count() {
        let config = create_test_config_with_balanced();
        let selector = Arc::new(ModelSelector::new(config, test_metrics()));
        let tier_selector =
            TierSelector::new(selector, TargetModel::Balanced).expect("should create TierSelector");

        assert_eq!(
            tier_selector.endpoint_count(),
            2,
            "should have 2 balanced endpoints"
        );
    }

    #[tokio::test]
    async fn test_tier_selector_with_fast_tier() {
        let config = create_test_config_with_balanced();
        let selector = Arc::new(ModelSelector::new(config, test_metrics()));

        let tier_selector = TierSelector::new(selector, TargetModel::Fast)
            .expect("should create TierSelector for Fast tier");

        assert_eq!(tier_selector.tier(), TargetModel::Fast);
        assert_eq!(
            tier_selector.endpoint_count(),
            1,
            "should have 1 fast endpoint"
        );

        let exclude = ExclusionSet::new();
        let endpoint = tier_selector.select(&exclude).await;
        assert!(endpoint.is_some(), "should select a fast endpoint");
        assert_eq!(endpoint.unwrap().name(), "fast-1");
    }

    #[tokio::test]
    async fn test_tier_selector_with_deep_tier() {
        let config = create_test_config_with_balanced();
        let selector = Arc::new(ModelSelector::new(config, test_metrics()));

        let tier_selector = TierSelector::new(selector, TargetModel::Deep)
            .expect("should create TierSelector for Deep tier");

        assert_eq!(tier_selector.tier(), TargetModel::Deep);
        assert_eq!(
            tier_selector.endpoint_count(),
            1,
            "should have 1 deep endpoint"
        );

        let exclude = ExclusionSet::new();
        let endpoint = tier_selector.select(&exclude).await;
        assert!(endpoint.is_some(), "should select a deep endpoint");
        assert_eq!(endpoint.unwrap().name(), "deep-1");
    }

    #[tokio::test]
    async fn test_tier_selector_fails_with_no_endpoints_for_tier() {
        let config = create_test_config_without_balanced();
        let selector = Arc::new(ModelSelector::new(config, test_metrics()));

        // Should fail for Balanced tier (no endpoints)
        let result = TierSelector::new(selector.clone(), TargetModel::Balanced);
        assert!(result.is_err(), "should fail when tier has no endpoints");

        let err = result.unwrap_err();
        match err {
            AppError::Config(msg) => {
                assert!(
                    msg.contains("Balanced") && msg.contains("tier endpoint"),
                    "error should mention missing tier, got: {}",
                    msg
                );
            }
            _ => panic!("expected Config error, got: {:?}", err),
        }
    }
}