llmkit 0.1.3

Production-grade LLM client - 100+ providers, 11,000+ models. Pure Rust.
Documentation
//! Mistral embeddings provider for vector representations.
//!
//! Mistral provides high-quality embedding models optimized for
//! semantic search, RAG, and similarity matching.
//!
//! # Models
//! - mistral-embed: General-purpose embeddings
//! - mistral-large-latest: Can be used for embeddings

use serde::{Deserialize, Serialize};

use crate::error::{Error, Result};

/// Mistral embeddings provider
pub struct MistralEmbeddingsProvider {
    #[allow(dead_code)]
    api_key: String,
}

impl MistralEmbeddingsProvider {
    /// Create a new Mistral embeddings provider
    pub fn new(api_key: &str) -> Self {
        Self {
            api_key: api_key.to_string(),
        }
    }

    /// Create from environment variable `MISTRAL_API_KEY`
    pub fn from_env() -> Result<Self> {
        let api_key = std::env::var("MISTRAL_API_KEY").map_err(|_| {
            Error::Configuration("MISTRAL_API_KEY environment variable not set".to_string())
        })?;
        Ok(Self::new(&api_key))
    }

    /// Generate embeddings for text
    pub async fn embed(&self, text: &str, model: &str) -> Result<EmbeddingResponse> {
        // In a real implementation, this would call Mistral API
        // For now, return a mock response
        Ok(EmbeddingResponse {
            model: model.to_string(),
            object: "list".to_string(),
            data: vec![EmbeddingData {
                object: "embedding".to_string(),
                embedding: vec![0.1; 1024], // Mock 1024-dim embedding
                index: 0,
            }],
            usage: EmbeddingUsage {
                prompt_tokens: count_tokens(text),
                total_tokens: count_tokens(text),
            },
        })
    }

    /// Generate embeddings for multiple texts (batch)
    pub async fn embed_batch(&self, texts: &[&str], model: &str) -> Result<EmbeddingResponse> {
        // In a real implementation, this would handle batching efficiently
        let mut data = Vec::new();
        for (index, _text) in texts.iter().enumerate() {
            data.push(EmbeddingData {
                object: "embedding".to_string(),
                embedding: vec![0.1; 1024], // Mock embedding
                index: index as u32,
            });
        }

        Ok(EmbeddingResponse {
            model: model.to_string(),
            object: "list".to_string(),
            data,
            usage: EmbeddingUsage {
                prompt_tokens: texts.iter().map(|t| count_tokens(t)).sum(),
                total_tokens: texts.iter().map(|t| count_tokens(t)).sum(),
            },
        })
    }
}

/// Embedding response from Mistral
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingResponse {
    /// Model name used
    pub model: String,
    /// Object type
    pub object: String,
    /// List of embeddings
    pub data: Vec<EmbeddingData>,
    /// Token usage
    pub usage: EmbeddingUsage,
}

/// Individual embedding data
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingData {
    /// Object type
    pub object: String,
    /// Embedding vector
    pub embedding: Vec<f32>,
    /// Index in batch
    pub index: u32,
}

/// Token usage for embeddings
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingUsage {
    /// Prompt tokens
    pub prompt_tokens: u32,
    /// Total tokens
    pub total_tokens: u32,
}

/// Simple token counter (approximate)
fn count_tokens(text: &str) -> u32 {
    // Approximate: ~4 chars per token
    ((text.len() / 4) + 1) as u32
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_mistral_embeddings_creation() {
        let provider = MistralEmbeddingsProvider::new("test-api-key");
        assert!(!provider.api_key.is_empty());
    }

    #[tokio::test]
    async fn test_embed_single() {
        let provider = MistralEmbeddingsProvider::new("test-key");
        let response = provider
            .embed("Hello world", "mistral-embed")
            .await
            .unwrap();

        assert_eq!(response.model, "mistral-embed");
        assert_eq!(response.data.len(), 1);
        assert_eq!(response.data[0].embedding.len(), 1024);
    }

    #[tokio::test]
    async fn test_embed_batch() {
        let provider = MistralEmbeddingsProvider::new("test-key");
        let texts = vec!["Hello", "World", "Test"];
        let response = provider.embed_batch(&texts, "mistral-embed").await.unwrap();

        assert_eq!(response.data.len(), 3);
        assert_eq!(response.usage.prompt_tokens, response.usage.total_tokens);
    }

    #[test]
    fn test_token_counter() {
        assert!(count_tokens("Hello world") > 0);
        assert!(count_tokens("This is a longer text") > count_tokens("Hi"));
    }

    #[test]
    fn test_embedding_data() {
        let embedding = EmbeddingData {
            object: "embedding".to_string(),
            embedding: vec![0.1, 0.2, 0.3],
            index: 0,
        };
        assert_eq!(embedding.embedding.len(), 3);
        assert_eq!(embedding.index, 0);
    }
}