herolib-ai 0.3.7

AI client with multi-provider support (Groq, OpenRouter, SambaNova) and automatic failover
Documentation
//! AI model definitions.
//!
//! This module defines the available AI models and their provider mappings.

use serde::{Deserialize, Serialize};

use crate::provider::Provider;

/// Available AI models with our own naming convention.
///
/// Each model maps to one or more providers, tried in order.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Model {
    // Fast, small models for quick tasks
    /// Llama 3.3 70B - Fast, capable model for general tasks.
    Llama3_3_70B,
    /// Llama 3.1 70B - Versatile model for various tasks.
    Llama3_1_70B,
    /// Llama 3.1 8B - Small, fast model for simple tasks.
    Llama3_1_8B,

    // Coding-focused models
    /// Qwen 2.5 Coder 32B - Specialized for code generation.
    Qwen2_5Coder32B,
    /// DeepSeek Coder V2.5 - Advanced coding model.
    DeepSeekCoderV2_5,
    /// DeepSeek V3 - Latest DeepSeek model.
    DeepSeekV3,

    // Large reasoning models
    /// Llama 3.1 405B - Largest Llama model for complex tasks.
    Llama3_1_405B,

    // Mixture of Experts models
    /// Mixtral 8x7B - Efficient mixture of experts model.
    Mixtral8x7B,

    // Vision models
    /// Llama 3.2 90B Vision - Multimodal model with vision.
    Llama3_2_90BVision,
    /// Llama 3.2 11B Vision - Smaller vision model.
    Llama3_2_11BVision,

    // NVIDIA models
    /// Nemotron 3 Nano 30B - NVIDIA MoE model with reasoning support.
    NemotronNano30B,
}

/// Model information including provider mappings.
#[derive(Debug, Clone)]
pub struct ModelInfo {
    /// Our internal model name.
    pub model: Model,
    /// Human-readable description.
    pub description: &'static str,
    /// Context window size in tokens.
    pub context_window: usize,
    /// Provider mappings in order of preference.
    pub providers: Vec<ProviderMapping>,
}

/// Mapping of a model to a specific provider.
#[derive(Debug, Clone)]
pub struct ProviderMapping {
    /// The provider.
    pub provider: Provider,
    /// The model name/ID used by this provider.
    pub model_id: &'static str,
}

impl ProviderMapping {
    /// Creates a new provider mapping.
    pub const fn new(provider: Provider, model_id: &'static str) -> Self {
        Self { provider, model_id }
    }
}

