use dashmap::DashMap;
use std::collections::HashSet;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Instant;
pub type ChunkId = u64;
pub type EmbeddingHash = u64;
#[derive(Debug, Clone)]
pub struct Chunk {
pub id: ChunkId,
pub document_id: String,
pub content: String,
pub embedding: Option<Vec<f32>>,
pub position: usize,
pub metadata: Option<serde_json::Value>,
pub created_at: Instant,
}
impl Chunk {
pub fn new(id: ChunkId, document_id: impl Into<String>, content: impl Into<String>) -> Self {
Self {
id,
document_id: document_id.into(),
content: content.into(),
embedding: None,
position: 0,
metadata: None,
created_at: Instant::now(),
}
}
pub fn with_embedding(mut self, embedding: Vec<f32>) -> Self {
self.embedding = Some(embedding);
self
}
pub fn with_position(mut self, position: usize) -> Self {
self.position = position;
self
}
pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
self.metadata = Some(metadata);
self
}
pub fn size(&self) -> usize {
self.content.len() +
self.document_id.len() +
self.embedding.as_ref().map(|e| e.len() * 4).unwrap_or(0) +
64
}
}
pub fn hash_embedding(embedding: &[f32]) -> EmbeddingHash {
use std::hash::{Hash, Hasher};
use std::collections::hash_map::DefaultHasher;
let mut hasher = DefaultHasher::new();
for val in embedding {
let quantized = (val * 1000.0) as i32;
quantized.hash(&mut hasher);
}
hasher.finish()
}
pub struct RagChunkCache {
chunks: DashMap<ChunkId, Chunk>,
embedding_to_chunks: DashMap<EmbeddingHash, Vec<ChunkId>>,
document_to_chunks: DashMap<String, HashSet<ChunkId>>,
max_size_mb: usize,
current_size: AtomicU64,
stats: RagCacheStats,
}
#[derive(Debug, Default)]
struct RagCacheStats {
hits: AtomicU64,
misses: AtomicU64,
embedding_lookups: AtomicU64,
embedding_cache_hits: AtomicU64,
}
impl RagChunkCache {
pub fn new(max_size_mb: usize) -> Self {
Self {
chunks: DashMap::new(),
embedding_to_chunks: DashMap::new(),
document_to_chunks: DashMap::new(),
max_size_mb,
current_size: AtomicU64::new(0),
stats: RagCacheStats::default(),
}
}
pub fn get_chunk(&self, id: ChunkId) -> Option<Chunk> {
if let Some(chunk) = self.chunks.get(&id) {
self.stats.hits.fetch_add(1, Ordering::Relaxed);
Some(chunk.clone())
} else {
self.stats.misses.fetch_add(1, Ordering::Relaxed);
None
}
}
pub fn get_chunks_by_embedding(&self, embedding: &[f32], k: usize) -> Vec<Chunk> {
self.stats.embedding_lookups.fetch_add(1, Ordering::Relaxed);
let hash = hash_embedding(embedding);
if let Some(chunk_ids) = self.embedding_to_chunks.get(&hash) {
self.stats.embedding_cache_hits.fetch_add(1, Ordering::Relaxed);
let chunks: Vec<_> = chunk_ids.iter()
.filter_map(|id| self.chunks.get(id).map(|c| c.clone()))
.take(k)
.collect();
return chunks;
}
Vec::new()
}
pub fn get_document_chunks(&self, document_id: &str) -> Vec<Chunk> {
if let Some(ids) = self.document_to_chunks.get(document_id) {
ids.iter()
.filter_map(|id| self.chunks.get(id).map(|c| c.clone()))
.collect()
} else {
Vec::new()
}
}
pub fn insert_chunk(&self, chunk: Chunk) {
let size = chunk.size() as u64;
let max_bytes = (self.max_size_mb * 1024 * 1024) as u64;
while self.current_size.load(Ordering::Relaxed) + size > max_bytes {
if !self.evict_one() {
break;
}
}
self.document_to_chunks
.entry(chunk.document_id.clone())
.or_default()
.insert(chunk.id);
if let Some(ref embedding) = chunk.embedding {
let hash = hash_embedding(embedding);
self.embedding_to_chunks
.entry(hash)
.or_default()
.push(chunk.id);
}
self.chunks.insert(chunk.id, chunk);
self.current_size.fetch_add(size, Ordering::Relaxed);
}
pub fn insert_chunks(&self, chunks: Vec<Chunk>) {
for chunk in chunks {
self.insert_chunk(chunk);
}
}
pub fn cache_embedding_result(&self, embedding: &[f32], chunk_ids: Vec<ChunkId>) {
let hash = hash_embedding(embedding);
self.embedding_to_chunks.insert(hash, chunk_ids);
}
pub fn remove_chunk(&self, id: ChunkId) {
if let Some((_, chunk)) = self.chunks.remove(&id) {
self.current_size.fetch_sub(chunk.size() as u64, Ordering::Relaxed);
if let Some(mut ids) = self.document_to_chunks.get_mut(&chunk.document_id) {
ids.remove(&id);
}
}
}
pub fn remove_document(&self, document_id: &str) {
if let Some((_, ids)) = self.document_to_chunks.remove(document_id) {
for id in ids {
self.remove_chunk(id);
}
}
}
fn evict_one(&self) -> bool {
let mut oldest_id = None;
let mut oldest_time = Instant::now();
for entry in self.chunks.iter() {
if entry.created_at < oldest_time {
oldest_time = entry.created_at;
oldest_id = Some(*entry.key());
}
}
if let Some(id) = oldest_id {
self.remove_chunk(id);
return true;
}
false
}
pub fn stats(&self) -> RagCacheStatsSnapshot {
RagCacheStatsSnapshot {
chunk_count: self.chunks.len(),
document_count: self.document_to_chunks.len(),
size_bytes: self.current_size.load(Ordering::Relaxed),
max_size_bytes: (self.max_size_mb * 1024 * 1024) as u64,
hits: self.stats.hits.load(Ordering::Relaxed),
misses: self.stats.misses.load(Ordering::Relaxed),
embedding_lookups: self.stats.embedding_lookups.load(Ordering::Relaxed),
embedding_cache_hit_rate: {
let lookups = self.stats.embedding_lookups.load(Ordering::Relaxed);
let hits = self.stats.embedding_cache_hits.load(Ordering::Relaxed);
if lookups > 0 { hits as f64 / lookups as f64 } else { 0.0 }
},
}
}
pub fn clear(&self) {
self.chunks.clear();
self.embedding_to_chunks.clear();
self.document_to_chunks.clear();
self.current_size.store(0, Ordering::Relaxed);
}
}
#[derive(Debug, Clone)]
pub struct RagCacheStatsSnapshot {
pub chunk_count: usize,
pub document_count: usize,
pub size_bytes: u64,
pub max_size_bytes: u64,
pub hits: u64,
pub misses: u64,
pub embedding_lookups: u64,
pub embedding_cache_hit_rate: f64,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_chunk_creation() {
let chunk = Chunk::new(1, "doc-1", "This is a test chunk")
.with_position(0);
assert_eq!(chunk.id, 1);
assert_eq!(chunk.document_id, "doc-1");
assert_eq!(chunk.position, 0);
}
#[test]
fn test_insert_and_get() {
let cache = RagChunkCache::new(10);
let chunk = Chunk::new(1, "doc-1", "Test content");
cache.insert_chunk(chunk);
let retrieved = cache.get_chunk(1);
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().content, "Test content");
}
#[test]
fn test_document_chunks() {
let cache = RagChunkCache::new(10);
cache.insert_chunk(Chunk::new(1, "doc-1", "Chunk 1").with_position(0));
cache.insert_chunk(Chunk::new(2, "doc-1", "Chunk 2").with_position(1));
cache.insert_chunk(Chunk::new(3, "doc-2", "Chunk 3").with_position(0));
let doc1_chunks = cache.get_document_chunks("doc-1");
assert_eq!(doc1_chunks.len(), 2);
let doc2_chunks = cache.get_document_chunks("doc-2");
assert_eq!(doc2_chunks.len(), 1);
}
#[test]
fn test_embedding_lookup() {
let cache = RagChunkCache::new(10);
let embedding = vec![0.1, 0.2, 0.3];
let chunk = Chunk::new(1, "doc-1", "Embedded content")
.with_embedding(embedding.clone());
cache.insert_chunk(chunk);
cache.cache_embedding_result(&embedding, vec![1]);
let results = cache.get_chunks_by_embedding(&embedding, 10);
assert_eq!(results.len(), 1);
assert_eq!(results[0].id, 1);
}
#[test]
fn test_remove_document() {
let cache = RagChunkCache::new(10);
cache.insert_chunk(Chunk::new(1, "doc-1", "Chunk 1"));
cache.insert_chunk(Chunk::new(2, "doc-1", "Chunk 2"));
cache.remove_document("doc-1");
assert!(cache.get_chunk(1).is_none());
assert!(cache.get_chunk(2).is_none());
}
#[test]
fn test_stats() {
let cache = RagChunkCache::new(10);
cache.insert_chunk(Chunk::new(1, "doc-1", "Content"));
cache.get_chunk(1); cache.get_chunk(2);
let stats = cache.stats();
assert_eq!(stats.chunk_count, 1);
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 1);
}
}