use crate::error::{ClientError, Result};
use byteorder::{ByteOrder, LittleEndian};
use std::marker::PhantomData;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum FieldType {
Int64,
UInt64,
Float64,
String,
Bytes,
Bool,
}
#[derive(Debug, Clone)]
pub struct ColumnRef {
pub id: usize,
pub name: String,
pub field_type: FieldType,
}
#[derive(Debug, Clone)]
pub struct ArraySchema {
pub name: String,
pub fields: Vec<String>,
pub types: Vec<FieldType>,
}
#[derive(Debug)]
pub struct TypedColumn<'a, T> {
data: &'a [u8],
count: usize,
_marker: PhantomData<T>,
}
impl<'a> TypedColumn<'a, i64> {
pub fn from_bytes(data: &'a [u8], count: usize) -> Self {
Self {
data,
count,
_marker: PhantomData,
}
}
#[inline]
pub fn get(&self, index: usize) -> Option<i64> {
if index >= self.count {
return None;
}
let offset = index * 8;
if offset + 8 > self.data.len() {
return None;
}
Some(LittleEndian::read_i64(&self.data[offset..offset + 8]))
}
pub fn to_vec(&self) -> Vec<i64> {
(0..self.count).filter_map(|i| self.get(i)).collect()
}
pub fn sum(&self) -> i64 {
(0..self.count).filter_map(|i| self.get(i)).sum()
}
pub fn min(&self) -> Option<i64> {
(0..self.count).filter_map(|i| self.get(i)).min()
}
pub fn max(&self) -> Option<i64> {
(0..self.count).filter_map(|i| self.get(i)).max()
}
pub fn avg(&self) -> Option<f64> {
if self.count == 0 {
return None;
}
let sum: i64 = self.sum();
Some(sum as f64 / self.count as f64)
}
pub fn len(&self) -> usize {
self.count
}
pub fn is_empty(&self) -> bool {
self.count == 0
}
pub fn iter(&self) -> impl Iterator<Item = i64> + '_ {
(0..self.count).filter_map(move |i| self.get(i))
}
}
impl<'a> TypedColumn<'a, f64> {
pub fn from_bytes(data: &'a [u8], count: usize) -> Self {
Self {
data,
count,
_marker: PhantomData,
}
}
#[inline]
pub fn get(&self, index: usize) -> Option<f64> {
if index >= self.count {
return None;
}
let offset = index * 8;
if offset + 8 > self.data.len() {
return None;
}
Some(LittleEndian::read_f64(&self.data[offset..offset + 8]))
}
pub fn to_vec(&self) -> Vec<f64> {
(0..self.count).filter_map(|i| self.get(i)).collect()
}
pub fn sum(&self) -> f64 {
(0..self.count).filter_map(|i| self.get(i)).sum()
}
pub fn min(&self) -> Option<f64> {
(0..self.count)
.filter_map(|i| self.get(i))
.fold(None, |acc, x| {
Some(match acc {
None => x,
Some(min) => {
if x < min {
x
} else {
min
}
}
})
})
}
pub fn max(&self) -> Option<f64> {
(0..self.count)
.filter_map(|i| self.get(i))
.fold(None, |acc, x| {
Some(match acc {
None => x,
Some(max) => {
if x > max {
x
} else {
max
}
}
})
})
}
pub fn avg(&self) -> Option<f64> {
if self.count == 0 {
return None;
}
let sum: f64 = self.sum();
Some(sum / self.count as f64)
}
pub fn std_dev(&self) -> Option<f64> {
let avg = self.avg()?;
let variance: f64 = (0..self.count)
.filter_map(|i| self.get(i))
.map(|x| (x - avg).powi(2))
.sum::<f64>()
/ self.count as f64;
Some(variance.sqrt())
}
pub fn len(&self) -> usize {
self.count
}
pub fn is_empty(&self) -> bool {
self.count == 0
}
pub fn iter(&self) -> impl Iterator<Item = f64> + '_ {
(0..self.count).filter_map(move |i| self.get(i))
}
}
pub struct ColumnView<'a> {
schema: &'a ArraySchema,
columns: &'a [ColumnRef],
_data: PhantomData<&'a [u8]>,
}
impl<'a> ColumnView<'a> {
pub fn new(schema: &'a ArraySchema, columns: &'a [ColumnRef]) -> Self {
Self {
schema,
columns,
_data: PhantomData,
}
}
#[allow(dead_code)]
pub fn schema_name(&self) -> &str {
&self.schema.name
}
pub fn column_count(&self) -> usize {
self.columns.len()
}
#[allow(dead_code)]
pub fn column_names(&self) -> Vec<&str> {
self.columns.iter().map(|c| c.name.as_str()).collect()
}
pub fn get_column(&self, name: &str) -> Option<&ColumnRef> {
self.columns.iter().find(|c| c.name == name)
}
pub fn simd_groups(&self) -> Vec<Vec<&str>> {
use std::collections::HashMap;
let mut groups: HashMap<FieldType, Vec<&str>> = HashMap::new();
for col in self.columns {
groups.entry(col.field_type).or_default().push(&col.name);
}
groups.into_values().collect()
}
pub fn column_i64(
&self,
name: &str,
data: &'a [u8],
count: usize,
) -> Result<TypedColumn<'a, i64>> {
let col = self
.get_column(name)
.ok_or_else(|| ClientError::NotFound(format!("Column '{}' not found", name)))?;
if col.field_type != FieldType::Int64 && col.field_type != FieldType::UInt64 {
return Err(ClientError::TypeMismatch {
expected: "Int64".to_string(),
actual: format!("{:?}", col.field_type),
});
}
Ok(TypedColumn::<i64>::from_bytes(data, count))
}
pub fn column_f64(
&self,
name: &str,
data: &'a [u8],
count: usize,
) -> Result<TypedColumn<'a, f64>> {
let col = self
.get_column(name)
.ok_or_else(|| ClientError::NotFound(format!("Column '{}' not found", name)))?;
if col.field_type != FieldType::Float64 {
return Err(ClientError::TypeMismatch {
expected: "Float64".to_string(),
actual: format!("{:?}", col.field_type),
});
}
Ok(TypedColumn::<f64>::from_bytes(data, count))
}
}
pub trait ColumnAccess {
fn row_count(&self) -> usize;
fn col_count(&self) -> usize;
fn field_names(&self) -> Vec<&str>;
}
#[cfg(test)]
mod tests {
use super::*;
use byteorder::WriteBytesExt;
fn i64_to_bytes(values: &[i64]) -> Vec<u8> {
let mut bytes = Vec::with_capacity(values.len() * 8);
for &v in values {
bytes.write_i64::<LittleEndian>(v).unwrap();
}
bytes
}
fn f64_to_bytes(values: &[f64]) -> Vec<u8> {
let mut bytes = Vec::with_capacity(values.len() * 8);
for &v in values {
bytes.write_f64::<LittleEndian>(v).unwrap();
}
bytes
}
#[test]
fn test_typed_column_i64() {
let data: Vec<i64> = vec![1, 2, 3, 4, 5];
let bytes = i64_to_bytes(&data);
let col = TypedColumn::<i64>::from_bytes(&bytes, 5);
assert_eq!(col.sum(), 15);
assert_eq!(col.min(), Some(1));
assert_eq!(col.max(), Some(5));
assert!((col.avg().unwrap() - 3.0).abs() < 0.001);
}
#[test]
fn test_typed_column_f64() {
let data: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let bytes = f64_to_bytes(&data);
let col = TypedColumn::<f64>::from_bytes(&bytes, 5);
assert!((col.sum() - 15.0).abs() < 0.001);
assert!((col.avg().unwrap() - 3.0).abs() < 0.001);
assert!(col.std_dev().is_some());
}
#[test]
fn test_safe_column_access_i64() {
let values = vec![100i64, 200, 300, -400, 500];
let bytes = i64_to_bytes(&values);
let col = TypedColumn::<i64>::from_bytes(&bytes, 5);
assert_eq!(col.get(0), Some(100));
assert_eq!(col.get(1), Some(200));
assert_eq!(col.get(2), Some(300));
assert_eq!(col.get(3), Some(-400));
assert_eq!(col.get(4), Some(500));
assert_eq!(col.get(5), None);
assert_eq!(col.to_vec(), values);
let collected: Vec<i64> = col.iter().collect();
assert_eq!(collected, values);
}
#[test]
fn test_safe_column_access_f64() {
let values = vec![1.5f64, 2.5, 3.5, -4.5, 5.5];
let bytes = f64_to_bytes(&values);
let col = TypedColumn::<f64>::from_bytes(&bytes, 5);
assert!((col.get(0).unwrap() - 1.5).abs() < 0.001);
assert!((col.get(1).unwrap() - 2.5).abs() < 0.001);
assert!((col.get(2).unwrap() - 3.5).abs() < 0.001);
assert!((col.get(3).unwrap() - (-4.5)).abs() < 0.001);
assert!((col.get(4).unwrap() - 5.5).abs() < 0.001);
assert_eq!(col.get(5), None);
for (a, b) in col.to_vec().iter().zip(values.iter()) {
assert!((a - b).abs() < 0.001);
}
}
#[test]
fn test_safe_column_misaligned_data() {
let values = vec![42i64, 84, 126];
let aligned_bytes = i64_to_bytes(&values);
let mut misaligned = vec![0u8];
misaligned.extend_from_slice(&aligned_bytes);
let col = TypedColumn::<i64>::from_bytes(&misaligned[1..], 3);
assert_eq!(col.get(0), Some(42));
assert_eq!(col.get(1), Some(84));
assert_eq!(col.get(2), Some(126));
assert_eq!(col.sum(), 252);
}
#[test]
fn test_empty_column() {
let bytes: Vec<u8> = vec![];
let i64_col = TypedColumn::<i64>::from_bytes(&bytes, 0);
assert!(i64_col.is_empty());
assert_eq!(i64_col.len(), 0);
assert_eq!(i64_col.sum(), 0);
assert_eq!(i64_col.min(), None);
assert_eq!(i64_col.max(), None);
assert_eq!(i64_col.avg(), None);
let f64_col = TypedColumn::<f64>::from_bytes(&bytes, 0);
assert!(f64_col.is_empty());
assert_eq!(f64_col.len(), 0);
assert!((f64_col.sum() - 0.0).abs() < 0.001);
assert_eq!(f64_col.min(), None);
assert_eq!(f64_col.max(), None);
assert_eq!(f64_col.avg(), None);
}
#[test]
fn test_column_view() {
let schema = ArraySchema {
name: "test".to_string(),
fields: vec!["a".to_string(), "b".to_string()],
types: vec![FieldType::Int64, FieldType::Float64],
};
let columns = vec![
ColumnRef {
id: 0,
name: "a".to_string(),
field_type: FieldType::Int64,
},
ColumnRef {
id: 1,
name: "b".to_string(),
field_type: FieldType::Float64,
},
];
let view = ColumnView::new(&schema, &columns);
assert_eq!(view.column_count(), 2);
assert!(view.get_column("a").is_some());
assert!(view.get_column("c").is_none());
}
#[test]
fn test_simd_groups() {
let schema = ArraySchema {
name: "test".to_string(),
fields: vec!["a".to_string(), "b".to_string(), "c".to_string()],
types: vec![FieldType::Int64, FieldType::Int64, FieldType::Float64],
};
let columns = vec![
ColumnRef {
id: 0,
name: "a".to_string(),
field_type: FieldType::Int64,
},
ColumnRef {
id: 1,
name: "b".to_string(),
field_type: FieldType::Int64,
},
ColumnRef {
id: 2,
name: "c".to_string(),
field_type: FieldType::Float64,
},
];
let view = ColumnView::new(&schema, &columns);
let groups = view.simd_groups();
assert_eq!(groups.len(), 2); }
#[test]
fn test_column_view_type_checking() {
let schema = ArraySchema {
name: "test".to_string(),
fields: vec!["int_col".to_string(), "float_col".to_string()],
types: vec![FieldType::Int64, FieldType::Float64],
};
let columns = vec![
ColumnRef {
id: 0,
name: "int_col".to_string(),
field_type: FieldType::Int64,
},
ColumnRef {
id: 1,
name: "float_col".to_string(),
field_type: FieldType::Float64,
},
];
let view = ColumnView::new(&schema, &columns);
let int_bytes = i64_to_bytes(&[1, 2, 3]);
let float_bytes = f64_to_bytes(&[1.0, 2.0, 3.0]);
assert!(view.column_i64("int_col", &int_bytes, 3).is_ok());
assert!(view.column_f64("float_col", &float_bytes, 3).is_ok());
assert!(view.column_f64("int_col", &int_bytes, 3).is_err());
assert!(view.column_i64("float_col", &float_bytes, 3).is_err());
assert!(view.column_i64("nonexistent", &int_bytes, 3).is_err());
}
#[test]
fn test_large_column_performance() {
let size = 10000;
let values: Vec<i64> = (0..size).collect();
let bytes = i64_to_bytes(&values);
let col = TypedColumn::<i64>::from_bytes(&bytes, size as usize);
assert_eq!(col.sum(), 49995000);
assert_eq!(col.min(), Some(0));
assert_eq!(col.max(), Some(9999));
assert_eq!(col.len(), 10000);
}
}