1use async_trait::async_trait;
38use serde::{Deserialize, Serialize};
39
40use crate::error::Result;
41
42#[derive(Debug, Clone)]
44pub struct EmbeddingRequest {
45 pub model: String,
47 pub input: EmbeddingInput,
49 pub dimensions: Option<usize>,
52 pub encoding_format: Option<EncodingFormat>,
54 pub input_type: Option<EmbeddingInputType>,
57}
58
59impl EmbeddingRequest {
60 pub fn new(model: impl Into<String>, text: impl Into<String>) -> Self {
62 Self {
63 model: model.into(),
64 input: EmbeddingInput::Single(text.into()),
65 dimensions: None,
66 encoding_format: None,
67 input_type: None,
68 }
69 }
70
71 pub fn batch(model: impl Into<String>, texts: Vec<impl Into<String>>) -> Self {
73 Self {
74 model: model.into(),
75 input: EmbeddingInput::Batch(texts.into_iter().map(|t| t.into()).collect()),
76 dimensions: None,
77 encoding_format: None,
78 input_type: None,
79 }
80 }
81
82 pub fn with_dimensions(mut self, dimensions: usize) -> Self {
84 self.dimensions = Some(dimensions);
85 self
86 }
87
88 pub fn with_encoding_format(mut self, format: EncodingFormat) -> Self {
90 self.encoding_format = Some(format);
91 self
92 }
93
94 pub fn with_input_type(mut self, input_type: EmbeddingInputType) -> Self {
96 self.input_type = Some(input_type);
97 self
98 }
99
100 pub fn text_count(&self) -> usize {
102 match &self.input {
103 EmbeddingInput::Single(_) => 1,
104 EmbeddingInput::Batch(texts) => texts.len(),
105 }
106 }
107
108 pub fn texts(&self) -> Vec<&str> {
110 match &self.input {
111 EmbeddingInput::Single(text) => vec![text.as_str()],
112 EmbeddingInput::Batch(texts) => texts.iter().map(|s| s.as_str()).collect(),
113 }
114 }
115}
116
117#[derive(Debug, Clone)]
119pub enum EmbeddingInput {
120 Single(String),
122 Batch(Vec<String>),
124}
125
126#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
128#[serde(rename_all = "lowercase")]
129pub enum EncodingFormat {
130 #[default]
132 Float,
133 Base64,
135}
136
137#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
141#[serde(rename_all = "snake_case")]
142pub enum EmbeddingInputType {
143 Query,
145 Document,
147}
148
149#[derive(Debug, Clone)]
151pub struct EmbeddingResponse {
152 pub model: String,
154 pub embeddings: Vec<Embedding>,
156 pub usage: EmbeddingUsage,
158}
159
160impl EmbeddingResponse {
161 pub fn first(&self) -> Option<&Embedding> {
163 self.embeddings.first()
164 }
165
166 pub fn values(&self) -> Option<&[f32]> {
168 self.first().map(|e| e.values.as_slice())
169 }
170
171 pub fn dimensions(&self) -> usize {
173 self.embeddings.first().map(|e| e.values.len()).unwrap_or(0)
174 }
175}
176
177#[derive(Debug, Clone)]
179pub struct Embedding {
180 pub index: usize,
182 pub values: Vec<f32>,
184}
185
186impl Embedding {
187 pub fn new(index: usize, values: Vec<f32>) -> Self {
189 Self { index, values }
190 }
191
192 pub fn dimensions(&self) -> usize {
194 self.values.len()
195 }
196
197 pub fn cosine_similarity(&self, other: &Embedding) -> f32 {
199 if self.values.len() != other.values.len() {
200 return 0.0;
201 }
202
203 let dot_product: f32 = self
204 .values
205 .iter()
206 .zip(other.values.iter())
207 .map(|(a, b)| a * b)
208 .sum();
209
210 let norm_a: f32 = self.values.iter().map(|x| x * x).sum::<f32>().sqrt();
211 let norm_b: f32 = other.values.iter().map(|x| x * x).sum::<f32>().sqrt();
212
213 if norm_a == 0.0 || norm_b == 0.0 {
214 return 0.0;
215 }
216
217 dot_product / (norm_a * norm_b)
218 }
219
220 pub fn dot_product(&self, other: &Embedding) -> f32 {
222 self.values
223 .iter()
224 .zip(other.values.iter())
225 .map(|(a, b)| a * b)
226 .sum()
227 }
228
229 pub fn euclidean_distance(&self, other: &Embedding) -> f32 {
231 if self.values.len() != other.values.len() {
232 return f32::INFINITY;
233 }
234
235 self.values
236 .iter()
237 .zip(other.values.iter())
238 .map(|(a, b)| (a - b).powi(2))
239 .sum::<f32>()
240 .sqrt()
241 }
242}
243
244#[derive(Debug, Clone, Default)]
246pub struct EmbeddingUsage {
247 pub prompt_tokens: u32,
249 pub total_tokens: u32,
251}
252
253impl EmbeddingUsage {
254 pub fn new(prompt_tokens: u32, total_tokens: u32) -> Self {
256 Self {
257 prompt_tokens,
258 total_tokens,
259 }
260 }
261}
262
263#[async_trait]
265pub trait EmbeddingProvider: Send + Sync {
266 fn name(&self) -> &str;
268
269 async fn embed(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse>;
271
272 fn embedding_dimensions(&self, _model: &str) -> Option<usize> {
276 None
277 }
278
279 fn default_embedding_model(&self) -> Option<&str> {
281 None
282 }
283
284 fn max_batch_size(&self) -> usize {
286 2048
287 }
288
289 fn supports_dimensions(&self, _model: &str) -> bool {
291 false
292 }
293
294 fn supported_embedding_models(&self) -> Option<&[&str]> {
296 None
297 }
298}
299
300#[derive(Debug, Clone)]
302pub struct EmbeddingModelInfo {
303 pub id: &'static str,
305 pub provider: &'static str,
307 pub dimensions: usize,
309 pub max_tokens: usize,
311 pub pricing_per_1k: f64,
313 pub supports_dimensions: bool,
315}
316
317pub static EMBEDDING_MODELS: &[EmbeddingModelInfo] = &[
319 EmbeddingModelInfo {
321 id: "text-embedding-3-small",
322 provider: "openai",
323 dimensions: 1536,
324 max_tokens: 8191,
325 pricing_per_1k: 0.00002,
326 supports_dimensions: true,
327 },
328 EmbeddingModelInfo {
329 id: "text-embedding-3-large",
330 provider: "openai",
331 dimensions: 3072,
332 max_tokens: 8191,
333 pricing_per_1k: 0.00013,
334 supports_dimensions: true,
335 },
336 EmbeddingModelInfo {
337 id: "text-embedding-ada-002",
338 provider: "openai",
339 dimensions: 1536,
340 max_tokens: 8191,
341 pricing_per_1k: 0.0001,
342 supports_dimensions: false,
343 },
344 EmbeddingModelInfo {
346 id: "voyage-3",
347 provider: "voyage",
348 dimensions: 1024,
349 max_tokens: 32000,
350 pricing_per_1k: 0.00006,
351 supports_dimensions: false,
352 },
353 EmbeddingModelInfo {
354 id: "voyage-3-lite",
355 provider: "voyage",
356 dimensions: 512,
357 max_tokens: 32000,
358 pricing_per_1k: 0.00002,
359 supports_dimensions: false,
360 },
361 EmbeddingModelInfo {
362 id: "voyage-code-3",
363 provider: "voyage",
364 dimensions: 1024,
365 max_tokens: 32000,
366 pricing_per_1k: 0.00006,
367 supports_dimensions: false,
368 },
369 EmbeddingModelInfo {
371 id: "jina-embeddings-v3",
372 provider: "jina",
373 dimensions: 1024,
374 max_tokens: 8192,
375 pricing_per_1k: 0.00002,
376 supports_dimensions: true,
377 },
378 EmbeddingModelInfo {
379 id: "jina-clip-v2",
380 provider: "jina",
381 dimensions: 1024,
382 max_tokens: 8192,
383 pricing_per_1k: 0.00002,
384 supports_dimensions: false,
385 },
386 EmbeddingModelInfo {
388 id: "embed-english-v3.0",
389 provider: "cohere",
390 dimensions: 1024,
391 max_tokens: 512,
392 pricing_per_1k: 0.0001,
393 supports_dimensions: false,
394 },
395 EmbeddingModelInfo {
396 id: "embed-multilingual-v3.0",
397 provider: "cohere",
398 dimensions: 1024,
399 max_tokens: 512,
400 pricing_per_1k: 0.0001,
401 supports_dimensions: false,
402 },
403 EmbeddingModelInfo {
404 id: "embed-english-light-v3.0",
405 provider: "cohere",
406 dimensions: 384,
407 max_tokens: 512,
408 pricing_per_1k: 0.0001,
409 supports_dimensions: false,
410 },
411 EmbeddingModelInfo {
413 id: "textembedding-gecko@003",
414 provider: "google",
415 dimensions: 768,
416 max_tokens: 3072,
417 pricing_per_1k: 0.000025,
418 supports_dimensions: false,
419 },
420 EmbeddingModelInfo {
421 id: "text-embedding-004",
422 provider: "google",
423 dimensions: 768,
424 max_tokens: 2048,
425 pricing_per_1k: 0.000025,
426 supports_dimensions: true,
427 },
428];
429
430pub fn get_embedding_model_info(model_id: &str) -> Option<&'static EmbeddingModelInfo> {
432 EMBEDDING_MODELS.iter().find(|m| m.id == model_id)
433}
434
435pub fn get_embedding_models_by_provider(provider: &str) -> Vec<&'static EmbeddingModelInfo> {
437 EMBEDDING_MODELS
438 .iter()
439 .filter(|m| m.provider == provider)
440 .collect()
441}
442
443#[cfg(test)]
444mod tests {
445 use super::*;
446
447 #[test]
448 fn test_embedding_request_single() {
449 let request = EmbeddingRequest::new("text-embedding-3-small", "Hello, world!");
450 assert_eq!(request.model, "text-embedding-3-small");
451 assert_eq!(request.text_count(), 1);
452 assert_eq!(request.texts(), vec!["Hello, world!"]);
453 }
454
455 #[test]
456 fn test_embedding_request_batch() {
457 let request =
458 EmbeddingRequest::batch("text-embedding-3-small", vec!["First", "Second", "Third"]);
459 assert_eq!(request.text_count(), 3);
460 assert_eq!(request.texts(), vec!["First", "Second", "Third"]);
461 }
462
463 #[test]
464 fn test_embedding_request_with_dimensions() {
465 let request = EmbeddingRequest::new("text-embedding-3-small", "test").with_dimensions(256);
466 assert_eq!(request.dimensions, Some(256));
467 }
468
469 #[test]
470 fn test_cosine_similarity() {
471 let e1 = Embedding::new(0, vec![1.0, 0.0, 0.0]);
472 let e2 = Embedding::new(1, vec![1.0, 0.0, 0.0]);
473 let e3 = Embedding::new(2, vec![0.0, 1.0, 0.0]);
474
475 assert!((e1.cosine_similarity(&e2) - 1.0).abs() < 0.0001);
476 assert!((e1.cosine_similarity(&e3) - 0.0).abs() < 0.0001);
477 }
478
479 #[test]
480 fn test_euclidean_distance() {
481 let e1 = Embedding::new(0, vec![0.0, 0.0]);
482 let e2 = Embedding::new(1, vec![3.0, 4.0]);
483
484 assert!((e1.euclidean_distance(&e2) - 5.0).abs() < 0.0001);
485 }
486
487 #[test]
488 fn test_dot_product() {
489 let e1 = Embedding::new(0, vec![1.0, 2.0, 3.0]);
490 let e2 = Embedding::new(1, vec![4.0, 5.0, 6.0]);
491
492 assert!((e1.dot_product(&e2) - 32.0).abs() < 0.0001);
494 }
495
496 #[test]
497 fn test_embedding_model_registry() {
498 let model = get_embedding_model_info("text-embedding-3-small");
499 assert!(model.is_some());
500 let model = model.unwrap();
501 assert_eq!(model.provider, "openai");
502 assert_eq!(model.dimensions, 1536);
503 assert!(model.supports_dimensions);
504 }
505
506 #[test]
507 fn test_get_models_by_provider() {
508 let voyage_models = get_embedding_models_by_provider("voyage");
509 assert!(!voyage_models.is_empty());
510 assert!(voyage_models.iter().all(|m| m.provider == "voyage"));
511 }
512
513 #[test]
514 fn test_embedding_response() {
515 let response = EmbeddingResponse {
516 model: "test-model".to_string(),
517 embeddings: vec![
518 Embedding::new(0, vec![0.1, 0.2, 0.3]),
519 Embedding::new(1, vec![0.4, 0.5, 0.6]),
520 ],
521 usage: EmbeddingUsage::new(10, 10),
522 };
523
524 assert_eq!(response.dimensions(), 3);
525 assert!(response.first().is_some());
526 assert_eq!(response.values().unwrap(), &[0.1, 0.2, 0.3]);
527 }
528}