Skip to main content

heliosdb_proxy/cache/
l3_semantic.rs

1//! L3 Semantic Cache
2//!
3//! Vector similarity cache for AI/RAG workloads.
4//! Uses embeddings to find semantically similar queries.
5
6use std::sync::RwLock;
7use std::time::Duration;
8
9use dashmap::DashMap;
10use tokio::sync::Semaphore;
11
12use super::config::L3Config;
13use super::result::{CachedResult, L3Entry};
14use super::CacheContext;
15
16/// L3 semantic cache (vector similarity)
17///
18/// This cache stores query embeddings and uses cosine similarity
19/// to find matches even when queries are not identical.
20#[derive(Debug)]
21pub struct L3SemanticCache {
22    /// Configuration
23    config: L3Config,
24
25    /// Cache entries
26    entries: RwLock<Vec<L3Entry>>,
27
28    /// Embedding service client
29    embedding_client: EmbeddingClient,
30
31    /// Semaphore for limiting concurrent embedding requests
32    embedding_semaphore: Semaphore,
33
34    /// Cache for computed embeddings (query hash -> embedding)
35    embedding_cache: DashMap<u64, Vec<f32>>,
36}
37
38/// Embedding service client (Ollama)
39#[derive(Debug)]
40pub struct EmbeddingClient {
41    /// Ollama endpoint
42    endpoint: String,
43
44    /// Model name
45    model: String,
46
47    /// Expected embedding dimension
48    dimension: usize,
49
50    /// HTTP client
51    client: reqwest::Client,
52}
53
54impl L3SemanticCache {
55    /// Create a new L3 semantic cache
56    pub fn new(config: L3Config) -> Self {
57        let embedding_client = EmbeddingClient::new(
58            config.embedding_endpoint.clone(),
59            config.embedding_model.clone(),
60            config.embedding_dim,
61        );
62
63        Self {
64            config: config.clone(),
65            entries: RwLock::new(Vec::with_capacity(config.max_entries)),
66            embedding_client,
67            embedding_semaphore: Semaphore::new(10), // Max 10 concurrent embedding requests
68            embedding_cache: DashMap::new(),
69        }
70    }
71
72    /// Look up a query using semantic similarity
73    pub async fn get(&self, query: &str, context: &CacheContext) -> Option<CachedResult> {
74        if !self.config.enabled {
75            return None;
76        }
77
78        // Get embedding for the query
79        let embedding = self.get_embedding(query).await?;
80
81        // Find best match
82        let entries = self.entries.read().ok()?;
83
84        let mut best_match: Option<(f32, &L3Entry)> = None;
85
86        for entry in entries.iter() {
87            // Skip expired entries
88            if entry.is_expired() {
89                continue;
90            }
91
92            // Check context match (database, user for RLS)
93            if entry.context.database != context.database {
94                continue;
95            }
96
97            if entry.context.user != context.user {
98                continue;
99            }
100
101            // Calculate similarity
102            let similarity = entry.similarity(&embedding);
103
104            if similarity >= self.config.similarity_threshold {
105                if let Some((best_sim, _)) = best_match {
106                    if similarity > best_sim {
107                        best_match = Some((similarity, entry));
108                    }
109                } else {
110                    best_match = Some((similarity, entry));
111                }
112            }
113        }
114
115        best_match.map(|(_, entry)| entry.result.clone())
116    }
117
118    /// Store a query and result in the semantic cache
119    pub async fn put(&self, query: &str, context: &CacheContext, result: CachedResult) {
120        if !self.config.enabled {
121            return;
122        }
123
124        // Get embedding for the query
125        let embedding = match self.get_embedding(query).await {
126            Some(e) => e,
127            None => return,
128        };
129
130        // Create entry
131        let mut entry = L3Entry::new(query.to_string(), embedding, context.clone(), result);
132
133        // Enforce TTL from config
134        if entry.result.ttl > self.config.ttl {
135            entry.result.ttl = self.config.ttl;
136        }
137
138        let mut entries = match self.entries.write() {
139            Ok(e) => e,
140            Err(_) => return,
141        };
142
143        // Check capacity and evict if needed
144        if entries.len() >= self.config.max_entries {
145            self.evict(&mut entries);
146        }
147
148        entries.push(entry);
149    }
150
151    /// Clear all entries
152    pub async fn clear(&self) {
153        if let Ok(mut entries) = self.entries.write() {
154            entries.clear();
155        }
156        self.embedding_cache.clear();
157    }
158
159    /// Get entry count
160    pub fn len(&self) -> usize {
161        self.entries.read().map(|e| e.len()).unwrap_or(0)
162    }
163
164    /// Check if cache is empty
165    pub fn is_empty(&self) -> bool {
166        self.len() == 0
167    }
168
169    /// Get cache statistics
170    pub fn stats(&self) -> L3CacheStats {
171        let entries = self.entries.read().unwrap();
172
173        let total_access: u64 = entries.iter().map(|e| e.access_count).sum();
174        let avg_embedding_size = if entries.is_empty() {
175            0
176        } else {
177            entries.first().map(|e| e.embedding.len()).unwrap_or(0)
178        };
179
180        L3CacheStats {
181            entry_count: entries.len(),
182            max_entries: self.config.max_entries,
183            similarity_threshold: self.config.similarity_threshold,
184            embedding_dimension: avg_embedding_size,
185            total_accesses: total_access,
186            embedding_cache_size: self.embedding_cache.len(),
187        }
188    }
189
190    /// Get embedding for a query (cached)
191    async fn get_embedding(&self, query: &str) -> Option<Vec<f32>> {
192        // Check embedding cache first
193        let query_hash = quick_hash(query);
194
195        if let Some(cached) = self.embedding_cache.get(&query_hash) {
196            return Some(cached.clone());
197        }
198
199        // Acquire semaphore to limit concurrent requests
200        let _permit = self.embedding_semaphore.acquire().await.ok()?;
201
202        // Call embedding service
203        let embedding = self.embedding_client.embed(query).await?;
204
205        // Cache the embedding
206        self.embedding_cache.insert(query_hash, embedding.clone());
207
208        Some(embedding)
209    }
210
211    /// Evict entries to make room for new ones
212    fn evict(&self, entries: &mut Vec<L3Entry>) {
213        // First, remove expired entries
214        entries.retain(|e| !e.is_expired());
215
216        // If still full, remove LRU entries
217        while entries.len() >= self.config.max_entries {
218            if let Some(lru_idx) = entries
219                .iter()
220                .enumerate()
221                .min_by_key(|(_, e)| e.last_access)
222                .map(|(i, _)| i)
223            {
224                entries.remove(lru_idx);
225            } else {
226                break;
227            }
228        }
229    }
230
231    /// Check if the embedding service is available
232    pub async fn health_check(&self) -> bool {
233        self.embedding_client.health_check().await
234    }
235}
236
237impl EmbeddingClient {
238    /// Create a new embedding client
239    pub fn new(endpoint: String, model: String, dimension: usize) -> Self {
240        let client = reqwest::Client::builder()
241            .timeout(Duration::from_secs(30))
242            .build()
243            .unwrap_or_default();
244
245        Self {
246            endpoint,
247            model,
248            dimension,
249            client,
250        }
251    }
252
253    /// Generate embedding for text using Ollama
254    pub async fn embed(&self, text: &str) -> Option<Vec<f32>> {
255        let url = format!("{}/api/embeddings", self.endpoint);
256
257        let request = serde_json::json!({
258            "model": self.model,
259            "prompt": text
260        });
261
262        let response = self.client.post(&url).json(&request).send().await.ok()?;
263
264        if !response.status().is_success() {
265            return None;
266        }
267
268        let body: serde_json::Value = response.json().await.ok()?;
269
270        let embedding = body
271            .get("embedding")?
272            .as_array()?
273            .iter()
274            .filter_map(|v| v.as_f64().map(|f| f as f32))
275            .collect::<Vec<f32>>();
276
277        // Validate dimension
278        if embedding.len() != self.dimension {
279            // Try to handle dimension mismatch gracefully
280            if embedding.len() > self.dimension {
281                return Some(embedding[..self.dimension].to_vec());
282            } else {
283                // Pad with zeros (not ideal, but better than failing)
284                let mut padded = embedding;
285                padded.resize(self.dimension, 0.0);
286                return Some(padded);
287            }
288        }
289
290        Some(embedding)
291    }
292
293    /// Check if Ollama is available
294    pub async fn health_check(&self) -> bool {
295        let url = format!("{}/api/tags", self.endpoint);
296
297        match self.client.get(&url).send().await {
298            Ok(response) => response.status().is_success(),
299            Err(_) => false,
300        }
301    }
302
303    /// List available models
304    pub async fn list_models(&self) -> Option<Vec<String>> {
305        let url = format!("{}/api/tags", self.endpoint);
306
307        let response = self.client.get(&url).send().await.ok()?;
308        let body: serde_json::Value = response.json().await.ok()?;
309
310        let models = body
311            .get("models")?
312            .as_array()?
313            .iter()
314            .filter_map(|m| m.get("name")?.as_str().map(String::from))
315            .collect();
316
317        Some(models)
318    }
319
320    /// Pull a model if not available
321    pub async fn pull_model(&self) -> Result<(), String> {
322        let url = format!("{}/api/pull", self.endpoint);
323
324        let request = serde_json::json!({
325            "name": self.model
326        });
327
328        let response = self
329            .client
330            .post(&url)
331            .json(&request)
332            .send()
333            .await
334            .map_err(|e| e.to_string())?;
335
336        if response.status().is_success() {
337            Ok(())
338        } else {
339            Err(format!("Failed to pull model: {}", response.status()))
340        }
341    }
342}
343
344/// L3 cache statistics
345#[derive(Debug, Clone)]
346pub struct L3CacheStats {
347    /// Number of entries
348    pub entry_count: usize,
349
350    /// Maximum entries
351    pub max_entries: usize,
352
353    /// Similarity threshold
354    pub similarity_threshold: f32,
355
356    /// Embedding dimension
357    pub embedding_dimension: usize,
358
359    /// Total accesses
360    pub total_accesses: u64,
361
362    /// Embedding cache size
363    pub embedding_cache_size: usize,
364}
365
366/// Quick hash for embedding cache key
367fn quick_hash(s: &str) -> u64 {
368    use std::collections::hash_map::DefaultHasher;
369    use std::hash::{Hash, Hasher};
370
371    let mut hasher = DefaultHasher::new();
372    s.hash(&mut hasher);
373    hasher.finish()
374}
375
376/// Compute cosine similarity between two vectors
377pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
378    if a.len() != b.len() || a.is_empty() {
379        return 0.0;
380    }
381
382    let mut dot_product = 0.0f32;
383    let mut norm_a = 0.0f32;
384    let mut norm_b = 0.0f32;
385
386    for (x, y) in a.iter().zip(b.iter()) {
387        dot_product += x * y;
388        norm_a += x * x;
389        norm_b += y * y;
390    }
391
392    let norm_a = norm_a.sqrt();
393    let norm_b = norm_b.sqrt();
394
395    if norm_a == 0.0 || norm_b == 0.0 {
396        return 0.0;
397    }
398
399    dot_product / (norm_a * norm_b)
400}
401
402/// Generate a random embedding for testing
403#[cfg(test)]
404fn random_embedding(dim: usize) -> Vec<f32> {
405    use std::collections::hash_map::DefaultHasher;
406    use std::hash::{Hash, Hasher};
407
408    let mut hasher = DefaultHasher::new();
409    std::time::Instant::now().hash(&mut hasher);
410    let seed = hasher.finish();
411
412    (0..dim)
413        .map(|i| {
414            let x = ((seed.wrapping_add(i as u64) as f64) * 0.0001).sin() as f32;
415            x
416        })
417        .collect()
418}
419
420#[cfg(test)]
421mod tests {
422    use super::*;
423    use bytes::Bytes;
424
425    fn create_result(data: &str) -> CachedResult {
426        CachedResult::new(
427            Bytes::from(data.to_string()),
428            1,
429            Duration::from_secs(60),
430            vec!["test".to_string()],
431            Duration::from_millis(5),
432        )
433    }
434
435    #[test]
436    fn test_cosine_similarity() {
437        // Same vector = 1.0
438        let a = vec![1.0, 0.0, 0.0];
439        assert!((cosine_similarity(&a, &a) - 1.0).abs() < 0.001);
440
441        // Orthogonal vectors = 0.0
442        let b = vec![0.0, 1.0, 0.0];
443        assert!(cosine_similarity(&a, &b).abs() < 0.001);
444
445        // Opposite vectors = -1.0
446        let c = vec![-1.0, 0.0, 0.0];
447        assert!((cosine_similarity(&a, &c) + 1.0).abs() < 0.001);
448
449        // Empty vectors = 0.0
450        assert!(cosine_similarity(&[], &[]).abs() < 0.001);
451
452        // Different lengths = 0.0
453        let d = vec![1.0, 0.0];
454        assert!(cosine_similarity(&a, &d).abs() < 0.001);
455    }
456
457    #[test]
458    fn test_l3_entry_similarity() {
459        let result = create_result("test");
460        let ctx = CacheContext::default();
461
462        let entry = L3Entry::new(
463            "SELECT * FROM users".to_string(),
464            vec![0.5, 0.5, 0.5, 0.5],
465            ctx,
466            result,
467        );
468
469        // High similarity
470        let similar = vec![0.5, 0.5, 0.5, 0.5];
471        assert!((entry.similarity(&similar) - 1.0).abs() < 0.001);
472
473        // Moderate similarity
474        let moderate = vec![0.5, 0.5, 0.0, 0.0];
475        assert!(entry.similarity(&moderate) > 0.5);
476        assert!(entry.similarity(&moderate) < 1.0);
477    }
478
479    #[test]
480    fn test_quick_hash() {
481        let hash1 = quick_hash("SELECT * FROM users");
482        let hash2 = quick_hash("SELECT * FROM users");
483        let hash3 = quick_hash("SELECT * FROM orders");
484
485        assert_eq!(hash1, hash2);
486        assert_ne!(hash1, hash3);
487    }
488
489    #[test]
490    fn test_random_embedding() {
491        let emb = random_embedding(384);
492        assert_eq!(emb.len(), 384);
493    }
494
495    #[tokio::test]
496    async fn test_l3_cache_disabled() {
497        let config = L3Config {
498            enabled: false,
499            ..Default::default()
500        };
501        let cache = L3SemanticCache::new(config);
502
503        let ctx = CacheContext::default();
504        let result = cache.get("test query", &ctx).await;
505        assert!(result.is_none());
506    }
507
508    #[test]
509    fn test_embedding_client_creation() {
510        let client = EmbeddingClient::new(
511            "http://localhost:11434".to_string(),
512            "all-minilm".to_string(),
513            384,
514        );
515
516        assert_eq!(client.endpoint, "http://localhost:11434");
517        assert_eq!(client.model, "all-minilm");
518        assert_eq!(client.dimension, 384);
519    }
520
521    #[test]
522    fn test_l3_stats() {
523        let config = L3Config {
524            enabled: true,
525            max_entries: 1000,
526            similarity_threshold: 0.9,
527            ..Default::default()
528        };
529        let cache = L3SemanticCache::new(config);
530
531        let stats = cache.stats();
532        assert_eq!(stats.entry_count, 0);
533        assert_eq!(stats.max_entries, 1000);
534        assert!((stats.similarity_threshold - 0.9).abs() < 0.001);
535    }
536
537    #[test]
538    fn test_eviction() {
539        // Test that eviction logic works
540        let config = L3Config {
541            enabled: true,
542            max_entries: 3,
543            ..Default::default()
544        };
545        let cache = L3SemanticCache::new(config);
546
547        // Manually add entries for testing
548        {
549            let mut entries = cache.entries.write().unwrap();
550
551            for i in 0..5 {
552                let ctx = CacheContext::default();
553                let result = create_result(&format!("result_{}", i));
554                let embedding = random_embedding(384);
555
556                entries.push(L3Entry::new(format!("query_{}", i), embedding, ctx, result));
557
558                // Evict if needed
559                cache.evict(&mut entries);
560            }
561
562            // Should have at most max_entries
563            assert!(entries.len() <= 3);
564        }
565    }
566}