use std::marker::PhantomData;
use std::ops::{Add, Div, Mul, Rem, Sub};
use arrow_buffer::{ArrowNativeType, BooleanBuffer, NullBuffer};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ChunkContract {
ElementIndependent,
FixedSize(usize),
BoundaryDependent,
}
pub trait Kernel<T: ArrowNativeType>: Send + Sync {
fn execute_chunk(&self, inputs: &[&[T]], chunk_idx: usize) -> Vec<T>;
fn contract(&self) -> ChunkContract;
}
pub struct BinaryKernel<T, F> {
op: F,
_phantom: PhantomData<fn(T, T) -> T>,
}
impl<T: ArrowNativeType, F: Fn(T, T) -> T + Send + Sync> BinaryKernel<T, F> {
pub fn new(op: F) -> Self {
Self { op, _phantom: PhantomData }
}
}
impl<T: ArrowNativeType, F: Fn(T, T) -> T + Send + Sync> Kernel<T> for BinaryKernel<T, F> {
fn execute_chunk(&self, inputs: &[&[T]], _chunk_idx: usize) -> Vec<T> {
assert_eq!(inputs.len(), 2, "BinaryKernel requires exactly 2 inputs, got {}", inputs.len());
let len = inputs[0].len();
assert_eq!(inputs[1].len(), len, "BinaryKernel: input length mismatch");
(0..len).map(|i| (self.op)(inputs[0][i], inputs[1][i])).collect()
}
fn contract(&self) -> ChunkContract {
ChunkContract::ElementIndependent
}
}
pub struct UnaryKernel<T, F> {
op: F,
_phantom: PhantomData<fn(T) -> T>,
}
impl<T: ArrowNativeType, F: Fn(T) -> T + Send + Sync> UnaryKernel<T, F> {
pub fn new(op: F) -> Self {
Self { op, _phantom: PhantomData }
}
}
impl<T: ArrowNativeType, F: Fn(T) -> T + Send + Sync> Kernel<T> for UnaryKernel<T, F> {
fn execute_chunk(&self, inputs: &[&[T]], _chunk_idx: usize) -> Vec<T> {
assert_eq!(inputs.len(), 1, "UnaryKernel requires exactly 1 input, got {}", inputs.len());
inputs[0].iter().copied().map(|v| (self.op)(v)).collect()
}
fn contract(&self) -> ChunkContract {
ChunkContract::ElementIndependent
}
}
#[derive(Debug, Clone, Default)]
pub struct CondKernel<T>(PhantomData<T>);
impl<T: ArrowNativeType> CondKernel<T> {
pub fn new() -> Self {
Self(PhantomData)
}
}
impl<T: ArrowNativeType> Kernel<T> for CondKernel<T> {
fn execute_chunk(&self, inputs: &[&[T]], _chunk_idx: usize) -> Vec<T> {
assert_eq!(
inputs.len(), 3,
"CondKernel requires exactly 3 inputs (mask, then, else), got {}",
inputs.len(),
);
let len = inputs[0].len();
let zero = T::default();
(0..len)
.map(|i| if inputs[0][i] != zero { inputs[1][i] } else { inputs[2][i] })
.collect()
}
fn contract(&self) -> ChunkContract {
ChunkContract::ElementIndependent
}
}
pub fn is_null_mask(validity: &NullBuffer, one: f64, zero: f64) -> Vec<f64> {
(0..validity.len())
.map(|i| if validity.is_null(i) { one } else { zero })
.collect()
}
pub fn add<T: ArrowNativeType + Add<Output = T>>() -> impl Kernel<T> {
BinaryKernel::new(|a: T, b: T| a + b)
}
pub fn sub<T: ArrowNativeType + Sub<Output = T>>() -> impl Kernel<T> {
BinaryKernel::new(|a: T, b: T| a - b)
}
pub fn mul<T: ArrowNativeType + Mul<Output = T>>() -> impl Kernel<T> {
BinaryKernel::new(|a: T, b: T| a * b)
}
pub fn div<T: ArrowNativeType + Div<Output = T>>() -> impl Kernel<T> {
BinaryKernel::new(|a: T, b: T| a / b)
}
pub fn rem<T: ArrowNativeType + Rem<Output = T>>() -> impl Kernel<T> {
BinaryKernel::new(|a: T, b: T| a % b)
}
pub fn cond<T: ArrowNativeType>() -> impl Kernel<T> {
CondKernel::new()
}
pub fn propagate_nulls(inputs: &[Option<&NullBuffer>]) -> Option<NullBuffer> {
let null_bufs: Vec<&BooleanBuffer> =
inputs.iter().filter_map(|n| n.map(NullBuffer::inner)).collect();
if null_bufs.is_empty() {
return None; }
let combined = null_bufs[1..]
.iter()
.fold(null_bufs[0].clone(), |acc, &b| &acc & b);
Some(NullBuffer::new(combined))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn chunk_contract_variants_eq() {
assert_eq!(ChunkContract::ElementIndependent, ChunkContract::ElementIndependent);
assert_eq!(ChunkContract::FixedSize(4), ChunkContract::FixedSize(4));
assert_ne!(ChunkContract::FixedSize(4), ChunkContract::FixedSize(8));
assert_eq!(ChunkContract::BoundaryDependent, ChunkContract::BoundaryDependent);
}
#[test]
fn binary_kernel_add_f64() {
let k = BinaryKernel::new(|a: f64, b: f64| a + b);
let a = [1.0_f64, 2.0, 3.0];
let b = [10.0_f64, 20.0, 30.0];
assert_eq!(k.execute_chunk(&[&a, &b], 0), [11.0, 22.0, 33.0]);
assert_eq!(k.contract(), ChunkContract::ElementIndependent);
}
#[test]
fn binary_kernel_mul_i64() {
let k = BinaryKernel::new(|a: i64, b: i64| a * b);
let a = [2_i64, 3, 4];
let b = [5_i64, 6, 7];
assert_eq!(k.execute_chunk(&[&a, &b], 0), [10, 18, 28]);
}
#[test]
#[should_panic(expected = "requires exactly 2 inputs")]
fn binary_kernel_panics_wrong_input_count() {
let k = BinaryKernel::new(|a: f64, _b: f64| a);
let a = [1.0_f64];
k.execute_chunk(&[&a], 0); }
#[test]
fn unary_kernel_double() {
let k = UnaryKernel::new(|x: f64| x * 2.0);
let a = [1.0_f64, 2.0, 3.0];
assert_eq!(k.execute_chunk(&[&a], 0), [2.0, 4.0, 6.0]);
assert_eq!(k.contract(), ChunkContract::ElementIndependent);
}
#[test]
fn unary_kernel_negate_i32() {
let k = UnaryKernel::new(|x: i32| -x);
let a = [1_i32, -2, 3];
assert_eq!(k.execute_chunk(&[&a], 0), [-1, 2, -3]);
}
#[test]
fn cond_kernel_selects_correctly() {
let k = CondKernel::<f64>::new();
let mask = [1.0_f64, 0.0, 1.0, 0.0]; let then = [10.0_f64, 20.0, 30.0, 40.0];
let else_ = [100.0_f64, 200.0, 300.0, 400.0];
assert_eq!(
k.execute_chunk(&[&mask, &then, &else_], 0),
[10.0, 200.0, 30.0, 400.0],
);
}
#[test]
fn cond_kernel_all_true() {
let k = CondKernel::<i32>::new();
let mask = [1_i32, 2, 3]; let then = [10_i32, 20, 30];
let else_ = [0_i32, 0, 0];
assert_eq!(k.execute_chunk(&[&mask, &then, &else_], 0), [10, 20, 30]);
}
#[test]
#[should_panic(expected = "requires exactly 3 inputs")]
fn cond_kernel_panics_wrong_input_count() {
let k = CondKernel::<f64>::new();
k.execute_chunk(&[&[1.0_f64], &[2.0_f64]], 0); }
#[test]
fn factory_add_kernel() {
let k = add::<f64>();
let a = [3.0_f64, 4.0];
let b = [1.0_f64, 2.0];
assert_eq!(k.execute_chunk(&[&a, &b], 0), [4.0, 6.0]);
}
#[test]
fn factory_sub_kernel() {
let k = sub::<f64>();
assert_eq!(k.execute_chunk(&[&[10.0_f64], &[3.0_f64]], 0), [7.0]);
}
#[test]
fn factory_mul_kernel() {
let k = mul::<i64>();
assert_eq!(k.execute_chunk(&[&[6_i64], &[7_i64]], 0), [42]);
}
#[test]
fn factory_div_kernel() {
let k = div::<f64>();
assert_eq!(k.execute_chunk(&[&[10.0_f64], &[4.0_f64]], 0), [2.5]);
}
#[test]
fn factory_rem_kernel() {
let k = rem::<i32>();
assert_eq!(k.execute_chunk(&[&[10_i32], &[3_i32]], 0), [1]);
}
#[test]
fn factory_cond_kernel() {
let k = cond::<f64>();
let mask = [0.0_f64, 1.0];
let then = [99.0_f64, 99.0];
let else_ = [0.0_f64, 0.0];
assert_eq!(k.execute_chunk(&[&mask, &then, &else_], 0), [0.0, 99.0]);
}
#[test]
fn is_null_mask_marks_nulls() {
let validity = NullBuffer::new(BooleanBuffer::from(vec![true, false, true, false]));
let result = is_null_mask(&validity, 1.0, 0.0);
assert_eq!(result, [0.0, 1.0, 0.0, 1.0]);
}
#[test]
fn propagate_nulls_all_valid_returns_none() {
assert!(propagate_nulls(&[None, None, None]).is_none());
}
#[test]
fn propagate_nulls_empty_returns_none() {
assert!(propagate_nulls(&[]).is_none());
}
#[test]
fn propagate_nulls_single_buffer() {
let b = NullBuffer::new(BooleanBuffer::from(vec![false, true, true]));
let out = propagate_nulls(&[None, Some(&b)]).unwrap();
assert!(out.is_null(0));
assert!(out.is_valid(1));
assert!(out.is_valid(2));
assert_eq!(out.null_count(), 1);
}
#[test]
fn propagate_nulls_ands_two_bitmaps() {
let a = NullBuffer::new(BooleanBuffer::from(vec![true, false, true, true]));
let b = NullBuffer::new(BooleanBuffer::from(vec![true, true, false, true]));
let out = propagate_nulls(&[Some(&a), Some(&b)]).unwrap();
assert!(out.is_valid(0));
assert!(out.is_null(1)); assert!(out.is_null(2)); assert!(out.is_valid(3));
assert_eq!(out.null_count(), 2);
}
#[test]
fn propagate_nulls_three_inputs_mixed() {
let a = NullBuffer::new(BooleanBuffer::from(vec![true, false, true]));
let c = NullBuffer::new(BooleanBuffer::from(vec![true, true, false]));
let out = propagate_nulls(&[Some(&a), None, Some(&c)]).unwrap();
assert!(out.is_valid(0));
assert!(out.is_null(1));
assert!(out.is_null(2));
}
}