use serde::{Deserialize, Serialize};
use crate::error::{Error, Result};
pub struct MistralEmbeddingsProvider {
#[allow(dead_code)]
api_key: String,
}
impl MistralEmbeddingsProvider {
pub fn new(api_key: &str) -> Self {
Self {
api_key: api_key.to_string(),
}
}
pub fn from_env() -> Result<Self> {
let api_key = std::env::var("MISTRAL_API_KEY").map_err(|_| {
Error::Configuration("MISTRAL_API_KEY environment variable not set".to_string())
})?;
Ok(Self::new(&api_key))
}
pub async fn embed(&self, text: &str, model: &str) -> Result<EmbeddingResponse> {
Ok(EmbeddingResponse {
model: model.to_string(),
object: "list".to_string(),
data: vec![EmbeddingData {
object: "embedding".to_string(),
embedding: vec![0.1; 1024], index: 0,
}],
usage: EmbeddingUsage {
prompt_tokens: count_tokens(text),
total_tokens: count_tokens(text),
},
})
}
pub async fn embed_batch(&self, texts: &[&str], model: &str) -> Result<EmbeddingResponse> {
let mut data = Vec::new();
for (index, _text) in texts.iter().enumerate() {
data.push(EmbeddingData {
object: "embedding".to_string(),
embedding: vec![0.1; 1024], index: index as u32,
});
}
Ok(EmbeddingResponse {
model: model.to_string(),
object: "list".to_string(),
data,
usage: EmbeddingUsage {
prompt_tokens: texts.iter().map(|t| count_tokens(t)).sum(),
total_tokens: texts.iter().map(|t| count_tokens(t)).sum(),
},
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingResponse {
pub model: String,
pub object: String,
pub data: Vec<EmbeddingData>,
pub usage: EmbeddingUsage,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingData {
pub object: String,
pub embedding: Vec<f32>,
pub index: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingUsage {
pub prompt_tokens: u32,
pub total_tokens: u32,
}
fn count_tokens(text: &str) -> u32 {
((text.len() / 4) + 1) as u32
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mistral_embeddings_creation() {
let provider = MistralEmbeddingsProvider::new("test-api-key");
assert!(!provider.api_key.is_empty());
}
#[tokio::test]
async fn test_embed_single() {
let provider = MistralEmbeddingsProvider::new("test-key");
let response = provider
.embed("Hello world", "mistral-embed")
.await
.unwrap();
assert_eq!(response.model, "mistral-embed");
assert_eq!(response.data.len(), 1);
assert_eq!(response.data[0].embedding.len(), 1024);
}
#[tokio::test]
async fn test_embed_batch() {
let provider = MistralEmbeddingsProvider::new("test-key");
let texts = vec!["Hello", "World", "Test"];
let response = provider.embed_batch(&texts, "mistral-embed").await.unwrap();
assert_eq!(response.data.len(), 3);
assert_eq!(response.usage.prompt_tokens, response.usage.total_tokens);
}
#[test]
fn test_token_counter() {
assert!(count_tokens("Hello world") > 0);
assert!(count_tokens("This is a longer text") > count_tokens("Hi"));
}
#[test]
fn test_embedding_data() {
let embedding = EmbeddingData {
object: "embedding".to_string(),
embedding: vec![0.1, 0.2, 0.3],
index: 0,
};
assert_eq!(embedding.embedding.len(), 3);
assert_eq!(embedding.index, 0);
}
}