octoroute 1.0.0

Intelligent multi-model router for self-hosted LLMs
Documentation
//! Basic ModelSelector tests
//!
//! Tests basic functionality: creation, simple selection, endpoint counting,
//! empty tiers, and concurrency safety.

use super::*;

/// Helper to create test metrics
fn test_metrics() -> Arc<crate::metrics::Metrics> {
    Arc::new(crate::metrics::Metrics::new().expect("should create metrics"))
}
use crate::models::endpoint_name::ExclusionSet;
use crate::models::selector::ModelSelector;
use crate::router::TargetModel;
use std::sync::Arc;

#[tokio::test]
async fn test_selector_new_creates_selector() {
    let config = Arc::new(create_test_config());
    let selector = ModelSelector::new(config, test_metrics());

    // Verify we can create a selector
    assert_eq!(selector.endpoint_count(TargetModel::Fast), 2);
    assert_eq!(selector.endpoint_count(TargetModel::Balanced), 1);
    assert_eq!(selector.endpoint_count(TargetModel::Deep), 1);
}

#[tokio::test]
async fn test_selector_select_returns_endpoint() {
    let config = Arc::new(create_test_config());
    let selector = ModelSelector::new(config, test_metrics());

    // Should return some endpoint for each tier (no exclusions)
    let no_exclude = ExclusionSet::new();
    assert!(
        selector
            .select(TargetModel::Fast, &no_exclude)
            .await
            .is_some()
    );
    assert!(
        selector
            .select(TargetModel::Balanced, &no_exclude)
            .await
            .is_some()
    );
    assert!(
        selector
            .select(TargetModel::Deep, &no_exclude)
            .await
            .is_some()
    );
}

#[tokio::test]
async fn test_selector_single_endpoint_tier() {
    let config = Arc::new(create_test_config());
    let selector = ModelSelector::new(config, test_metrics());

    // Balanced tier has only one endpoint, should return same one
    let no_exclude = ExclusionSet::new();
    let first = selector
        .select(TargetModel::Balanced, &no_exclude)
        .await
        .unwrap();
    let second = selector
        .select(TargetModel::Balanced, &no_exclude)
        .await
        .unwrap();

    assert_eq!(first.name(), "balanced-1");
    assert_eq!(second.name(), "balanced-1");
}

#[tokio::test]
async fn test_selector_endpoint_count() {
    let config = Arc::new(create_test_config());
    let selector = ModelSelector::new(config, test_metrics());

    assert_eq!(selector.endpoint_count(TargetModel::Fast), 2);
    assert_eq!(selector.endpoint_count(TargetModel::Balanced), 1);
    assert_eq!(selector.endpoint_count(TargetModel::Deep), 1);
}

#[tokio::test]
async fn test_selector_returns_none_for_empty_tier() {
    // Config with empty fast tier
    let toml_config = r#"
[server]
host = "127.0.0.1"
port = 3000
request_timeout_seconds = 30

[models]
fast = []

[[models.balanced]]
name = "balanced-1"
base_url = "http://localhost:1236/v1"
max_tokens = 4096

[[models.deep]]
name = "deep-1"
base_url = "http://localhost:1237/v1"
max_tokens = 8192

[routing]
strategy = "rule"
router_tier = "balanced"
"#;
    let config: Config = toml::from_str(toml_config).expect("should parse TOML");
    let selector = ModelSelector::new(Arc::new(config), test_metrics());

    let no_exclude = ExclusionSet::new();
    let result = selector.select(TargetModel::Fast, &no_exclude).await;
    assert!(result.is_none(), "should return None for empty tier");
}

#[tokio::test]
async fn test_selector_concurrent_weighted_selection() {
    let config = Arc::new(create_test_config());
    let selector = Arc::new(ModelSelector::new(config, test_metrics()));

    // Spawn 10 concurrent tasks selecting from Fast tier (which has 2 endpoints)
    let mut handles = vec![];
    for _ in 0..10 {
        let sel = selector.clone();
        handles.push(tokio::spawn(async move {
            let no_exclude = ExclusionSet::new();
            sel.select(TargetModel::Fast, &no_exclude)
                .await
                .map(|e| e.name().to_string())
        }));
    }

    // Collect results
    let results: Vec<_> = futures::future::join_all(handles)
        .await
        .into_iter()
        .map(|r| r.unwrap())
        .collect();

    // Verify all selections succeeded (concurrency safety)
    assert_eq!(
        results.len(),
        10,
        "all concurrent selections should complete"
    );
    for result in &results {
        assert!(result.is_some(), "all selections should succeed");
    }

    // Verify all selected endpoints are valid (from the configured endpoints)
    let endpoint_names: Vec<String> = results.into_iter().flatten().collect();
    for name in &endpoint_names {
        assert!(
            name == "fast-1" || name == "fast-2",
            "selected endpoint should be valid, got: {}",
            name
        );
    }

    // Note: With weighted random selection, we cannot deterministically assert
    // that both endpoints are always selected in just 10 draws. With equal weights,
    // there's ~0.2% chance all 10 selections hit the same endpoint.
    // This test focuses on concurrency safety, not distribution.
}