#![allow(clippy::disallowed_types)]
use std::collections::HashMap;
use arrow::array::RecordBatch;
use arrow::datatypes::SchemaRef;
use crate::error::DbError;
#[allow(dead_code)]
pub(crate) enum TableBackend {
InMemory {
rows: HashMap<String, RecordBatch>,
},
}
#[allow(dead_code, clippy::unnecessary_wraps)]
impl TableBackend {
pub fn in_memory() -> Self {
Self::InMemory {
rows: HashMap::new(),
}
}
pub fn get(&self, key: &str) -> Result<Option<RecordBatch>, DbError> {
match self {
Self::InMemory { rows } => Ok(rows.get(key).cloned()),
}
}
pub fn put(&mut self, key: &str, batch: RecordBatch) -> Result<bool, DbError> {
match self {
Self::InMemory { rows } => {
let existed = rows.insert(key.to_string(), batch).is_some();
Ok(existed)
}
}
}
pub fn remove(&mut self, key: &str) -> Result<bool, DbError> {
match self {
Self::InMemory { rows } => Ok(rows.remove(key).is_some()),
}
}
pub fn contains_key(&self, key: &str) -> Result<bool, DbError> {
match self {
Self::InMemory { rows } => Ok(rows.contains_key(key)),
}
}
pub fn keys(&self) -> Result<Vec<String>, DbError> {
match self {
Self::InMemory { rows } => Ok(rows.keys().cloned().collect()),
}
}
pub fn len(&self) -> Result<usize, DbError> {
match self {
Self::InMemory { rows } => Ok(rows.len()),
}
}
pub fn drain(&mut self) -> Result<Vec<(String, RecordBatch)>, DbError> {
match self {
Self::InMemory { rows } => Ok(rows.drain().collect()),
}
}
pub fn to_record_batch(&self, schema: &SchemaRef) -> Result<Option<RecordBatch>, DbError> {
match self {
Self::InMemory { rows } => {
if rows.is_empty() {
return Ok(Some(RecordBatch::new_empty(schema.clone())));
}
let batches: Vec<&RecordBatch> = rows.values().collect();
arrow::compute::concat_batches(schema, batches.iter().copied())
.map(Some)
.map_err(|e| DbError::Storage(format!("concat batches: {e}")))
}
}
}
#[allow(clippy::unused_self)]
pub fn is_persistent(&self) -> bool {
false
}
pub fn is_empty(&self) -> Result<bool, DbError> {
self.len().map(|n| n == 0)
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
use arrow::array::{Float64Array, Int32Array, StringArray};
use arrow::datatypes::{DataType, Field, Schema};
fn test_schema() -> SchemaRef {
Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, false),
Field::new("price", DataType::Float64, true),
]))
}
fn make_batch(id: i32, name: &str, price: f64) -> RecordBatch {
RecordBatch::try_new(
test_schema(),
vec![
Arc::new(Int32Array::from(vec![id])),
Arc::new(StringArray::from(vec![name])),
Arc::new(Float64Array::from(vec![price])),
],
)
.unwrap()
}
#[test]
fn test_in_memory_crud() {
let mut backend = TableBackend::in_memory();
assert!(!backend.is_persistent());
assert!(backend.is_empty().unwrap());
let existed = backend.put("1", make_batch(1, "A", 1.0)).unwrap();
assert!(!existed);
assert_eq!(backend.len().unwrap(), 1);
let row = backend.get("1").unwrap().unwrap();
assert_eq!(row.num_rows(), 1);
assert!(backend.contains_key("1").unwrap());
assert!(!backend.contains_key("2").unwrap());
let existed = backend.put("1", make_batch(1, "B", 2.0)).unwrap();
assert!(existed);
assert_eq!(backend.len().unwrap(), 1);
let existed = backend.remove("1").unwrap();
assert!(existed);
assert!(backend.is_empty().unwrap());
let existed = backend.remove("1").unwrap();
assert!(!existed);
}
#[test]
fn test_in_memory_keys_and_drain() {
let mut backend = TableBackend::in_memory();
backend.put("a", make_batch(1, "A", 1.0)).unwrap();
backend.put("b", make_batch(2, "B", 2.0)).unwrap();
let mut keys = backend.keys().unwrap();
keys.sort();
assert_eq!(keys, vec!["a", "b"]);
let items = backend.drain().unwrap();
assert_eq!(items.len(), 2);
assert!(backend.is_empty().unwrap());
}
#[test]
fn test_in_memory_to_record_batch() {
let mut backend = TableBackend::in_memory();
let schema = test_schema();
let batch = backend.to_record_batch(&schema).unwrap().unwrap();
assert_eq!(batch.num_rows(), 0);
backend.put("1", make_batch(1, "A", 1.0)).unwrap();
backend.put("2", make_batch(2, "B", 2.0)).unwrap();
let batch = backend.to_record_batch(&schema).unwrap().unwrap();
assert_eq!(batch.num_rows(), 2);
}
}