use std::sync::Arc;
use std::time::Duration;
use arrow::record_batch::RecordBatch;
use oxistore_columnar::{ColumnarError, ColumnarTable};
use crate::lru::LruCache;
use crate::Cache;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct RowGroupKey {
pub file_id: String,
pub row_group_index: usize,
}
impl RowGroupKey {
#[must_use]
pub fn new(file_id: impl Into<String>, row_group_index: usize) -> Self {
Self {
file_id: file_id.into(),
row_group_index,
}
}
}
pub struct ColumnarRowGroupCache {
inner: LruCache<RowGroupKey, Vec<u8>>,
hits: u64,
misses: u64,
}
impl ColumnarRowGroupCache {
#[must_use]
pub fn new(max_entries: usize) -> Self {
Self {
inner: LruCache::new(max_entries),
hits: 0,
misses: 0,
}
}
#[must_use]
pub fn hits(&self) -> u64 {
self.hits
}
#[must_use]
pub fn misses(&self) -> u64 {
self.misses
}
#[must_use]
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
}
}
#[must_use]
pub fn len(&self) -> usize {
self.inner.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
#[must_use]
pub fn cap(&self) -> usize {
self.inner.cap()
}
pub fn get(&mut self, key: &RowGroupKey) -> Option<&[u8]> {
let v = self.inner.get(key);
if v.is_some() {
self.hits += 1;
} else {
self.misses += 1;
}
v.map(|b| b.as_slice())
}
pub fn insert(&mut self, key: RowGroupKey, bytes: Vec<u8>) {
self.inner.put(key, bytes);
}
pub fn insert_with_ttl(&mut self, key: RowGroupKey, bytes: Vec<u8>, ttl: Duration) {
self.inner.put_with_ttl(key, bytes, ttl);
}
pub fn evict(&mut self, key: &RowGroupKey) -> Option<Vec<u8>> {
self.inner.remove(key)
}
pub fn invalidate_file(&mut self, file_id: &str) {
let keys: Vec<RowGroupKey> = self
.inner
.iter()
.filter(|(k, _)| k.file_id == file_id)
.map(|(k, _)| k.clone())
.collect();
for k in keys {
self.inner.remove(&k);
}
}
pub fn clear(&mut self) {
self.inner.clear();
}
pub fn load_row_group(
&mut self,
file_id: impl Into<String>,
row_group_index: usize,
table: &ColumnarTable,
) -> Result<&[u8], ColumnarError> {
let key = RowGroupKey::new(file_id, row_group_index);
if self.inner.contains_key(&key) {
self.hits += 1;
return Ok(self
.inner
.get(&key)
.map(|v| v.as_slice())
.unwrap_or_default());
}
self.misses += 1;
let bytes = serialise_row_group(table, row_group_index)?;
self.inner.put(key.clone(), bytes);
Ok(self
.inner
.get(&key)
.map(|v| v.as_slice())
.unwrap_or_default())
}
pub fn load_row_group_with_ttl(
&mut self,
file_id: impl Into<String>,
row_group_index: usize,
table: &ColumnarTable,
ttl: Duration,
) -> Result<&[u8], ColumnarError> {
let key = RowGroupKey::new(file_id, row_group_index);
if self.inner.contains_key(&key) {
self.hits += 1;
return Ok(self
.inner
.get(&key)
.map(|v| v.as_slice())
.unwrap_or_default());
}
self.misses += 1;
let bytes = serialise_row_group(table, row_group_index)?;
self.inner.put_with_ttl(key.clone(), bytes, ttl);
Ok(self
.inner
.get(&key)
.map(|v| v.as_slice())
.unwrap_or_default())
}
pub fn warm_from_table(
&mut self,
file_id: impl Into<String> + Clone,
table: &ColumnarTable,
) -> Result<(), ColumnarError> {
for idx in 0..table.batches.len() {
let key = RowGroupKey::new(file_id.clone(), idx);
if !self.inner.contains_key(&key) {
let bytes = serialise_row_group(table, idx)?;
self.inner.put(key, bytes);
}
}
Ok(())
}
pub fn get_as_batch(
&mut self,
key: &RowGroupKey,
) -> Result<Option<RecordBatch>, ColumnarError> {
let Some(bytes) = self.inner.get(key) else {
return Ok(None);
};
let batches = oxistore_columnar::read_batches_from_bytes(bytes)?;
Ok(batches.into_iter().next())
}
}
fn serialise_row_group(
table: &ColumnarTable,
row_group_index: usize,
) -> Result<Vec<u8>, ColumnarError> {
let batch = table.batches.get(row_group_index).ok_or_else(|| {
ColumnarError::SchemaMismatch(format!(
"row group index {row_group_index} out of range (table has {} batches)",
table.batches.len()
))
})?;
let mut single = ColumnarTable::new(Arc::clone(&table.schema));
single.push_unchecked(batch.clone());
single.write_to_bytes()
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::Int64Array;
use arrow::datatypes::{DataType, Field, Schema};
use std::sync::Arc;
fn make_table(num_batches: usize, rows_per_batch: usize) -> ColumnarTable {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("value", DataType::Int64, false),
]));
let mut table = ColumnarTable::new(Arc::clone(&schema));
for batch_idx in 0..num_batches {
let base = (batch_idx * rows_per_batch) as i64;
let ids: Vec<i64> = (base..base + rows_per_batch as i64).collect();
let vals: Vec<i64> = ids.iter().map(|&i| i * 2).collect();
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int64Array::from(ids)),
Arc::new(Int64Array::from(vals)),
],
)
.expect("batch construction");
table.push_unchecked(batch);
}
table
}
#[test]
fn load_row_group_returns_bytes() {
let table = make_table(3, 10);
let mut cache = ColumnarRowGroupCache::new(16);
let bytes = cache.load_row_group("test.parquet", 0, &table).unwrap();
assert!(!bytes.is_empty());
}
#[test]
fn load_row_group_hit_on_second_access() {
let table = make_table(3, 10);
let mut cache = ColumnarRowGroupCache::new(16);
cache.load_row_group("f.parquet", 0, &table).unwrap();
assert_eq!(cache.misses(), 1);
assert_eq!(cache.hits(), 0);
cache.load_row_group("f.parquet", 0, &table).unwrap();
assert_eq!(cache.hits(), 1);
assert_eq!(cache.misses(), 1);
}
#[test]
fn load_row_group_out_of_range_errors() {
let table = make_table(2, 5);
let mut cache = ColumnarRowGroupCache::new(16);
let result = cache.load_row_group("f.parquet", 99, &table);
assert!(result.is_err());
}
#[test]
fn invalidate_file_removes_entries() {
let table = make_table(4, 5);
let mut cache = ColumnarRowGroupCache::new(32);
for i in 0..4 {
cache.load_row_group("file_a", i, &table).unwrap();
}
cache.load_row_group("file_b", 0, &table).unwrap();
assert_eq!(cache.len(), 5);
cache.invalidate_file("file_a");
assert_eq!(cache.len(), 1); }
#[test]
fn warm_from_table_populates_all_groups() {
let table = make_table(5, 8);
let mut cache = ColumnarRowGroupCache::new(32);
cache.warm_from_table("warm_file", &table).unwrap();
assert_eq!(cache.len(), 5);
for i in 0..5 {
let key = RowGroupKey::new("warm_file", i);
assert!(cache.get(&key).is_some());
}
}
#[test]
fn get_as_batch_round_trip() {
let table = make_table(2, 4);
let mut cache = ColumnarRowGroupCache::new(16);
cache.load_row_group("rt.parquet", 0, &table).unwrap();
let key = RowGroupKey::new("rt.parquet", 0);
let batch = cache.get_as_batch(&key).unwrap().expect("batch present");
assert_eq!(batch.num_rows(), 4);
}
#[test]
fn clear_empties_cache() {
let table = make_table(3, 5);
let mut cache = ColumnarRowGroupCache::new(16);
cache.warm_from_table("f", &table).unwrap();
assert!(!cache.is_empty());
cache.clear();
assert!(cache.is_empty());
}
#[test]
fn hit_rate_computation() {
let table = make_table(1, 3);
let mut cache = ColumnarRowGroupCache::new(16);
cache.load_row_group("h.parquet", 0, &table).unwrap(); cache.load_row_group("h.parquet", 0, &table).unwrap(); cache.load_row_group("h.parquet", 0, &table).unwrap(); assert!((cache.hit_rate() - 2.0 / 3.0).abs() < 1e-9);
}
#[test]
fn evict_removes_single_entry() {
let table = make_table(2, 5);
let mut cache = ColumnarRowGroupCache::new(16);
cache.load_row_group("e.parquet", 0, &table).unwrap();
cache.load_row_group("e.parquet", 1, &table).unwrap();
let key = RowGroupKey::new("e.parquet", 0);
let evicted = cache.evict(&key);
assert!(evicted.is_some());
assert_eq!(cache.len(), 1);
}
#[test]
fn ttl_expired_entry_is_miss() {
let table = make_table(1, 2);
let mut cache = ColumnarRowGroupCache::new(16);
let key = RowGroupKey::new("ttl_file", 0);
let bytes = make_table(1, 2).write_to_bytes().expect("serialise");
cache.insert_with_ttl(key.clone(), bytes, Duration::from_nanos(1));
std::thread::yield_now();
cache.load_row_group("ttl_file", 0, &table).unwrap();
assert_eq!(cache.len(), 1);
}
}