use std::collections::HashMap;
use std::hash::Hash;
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct DataLoaderConfig {
pub batch_window: Duration,
pub max_batch_size: usize,
pub cache_enabled: bool,
pub cache_ttl: Duration,
pub dedupe: bool,
}
impl Default for DataLoaderConfig {
fn default() -> Self {
Self {
batch_window: Duration::from_millis(10),
max_batch_size: 100,
cache_enabled: true,
cache_ttl: Duration::from_secs(60),
dedupe: true,
}
}
}
impl DataLoaderConfig {
pub fn new() -> Self {
Self::default()
}
pub fn batch_window(mut self, duration: Duration) -> Self {
self.batch_window = duration;
self
}
pub fn max_batch_size(mut self, size: usize) -> Self {
self.max_batch_size = size;
self
}
pub fn cache(mut self, enabled: bool) -> Self {
self.cache_enabled = enabled;
self
}
pub fn cache_ttl(mut self, ttl: Duration) -> Self {
self.cache_ttl = ttl;
self
}
}
#[derive(Debug, Clone)]
pub struct BatchResult<K, V> {
pub results: HashMap<K, V>,
pub missing: Vec<K>,
}
impl<K: Eq + Hash, V> BatchResult<K, V> {
pub fn new(results: HashMap<K, V>) -> Self {
Self {
results,
missing: Vec::new(),
}
}
pub fn empty() -> Self {
Self {
results: HashMap::new(),
missing: Vec::new(),
}
}
pub fn with_missing(mut self, missing: Vec<K>) -> Self {
self.missing = missing;
self
}
pub fn get(&self, key: &K) -> Option<&V> {
self.results.get(key)
}
pub fn is_missing(&self, key: &K) -> bool
where
K: PartialEq,
{
self.missing.contains(key)
}
}
#[derive(Debug, Clone)]
struct CacheEntry<V> {
value: V,
expires_at: Instant,
}
impl<V> CacheEntry<V> {
fn new(value: V, ttl: Duration) -> Self {
Self {
value,
expires_at: Instant::now() + ttl,
}
}
fn is_expired(&self) -> bool {
Instant::now() >= self.expires_at
}
}
#[derive(Debug)]
pub struct DataLoader<K, V>
where
K: Eq + Hash + Clone,
V: Clone,
{
config: DataLoaderConfig,
cache: std::sync::Mutex<HashMap<K, CacheEntry<V>>>,
pending: std::sync::Mutex<Vec<K>>,
stats: std::sync::Mutex<DataLoaderStats>,
}
#[derive(Debug, Clone, Default)]
pub struct DataLoaderStats {
pub total_loads: u64,
pub cache_hits: u64,
pub cache_misses: u64,
pub batch_loads: u64,
pub avg_batch_size: f64,
}
impl DataLoaderStats {
pub fn hit_rate(&self) -> f64 {
if self.total_loads == 0 {
0.0
} else {
self.cache_hits as f64 / self.total_loads as f64
}
}
}
impl<K, V> DataLoader<K, V>
where
K: Eq + Hash + Clone + Send + Sync,
V: Clone + Send + Sync,
{
pub fn new(config: DataLoaderConfig) -> Self {
Self {
config,
cache: std::sync::Mutex::new(HashMap::new()),
pending: std::sync::Mutex::new(Vec::new()),
stats: std::sync::Mutex::new(DataLoaderStats::default()),
}
}
pub fn load(&self, key: K) -> Option<V> {
self.update_stats(|s| s.total_loads += 1);
if self.config.cache_enabled {
if let Some(value) = self.get_cached(&key) {
self.update_stats(|s| s.cache_hits += 1);
return Some(value);
}
self.update_stats(|s| s.cache_misses += 1);
}
self.pending.lock().unwrap().push(key);
None
}
pub fn load_many(&self, keys: Vec<K>) -> HashMap<K, Option<V>> {
let mut results = HashMap::new();
for key in keys {
results.insert(key.clone(), self.load(key));
}
results
}
pub fn prime(&self, key: K, value: V) {
if self.config.cache_enabled {
let entry = CacheEntry::new(value, self.config.cache_ttl);
self.cache.lock().unwrap().insert(key, entry);
}
}
pub fn clear(&self) {
self.cache.lock().unwrap().clear();
}
pub fn clear_key(&self, key: &K) {
self.cache.lock().unwrap().remove(key);
}
pub fn execute_batch<F>(&self, mut loader: F) -> BatchResult<K, V>
where
F: FnMut(Vec<K>) -> HashMap<K, V>,
{
let keys: Vec<K> = {
let mut pending = self.pending.lock().unwrap();
std::mem::take(&mut *pending)
};
if keys.is_empty() {
return BatchResult::empty();
}
let unique_keys: Vec<K> = if self.config.dedupe {
let mut seen = std::collections::HashSet::new();
keys.into_iter()
.filter(|k| seen.insert(k.clone()))
.collect()
} else {
keys
};
let _batch_count = (unique_keys.len() + self.config.max_batch_size - 1)
/ self.config.max_batch_size;
let mut all_results = HashMap::new();
for batch in unique_keys.chunks(self.config.max_batch_size) {
let batch_keys: Vec<K> = batch.to_vec();
let batch_size = batch_keys.len();
let results = loader(batch_keys);
self.update_stats(|s| {
s.batch_loads += 1;
let total_batches = s.batch_loads as f64;
s.avg_batch_size = ((s.avg_batch_size * (total_batches - 1.0)) + batch_size as f64)
/ total_batches;
});
if self.config.cache_enabled {
let mut cache = self.cache.lock().unwrap();
for (k, v) in &results {
cache.insert(k.clone(), CacheEntry::new(v.clone(), self.config.cache_ttl));
}
}
all_results.extend(results);
}
BatchResult::new(all_results)
}
fn get_cached(&self, key: &K) -> Option<V> {
let mut cache = self.cache.lock().unwrap();
if let Some(entry) = cache.get(key) {
if !entry.is_expired() {
return Some(entry.value.clone());
} else {
cache.remove(key);
}
}
None
}
fn update_stats<F>(&self, f: F)
where
F: FnOnce(&mut DataLoaderStats),
{
let mut stats = self.stats.lock().unwrap();
f(&mut stats);
}
pub fn stats(&self) -> DataLoaderStats {
self.stats.lock().unwrap().clone()
}
pub fn config(&self) -> &DataLoaderConfig {
&self.config
}
pub fn clean_expired(&self) {
let mut cache = self.cache.lock().unwrap();
cache.retain(|_, entry| !entry.is_expired());
}
}
impl<K, V> Clone for DataLoader<K, V>
where
K: Eq + Hash + Clone,
V: Clone,
{
fn clone(&self) -> Self {
Self {
config: self.config.clone(),
cache: std::sync::Mutex::new(self.cache.lock().unwrap().clone()),
pending: std::sync::Mutex::new(self.pending.lock().unwrap().clone()),
stats: std::sync::Mutex::new(self.stats.lock().unwrap().clone()),
}
}
}
#[derive(Debug)]
pub struct DataLoaderFactory {
default_config: DataLoaderConfig,
}
impl DataLoaderFactory {
pub fn new(config: DataLoaderConfig) -> Self {
Self {
default_config: config,
}
}
pub fn create<K, V>(&self) -> DataLoader<K, V>
where
K: Eq + Hash + Clone + Send + Sync,
V: Clone + Send + Sync,
{
DataLoader::new(self.default_config.clone())
}
pub fn create_with_config<K, V>(&self, config: DataLoaderConfig) -> DataLoader<K, V>
where
K: Eq + Hash + Clone + Send + Sync,
V: Clone + Send + Sync,
{
DataLoader::new(config)
}
}
impl Default for DataLoaderFactory {
fn default() -> Self {
Self::new(DataLoaderConfig::default())
}
}
pub type IdLoader<V> = DataLoader<String, V>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dataloader_config() {
let config = DataLoaderConfig::new()
.batch_window(Duration::from_millis(20))
.max_batch_size(50)
.cache(true)
.cache_ttl(Duration::from_secs(120));
assert_eq!(config.batch_window, Duration::from_millis(20));
assert_eq!(config.max_batch_size, 50);
assert!(config.cache_enabled);
assert_eq!(config.cache_ttl, Duration::from_secs(120));
}
#[test]
fn test_dataloader_prime_and_load() {
let loader: DataLoader<String, String> = DataLoader::new(DataLoaderConfig::default());
loader.prime("key1".to_string(), "value1".to_string());
let result = loader.load("key1".to_string());
assert_eq!(result, Some("value1".to_string()));
let stats = loader.stats();
assert_eq!(stats.cache_hits, 1);
}
#[test]
fn test_dataloader_batch_execution() {
let loader: DataLoader<String, String> = DataLoader::new(DataLoaderConfig::default());
loader.load("key1".to_string());
loader.load("key2".to_string());
loader.load("key3".to_string());
let result = loader.execute_batch(|keys| {
keys.into_iter()
.map(|k| (k.clone(), format!("value_{}", k)))
.collect()
});
assert_eq!(result.results.len(), 3);
assert_eq!(result.get(&"key1".to_string()), Some(&"value_key1".to_string()));
let stats = loader.stats();
assert_eq!(stats.batch_loads, 1);
}
#[test]
fn test_dataloader_deduplication() {
let loader: DataLoader<String, i32> = DataLoader::new(
DataLoaderConfig::default().max_batch_size(100)
);
loader.load("key1".to_string());
loader.load("key1".to_string());
loader.load("key2".to_string());
loader.load("key1".to_string());
let mut batch_keys_count = 0;
let result = loader.execute_batch(|keys| {
batch_keys_count = keys.len();
keys.into_iter().map(|k| (k, 1)).collect()
});
assert_eq!(batch_keys_count, 2);
assert_eq!(result.results.len(), 2);
}
#[test]
fn test_dataloader_batch_splitting() {
let loader: DataLoader<i32, i32> = DataLoader::new(
DataLoaderConfig::default().max_batch_size(2)
);
for i in 0..5 {
loader.load(i);
}
let result = loader.execute_batch(|keys| {
keys.into_iter().map(|k| (k, k * 10)).collect()
});
assert_eq!(result.results.len(), 5);
let stats = loader.stats();
assert_eq!(stats.batch_loads, 3); }
#[test]
fn test_dataloader_clear() {
let loader: DataLoader<String, String> = DataLoader::new(DataLoaderConfig::default());
loader.prime("key1".to_string(), "value1".to_string());
loader.prime("key2".to_string(), "value2".to_string());
assert!(loader.load("key1".to_string()).is_some());
loader.clear();
assert!(loader.load("key1".to_string()).is_none());
}
#[test]
fn test_dataloader_clear_key() {
let loader: DataLoader<String, String> = DataLoader::new(DataLoaderConfig::default());
loader.prime("key1".to_string(), "value1".to_string());
loader.prime("key2".to_string(), "value2".to_string());
loader.clear_key(&"key1".to_string());
assert!(loader.load("key1".to_string()).is_none());
assert!(loader.load("key2".to_string()).is_some());
}
#[test]
fn test_dataloader_stats() {
let loader: DataLoader<String, String> = DataLoader::new(DataLoaderConfig::default());
loader.prime("cached".to_string(), "value".to_string());
loader.load("cached".to_string());
loader.load("not_cached".to_string());
let stats = loader.stats();
assert_eq!(stats.total_loads, 2);
assert_eq!(stats.cache_hits, 1);
assert_eq!(stats.cache_misses, 1);
assert_eq!(stats.hit_rate(), 0.5);
}
#[test]
fn test_dataloader_cache_disabled() {
let loader: DataLoader<String, String> = DataLoader::new(
DataLoaderConfig::default().cache(false)
);
loader.prime("key1".to_string(), "value1".to_string());
let result = loader.load("key1".to_string());
assert!(result.is_none());
}
#[test]
fn test_batch_result() {
let mut results = HashMap::new();
results.insert("a".to_string(), 1);
results.insert("b".to_string(), 2);
let batch = BatchResult::new(results)
.with_missing(vec!["c".to_string()]);
assert_eq!(batch.get(&"a".to_string()), Some(&1));
assert_eq!(batch.get(&"c".to_string()), None);
assert!(batch.is_missing(&"c".to_string()));
assert!(!batch.is_missing(&"a".to_string()));
}
#[test]
fn test_dataloader_factory() {
let factory = DataLoaderFactory::new(
DataLoaderConfig::default().max_batch_size(50)
);
let loader: DataLoader<String, i32> = factory.create();
assert_eq!(loader.config().max_batch_size, 50);
let custom_loader: DataLoader<String, i32> = factory.create_with_config(
DataLoaderConfig::default().max_batch_size(100)
);
assert_eq!(custom_loader.config().max_batch_size, 100);
}
#[test]
fn test_dataloader_load_many() {
let loader: DataLoader<String, String> = DataLoader::new(DataLoaderConfig::default());
loader.prime("key1".to_string(), "value1".to_string());
let results = loader.load_many(vec![
"key1".to_string(),
"key2".to_string(),
]);
assert_eq!(results.get(&"key1".to_string()), Some(&Some("value1".to_string())));
assert_eq!(results.get(&"key2".to_string()), Some(&None));
}
}