use arrow_array::{Array, BooleanArray, Float64Array};
use arrow_schema::DataType;
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::logical_expr::{ColumnarValue, Volatility};
use datafusion::scalar::ScalarValue;
use crate::errors::FnError;
use crate::traits::scalar::ArgType;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
pub enum FoldSemiring {
#[default]
AddMult,
MaxMin,
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct FoldContext {
pub strict: bool,
pub epsilon: f64,
pub semiring: FoldSemiring,
}
impl Default for FoldContext {
fn default() -> Self {
Self {
strict: false,
epsilon: 0.0,
semiring: FoldSemiring::AddMult,
}
}
}
pub trait LocyAggregate: Send + Sync + std::fmt::Debug {
fn semilattice(&self) -> Semilattice;
fn output_type(&self) -> DataType;
fn output_type_for_input(&self, _input: &DataType) -> DataType {
self.output_type()
}
fn create(&self) -> Box<dyn LocyAggState>;
fn initial_accum_f64(&self) -> Option<f64> {
None
}
fn update_step(&self, _accum: f64, _val: f64, _strict: bool) -> Result<f64, FnError> {
Err(FnError::new(
FnError::CODE_UNKNOWN_FUNCTION,
"aggregate has no row-level update_step; use ingest()",
))
}
fn is_probability_aggregate(&self) -> bool {
false
}
fn is_noisy_or(&self) -> bool {
false
}
}
pub trait LocyAggState: Send + 'static {
fn as_any(&self) -> &dyn std::any::Any;
fn ingest_indices(
&mut self,
col: &dyn Array,
indices: &[usize],
cx: &FoldContext,
) -> Result<(), FnError>;
fn ingest(&mut self, batch: &RecordBatch, value_col: usize) -> Result<(), FnError> {
let indices: Vec<usize> = (0..batch.num_rows()).collect();
self.ingest_indices(batch.column(value_col), &indices, &FoldContext::default())
}
fn merge(&mut self, other: &dyn LocyAggState) -> Result<(), FnError>;
fn finalize(&self) -> Result<ScalarValue, FnError>;
fn is_at_top(&self) -> bool {
false
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct Semilattice {
pub idempotent: bool,
pub commutative: bool,
pub associative: bool,
pub monotone_join: bool,
pub has_top: bool,
}
impl Semilattice {
pub const NON_MONOTONE: Self = Self {
idempotent: false,
commutative: true,
associative: true,
monotone_join: false,
has_top: false,
};
pub const BOUNDED_MIN_MAX: Self = Self {
idempotent: true,
commutative: true,
associative: true,
monotone_join: true,
has_top: true,
};
pub const COUNT: Self = Self {
idempotent: false,
commutative: true,
associative: true,
monotone_join: true,
has_top: false,
};
}
pub trait LocyPredicate: Send + Sync {
fn signature(&self) -> &PredSignature;
fn evaluate(&self, args: &[ColumnarValue], rows: usize) -> Result<BooleanArray, FnError>;
fn evaluate_fuzzy(
&self,
_args: &[ColumnarValue],
_rows: usize,
) -> Option<Result<Float64Array, FnError>> {
None
}
}
#[derive(Clone, Debug)]
pub struct PredSignature {
pub args: Vec<ArgType>,
pub volatility: Volatility,
pub supports_fuzzy: bool,
pub batch_hint: BatchHint,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
#[non_exhaustive]
pub enum BatchHint {
Small,
#[default]
Medium,
Large,
}
#[derive(Clone, Debug, Default)]
pub struct DerivationTracker {
_placeholder: (),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn semilattice_constants() {
const {
assert!(Semilattice::BOUNDED_MIN_MAX.monotone_join);
assert!(Semilattice::BOUNDED_MIN_MAX.has_top);
assert!(!Semilattice::NON_MONOTONE.monotone_join);
assert!(Semilattice::COUNT.monotone_join);
assert!(!Semilattice::COUNT.has_top);
}
}
}