rrag/
embeddings.rs

1//! # RRAG Embeddings System
2//!
3//! High-performance, async-first embedding generation with pluggable providers,
4//! efficient batching, and comprehensive error handling. Built for production
5//! workloads with robust retry logic and monitoring capabilities.
6//!
7//! ## Features
8//!
9//! - **Multi-Provider Support**: OpenAI, local models, and custom providers
10//! - **Efficient Batching**: Automatic batching with configurable sizes
11//! - **Retry Logic**: Robust error handling with exponential backoff
12//! - **Parallel Processing**: Concurrent embedding generation
13//! - **Similarity Metrics**: Built-in cosine similarity and distance calculations
14//! - **Metadata Support**: Rich metadata tracking for embeddings
15//! - **Health Monitoring**: Provider health checks and monitoring
16//!
17//! ## Quick Start
18//!
19//! ```rust
20//! use rrag::prelude::*;
21//! use std::sync::Arc;
22//!
23//! # #[tokio::main]
24//! # async fn main() -> RragResult<()> {
25//! // Create an OpenAI provider
26//! let provider = Arc::new(OpenAIEmbeddingProvider::new("your-api-key")
27//!     .with_model("text-embedding-3-small"));
28//!
29//! // Create embedding service
30//! let service = EmbeddingService::new(provider);
31//!
32//! // Embed a document
33//! let document = Document::new("The quick brown fox jumps over the lazy dog");
34//! let embedding = service.embed_document(&document).await?;
35//!
36//! println!("Generated embedding with {} dimensions", embedding.dimensions);
37//! # Ok(())
38//! # }
39//! ```
40//!
41//! ## Batch Processing
42//!
43//! For high-throughput scenarios, use batch processing:
44//!
45//! ```rust
46//! use rrag::prelude::*;
47//! use std::sync::Arc;
48//!
49//! # #[tokio::main]
50//! # async fn main() -> RragResult<()> {
51//! let provider = Arc::new(LocalEmbeddingProvider::new("sentence-transformers/all-MiniLM-L6-v2", 384));
52//! let service = EmbeddingService::with_config(
53//!     provider,
54//!     EmbeddingConfig {
55//!         batch_size: 50,
56//!         parallel_processing: true,
57//!         max_retries: 3,
58//!         retry_delay_ms: 1000,
59//!     }
60//! );
61//!
62//! let documents = vec![
63//!     Document::new("First document"),
64//!     Document::new("Second document"),
65//!     Document::new("Third document"),
66//! ];
67//!
68//! let embeddings = service.embed_documents(&documents).await?;
69//! println!("Generated {} embeddings", embeddings.len());
70//! # Ok(())
71//! # }
72//! ```
73//!
74//! ## Custom Providers
75//!
76//! Implement the [`EmbeddingProvider`] trait to create custom providers:
77//!
78//! ```rust
79//! use rrag::prelude::*;
80//! use async_trait::async_trait;
81//!
82//! struct MyCustomProvider {
83//!     model_name: String,
84//! }
85//!
86//! #[async_trait]
87//! impl EmbeddingProvider for MyCustomProvider {
88//!     fn name(&self) -> &str { "custom" }
89//!     fn supported_models(&self) -> Vec<&str> { vec!["custom-model-v1"] }
90//!     fn max_batch_size(&self) -> usize { 32 }
91//!     fn embedding_dimensions(&self) -> usize { 512 }
92//!
93//!     async fn embed_text(&self, text: &str) -> RragResult<Embedding> {
94//!         // Your custom embedding logic here
95//!         # let vector = vec![0.0; 512];
96//!         # Ok(Embedding::new(vector, &self.model_name, text))
97//!     }
98//!
99//!     async fn embed_batch(&self, requests: Vec<EmbeddingRequest>) -> RragResult<EmbeddingBatch> {
100//!         // Your custom batch processing logic
101//!         # todo!()
102//!     }
103//!
104//!     async fn health_check(&self) -> RragResult<bool> {
105//!         // Your health check logic
106//!         Ok(true)
107//!     }
108//! }
109//! ```
110//!
111//! ## Performance Considerations
112//!
113//! - Use batch processing for multiple documents to reduce API overhead
114//! - Configure appropriate batch sizes based on your provider's limits
115//! - Enable parallel processing for local models to utilize multiple CPU cores
116//! - Monitor provider health and implement fallback strategies
117//! - Cache embeddings when possible to avoid redundant API calls
118//!
119//! ## Error Handling
120//!
121//! The embedding system provides detailed error information:
122//!
123//! ```rust
124//! use rrag::prelude::*;
125//!
126//! # #[tokio::main]
127//! # async fn main() {
128//! match service.embed_document(&document).await {
129//!     Ok(embedding) => {
130//!         println!("Success: {} dimensions", embedding.dimensions);
131//!     }
132//!     Err(RragError::Embedding { content_type, message, .. }) => {
133//!         eprintln!("Embedding error for {}: {}", content_type, message);
134//!     }
135//!     Err(e) => {
136//!         eprintln!("Other error: {}", e);
137//!     }
138//! }
139//! # }
140//! ```
141
142use crate::{Document, DocumentChunk, RragError, RragResult};
143use async_trait::async_trait;
144use serde::{Deserialize, Serialize};
145use std::collections::HashMap;
146use std::sync::Arc;
147
148/// Embedding vector type optimized for common dimensions
149///
150/// Uses `Vec<f32>` for maximum compatibility with ML libraries and APIs.
151/// Common dimensions:
152/// - 384: sentence-transformers/all-MiniLM-L6-v2
153/// - 768: BERT-base models
154/// - 1536: OpenAI text-embedding-ada-002
155/// - 3072: OpenAI text-embedding-3-large
156pub type EmbeddingVector = Vec<f32>;
157
158/// A dense vector representation of text content with metadata
159///
160/// Embeddings are high-dimensional vectors that capture semantic meaning
161/// of text in a format suitable for similarity comparison and retrieval.
162/// Each embedding includes the vector data, source information, and
163/// generation metadata.
164///
165/// # Example
166///
167/// ```rust
168/// use rrag::prelude::*;
169///
170/// let vector = vec![0.1, -0.2, 0.3, 0.4]; // 4-dimensional example
171/// let embedding = Embedding::new(vector, "text-embedding-ada-002", "doc-123")
172///     .with_metadata("content_type", "paragraph".into())
173///     .with_metadata("language", "en".into());
174///
175/// assert_eq!(embedding.dimensions, 4);
176/// assert_eq!(embedding.model, "text-embedding-ada-002");
177/// ```
178#[derive(Debug, Clone, Serialize, Deserialize)]
179pub struct Embedding {
180    /// The actual embedding vector
181    pub vector: EmbeddingVector,
182
183    /// Dimensions of the embedding
184    pub dimensions: usize,
185
186    /// Model used to generate this embedding
187    pub model: String,
188
189    /// Source content identifier
190    pub source_id: String,
191
192    /// Embedding metadata
193    pub metadata: HashMap<String, serde_json::Value>,
194
195    /// Generation timestamp
196    pub created_at: chrono::DateTime<chrono::Utc>,
197}
198
199impl Embedding {
200    /// Create a new embedding
201    pub fn new(
202        vector: EmbeddingVector,
203        model: impl Into<String>,
204        source_id: impl Into<String>,
205    ) -> Self {
206        let dimensions = vector.len();
207        Self {
208            vector,
209            dimensions,
210            model: model.into(),
211            source_id: source_id.into(),
212            metadata: HashMap::new(),
213            created_at: chrono::Utc::now(),
214        }
215    }
216
217    /// Add metadata using builder pattern
218    pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
219        self.metadata.insert(key.into(), value);
220        self
221    }
222
223    /// Calculate cosine similarity with another embedding
224    pub fn cosine_similarity(&self, other: &Embedding) -> RragResult<f32> {
225        if self.dimensions != other.dimensions {
226            return Err(RragError::embedding(
227                "similarity_calculation",
228                format!(
229                    "Dimension mismatch: {} vs {}",
230                    self.dimensions, other.dimensions
231                ),
232            ));
233        }
234
235        let dot_product: f32 = self
236            .vector
237            .iter()
238            .zip(other.vector.iter())
239            .map(|(a, b)| a * b)
240            .sum();
241
242        let norm_a: f32 = self.vector.iter().map(|x| x * x).sum::<f32>().sqrt();
243        let norm_b: f32 = other.vector.iter().map(|x| x * x).sum::<f32>().sqrt();
244
245        if norm_a == 0.0 || norm_b == 0.0 {
246            return Ok(0.0);
247        }
248
249        Ok(dot_product / (norm_a * norm_b))
250    }
251
252    /// Calculate Euclidean distance with another embedding
253    pub fn euclidean_distance(&self, other: &Embedding) -> RragResult<f32> {
254        if self.dimensions != other.dimensions {
255            return Err(RragError::embedding(
256                "distance_calculation",
257                format!(
258                    "Dimension mismatch: {} vs {}",
259                    self.dimensions, other.dimensions
260                ),
261            ));
262        }
263
264        let distance: f32 = self
265            .vector
266            .iter()
267            .zip(other.vector.iter())
268            .map(|(a, b)| (a - b).powi(2))
269            .sum::<f32>()
270            .sqrt();
271
272        Ok(distance)
273    }
274}
275
276/// A request for embedding generation, used in batch processing
277///
278/// Embedding requests bundle text content with metadata and a unique
279/// identifier for efficient batch processing and result tracking.
280///
281/// # Example
282///
283/// ```rust
284/// use rrag::prelude::*;
285///
286/// let request = EmbeddingRequest::new("chunk-1", "The quick brown fox")
287///     .with_metadata("chunk_index", 0.into())
288///     .with_metadata("document_id", "doc-123".into());
289/// ```
290#[derive(Debug, Clone)]
291pub struct EmbeddingRequest {
292    /// Unique identifier for the request
293    pub id: String,
294
295    /// Text content to embed
296    pub content: String,
297
298    /// Optional metadata to attach to the embedding
299    pub metadata: HashMap<String, serde_json::Value>,
300}
301
302impl EmbeddingRequest {
303    /// Create a new embedding request
304    pub fn new(id: impl Into<String>, content: impl Into<String>) -> Self {
305        Self {
306            id: id.into(),
307            content: content.into(),
308            metadata: HashMap::new(),
309        }
310    }
311
312    /// Add metadata to the embedding request
313    pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
314        self.metadata.insert(key.into(), value);
315        self
316    }
317}
318
319/// Batch embedding response
320#[derive(Debug)]
321pub struct EmbeddingBatch {
322    /// Generated embeddings indexed by request ID
323    pub embeddings: HashMap<String, Embedding>,
324
325    /// Processing metadata
326    pub metadata: BatchMetadata,
327}
328
329/// Metadata for batch processing
330#[derive(Debug, Clone, Serialize, Deserialize)]
331pub struct BatchMetadata {
332    /// Total items processed
333    pub total_items: usize,
334
335    /// Successfully processed items
336    pub successful_items: usize,
337
338    /// Failed items with error messages
339    pub failed_items: HashMap<String, String>,
340
341    /// Processing duration in milliseconds
342    pub duration_ms: u64,
343
344    /// Model used for embedding
345    pub model: String,
346
347    /// Provider used
348    pub provider: String,
349}
350
351/// Core trait for embedding providers
352///
353/// This trait defines the interface that all embedding providers must implement.
354/// Providers can be remote APIs (OpenAI, Cohere), local models (Hugging Face),
355/// or custom implementations.
356///
357/// The trait is designed for async operation and includes methods for:
358/// - Single text embedding
359/// - Batch processing for efficiency
360/// - Health monitoring
361/// - Provider introspection
362///
363/// # Implementation Guidelines
364///
365/// - Implement proper error handling with detailed context
366/// - Use appropriate timeouts for network operations
367/// - Support cancellation via async context
368/// - Provide accurate metadata in batch responses
369/// - Implement health checks that verify actual connectivity
370///
371/// # Example
372///
373/// ```rust
374/// use rrag::prelude::*;
375/// use async_trait::async_trait;
376///
377/// struct MyProvider;
378///
379/// #[async_trait]
380/// impl EmbeddingProvider for MyProvider {
381///     fn name(&self) -> &str { "my-provider" }
382///     fn supported_models(&self) -> Vec<&str> { vec!["model-v1"] }
383///     fn max_batch_size(&self) -> usize { 100 }
384///     fn embedding_dimensions(&self) -> usize { 768 }
385///
386///     async fn embed_text(&self, text: &str) -> RragResult<Embedding> {
387///         // Implementation here
388///         # todo!()
389///     }
390///
391///     async fn embed_batch(&self, requests: Vec<EmbeddingRequest>) -> RragResult<EmbeddingBatch> {
392///         // Batch implementation here
393///         # todo!()
394///     }
395///
396///     async fn health_check(&self) -> RragResult<bool> {
397///         // Health check implementation
398///         Ok(true)
399///     }
400/// }
401/// ```
402#[async_trait]
403pub trait EmbeddingProvider: Send + Sync {
404    /// Provider name (e.g., "openai", "huggingface")
405    fn name(&self) -> &str;
406
407    /// Supported models for this provider
408    fn supported_models(&self) -> Vec<&str>;
409
410    /// Maximum batch size supported
411    fn max_batch_size(&self) -> usize;
412
413    /// Embedding dimensions for the current model
414    fn embedding_dimensions(&self) -> usize;
415
416    /// Generate embedding for a single text
417    async fn embed_text(&self, text: &str) -> RragResult<Embedding>;
418
419    /// Generate embeddings for multiple texts (more efficient)
420    async fn embed_batch(&self, requests: Vec<EmbeddingRequest>) -> RragResult<EmbeddingBatch>;
421
422    /// Health check for the provider
423    async fn health_check(&self) -> RragResult<bool>;
424}
425
426/// OpenAI embedding provider for production-grade text embeddings
427///
428/// Provides access to OpenAI's embedding models through their API.
429/// Supports all current OpenAI embedding models with automatic
430/// batching and retry logic.
431///
432/// # Supported Models
433///
434/// - `text-embedding-ada-002`: Legacy model, 1536 dimensions
435/// - `text-embedding-3-small`: New model, 1536 dimensions, better performance
436/// - `text-embedding-3-large`: Premium model, 3072 dimensions, highest quality
437///
438/// # Example
439///
440/// ```rust
441/// use rrag::prelude::*;
442/// use std::sync::Arc;
443///
444/// # #[tokio::main]
445/// # async fn main() -> RragResult<()> {
446/// let provider = Arc::new(
447///     OpenAIEmbeddingProvider::new(std::env::var("OPENAI_API_KEY").unwrap())
448///         .with_model("text-embedding-3-small")
449/// );
450///
451/// let service = EmbeddingService::new(provider);
452/// let embedding = service.embed_document(&Document::new("Hello world")).await?;
453///
454/// assert_eq!(embedding.dimensions, 1536);
455/// # Ok(())
456/// # }
457/// ```
458///
459/// # Rate Limits
460///
461/// OpenAI has rate limits that vary by plan:
462/// - Free tier: 3 RPM, 150,000 TPM
463/// - Pay-as-you-go: 3,000 RPM, 1,000,000 TPM
464///
465/// The provider automatically handles batching within these limits.
466#[allow(dead_code)]
467pub struct OpenAIEmbeddingProvider {
468    /// API client (placeholder - would use actual HTTP client)
469    client: String, // In production: reqwest::Client or rsllm client
470
471    /// Model to use for embeddings
472    model: String,
473
474    /// API key
475    api_key: String,
476
477    /// Request timeout
478    timeout: std::time::Duration,
479}
480
481impl OpenAIEmbeddingProvider {
482    /// Create a new OpenAI embedding provider
483    pub fn new(api_key: impl Into<String>) -> Self {
484        Self {
485            client: "openai_client".to_string(), // Placeholder
486            model: "text-embedding-ada-002".to_string(),
487            api_key: api_key.into(),
488            timeout: std::time::Duration::from_secs(30),
489        }
490    }
491
492    /// Set the model to use for embeddings
493    pub fn with_model(mut self, model: impl Into<String>) -> Self {
494        self.model = model.into();
495        self
496    }
497}
498
499#[async_trait]
500impl EmbeddingProvider for OpenAIEmbeddingProvider {
501    fn name(&self) -> &str {
502        "openai"
503    }
504
505    fn supported_models(&self) -> Vec<&str> {
506        vec![
507            "text-embedding-ada-002",
508            "text-embedding-3-small",
509            "text-embedding-3-large",
510        ]
511    }
512
513    fn max_batch_size(&self) -> usize {
514        100 // OpenAI's current limit
515    }
516
517    fn embedding_dimensions(&self) -> usize {
518        match self.model.as_str() {
519            "text-embedding-ada-002" => 1536,
520            "text-embedding-3-small" => 1536,
521            "text-embedding-3-large" => 3072,
522            _ => 1536, // Default fallback
523        }
524    }
525
526    async fn embed_text(&self, text: &str) -> RragResult<Embedding> {
527        // Mock implementation - in production, this would make actual API calls
528        let start = std::time::Instant::now();
529
530        // Simulate API delay
531        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
532
533        // Generate mock embedding vector
534        let dimensions = self.embedding_dimensions();
535        let vector: Vec<f32> = (0..dimensions)
536            .map(|i| (text.len() as f32 + i as f32) / 1000.0)
537            .collect();
538
539        let embedding = Embedding::new(vector, &self.model, text)
540            .with_metadata(
541                "processing_time_ms",
542                serde_json::Value::Number((start.elapsed().as_millis() as u64).into()),
543            )
544            .with_metadata(
545                "provider",
546                serde_json::Value::String(self.name().to_string()),
547            );
548
549        Ok(embedding)
550    }
551
552    async fn embed_batch(&self, requests: Vec<EmbeddingRequest>) -> RragResult<EmbeddingBatch> {
553        let start = std::time::Instant::now();
554
555        if requests.len() > self.max_batch_size() {
556            return Err(RragError::embedding(
557                "batch_processing",
558                format!(
559                    "Batch size {} exceeds maximum {}",
560                    requests.len(),
561                    self.max_batch_size()
562                ),
563            ));
564        }
565
566        let mut embeddings = HashMap::new();
567        let mut failed_items = HashMap::new();
568        let mut successful_count = 0;
569
570        for request in requests.iter() {
571            match self.embed_text(&request.content).await {
572                Ok(mut embedding) => {
573                    // Merge request metadata
574                    embedding.metadata.extend(request.metadata.clone());
575                    embedding.source_id = request.id.clone();
576
577                    embeddings.insert(request.id.clone(), embedding);
578                    successful_count += 1;
579                }
580                Err(e) => {
581                    failed_items.insert(request.id.clone(), e.to_string());
582                }
583            }
584        }
585
586        let batch = EmbeddingBatch {
587            embeddings,
588            metadata: BatchMetadata {
589                total_items: requests.len(),
590                successful_items: successful_count,
591                failed_items,
592                duration_ms: start.elapsed().as_millis() as u64,
593                model: self.model.clone(),
594                provider: self.name().to_string(),
595            },
596        };
597
598        Ok(batch)
599    }
600
601    async fn health_check(&self) -> RragResult<bool> {
602        // Mock health check - in production, this would ping the API
603        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
604        Ok(true)
605    }
606}
607
608/// Local embedding provider for offline and privacy-focused deployments
609///
610/// Supports local inference with Hugging Face models or custom implementations.
611/// Ideal for:
612/// - Privacy-sensitive applications
613/// - High-volume processing without API costs
614/// - Offline or air-gapped environments
615/// - Custom fine-tuned models
616///
617/// # Popular Models
618///
619/// - `sentence-transformers/all-MiniLM-L6-v2`: 384 dims, fast and lightweight
620/// - `sentence-transformers/all-mpnet-base-v2`: 768 dims, high quality
621/// - `sentence-transformers/all-distilroberta-v1`: 768 dims, balanced
622///
623/// # Example
624///
625/// ```rust
626/// use rrag::prelude::*;
627/// use std::sync::Arc;
628///
629/// # #[tokio::main]
630/// # async fn main() -> RragResult<()> {
631/// let provider = Arc::new(LocalEmbeddingProvider::new(
632///     "sentence-transformers/all-MiniLM-L6-v2",
633///     384
634/// ));
635///
636/// let service = EmbeddingService::new(provider);
637/// let embeddings = service.embed_documents(&[
638///     Document::new("First document"),
639///     Document::new("Second document")
640/// ]).await?;
641///
642/// assert_eq!(embeddings.len(), 2);
643/// # Ok(())
644/// # }
645/// ```
646///
647/// # Performance Notes
648///
649/// - Uses parallel processing for batch operations
650/// - Smaller batch sizes recommended (16-32) for memory efficiency
651/// - CPU-intensive; consider GPU acceleration for large workloads
652pub struct LocalEmbeddingProvider {
653    /// Model name or path
654    model_path: String,
655
656    /// Embedding dimensions
657    dimensions: usize,
658}
659
660impl LocalEmbeddingProvider {
661    /// Create a new local embedding provider
662    pub fn new(model_path: impl Into<String>, dimensions: usize) -> Self {
663        Self {
664            model_path: model_path.into(),
665            dimensions,
666        }
667    }
668}
669
670#[async_trait]
671impl EmbeddingProvider for LocalEmbeddingProvider {
672    fn name(&self) -> &str {
673        "local"
674    }
675
676    fn supported_models(&self) -> Vec<&str> {
677        vec![
678            "sentence-transformers/all-MiniLM-L6-v2",
679            "custom-local-model",
680        ]
681    }
682
683    fn max_batch_size(&self) -> usize {
684        32 // Smaller batches for local processing
685    }
686
687    fn embedding_dimensions(&self) -> usize {
688        self.dimensions
689    }
690
691    async fn embed_text(&self, text: &str) -> RragResult<Embedding> {
692        // Mock local model inference
693        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
694
695        let vector: Vec<f32> = (0..self.dimensions)
696            .map(|i| ((text.len() * 31 + i * 17) % 1000) as f32 / 1000.0)
697            .collect();
698
699        Ok(
700            Embedding::new(vector, &self.model_path, text).with_metadata(
701                "provider",
702                serde_json::Value::String(self.name().to_string()),
703            ),
704        )
705    }
706
707    async fn embed_batch(&self, requests: Vec<EmbeddingRequest>) -> RragResult<EmbeddingBatch> {
708        let start = std::time::Instant::now();
709
710        let mut embeddings = HashMap::new();
711        let failed_items = HashMap::new();
712
713        // Process in parallel for local models
714        let futures: Vec<_> = requests
715            .iter()
716            .map(|req| async move {
717                let embedding = self.embed_text(&req.content).await?;
718                Ok::<_, RragError>((req.id.clone(), embedding))
719            })
720            .collect();
721
722        let results = futures::future::join_all(futures).await;
723
724        for result in results {
725            match result {
726                Ok((id, embedding)) => {
727                    embeddings.insert(id, embedding);
728                }
729                Err(_) => {
730                    // Error handling would be more sophisticated in production
731                }
732            }
733        }
734
735        let successful_items = embeddings.len();
736        let batch = EmbeddingBatch {
737            embeddings,
738            metadata: BatchMetadata {
739                total_items: requests.len(),
740                successful_items,
741                failed_items,
742                duration_ms: start.elapsed().as_millis() as u64,
743                model: self.model_path.clone(),
744                provider: self.name().to_string(),
745            },
746        };
747
748        Ok(batch)
749    }
750
751    async fn health_check(&self) -> RragResult<bool> {
752        // Check if model is loaded/accessible
753        Ok(true)
754    }
755}
756
757/// High-level embedding service with comprehensive provider management
758///
759/// The embedding service provides a unified interface for embedding generation
760/// with advanced features like automatic batching, retry logic, parallel processing,
761/// and comprehensive error handling.
762///
763/// # Features
764///
765/// - **Provider Abstraction**: Work with any embedding provider through a common interface
766/// - **Intelligent Batching**: Automatically batches requests for optimal performance
767/// - **Retry Logic**: Configurable retry with exponential backoff for transient failures
768/// - **Parallel Processing**: Concurrent processing for improved throughput
769/// - **Order Preservation**: Maintains document order in batch operations
770/// - **Metadata Propagation**: Preserves metadata through processing pipeline
771///
772/// # Example
773///
774/// ```rust
775/// use rrag::prelude::*;
776/// use std::sync::Arc;
777///
778/// # #[tokio::main]
779/// # async fn main() -> RragResult<()> {
780/// let provider = Arc::new(OpenAIEmbeddingProvider::new("api-key"));
781/// let service = EmbeddingService::with_config(
782///     provider,
783///     EmbeddingConfig {
784///         batch_size: 100,
785///         parallel_processing: true,
786///         max_retries: 3,
787///         retry_delay_ms: 1000,
788///     }
789/// );
790///
791/// // Process multiple documents efficiently
792/// let documents = vec![
793///     Document::new("First document"),
794///     Document::new("Second document"),
795/// ];
796///
797/// let embeddings = service.embed_documents(&documents).await?;
798/// println!("Generated {} embeddings", embeddings.len());
799/// # Ok(())
800/// # }
801/// ```
802///
803/// # Configuration Options
804///
805/// - `batch_size`: Number of items to process in each batch (default: 50)
806/// - `parallel_processing`: Whether to enable concurrent batch processing (default: true)
807/// - `max_retries`: Maximum retry attempts for failed requests (default: 3)
808/// - `retry_delay_ms`: Initial delay between retries, increases exponentially (default: 1000ms)
809pub struct EmbeddingService {
810    /// Active embedding provider
811    provider: Arc<dyn EmbeddingProvider>,
812
813    /// Service configuration
814    config: EmbeddingConfig,
815}
816
817/// Configuration for embedding service
818#[derive(Debug, Clone)]
819pub struct EmbeddingConfig {
820    /// Batch size for processing documents
821    pub batch_size: usize,
822
823    /// Whether to enable parallel processing
824    pub parallel_processing: bool,
825
826    /// Maximum retries for failed embeddings
827    pub max_retries: usize,
828
829    /// Retry delay in milliseconds
830    pub retry_delay_ms: u64,
831}
832
833impl Default for EmbeddingConfig {
834    fn default() -> Self {
835        Self {
836            batch_size: 50,
837            parallel_processing: true,
838            max_retries: 3,
839            retry_delay_ms: 1000,
840        }
841    }
842}
843
844impl EmbeddingService {
845    /// Create embedding service with provider
846    pub fn new(provider: Arc<dyn EmbeddingProvider>) -> Self {
847        Self {
848            provider,
849            config: EmbeddingConfig::default(),
850        }
851    }
852
853    /// Create service with configuration
854    pub fn with_config(provider: Arc<dyn EmbeddingProvider>, config: EmbeddingConfig) -> Self {
855        Self { provider, config }
856    }
857
858    /// Embed a single document
859    pub async fn embed_document(&self, document: &Document) -> RragResult<Embedding> {
860        self.provider.embed_text(document.content_str()).await
861    }
862
863    /// Embed multiple documents with batching
864    pub async fn embed_documents(&self, documents: &[Document]) -> RragResult<Vec<Embedding>> {
865        let requests: Vec<EmbeddingRequest> = documents
866            .iter()
867            .map(|doc| EmbeddingRequest::new(&doc.id, doc.content_str()))
868            .collect();
869
870        let batches = self.create_batches(requests);
871        let mut all_embeddings = Vec::new();
872
873        for batch in batches {
874            let batch_result = self.process_batch_with_retry(batch).await?;
875
876            // Collect embeddings in original order
877            for request_id in batch_result.embeddings.keys() {
878                if let Some(embedding) = batch_result.embeddings.get(request_id) {
879                    all_embeddings.push(embedding.clone());
880                }
881            }
882        }
883
884        Ok(all_embeddings)
885    }
886
887    /// Embed document chunks efficiently
888    pub async fn embed_chunks(&self, chunks: &[DocumentChunk]) -> RragResult<Vec<Embedding>> {
889        let requests: Vec<EmbeddingRequest> = chunks
890            .iter()
891            .map(|chunk| {
892                EmbeddingRequest::new(
893                    format!("{}_{}", chunk.document_id, chunk.chunk_index),
894                    &chunk.content,
895                )
896                .with_metadata(
897                    "chunk_index",
898                    serde_json::Value::Number(chunk.chunk_index.into()),
899                )
900                .with_metadata(
901                    "document_id",
902                    serde_json::Value::String(chunk.document_id.clone()),
903                )
904            })
905            .collect();
906
907        let batches = self.create_batches(requests);
908        let mut all_embeddings = Vec::new();
909
910        for batch in batches {
911            let batch_result = self.process_batch_with_retry(batch).await?;
912
913            for embedding in batch_result.embeddings.into_values() {
914                all_embeddings.push(embedding);
915            }
916        }
917
918        Ok(all_embeddings)
919    }
920
921    /// Create batches from requests
922    fn create_batches(&self, requests: Vec<EmbeddingRequest>) -> Vec<Vec<EmbeddingRequest>> {
923        requests
924            .chunks(self.config.batch_size.min(self.provider.max_batch_size()))
925            .map(|chunk| chunk.to_vec())
926            .collect()
927    }
928
929    /// Process a batch with retry logic
930    async fn process_batch_with_retry(
931        &self,
932        batch: Vec<EmbeddingRequest>,
933    ) -> RragResult<EmbeddingBatch> {
934        let mut attempts = 0;
935        let mut last_error = None;
936
937        while attempts < self.config.max_retries {
938            match self.provider.embed_batch(batch.clone()).await {
939                Ok(result) => return Ok(result),
940                Err(e) => {
941                    last_error = Some(e);
942                    attempts += 1;
943
944                    if attempts < self.config.max_retries {
945                        tokio::time::sleep(std::time::Duration::from_millis(
946                            self.config.retry_delay_ms * attempts as u64,
947                        ))
948                        .await;
949                    }
950                }
951            }
952        }
953
954        Err(last_error
955            .unwrap_or_else(|| RragError::embedding("batch_processing", "Max retries exceeded")))
956    }
957
958    /// Get provider information
959    pub fn provider_info(&self) -> ProviderInfo {
960        ProviderInfo {
961            name: self.provider.name().to_string(),
962            supported_models: self
963                .provider
964                .supported_models()
965                .iter()
966                .map(|s| s.to_string())
967                .collect(),
968            max_batch_size: self.provider.max_batch_size(),
969            embedding_dimensions: self.provider.embedding_dimensions(),
970        }
971    }
972}
973
974/// Provider information for introspection
975#[derive(Debug, Clone, Serialize, Deserialize)]
976pub struct ProviderInfo {
977    /// Provider name
978    pub name: String,
979    /// List of supported model names
980    pub supported_models: Vec<String>,
981    /// Maximum batch size for efficient processing
982    pub max_batch_size: usize,
983    /// Number of dimensions in the embedding vectors
984    pub embedding_dimensions: usize,
985}
986
987/// Mock embedding provider for testing and development
988///
989/// Provides deterministic, fast embedding generation for testing purposes.
990/// The mock provider generates consistent embeddings based on input text,
991/// making it suitable for unit tests and development workflows.
992///
993/// # Features
994///
995/// - **Deterministic**: Same input always produces the same embedding
996/// - **Fast**: No network calls or heavy computation
997/// - **Configurable**: Adjustable dimensions and behavior
998/// - **Testing-Friendly**: Predictable behavior for assertions
999///
1000/// # Example
1001///
1002/// ```rust
1003/// use rrag::prelude::*;
1004/// use std::sync::Arc;
1005///
1006/// # #[tokio::main]
1007/// # async fn main() -> RragResult<()> {
1008/// let provider = Arc::new(MockEmbeddingProvider::new());
1009/// let service = EmbeddingService::new(provider);
1010///
1011/// let document = Document::new("Test content");
1012/// let embedding = service.embed_document(&document).await?;
1013///
1014/// // Mock provider always returns 384 dimensions
1015/// assert_eq!(embedding.dimensions, 384);
1016/// assert_eq!(embedding.model, "mock-model");
1017/// # Ok(())
1018/// # }
1019/// ```
1020pub struct MockEmbeddingProvider {
1021    model: String,
1022    dimensions: usize,
1023}
1024
1025impl MockEmbeddingProvider {
1026    /// Create a new mock embedding provider
1027    pub fn new() -> Self {
1028        Self {
1029            model: "mock-model".to_string(),
1030            dimensions: 384,
1031        }
1032    }
1033}
1034
1035#[async_trait]
1036impl EmbeddingProvider for MockEmbeddingProvider {
1037    fn name(&self) -> &str {
1038        "mock"
1039    }
1040
1041    fn supported_models(&self) -> Vec<&str> {
1042        vec!["mock-model"]
1043    }
1044
1045    fn max_batch_size(&self) -> usize {
1046        100
1047    }
1048
1049    fn embedding_dimensions(&self) -> usize {
1050        self.dimensions
1051    }
1052
1053    async fn embed_text(&self, text: &str) -> RragResult<Embedding> {
1054        // Generate a simple mock embedding based on text hash
1055        let hash = text.len() as f32;
1056        let mut vector = vec![0.0; self.dimensions];
1057        for i in 0..self.dimensions {
1058            vector[i] = (hash + i as f32).sin() / (i + 1) as f32;
1059        }
1060
1061        Ok(Embedding::new(vector, &self.model, "mock"))
1062    }
1063
1064    async fn embed_batch(&self, requests: Vec<EmbeddingRequest>) -> RragResult<EmbeddingBatch> {
1065        let mut embeddings = HashMap::new();
1066
1067        for request in &requests {
1068            let embedding = self.embed_text(&request.content).await?;
1069            embeddings.insert(request.id.clone(), embedding);
1070        }
1071
1072        Ok(EmbeddingBatch {
1073            embeddings,
1074            metadata: BatchMetadata {
1075                total_items: requests.len(),
1076                successful_items: requests.len(),
1077                failed_items: HashMap::new(),
1078                duration_ms: 10,
1079                model: self.model.clone(),
1080                provider: self.name().to_string(),
1081            },
1082        })
1083    }
1084
1085    async fn health_check(&self) -> RragResult<bool> {
1086        Ok(true)
1087    }
1088}
1089
1090// For backward compatibility
1091/// Type alias for MockEmbeddingProvider for backward compatibility
1092pub type MockEmbeddingService = MockEmbeddingProvider;
1093
1094#[cfg(test)]
1095mod tests {
1096    use super::*;
1097
1098    #[tokio::test]
1099    async fn test_openai_provider() {
1100        let provider = OpenAIEmbeddingProvider::new("test-key");
1101
1102        assert_eq!(provider.name(), "openai");
1103        assert_eq!(provider.embedding_dimensions(), 1536);
1104        assert!(provider.health_check().await.unwrap());
1105
1106        let embedding = provider.embed_text("Hello, world!").await.unwrap();
1107        assert_eq!(embedding.dimensions, 1536);
1108        assert_eq!(embedding.model, "text-embedding-ada-002");
1109    }
1110
1111    #[tokio::test]
1112    async fn test_local_provider() {
1113        let provider = LocalEmbeddingProvider::new("test-model", 384);
1114
1115        assert_eq!(provider.name(), "local");
1116        assert_eq!(provider.embedding_dimensions(), 384);
1117
1118        let embedding = provider.embed_text("Test content").await.unwrap();
1119        assert_eq!(embedding.dimensions, 384);
1120    }
1121
1122    #[tokio::test]
1123    async fn test_embedding_service() {
1124        let provider = Arc::new(LocalEmbeddingProvider::new("test-model", 384));
1125        let service = EmbeddingService::new(provider);
1126
1127        let doc = Document::new("Test document content");
1128        let embedding = service.embed_document(&doc).await.unwrap();
1129
1130        assert_eq!(embedding.dimensions, 384);
1131        assert!(!embedding.vector.is_empty());
1132    }
1133
1134    #[test]
1135    fn test_cosine_similarity() {
1136        let vector1 = vec![1.0, 0.0, 0.0];
1137        let vector2 = vec![0.0, 1.0, 0.0];
1138        let vector3 = vec![1.0, 0.0, 0.0];
1139
1140        let emb1 = Embedding::new(vector1, "test", "1");
1141        let emb2 = Embedding::new(vector2, "test", "2");
1142        let emb3 = Embedding::new(vector3, "test", "3");
1143
1144        let similarity_12 = emb1.cosine_similarity(&emb2).unwrap();
1145        let similarity_13 = emb1.cosine_similarity(&emb3).unwrap();
1146
1147        assert!((similarity_12 - 0.0).abs() < 1e-6); // Orthogonal vectors
1148        assert!((similarity_13 - 1.0).abs() < 1e-6); // Identical vectors
1149    }
1150
1151    #[tokio::test]
1152    async fn test_batch_processing() {
1153        let provider = Arc::new(LocalEmbeddingProvider::new("test-model", 128));
1154
1155        let requests = vec![
1156            EmbeddingRequest::new("1", "First text"),
1157            EmbeddingRequest::new("2", "Second text"),
1158            EmbeddingRequest::new("3", "Third text"),
1159        ];
1160
1161        let batch_result = provider.embed_batch(requests).await.unwrap();
1162
1163        assert_eq!(batch_result.metadata.total_items, 3);
1164        assert_eq!(batch_result.metadata.successful_items, 3);
1165        assert_eq!(batch_result.embeddings.len(), 3);
1166    }
1167}