Skip to main content

oxirs_vec/embeddings/
types.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use crate::Vector;
6use anyhow::{anyhow, Result};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::hash::{Hash, Hasher};
10
11use super::functions::EmbeddingGenerator;
12use super::openaiembeddinggenerator_type::OpenAIEmbeddingGenerator;
13use super::sentencetransformergenerator_type::SentenceTransformerGenerator;
14
15/// Embedding cache for frequently accessed embeddings
16pub struct EmbeddingCache {
17    cache: HashMap<u64, Vector>,
18    max_size: usize,
19    access_order: Vec<u64>,
20}
21impl EmbeddingCache {
22    pub fn new(max_size: usize) -> Self {
23        Self {
24            cache: HashMap::new(),
25            max_size,
26            access_order: Vec::new(),
27        }
28    }
29    pub fn get(&mut self, content: &EmbeddableContent) -> Option<&Vector> {
30        let hash = content.content_hash();
31        if let Some(vector) = self.cache.get(&hash) {
32            if let Some(pos) = self.access_order.iter().position(|&x| x == hash) {
33                self.access_order.remove(pos);
34            }
35            self.access_order.push(hash);
36            Some(vector)
37        } else {
38            None
39        }
40    }
41    pub fn insert(&mut self, content: &EmbeddableContent, vector: Vector) {
42        let hash = content.content_hash();
43        if self.cache.len() >= self.max_size && !self.cache.contains_key(&hash) {
44            if let Some(&lru_hash) = self.access_order.first() {
45                self.cache.remove(&lru_hash);
46                self.access_order.remove(0);
47            }
48        }
49        self.cache.insert(hash, vector);
50        self.access_order.push(hash);
51    }
52    pub fn clear(&mut self) {
53        self.cache.clear();
54        self.access_order.clear();
55    }
56    pub fn size(&self) -> usize {
57        self.cache.len()
58    }
59}
60/// Detailed information about a transformer model
61#[derive(Debug, Clone)]
62pub struct ModelDetails {
63    pub vocab_size: usize,
64    pub num_layers: usize,
65    pub num_attention_heads: usize,
66    pub hidden_size: usize,
67    pub intermediate_size: usize,
68    pub max_position_embeddings: usize,
69    pub supports_languages: Vec<String>,
70    pub model_size_mb: usize,
71    pub typical_inference_time_ms: u64,
72}
73/// Retry strategy for failed requests
74#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
75pub enum RetryStrategy {
76    /// Fixed delay between retries
77    Fixed,
78    /// Exponential backoff with jitter
79    ExponentialBackoff,
80    /// Linear backoff
81    LinearBackoff,
82}
83/// Embedding model configuration
84#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct EmbeddingConfig {
86    pub model_name: String,
87    pub dimensions: usize,
88    pub max_sequence_length: usize,
89    pub normalize: bool,
90}
91/// Mock embedding generator for testing
92#[cfg(test)]
93pub struct MockEmbeddingGenerator {
94    pub(super) config: EmbeddingConfig,
95}
96#[cfg(test)]
97impl MockEmbeddingGenerator {
98    pub fn new() -> Self {
99        Self {
100            config: EmbeddingConfig {
101                dimensions: 128,
102                ..Default::default()
103            },
104        }
105    }
106    pub fn with_dimensions(dimensions: usize) -> Self {
107        Self {
108            config: EmbeddingConfig {
109                dimensions,
110                ..Default::default()
111            },
112        }
113    }
114}
115/// Content to be embedded
116#[derive(Debug, Clone)]
117pub enum EmbeddableContent {
118    /// Plain text content
119    Text(String),
120    /// RDF resource with properties
121    RdfResource {
122        uri: String,
123        label: Option<String>,
124        description: Option<String>,
125        properties: HashMap<String, Vec<String>>,
126    },
127    /// SPARQL query or query fragment
128    SparqlQuery(String),
129    /// Knowledge graph path or pattern
130    GraphPattern(String),
131}
132impl EmbeddableContent {
133    /// Convert content to text representation for embedding
134    pub fn to_text(&self) -> String {
135        match self {
136            EmbeddableContent::Text(text) => text.clone(),
137            EmbeddableContent::RdfResource {
138                uri,
139                label,
140                description,
141                properties,
142            } => {
143                let mut text_parts = vec![uri.clone()];
144                if let Some(label) = label {
145                    text_parts.push(format!("label: {label}"));
146                }
147                if let Some(desc) = description {
148                    text_parts.push(format!("description: {desc}"));
149                }
150                for (prop, values) in properties {
151                    text_parts.push(format!("{prop}: {}", values.join(", ")));
152                }
153                text_parts.join(" ")
154            }
155            EmbeddableContent::SparqlQuery(query) => query.clone(),
156            EmbeddableContent::GraphPattern(pattern) => pattern.clone(),
157        }
158    }
159    /// Get a unique identifier for this content
160    pub fn content_hash(&self) -> u64 {
161        let mut hasher = std::collections::hash_map::DefaultHasher::new();
162        self.to_text().hash(&mut hasher);
163        hasher.finish()
164    }
165}
166/// Embedding generation strategy
167#[derive(Debug, Clone, Serialize, Deserialize)]
168pub enum EmbeddingStrategy {
169    /// Simple TF-IDF based embeddings (for testing/fallback)
170    TfIdf,
171    /// Sentence transformer embeddings (requires external service)
172    SentenceTransformer,
173    /// BERT-based transformer models
174    Transformer(TransformerModelType),
175    /// Word2Vec embeddings
176    Word2Vec(crate::word2vec::Word2VecConfig),
177    /// OpenAI embeddings (requires API key)
178    OpenAI(OpenAIConfig),
179    /// Custom embedding model
180    Custom(String),
181}
182/// Embedding manager that combines generation, caching, and persistence
183pub struct EmbeddingManager {
184    generator: Box<dyn EmbeddingGenerator>,
185    cache: EmbeddingCache,
186    strategy: EmbeddingStrategy,
187}
188impl EmbeddingManager {
189    pub fn new(strategy: EmbeddingStrategy, cache_size: usize) -> Result<Self> {
190        let generator: Box<dyn EmbeddingGenerator> = match &strategy {
191            EmbeddingStrategy::TfIdf => {
192                let config = EmbeddingConfig::default();
193                Box::new(TfIdfEmbeddingGenerator::new(config))
194            }
195            EmbeddingStrategy::SentenceTransformer => {
196                let config = EmbeddingConfig::default();
197                Box::new(SentenceTransformerGenerator::new(config))
198            }
199            EmbeddingStrategy::Transformer(model_type) => {
200                let config = EmbeddingConfig {
201                    model_name: format!("{model_type:?}"),
202                    dimensions: match model_type {
203                        TransformerModelType::DistilBERT => 384,
204                        _ => 768,
205                    },
206                    max_sequence_length: 512,
207                    normalize: true,
208                };
209                Box::new(SentenceTransformerGenerator::with_model_type(
210                    config,
211                    model_type.clone(),
212                ))
213            }
214            EmbeddingStrategy::Word2Vec(word2vec_config) => {
215                let embedding_config = EmbeddingConfig {
216                    model_name: "word2vec".to_string(),
217                    dimensions: word2vec_config.dimensions,
218                    max_sequence_length: 512,
219                    normalize: word2vec_config.normalize,
220                };
221                Box::new(crate::word2vec::Word2VecEmbeddingGenerator::new(
222                    word2vec_config.clone(),
223                    embedding_config,
224                )?)
225            }
226            EmbeddingStrategy::OpenAI(openai_config) => {
227                Box::new(OpenAIEmbeddingGenerator::new(openai_config.clone())?)
228            }
229            EmbeddingStrategy::Custom(_model_path) => {
230                let config = EmbeddingConfig::default();
231                Box::new(SentenceTransformerGenerator::new(config))
232            }
233        };
234        Ok(Self {
235            generator,
236            cache: EmbeddingCache::new(cache_size),
237            strategy,
238        })
239    }
240    /// Get or generate embedding for content
241    pub fn get_embedding(&mut self, content: &EmbeddableContent) -> Result<Vector> {
242        if let Some(cached) = self.cache.get(content) {
243            return Ok(cached.clone());
244        }
245        let embedding = self.generator.generate(content)?;
246        self.cache.insert(content, embedding.clone());
247        Ok(embedding)
248    }
249    /// Pre-compute embeddings for a batch of content
250    pub fn precompute_embeddings(&mut self, contents: &[EmbeddableContent]) -> Result<()> {
251        let embeddings = self.generator.generate_batch(contents)?;
252        for (content, embedding) in contents.iter().zip(embeddings) {
253            self.cache.insert(content, embedding);
254        }
255        Ok(())
256    }
257    /// Build vocabulary for TF-IDF strategy
258    pub fn build_vocabulary(&mut self, documents: &[String]) -> Result<()> {
259        if let EmbeddingStrategy::TfIdf = self.strategy {
260            if let Some(tfidf_gen) = self
261                .generator
262                .as_any_mut()
263                .downcast_mut::<TfIdfEmbeddingGenerator>()
264            {
265                tfidf_gen.build_vocabulary(documents)?;
266            }
267        }
268        Ok(())
269    }
270    pub fn dimensions(&self) -> usize {
271        self.generator.dimensions()
272    }
273    pub fn cache_stats(&self) -> (usize, usize) {
274        (self.cache.size(), self.cache.max_size)
275    }
276}
277/// Supported transformer model types
278#[derive(Debug, Clone, Serialize, Deserialize, Default)]
279pub enum TransformerModelType {
280    /// Basic BERT-based model (already implemented)
281    #[default]
282    BERT,
283    /// RoBERTa model with improved training
284    RoBERTa,
285    /// DistilBERT for efficiency
286    DistilBERT,
287    /// Multilingual BERT
288    MultiBERT,
289    /// Custom model path
290    Custom(String),
291}
292/// OpenAI embeddings configuration
293#[derive(Debug, Clone, Serialize, Deserialize)]
294pub struct OpenAIConfig {
295    /// API key for OpenAI service
296    pub api_key: String,
297    /// Model to use (e.g., "text-embedding-ada-002", "text-embedding-3-small")
298    pub model: String,
299    /// Base URL for API calls (default: `https://api.openai.com/v1`)
300    pub base_url: String,
301    /// Request timeout in seconds
302    pub timeout_seconds: u64,
303    /// Rate limiting: requests per minute
304    pub requests_per_minute: u32,
305    /// Batch size for batch processing
306    pub batch_size: usize,
307    /// Enable local caching
308    pub enable_cache: bool,
309    /// Cache size (number of embeddings to cache)
310    pub cache_size: usize,
311    /// Cache TTL in seconds (0 for no expiration)
312    pub cache_ttl_seconds: u64,
313    /// Maximum retries for failed requests
314    pub max_retries: u32,
315    /// Retry delay in milliseconds
316    pub retry_delay_ms: u64,
317    /// Retry strategy
318    pub retry_strategy: RetryStrategy,
319    /// Enable cost tracking
320    pub track_costs: bool,
321    /// Enable detailed metrics
322    pub enable_metrics: bool,
323    /// User agent for requests
324    pub user_agent: String,
325}
326impl OpenAIConfig {
327    /// Create config for production use
328    pub fn production() -> Self {
329        Self {
330            requests_per_minute: 1000,
331            cache_size: 50000,
332            cache_ttl_seconds: 7200,
333            max_retries: 5,
334            retry_strategy: RetryStrategy::ExponentialBackoff,
335            ..Default::default()
336        }
337    }
338    /// Create config for development/testing
339    pub fn development() -> Self {
340        Self {
341            requests_per_minute: 100,
342            cache_size: 1000,
343            cache_ttl_seconds: 300,
344            max_retries: 2,
345            ..Default::default()
346        }
347    }
348    /// Validate configuration
349    pub fn validate(&self) -> Result<()> {
350        if self.api_key.is_empty() {
351            return Err(anyhow!("OpenAI API key is required"));
352        }
353        if self.requests_per_minute == 0 {
354            return Err(anyhow!("requests_per_minute must be greater than 0"));
355        }
356        if self.batch_size == 0 {
357            return Err(anyhow!("batch_size must be greater than 0"));
358        }
359        if self.timeout_seconds == 0 {
360            return Err(anyhow!("timeout_seconds must be greater than 0"));
361        }
362        Ok(())
363    }
364}
365/// Simple rate limiter implementation
366pub struct RateLimiter {
367    requests_per_minute: u32,
368    request_times: std::collections::VecDeque<std::time::Instant>,
369}
370impl RateLimiter {
371    pub fn new(requests_per_minute: u32) -> Self {
372        Self {
373            requests_per_minute,
374            request_times: std::collections::VecDeque::new(),
375        }
376    }
377    pub async fn wait_if_needed(&mut self) {
378        let now = std::time::Instant::now();
379        let minute_ago = now - std::time::Duration::from_secs(60);
380        while let Some(&front_time) = self.request_times.front() {
381            if front_time < minute_ago {
382                self.request_times.pop_front();
383            } else {
384                break;
385            }
386        }
387        if self.request_times.len() >= self.requests_per_minute as usize {
388            if let Some(&oldest) = self.request_times.front() {
389                let wait_time = oldest + std::time::Duration::from_secs(60) - now;
390                if !wait_time.is_zero() {
391                    tokio::time::sleep(wait_time).await;
392                }
393            }
394        }
395        self.request_times.push_back(now);
396    }
397}
398/// Metrics for OpenAI API usage
399#[derive(Debug, Clone, Default)]
400pub struct OpenAIMetrics {
401    pub total_requests: u64,
402    pub successful_requests: u64,
403    pub failed_requests: u64,
404    pub total_tokens_processed: u64,
405    pub cache_hits: u64,
406    pub cache_misses: u64,
407    pub total_cost_usd: f64,
408    pub retry_count: u64,
409    pub rate_limit_waits: u64,
410    pub average_response_time_ms: f64,
411    pub last_request_time: Option<std::time::SystemTime>,
412    pub requests_by_model: HashMap<String, u64>,
413    pub errors_by_type: HashMap<String, u64>,
414}
415impl OpenAIMetrics {
416    /// Calculate cache hit ratio
417    pub fn cache_hit_ratio(&self) -> f64 {
418        if self.cache_hits + self.cache_misses == 0 {
419            0.0
420        } else {
421            self.cache_hits as f64 / (self.cache_hits + self.cache_misses) as f64
422        }
423    }
424    /// Calculate success rate
425    pub fn success_rate(&self) -> f64 {
426        if self.total_requests == 0 {
427            0.0
428        } else {
429            self.successful_requests as f64 / self.total_requests as f64
430        }
431    }
432    /// Calculate average cost per request
433    pub fn average_cost_per_request(&self) -> f64 {
434        if self.successful_requests == 0 {
435            0.0
436        } else {
437            self.total_cost_usd / self.successful_requests as f64
438        }
439    }
440    /// Get formatted metrics report
441    pub fn report(&self) -> String {
442        format!(
443            "OpenAI Metrics Report:\n\
444            Total Requests: {}\n\
445            Success Rate: {:.2}%\n\
446            Cache Hit Ratio: {:.2}%\n\
447            Total Cost: ${:.4}\n\
448            Avg Cost/Request: ${:.6}\n\
449            Avg Response Time: {:.2}ms\n\
450            Retries: {}\n\
451            Rate Limit Waits: {}",
452            self.total_requests,
453            self.success_rate() * 100.0,
454            self.cache_hit_ratio() * 100.0,
455            self.total_cost_usd,
456            self.average_cost_per_request(),
457            self.average_response_time_ms,
458            self.retry_count,
459            self.rate_limit_waits
460        )
461    }
462}
463/// Cached embedding with metadata
464#[derive(Debug, Clone)]
465pub struct CachedEmbedding {
466    pub vector: Vector,
467    pub cached_at: std::time::SystemTime,
468    pub model: String,
469    pub cost_usd: f64,
470}
471/// Simple TF-IDF based embedding generator
472pub struct TfIdfEmbeddingGenerator {
473    pub(super) config: EmbeddingConfig,
474    pub(super) vocabulary: HashMap<String, usize>,
475    idf_scores: HashMap<String, f32>,
476}
477impl TfIdfEmbeddingGenerator {
478    pub fn new(config: EmbeddingConfig) -> Self {
479        Self {
480            config,
481            vocabulary: HashMap::new(),
482            idf_scores: HashMap::new(),
483        }
484    }
485    /// Build vocabulary from a corpus of documents
486    pub fn build_vocabulary(&mut self, documents: &[String]) -> Result<()> {
487        let mut word_counts: HashMap<String, usize> = HashMap::new();
488        let mut doc_counts: HashMap<String, usize> = HashMap::new();
489        for doc in documents {
490            let words: Vec<String> = self.tokenize(doc);
491            let unique_words: std::collections::HashSet<_> = words.iter().collect();
492            for word in &words {
493                *word_counts.entry(word.clone()).or_insert(0) += 1;
494            }
495            for word in unique_words {
496                *doc_counts.entry(word.clone()).or_insert(0) += 1;
497            }
498        }
499        let mut word_freq: Vec<(String, usize)> = word_counts.into_iter().collect();
500        word_freq.sort_by(|a, b| b.1.cmp(&a.1));
501        self.vocabulary = word_freq
502            .into_iter()
503            .take(self.config.dimensions)
504            .enumerate()
505            .map(|(idx, (word, _))| (word, idx))
506            .collect();
507        let total_docs = documents.len() as f32;
508        for word in self.vocabulary.keys() {
509            let doc_freq = doc_counts.get(word).unwrap_or(&0);
510            let idf = (total_docs / (*doc_freq as f32 + 1.0)).ln();
511            self.idf_scores.insert(word.clone(), idf);
512        }
513        Ok(())
514    }
515    fn tokenize(&self, text: &str) -> Vec<String> {
516        text.to_lowercase()
517            .split_whitespace()
518            .map(|s| s.trim_matches(|c: char| !c.is_alphanumeric()))
519            .filter(|s| !s.is_empty())
520            .map(String::from)
521            .collect()
522    }
523    pub(super) fn calculate_tf_idf(&self, text: &str) -> Vector {
524        let words = self.tokenize(text);
525        let mut tf_counts: HashMap<String, usize> = HashMap::new();
526        for word in &words {
527            *tf_counts.entry(word.clone()).or_insert(0) += 1;
528        }
529        let total_words = words.len() as f32;
530        let mut embedding = vec![0.0; self.config.dimensions];
531        for (word, count) in tf_counts {
532            if let Some(&idx) = self.vocabulary.get(&word) {
533                let tf = count as f32 / total_words;
534                let idf = self.idf_scores.get(&word).unwrap_or(&0.0);
535                embedding[idx] = tf * idf;
536            }
537        }
538        if self.config.normalize {
539            self.normalize_vector(&mut embedding);
540        }
541        Vector::new(embedding)
542    }
543    fn normalize_vector(&self, vector: &mut [f32]) {
544        let magnitude: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
545        if magnitude > 0.0 {
546            for value in vector {
547                *value /= magnitude;
548            }
549        }
550    }
551}