llmkit/
embedding.rs

1//! Embedding API for generating text embeddings across multiple providers.
2//!
3//! This module provides a unified interface for generating text embeddings
4//! from various providers including OpenAI, Voyage, Jina, Cohere, and others.
5//!
6//! # Example
7//!
8//! ```ignore
9//! use llmkit::{EmbeddingProvider, EmbeddingRequest, EmbeddingInput};
10//!
11//! // Create provider (OpenAI example)
12//! let provider = OpenAIProvider::from_env()?;
13//!
14//! // Create embedding request
15//! let request = EmbeddingRequest::new("text-embedding-3-small", "Hello, world!");
16//!
17//! // Get embeddings
18//! let response = provider.embed(request).await?;
19//! println!("Embedding dimensions: {}", response.embeddings[0].values.len());
20//! ```
21//!
22//! # Batch Embeddings
23//!
24//! ```ignore
25//! // Batch embed multiple texts
26//! let request = EmbeddingRequest::batch(
27//!     "text-embedding-3-small",
28//!     vec!["First text", "Second text", "Third text"],
29//! );
30//!
31//! let response = provider.embed(request).await?;
32//! for embedding in &response.embeddings {
33//!     println!("Index {}: {} dimensions", embedding.index, embedding.values.len());
34//! }
35//! ```
36
37use async_trait::async_trait;
38use serde::{Deserialize, Serialize};
39
40use crate::error::Result;
41
42/// Request for generating embeddings.
43#[derive(Debug, Clone)]
44pub struct EmbeddingRequest {
45    /// The embedding model to use (e.g., "text-embedding-3-small").
46    pub model: String,
47    /// The input text(s) to embed.
48    pub input: EmbeddingInput,
49    /// Optional: Number of dimensions for the output embedding.
50    /// Only supported by some models (e.g., OpenAI text-embedding-3-*).
51    pub dimensions: Option<usize>,
52    /// Optional: Output encoding format.
53    pub encoding_format: Option<EncodingFormat>,
54    /// Optional: Input type hint for optimized embeddings.
55    /// Supported by some providers like Voyage AI.
56    pub input_type: Option<EmbeddingInputType>,
57}
58
59impl EmbeddingRequest {
60    /// Create a new embedding request for a single text.
61    pub fn new(model: impl Into<String>, text: impl Into<String>) -> Self {
62        Self {
63            model: model.into(),
64            input: EmbeddingInput::Single(text.into()),
65            dimensions: None,
66            encoding_format: None,
67            input_type: None,
68        }
69    }
70
71    /// Create a new embedding request for multiple texts (batch).
72    pub fn batch(model: impl Into<String>, texts: Vec<impl Into<String>>) -> Self {
73        Self {
74            model: model.into(),
75            input: EmbeddingInput::Batch(texts.into_iter().map(|t| t.into()).collect()),
76            dimensions: None,
77            encoding_format: None,
78            input_type: None,
79        }
80    }
81
82    /// Set the output dimensions (for models that support dimension reduction).
83    pub fn with_dimensions(mut self, dimensions: usize) -> Self {
84        self.dimensions = Some(dimensions);
85        self
86    }
87
88    /// Set the encoding format.
89    pub fn with_encoding_format(mut self, format: EncodingFormat) -> Self {
90        self.encoding_format = Some(format);
91        self
92    }
93
94    /// Set the input type hint.
95    pub fn with_input_type(mut self, input_type: EmbeddingInputType) -> Self {
96        self.input_type = Some(input_type);
97        self
98    }
99
100    /// Get the number of texts to embed.
101    pub fn text_count(&self) -> usize {
102        match &self.input {
103            EmbeddingInput::Single(_) => 1,
104            EmbeddingInput::Batch(texts) => texts.len(),
105        }
106    }
107
108    /// Get all input texts as a vector.
109    pub fn texts(&self) -> Vec<&str> {
110        match &self.input {
111            EmbeddingInput::Single(text) => vec![text.as_str()],
112            EmbeddingInput::Batch(texts) => texts.iter().map(|s| s.as_str()).collect(),
113        }
114    }
115}
116
117/// Input for embedding requests.
118#[derive(Debug, Clone)]
119pub enum EmbeddingInput {
120    /// Single text to embed.
121    Single(String),
122    /// Multiple texts to embed in a batch.
123    Batch(Vec<String>),
124}
125
126/// Output encoding format for embeddings.
127#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
128#[serde(rename_all = "lowercase")]
129pub enum EncodingFormat {
130    /// Float32 array (default).
131    #[default]
132    Float,
133    /// Base64-encoded binary.
134    Base64,
135}
136
137/// Input type hint for embedding optimization.
138///
139/// Some providers (like Voyage AI) optimize embeddings based on the input type.
140#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
141#[serde(rename_all = "snake_case")]
142pub enum EmbeddingInputType {
143    /// The input is a search query.
144    Query,
145    /// The input is a document to be indexed.
146    Document,
147}
148
149/// Response from an embedding request.
150#[derive(Debug, Clone)]
151pub struct EmbeddingResponse {
152    /// The model used for embedding.
153    pub model: String,
154    /// The generated embeddings.
155    pub embeddings: Vec<Embedding>,
156    /// Token usage information.
157    pub usage: EmbeddingUsage,
158}
159
160impl EmbeddingResponse {
161    /// Get the first embedding (convenience for single-text requests).
162    pub fn first(&self) -> Option<&Embedding> {
163        self.embeddings.first()
164    }
165
166    /// Get embedding values as a flat vector (for single-text requests).
167    pub fn values(&self) -> Option<&[f32]> {
168        self.first().map(|e| e.values.as_slice())
169    }
170
171    /// Get the embedding dimensions.
172    pub fn dimensions(&self) -> usize {
173        self.embeddings.first().map(|e| e.values.len()).unwrap_or(0)
174    }
175}
176
177/// A single embedding vector.
178#[derive(Debug, Clone)]
179pub struct Embedding {
180    /// The index of this embedding in the batch.
181    pub index: usize,
182    /// The embedding vector values.
183    pub values: Vec<f32>,
184}
185
186impl Embedding {
187    /// Create a new embedding.
188    pub fn new(index: usize, values: Vec<f32>) -> Self {
189        Self { index, values }
190    }
191
192    /// Get the number of dimensions.
193    pub fn dimensions(&self) -> usize {
194        self.values.len()
195    }
196
197    /// Compute cosine similarity with another embedding.
198    pub fn cosine_similarity(&self, other: &Embedding) -> f32 {
199        if self.values.len() != other.values.len() {
200            return 0.0;
201        }
202
203        let dot_product: f32 = self
204            .values
205            .iter()
206            .zip(other.values.iter())
207            .map(|(a, b)| a * b)
208            .sum();
209
210        let norm_a: f32 = self.values.iter().map(|x| x * x).sum::<f32>().sqrt();
211        let norm_b: f32 = other.values.iter().map(|x| x * x).sum::<f32>().sqrt();
212
213        if norm_a == 0.0 || norm_b == 0.0 {
214            return 0.0;
215        }
216
217        dot_product / (norm_a * norm_b)
218    }
219
220    /// Compute dot product with another embedding.
221    pub fn dot_product(&self, other: &Embedding) -> f32 {
222        self.values
223            .iter()
224            .zip(other.values.iter())
225            .map(|(a, b)| a * b)
226            .sum()
227    }
228
229    /// Compute Euclidean distance to another embedding.
230    pub fn euclidean_distance(&self, other: &Embedding) -> f32 {
231        if self.values.len() != other.values.len() {
232            return f32::INFINITY;
233        }
234
235        self.values
236            .iter()
237            .zip(other.values.iter())
238            .map(|(a, b)| (a - b).powi(2))
239            .sum::<f32>()
240            .sqrt()
241    }
242}
243
244/// Token usage for embedding requests.
245#[derive(Debug, Clone, Default)]
246pub struct EmbeddingUsage {
247    /// Number of tokens in the input.
248    pub prompt_tokens: u32,
249    /// Total tokens processed.
250    pub total_tokens: u32,
251}
252
253impl EmbeddingUsage {
254    /// Create new usage stats.
255    pub fn new(prompt_tokens: u32, total_tokens: u32) -> Self {
256        Self {
257            prompt_tokens,
258            total_tokens,
259        }
260    }
261}
262
263/// Trait for providers that support text embeddings.
264#[async_trait]
265pub trait EmbeddingProvider: Send + Sync {
266    /// Get the provider name.
267    fn name(&self) -> &str;
268
269    /// Generate embeddings for the given request.
270    async fn embed(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse>;
271
272    /// Get the default embedding dimensions for a model.
273    ///
274    /// Returns `None` if the model is unknown or dimensions are variable.
275    fn embedding_dimensions(&self, _model: &str) -> Option<usize> {
276        None
277    }
278
279    /// Get the default embedding model for this provider.
280    fn default_embedding_model(&self) -> Option<&str> {
281        None
282    }
283
284    /// Get the maximum batch size supported by this provider.
285    fn max_batch_size(&self) -> usize {
286        2048
287    }
288
289    /// Check if a model supports dimension reduction.
290    fn supports_dimensions(&self, _model: &str) -> bool {
291        false
292    }
293
294    /// Get all supported embedding models.
295    fn supported_embedding_models(&self) -> Option<&[&str]> {
296        None
297    }
298}
299
300/// Information about an embedding model.
301#[derive(Debug, Clone)]
302pub struct EmbeddingModelInfo {
303    /// Model ID/name.
304    pub id: &'static str,
305    /// Provider that offers this model.
306    pub provider: &'static str,
307    /// Default output dimensions.
308    pub dimensions: usize,
309    /// Maximum input tokens.
310    pub max_tokens: usize,
311    /// Price per 1K tokens (USD).
312    pub pricing_per_1k: f64,
313    /// Whether the model supports dimension reduction.
314    pub supports_dimensions: bool,
315}
316
317/// Registry of known embedding models.
318pub static EMBEDDING_MODELS: &[EmbeddingModelInfo] = &[
319    // OpenAI
320    EmbeddingModelInfo {
321        id: "text-embedding-3-small",
322        provider: "openai",
323        dimensions: 1536,
324        max_tokens: 8191,
325        pricing_per_1k: 0.00002,
326        supports_dimensions: true,
327    },
328    EmbeddingModelInfo {
329        id: "text-embedding-3-large",
330        provider: "openai",
331        dimensions: 3072,
332        max_tokens: 8191,
333        pricing_per_1k: 0.00013,
334        supports_dimensions: true,
335    },
336    EmbeddingModelInfo {
337        id: "text-embedding-ada-002",
338        provider: "openai",
339        dimensions: 1536,
340        max_tokens: 8191,
341        pricing_per_1k: 0.0001,
342        supports_dimensions: false,
343    },
344    // Voyage AI
345    EmbeddingModelInfo {
346        id: "voyage-3",
347        provider: "voyage",
348        dimensions: 1024,
349        max_tokens: 32000,
350        pricing_per_1k: 0.00006,
351        supports_dimensions: false,
352    },
353    EmbeddingModelInfo {
354        id: "voyage-3-lite",
355        provider: "voyage",
356        dimensions: 512,
357        max_tokens: 32000,
358        pricing_per_1k: 0.00002,
359        supports_dimensions: false,
360    },
361    EmbeddingModelInfo {
362        id: "voyage-code-3",
363        provider: "voyage",
364        dimensions: 1024,
365        max_tokens: 32000,
366        pricing_per_1k: 0.00006,
367        supports_dimensions: false,
368    },
369    // Jina AI
370    EmbeddingModelInfo {
371        id: "jina-embeddings-v3",
372        provider: "jina",
373        dimensions: 1024,
374        max_tokens: 8192,
375        pricing_per_1k: 0.00002,
376        supports_dimensions: true,
377    },
378    EmbeddingModelInfo {
379        id: "jina-clip-v2",
380        provider: "jina",
381        dimensions: 1024,
382        max_tokens: 8192,
383        pricing_per_1k: 0.00002,
384        supports_dimensions: false,
385    },
386    // Cohere
387    EmbeddingModelInfo {
388        id: "embed-english-v3.0",
389        provider: "cohere",
390        dimensions: 1024,
391        max_tokens: 512,
392        pricing_per_1k: 0.0001,
393        supports_dimensions: false,
394    },
395    EmbeddingModelInfo {
396        id: "embed-multilingual-v3.0",
397        provider: "cohere",
398        dimensions: 1024,
399        max_tokens: 512,
400        pricing_per_1k: 0.0001,
401        supports_dimensions: false,
402    },
403    EmbeddingModelInfo {
404        id: "embed-english-light-v3.0",
405        provider: "cohere",
406        dimensions: 384,
407        max_tokens: 512,
408        pricing_per_1k: 0.0001,
409        supports_dimensions: false,
410    },
411    // Google
412    EmbeddingModelInfo {
413        id: "textembedding-gecko@003",
414        provider: "google",
415        dimensions: 768,
416        max_tokens: 3072,
417        pricing_per_1k: 0.000025,
418        supports_dimensions: false,
419    },
420    EmbeddingModelInfo {
421        id: "text-embedding-004",
422        provider: "google",
423        dimensions: 768,
424        max_tokens: 2048,
425        pricing_per_1k: 0.000025,
426        supports_dimensions: true,
427    },
428];
429
430/// Get embedding model info by ID.
431pub fn get_embedding_model_info(model_id: &str) -> Option<&'static EmbeddingModelInfo> {
432    EMBEDDING_MODELS.iter().find(|m| m.id == model_id)
433}
434
435/// Get all embedding models for a provider.
436pub fn get_embedding_models_by_provider(provider: &str) -> Vec<&'static EmbeddingModelInfo> {
437    EMBEDDING_MODELS
438        .iter()
439        .filter(|m| m.provider == provider)
440        .collect()
441}
442
443#[cfg(test)]
444mod tests {
445    use super::*;
446
447    #[test]
448    fn test_embedding_request_single() {
449        let request = EmbeddingRequest::new("text-embedding-3-small", "Hello, world!");
450        assert_eq!(request.model, "text-embedding-3-small");
451        assert_eq!(request.text_count(), 1);
452        assert_eq!(request.texts(), vec!["Hello, world!"]);
453    }
454
455    #[test]
456    fn test_embedding_request_batch() {
457        let request =
458            EmbeddingRequest::batch("text-embedding-3-small", vec!["First", "Second", "Third"]);
459        assert_eq!(request.text_count(), 3);
460        assert_eq!(request.texts(), vec!["First", "Second", "Third"]);
461    }
462
463    #[test]
464    fn test_embedding_request_with_dimensions() {
465        let request = EmbeddingRequest::new("text-embedding-3-small", "test").with_dimensions(256);
466        assert_eq!(request.dimensions, Some(256));
467    }
468
469    #[test]
470    fn test_cosine_similarity() {
471        let e1 = Embedding::new(0, vec![1.0, 0.0, 0.0]);
472        let e2 = Embedding::new(1, vec![1.0, 0.0, 0.0]);
473        let e3 = Embedding::new(2, vec![0.0, 1.0, 0.0]);
474
475        assert!((e1.cosine_similarity(&e2) - 1.0).abs() < 0.0001);
476        assert!((e1.cosine_similarity(&e3) - 0.0).abs() < 0.0001);
477    }
478
479    #[test]
480    fn test_euclidean_distance() {
481        let e1 = Embedding::new(0, vec![0.0, 0.0]);
482        let e2 = Embedding::new(1, vec![3.0, 4.0]);
483
484        assert!((e1.euclidean_distance(&e2) - 5.0).abs() < 0.0001);
485    }
486
487    #[test]
488    fn test_dot_product() {
489        let e1 = Embedding::new(0, vec![1.0, 2.0, 3.0]);
490        let e2 = Embedding::new(1, vec![4.0, 5.0, 6.0]);
491
492        // 1*4 + 2*5 + 3*6 = 4 + 10 + 18 = 32
493        assert!((e1.dot_product(&e2) - 32.0).abs() < 0.0001);
494    }
495
496    #[test]
497    fn test_embedding_model_registry() {
498        let model = get_embedding_model_info("text-embedding-3-small");
499        assert!(model.is_some());
500        let model = model.unwrap();
501        assert_eq!(model.provider, "openai");
502        assert_eq!(model.dimensions, 1536);
503        assert!(model.supports_dimensions);
504    }
505
506    #[test]
507    fn test_get_models_by_provider() {
508        let voyage_models = get_embedding_models_by_provider("voyage");
509        assert!(!voyage_models.is_empty());
510        assert!(voyage_models.iter().all(|m| m.provider == "voyage"));
511    }
512
513    #[test]
514    fn test_embedding_response() {
515        let response = EmbeddingResponse {
516            model: "test-model".to_string(),
517            embeddings: vec![
518                Embedding::new(0, vec![0.1, 0.2, 0.3]),
519                Embedding::new(1, vec![0.4, 0.5, 0.6]),
520            ],
521            usage: EmbeddingUsage::new(10, 10),
522        };
523
524        assert_eq!(response.dimensions(), 3);
525        assert!(response.first().is_some());
526        assert_eq!(response.values().unwrap(), &[0.1, 0.2, 0.3]);
527    }
528}