use crate::error::{DbxError, DbxResult};
use arrow::array::{ArrayRef, RecordBatch};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use arrow::ipc::reader::StreamReader;
use dashmap::DashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
const DEFAULT_MAX_MEMORY: usize = 1024 * 1024 * 1024;
#[derive(Clone)]
pub enum CachedData {
Typed {
schema: SchemaRef,
batches: Vec<RecordBatch>,
},
Raw {
batches: Vec<RecordBatch>, },
}
impl CachedData {
pub fn batches(&self) -> &[RecordBatch] {
match self {
CachedData::Typed { batches, .. } => batches,
CachedData::Raw { batches } => batches,
}
}
pub fn schema(&self) -> SchemaRef {
match self {
CachedData::Typed { schema, .. } => schema.clone(),
CachedData::Raw { batches } => {
if batches.is_empty() {
Arc::new(Schema::new(vec![
Field::new("key", DataType::Binary, false),
Field::new("value", DataType::Binary, true),
]))
} else {
batches[0].schema()
}
}
}
}
pub fn is_typed(&self) -> bool {
matches!(self, CachedData::Typed { .. })
}
}
pub struct ColumnarCache {
tables: DashMap<String, Arc<TableCache>>,
max_memory: usize,
current_memory: AtomicUsize,
access_counter: AtomicU64,
}
struct TableCache {
data: parking_lot::RwLock<CachedData>,
_last_sync_ts: AtomicU64,
last_access: AtomicU64,
memory_usage: AtomicUsize,
}
impl ColumnarCache {
pub fn new() -> Self {
Self::with_memory_limit(DEFAULT_MAX_MEMORY)
}
pub fn with_memory_limit(max_memory: usize) -> Self {
Self {
tables: DashMap::new(),
max_memory,
current_memory: AtomicUsize::new(0),
access_counter: AtomicU64::new(0),
}
}
pub fn memory_usage(&self) -> usize {
self.current_memory.load(Ordering::Relaxed)
}
pub fn memory_limit(&self) -> usize {
self.max_memory
}
pub fn should_evict(&self) -> bool {
self.memory_usage() > self.max_memory
}
pub fn persist_to_disk(&self, table: &str, cache_dir: &str) -> DbxResult<()> {
use crate::storage::arrow_ipc::write_ipc_batch;
use std::fs;
use std::path::Path;
let table_cache = self
.tables
.get(table)
.ok_or_else(|| DbxError::Storage(format!("Table '{}' not in cache", table)))?;
let data = table_cache.data.read();
let batches = data.batches();
if batches.is_empty() {
return Ok(());
}
let cache_path = Path::new(cache_dir);
fs::create_dir_all(cache_path)
.map_err(|e| DbxError::Storage(format!("Failed to create cache dir: {}", e)))?;
for (idx, batch) in batches.iter().enumerate() {
let ipc_bytes = write_ipc_batch(batch)?;
let file_path = cache_path.join(format!("{}_{}.arrow", table, idx));
fs::write(&file_path, ipc_bytes)
.map_err(|e| DbxError::Storage(format!("Failed to write cache file: {}", e)))?;
}
Ok(())
}
pub fn load_from_disk(&self, table: &str, cache_dir: &str) -> DbxResult<Vec<RecordBatch>> {
use crate::storage::arrow_ipc::read_ipc_batch;
use std::fs;
use std::path::Path;
let cache_path = Path::new(cache_dir);
if !cache_path.exists() {
return Ok(vec![]);
}
let mut batches = Vec::new();
let mut idx = 0;
loop {
let file_path = cache_path.join(format!("{}_{}.arrow", table, idx));
if !file_path.exists() {
break;
}
let ipc_bytes = fs::read(&file_path)
.map_err(|e| DbxError::Storage(format!("Failed to read cache file: {}", e)))?;
let batch = read_ipc_batch(&ipc_bytes)?;
batches.push(batch);
idx += 1;
}
if !batches.is_empty() {
for batch in &batches {
self.insert_batch(table, batch.clone())?;
}
}
Ok(batches)
}
pub fn clear_disk_cache(&self, table: &str, cache_dir: &str) -> DbxResult<()> {
use std::fs;
use std::path::Path;
let cache_path = Path::new(cache_dir);
if !cache_path.exists() {
return Ok(());
}
let mut idx = 0;
loop {
let file_path = cache_path.join(format!("{}_{}.arrow", table, idx));
if !file_path.exists() {
break;
}
fs::remove_file(&file_path)
.map_err(|e| DbxError::Storage(format!("Failed to remove cache file: {}", e)))?;
idx += 1;
}
Ok(())
}
pub fn insert_batch(&self, table: &str, batch: RecordBatch) -> DbxResult<()> {
let memory_size = estimate_batch_memory(&batch);
let mut attempts = 0;
const MAX_EVICTION_ATTEMPTS: usize = 10;
while self.current_memory.load(Ordering::Relaxed) + memory_size > self.max_memory {
if attempts >= MAX_EVICTION_ATTEMPTS {
return Err(DbxError::Storage(
"Columnar cache memory limit exceeded (eviction failed)".to_string(),
));
}
if !self.evict_lru() {
return Err(DbxError::Storage(
"Columnar cache memory limit exceeded (nothing to evict)".to_string(),
));
}
attempts += 1;
}
let table_cache = self.tables.entry(table.to_string()).or_insert_with(|| {
Arc::new(TableCache {
data: parking_lot::RwLock::new(CachedData::Raw {
batches: Vec::new(),
}),
_last_sync_ts: AtomicU64::new(0),
last_access: AtomicU64::new(self.access_counter.fetch_add(1, Ordering::Relaxed)),
memory_usage: AtomicUsize::new(0),
})
});
table_cache.last_access.store(
self.access_counter.fetch_add(1, Ordering::Relaxed),
Ordering::Relaxed,
);
let mut data = table_cache.data.write();
match &mut *data {
CachedData::Raw { batches } => batches.push(batch),
CachedData::Typed { batches, .. } => batches.push(batch),
}
table_cache
.memory_usage
.fetch_add(memory_size, Ordering::Relaxed);
self.current_memory
.fetch_add(memory_size, Ordering::Relaxed);
Ok(())
}
pub fn sync_from_storage(
&self,
table: &str,
rows: Vec<(Vec<u8>, Vec<u8>)>,
table_schema: Option<SchemaRef>,
) -> DbxResult<usize> {
if rows.is_empty() {
self.clear_table(table)?;
return Ok(0);
}
if let Some(schema) = table_schema {
self.sync_typed(table, rows, schema)
} else {
self.sync_raw(table, rows)
}
}
fn deserialize_arrow_ipc(value: &[u8]) -> DbxResult<RecordBatch> {
let cursor = std::io::Cursor::new(value);
let mut reader = StreamReader::try_new(cursor, None)
.map_err(|e| DbxError::Serialization(format!("Arrow IPC read error: {}", e)))?;
reader
.next()
.ok_or_else(|| DbxError::Serialization("No batch in Arrow IPC stream".to_string()))?
.map_err(|e| DbxError::Serialization(format!("Arrow IPC batch error: {}", e)))
}
fn sync_typed(
&self,
table: &str,
rows: Vec<(Vec<u8>, Vec<u8>)>,
schema: SchemaRef,
) -> DbxResult<usize> {
let mut batches = Vec::new();
for (_key, value) in rows.iter() {
let batch = Self::deserialize_arrow_ipc(value)?;
batches.push(batch);
}
self.clear_table(table)?;
if !batches.is_empty() {
let consolidated = arrow::compute::concat_batches(&schema, &batches)
.map_err(|e| DbxError::Storage(format!("Failed to consolidate batches: {}", e)))?;
self.insert_typed_batch(table, schema, consolidated)?;
}
Ok(rows.len())
}
fn sync_raw(&self, table: &str, rows: Vec<(Vec<u8>, Vec<u8>)>) -> DbxResult<usize> {
use arrow::array::builder::BinaryBuilder;
let schema = Arc::new(Schema::new(vec![
Field::new("key", DataType::Binary, false),
Field::new("value", DataType::Binary, true),
]));
let mut key_builder = BinaryBuilder::with_capacity(rows.len(), rows.len() * 32);
let mut val_builder = BinaryBuilder::with_capacity(rows.len(), rows.len() * 128);
for (k, v) in rows {
let user_key = if k.len() > 8 {
if let Ok(vk) = crate::transaction::mvcc::version::VersionedKey::decode(&k) {
vk.user_key
} else {
k
}
} else {
k
};
key_builder.append_value(user_key);
val_builder.append_value(v);
}
let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(key_builder.finish()),
Arc::new(val_builder.finish()),
],
)?;
let row_count = batch.num_rows();
self.clear_table(table)?;
self.insert_batch(table, batch)?;
Ok(row_count)
}
fn insert_typed_batch(
&self,
table: &str,
schema: SchemaRef,
batch: RecordBatch,
) -> DbxResult<()> {
let memory_size = estimate_batch_memory(&batch);
let mut attempts = 0;
const MAX_EVICTION_ATTEMPTS: usize = 10;
while self.current_memory.load(Ordering::Relaxed) + memory_size > self.max_memory {
if attempts >= MAX_EVICTION_ATTEMPTS {
return Err(DbxError::Storage(
"Columnar cache memory limit exceeded (eviction failed)".to_string(),
));
}
if !self.evict_lru() {
return Err(DbxError::Storage(
"Columnar cache memory limit exceeded (nothing to evict)".to_string(),
));
}
attempts += 1;
}
let table_cache = {
self.tables
.entry(table.to_string())
.or_insert_with(|| {
Arc::new(TableCache {
data: parking_lot::RwLock::new(CachedData::Typed {
schema: schema.clone(),
batches: Vec::new(),
}),
_last_sync_ts: AtomicU64::new(0),
last_access: AtomicU64::new(
self.access_counter.fetch_add(1, Ordering::Relaxed),
),
memory_usage: AtomicUsize::new(0),
})
})
.clone()
};
table_cache.last_access.store(
self.access_counter.fetch_add(1, Ordering::Relaxed),
Ordering::Relaxed,
);
let mut data = table_cache.data.write();
match &mut *data {
CachedData::Typed { batches, .. } => {
batches.push(batch);
}
CachedData::Raw { .. } => {
*data = CachedData::Typed {
schema,
batches: vec![batch],
};
}
}
drop(data);
table_cache
.memory_usage
.fetch_add(memory_size, Ordering::Relaxed);
self.current_memory
.fetch_add(memory_size, Ordering::Relaxed);
Ok(())
}
pub fn get_batches_with_filter<F>(
&self,
table: &str,
projection: Option<&[usize]>,
filter: F,
) -> DbxResult<Option<Vec<RecordBatch>>>
where
F: Fn(&RecordBatch) -> DbxResult<arrow::array::BooleanArray>,
{
let Some(table_cache) = self.tables.get(table) else {
return Ok(None);
};
let current_access = self.access_counter.fetch_add(1, Ordering::Relaxed);
table_cache
.last_access
.store(current_access, Ordering::Relaxed);
let data = table_cache.data.read();
let batches = data.batches();
if batches.is_empty() {
return Ok(None);
}
let mut result = Vec::with_capacity(batches.len());
for batch in batches.iter() {
let mask = filter(batch)?;
let filtered_batch = arrow::compute::filter_record_batch(batch, &mask)
.map_err(|e| DbxError::Storage(format!("Failed to filter batch: {}", e)))?;
if filtered_batch.num_rows() == 0 {
continue;
}
let final_batch = if let Some(indices) = projection {
project_batch(&filtered_batch, indices)?
} else {
filtered_batch
};
result.push(final_batch);
}
Ok(Some(result))
}
pub fn get_batches(
&self,
table: &str,
projection: Option<&[usize]>,
) -> DbxResult<Option<Vec<RecordBatch>>> {
let table_key = self
.tables
.iter()
.find(|entry| entry.key().to_lowercase() == table.to_lowercase())
.map(|entry| entry.key().clone());
let lookup_key = table_key.as_deref().unwrap_or(table);
let Some(table_cache) = self.tables.get(lookup_key) else {
return Ok(None);
};
let current_access = self.access_counter.fetch_add(1, Ordering::Relaxed);
table_cache
.last_access
.store(current_access, Ordering::Relaxed);
let data = table_cache.data.read();
let batches = data.batches();
if batches.is_empty() {
return Ok(None);
}
let result = if let Some(indices) = projection {
batches
.iter()
.map(|batch| project_batch(batch, indices))
.collect::<DbxResult<Vec<_>>>()?
} else {
batches.to_vec()
};
Ok(Some(result))
}
pub fn clear_table(&self, table: &str) -> DbxResult<()> {
if let Some((_, table_cache)) = self.tables.remove(table) {
let memory = table_cache.memory_usage.load(Ordering::Relaxed);
self.current_memory.fetch_sub(memory, Ordering::Relaxed);
}
Ok(())
}
pub fn clear_all(&self) -> DbxResult<()> {
self.tables.clear();
self.current_memory.store(0, Ordering::Relaxed);
Ok(())
}
pub fn get_schema(&self, table: &str) -> Option<SchemaRef> {
self.tables.get(table).map(|tc| {
let data = tc.data.read();
data.schema()
})
}
fn evict_lru(&self) -> bool {
let candidate = self
.tables
.iter()
.min_by_key(|entry| entry.value().last_access.load(Ordering::Relaxed))
.map(|entry| entry.key().clone());
if let Some(table_to_evict) = candidate {
if let Some((_, table_cache)) = self.tables.remove(&table_to_evict) {
let memory = table_cache.memory_usage.load(Ordering::Relaxed);
self.current_memory.fetch_sub(memory, Ordering::Relaxed);
return true;
}
}
false
}
pub fn table_names(&self) -> Vec<String> {
self.tables.iter().map(|e| e.key().clone()).collect()
}
pub fn has_table(&self, table: &str) -> bool {
self.tables.contains_key(table)
}
}
impl Default for ColumnarCache {
fn default() -> Self {
Self::new()
}
}
fn estimate_batch_memory(batch: &RecordBatch) -> usize {
batch
.columns()
.iter()
.map(|array| array.get_array_memory_size())
.sum()
}
fn project_batch(batch: &RecordBatch, indices: &[usize]) -> DbxResult<RecordBatch> {
let schema = batch.schema();
let columns: Vec<ArrayRef> = indices.iter().map(|&i| batch.column(i).clone()).collect();
let projected_fields: Vec<_> = indices.iter().map(|&i| schema.field(i).clone()).collect();
let projected_schema = Arc::new(arrow::datatypes::Schema::new(projected_fields));
RecordBatch::try_new(projected_schema, columns)
.map_err(|e| DbxError::Storage(format!("Failed to project batch: {}", e)))
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::{Int32Array, StringArray};
use arrow::datatypes::{DataType, Field, Schema};
fn create_test_batch() -> RecordBatch {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, false),
]));
let id_array = Int32Array::from(vec![1, 2, 3]);
let name_array = StringArray::from(vec!["Alice", "Bob", "Charlie"]);
RecordBatch::try_new(schema, vec![Arc::new(id_array), Arc::new(name_array)]).unwrap()
}
#[test]
fn test_insert_and_get() {
let cache = ColumnarCache::new();
let batch = create_test_batch();
cache.insert_batch("users", batch.clone()).unwrap();
let result = cache.get_batches("users", None).unwrap();
assert!(result.is_some());
assert_eq!(result.unwrap().len(), 1);
}
#[test]
fn test_projection() {
let cache = ColumnarCache::new();
let batch = create_test_batch();
cache.insert_batch("users", batch).unwrap();
let result = cache.get_batches("users", Some(&[0])).unwrap().unwrap();
assert_eq!(result[0].num_columns(), 1);
assert_eq!(result[0].schema().field(0).name(), "id");
}
#[test]
fn test_memory_tracking() {
let cache = ColumnarCache::new();
let batch = create_test_batch();
let initial_memory = cache.memory_usage();
cache.insert_batch("users", batch).unwrap();
let after_insert = cache.memory_usage();
assert!(after_insert > initial_memory);
}
#[test]
fn test_clear_table() {
let cache = ColumnarCache::new();
let batch = create_test_batch();
cache.insert_batch("users", batch).unwrap();
assert!(cache.get_batches("users", None).unwrap().is_some());
cache.clear_table("users").unwrap();
assert!(cache.get_batches("users", None).unwrap().is_none());
assert_eq!(cache.memory_usage(), 0);
}
#[test]
fn test_memory_limit() {
let cache = ColumnarCache::with_memory_limit(100); let batch = create_test_batch();
let result = cache.insert_batch("users", batch);
assert!(result.is_err()); }
#[test]
fn test_table_names() {
let cache = ColumnarCache::new();
let batch = create_test_batch();
cache.insert_batch("users", batch.clone()).unwrap();
cache.insert_batch("orders", batch).unwrap();
let mut names = cache.table_names();
names.sort();
assert_eq!(names, vec!["orders", "users"]);
}
#[test]
fn test_lru_eviction() {
let batch = create_test_batch();
let batch_size = estimate_batch_memory(&batch);
let cache = ColumnarCache::with_memory_limit(batch_size * 2 + 100);
cache.insert_batch("A", batch.clone()).unwrap();
cache.insert_batch("B", batch.clone()).unwrap();
cache.get_batches("A", None).unwrap();
cache.insert_batch("C", batch.clone()).unwrap();
let names = cache.table_names();
assert!(names.contains(&"A".to_string()));
assert!(names.contains(&"C".to_string()));
assert!(!names.contains(&"B".to_string())); }
#[test]
fn test_filter_pushdown() {
let cache = ColumnarCache::new();
let batch = create_test_batch();
cache.insert_batch("users", batch).unwrap();
let result = cache
.get_batches_with_filter("users", None, |batch| {
use arrow::array::Array; let id_col = batch
.column(0)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
let mut builder = arrow::array::BooleanBuilder::with_capacity(id_col.len());
for i in 0..id_col.len() {
if id_col.is_null(i) {
builder.append_null();
} else {
builder.append_value(id_col.value(i) > 1);
}
}
Ok(builder.finish())
})
.unwrap()
.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].num_rows(), 2);
let ids = result[0]
.column(0)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
assert_eq!(ids.value(0), 2);
assert_eq!(ids.value(1), 3);
}
}