impl Model {
    /// Returns the model information.
    pub fn info(&self) -> ModelInfo {
        match self {
            Model::Llama3_3_70B => ModelInfo {
                model: *self,
                description: "Llama 3.3 70B - Fast, capable model for general tasks",
                context_window: 128_000,
                providers: vec![
                    ProviderMapping::new(Provider::Groq, "llama-3.3-70b-versatile"),
                    ProviderMapping::new(Provider::SambaNova, "Meta-Llama-3.3-70B-Instruct"),
                    ProviderMapping::new(Provider::OpenRouter, "meta-llama/llama-3.3-70b-instruct"),
                ],
            },
            Model::Llama3_1_70B => ModelInfo {
                model: *self,
                description: "Llama 3.1 70B - Versatile model for various tasks",
                context_window: 128_000,
                providers: vec![
                    ProviderMapping::new(Provider::Groq, "llama-3.1-70b-versatile"),
                    ProviderMapping::new(Provider::SambaNova, "Meta-Llama-3.1-70B-Instruct"),
                    ProviderMapping::new(Provider::OpenRouter, "meta-llama/llama-3.1-70b-instruct"),
                ],
            },
            Model::Llama3_1_8B => ModelInfo {
                model: *self,
                description: "Llama 3.1 8B - Small, fast model for simple tasks",
                context_window: 128_000,
                providers: vec![
                    ProviderMapping::new(Provider::Groq, "llama-3.1-8b-instant"),
                    ProviderMapping::new(Provider::SambaNova, "Meta-Llama-3.1-8B-Instruct"),
                    ProviderMapping::new(Provider::OpenRouter, "meta-llama/llama-3.1-8b-instruct"),
                ],
            },
            Model::Qwen2_5Coder32B => ModelInfo {
                model: *self,
                description: "Qwen 2.5 Coder 32B - Specialized for code generation",
                context_window: 32_000,
                providers: vec![
                    ProviderMapping::new(Provider::Groq, "qwen-2.5-coder-32b"),
                    ProviderMapping::new(Provider::SambaNova, "Qwen2.5-Coder-32B-Instruct"),
                    ProviderMapping::new(Provider::OpenRouter, "qwen/qwen-2.5-coder-32b-instruct"),
                ],
            },
            Model::DeepSeekCoderV2_5 => ModelInfo {
                model: *self,
                description: "DeepSeek Coder V2.5 - Advanced coding model",
                context_window: 128_000,
                providers: vec![
                    ProviderMapping::new(Provider::OpenRouter, "deepseek/deepseek-coder"),
                    ProviderMapping::new(Provider::SambaNova, "DeepSeek-Coder-V2-Instruct"),
                ],
            },
            Model::DeepSeekV3 => ModelInfo {
                model: *self,
                description: "DeepSeek V3 - Latest DeepSeek model",
                context_window: 128_000,
                providers: vec![
                    ProviderMapping::new(Provider::OpenRouter, "deepseek/deepseek-chat"),
                    ProviderMapping::new(Provider::SambaNova, "DeepSeek-V3"),
                ],
            },
            Model::Llama3_1_405B => ModelInfo {
                model: *self,
                description: "Llama 3.1 405B - Largest Llama model for complex tasks",
                context_window: 128_000,
                providers: vec![
                    ProviderMapping::new(Provider::SambaNova, "Meta-Llama-3.1-405B-Instruct"),
                    ProviderMapping::new(
                        Provider::OpenRouter,
                        "meta-llama/llama-3.1-405b-instruct",
                    ),
                ],
            },
            Model::Mixtral8x7B => ModelInfo {
                model: *self,
                description: "Mixtral 8x7B - Efficient mixture of experts model",
                context_window: 32_000,
                providers: vec![
                    ProviderMapping::new(Provider::Groq, "mixtral-8x7b-32768"),
                    ProviderMapping::new(Provider::OpenRouter, "mistralai/mixtral-8x7b-instruct"),
                ],
            },
            Model::Llama3_2_90BVision => ModelInfo {
                model: *self,
                description: "Llama 3.2 90B Vision - Multimodal model with vision",
                context_window: 128_000,
                providers: vec![
                    ProviderMapping::new(Provider::Groq, "llama-3.2-90b-vision-preview"),
                    ProviderMapping::new(
                        Provider::OpenRouter,
                        "meta-llama/llama-3.2-90b-vision-instruct",
                    ),
                ],
            },
            Model::Llama3_2_11BVision => ModelInfo {
                model: *self,
                description: "Llama 3.2 11B Vision - Smaller vision model",
                context_window: 128_000,
                providers: vec![
                    ProviderMapping::new(Provider::Groq, "llama-3.2-11b-vision-preview"),
                    ProviderMapping::new(Provider::SambaNova, "Llama-3.2-11B-Vision-Instruct"),
                    ProviderMapping::new(
                        Provider::OpenRouter,
                        "meta-llama/llama-3.2-11b-vision-instruct",
                    ),
                ],
            },
            Model::NemotronNano30B => ModelInfo {
                model: *self,
                description: "Nemotron 3 Nano 30B - NVIDIA MoE model with reasoning support",
                context_window: 262_144,
                providers: vec![ProviderMapping::new(
                    Provider::OpenRouter,
                    "nvidia/nemotron-3-nano-30b-a3b",
                )],
            },
        }
    }

    /// Returns the human-readable name.
    pub fn name(&self) -> &'static str {
        match self {
            Model::Llama3_3_70B => "Llama 3.3 70B",
            Model::Llama3_1_70B => "Llama 3.1 70B",
            Model::Llama3_1_8B => "Llama 3.1 8B",
            Model::Qwen2_5Coder32B => "Qwen 2.5 Coder 32B",
            Model::DeepSeekCoderV2_5 => "DeepSeek Coder V2.5",
            Model::DeepSeekV3 => "DeepSeek V3",
            Model::Llama3_1_405B => "Llama 3.1 405B",
            Model::Mixtral8x7B => "Mixtral 8x7B",
            Model::Llama3_2_90BVision => "Llama 3.2 90B Vision",
            Model::Llama3_2_11BVision => "Llama 3.2 11B Vision",
            Model::NemotronNano30B => "Nemotron 3 Nano 30B",
        }
    }

    /// Returns the default model for general tasks.
    pub fn default_general() -> Self {
        Model::Llama3_3_70B
    }

    /// Returns the default model for coding tasks.
    pub fn default_coding() -> Self {
        Model::Qwen2_5Coder32B
    }

    /// Returns all available models.
    pub fn all() -> &'static [Model] {
        &[
            Model::Llama3_3_70B,
            Model::Llama3_1_70B,
            Model::Llama3_1_8B,
            Model::Qwen2_5Coder32B,
            Model::DeepSeekCoderV2_5,
            Model::DeepSeekV3,
            Model::Llama3_1_405B,
            Model::Mixtral8x7B,
            Model::Llama3_2_90BVision,
            Model::Llama3_2_11BVision,
            Model::NemotronNano30B,
        ]
    }
}

impl std::fmt::Display for Model {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "{}", self.name())
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_model_info() {
        let info = Model::Llama3_3_70B.info();
        assert!(!info.providers.is_empty());
        assert!(info.context_window > 0);
    }

    #[test]
    fn test_all_models_have_providers() {
        for model in Model::all() {
            let info = model.info();
            assert!(
                !info.providers.is_empty(),
                "Model {} has no providers",
                model.name()
            );
        }
    }

    #[test]
    fn test_default_models() {
        assert_eq!(Model::default_general(), Model::Llama3_3_70B);
        assert_eq!(Model::default_coding(), Model::Qwen2_5Coder32B);
    }
}