rrag/embeddings.rs
1//! # RRAG Embeddings System
2//!
3//! High-performance, async-first embedding generation with pluggable providers,
4//! efficient batching, and comprehensive error handling. Built for production
5//! workloads with robust retry logic and monitoring capabilities.
6//!
7//! ## Features
8//!
9//! - **Multi-Provider Support**: OpenAI, local models, and custom providers
10//! - **Efficient Batching**: Automatic batching with configurable sizes
11//! - **Retry Logic**: Robust error handling with exponential backoff
12//! - **Parallel Processing**: Concurrent embedding generation
13//! - **Similarity Metrics**: Built-in cosine similarity and distance calculations
14//! - **Metadata Support**: Rich metadata tracking for embeddings
15//! - **Health Monitoring**: Provider health checks and monitoring
16//!
17//! ## Quick Start
18//!
19//! ```rust
20//! use rrag::prelude::*;
21//! use std::sync::Arc;
22//!
23//! # #[tokio::main]
24//! # async fn main() -> RragResult<()> {
25//! // Create an OpenAI provider
26//! let provider = Arc::new(OpenAIEmbeddingProvider::new("your-api-key")
27//! .with_model("text-embedding-3-small"));
28//!
29//! // Create embedding service
30//! let service = EmbeddingService::new(provider);
31//!
32//! // Embed a document
33//! let document = Document::new("The quick brown fox jumps over the lazy dog");
34//! let embedding = service.embed_document(&document).await?;
35//!
36//! println!("Generated embedding with {} dimensions", embedding.dimensions);
37//! # Ok(())
38//! # }
39//! ```
40//!
41//! ## Batch Processing
42//!
43//! For high-throughput scenarios, use batch processing:
44//!
45//! ```rust
46//! use rrag::prelude::*;
47//! use std::sync::Arc;
48//!
49//! # #[tokio::main]
50//! # async fn main() -> RragResult<()> {
51//! let provider = Arc::new(LocalEmbeddingProvider::new("sentence-transformers/all-MiniLM-L6-v2", 384));
52//! let service = EmbeddingService::with_config(
53//! provider,
54//! EmbeddingConfig {
55//! batch_size: 50,
56//! parallel_processing: true,
57//! max_retries: 3,
58//! retry_delay_ms: 1000,
59//! }
60//! );
61//!
62//! let documents = vec![
63//! Document::new("First document"),
64//! Document::new("Second document"),
65//! Document::new("Third document"),
66//! ];
67//!
68//! let embeddings = service.embed_documents(&documents).await?;
69//! println!("Generated {} embeddings", embeddings.len());
70//! # Ok(())
71//! # }
72//! ```
73//!
74//! ## Custom Providers
75//!
76//! Implement the [`EmbeddingProvider`] trait to create custom providers:
77//!
78//! ```rust
79//! use rrag::prelude::*;
80//! use async_trait::async_trait;
81//!
82//! struct MyCustomProvider {
83//! model_name: String,
84//! }
85//!
86//! #[async_trait]
87//! impl EmbeddingProvider for MyCustomProvider {
88//! fn name(&self) -> &str { "custom" }
89//! fn supported_models(&self) -> Vec<&str> { vec!["custom-model-v1"] }
90//! fn max_batch_size(&self) -> usize { 32 }
91//! fn embedding_dimensions(&self) -> usize { 512 }
92//!
93//! async fn embed_text(&self, text: &str) -> RragResult<Embedding> {
94//! // Your custom embedding logic here
95//! # let vector = vec![0.0; 512];
96//! # Ok(Embedding::new(vector, &self.model_name, text))
97//! }
98//!
99//! async fn embed_batch(&self, requests: Vec<EmbeddingRequest>) -> RragResult<EmbeddingBatch> {
100//! // Your custom batch processing logic
101//! # todo!()
102//! }
103//!
104//! async fn health_check(&self) -> RragResult<bool> {
105//! // Your health check logic
106//! Ok(true)
107//! }
108//! }
109//! ```
110//!
111//! ## Performance Considerations
112//!
113//! - Use batch processing for multiple documents to reduce API overhead
114//! - Configure appropriate batch sizes based on your provider's limits
115//! - Enable parallel processing for local models to utilize multiple CPU cores
116//! - Monitor provider health and implement fallback strategies
117//! - Cache embeddings when possible to avoid redundant API calls
118//!
119//! ## Error Handling
120//!
121//! The embedding system provides detailed error information:
122//!
123//! ```rust
124//! use rrag::prelude::*;
125//!
126//! # #[tokio::main]
127//! # async fn main() {
128//! match service.embed_document(&document).await {
129//! Ok(embedding) => {
130//! println!("Success: {} dimensions", embedding.dimensions);
131//! }
132//! Err(RragError::Embedding { content_type, message, .. }) => {
133//! eprintln!("Embedding error for {}: {}", content_type, message);
134//! }
135//! Err(e) => {
136//! eprintln!("Other error: {}", e);
137//! }
138//! }
139//! # }
140//! ```
141
142use crate::{Document, DocumentChunk, RragError, RragResult};
143use async_trait::async_trait;
144use serde::{Deserialize, Serialize};
145use std::collections::HashMap;
146use std::sync::Arc;
147
148/// Embedding vector type optimized for common dimensions
149///
150/// Uses `Vec<f32>` for maximum compatibility with ML libraries and APIs.
151/// Common dimensions:
152/// - 384: sentence-transformers/all-MiniLM-L6-v2
153/// - 768: BERT-base models
154/// - 1536: OpenAI text-embedding-ada-002
155/// - 3072: OpenAI text-embedding-3-large
156pub type EmbeddingVector = Vec<f32>;
157
158/// A dense vector representation of text content with metadata
159///
160/// Embeddings are high-dimensional vectors that capture semantic meaning
161/// of text in a format suitable for similarity comparison and retrieval.
162/// Each embedding includes the vector data, source information, and
163/// generation metadata.
164///
165/// # Example
166///
167/// ```rust
168/// use rrag::prelude::*;
169///
170/// let vector = vec![0.1, -0.2, 0.3, 0.4]; // 4-dimensional example
171/// let embedding = Embedding::new(vector, "text-embedding-ada-002", "doc-123")
172/// .with_metadata("content_type", "paragraph".into())
173/// .with_metadata("language", "en".into());
174///
175/// assert_eq!(embedding.dimensions, 4);
176/// assert_eq!(embedding.model, "text-embedding-ada-002");
177/// ```
178#[derive(Debug, Clone, Serialize, Deserialize)]
179pub struct Embedding {
180 /// The actual embedding vector
181 pub vector: EmbeddingVector,
182
183 /// Dimensions of the embedding
184 pub dimensions: usize,
185
186 /// Model used to generate this embedding
187 pub model: String,
188
189 /// Source content identifier
190 pub source_id: String,
191
192 /// Embedding metadata
193 pub metadata: HashMap<String, serde_json::Value>,
194
195 /// Generation timestamp
196 pub created_at: chrono::DateTime<chrono::Utc>,
197}
198
199impl Embedding {
200 /// Create a new embedding
201 pub fn new(
202 vector: EmbeddingVector,
203 model: impl Into<String>,
204 source_id: impl Into<String>,
205 ) -> Self {
206 let dimensions = vector.len();
207 Self {
208 vector,
209 dimensions,
210 model: model.into(),
211 source_id: source_id.into(),
212 metadata: HashMap::new(),
213 created_at: chrono::Utc::now(),
214 }
215 }
216
217 /// Add metadata using builder pattern
218 pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
219 self.metadata.insert(key.into(), value);
220 self
221 }
222
223 /// Calculate cosine similarity with another embedding
224 pub fn cosine_similarity(&self, other: &Embedding) -> RragResult<f32> {
225 if self.dimensions != other.dimensions {
226 return Err(RragError::embedding(
227 "similarity_calculation",
228 format!(
229 "Dimension mismatch: {} vs {}",
230 self.dimensions, other.dimensions
231 ),
232 ));
233 }
234
235 let dot_product: f32 = self
236 .vector
237 .iter()
238 .zip(other.vector.iter())
239 .map(|(a, b)| a * b)
240 .sum();
241
242 let norm_a: f32 = self.vector.iter().map(|x| x * x).sum::<f32>().sqrt();
243 let norm_b: f32 = other.vector.iter().map(|x| x * x).sum::<f32>().sqrt();
244
245 if norm_a == 0.0 || norm_b == 0.0 {
246 return Ok(0.0);
247 }
248
249 Ok(dot_product / (norm_a * norm_b))
250 }
251
252 /// Calculate Euclidean distance with another embedding
253 pub fn euclidean_distance(&self, other: &Embedding) -> RragResult<f32> {
254 if self.dimensions != other.dimensions {
255 return Err(RragError::embedding(
256 "distance_calculation",
257 format!(
258 "Dimension mismatch: {} vs {}",
259 self.dimensions, other.dimensions
260 ),
261 ));
262 }
263
264 let distance: f32 = self
265 .vector
266 .iter()
267 .zip(other.vector.iter())
268 .map(|(a, b)| (a - b).powi(2))
269 .sum::<f32>()
270 .sqrt();
271
272 Ok(distance)
273 }
274}
275
276/// A request for embedding generation, used in batch processing
277///
278/// Embedding requests bundle text content with metadata and a unique
279/// identifier for efficient batch processing and result tracking.
280///
281/// # Example
282///
283/// ```rust
284/// use rrag::prelude::*;
285///
286/// let request = EmbeddingRequest::new("chunk-1", "The quick brown fox")
287/// .with_metadata("chunk_index", 0.into())
288/// .with_metadata("document_id", "doc-123".into());
289/// ```
290#[derive(Debug, Clone)]
291pub struct EmbeddingRequest {
292 /// Unique identifier for the request
293 pub id: String,
294
295 /// Text content to embed
296 pub content: String,
297
298 /// Optional metadata to attach to the embedding
299 pub metadata: HashMap<String, serde_json::Value>,
300}
301
302impl EmbeddingRequest {
303 /// Create a new embedding request
304 pub fn new(id: impl Into<String>, content: impl Into<String>) -> Self {
305 Self {
306 id: id.into(),
307 content: content.into(),
308 metadata: HashMap::new(),
309 }
310 }
311
312 /// Add metadata to the embedding request
313 pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
314 self.metadata.insert(key.into(), value);
315 self
316 }
317}
318
319/// Batch embedding response
320#[derive(Debug)]
321pub struct EmbeddingBatch {
322 /// Generated embeddings indexed by request ID
323 pub embeddings: HashMap<String, Embedding>,
324
325 /// Processing metadata
326 pub metadata: BatchMetadata,
327}
328
329/// Metadata for batch processing
330#[derive(Debug, Clone, Serialize, Deserialize)]
331pub struct BatchMetadata {
332 /// Total items processed
333 pub total_items: usize,
334
335 /// Successfully processed items
336 pub successful_items: usize,
337
338 /// Failed items with error messages
339 pub failed_items: HashMap<String, String>,
340
341 /// Processing duration in milliseconds
342 pub duration_ms: u64,
343
344 /// Model used for embedding
345 pub model: String,
346
347 /// Provider used
348 pub provider: String,
349}
350
351/// Core trait for embedding providers
352///
353/// This trait defines the interface that all embedding providers must implement.
354/// Providers can be remote APIs (OpenAI, Cohere), local models (Hugging Face),
355/// or custom implementations.
356///
357/// The trait is designed for async operation and includes methods for:
358/// - Single text embedding
359/// - Batch processing for efficiency
360/// - Health monitoring
361/// - Provider introspection
362///
363/// # Implementation Guidelines
364///
365/// - Implement proper error handling with detailed context
366/// - Use appropriate timeouts for network operations
367/// - Support cancellation via async context
368/// - Provide accurate metadata in batch responses
369/// - Implement health checks that verify actual connectivity
370///
371/// # Example
372///
373/// ```rust
374/// use rrag::prelude::*;
375/// use async_trait::async_trait;
376///
377/// struct MyProvider;
378///
379/// #[async_trait]
380/// impl EmbeddingProvider for MyProvider {
381/// fn name(&self) -> &str { "my-provider" }
382/// fn supported_models(&self) -> Vec<&str> { vec!["model-v1"] }
383/// fn max_batch_size(&self) -> usize { 100 }
384/// fn embedding_dimensions(&self) -> usize { 768 }
385///
386/// async fn embed_text(&self, text: &str) -> RragResult<Embedding> {
387/// // Implementation here
388/// # todo!()
389/// }
390///
391/// async fn embed_batch(&self, requests: Vec<EmbeddingRequest>) -> RragResult<EmbeddingBatch> {
392/// // Batch implementation here
393/// # todo!()
394/// }
395///
396/// async fn health_check(&self) -> RragResult<bool> {
397/// // Health check implementation
398/// Ok(true)
399/// }
400/// }
401/// ```
402#[async_trait]
403pub trait EmbeddingProvider: Send + Sync {
404 /// Provider name (e.g., "openai", "huggingface")
405 fn name(&self) -> &str;
406
407 /// Supported models for this provider
408 fn supported_models(&self) -> Vec<&str>;
409
410 /// Maximum batch size supported
411 fn max_batch_size(&self) -> usize;
412
413 /// Embedding dimensions for the current model
414 fn embedding_dimensions(&self) -> usize;
415
416 /// Generate embedding for a single text
417 async fn embed_text(&self, text: &str) -> RragResult<Embedding>;
418
419 /// Generate embeddings for multiple texts (more efficient)
420 async fn embed_batch(&self, requests: Vec<EmbeddingRequest>) -> RragResult<EmbeddingBatch>;
421
422 /// Health check for the provider
423 async fn health_check(&self) -> RragResult<bool>;
424}
425
426/// OpenAI embedding provider for production-grade text embeddings
427///
428/// Provides access to OpenAI's embedding models through their API.
429/// Supports all current OpenAI embedding models with automatic
430/// batching and retry logic.
431///
432/// # Supported Models
433///
434/// - `text-embedding-ada-002`: Legacy model, 1536 dimensions
435/// - `text-embedding-3-small`: New model, 1536 dimensions, better performance
436/// - `text-embedding-3-large`: Premium model, 3072 dimensions, highest quality
437///
438/// # Example
439///
440/// ```rust
441/// use rrag::prelude::*;
442/// use std::sync::Arc;
443///
444/// # #[tokio::main]
445/// # async fn main() -> RragResult<()> {
446/// let provider = Arc::new(
447/// OpenAIEmbeddingProvider::new(std::env::var("OPENAI_API_KEY").unwrap())
448/// .with_model("text-embedding-3-small")
449/// );
450///
451/// let service = EmbeddingService::new(provider);
452/// let embedding = service.embed_document(&Document::new("Hello world")).await?;
453///
454/// assert_eq!(embedding.dimensions, 1536);
455/// # Ok(())
456/// # }
457/// ```
458///
459/// # Rate Limits
460///
461/// OpenAI has rate limits that vary by plan:
462/// - Free tier: 3 RPM, 150,000 TPM
463/// - Pay-as-you-go: 3,000 RPM, 1,000,000 TPM
464///
465/// The provider automatically handles batching within these limits.
466#[allow(dead_code)]
467pub struct OpenAIEmbeddingProvider {
468 /// API client (placeholder - would use actual HTTP client)
469 client: String, // In production: reqwest::Client or rsllm client
470
471 /// Model to use for embeddings
472 model: String,
473
474 /// API key
475 api_key: String,
476
477 /// Request timeout
478 timeout: std::time::Duration,
479}
480
481impl OpenAIEmbeddingProvider {
482 /// Create a new OpenAI embedding provider
483 pub fn new(api_key: impl Into<String>) -> Self {
484 Self {
485 client: "openai_client".to_string(), // Placeholder
486 model: "text-embedding-ada-002".to_string(),
487 api_key: api_key.into(),
488 timeout: std::time::Duration::from_secs(30),
489 }
490 }
491
492 /// Set the model to use for embeddings
493 pub fn with_model(mut self, model: impl Into<String>) -> Self {
494 self.model = model.into();
495 self
496 }
497}
498
499#[async_trait]
500impl EmbeddingProvider for OpenAIEmbeddingProvider {
501 fn name(&self) -> &str {
502 "openai"
503 }
504
505 fn supported_models(&self) -> Vec<&str> {
506 vec![
507 "text-embedding-ada-002",
508 "text-embedding-3-small",
509 "text-embedding-3-large",
510 ]
511 }
512
513 fn max_batch_size(&self) -> usize {
514 100 // OpenAI's current limit
515 }
516
517 fn embedding_dimensions(&self) -> usize {
518 match self.model.as_str() {
519 "text-embedding-ada-002" => 1536,
520 "text-embedding-3-small" => 1536,
521 "text-embedding-3-large" => 3072,
522 _ => 1536, // Default fallback
523 }
524 }
525
526 async fn embed_text(&self, text: &str) -> RragResult<Embedding> {
527 // Mock implementation - in production, this would make actual API calls
528 let start = std::time::Instant::now();
529
530 // Simulate API delay
531 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
532
533 // Generate mock embedding vector
534 let dimensions = self.embedding_dimensions();
535 let vector: Vec<f32> = (0..dimensions)
536 .map(|i| (text.len() as f32 + i as f32) / 1000.0)
537 .collect();
538
539 let embedding = Embedding::new(vector, &self.model, text)
540 .with_metadata(
541 "processing_time_ms",
542 serde_json::Value::Number((start.elapsed().as_millis() as u64).into()),
543 )
544 .with_metadata(
545 "provider",
546 serde_json::Value::String(self.name().to_string()),
547 );
548
549 Ok(embedding)
550 }
551
552 async fn embed_batch(&self, requests: Vec<EmbeddingRequest>) -> RragResult<EmbeddingBatch> {
553 let start = std::time::Instant::now();
554
555 if requests.len() > self.max_batch_size() {
556 return Err(RragError::embedding(
557 "batch_processing",
558 format!(
559 "Batch size {} exceeds maximum {}",
560 requests.len(),
561 self.max_batch_size()
562 ),
563 ));
564 }
565
566 let mut embeddings = HashMap::new();
567 let mut failed_items = HashMap::new();
568 let mut successful_count = 0;
569
570 for request in requests.iter() {
571 match self.embed_text(&request.content).await {
572 Ok(mut embedding) => {
573 // Merge request metadata
574 embedding.metadata.extend(request.metadata.clone());
575 embedding.source_id = request.id.clone();
576
577 embeddings.insert(request.id.clone(), embedding);
578 successful_count += 1;
579 }
580 Err(e) => {
581 failed_items.insert(request.id.clone(), e.to_string());
582 }
583 }
584 }
585
586 let batch = EmbeddingBatch {
587 embeddings,
588 metadata: BatchMetadata {
589 total_items: requests.len(),
590 successful_items: successful_count,
591 failed_items,
592 duration_ms: start.elapsed().as_millis() as u64,
593 model: self.model.clone(),
594 provider: self.name().to_string(),
595 },
596 };
597
598 Ok(batch)
599 }
600
601 async fn health_check(&self) -> RragResult<bool> {
602 // Mock health check - in production, this would ping the API
603 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
604 Ok(true)
605 }
606}
607
608/// Local embedding provider for offline and privacy-focused deployments
609///
610/// Supports local inference with Hugging Face models or custom implementations.
611/// Ideal for:
612/// - Privacy-sensitive applications
613/// - High-volume processing without API costs
614/// - Offline or air-gapped environments
615/// - Custom fine-tuned models
616///
617/// # Popular Models
618///
619/// - `sentence-transformers/all-MiniLM-L6-v2`: 384 dims, fast and lightweight
620/// - `sentence-transformers/all-mpnet-base-v2`: 768 dims, high quality
621/// - `sentence-transformers/all-distilroberta-v1`: 768 dims, balanced
622///
623/// # Example
624///
625/// ```rust
626/// use rrag::prelude::*;
627/// use std::sync::Arc;
628///
629/// # #[tokio::main]
630/// # async fn main() -> RragResult<()> {
631/// let provider = Arc::new(LocalEmbeddingProvider::new(
632/// "sentence-transformers/all-MiniLM-L6-v2",
633/// 384
634/// ));
635///
636/// let service = EmbeddingService::new(provider);
637/// let embeddings = service.embed_documents(&[
638/// Document::new("First document"),
639/// Document::new("Second document")
640/// ]).await?;
641///
642/// assert_eq!(embeddings.len(), 2);
643/// # Ok(())
644/// # }
645/// ```
646///
647/// # Performance Notes
648///
649/// - Uses parallel processing for batch operations
650/// - Smaller batch sizes recommended (16-32) for memory efficiency
651/// - CPU-intensive; consider GPU acceleration for large workloads
652pub struct LocalEmbeddingProvider {
653 /// Model name or path
654 model_path: String,
655
656 /// Embedding dimensions
657 dimensions: usize,
658}
659
660impl LocalEmbeddingProvider {
661 /// Create a new local embedding provider
662 pub fn new(model_path: impl Into<String>, dimensions: usize) -> Self {
663 Self {
664 model_path: model_path.into(),
665 dimensions,
666 }
667 }
668}
669
670#[async_trait]
671impl EmbeddingProvider for LocalEmbeddingProvider {
672 fn name(&self) -> &str {
673 "local"
674 }
675
676 fn supported_models(&self) -> Vec<&str> {
677 vec![
678 "sentence-transformers/all-MiniLM-L6-v2",
679 "custom-local-model",
680 ]
681 }
682
683 fn max_batch_size(&self) -> usize {
684 32 // Smaller batches for local processing
685 }
686
687 fn embedding_dimensions(&self) -> usize {
688 self.dimensions
689 }
690
691 async fn embed_text(&self, text: &str) -> RragResult<Embedding> {
692 // Mock local model inference
693 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
694
695 let vector: Vec<f32> = (0..self.dimensions)
696 .map(|i| ((text.len() * 31 + i * 17) % 1000) as f32 / 1000.0)
697 .collect();
698
699 Ok(
700 Embedding::new(vector, &self.model_path, text).with_metadata(
701 "provider",
702 serde_json::Value::String(self.name().to_string()),
703 ),
704 )
705 }
706
707 async fn embed_batch(&self, requests: Vec<EmbeddingRequest>) -> RragResult<EmbeddingBatch> {
708 let start = std::time::Instant::now();
709
710 let mut embeddings = HashMap::new();
711 let failed_items = HashMap::new();
712
713 // Process in parallel for local models
714 let futures: Vec<_> = requests
715 .iter()
716 .map(|req| async move {
717 let embedding = self.embed_text(&req.content).await?;
718 Ok::<_, RragError>((req.id.clone(), embedding))
719 })
720 .collect();
721
722 let results = futures::future::join_all(futures).await;
723
724 for result in results {
725 match result {
726 Ok((id, embedding)) => {
727 embeddings.insert(id, embedding);
728 }
729 Err(_) => {
730 // Error handling would be more sophisticated in production
731 }
732 }
733 }
734
735 let successful_items = embeddings.len();
736 let batch = EmbeddingBatch {
737 embeddings,
738 metadata: BatchMetadata {
739 total_items: requests.len(),
740 successful_items,
741 failed_items,
742 duration_ms: start.elapsed().as_millis() as u64,
743 model: self.model_path.clone(),
744 provider: self.name().to_string(),
745 },
746 };
747
748 Ok(batch)
749 }
750
751 async fn health_check(&self) -> RragResult<bool> {
752 // Check if model is loaded/accessible
753 Ok(true)
754 }
755}
756
757/// High-level embedding service with comprehensive provider management
758///
759/// The embedding service provides a unified interface for embedding generation
760/// with advanced features like automatic batching, retry logic, parallel processing,
761/// and comprehensive error handling.
762///
763/// # Features
764///
765/// - **Provider Abstraction**: Work with any embedding provider through a common interface
766/// - **Intelligent Batching**: Automatically batches requests for optimal performance
767/// - **Retry Logic**: Configurable retry with exponential backoff for transient failures
768/// - **Parallel Processing**: Concurrent processing for improved throughput
769/// - **Order Preservation**: Maintains document order in batch operations
770/// - **Metadata Propagation**: Preserves metadata through processing pipeline
771///
772/// # Example
773///
774/// ```rust
775/// use rrag::prelude::*;
776/// use std::sync::Arc;
777///
778/// # #[tokio::main]
779/// # async fn main() -> RragResult<()> {
780/// let provider = Arc::new(OpenAIEmbeddingProvider::new("api-key"));
781/// let service = EmbeddingService::with_config(
782/// provider,
783/// EmbeddingConfig {
784/// batch_size: 100,
785/// parallel_processing: true,
786/// max_retries: 3,
787/// retry_delay_ms: 1000,
788/// }
789/// );
790///
791/// // Process multiple documents efficiently
792/// let documents = vec![
793/// Document::new("First document"),
794/// Document::new("Second document"),
795/// ];
796///
797/// let embeddings = service.embed_documents(&documents).await?;
798/// println!("Generated {} embeddings", embeddings.len());
799/// # Ok(())
800/// # }
801/// ```
802///
803/// # Configuration Options
804///
805/// - `batch_size`: Number of items to process in each batch (default: 50)
806/// - `parallel_processing`: Whether to enable concurrent batch processing (default: true)
807/// - `max_retries`: Maximum retry attempts for failed requests (default: 3)
808/// - `retry_delay_ms`: Initial delay between retries, increases exponentially (default: 1000ms)
809pub struct EmbeddingService {
810 /// Active embedding provider
811 provider: Arc<dyn EmbeddingProvider>,
812
813 /// Service configuration
814 config: EmbeddingConfig,
815}
816
817/// Configuration for embedding service
818#[derive(Debug, Clone)]
819pub struct EmbeddingConfig {
820 /// Batch size for processing documents
821 pub batch_size: usize,
822
823 /// Whether to enable parallel processing
824 pub parallel_processing: bool,
825
826 /// Maximum retries for failed embeddings
827 pub max_retries: usize,
828
829 /// Retry delay in milliseconds
830 pub retry_delay_ms: u64,
831}
832
833impl Default for EmbeddingConfig {
834 fn default() -> Self {
835 Self {
836 batch_size: 50,
837 parallel_processing: true,
838 max_retries: 3,
839 retry_delay_ms: 1000,
840 }
841 }
842}
843
844impl EmbeddingService {
845 /// Create embedding service with provider
846 pub fn new(provider: Arc<dyn EmbeddingProvider>) -> Self {
847 Self {
848 provider,
849 config: EmbeddingConfig::default(),
850 }
851 }
852
853 /// Create service with configuration
854 pub fn with_config(provider: Arc<dyn EmbeddingProvider>, config: EmbeddingConfig) -> Self {
855 Self { provider, config }
856 }
857
858 /// Embed a single document
859 pub async fn embed_document(&self, document: &Document) -> RragResult<Embedding> {
860 self.provider.embed_text(document.content_str()).await
861 }
862
863 /// Embed multiple documents with batching
864 pub async fn embed_documents(&self, documents: &[Document]) -> RragResult<Vec<Embedding>> {
865 let requests: Vec<EmbeddingRequest> = documents
866 .iter()
867 .map(|doc| EmbeddingRequest::new(&doc.id, doc.content_str()))
868 .collect();
869
870 let batches = self.create_batches(requests);
871 let mut all_embeddings = Vec::new();
872
873 for batch in batches {
874 let batch_result = self.process_batch_with_retry(batch).await?;
875
876 // Collect embeddings in original order
877 for request_id in batch_result.embeddings.keys() {
878 if let Some(embedding) = batch_result.embeddings.get(request_id) {
879 all_embeddings.push(embedding.clone());
880 }
881 }
882 }
883
884 Ok(all_embeddings)
885 }
886
887 /// Embed document chunks efficiently
888 pub async fn embed_chunks(&self, chunks: &[DocumentChunk]) -> RragResult<Vec<Embedding>> {
889 let requests: Vec<EmbeddingRequest> = chunks
890 .iter()
891 .map(|chunk| {
892 EmbeddingRequest::new(
893 format!("{}_{}", chunk.document_id, chunk.chunk_index),
894 &chunk.content,
895 )
896 .with_metadata(
897 "chunk_index",
898 serde_json::Value::Number(chunk.chunk_index.into()),
899 )
900 .with_metadata(
901 "document_id",
902 serde_json::Value::String(chunk.document_id.clone()),
903 )
904 })
905 .collect();
906
907 let batches = self.create_batches(requests);
908 let mut all_embeddings = Vec::new();
909
910 for batch in batches {
911 let batch_result = self.process_batch_with_retry(batch).await?;
912
913 for embedding in batch_result.embeddings.into_values() {
914 all_embeddings.push(embedding);
915 }
916 }
917
918 Ok(all_embeddings)
919 }
920
921 /// Create batches from requests
922 fn create_batches(&self, requests: Vec<EmbeddingRequest>) -> Vec<Vec<EmbeddingRequest>> {
923 requests
924 .chunks(self.config.batch_size.min(self.provider.max_batch_size()))
925 .map(|chunk| chunk.to_vec())
926 .collect()
927 }
928
929 /// Process a batch with retry logic
930 async fn process_batch_with_retry(
931 &self,
932 batch: Vec<EmbeddingRequest>,
933 ) -> RragResult<EmbeddingBatch> {
934 let mut attempts = 0;
935 let mut last_error = None;
936
937 while attempts < self.config.max_retries {
938 match self.provider.embed_batch(batch.clone()).await {
939 Ok(result) => return Ok(result),
940 Err(e) => {
941 last_error = Some(e);
942 attempts += 1;
943
944 if attempts < self.config.max_retries {
945 tokio::time::sleep(std::time::Duration::from_millis(
946 self.config.retry_delay_ms * attempts as u64,
947 ))
948 .await;
949 }
950 }
951 }
952 }
953
954 Err(last_error
955 .unwrap_or_else(|| RragError::embedding("batch_processing", "Max retries exceeded")))
956 }
957
958 /// Get provider information
959 pub fn provider_info(&self) -> ProviderInfo {
960 ProviderInfo {
961 name: self.provider.name().to_string(),
962 supported_models: self
963 .provider
964 .supported_models()
965 .iter()
966 .map(|s| s.to_string())
967 .collect(),
968 max_batch_size: self.provider.max_batch_size(),
969 embedding_dimensions: self.provider.embedding_dimensions(),
970 }
971 }
972}
973
974/// Provider information for introspection
975#[derive(Debug, Clone, Serialize, Deserialize)]
976pub struct ProviderInfo {
977 /// Provider name
978 pub name: String,
979 /// List of supported model names
980 pub supported_models: Vec<String>,
981 /// Maximum batch size for efficient processing
982 pub max_batch_size: usize,
983 /// Number of dimensions in the embedding vectors
984 pub embedding_dimensions: usize,
985}
986
987/// Mock embedding provider for testing and development
988///
989/// Provides deterministic, fast embedding generation for testing purposes.
990/// The mock provider generates consistent embeddings based on input text,
991/// making it suitable for unit tests and development workflows.
992///
993/// # Features
994///
995/// - **Deterministic**: Same input always produces the same embedding
996/// - **Fast**: No network calls or heavy computation
997/// - **Configurable**: Adjustable dimensions and behavior
998/// - **Testing-Friendly**: Predictable behavior for assertions
999///
1000/// # Example
1001///
1002/// ```rust
1003/// use rrag::prelude::*;
1004/// use std::sync::Arc;
1005///
1006/// # #[tokio::main]
1007/// # async fn main() -> RragResult<()> {
1008/// let provider = Arc::new(MockEmbeddingProvider::new());
1009/// let service = EmbeddingService::new(provider);
1010///
1011/// let document = Document::new("Test content");
1012/// let embedding = service.embed_document(&document).await?;
1013///
1014/// // Mock provider always returns 384 dimensions
1015/// assert_eq!(embedding.dimensions, 384);
1016/// assert_eq!(embedding.model, "mock-model");
1017/// # Ok(())
1018/// # }
1019/// ```
1020pub struct MockEmbeddingProvider {
1021 model: String,
1022 dimensions: usize,
1023}
1024
1025impl MockEmbeddingProvider {
1026 /// Create a new mock embedding provider
1027 pub fn new() -> Self {
1028 Self {
1029 model: "mock-model".to_string(),
1030 dimensions: 384,
1031 }
1032 }
1033}
1034
1035#[async_trait]
1036impl EmbeddingProvider for MockEmbeddingProvider {
1037 fn name(&self) -> &str {
1038 "mock"
1039 }
1040
1041 fn supported_models(&self) -> Vec<&str> {
1042 vec!["mock-model"]
1043 }
1044
1045 fn max_batch_size(&self) -> usize {
1046 100
1047 }
1048
1049 fn embedding_dimensions(&self) -> usize {
1050 self.dimensions
1051 }
1052
1053 async fn embed_text(&self, text: &str) -> RragResult<Embedding> {
1054 // Generate a simple mock embedding based on text hash
1055 let hash = text.len() as f32;
1056 let mut vector = vec![0.0; self.dimensions];
1057 for i in 0..self.dimensions {
1058 vector[i] = (hash + i as f32).sin() / (i + 1) as f32;
1059 }
1060
1061 Ok(Embedding::new(vector, &self.model, "mock"))
1062 }
1063
1064 async fn embed_batch(&self, requests: Vec<EmbeddingRequest>) -> RragResult<EmbeddingBatch> {
1065 let mut embeddings = HashMap::new();
1066
1067 for request in &requests {
1068 let embedding = self.embed_text(&request.content).await?;
1069 embeddings.insert(request.id.clone(), embedding);
1070 }
1071
1072 Ok(EmbeddingBatch {
1073 embeddings,
1074 metadata: BatchMetadata {
1075 total_items: requests.len(),
1076 successful_items: requests.len(),
1077 failed_items: HashMap::new(),
1078 duration_ms: 10,
1079 model: self.model.clone(),
1080 provider: self.name().to_string(),
1081 },
1082 })
1083 }
1084
1085 async fn health_check(&self) -> RragResult<bool> {
1086 Ok(true)
1087 }
1088}
1089
1090// For backward compatibility
1091/// Type alias for MockEmbeddingProvider for backward compatibility
1092pub type MockEmbeddingService = MockEmbeddingProvider;
1093
1094#[cfg(test)]
1095mod tests {
1096 use super::*;
1097
1098 #[tokio::test]
1099 async fn test_openai_provider() {
1100 let provider = OpenAIEmbeddingProvider::new("test-key");
1101
1102 assert_eq!(provider.name(), "openai");
1103 assert_eq!(provider.embedding_dimensions(), 1536);
1104 assert!(provider.health_check().await.unwrap());
1105
1106 let embedding = provider.embed_text("Hello, world!").await.unwrap();
1107 assert_eq!(embedding.dimensions, 1536);
1108 assert_eq!(embedding.model, "text-embedding-ada-002");
1109 }
1110
1111 #[tokio::test]
1112 async fn test_local_provider() {
1113 let provider = LocalEmbeddingProvider::new("test-model", 384);
1114
1115 assert_eq!(provider.name(), "local");
1116 assert_eq!(provider.embedding_dimensions(), 384);
1117
1118 let embedding = provider.embed_text("Test content").await.unwrap();
1119 assert_eq!(embedding.dimensions, 384);
1120 }
1121
1122 #[tokio::test]
1123 async fn test_embedding_service() {
1124 let provider = Arc::new(LocalEmbeddingProvider::new("test-model", 384));
1125 let service = EmbeddingService::new(provider);
1126
1127 let doc = Document::new("Test document content");
1128 let embedding = service.embed_document(&doc).await.unwrap();
1129
1130 assert_eq!(embedding.dimensions, 384);
1131 assert!(!embedding.vector.is_empty());
1132 }
1133
1134 #[test]
1135 fn test_cosine_similarity() {
1136 let vector1 = vec![1.0, 0.0, 0.0];
1137 let vector2 = vec![0.0, 1.0, 0.0];
1138 let vector3 = vec![1.0, 0.0, 0.0];
1139
1140 let emb1 = Embedding::new(vector1, "test", "1");
1141 let emb2 = Embedding::new(vector2, "test", "2");
1142 let emb3 = Embedding::new(vector3, "test", "3");
1143
1144 let similarity_12 = emb1.cosine_similarity(&emb2).unwrap();
1145 let similarity_13 = emb1.cosine_similarity(&emb3).unwrap();
1146
1147 assert!((similarity_12 - 0.0).abs() < 1e-6); // Orthogonal vectors
1148 assert!((similarity_13 - 1.0).abs() < 1e-6); // Identical vectors
1149 }
1150
1151 #[tokio::test]
1152 async fn test_batch_processing() {
1153 let provider = Arc::new(LocalEmbeddingProvider::new("test-model", 128));
1154
1155 let requests = vec![
1156 EmbeddingRequest::new("1", "First text"),
1157 EmbeddingRequest::new("2", "Second text"),
1158 EmbeddingRequest::new("3", "Third text"),
1159 ];
1160
1161 let batch_result = provider.embed_batch(requests).await.unwrap();
1162
1163 assert_eq!(batch_result.metadata.total_items, 3);
1164 assert_eq!(batch_result.metadata.successful_items, 3);
1165 assert_eq!(batch_result.embeddings.len(), 3);
1166 }
1167}