herolib-ai 0.3.13

AI client with multi-provider support (Groq, OpenRouter, SambaNova) and automatic failover
Documentation
//! Embedding model definitions and types.
//!
//! This module defines the available embedding models and the types for embedding requests/responses.

use serde::{Deserialize, Serialize};

use crate::provider::Provider;

/// Available embedding models.
///
/// Each model maps to one or more providers.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum EmbeddingModel {
    /// Qwen3 Embedding 8B - Multilingual embedding model with strong retrieval capabilities.
    Qwen3Embedding8B,
    /// OpenAI Text Embedding 3 Small - Fast, efficient embedding model with 1536 dimensions.
    TextEmbedding3Small,
}

/// Embedding model information including provider mappings.
#[derive(Debug, Clone)]
pub struct EmbeddingModelInfo {
    /// Our internal model name.
    pub model: EmbeddingModel,
    /// Human-readable description.
    pub description: &'static str,
    /// Context window size in tokens.
    pub context_window: usize,
    /// Provider mappings in order of preference.
    pub providers: Vec<EmbeddingProviderMapping>,
}

/// Mapping of an embedding model to a specific provider.
#[derive(Debug, Clone)]
pub struct EmbeddingProviderMapping {
    /// The provider.
    pub provider: Provider,
    /// The model name/ID used by this provider.
    pub model_id: &'static str,
}

impl EmbeddingProviderMapping {
    /// Creates a new provider mapping.
    pub const fn new(provider: Provider, model_id: &'static str) -> Self {
        Self { provider, model_id }
    }
}

impl EmbeddingModel {
    /// Returns the model information.
    pub fn info(&self) -> EmbeddingModelInfo {
        match self {
            EmbeddingModel::Qwen3Embedding8B => EmbeddingModelInfo {
                model: *self,
                description: "Qwen3 Embedding 8B - Multilingual embedding model",
                context_window: 32_768,
                providers: vec![EmbeddingProviderMapping::new(
                    Provider::OpenRouter,
                    "qwen/qwen3-embedding-8b",
                )],
            },
            EmbeddingModel::TextEmbedding3Small => EmbeddingModelInfo {
                model: *self,
                description: "OpenAI Text Embedding 3 Small - Fast, efficient embedding model",
                context_window: 8_191,
                providers: vec![EmbeddingProviderMapping::new(
                    Provider::OpenRouter,
                    "openai/text-embedding-3-small",
                )],
            },
        }
    }

    /// Returns the human-readable name.
    pub fn name(&self) -> &'static str {
        match self {
            EmbeddingModel::Qwen3Embedding8B => "Qwen3 Embedding 8B",
            EmbeddingModel::TextEmbedding3Small => "Text Embedding 3 Small",
        }
    }

    /// Returns the default embedding model.
    pub fn default() -> Self {
        EmbeddingModel::TextEmbedding3Small
    }

    /// Returns all available embedding models.
    pub fn all() -> &'static [EmbeddingModel] {
        &[
            EmbeddingModel::Qwen3Embedding8B,
            EmbeddingModel::TextEmbedding3Small,
        ]
    }
}

impl std::fmt::Display for EmbeddingModel {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "{}", self.name())
    }
}

/// Request body for embeddings.
#[derive(Debug, Clone, Serialize)]
pub struct EmbeddingRequest {
    /// The model to use.
    pub model: String,
    /// Input text to embed. Can be a string or array of strings.
    pub input: EmbeddingInput,
    /// The format to return the embeddings in.
    #[serde(skip_serializing_if = "Option::is_none")]
    pub encoding_format: Option<String>,
    /// The number of dimensions the resulting output embeddings should have.
    #[serde(skip_serializing_if = "Option::is_none")]
    pub dimensions: Option<u32>,
}

/// Input for embedding requests.
#[derive(Debug, Clone, Serialize)]
#[serde(untagged)]
pub enum EmbeddingInput {
    /// Single text input.
    Single(String),
    /// Multiple text inputs.
    Multiple(Vec<String>),
}

