use arrow_array::{Array, ArrayRef};
use arrow_schema::DataType;
use std::collections::HashMap;
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub struct BatchID {
pub val: usize,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct CacheKey {
pub column_idx: usize,
pub batch_id: BatchID,
}
fn get_array_memory_size_for_cache(array: &ArrayRef) -> usize {
match array.data_type() {
DataType::Utf8View => {
use arrow_array::cast::AsArray;
let array = array.as_string_view();
array.len() * 16 + array.total_buffer_bytes_used() + std::mem::size_of_val(array)
}
_ => array.get_array_memory_size(),
}
}
#[derive(Debug)]
pub struct RowGroupCache {
cache: HashMap<CacheKey, ArrayRef>,
batch_size: usize,
max_cache_bytes: usize,
current_cache_size: usize,
}
impl RowGroupCache {
pub fn new(batch_size: usize, max_cache_bytes: usize) -> Self {
Self {
cache: HashMap::new(),
batch_size,
max_cache_bytes,
current_cache_size: 0,
}
}
pub fn insert(&mut self, column_idx: usize, batch_id: BatchID, array: ArrayRef) -> bool {
let array_size = get_array_memory_size_for_cache(&array);
if self.current_cache_size + array_size > self.max_cache_bytes {
return false; }
let key = CacheKey {
column_idx,
batch_id,
};
let existing = self.cache.insert(key, array);
assert!(existing.is_none());
self.current_cache_size += array_size;
true
}
pub fn get(&self, column_idx: usize, batch_id: BatchID) -> Option<ArrayRef> {
let key = CacheKey {
column_idx,
batch_id,
};
self.cache.get(&key).cloned()
}
pub fn batch_size(&self) -> usize {
self.batch_size
}
pub fn remove(&mut self, column_idx: usize, batch_id: BatchID) -> bool {
let key = CacheKey {
column_idx,
batch_id,
};
if let Some(array) = self.cache.remove(&key) {
self.current_cache_size -= get_array_memory_size_for_cache(&array);
true
} else {
false
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::{ArrayRef, Int32Array};
use std::sync::Arc;
#[test]
fn test_cache_basic_operations() {
let mut cache = RowGroupCache::new(1000, usize::MAX);
let array: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
let batch_id = BatchID { val: 0 };
assert!(cache.insert(0, batch_id, array.clone()));
let retrieved = cache.get(0, batch_id);
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().len(), 5);
let miss = cache.get(1, batch_id);
assert!(miss.is_none());
let miss = cache.get(0, BatchID { val: 1000 });
assert!(miss.is_none());
}
#[test]
fn test_cache_remove() {
let mut cache = RowGroupCache::new(1000, usize::MAX);
let array1: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3]));
let array2: ArrayRef = Arc::new(Int32Array::from(vec![4, 5, 6]));
assert!(cache.insert(0, BatchID { val: 0 }, array1.clone()));
assert!(cache.insert(0, BatchID { val: 1000 }, array2.clone()));
assert!(cache.insert(1, BatchID { val: 0 }, array1.clone()));
assert!(cache.get(0, BatchID { val: 0 }).is_some());
assert!(cache.get(0, BatchID { val: 1000 }).is_some());
assert!(cache.get(1, BatchID { val: 0 }).is_some());
let removed = cache.remove(0, BatchID { val: 0 });
assert!(removed);
assert!(cache.get(0, BatchID { val: 0 }).is_none());
assert!(cache.get(0, BatchID { val: 1000 }).is_some());
assert!(cache.get(1, BatchID { val: 0 }).is_some());
let not_removed = cache.remove(0, BatchID { val: 0 });
assert!(!not_removed);
assert!(cache.remove(0, BatchID { val: 1000 }));
assert!(cache.remove(1, BatchID { val: 0 }));
assert!(cache.get(0, BatchID { val: 1000 }).is_none());
assert!(cache.get(1, BatchID { val: 0 }).is_none());
}
}