Skip to main content

heliosdb_proxy/cache/
l3_semantic.rs

1//! L3 Semantic Cache
2//!
3//! Vector similarity cache for AI/RAG workloads.
4//! Uses embeddings to find semantically similar queries.
5
6use std::sync::RwLock;
7use std::time::{Duration, Instant};
8
9use bytes::Bytes;
10use dashmap::DashMap;
11use tokio::sync::Semaphore;
12
13use super::config::L3Config;
14use super::result::{CachedResult, L3Entry};
15use super::CacheContext;
16
17/// L3 semantic cache (vector similarity)
18///
19/// This cache stores query embeddings and uses cosine similarity
20/// to find matches even when queries are not identical.
21#[derive(Debug)]
22pub struct L3SemanticCache {
23    /// Configuration
24    config: L3Config,
25
26    /// Cache entries
27    entries: RwLock<Vec<L3Entry>>,
28
29    /// Embedding service client
30    embedding_client: EmbeddingClient,
31
32    /// Semaphore for limiting concurrent embedding requests
33    embedding_semaphore: Semaphore,
34
35    /// Cache for computed embeddings (query hash -> embedding)
36    embedding_cache: DashMap<u64, Vec<f32>>,
37}
38
39/// Embedding service client (Ollama)
40#[derive(Debug)]
41pub struct EmbeddingClient {
42    /// Ollama endpoint
43    endpoint: String,
44
45    /// Model name
46    model: String,
47
48    /// Expected embedding dimension
49    dimension: usize,
50
51    /// HTTP client
52    client: reqwest::Client,
53}
54
55impl L3SemanticCache {
56    /// Create a new L3 semantic cache
57    pub fn new(config: L3Config) -> Self {
58        let embedding_client = EmbeddingClient::new(
59            config.embedding_endpoint.clone(),
60            config.embedding_model.clone(),
61            config.embedding_dim,
62        );
63
64        Self {
65            config: config.clone(),
66            entries: RwLock::new(Vec::with_capacity(config.max_entries)),
67            embedding_client,
68            embedding_semaphore: Semaphore::new(10), // Max 10 concurrent embedding requests
69            embedding_cache: DashMap::new(),
70        }
71    }
72
73    /// Look up a query using semantic similarity
74    pub async fn get(&self, query: &str, context: &CacheContext) -> Option<CachedResult> {
75        if !self.config.enabled {
76            return None;
77        }
78
79        // Get embedding for the query
80        let embedding = self.get_embedding(query).await?;
81
82        // Find best match
83        let entries = self.entries.read().ok()?;
84
85        let mut best_match: Option<(f32, &L3Entry)> = None;
86
87        for entry in entries.iter() {
88            // Skip expired entries
89            if entry.is_expired() {
90                continue;
91            }
92
93            // Check context match (database, user for RLS)
94            if entry.context.database != context.database {
95                continue;
96            }
97
98            if entry.context.user != context.user {
99                continue;
100            }
101
102            // Calculate similarity
103            let similarity = entry.similarity(&embedding);
104
105            if similarity >= self.config.similarity_threshold {
106                if let Some((best_sim, _)) = best_match {
107                    if similarity > best_sim {
108                        best_match = Some((similarity, entry));
109                    }
110                } else {
111                    best_match = Some((similarity, entry));
112                }
113            }
114        }
115
116        best_match.map(|(_, entry)| entry.result.clone())
117    }
118
119    /// Store a query and result in the semantic cache
120    pub async fn put(&self, query: &str, context: &CacheContext, result: CachedResult) {
121        if !self.config.enabled {
122            return;
123        }
124
125        // Get embedding for the query
126        let embedding = match self.get_embedding(query).await {
127            Some(e) => e,
128            None => return,
129        };
130
131        // Create entry
132        let mut entry = L3Entry::new(
133            query.to_string(),
134            embedding,
135            context.clone(),
136            result,
137        );
138
139        // Enforce TTL from config
140        if entry.result.ttl > self.config.ttl {
141            entry.result.ttl = self.config.ttl;
142        }
143
144        let mut entries = match self.entries.write() {
145            Ok(e) => e,
146            Err(_) => return,
147        };
148
149        // Check capacity and evict if needed
150        if entries.len() >= self.config.max_entries {
151            self.evict(&mut entries);
152        }
153
154        entries.push(entry);
155    }
156
157    /// Clear all entries
158    pub async fn clear(&self) {
159        if let Ok(mut entries) = self.entries.write() {
160            entries.clear();
161        }
162        self.embedding_cache.clear();
163    }
164
165    /// Get entry count
166    pub fn len(&self) -> usize {
167        self.entries.read().map(|e| e.len()).unwrap_or(0)
168    }
169
170    /// Check if cache is empty
171    pub fn is_empty(&self) -> bool {
172        self.len() == 0
173    }
174
175    /// Get cache statistics
176    pub fn stats(&self) -> L3CacheStats {
177        let entries = self.entries.read().unwrap();
178
179        let total_access: u64 = entries.iter().map(|e| e.access_count).sum();
180        let avg_embedding_size = if entries.is_empty() {
181            0
182        } else {
183            entries.first().map(|e| e.embedding.len()).unwrap_or(0)
184        };
185
186        L3CacheStats {
187            entry_count: entries.len(),
188            max_entries: self.config.max_entries,
189            similarity_threshold: self.config.similarity_threshold,
190            embedding_dimension: avg_embedding_size,
191            total_accesses: total_access,
192            embedding_cache_size: self.embedding_cache.len(),
193        }
194    }
195
196    /// Get embedding for a query (cached)
197    async fn get_embedding(&self, query: &str) -> Option<Vec<f32>> {
198        // Check embedding cache first
199        let query_hash = quick_hash(query);
200
201        if let Some(cached) = self.embedding_cache.get(&query_hash) {
202            return Some(cached.clone());
203        }
204
205        // Acquire semaphore to limit concurrent requests
206        let _permit = self.embedding_semaphore.acquire().await.ok()?;
207
208        // Call embedding service
209        let embedding = self.embedding_client.embed(query).await?;
210
211        // Cache the embedding
212        self.embedding_cache.insert(query_hash, embedding.clone());
213
214        Some(embedding)
215    }
216
217    /// Evict entries to make room for new ones
218    fn evict(&self, entries: &mut Vec<L3Entry>) {
219        // First, remove expired entries
220        entries.retain(|e| !e.is_expired());
221
222        // If still full, remove LRU entries
223        while entries.len() >= self.config.max_entries {
224            if let Some(lru_idx) = entries
225                .iter()
226                .enumerate()
227                .min_by_key(|(_, e)| e.last_access)
228                .map(|(i, _)| i)
229            {
230                entries.remove(lru_idx);
231            } else {
232                break;
233            }
234        }
235    }
236
237    /// Check if the embedding service is available
238    pub async fn health_check(&self) -> bool {
239        self.embedding_client.health_check().await
240    }
241}
242
243impl EmbeddingClient {
244    /// Create a new embedding client
245    pub fn new(endpoint: String, model: String, dimension: usize) -> Self {
246        let client = reqwest::Client::builder()
247            .timeout(Duration::from_secs(30))
248            .build()
249            .unwrap_or_default();
250
251        Self {
252            endpoint,
253            model,
254            dimension,
255            client,
256        }
257    }
258
259    /// Generate embedding for text using Ollama
260    pub async fn embed(&self, text: &str) -> Option<Vec<f32>> {
261        let url = format!("{}/api/embeddings", self.endpoint);
262
263        let request = serde_json::json!({
264            "model": self.model,
265            "prompt": text
266        });
267
268        let response = self.client
269            .post(&url)
270            .json(&request)
271            .send()
272            .await
273            .ok()?;
274
275        if !response.status().is_success() {
276            return None;
277        }
278
279        let body: serde_json::Value = response.json().await.ok()?;
280
281        let embedding = body.get("embedding")?
282            .as_array()?
283            .iter()
284            .filter_map(|v| v.as_f64().map(|f| f as f32))
285            .collect::<Vec<f32>>();
286
287        // Validate dimension
288        if embedding.len() != self.dimension {
289            // Try to handle dimension mismatch gracefully
290            if embedding.len() > self.dimension {
291                return Some(embedding[..self.dimension].to_vec());
292            } else {
293                // Pad with zeros (not ideal, but better than failing)
294                let mut padded = embedding;
295                padded.resize(self.dimension, 0.0);
296                return Some(padded);
297            }
298        }
299
300        Some(embedding)
301    }
302
303    /// Check if Ollama is available
304    pub async fn health_check(&self) -> bool {
305        let url = format!("{}/api/tags", self.endpoint);
306
307        match self.client.get(&url).send().await {
308            Ok(response) => response.status().is_success(),
309            Err(_) => false,
310        }
311    }
312
313    /// List available models
314    pub async fn list_models(&self) -> Option<Vec<String>> {
315        let url = format!("{}/api/tags", self.endpoint);
316
317        let response = self.client.get(&url).send().await.ok()?;
318        let body: serde_json::Value = response.json().await.ok()?;
319
320        let models = body.get("models")?
321            .as_array()?
322            .iter()
323            .filter_map(|m| m.get("name")?.as_str().map(String::from))
324            .collect();
325
326        Some(models)
327    }
328
329    /// Pull a model if not available
330    pub async fn pull_model(&self) -> Result<(), String> {
331        let url = format!("{}/api/pull", self.endpoint);
332
333        let request = serde_json::json!({
334            "name": self.model
335        });
336
337        let response = self.client
338            .post(&url)
339            .json(&request)
340            .send()
341            .await
342            .map_err(|e| e.to_string())?;
343
344        if response.status().is_success() {
345            Ok(())
346        } else {
347            Err(format!("Failed to pull model: {}", response.status()))
348        }
349    }
350}
351
352/// L3 cache statistics
353#[derive(Debug, Clone)]
354pub struct L3CacheStats {
355    /// Number of entries
356    pub entry_count: usize,
357
358    /// Maximum entries
359    pub max_entries: usize,
360
361    /// Similarity threshold
362    pub similarity_threshold: f32,
363
364    /// Embedding dimension
365    pub embedding_dimension: usize,
366
367    /// Total accesses
368    pub total_accesses: u64,
369
370    /// Embedding cache size
371    pub embedding_cache_size: usize,
372}
373
374/// Quick hash for embedding cache key
375fn quick_hash(s: &str) -> u64 {
376    use std::collections::hash_map::DefaultHasher;
377    use std::hash::{Hash, Hasher};
378
379    let mut hasher = DefaultHasher::new();
380    s.hash(&mut hasher);
381    hasher.finish()
382}
383
384/// Compute cosine similarity between two vectors
385pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
386    if a.len() != b.len() || a.is_empty() {
387        return 0.0;
388    }
389
390    let mut dot_product = 0.0f32;
391    let mut norm_a = 0.0f32;
392    let mut norm_b = 0.0f32;
393
394    for (x, y) in a.iter().zip(b.iter()) {
395        dot_product += x * y;
396        norm_a += x * x;
397        norm_b += y * y;
398    }
399
400    let norm_a = norm_a.sqrt();
401    let norm_b = norm_b.sqrt();
402
403    if norm_a == 0.0 || norm_b == 0.0 {
404        return 0.0;
405    }
406
407    dot_product / (norm_a * norm_b)
408}
409
410/// Generate a random embedding for testing
411#[cfg(test)]
412fn random_embedding(dim: usize) -> Vec<f32> {
413    use std::collections::hash_map::DefaultHasher;
414    use std::hash::{Hash, Hasher};
415
416    let mut hasher = DefaultHasher::new();
417    std::time::Instant::now().hash(&mut hasher);
418    let seed = hasher.finish();
419
420    (0..dim)
421        .map(|i| {
422            let x = ((seed.wrapping_add(i as u64) as f64) * 0.0001).sin() as f32;
423            x
424        })
425        .collect()
426}
427
428#[cfg(test)]
429mod tests {
430    use super::*;
431
432    fn create_result(data: &str) -> CachedResult {
433        CachedResult::new(
434            Bytes::from(data.to_string()),
435            1,
436            Duration::from_secs(60),
437            vec!["test".to_string()],
438            Duration::from_millis(5),
439        )
440    }
441
442    #[test]
443    fn test_cosine_similarity() {
444        // Same vector = 1.0
445        let a = vec![1.0, 0.0, 0.0];
446        assert!((cosine_similarity(&a, &a) - 1.0).abs() < 0.001);
447
448        // Orthogonal vectors = 0.0
449        let b = vec![0.0, 1.0, 0.0];
450        assert!(cosine_similarity(&a, &b).abs() < 0.001);
451
452        // Opposite vectors = -1.0
453        let c = vec![-1.0, 0.0, 0.0];
454        assert!((cosine_similarity(&a, &c) + 1.0).abs() < 0.001);
455
456        // Empty vectors = 0.0
457        assert!(cosine_similarity(&[], &[]).abs() < 0.001);
458
459        // Different lengths = 0.0
460        let d = vec![1.0, 0.0];
461        assert!(cosine_similarity(&a, &d).abs() < 0.001);
462    }
463
464    #[test]
465    fn test_l3_entry_similarity() {
466        let result = create_result("test");
467        let ctx = CacheContext::default();
468
469        let entry = L3Entry::new(
470            "SELECT * FROM users".to_string(),
471            vec![0.5, 0.5, 0.5, 0.5],
472            ctx,
473            result,
474        );
475
476        // High similarity
477        let similar = vec![0.5, 0.5, 0.5, 0.5];
478        assert!((entry.similarity(&similar) - 1.0).abs() < 0.001);
479
480        // Moderate similarity
481        let moderate = vec![0.5, 0.5, 0.0, 0.0];
482        assert!(entry.similarity(&moderate) > 0.5);
483        assert!(entry.similarity(&moderate) < 1.0);
484    }
485
486    #[test]
487    fn test_quick_hash() {
488        let hash1 = quick_hash("SELECT * FROM users");
489        let hash2 = quick_hash("SELECT * FROM users");
490        let hash3 = quick_hash("SELECT * FROM orders");
491
492        assert_eq!(hash1, hash2);
493        assert_ne!(hash1, hash3);
494    }
495
496    #[test]
497    fn test_random_embedding() {
498        let emb = random_embedding(384);
499        assert_eq!(emb.len(), 384);
500    }
501
502    #[tokio::test]
503    async fn test_l3_cache_disabled() {
504        let config = L3Config {
505            enabled: false,
506            ..Default::default()
507        };
508        let cache = L3SemanticCache::new(config);
509
510        let ctx = CacheContext::default();
511        let result = cache.get("test query", &ctx).await;
512        assert!(result.is_none());
513    }
514
515    #[test]
516    fn test_embedding_client_creation() {
517        let client = EmbeddingClient::new(
518            "http://localhost:11434".to_string(),
519            "all-minilm".to_string(),
520            384,
521        );
522
523        assert_eq!(client.endpoint, "http://localhost:11434");
524        assert_eq!(client.model, "all-minilm");
525        assert_eq!(client.dimension, 384);
526    }
527
528    #[test]
529    fn test_l3_stats() {
530        let config = L3Config {
531            enabled: true,
532            max_entries: 1000,
533            similarity_threshold: 0.9,
534            ..Default::default()
535        };
536        let cache = L3SemanticCache::new(config);
537
538        let stats = cache.stats();
539        assert_eq!(stats.entry_count, 0);
540        assert_eq!(stats.max_entries, 1000);
541        assert!((stats.similarity_threshold - 0.9).abs() < 0.001);
542    }
543
544    #[test]
545    fn test_eviction() {
546        // Test that eviction logic works
547        let config = L3Config {
548            enabled: true,
549            max_entries: 3,
550            ..Default::default()
551        };
552        let cache = L3SemanticCache::new(config);
553
554        // Manually add entries for testing
555        {
556            let mut entries = cache.entries.write().unwrap();
557
558            for i in 0..5 {
559                let ctx = CacheContext::default();
560                let result = create_result(&format!("result_{}", i));
561                let embedding = random_embedding(384);
562
563                entries.push(L3Entry::new(
564                    format!("query_{}", i),
565                    embedding,
566                    ctx,
567                    result,
568                ));
569
570                // Evict if needed
571                cache.evict(&mut entries);
572            }
573
574            // Should have at most max_entries
575            assert!(entries.len() <= 3);
576        }
577    }
578}