impl EmbeddingRequest {
    /// Creates a new embedding request with a single input.
    pub fn new(model: impl Into<String>, input: impl Into<String>) -> Self {
        Self {
            model: model.into(),
            input: EmbeddingInput::Single(input.into()),
            encoding_format: None,
            dimensions: None,
        }
    }

    /// Creates a new embedding request with multiple inputs.
    pub fn new_batch(model: impl Into<String>, inputs: Vec<String>) -> Self {
        Self {
            model: model.into(),
            input: EmbeddingInput::Multiple(inputs),
            encoding_format: None,
            dimensions: None,
        }
    }

    /// Sets the encoding format.
    pub fn with_encoding_format(mut self, format: impl Into<String>) -> Self {
        self.encoding_format = Some(format.into());
        self
    }

    /// Sets the number of dimensions for the output embeddings.
    pub fn with_dimensions(mut self, dimensions: u32) -> Self {
        self.dimensions = Some(dimensions);
        self
    }
}

/// Response from embeddings API.
#[derive(Debug, Clone, Deserialize)]
pub struct EmbeddingResponse {
    /// Object type (always "list").
    pub object: String,
    /// The embedding data.
    pub data: Vec<EmbeddingData>,
    /// Model used for embedding.
    pub model: String,
    /// Token usage statistics.
    #[serde(default)]
    pub usage: Option<EmbeddingUsage>,
}

impl EmbeddingResponse {
    /// Returns the first embedding vector.
    pub fn embedding(&self) -> Option<&[f32]> {
        self.data.first().map(|d| d.embedding.as_slice())
    }

    /// Returns all embedding vectors.
    pub fn embeddings(&self) -> Vec<&[f32]> {
        self.data.iter().map(|d| d.embedding.as_slice()).collect()
    }
}

/// Individual embedding data.
#[derive(Debug, Clone, Deserialize)]
pub struct EmbeddingData {
    /// Object type (always "embedding").
    pub object: String,
    /// Index of this embedding in the batch.
    pub index: usize,
    /// The embedding vector.
    pub embedding: Vec<f32>,
}

/// Token usage for embedding requests.
#[derive(Debug, Clone, Deserialize, Default)]
pub struct EmbeddingUsage {
    /// Tokens in the prompt/input.
    #[serde(default)]
    pub prompt_tokens: u32,
    /// Total tokens used.
    #[serde(default)]
    pub total_tokens: u32,
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_embedding_model_info() {
        let info = EmbeddingModel::Qwen3Embedding8B.info();
        assert!(!info.providers.is_empty());
        assert!(info.context_window > 0);
    }

    #[test]
    fn test_all_embedding_models_have_providers() {
        for model in EmbeddingModel::all() {
            let info = model.info();
            assert!(
                !info.providers.is_empty(),
                "Embedding model {} has no providers",
                model.name()
            );
        }
    }

    #[test]
    fn test_embedding_request() {
        let request = EmbeddingRequest::new("model", "Hello, world!");
        assert_eq!(request.model, "model");

        let batch_request =
            EmbeddingRequest::new_batch("model", vec!["Hello".to_string(), "World".to_string()]);
        if let EmbeddingInput::Multiple(inputs) = batch_request.input {
            assert_eq!(inputs.len(), 2);
        } else {
            panic!("Expected Multiple input");
        }
    }

    #[test]
    fn test_embedding_response_parsing() {
        let json = r#"{
            "object": "list",
            "data": [{
                "object": "embedding",
                "index": 0,
                "embedding": [0.1, 0.2, 0.3, 0.4, 0.5]
            }],
            "model": "qwen/qwen3-embedding-8b",
            "usage": {
                "prompt_tokens": 5,
                "total_tokens": 5
            }
        }"#;

        let response: EmbeddingResponse = serde_json::from_str(json).unwrap();
        assert_eq!(response.data.len(), 1);
        assert_eq!(response.embedding().unwrap().len(), 5);
        assert_eq!(response.usage.as_ref().unwrap().prompt_tokens, 5);
    }
}