1use crate::Vector;
4use crate::{EmbeddingModel, ModelStats};
5use anyhow::Result;
6use std::collections::HashMap;
7use std::sync::{Arc, RwLock};
8use tokio::sync::Semaphore;
9use tracing::{debug, info};
10
11pub struct InferenceEngine {
13 model: Arc<RwLock<Box<dyn EmbeddingModel>>>,
14 cache: Arc<RwLock<InferenceCache>>,
15 config: InferenceConfig,
16 batch_processor: BatchProcessor,
17}
18
19#[derive(Debug, Clone)]
21pub struct InferenceConfig {
22 pub cache_size: usize,
24 pub batch_size: usize,
26 pub max_concurrent: usize,
28 pub cache_ttl: u64,
30 pub enable_caching: bool,
32 pub warm_up_cache: bool,
34}
35
36impl Default for InferenceConfig {
37 fn default() -> Self {
38 Self {
39 cache_size: 10000,
40 batch_size: 100,
41 max_concurrent: 10,
42 cache_ttl: 3600, enable_caching: true,
44 warm_up_cache: false,
45 }
46 }
47}
48
49#[derive(Debug)]
51pub struct InferenceCache {
52 entity_cache: HashMap<String, CacheEntry<Vector>>,
53 relation_cache: HashMap<String, CacheEntry<Vector>>,
54 triple_score_cache: HashMap<String, CacheEntry<f64>>,
55 max_size: usize,
56 ttl_seconds: u64,
57}
58
59#[derive(Debug, Clone)]
60struct CacheEntry<T> {
61 value: T,
62 timestamp: std::time::SystemTime,
63}
64
65impl<T> CacheEntry<T> {
66 fn new(value: T) -> Self {
67 Self {
68 value,
69 timestamp: std::time::SystemTime::now(),
70 }
71 }
72
73 fn is_expired(&self, ttl_seconds: u64) -> bool {
74 if let Ok(elapsed) = self.timestamp.elapsed() {
75 elapsed.as_secs() > ttl_seconds
76 } else {
77 true }
79 }
80}
81
82impl InferenceCache {
83 pub fn new(max_size: usize, ttl_seconds: u64) -> Self {
84 Self {
85 entity_cache: HashMap::new(),
86 relation_cache: HashMap::new(),
87 triple_score_cache: HashMap::new(),
88 max_size,
89 ttl_seconds,
90 }
91 }
92
93 pub fn get_entity_embedding(&mut self, entity: &str) -> Option<Vector> {
94 let expired = if let Some(entry) = self.entity_cache.get(entity) {
95 if !entry.is_expired(self.ttl_seconds) {
96 return Some(entry.value.clone());
97 } else {
98 true
99 }
100 } else {
101 false
102 };
103
104 if expired {
105 self.entity_cache.remove(entity);
106 }
107 None
108 }
109
110 pub fn cache_entity_embedding(&mut self, entity: String, embedding: Vector) {
111 if self.entity_cache.len() >= self.max_size {
112 if let Some(oldest_key) = self.find_oldest_entity() {
114 self.entity_cache.remove(&oldest_key);
115 }
116 }
117
118 self.entity_cache.insert(entity, CacheEntry::new(embedding));
119 }
120
121 pub fn get_relation_embedding(&mut self, relation: &str) -> Option<Vector> {
122 let expired = if let Some(entry) = self.relation_cache.get(relation) {
123 if !entry.is_expired(self.ttl_seconds) {
124 return Some(entry.value.clone());
125 } else {
126 true
127 }
128 } else {
129 false
130 };
131
132 if expired {
133 self.relation_cache.remove(relation);
134 }
135 None
136 }
137
138 pub fn cacherelation_embedding(&mut self, relation: String, embedding: Vector) {
139 if self.relation_cache.len() >= self.max_size {
140 if let Some(oldest_key) = self.find_oldest_relation() {
141 self.relation_cache.remove(&oldest_key);
142 }
143 }
144
145 self.relation_cache
146 .insert(relation, CacheEntry::new(embedding));
147 }
148
149 pub fn get_triple_score(&mut self, key: &str) -> Option<f64> {
150 if let Some(entry) = self.triple_score_cache.get(key) {
151 if !entry.is_expired(self.ttl_seconds) {
152 return Some(entry.value);
153 } else {
154 self.triple_score_cache.remove(key);
155 }
156 }
157 None
158 }
159
160 pub fn cache_triple_score(&mut self, key: String, score: f64) {
161 if self.triple_score_cache.len() >= self.max_size {
162 if let Some(oldest_key) = self.find_oldest_triple() {
163 self.triple_score_cache.remove(&oldest_key);
164 }
165 }
166
167 self.triple_score_cache.insert(key, CacheEntry::new(score));
168 }
169
170 fn find_oldest_entity(&self) -> Option<String> {
171 self.entity_cache
172 .iter()
173 .min_by_key(|(_, entry)| entry.timestamp)
174 .map(|(key, _)| key.clone())
175 }
176
177 fn find_oldest_relation(&self) -> Option<String> {
178 self.relation_cache
179 .iter()
180 .min_by_key(|(_, entry)| entry.timestamp)
181 .map(|(key, _)| key.clone())
182 }
183
184 fn find_oldest_triple(&self) -> Option<String> {
185 self.triple_score_cache
186 .iter()
187 .min_by_key(|(_, entry)| entry.timestamp)
188 .map(|(key, _)| key.clone())
189 }
190
191 pub fn clear(&mut self) {
192 self.entity_cache.clear();
193 self.relation_cache.clear();
194 self.triple_score_cache.clear();
195 }
196
197 pub fn stats(&self) -> CacheStats {
198 CacheStats {
199 entity_cache_size: self.entity_cache.len(),
200 relation_cache_size: self.relation_cache.len(),
201 triple_cache_size: self.triple_score_cache.len(),
202 max_size: self.max_size,
203 }
204 }
205}
206
207#[derive(Debug, Clone)]
208pub struct CacheStats {
209 pub entity_cache_size: usize,
210 pub relation_cache_size: usize,
211 pub triple_cache_size: usize,
212 pub max_size: usize,
213}
214
215#[derive(Debug)]
217pub struct BatchProcessor {
218 batch_size: usize,
219 semaphore: Arc<Semaphore>,
220}
221
222impl BatchProcessor {
223 pub fn new(batch_size: usize, max_concurrent: usize) -> Self {
224 Self {
225 batch_size,
226 semaphore: Arc::new(Semaphore::new(max_concurrent)),
227 }
228 }
229
230 pub async fn process_entity_batch(
231 &self,
232 model: Arc<RwLock<Box<dyn EmbeddingModel>>>,
233 entities: Vec<String>,
234 ) -> Result<Vec<(String, Result<Vector>)>> {
235 let _permit = self
236 .semaphore
237 .acquire()
238 .await
239 .expect("semaphore should not be closed");
240
241 let mut results = Vec::new();
242
243 for chunk in entities.chunks(self.batch_size) {
244 let model_guard = model.read().expect("rwlock should not be poisoned");
245 for entity in chunk {
246 let result = model_guard.get_entity_embedding(entity);
247 results.push((entity.clone(), result));
248 }
249 }
250
251 Ok(results)
252 }
253
254 pub async fn process_relation_batch(
255 &self,
256 model: Arc<RwLock<Box<dyn EmbeddingModel>>>,
257 relations: Vec<String>,
258 ) -> Result<Vec<(String, Result<Vector>)>> {
259 let _permit = self
260 .semaphore
261 .acquire()
262 .await
263 .expect("semaphore should not be closed");
264
265 let mut results = Vec::new();
266
267 for chunk in relations.chunks(self.batch_size) {
268 let model_guard = model.read().expect("rwlock should not be poisoned");
269 for relation in chunk {
270 let result = model_guard.get_relation_embedding(relation);
271 results.push((relation.clone(), result));
272 }
273 }
274
275 Ok(results)
276 }
277}
278
279impl InferenceEngine {
280 pub fn new(model: Box<dyn EmbeddingModel>, config: InferenceConfig) -> Self {
282 let cache = Arc::new(RwLock::new(InferenceCache::new(
283 config.cache_size,
284 config.cache_ttl,
285 )));
286
287 let batch_processor = BatchProcessor::new(config.batch_size, config.max_concurrent);
288
289 Self {
290 model: Arc::new(RwLock::new(model)),
291 cache,
292 config,
293 batch_processor,
294 }
295 }
296
297 pub async fn get_entity_embedding(&self, entity: &str) -> Result<Vector> {
299 if self.config.enable_caching {
301 if let Ok(mut cache) = self.cache.write() {
302 if let Some(cached) = cache.get_entity_embedding(entity) {
303 debug!("Cache hit for entity: {}", entity);
304 return Ok(cached.clone());
305 }
306 }
307 }
308
309 let embedding = {
311 let model_guard = self.model.read().expect("rwlock should not be poisoned");
312 model_guard.get_entity_embedding(entity)?
313 };
314
315 if self.config.enable_caching {
317 if let Ok(mut cache) = self.cache.write() {
318 cache.cache_entity_embedding(entity.to_string(), embedding.clone());
319 }
320 }
321
322 Ok(embedding)
323 }
324
325 pub async fn get_relation_embedding(&self, relation: &str) -> Result<Vector> {
327 if self.config.enable_caching {
329 if let Ok(mut cache) = self.cache.write() {
330 if let Some(cached) = cache.get_relation_embedding(relation) {
331 debug!("Cache hit for relation: {}", relation);
332 return Ok(cached.clone());
333 }
334 }
335 }
336
337 let embedding = {
339 let model_guard = self.model.read().expect("rwlock should not be poisoned");
340 model_guard.get_relation_embedding(relation)?
341 };
342
343 if self.config.enable_caching {
345 if let Ok(mut cache) = self.cache.write() {
346 cache.cacherelation_embedding(relation.to_string(), embedding.clone());
347 }
348 }
349
350 Ok(embedding)
351 }
352
353 pub async fn score_triple(&self, subject: &str, predicate: &str, object: &str) -> Result<f64> {
355 let cache_key = format!("{subject}|{predicate}|{object}");
356
357 if self.config.enable_caching {
359 if let Ok(mut cache) = self.cache.write() {
360 if let Some(cached_score) = cache.get_triple_score(&cache_key) {
361 debug!("Cache hit for triple: {}", cache_key);
362 return Ok(cached_score);
363 }
364 }
365 }
366
367 let score = {
369 let model_guard = self.model.read().expect("rwlock should not be poisoned");
370 model_guard.score_triple(subject, predicate, object)?
371 };
372
373 if self.config.enable_caching {
375 if let Ok(mut cache) = self.cache.write() {
376 cache.cache_triple_score(cache_key, score);
377 }
378 }
379
380 Ok(score)
381 }
382
383 pub async fn get_entity_embeddings_batch(
385 &self,
386 entities: Vec<String>,
387 ) -> Result<Vec<(String, Result<Vector>)>> {
388 self.batch_processor
389 .process_entity_batch(self.model.clone(), entities)
390 .await
391 }
392
393 pub async fn get_relation_embeddings_batch(
395 &self,
396 relations: Vec<String>,
397 ) -> Result<Vec<(String, Result<Vector>)>> {
398 self.batch_processor
399 .process_relation_batch(self.model.clone(), relations)
400 .await
401 }
402
403 pub async fn warm_up_cache(&self) -> Result<()> {
405 if !self.config.warm_up_cache {
406 return Ok(());
407 }
408
409 info!("Warming up inference cache...");
410
411 let (entities, relations) = {
412 let model_guard = self.model.read().expect("rwlock should not be poisoned");
413 (model_guard.get_entities(), model_guard.get_relations())
414 };
415
416 for entity in entities.iter().take(self.config.cache_size / 2) {
418 let _ = self.get_entity_embedding(entity).await;
419 }
420
421 for relation in relations.iter().take(self.config.cache_size / 2) {
423 let _ = self.get_relation_embedding(relation).await;
424 }
425
426 info!("Cache warm-up completed");
427 Ok(())
428 }
429
430 pub fn cache_stats(&self) -> Result<CacheStats> {
432 let cache_guard = self.cache.read().expect("rwlock should not be poisoned");
433 Ok(cache_guard.stats())
434 }
435
436 pub fn clear_cache(&self) -> Result<()> {
438 let mut cache_guard = self.cache.write().expect("rwlock should not be poisoned");
439 cache_guard.clear();
440 info!("Inference cache cleared");
441 Ok(())
442 }
443
444 pub fn model_stats(&self) -> Result<ModelStats> {
446 let model_guard = self.model.read().expect("rwlock should not be poisoned");
447 Ok(model_guard.get_stats())
448 }
449}
450
451#[derive(Debug, Clone)]
453pub struct InferenceMetrics {
454 pub total_requests: u64,
455 pub cache_hits: u64,
456 pub cache_misses: u64,
457 pub average_latency_ms: f64,
458 pub throughput_per_second: f64,
459}
460
461impl InferenceMetrics {
462 pub fn cache_hit_rate(&self) -> f64 {
463 if self.total_requests > 0 {
464 self.cache_hits as f64 / self.total_requests as f64
465 } else {
466 0.0
467 }
468 }
469}
470
471#[cfg(test)]
472mod tests {
473 use super::*;
474 use crate::models::TransE;
475 use crate::ModelConfig;
476
477 #[tokio::test]
478 async fn test_inference_cache() {
479 let mut cache = InferenceCache::new(2, 3600);
480
481 let vec1 = Vector::new(vec![1.0, 2.0, 3.0]);
482 let vec2 = Vector::new(vec![4.0, 5.0, 6.0]);
483
484 cache.cache_entity_embedding("entity1".to_string(), vec1.clone());
485 cache.cache_entity_embedding("entity2".to_string(), vec2.clone());
486
487 assert!(cache.get_entity_embedding("entity1").is_some());
488 assert!(cache.get_entity_embedding("entity2").is_some());
489
490 let vec3 = Vector::new(vec![7.0, 8.0, 9.0]);
492 cache.cache_entity_embedding("entity3".to_string(), vec3);
493
494 assert_eq!(cache.entity_cache.len(), 2);
495 }
496
497 #[tokio::test]
498 async fn test_inference_engine() -> Result<()> {
499 let config = ModelConfig::default().with_dimensions(10).with_seed(42);
500 let model = TransE::new(config);
501
502 let inference_config = InferenceConfig {
503 cache_size: 100,
504 enable_caching: true,
505 ..Default::default()
506 };
507
508 let engine = InferenceEngine::new(Box::new(model), inference_config);
509
510 let stats = engine.model_stats()?;
512 assert_eq!(stats.dimensions, 10);
513
514 Ok(())
515 }
516}