use lru::LruCache;
use ruvector_gnn::layer::RuvectorLayer;
use std::collections::HashMap;
use std::num::NonZeroUsize;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
#[derive(Debug, Clone)]
pub struct CacheEntry<T> {
pub value: T,
pub created_at: Instant,
pub last_accessed: Instant,
pub access_count: u64,
}
impl<T: Clone> CacheEntry<T> {
pub fn new(value: T) -> Self {
let now = Instant::now();
Self {
value,
created_at: now,
last_accessed: now,
access_count: 1,
}
}
pub fn access(&mut self) -> &T {
self.last_accessed = Instant::now();
self.access_count += 1;
&self.value
}
}
#[derive(Debug, Clone)]
pub struct GnnCacheConfig {
pub max_layers: usize,
pub max_query_results: usize,
pub query_result_ttl_secs: u64,
pub preload_common: bool,
}
impl Default for GnnCacheConfig {
fn default() -> Self {
Self {
max_layers: 32,
max_query_results: 1000,
query_result_ttl_secs: 300, preload_common: true,
}
}
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct QueryCacheKey {
pub layer_hash: String,
pub query_hash: u64,
pub k: usize,
}
impl QueryCacheKey {
pub fn new(layer_id: &str, query: &[f32], k: usize) -> Self {
let query_hash = query
.iter()
.take(8)
.fold(0u64, |acc, &v| acc.wrapping_add(v.to_bits() as u64));
Self {
layer_hash: layer_id.to_string(),
query_hash,
k,
}
}
}
#[derive(Debug, Clone)]
pub struct CachedQueryResult {
pub result: Vec<f32>,
pub cached_at: Instant,
}
pub struct GnnCache {
layers: Arc<RwLock<HashMap<String, CacheEntry<RuvectorLayer>>>>,
query_results: Arc<RwLock<LruCache<QueryCacheKey, CachedQueryResult>>>,
config: GnnCacheConfig,
stats: Arc<RwLock<CacheStats>>,
}
#[derive(Debug, Clone, Default)]
pub struct CacheStats {
pub layer_hits: u64,
pub layer_misses: u64,
pub query_hits: u64,
pub query_misses: u64,
pub evictions: u64,
pub total_queries: u64,
}
impl CacheStats {
pub fn layer_hit_rate(&self) -> f64 {
let total = self.layer_hits + self.layer_misses;
if total == 0 {
0.0
} else {
self.layer_hits as f64 / total as f64
}
}
pub fn query_hit_rate(&self) -> f64 {
let total = self.query_hits + self.query_misses;
if total == 0 {
0.0
} else {
self.query_hits as f64 / total as f64
}
}
}
impl GnnCache {
pub fn new(config: GnnCacheConfig) -> Self {
let query_cache_size =
NonZeroUsize::new(config.max_query_results).unwrap_or(NonZeroUsize::new(1000).unwrap());
Self {
layers: Arc::new(RwLock::new(HashMap::new())),
query_results: Arc::new(RwLock::new(LruCache::new(query_cache_size))),
config,
stats: Arc::new(RwLock::new(CacheStats::default())),
}
}
pub async fn get_or_create_layer(
&self,
input_dim: usize,
hidden_dim: usize,
heads: usize,
dropout: f32,
) -> RuvectorLayer {
let key = format!(
"{}_{}_{}_{}",
input_dim,
hidden_dim,
heads,
(dropout * 1000.0) as u32
);
{
let mut layers = self.layers.write().await;
if let Some(entry) = layers.get_mut(&key) {
let mut stats = self.stats.write().await;
stats.layer_hits += 1;
return entry.access().clone();
}
}
let layer = RuvectorLayer::new(input_dim, hidden_dim, heads, dropout);
{
let mut layers = self.layers.write().await;
let mut stats = self.stats.write().await;
stats.layer_misses += 1;
if layers.len() >= self.config.max_layers {
if let Some(oldest_key) = layers
.iter()
.min_by_key(|(_, v)| v.last_accessed)
.map(|(k, _)| k.clone())
{
layers.remove(&oldest_key);
stats.evictions += 1;
}
}
layers.insert(key, CacheEntry::new(layer.clone()));
}
layer
}
pub async fn get_query_result(&self, key: &QueryCacheKey) -> Option<Vec<f32>> {
let mut results = self.query_results.write().await;
if let Some(cached) = results.get(key) {
let ttl = Duration::from_secs(self.config.query_result_ttl_secs);
if cached.cached_at.elapsed() < ttl {
let mut stats = self.stats.write().await;
stats.query_hits += 1;
stats.total_queries += 1;
return Some(cached.result.clone());
}
results.pop(key);
}
let mut stats = self.stats.write().await;
stats.query_misses += 1;
stats.total_queries += 1;
None
}
pub async fn cache_query_result(&self, key: QueryCacheKey, result: Vec<f32>) {
let mut results = self.query_results.write().await;
results.put(
key,
CachedQueryResult {
result,
cached_at: Instant::now(),
},
);
}
pub async fn stats(&self) -> CacheStats {
self.stats.read().await.clone()
}
pub async fn clear(&self) {
self.layers.write().await.clear();
self.query_results.write().await.clear();
}
pub async fn preload_common_layers(&self) {
let common_configs = [
(128, 256, 4, 0.1), (256, 512, 8, 0.1), (384, 768, 8, 0.1), (768, 1024, 16, 0.1), ];
for (input, hidden, heads, dropout) in common_configs {
let _ = self
.get_or_create_layer(input, hidden, heads, dropout)
.await;
}
}
pub async fn layer_count(&self) -> usize {
self.layers.read().await.len()
}
pub async fn query_result_count(&self) -> usize {
self.query_results.read().await.len()
}
}
#[derive(Debug, Clone)]
pub struct BatchGnnRequest {
pub layer_config: LayerConfig,
pub operations: Vec<GnnOperation>,
}
#[derive(Debug, Clone)]
pub struct LayerConfig {
pub input_dim: usize,
pub hidden_dim: usize,
pub heads: usize,
pub dropout: f32,
}
#[derive(Debug, Clone)]
pub struct GnnOperation {
pub node_embedding: Vec<f32>,
pub neighbor_embeddings: Vec<Vec<f32>>,
pub edge_weights: Vec<f32>,
}
#[derive(Debug, Clone)]
pub struct BatchGnnResult {
pub results: Vec<Vec<f32>>,
pub cached_count: usize,
pub computed_count: usize,
pub total_time_ms: f64,
}
impl GnnCache {
pub async fn batch_forward(&self, request: BatchGnnRequest) -> BatchGnnResult {
let start = Instant::now();
let layer = self
.get_or_create_layer(
request.layer_config.input_dim,
request.layer_config.hidden_dim,
request.layer_config.heads,
request.layer_config.dropout,
)
.await;
let layer_id = format!(
"{}_{}_{}",
request.layer_config.input_dim,
request.layer_config.hidden_dim,
request.layer_config.heads
);
let mut results = Vec::with_capacity(request.operations.len());
let mut cached_count = 0;
let mut computed_count = 0;
for op in &request.operations {
let cache_key = QueryCacheKey::new(&layer_id, &op.node_embedding, 1);
if let Some(cached) = self.get_query_result(&cache_key).await {
results.push(cached);
cached_count += 1;
} else {
let result = layer.forward(
&op.node_embedding,
&op.neighbor_embeddings,
&op.edge_weights,
);
self.cache_query_result(cache_key, result.clone()).await;
results.push(result);
computed_count += 1;
}
}
BatchGnnResult {
results,
cached_count,
computed_count,
total_time_ms: start.elapsed().as_secs_f64() * 1000.0,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_layer_caching() {
let cache = GnnCache::new(GnnCacheConfig::default());
let layer1 = cache.get_or_create_layer(128, 256, 4, 0.1).await;
let stats = cache.stats().await;
assert_eq!(stats.layer_misses, 1);
assert_eq!(stats.layer_hits, 0);
let _layer2 = cache.get_or_create_layer(128, 256, 4, 0.1).await;
let stats = cache.stats().await;
assert_eq!(stats.layer_misses, 1);
assert_eq!(stats.layer_hits, 1);
}
#[tokio::test]
async fn test_query_result_caching() {
let cache = GnnCache::new(GnnCacheConfig::default());
let key = QueryCacheKey::new("test", &[1.0, 2.0, 3.0], 10);
let result = vec![0.1, 0.2, 0.3];
assert!(cache.get_query_result(&key).await.is_none());
cache.cache_query_result(key.clone(), result.clone()).await;
let cached = cache.get_query_result(&key).await;
assert!(cached.is_some());
assert_eq!(cached.unwrap(), result);
}
#[tokio::test]
async fn test_batch_forward() {
let cache = GnnCache::new(GnnCacheConfig::default());
let request = BatchGnnRequest {
layer_config: LayerConfig {
input_dim: 4,
hidden_dim: 8,
heads: 2,
dropout: 0.1,
},
operations: vec![
GnnOperation {
node_embedding: vec![1.0, 2.0, 3.0, 4.0],
neighbor_embeddings: vec![vec![0.5, 1.0, 1.5, 2.0]],
edge_weights: vec![1.0],
},
GnnOperation {
node_embedding: vec![2.0, 3.0, 4.0, 5.0],
neighbor_embeddings: vec![vec![1.0, 1.5, 2.0, 2.5]],
edge_weights: vec![1.0],
},
],
};
let result = cache.batch_forward(request).await;
assert_eq!(result.results.len(), 2);
assert_eq!(result.computed_count, 2);
assert_eq!(result.cached_count, 0);
}
#[tokio::test]
async fn test_preload_common_layers() {
let cache = GnnCache::new(GnnCacheConfig {
preload_common: true,
..Default::default()
});
cache.preload_common_layers().await;
assert_eq!(cache.layer_count().await, 4);
}
}