1use std::sync::RwLock;
7use std::time::{Duration, Instant};
8
9use bytes::Bytes;
10use dashmap::DashMap;
11use tokio::sync::Semaphore;
12
13use super::config::L3Config;
14use super::result::{CachedResult, L3Entry};
15use super::CacheContext;
16
17#[derive(Debug)]
22pub struct L3SemanticCache {
23 config: L3Config,
25
26 entries: RwLock<Vec<L3Entry>>,
28
29 embedding_client: EmbeddingClient,
31
32 embedding_semaphore: Semaphore,
34
35 embedding_cache: DashMap<u64, Vec<f32>>,
37}
38
39#[derive(Debug)]
41pub struct EmbeddingClient {
42 endpoint: String,
44
45 model: String,
47
48 dimension: usize,
50
51 client: reqwest::Client,
53}
54
55impl L3SemanticCache {
56 pub fn new(config: L3Config) -> Self {
58 let embedding_client = EmbeddingClient::new(
59 config.embedding_endpoint.clone(),
60 config.embedding_model.clone(),
61 config.embedding_dim,
62 );
63
64 Self {
65 config: config.clone(),
66 entries: RwLock::new(Vec::with_capacity(config.max_entries)),
67 embedding_client,
68 embedding_semaphore: Semaphore::new(10), embedding_cache: DashMap::new(),
70 }
71 }
72
73 pub async fn get(&self, query: &str, context: &CacheContext) -> Option<CachedResult> {
75 if !self.config.enabled {
76 return None;
77 }
78
79 let embedding = self.get_embedding(query).await?;
81
82 let entries = self.entries.read().ok()?;
84
85 let mut best_match: Option<(f32, &L3Entry)> = None;
86
87 for entry in entries.iter() {
88 if entry.is_expired() {
90 continue;
91 }
92
93 if entry.context.database != context.database {
95 continue;
96 }
97
98 if entry.context.user != context.user {
99 continue;
100 }
101
102 let similarity = entry.similarity(&embedding);
104
105 if similarity >= self.config.similarity_threshold {
106 if let Some((best_sim, _)) = best_match {
107 if similarity > best_sim {
108 best_match = Some((similarity, entry));
109 }
110 } else {
111 best_match = Some((similarity, entry));
112 }
113 }
114 }
115
116 best_match.map(|(_, entry)| entry.result.clone())
117 }
118
119 pub async fn put(&self, query: &str, context: &CacheContext, result: CachedResult) {
121 if !self.config.enabled {
122 return;
123 }
124
125 let embedding = match self.get_embedding(query).await {
127 Some(e) => e,
128 None => return,
129 };
130
131 let mut entry = L3Entry::new(
133 query.to_string(),
134 embedding,
135 context.clone(),
136 result,
137 );
138
139 if entry.result.ttl > self.config.ttl {
141 entry.result.ttl = self.config.ttl;
142 }
143
144 let mut entries = match self.entries.write() {
145 Ok(e) => e,
146 Err(_) => return,
147 };
148
149 if entries.len() >= self.config.max_entries {
151 self.evict(&mut entries);
152 }
153
154 entries.push(entry);
155 }
156
157 pub async fn clear(&self) {
159 if let Ok(mut entries) = self.entries.write() {
160 entries.clear();
161 }
162 self.embedding_cache.clear();
163 }
164
165 pub fn len(&self) -> usize {
167 self.entries.read().map(|e| e.len()).unwrap_or(0)
168 }
169
170 pub fn is_empty(&self) -> bool {
172 self.len() == 0
173 }
174
175 pub fn stats(&self) -> L3CacheStats {
177 let entries = self.entries.read().unwrap();
178
179 let total_access: u64 = entries.iter().map(|e| e.access_count).sum();
180 let avg_embedding_size = if entries.is_empty() {
181 0
182 } else {
183 entries.first().map(|e| e.embedding.len()).unwrap_or(0)
184 };
185
186 L3CacheStats {
187 entry_count: entries.len(),
188 max_entries: self.config.max_entries,
189 similarity_threshold: self.config.similarity_threshold,
190 embedding_dimension: avg_embedding_size,
191 total_accesses: total_access,
192 embedding_cache_size: self.embedding_cache.len(),
193 }
194 }
195
196 async fn get_embedding(&self, query: &str) -> Option<Vec<f32>> {
198 let query_hash = quick_hash(query);
200
201 if let Some(cached) = self.embedding_cache.get(&query_hash) {
202 return Some(cached.clone());
203 }
204
205 let _permit = self.embedding_semaphore.acquire().await.ok()?;
207
208 let embedding = self.embedding_client.embed(query).await?;
210
211 self.embedding_cache.insert(query_hash, embedding.clone());
213
214 Some(embedding)
215 }
216
217 fn evict(&self, entries: &mut Vec<L3Entry>) {
219 entries.retain(|e| !e.is_expired());
221
222 while entries.len() >= self.config.max_entries {
224 if let Some(lru_idx) = entries
225 .iter()
226 .enumerate()
227 .min_by_key(|(_, e)| e.last_access)
228 .map(|(i, _)| i)
229 {
230 entries.remove(lru_idx);
231 } else {
232 break;
233 }
234 }
235 }
236
237 pub async fn health_check(&self) -> bool {
239 self.embedding_client.health_check().await
240 }
241}
242
243impl EmbeddingClient {
244 pub fn new(endpoint: String, model: String, dimension: usize) -> Self {
246 let client = reqwest::Client::builder()
247 .timeout(Duration::from_secs(30))
248 .build()
249 .unwrap_or_default();
250
251 Self {
252 endpoint,
253 model,
254 dimension,
255 client,
256 }
257 }
258
259 pub async fn embed(&self, text: &str) -> Option<Vec<f32>> {
261 let url = format!("{}/api/embeddings", self.endpoint);
262
263 let request = serde_json::json!({
264 "model": self.model,
265 "prompt": text
266 });
267
268 let response = self.client
269 .post(&url)
270 .json(&request)
271 .send()
272 .await
273 .ok()?;
274
275 if !response.status().is_success() {
276 return None;
277 }
278
279 let body: serde_json::Value = response.json().await.ok()?;
280
281 let embedding = body.get("embedding")?
282 .as_array()?
283 .iter()
284 .filter_map(|v| v.as_f64().map(|f| f as f32))
285 .collect::<Vec<f32>>();
286
287 if embedding.len() != self.dimension {
289 if embedding.len() > self.dimension {
291 return Some(embedding[..self.dimension].to_vec());
292 } else {
293 let mut padded = embedding;
295 padded.resize(self.dimension, 0.0);
296 return Some(padded);
297 }
298 }
299
300 Some(embedding)
301 }
302
303 pub async fn health_check(&self) -> bool {
305 let url = format!("{}/api/tags", self.endpoint);
306
307 match self.client.get(&url).send().await {
308 Ok(response) => response.status().is_success(),
309 Err(_) => false,
310 }
311 }
312
313 pub async fn list_models(&self) -> Option<Vec<String>> {
315 let url = format!("{}/api/tags", self.endpoint);
316
317 let response = self.client.get(&url).send().await.ok()?;
318 let body: serde_json::Value = response.json().await.ok()?;
319
320 let models = body.get("models")?
321 .as_array()?
322 .iter()
323 .filter_map(|m| m.get("name")?.as_str().map(String::from))
324 .collect();
325
326 Some(models)
327 }
328
329 pub async fn pull_model(&self) -> Result<(), String> {
331 let url = format!("{}/api/pull", self.endpoint);
332
333 let request = serde_json::json!({
334 "name": self.model
335 });
336
337 let response = self.client
338 .post(&url)
339 .json(&request)
340 .send()
341 .await
342 .map_err(|e| e.to_string())?;
343
344 if response.status().is_success() {
345 Ok(())
346 } else {
347 Err(format!("Failed to pull model: {}", response.status()))
348 }
349 }
350}
351
352#[derive(Debug, Clone)]
354pub struct L3CacheStats {
355 pub entry_count: usize,
357
358 pub max_entries: usize,
360
361 pub similarity_threshold: f32,
363
364 pub embedding_dimension: usize,
366
367 pub total_accesses: u64,
369
370 pub embedding_cache_size: usize,
372}
373
374fn quick_hash(s: &str) -> u64 {
376 use std::collections::hash_map::DefaultHasher;
377 use std::hash::{Hash, Hasher};
378
379 let mut hasher = DefaultHasher::new();
380 s.hash(&mut hasher);
381 hasher.finish()
382}
383
384pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
386 if a.len() != b.len() || a.is_empty() {
387 return 0.0;
388 }
389
390 let mut dot_product = 0.0f32;
391 let mut norm_a = 0.0f32;
392 let mut norm_b = 0.0f32;
393
394 for (x, y) in a.iter().zip(b.iter()) {
395 dot_product += x * y;
396 norm_a += x * x;
397 norm_b += y * y;
398 }
399
400 let norm_a = norm_a.sqrt();
401 let norm_b = norm_b.sqrt();
402
403 if norm_a == 0.0 || norm_b == 0.0 {
404 return 0.0;
405 }
406
407 dot_product / (norm_a * norm_b)
408}
409
410#[cfg(test)]
412fn random_embedding(dim: usize) -> Vec<f32> {
413 use std::collections::hash_map::DefaultHasher;
414 use std::hash::{Hash, Hasher};
415
416 let mut hasher = DefaultHasher::new();
417 std::time::Instant::now().hash(&mut hasher);
418 let seed = hasher.finish();
419
420 (0..dim)
421 .map(|i| {
422 let x = ((seed.wrapping_add(i as u64) as f64) * 0.0001).sin() as f32;
423 x
424 })
425 .collect()
426}
427
428#[cfg(test)]
429mod tests {
430 use super::*;
431
432 fn create_result(data: &str) -> CachedResult {
433 CachedResult::new(
434 Bytes::from(data.to_string()),
435 1,
436 Duration::from_secs(60),
437 vec!["test".to_string()],
438 Duration::from_millis(5),
439 )
440 }
441
442 #[test]
443 fn test_cosine_similarity() {
444 let a = vec![1.0, 0.0, 0.0];
446 assert!((cosine_similarity(&a, &a) - 1.0).abs() < 0.001);
447
448 let b = vec![0.0, 1.0, 0.0];
450 assert!(cosine_similarity(&a, &b).abs() < 0.001);
451
452 let c = vec![-1.0, 0.0, 0.0];
454 assert!((cosine_similarity(&a, &c) + 1.0).abs() < 0.001);
455
456 assert!(cosine_similarity(&[], &[]).abs() < 0.001);
458
459 let d = vec![1.0, 0.0];
461 assert!(cosine_similarity(&a, &d).abs() < 0.001);
462 }
463
464 #[test]
465 fn test_l3_entry_similarity() {
466 let result = create_result("test");
467 let ctx = CacheContext::default();
468
469 let entry = L3Entry::new(
470 "SELECT * FROM users".to_string(),
471 vec![0.5, 0.5, 0.5, 0.5],
472 ctx,
473 result,
474 );
475
476 let similar = vec![0.5, 0.5, 0.5, 0.5];
478 assert!((entry.similarity(&similar) - 1.0).abs() < 0.001);
479
480 let moderate = vec![0.5, 0.5, 0.0, 0.0];
482 assert!(entry.similarity(&moderate) > 0.5);
483 assert!(entry.similarity(&moderate) < 1.0);
484 }
485
486 #[test]
487 fn test_quick_hash() {
488 let hash1 = quick_hash("SELECT * FROM users");
489 let hash2 = quick_hash("SELECT * FROM users");
490 let hash3 = quick_hash("SELECT * FROM orders");
491
492 assert_eq!(hash1, hash2);
493 assert_ne!(hash1, hash3);
494 }
495
496 #[test]
497 fn test_random_embedding() {
498 let emb = random_embedding(384);
499 assert_eq!(emb.len(), 384);
500 }
501
502 #[tokio::test]
503 async fn test_l3_cache_disabled() {
504 let config = L3Config {
505 enabled: false,
506 ..Default::default()
507 };
508 let cache = L3SemanticCache::new(config);
509
510 let ctx = CacheContext::default();
511 let result = cache.get("test query", &ctx).await;
512 assert!(result.is_none());
513 }
514
515 #[test]
516 fn test_embedding_client_creation() {
517 let client = EmbeddingClient::new(
518 "http://localhost:11434".to_string(),
519 "all-minilm".to_string(),
520 384,
521 );
522
523 assert_eq!(client.endpoint, "http://localhost:11434");
524 assert_eq!(client.model, "all-minilm");
525 assert_eq!(client.dimension, 384);
526 }
527
528 #[test]
529 fn test_l3_stats() {
530 let config = L3Config {
531 enabled: true,
532 max_entries: 1000,
533 similarity_threshold: 0.9,
534 ..Default::default()
535 };
536 let cache = L3SemanticCache::new(config);
537
538 let stats = cache.stats();
539 assert_eq!(stats.entry_count, 0);
540 assert_eq!(stats.max_entries, 1000);
541 assert!((stats.similarity_threshold - 0.9).abs() < 0.001);
542 }
543
544 #[test]
545 fn test_eviction() {
546 let config = L3Config {
548 enabled: true,
549 max_entries: 3,
550 ..Default::default()
551 };
552 let cache = L3SemanticCache::new(config);
553
554 {
556 let mut entries = cache.entries.write().unwrap();
557
558 for i in 0..5 {
559 let ctx = CacheContext::default();
560 let result = create_result(&format!("result_{}", i));
561 let embedding = random_embedding(384);
562
563 entries.push(L3Entry::new(
564 format!("query_{}", i),
565 embedding,
566 ctx,
567 result,
568 ));
569
570 cache.evict(&mut entries);
572 }
573
574 assert!(entries.len() <= 3);
576 }
577 }
578}