1use std::sync::RwLock;
7use std::time::Duration;
8
9use dashmap::DashMap;
10use tokio::sync::Semaphore;
11
12use super::config::L3Config;
13use super::result::{CachedResult, L3Entry};
14use super::CacheContext;
15
16#[derive(Debug)]
21pub struct L3SemanticCache {
22 config: L3Config,
24
25 entries: RwLock<Vec<L3Entry>>,
27
28 embedding_client: EmbeddingClient,
30
31 embedding_semaphore: Semaphore,
33
34 embedding_cache: DashMap<u64, Vec<f32>>,
36}
37
38#[derive(Debug)]
40pub struct EmbeddingClient {
41 endpoint: String,
43
44 model: String,
46
47 dimension: usize,
49
50 client: reqwest::Client,
52}
53
54impl L3SemanticCache {
55 pub fn new(config: L3Config) -> Self {
57 let embedding_client = EmbeddingClient::new(
58 config.embedding_endpoint.clone(),
59 config.embedding_model.clone(),
60 config.embedding_dim,
61 );
62
63 Self {
64 config: config.clone(),
65 entries: RwLock::new(Vec::with_capacity(config.max_entries)),
66 embedding_client,
67 embedding_semaphore: Semaphore::new(10), embedding_cache: DashMap::new(),
69 }
70 }
71
72 pub async fn get(&self, query: &str, context: &CacheContext) -> Option<CachedResult> {
74 if !self.config.enabled {
75 return None;
76 }
77
78 let embedding = self.get_embedding(query).await?;
80
81 let entries = self.entries.read().ok()?;
83
84 let mut best_match: Option<(f32, &L3Entry)> = None;
85
86 for entry in entries.iter() {
87 if entry.is_expired() {
89 continue;
90 }
91
92 if entry.context.database != context.database {
94 continue;
95 }
96
97 if entry.context.user != context.user {
98 continue;
99 }
100
101 let similarity = entry.similarity(&embedding);
103
104 if similarity >= self.config.similarity_threshold {
105 if let Some((best_sim, _)) = best_match {
106 if similarity > best_sim {
107 best_match = Some((similarity, entry));
108 }
109 } else {
110 best_match = Some((similarity, entry));
111 }
112 }
113 }
114
115 best_match.map(|(_, entry)| entry.result.clone())
116 }
117
118 pub async fn put(&self, query: &str, context: &CacheContext, result: CachedResult) {
120 if !self.config.enabled {
121 return;
122 }
123
124 let embedding = match self.get_embedding(query).await {
126 Some(e) => e,
127 None => return,
128 };
129
130 let mut entry = L3Entry::new(query.to_string(), embedding, context.clone(), result);
132
133 if entry.result.ttl > self.config.ttl {
135 entry.result.ttl = self.config.ttl;
136 }
137
138 let mut entries = match self.entries.write() {
139 Ok(e) => e,
140 Err(_) => return,
141 };
142
143 if entries.len() >= self.config.max_entries {
145 self.evict(&mut entries);
146 }
147
148 entries.push(entry);
149 }
150
151 pub async fn clear(&self) {
153 if let Ok(mut entries) = self.entries.write() {
154 entries.clear();
155 }
156 self.embedding_cache.clear();
157 }
158
159 pub fn len(&self) -> usize {
161 self.entries.read().map(|e| e.len()).unwrap_or(0)
162 }
163
164 pub fn is_empty(&self) -> bool {
166 self.len() == 0
167 }
168
169 pub fn stats(&self) -> L3CacheStats {
171 let entries = self.entries.read().unwrap();
172
173 let total_access: u64 = entries.iter().map(|e| e.access_count).sum();
174 let avg_embedding_size = if entries.is_empty() {
175 0
176 } else {
177 entries.first().map(|e| e.embedding.len()).unwrap_or(0)
178 };
179
180 L3CacheStats {
181 entry_count: entries.len(),
182 max_entries: self.config.max_entries,
183 similarity_threshold: self.config.similarity_threshold,
184 embedding_dimension: avg_embedding_size,
185 total_accesses: total_access,
186 embedding_cache_size: self.embedding_cache.len(),
187 }
188 }
189
190 async fn get_embedding(&self, query: &str) -> Option<Vec<f32>> {
192 let query_hash = quick_hash(query);
194
195 if let Some(cached) = self.embedding_cache.get(&query_hash) {
196 return Some(cached.clone());
197 }
198
199 let _permit = self.embedding_semaphore.acquire().await.ok()?;
201
202 let embedding = self.embedding_client.embed(query).await?;
204
205 self.embedding_cache.insert(query_hash, embedding.clone());
207
208 Some(embedding)
209 }
210
211 fn evict(&self, entries: &mut Vec<L3Entry>) {
213 entries.retain(|e| !e.is_expired());
215
216 while entries.len() >= self.config.max_entries {
218 if let Some(lru_idx) = entries
219 .iter()
220 .enumerate()
221 .min_by_key(|(_, e)| e.last_access)
222 .map(|(i, _)| i)
223 {
224 entries.remove(lru_idx);
225 } else {
226 break;
227 }
228 }
229 }
230
231 pub async fn health_check(&self) -> bool {
233 self.embedding_client.health_check().await
234 }
235}
236
237impl EmbeddingClient {
238 pub fn new(endpoint: String, model: String, dimension: usize) -> Self {
240 let client = reqwest::Client::builder()
241 .timeout(Duration::from_secs(30))
242 .build()
243 .unwrap_or_default();
244
245 Self {
246 endpoint,
247 model,
248 dimension,
249 client,
250 }
251 }
252
253 pub async fn embed(&self, text: &str) -> Option<Vec<f32>> {
255 let url = format!("{}/api/embeddings", self.endpoint);
256
257 let request = serde_json::json!({
258 "model": self.model,
259 "prompt": text
260 });
261
262 let response = self.client.post(&url).json(&request).send().await.ok()?;
263
264 if !response.status().is_success() {
265 return None;
266 }
267
268 let body: serde_json::Value = response.json().await.ok()?;
269
270 let embedding = body
271 .get("embedding")?
272 .as_array()?
273 .iter()
274 .filter_map(|v| v.as_f64().map(|f| f as f32))
275 .collect::<Vec<f32>>();
276
277 if embedding.len() != self.dimension {
279 if embedding.len() > self.dimension {
281 return Some(embedding[..self.dimension].to_vec());
282 } else {
283 let mut padded = embedding;
285 padded.resize(self.dimension, 0.0);
286 return Some(padded);
287 }
288 }
289
290 Some(embedding)
291 }
292
293 pub async fn health_check(&self) -> bool {
295 let url = format!("{}/api/tags", self.endpoint);
296
297 match self.client.get(&url).send().await {
298 Ok(response) => response.status().is_success(),
299 Err(_) => false,
300 }
301 }
302
303 pub async fn list_models(&self) -> Option<Vec<String>> {
305 let url = format!("{}/api/tags", self.endpoint);
306
307 let response = self.client.get(&url).send().await.ok()?;
308 let body: serde_json::Value = response.json().await.ok()?;
309
310 let models = body
311 .get("models")?
312 .as_array()?
313 .iter()
314 .filter_map(|m| m.get("name")?.as_str().map(String::from))
315 .collect();
316
317 Some(models)
318 }
319
320 pub async fn pull_model(&self) -> Result<(), String> {
322 let url = format!("{}/api/pull", self.endpoint);
323
324 let request = serde_json::json!({
325 "name": self.model
326 });
327
328 let response = self
329 .client
330 .post(&url)
331 .json(&request)
332 .send()
333 .await
334 .map_err(|e| e.to_string())?;
335
336 if response.status().is_success() {
337 Ok(())
338 } else {
339 Err(format!("Failed to pull model: {}", response.status()))
340 }
341 }
342}
343
344#[derive(Debug, Clone)]
346pub struct L3CacheStats {
347 pub entry_count: usize,
349
350 pub max_entries: usize,
352
353 pub similarity_threshold: f32,
355
356 pub embedding_dimension: usize,
358
359 pub total_accesses: u64,
361
362 pub embedding_cache_size: usize,
364}
365
366fn quick_hash(s: &str) -> u64 {
368 use std::collections::hash_map::DefaultHasher;
369 use std::hash::{Hash, Hasher};
370
371 let mut hasher = DefaultHasher::new();
372 s.hash(&mut hasher);
373 hasher.finish()
374}
375
376pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
378 if a.len() != b.len() || a.is_empty() {
379 return 0.0;
380 }
381
382 let mut dot_product = 0.0f32;
383 let mut norm_a = 0.0f32;
384 let mut norm_b = 0.0f32;
385
386 for (x, y) in a.iter().zip(b.iter()) {
387 dot_product += x * y;
388 norm_a += x * x;
389 norm_b += y * y;
390 }
391
392 let norm_a = norm_a.sqrt();
393 let norm_b = norm_b.sqrt();
394
395 if norm_a == 0.0 || norm_b == 0.0 {
396 return 0.0;
397 }
398
399 dot_product / (norm_a * norm_b)
400}
401
402#[cfg(test)]
404fn random_embedding(dim: usize) -> Vec<f32> {
405 use std::collections::hash_map::DefaultHasher;
406 use std::hash::{Hash, Hasher};
407
408 let mut hasher = DefaultHasher::new();
409 std::time::Instant::now().hash(&mut hasher);
410 let seed = hasher.finish();
411
412 (0..dim)
413 .map(|i| {
414 let x = ((seed.wrapping_add(i as u64) as f64) * 0.0001).sin() as f32;
415 x
416 })
417 .collect()
418}
419
420#[cfg(test)]
421mod tests {
422 use super::*;
423 use bytes::Bytes;
424
425 fn create_result(data: &str) -> CachedResult {
426 CachedResult::new(
427 Bytes::from(data.to_string()),
428 1,
429 Duration::from_secs(60),
430 vec!["test".to_string()],
431 Duration::from_millis(5),
432 )
433 }
434
435 #[test]
436 fn test_cosine_similarity() {
437 let a = vec![1.0, 0.0, 0.0];
439 assert!((cosine_similarity(&a, &a) - 1.0).abs() < 0.001);
440
441 let b = vec![0.0, 1.0, 0.0];
443 assert!(cosine_similarity(&a, &b).abs() < 0.001);
444
445 let c = vec![-1.0, 0.0, 0.0];
447 assert!((cosine_similarity(&a, &c) + 1.0).abs() < 0.001);
448
449 assert!(cosine_similarity(&[], &[]).abs() < 0.001);
451
452 let d = vec![1.0, 0.0];
454 assert!(cosine_similarity(&a, &d).abs() < 0.001);
455 }
456
457 #[test]
458 fn test_l3_entry_similarity() {
459 let result = create_result("test");
460 let ctx = CacheContext::default();
461
462 let entry = L3Entry::new(
463 "SELECT * FROM users".to_string(),
464 vec![0.5, 0.5, 0.5, 0.5],
465 ctx,
466 result,
467 );
468
469 let similar = vec![0.5, 0.5, 0.5, 0.5];
471 assert!((entry.similarity(&similar) - 1.0).abs() < 0.001);
472
473 let moderate = vec![0.5, 0.5, 0.0, 0.0];
475 assert!(entry.similarity(&moderate) > 0.5);
476 assert!(entry.similarity(&moderate) < 1.0);
477 }
478
479 #[test]
480 fn test_quick_hash() {
481 let hash1 = quick_hash("SELECT * FROM users");
482 let hash2 = quick_hash("SELECT * FROM users");
483 let hash3 = quick_hash("SELECT * FROM orders");
484
485 assert_eq!(hash1, hash2);
486 assert_ne!(hash1, hash3);
487 }
488
489 #[test]
490 fn test_random_embedding() {
491 let emb = random_embedding(384);
492 assert_eq!(emb.len(), 384);
493 }
494
495 #[tokio::test]
496 async fn test_l3_cache_disabled() {
497 let config = L3Config {
498 enabled: false,
499 ..Default::default()
500 };
501 let cache = L3SemanticCache::new(config);
502
503 let ctx = CacheContext::default();
504 let result = cache.get("test query", &ctx).await;
505 assert!(result.is_none());
506 }
507
508 #[test]
509 fn test_embedding_client_creation() {
510 let client = EmbeddingClient::new(
511 "http://localhost:11434".to_string(),
512 "all-minilm".to_string(),
513 384,
514 );
515
516 assert_eq!(client.endpoint, "http://localhost:11434");
517 assert_eq!(client.model, "all-minilm");
518 assert_eq!(client.dimension, 384);
519 }
520
521 #[test]
522 fn test_l3_stats() {
523 let config = L3Config {
524 enabled: true,
525 max_entries: 1000,
526 similarity_threshold: 0.9,
527 ..Default::default()
528 };
529 let cache = L3SemanticCache::new(config);
530
531 let stats = cache.stats();
532 assert_eq!(stats.entry_count, 0);
533 assert_eq!(stats.max_entries, 1000);
534 assert!((stats.similarity_threshold - 0.9).abs() < 0.001);
535 }
536
537 #[test]
538 fn test_eviction() {
539 let config = L3Config {
541 enabled: true,
542 max_entries: 3,
543 ..Default::default()
544 };
545 let cache = L3SemanticCache::new(config);
546
547 {
549 let mut entries = cache.entries.write().unwrap();
550
551 for i in 0..5 {
552 let ctx = CacheContext::default();
553 let result = create_result(&format!("result_{}", i));
554 let embedding = random_embedding(384);
555
556 entries.push(L3Entry::new(format!("query_{}", i), embedding, ctx, result));
557
558 cache.evict(&mut entries);
560 }
561
562 assert!(entries.len() <= 3);
564 }
565 }
566}