oxirs_embed/
inference.rs

1//! High-performance inference engine for embedding models
2
3use crate::Vector;
4use crate::{EmbeddingModel, ModelStats};
5use anyhow::Result;
6use std::collections::HashMap;
7use std::sync::{Arc, RwLock};
8use tokio::sync::Semaphore;
9use tracing::{debug, info};
10
11/// High-performance inference engine with caching and batching
12pub struct InferenceEngine {
13    model: Arc<RwLock<Box<dyn EmbeddingModel>>>,
14    cache: Arc<RwLock<InferenceCache>>,
15    config: InferenceConfig,
16    batch_processor: BatchProcessor,
17}
18
19/// Configuration for inference engine
20#[derive(Debug, Clone)]
21pub struct InferenceConfig {
22    /// Maximum cache size
23    pub cache_size: usize,
24    /// Batch size for batch processing
25    pub batch_size: usize,
26    /// Maximum number of concurrent requests
27    pub max_concurrent: usize,
28    /// Cache TTL in seconds
29    pub cache_ttl: u64,
30    /// Enable result caching
31    pub enable_caching: bool,
32    /// Warm up cache on startup
33    pub warm_up_cache: bool,
34}
35
36impl Default for InferenceConfig {
37    fn default() -> Self {
38        Self {
39            cache_size: 10000,
40            batch_size: 100,
41            max_concurrent: 10,
42            cache_ttl: 3600, // 1 hour
43            enable_caching: true,
44            warm_up_cache: false,
45        }
46    }
47}
48
49/// Embedding cache with TTL support
50#[derive(Debug)]
51pub struct InferenceCache {
52    entity_cache: HashMap<String, CacheEntry<Vector>>,
53    relation_cache: HashMap<String, CacheEntry<Vector>>,
54    triple_score_cache: HashMap<String, CacheEntry<f64>>,
55    max_size: usize,
56    ttl_seconds: u64,
57}
58
59#[derive(Debug, Clone)]
60struct CacheEntry<T> {
61    value: T,
62    timestamp: std::time::SystemTime,
63}
64
65impl<T> CacheEntry<T> {
66    fn new(value: T) -> Self {
67        Self {
68            value,
69            timestamp: std::time::SystemTime::now(),
70        }
71    }
72
73    fn is_expired(&self, ttl_seconds: u64) -> bool {
74        if let Ok(elapsed) = self.timestamp.elapsed() {
75            elapsed.as_secs() > ttl_seconds
76        } else {
77            true // If we can't determine elapsed time, consider expired
78        }
79    }
80}
81
82impl InferenceCache {
83    pub fn new(max_size: usize, ttl_seconds: u64) -> Self {
84        Self {
85            entity_cache: HashMap::new(),
86            relation_cache: HashMap::new(),
87            triple_score_cache: HashMap::new(),
88            max_size,
89            ttl_seconds,
90        }
91    }
92
93    pub fn get_entity_embedding(&mut self, entity: &str) -> Option<Vector> {
94        let expired = if let Some(entry) = self.entity_cache.get(entity) {
95            if !entry.is_expired(self.ttl_seconds) {
96                return Some(entry.value.clone());
97            } else {
98                true
99            }
100        } else {
101            false
102        };
103
104        if expired {
105            self.entity_cache.remove(entity);
106        }
107        None
108    }
109
110    pub fn cache_entity_embedding(&mut self, entity: String, embedding: Vector) {
111        if self.entity_cache.len() >= self.max_size {
112            // Simple LRU: remove oldest entry
113            if let Some(oldest_key) = self.find_oldest_entity() {
114                self.entity_cache.remove(&oldest_key);
115            }
116        }
117
118        self.entity_cache.insert(entity, CacheEntry::new(embedding));
119    }
120
121    pub fn get_relation_embedding(&mut self, relation: &str) -> Option<Vector> {
122        let expired = if let Some(entry) = self.relation_cache.get(relation) {
123            if !entry.is_expired(self.ttl_seconds) {
124                return Some(entry.value.clone());
125            } else {
126                true
127            }
128        } else {
129            false
130        };
131
132        if expired {
133            self.relation_cache.remove(relation);
134        }
135        None
136    }
137
138    pub fn cacherelation_embedding(&mut self, relation: String, embedding: Vector) {
139        if self.relation_cache.len() >= self.max_size {
140            if let Some(oldest_key) = self.find_oldest_relation() {
141                self.relation_cache.remove(&oldest_key);
142            }
143        }
144
145        self.relation_cache
146            .insert(relation, CacheEntry::new(embedding));
147    }
148
149    pub fn get_triple_score(&mut self, key: &str) -> Option<f64> {
150        if let Some(entry) = self.triple_score_cache.get(key) {
151            if !entry.is_expired(self.ttl_seconds) {
152                return Some(entry.value);
153            } else {
154                self.triple_score_cache.remove(key);
155            }
156        }
157        None
158    }
159
160    pub fn cache_triple_score(&mut self, key: String, score: f64) {
161        if self.triple_score_cache.len() >= self.max_size {
162            if let Some(oldest_key) = self.find_oldest_triple() {
163                self.triple_score_cache.remove(&oldest_key);
164            }
165        }
166
167        self.triple_score_cache.insert(key, CacheEntry::new(score));
168    }
169
170    fn find_oldest_entity(&self) -> Option<String> {
171        self.entity_cache
172            .iter()
173            .min_by_key(|(_, entry)| entry.timestamp)
174            .map(|(key, _)| key.clone())
175    }
176
177    fn find_oldest_relation(&self) -> Option<String> {
178        self.relation_cache
179            .iter()
180            .min_by_key(|(_, entry)| entry.timestamp)
181            .map(|(key, _)| key.clone())
182    }
183
184    fn find_oldest_triple(&self) -> Option<String> {
185        self.triple_score_cache
186            .iter()
187            .min_by_key(|(_, entry)| entry.timestamp)
188            .map(|(key, _)| key.clone())
189    }
190
191    pub fn clear(&mut self) {
192        self.entity_cache.clear();
193        self.relation_cache.clear();
194        self.triple_score_cache.clear();
195    }
196
197    pub fn stats(&self) -> CacheStats {
198        CacheStats {
199            entity_cache_size: self.entity_cache.len(),
200            relation_cache_size: self.relation_cache.len(),
201            triple_cache_size: self.triple_score_cache.len(),
202            max_size: self.max_size,
203        }
204    }
205}
206
207#[derive(Debug, Clone)]
208pub struct CacheStats {
209    pub entity_cache_size: usize,
210    pub relation_cache_size: usize,
211    pub triple_cache_size: usize,
212    pub max_size: usize,
213}
214
215/// Batch processor for efficient bulk operations
216#[derive(Debug)]
217pub struct BatchProcessor {
218    batch_size: usize,
219    semaphore: Arc<Semaphore>,
220}
221
222impl BatchProcessor {
223    pub fn new(batch_size: usize, max_concurrent: usize) -> Self {
224        Self {
225            batch_size,
226            semaphore: Arc::new(Semaphore::new(max_concurrent)),
227        }
228    }
229
230    pub async fn process_entity_batch(
231        &self,
232        model: Arc<RwLock<Box<dyn EmbeddingModel>>>,
233        entities: Vec<String>,
234    ) -> Result<Vec<(String, Result<Vector>)>> {
235        let _permit = self
236            .semaphore
237            .acquire()
238            .await
239            .expect("semaphore should not be closed");
240
241        let mut results = Vec::new();
242
243        for chunk in entities.chunks(self.batch_size) {
244            let model_guard = model.read().expect("rwlock should not be poisoned");
245            for entity in chunk {
246                let result = model_guard.get_entity_embedding(entity);
247                results.push((entity.clone(), result));
248            }
249        }
250
251        Ok(results)
252    }
253
254    pub async fn process_relation_batch(
255        &self,
256        model: Arc<RwLock<Box<dyn EmbeddingModel>>>,
257        relations: Vec<String>,
258    ) -> Result<Vec<(String, Result<Vector>)>> {
259        let _permit = self
260            .semaphore
261            .acquire()
262            .await
263            .expect("semaphore should not be closed");
264
265        let mut results = Vec::new();
266
267        for chunk in relations.chunks(self.batch_size) {
268            let model_guard = model.read().expect("rwlock should not be poisoned");
269            for relation in chunk {
270                let result = model_guard.get_relation_embedding(relation);
271                results.push((relation.clone(), result));
272            }
273        }
274
275        Ok(results)
276    }
277}
278
279impl InferenceEngine {
280    /// Create a new inference engine
281    pub fn new(model: Box<dyn EmbeddingModel>, config: InferenceConfig) -> Self {
282        let cache = Arc::new(RwLock::new(InferenceCache::new(
283            config.cache_size,
284            config.cache_ttl,
285        )));
286
287        let batch_processor = BatchProcessor::new(config.batch_size, config.max_concurrent);
288
289        Self {
290            model: Arc::new(RwLock::new(model)),
291            cache,
292            config,
293            batch_processor,
294        }
295    }
296
297    /// Get entity embedding with caching
298    pub async fn get_entity_embedding(&self, entity: &str) -> Result<Vector> {
299        // Check cache first
300        if self.config.enable_caching {
301            if let Ok(mut cache) = self.cache.write() {
302                if let Some(cached) = cache.get_entity_embedding(entity) {
303                    debug!("Cache hit for entity: {}", entity);
304                    return Ok(cached.clone());
305                }
306            }
307        }
308
309        // Get from model
310        let embedding = {
311            let model_guard = self.model.read().expect("rwlock should not be poisoned");
312            model_guard.get_entity_embedding(entity)?
313        };
314
315        // Cache result
316        if self.config.enable_caching {
317            if let Ok(mut cache) = self.cache.write() {
318                cache.cache_entity_embedding(entity.to_string(), embedding.clone());
319            }
320        }
321
322        Ok(embedding)
323    }
324
325    /// Get relation embedding with caching
326    pub async fn get_relation_embedding(&self, relation: &str) -> Result<Vector> {
327        // Check cache first
328        if self.config.enable_caching {
329            if let Ok(mut cache) = self.cache.write() {
330                if let Some(cached) = cache.get_relation_embedding(relation) {
331                    debug!("Cache hit for relation: {}", relation);
332                    return Ok(cached.clone());
333                }
334            }
335        }
336
337        // Get from model
338        let embedding = {
339            let model_guard = self.model.read().expect("rwlock should not be poisoned");
340            model_guard.get_relation_embedding(relation)?
341        };
342
343        // Cache result
344        if self.config.enable_caching {
345            if let Ok(mut cache) = self.cache.write() {
346                cache.cacherelation_embedding(relation.to_string(), embedding.clone());
347            }
348        }
349
350        Ok(embedding)
351    }
352
353    /// Score triple with caching
354    pub async fn score_triple(&self, subject: &str, predicate: &str, object: &str) -> Result<f64> {
355        let cache_key = format!("{subject}|{predicate}|{object}");
356
357        // Check cache first
358        if self.config.enable_caching {
359            if let Ok(mut cache) = self.cache.write() {
360                if let Some(cached_score) = cache.get_triple_score(&cache_key) {
361                    debug!("Cache hit for triple: {}", cache_key);
362                    return Ok(cached_score);
363                }
364            }
365        }
366
367        // Get from model
368        let score = {
369            let model_guard = self.model.read().expect("rwlock should not be poisoned");
370            model_guard.score_triple(subject, predicate, object)?
371        };
372
373        // Cache result
374        if self.config.enable_caching {
375            if let Ok(mut cache) = self.cache.write() {
376                cache.cache_triple_score(cache_key, score);
377            }
378        }
379
380        Ok(score)
381    }
382
383    /// Batch process entity embeddings
384    pub async fn get_entity_embeddings_batch(
385        &self,
386        entities: Vec<String>,
387    ) -> Result<Vec<(String, Result<Vector>)>> {
388        self.batch_processor
389            .process_entity_batch(self.model.clone(), entities)
390            .await
391    }
392
393    /// Batch process relation embeddings
394    pub async fn get_relation_embeddings_batch(
395        &self,
396        relations: Vec<String>,
397    ) -> Result<Vec<(String, Result<Vector>)>> {
398        self.batch_processor
399            .process_relation_batch(self.model.clone(), relations)
400            .await
401    }
402
403    /// Warm up cache with common entities and relations
404    pub async fn warm_up_cache(&self) -> Result<()> {
405        if !self.config.warm_up_cache {
406            return Ok(());
407        }
408
409        info!("Warming up inference cache...");
410
411        let (entities, relations) = {
412            let model_guard = self.model.read().expect("rwlock should not be poisoned");
413            (model_guard.get_entities(), model_guard.get_relations())
414        };
415
416        // Warm up entity cache
417        for entity in entities.iter().take(self.config.cache_size / 2) {
418            let _ = self.get_entity_embedding(entity).await;
419        }
420
421        // Warm up relation cache
422        for relation in relations.iter().take(self.config.cache_size / 2) {
423            let _ = self.get_relation_embedding(relation).await;
424        }
425
426        info!("Cache warm-up completed");
427        Ok(())
428    }
429
430    /// Get cache statistics
431    pub fn cache_stats(&self) -> Result<CacheStats> {
432        let cache_guard = self.cache.read().expect("rwlock should not be poisoned");
433        Ok(cache_guard.stats())
434    }
435
436    /// Clear cache
437    pub fn clear_cache(&self) -> Result<()> {
438        let mut cache_guard = self.cache.write().expect("rwlock should not be poisoned");
439        cache_guard.clear();
440        info!("Inference cache cleared");
441        Ok(())
442    }
443
444    /// Get model statistics
445    pub fn model_stats(&self) -> Result<ModelStats> {
446        let model_guard = self.model.read().expect("rwlock should not be poisoned");
447        Ok(model_guard.get_stats())
448    }
449}
450
451/// Performance monitoring for inference
452#[derive(Debug, Clone)]
453pub struct InferenceMetrics {
454    pub total_requests: u64,
455    pub cache_hits: u64,
456    pub cache_misses: u64,
457    pub average_latency_ms: f64,
458    pub throughput_per_second: f64,
459}
460
461impl InferenceMetrics {
462    pub fn cache_hit_rate(&self) -> f64 {
463        if self.total_requests > 0 {
464            self.cache_hits as f64 / self.total_requests as f64
465        } else {
466            0.0
467        }
468    }
469}
470
471#[cfg(test)]
472mod tests {
473    use super::*;
474    use crate::models::TransE;
475    use crate::ModelConfig;
476
477    #[tokio::test]
478    async fn test_inference_cache() {
479        let mut cache = InferenceCache::new(2, 3600);
480
481        let vec1 = Vector::new(vec![1.0, 2.0, 3.0]);
482        let vec2 = Vector::new(vec![4.0, 5.0, 6.0]);
483
484        cache.cache_entity_embedding("entity1".to_string(), vec1.clone());
485        cache.cache_entity_embedding("entity2".to_string(), vec2.clone());
486
487        assert!(cache.get_entity_embedding("entity1").is_some());
488        assert!(cache.get_entity_embedding("entity2").is_some());
489
490        // Adding third should evict first (LRU)
491        let vec3 = Vector::new(vec![7.0, 8.0, 9.0]);
492        cache.cache_entity_embedding("entity3".to_string(), vec3);
493
494        assert_eq!(cache.entity_cache.len(), 2);
495    }
496
497    #[tokio::test]
498    async fn test_inference_engine() -> Result<()> {
499        let config = ModelConfig::default().with_dimensions(10).with_seed(42);
500        let model = TransE::new(config);
501
502        let inference_config = InferenceConfig {
503            cache_size: 100,
504            enable_caching: true,
505            ..Default::default()
506        };
507
508        let engine = InferenceEngine::new(Box::new(model), inference_config);
509
510        // Test should work even with untrained model
511        let stats = engine.model_stats()?;
512        assert_eq!(stats.dimensions, 10);
513
514        Ok(())
515    }
516}