use crate::query::df_graph::common::{
ScalarKey, arrow_err, compute_plan_properties, extract_scalar_key,
};
use arrow_array::builder::Float64Builder;
use arrow_array::{Array, RecordBatch};
use arrow_schema::{DataType, Field, Schema, SchemaRef};
use datafusion::common::Result as DFResult;
use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
use datafusion::scalar::ScalarValue;
use futures::{Stream, TryStreamExt};
use smol_str::SmolStr;
use std::any::Any;
use std::collections::HashMap;
use std::fmt;
use std::pin::Pin;
use std::sync::{Arc, OnceLock};
use std::task::{Context, Poll};
use uni_locy::SemiringKind;
use uni_plugin::traits::locy::{FoldContext, FoldSemiring, LocyAggregate};
use super::locy_explain::ProofTerm;
#[must_use]
pub fn is_monotonic_aggregate(registry: &uni_plugin::PluginRegistry, name: &str) -> Option<bool> {
resolve_locy_aggregate(registry, name).map(|e| e.aggregate.semilattice().monotone_join)
}
#[must_use]
pub fn resolve_locy_aggregate(
registry: &uni_plugin::PluginRegistry,
name: &str,
) -> Option<std::sync::Arc<uni_plugin::registry::LocyAggregateEntry>> {
let canonical = match name.to_uppercase().as_str() {
"MMAX" => "MAX".to_owned(),
"MMIN" => "MIN".to_owned(),
"MCOUNT" => "COUNT".to_owned(),
"NOR" => "MNOR".to_owned(),
"PROD" => "MPROD".to_owned(),
other => other.to_owned(),
};
let qname = uni_plugin::QName::builtin(canonical);
if let Some(session_pr) = crate::current_session_plugin_registry()
&& let Some(entry) = session_pr.locy_aggregate(&qname)
{
return Some(entry);
}
registry.locy_aggregate(&qname)
}
#[must_use]
pub fn default_locy_plugin_registry() -> Arc<uni_plugin::PluginRegistry> {
static REGISTRY: OnceLock<Arc<uni_plugin::PluginRegistry>> = OnceLock::new();
Arc::clone(REGISTRY.get_or_init(|| {
let registry = uni_plugin::PluginRegistry::new();
let plugin_id = uni_plugin::PluginId::new(uni_plugin::QName::BUILTIN_NS);
let caps = uni_plugin::CapabilitySet::from_iter_of([uni_plugin::Capability::LocyAggregate]);
let mut r = uni_plugin::PluginRegistrar::new(plugin_id, &caps, ®istry);
uni_plugin_builtin::locy_aggregates::register_into(&mut r)
.expect("built-in locy aggregates register");
r.commit_to_registry().expect("commit built-in aggregates");
Arc::new(registry)
}))
}
#[derive(Debug, Clone)]
pub struct FoldBinding {
pub output_name: String,
pub name: SmolStr,
pub aggregate: Arc<dyn LocyAggregate>,
pub input_col_index: usize,
pub input_col_name: Option<String>,
}
#[derive(Debug)]
pub struct FoldExec {
input: Arc<dyn ExecutionPlan>,
key_indices: Vec<usize>,
fold_bindings: Vec<FoldBinding>,
strict_probability_domain: bool,
probability_epsilon: f64,
semiring_kind: SemiringKind,
provenance_tracker: Option<Arc<super::locy_explain::ProvenanceStore>>,
top_k_proofs_k: usize,
body_support_map: Option<Arc<HashMap<Vec<u8>, Vec<ProofTerm>>>>,
schema: SchemaRef,
properties: Arc<PlanProperties>,
metrics: ExecutionPlanMetricsSet,
}
impl FoldExec {
pub fn new(
input: Arc<dyn ExecutionPlan>,
key_indices: Vec<usize>,
fold_bindings: Vec<FoldBinding>,
strict_probability_domain: bool,
probability_epsilon: f64,
) -> Self {
Self::new_with_semiring(
input,
key_indices,
fold_bindings,
strict_probability_domain,
probability_epsilon,
SemiringKind::AddMultProb,
)
}
pub fn new_with_semiring(
input: Arc<dyn ExecutionPlan>,
key_indices: Vec<usize>,
fold_bindings: Vec<FoldBinding>,
strict_probability_domain: bool,
probability_epsilon: f64,
semiring_kind: SemiringKind,
) -> Self {
Self::new_with_topk(
input,
key_indices,
fold_bindings,
strict_probability_domain,
probability_epsilon,
semiring_kind,
None,
0,
None,
)
}
#[allow(clippy::too_many_arguments)]
pub fn new_with_topk(
input: Arc<dyn ExecutionPlan>,
key_indices: Vec<usize>,
fold_bindings: Vec<FoldBinding>,
strict_probability_domain: bool,
probability_epsilon: f64,
semiring_kind: SemiringKind,
provenance_tracker: Option<Arc<super::locy_explain::ProvenanceStore>>,
top_k_proofs_k: usize,
body_support_map: Option<Arc<HashMap<Vec<u8>, Vec<ProofTerm>>>>,
) -> Self {
let input_schema = input.schema();
let schema = Self::build_output_schema(&input_schema, &key_indices, &fold_bindings);
let properties = compute_plan_properties(Arc::clone(&schema));
Self {
input,
key_indices,
fold_bindings,
strict_probability_domain,
probability_epsilon,
semiring_kind,
provenance_tracker,
top_k_proofs_k,
body_support_map,
schema,
properties,
metrics: ExecutionPlanMetricsSet::new(),
}
}
fn build_output_schema(
input_schema: &SchemaRef,
key_indices: &[usize],
fold_bindings: &[FoldBinding],
) -> SchemaRef {
let mut fields = Vec::new();
for &ki in key_indices {
fields.push(Arc::new(input_schema.field(ki).clone()));
}
for binding in fold_bindings {
let idx = binding
.input_col_name
.as_ref()
.and_then(|name| input_schema.index_of(name).ok())
.unwrap_or(binding.input_col_index);
let input_type = if idx < input_schema.fields().len() {
input_schema.field(idx).data_type().clone()
} else {
DataType::Float64
};
let output_type = binding.aggregate.output_type_for_input(&input_type);
fields.push(Arc::new(Field::new(
&binding.output_name,
output_type,
true,
)));
}
Arc::new(Schema::new(fields))
}
}
impl DisplayAs for FoldExec {
fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"FoldExec: key_indices={:?}, bindings={:?}",
self.key_indices, self.fold_bindings
)
}
}
impl ExecutionPlan for FoldExec {
fn name(&self) -> &str {
"FoldExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
fn properties(&self) -> &Arc<PlanProperties> {
&self.properties
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.input]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> DFResult<Arc<dyn ExecutionPlan>> {
if children.len() != 1 {
return Err(datafusion::error::DataFusionError::Plan(
"FoldExec requires exactly one child".to_string(),
));
}
Ok(Arc::new(Self::new_with_topk(
Arc::clone(&children[0]),
self.key_indices.clone(),
self.fold_bindings.clone(),
self.strict_probability_domain,
self.probability_epsilon,
self.semiring_kind,
self.provenance_tracker.as_ref().map(Arc::clone),
self.top_k_proofs_k,
self.body_support_map.as_ref().map(Arc::clone),
)))
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> DFResult<SendableRecordBatchStream> {
let input_stream = self.input.execute(partition, Arc::clone(&context))?;
let metrics = BaselineMetrics::new(&self.metrics, partition);
let key_indices = self.key_indices.clone();
let fold_bindings = self.fold_bindings.clone();
let strict = self.strict_probability_domain;
let epsilon = self.probability_epsilon;
let semiring_kind = self.semiring_kind;
let _provenance_tracker = self.provenance_tracker.as_ref().map(Arc::clone);
let top_k_proofs_k = self.top_k_proofs_k;
let body_support_map = self.body_support_map.as_ref().map(Arc::clone);
let output_schema = Arc::clone(&self.schema);
let input_schema = self.input.schema();
let fut = async move {
let batches: Vec<RecordBatch> = input_stream.try_collect().await?;
if batches.is_empty() {
return Ok(RecordBatch::new_empty(output_schema));
}
let actual_schema = batches
.first()
.map(|b| b.schema())
.unwrap_or(input_schema.clone());
let batch =
arrow::compute::concat_batches(&actual_schema, &batches).map_err(arrow_err)?;
if batch.num_rows() == 0 {
return Ok(RecordBatch::new_empty(output_schema));
}
let mut groups: HashMap<Vec<ScalarKey>, Vec<usize>> = HashMap::new();
let mut ordered_keys: Vec<Vec<ScalarKey>> = Vec::new();
for row_idx in 0..batch.num_rows() {
let key = extract_scalar_key(&batch, &key_indices, row_idx);
let entry = groups.entry(key.clone());
if matches!(entry, std::collections::hash_map::Entry::Vacant(_)) {
ordered_keys.push(key);
}
entry.or_default().push(row_idx);
}
let num_groups = ordered_keys.len();
let mut output_columns: Vec<arrow_array::ArrayRef> = Vec::new();
for &ki in &key_indices {
if ki >= batch.num_columns() {
continue; }
let col = batch.column(ki);
let first_indices: Vec<u32> =
ordered_keys.iter().map(|k| groups[k][0] as u32).collect();
let idx_array = arrow_array::UInt32Array::from(first_indices);
let taken = arrow::compute::take(col.as_ref(), &idx_array, None).map_err(|e| {
datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
})?;
output_columns.push(taken);
}
let cx = FoldContext {
strict,
epsilon,
semiring: match semiring_kind {
SemiringKind::MaxMinProb => FoldSemiring::MaxMin,
_ => FoldSemiring::AddMult,
},
};
for binding in &fold_bindings {
let col: Arc<dyn Array> = if binding.name.as_str() == "COUNTALL" {
Arc::new(arrow_array::Int64Array::from(vec![0i64; batch.num_rows()]))
} else {
let resolved_idx = binding
.input_col_name
.as_ref()
.and_then(|name| batch.schema().index_of(name).ok())
.unwrap_or(binding.input_col_index);
if resolved_idx < batch.num_columns() {
Arc::clone(batch.column(resolved_idx))
} else {
Arc::new(arrow_array::Float64Array::from(vec![
0.0f64;
batch.num_rows()
]))
}
};
let topk_ctx = if matches!(semiring_kind, SemiringKind::TopKProofs { .. }) {
Some(TopKFoldCtx {
k: top_k_proofs_k,
batch: &batch,
body_support_map: body_support_map.as_deref(),
})
} else {
None
};
let agg_col = compute_fold_aggregate(
col.as_ref(),
&binding.aggregate,
FoldGroups {
ordered_keys: &ordered_keys,
groups: &groups,
num_groups,
},
&cx,
topk_ctx.as_ref(),
)?;
output_columns.push(agg_col);
}
RecordBatch::try_new(output_schema, output_columns).map_err(arrow_err)
};
Ok(Box::pin(FoldStream {
state: FoldStreamState::Running(Box::pin(fut)),
schema: Arc::clone(&self.schema),
metrics,
}))
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
}
struct FoldGroups<'a> {
ordered_keys: &'a [Vec<ScalarKey>],
groups: &'a HashMap<Vec<ScalarKey>, Vec<usize>>,
num_groups: usize,
}
struct TopKFoldCtx<'a> {
k: usize,
batch: &'a RecordBatch,
body_support_map: Option<&'a HashMap<Vec<u8>, Vec<ProofTerm>>>,
}
fn compute_fold_aggregate(
col: &dyn Array,
aggregate: &Arc<dyn LocyAggregate>,
groups_ctx: FoldGroups<'_>,
cx: &FoldContext,
topk_ctx: Option<&TopKFoldCtx<'_>>,
) -> DFResult<arrow_array::ArrayRef> {
let ordered_keys = groups_ctx.ordered_keys;
let groups = groups_ctx.groups;
let num_groups = groups_ctx.num_groups;
if let Some(ctx) = topk_ctx
&& aggregate.is_noisy_or()
{
let mut builder = Float64Builder::with_capacity(num_groups);
for key in ordered_keys {
builder.append_option(topk_dnf_disjunction(col, &groups[key], cx.strict, ctx)?);
}
return Ok(Arc::new(builder.finish()));
}
let mut scalars: Vec<ScalarValue> = Vec::with_capacity(num_groups);
for key in ordered_keys {
let mut state = aggregate.create();
state
.ingest_indices(col, &groups[key], cx)
.map_err(fn_error_to_df)?;
scalars.push(state.finalize().map_err(fn_error_to_df)?);
}
ScalarValue::iter_to_array(scalars)
}
fn fn_error_to_df(e: uni_plugin::FnError) -> datafusion::error::DataFusionError {
datafusion::error::DataFusionError::Execution(e.message)
}
fn topk_dnf_disjunction(
col: &dyn Array,
indices: &[usize],
strict: bool,
ctx: &TopKFoldCtx<'_>,
) -> DFResult<Option<f64>> {
use uni_locy::{BaseRv, BaseRvSet, Proof};
let batch = ctx.batch;
let all_indices: Vec<usize> = (0..batch.num_columns()).collect();
let mut interner: HashMap<Vec<u8>, BaseRv> = HashMap::new();
let mut next_rv: u32 = 0;
let mut base_weights: HashMap<BaseRv, f64> = HashMap::new();
let mut proofs: Vec<Proof> = Vec::with_capacity(indices.len());
for &i in indices {
if col.is_null(i) {
continue;
}
let val = match col.as_any().downcast_ref::<arrow_array::Float64Array>() {
Some(arr) => arr.value(i),
None => match col.as_any().downcast_ref::<arrow_array::Int64Array>() {
Some(arr) => arr.value(i) as f64,
None => continue,
},
};
if strict && !(0.0..=1.0).contains(&val) {
return Err(datafusion::error::DataFusionError::Execution(format!(
"strict_probability_domain: MNOR input {val} outside [0,1] under TopKProofs"
)));
}
let weight = val.clamp(0.0, 1.0);
let fact_hash = super::locy_fixpoint::fact_hash_key(batch, &all_indices, i);
let mut base_rvs = BaseRvSet::empty();
if let Some(support) = ctx.body_support_map.and_then(|m| m.get(&fact_hash)) {
for term in support {
let rv = *interner
.entry(term.base_fact_id.clone())
.or_insert_with(|| {
let r = BaseRv(next_rv);
next_rv += 1;
r
});
base_rvs.insert(rv);
}
}
if base_rvs.iter().count() > 0 {
for rv in base_rvs.iter() {
base_weights
.entry(rv)
.and_modify(|w| {
if weight > *w {
*w = weight;
}
})
.or_insert(weight);
}
}
proofs.push(Proof {
weight,
base_rvs,
neural_calls: Vec::new(),
});
}
if proofs.is_empty() {
return Ok(None);
}
if base_weights.is_empty() {
let mut complement = 1.0;
for p in &proofs {
complement *= 1.0 - p.weight;
}
return Ok(Some((1.0 - complement).clamp(0.0, 1.0)));
}
let k = if ctx.k == 0 { proofs.len() } else { ctx.k };
let (kept, _notice) = uni_locy::merge_top_k_runtime(Vec::new(), proofs, k);
let tag = uni_locy::TopKTag { proofs: kept };
Ok(Some(tag.to_dnf().weight(&base_weights)))
}
enum FoldStreamState {
Running(Pin<Box<dyn std::future::Future<Output = DFResult<RecordBatch>> + Send>>),
Done,
}
struct FoldStream {
state: FoldStreamState,
schema: SchemaRef,
metrics: BaselineMetrics,
}
impl Stream for FoldStream {
type Item = DFResult<RecordBatch>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let metrics = self.metrics.clone();
let _timer = metrics.elapsed_compute().timer();
match &mut self.state {
FoldStreamState::Running(fut) => match fut.as_mut().poll(cx) {
Poll::Ready(Ok(batch)) => {
self.metrics.record_output(batch.num_rows());
self.state = FoldStreamState::Done;
Poll::Ready(Some(Ok(batch)))
}
Poll::Ready(Err(e)) => {
self.state = FoldStreamState::Done;
Poll::Ready(Some(Err(e)))
}
Poll::Pending => Poll::Pending,
},
FoldStreamState::Done => Poll::Ready(None),
}
}
}
impl RecordBatchStream for FoldStream {
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::{Float64Array, Int64Array, StringArray};
use arrow_schema::{DataType, Field, Schema};
use datafusion::physical_plan::memory::MemoryStream;
use datafusion::prelude::SessionContext;
fn builtin_agg(name: &str) -> Arc<dyn LocyAggregate> {
use uni_plugin_builtin::locy_aggregates::*;
match name {
"SUM" | "MSUM" => Arc::new(SumAgg),
"MAX" | "MMAX" => Arc::new(MaxAgg),
"MIN" | "MMIN" => Arc::new(MinAgg),
"COUNT" | "COUNTALL" | "MCOUNT" => Arc::new(CountAgg),
"AVG" => Arc::new(AvgAgg),
"COLLECT" => Arc::new(CollectAgg),
"MNOR" | "NOR" => Arc::new(MnorAgg),
"MPROD" | "PROD" => Arc::new(MprodAgg),
other => panic!("unknown test aggregate `{other}`"),
}
}
fn make_test_batch(names: Vec<&str>, values: Vec<f64>) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![
Field::new("name", DataType::Utf8, true),
Field::new("value", DataType::Float64, true),
]));
RecordBatch::try_new(
schema,
vec![
Arc::new(StringArray::from(
names.into_iter().map(Some).collect::<Vec<_>>(),
)),
Arc::new(Float64Array::from(values)),
],
)
.unwrap()
}
fn make_memory_exec(batch: RecordBatch) -> Arc<dyn ExecutionPlan> {
let schema = batch.schema();
Arc::new(TestMemoryExec {
batches: vec![batch],
schema: schema.clone(),
properties: compute_plan_properties(schema),
})
}
#[derive(Debug)]
struct TestMemoryExec {
batches: Vec<RecordBatch>,
schema: SchemaRef,
properties: Arc<PlanProperties>,
}
impl DisplayAs for TestMemoryExec {
fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "TestMemoryExec")
}
}
impl ExecutionPlan for TestMemoryExec {
fn name(&self) -> &str {
"TestMemoryExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
fn properties(&self) -> &Arc<PlanProperties> {
&self.properties
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![]
}
fn with_new_children(
self: Arc<Self>,
_children: Vec<Arc<dyn ExecutionPlan>>,
) -> DFResult<Arc<dyn ExecutionPlan>> {
Ok(self)
}
fn execute(
&self,
_partition: usize,
_context: Arc<TaskContext>,
) -> DFResult<SendableRecordBatchStream> {
Ok(Box::pin(MemoryStream::try_new(
self.batches.clone(),
Arc::clone(&self.schema),
None,
)?))
}
}
async fn execute_fold(
input: Arc<dyn ExecutionPlan>,
key_indices: Vec<usize>,
fold_bindings: Vec<FoldBinding>,
) -> RecordBatch {
let exec = FoldExec::new(input, key_indices, fold_bindings, false, 1e-15);
let ctx = SessionContext::new();
let task_ctx = ctx.task_ctx();
let stream = exec.execute(0, task_ctx).unwrap();
let batches: Vec<RecordBatch> = datafusion::physical_plan::common::collect(stream)
.await
.unwrap();
if batches.is_empty() {
RecordBatch::new_empty(exec.schema())
} else {
arrow::compute::concat_batches(&exec.schema(), &batches).unwrap()
}
}
#[tokio::test]
async fn test_sum_single_group() {
let batch = make_test_batch(vec!["a", "a", "a"], vec![1.0, 2.0, 3.0]);
let input = make_memory_exec(batch);
let result = execute_fold(
input,
vec![0],
vec![FoldBinding {
output_name: "total".to_string(),
name: SmolStr::new_static("SUM"),
aggregate: builtin_agg("SUM"),
input_col_index: 1,
input_col_name: None,
}],
)
.await;
assert_eq!(result.num_rows(), 1);
let totals = result
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert!((totals.value(0) - 6.0).abs() < f64::EPSILON);
}
#[tokio::test]
async fn test_count_non_null() {
let schema = Arc::new(Schema::new(vec![
Field::new("name", DataType::Utf8, true),
Field::new("value", DataType::Float64, true),
]));
let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(StringArray::from(vec![Some("a"), Some("a"), Some("a")])),
Arc::new(Float64Array::from(vec![Some(1.0), None, Some(3.0)])),
],
)
.unwrap();
let input = make_memory_exec(batch);
let result = execute_fold(
input,
vec![0],
vec![FoldBinding {
output_name: "cnt".to_string(),
name: SmolStr::new_static("COUNT"),
aggregate: builtin_agg("COUNT"),
input_col_index: 1,
input_col_name: None,
}],
)
.await;
assert_eq!(result.num_rows(), 1);
let counts = result
.column(1)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(counts.value(0), 2); }
#[tokio::test]
async fn test_max_min() {
let batch = make_test_batch(vec!["a", "a", "a"], vec![3.0, 1.0, 5.0]);
let input_max = make_memory_exec(batch.clone());
let input_min = make_memory_exec(batch);
let result_max = execute_fold(
input_max,
vec![0],
vec![FoldBinding {
output_name: "mx".to_string(),
name: SmolStr::new_static("MAX"),
aggregate: builtin_agg("MAX"),
input_col_index: 1,
input_col_name: None,
}],
)
.await;
let result_min = execute_fold(
input_min,
vec![0],
vec![FoldBinding {
output_name: "mn".to_string(),
name: SmolStr::new_static("MIN"),
aggregate: builtin_agg("MIN"),
input_col_index: 1,
input_col_name: None,
}],
)
.await;
let max_vals = result_max
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert_eq!(max_vals.value(0), 5.0);
let min_vals = result_min
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert_eq!(min_vals.value(0), 1.0);
}
#[tokio::test]
async fn test_avg() {
let batch = make_test_batch(vec!["a", "a", "a", "a"], vec![2.0, 4.0, 6.0, 8.0]);
let input = make_memory_exec(batch);
let result = execute_fold(
input,
vec![0],
vec![FoldBinding {
output_name: "average".to_string(),
name: SmolStr::new_static("AVG"),
aggregate: builtin_agg("AVG"),
input_col_index: 1,
input_col_name: None,
}],
)
.await;
assert_eq!(result.num_rows(), 1);
let avgs = result
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert!((avgs.value(0) - 5.0).abs() < f64::EPSILON);
}
#[tokio::test]
async fn test_multiple_groups() {
let batch = make_test_batch(
vec!["a", "a", "b", "b", "b"],
vec![1.0, 2.0, 10.0, 20.0, 30.0],
);
let input = make_memory_exec(batch);
let result = execute_fold(
input,
vec![0],
vec![FoldBinding {
output_name: "total".to_string(),
name: SmolStr::new_static("SUM"),
aggregate: builtin_agg("SUM"),
input_col_index: 1,
input_col_name: None,
}],
)
.await;
assert_eq!(result.num_rows(), 2);
let names = result
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let totals = result
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
for i in 0..2 {
match names.value(i) {
"a" => assert!((totals.value(i) - 3.0).abs() < f64::EPSILON),
"b" => assert!((totals.value(i) - 60.0).abs() < f64::EPSILON),
_ => panic!("unexpected name"),
}
}
}
#[tokio::test]
async fn test_empty_input() {
let schema = Arc::new(Schema::new(vec![
Field::new("name", DataType::Utf8, true),
Field::new("value", DataType::Float64, true),
]));
let batch = RecordBatch::new_empty(schema);
let input = make_memory_exec(batch);
let result = execute_fold(
input,
vec![0],
vec![FoldBinding {
output_name: "total".to_string(),
name: SmolStr::new_static("SUM"),
aggregate: builtin_agg("SUM"),
input_col_index: 1,
input_col_name: None,
}],
)
.await;
assert_eq!(result.num_rows(), 0);
}
#[tokio::test]
async fn test_multiple_bindings() {
let batch = make_test_batch(vec!["a", "a", "a"], vec![1.0, 2.0, 3.0]);
let input = make_memory_exec(batch);
let result = execute_fold(
input,
vec![0],
vec![
FoldBinding {
output_name: "total".to_string(),
name: SmolStr::new_static("SUM"),
aggregate: builtin_agg("SUM"),
input_col_index: 1,
input_col_name: None,
},
FoldBinding {
output_name: "cnt".to_string(),
name: SmolStr::new_static("COUNT"),
aggregate: builtin_agg("COUNT"),
input_col_index: 1,
input_col_name: None,
},
FoldBinding {
output_name: "mx".to_string(),
name: SmolStr::new_static("MAX"),
aggregate: builtin_agg("MAX"),
input_col_index: 1,
input_col_name: None,
},
],
)
.await;
assert_eq!(result.num_rows(), 1);
assert_eq!(result.num_columns(), 4);
let totals = result
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert!((totals.value(0) - 6.0).abs() < f64::EPSILON);
let counts = result
.column(2)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(counts.value(0), 3);
let maxes = result
.column(3)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert_eq!(maxes.value(0), 3.0);
}
#[tokio::test]
async fn test_nor_single_group() {
let batch = make_test_batch(vec!["a", "a"], vec![0.3, 0.5]);
let input = make_memory_exec(batch);
let result = execute_fold(
input,
vec![0],
vec![FoldBinding {
output_name: "prob".to_string(),
name: SmolStr::new_static("MNOR"),
aggregate: builtin_agg("MNOR"),
input_col_index: 1,
input_col_name: None,
}],
)
.await;
assert_eq!(result.num_rows(), 1);
let vals = result
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert!((vals.value(0) - 0.65).abs() < 1e-10);
}
#[tokio::test]
async fn test_nor_identity() {
let batch = make_test_batch(vec!["a", "a"], vec![0.0, 0.0]);
let input = make_memory_exec(batch);
let result = execute_fold(
input,
vec![0],
vec![FoldBinding {
output_name: "prob".to_string(),
name: SmolStr::new_static("MNOR"),
aggregate: builtin_agg("MNOR"),
input_col_index: 1,
input_col_name: None,
}],
)
.await;
let vals = result
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert!((vals.value(0) - 0.0).abs() < 1e-10);
}
#[tokio::test]
async fn test_nor_clamping() {
let batch = make_test_batch(vec!["a", "a"], vec![-0.5, 1.5]);
let input = make_memory_exec(batch);
let result = execute_fold(
input,
vec![0],
vec![FoldBinding {
output_name: "prob".to_string(),
name: SmolStr::new_static("MNOR"),
aggregate: builtin_agg("MNOR"),
input_col_index: 1,
input_col_name: None,
}],
)
.await;
let vals = result
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert!((vals.value(0) - 1.0).abs() < 1e-10);
}
#[tokio::test]
async fn test_nor_multiple_groups() {
let batch = make_test_batch(vec!["a", "a", "b", "b"], vec![0.3, 0.5, 0.1, 0.2]);
let input = make_memory_exec(batch);
let result = execute_fold(
input,
vec![0],
vec![FoldBinding {
output_name: "prob".to_string(),
name: SmolStr::new_static("MNOR"),
aggregate: builtin_agg("MNOR"),
input_col_index: 1,
input_col_name: None,
}],
)
.await;
assert_eq!(result.num_rows(), 2);
let names = result
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let vals = result
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
for i in 0..2 {
match names.value(i) {
"a" => assert!((vals.value(i) - 0.65).abs() < 1e-10),
"b" => assert!((vals.value(i) - 0.28).abs() < 1e-10),
_ => panic!("unexpected name"),
}
}
}
#[tokio::test]
async fn test_prod_single_group() {
let batch = make_test_batch(vec!["a", "a"], vec![0.6, 0.8]);
let input = make_memory_exec(batch);
let result = execute_fold(
input,
vec![0],
vec![FoldBinding {
output_name: "prob".to_string(),
name: SmolStr::new_static("MPROD"),
aggregate: builtin_agg("MPROD"),
input_col_index: 1,
input_col_name: None,
}],
)
.await;
assert_eq!(result.num_rows(), 1);
let vals = result
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert!((vals.value(0) - 0.48).abs() < 1e-10);
}
#[tokio::test]
async fn test_prod_identity() {
let batch = make_test_batch(vec!["a", "a"], vec![1.0, 1.0]);
let input = make_memory_exec(batch);
let result = execute_fold(
input,
vec![0],
vec![FoldBinding {
output_name: "prob".to_string(),
name: SmolStr::new_static("MPROD"),
aggregate: builtin_agg("MPROD"),
input_col_index: 1,
input_col_name: None,
}],
)
.await;
let vals = result
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert!((vals.value(0) - 1.0).abs() < 1e-10);
}
#[tokio::test]
async fn test_prod_zero_absorbing() {
let batch = make_test_batch(vec!["a", "a", "a"], vec![0.5, 0.0, 0.8]);
let input = make_memory_exec(batch);
let result = execute_fold(
input,
vec![0],
vec![FoldBinding {
output_name: "prob".to_string(),
name: SmolStr::new_static("MPROD"),
aggregate: builtin_agg("MPROD"),
input_col_index: 1,
input_col_name: None,
}],
)
.await;
let vals = result
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert!((vals.value(0) - 0.0).abs() < 1e-10);
}
#[tokio::test]
async fn test_prod_underflow_protection() {
let names: Vec<&str> = vec!["a"; 50];
let values: Vec<f64> = vec![0.5; 50];
let batch = make_test_batch(names, values);
let input = make_memory_exec(batch);
let result = execute_fold(
input,
vec![0],
vec![FoldBinding {
output_name: "prob".to_string(),
name: SmolStr::new_static("MPROD"),
aggregate: builtin_agg("MPROD"),
input_col_index: 1,
input_col_name: None,
}],
)
.await;
let vals = result
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
let expected = 0.5_f64.powi(50); assert!(vals.value(0) > 0.0, "should not underflow to zero");
assert!(
(vals.value(0) - expected).abs() / expected < 1e-6,
"result {} should be close to expected {}",
vals.value(0),
expected
);
}
fn make_nullable_test_batch(names: Vec<&str>, values: Vec<Option<f64>>) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![
Field::new("name", DataType::Utf8, true),
Field::new("value", DataType::Float64, true),
]));
RecordBatch::try_new(
schema,
vec![
Arc::new(StringArray::from(
names.into_iter().map(Some).collect::<Vec<_>>(),
)),
Arc::new(Float64Array::from(values)),
],
)
.unwrap()
}
#[tokio::test]
async fn test_nor_single_element() {
let batch = make_test_batch(vec!["a"], vec![0.7]);
let input = make_memory_exec(batch);
let result = execute_fold(
input,
vec![0],
vec![FoldBinding {
output_name: "prob".to_string(),
name: SmolStr::new_static("MNOR"),
aggregate: builtin_agg("MNOR"),
input_col_index: 1,
input_col_name: None,
}],
)
.await;
let vals = result
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert!((vals.value(0) - 0.7).abs() < 1e-10);
}
#[tokio::test]
async fn test_prod_single_element() {
let batch = make_test_batch(vec!["a"], vec![0.7]);
let input = make_memory_exec(batch);
let result = execute_fold(
input,
vec![0],
vec![FoldBinding {
output_name: "prob".to_string(),
name: SmolStr::new_static("MPROD"),
aggregate: builtin_agg("MPROD"),
input_col_index: 1,
input_col_name: None,
}],
)
.await;
let vals = result
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert!((vals.value(0) - 0.7).abs() < 1e-10);
}
#[tokio::test]
async fn test_nor_three_elements() {
let batch = make_test_batch(vec!["a", "a", "a"], vec![0.3, 0.4, 0.5]);
let input = make_memory_exec(batch);
let result = execute_fold(
input,
vec![0],
vec![FoldBinding {
output_name: "prob".to_string(),
name: SmolStr::new_static("MNOR"),
aggregate: builtin_agg("MNOR"),
input_col_index: 1,
input_col_name: None,
}],
)
.await;
let vals = result
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert!((vals.value(0) - 0.79).abs() < 1e-10);
}
#[tokio::test]
async fn test_nor_four_elements_spec_example() {
let batch = make_test_batch(vec!["a", "a", "a", "a"], vec![0.72, 0.54, 0.56, 0.42]);
let input = make_memory_exec(batch);
let result = execute_fold(
input,
vec![0],
vec![FoldBinding {
output_name: "prob".to_string(),
name: SmolStr::new_static("MNOR"),
aggregate: builtin_agg("MNOR"),
input_col_index: 1,
input_col_name: None,
}],
)
.await;
let vals = result
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert!(
(vals.value(0) - 0.96713024).abs() < 1e-10,
"expected 0.96713024, got {}",
vals.value(0)
);
}
#[tokio::test]
async fn test_prod_three_elements() {
let batch = make_test_batch(vec!["a", "a", "a"], vec![0.5, 0.5, 0.5]);
let input = make_memory_exec(batch);
let result = execute_fold(
input,
vec![0],
vec![FoldBinding {
output_name: "prob".to_string(),
name: SmolStr::new_static("MPROD"),
aggregate: builtin_agg("MPROD"),
input_col_index: 1,
input_col_name: None,
}],
)
.await;
let vals = result
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert!((vals.value(0) - 0.125).abs() < 1e-10);
}
#[tokio::test]
async fn test_nor_absorbing_element() {
let batch = make_test_batch(vec!["a", "a"], vec![0.3, 1.0]);
let input = make_memory_exec(batch);
let result = execute_fold(
input,
vec![0],
vec![FoldBinding {
output_name: "prob".to_string(),
name: SmolStr::new_static("MNOR"),
aggregate: builtin_agg("MNOR"),
input_col_index: 1,
input_col_name: None,
}],
)
.await;
let vals = result
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert!((vals.value(0) - 1.0).abs() < 1e-10);
}
#[tokio::test]
async fn test_prod_clamping() {
let batch = make_test_batch(vec!["a", "a"], vec![2.0, 0.5]);
let input = make_memory_exec(batch);
let result = execute_fold(
input,
vec![0],
vec![FoldBinding {
output_name: "prob".to_string(),
name: SmolStr::new_static("MPROD"),
aggregate: builtin_agg("MPROD"),
input_col_index: 1,
input_col_name: None,
}],
)
.await;
let vals = result
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert!((vals.value(0) - 0.5).abs() < 1e-10);
}
#[tokio::test]
async fn test_prod_multiple_groups() {
let batch = make_test_batch(vec!["a", "a", "b", "b"], vec![0.6, 0.8, 0.5, 0.5]);
let input = make_memory_exec(batch);
let result = execute_fold(
input,
vec![0],
vec![FoldBinding {
output_name: "prob".to_string(),
name: SmolStr::new_static("MPROD"),
aggregate: builtin_agg("MPROD"),
input_col_index: 1,
input_col_name: None,
}],
)
.await;
assert_eq!(result.num_rows(), 2);
let names = result
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let vals = result
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
for i in 0..2 {
match names.value(i) {
"a" => assert!((vals.value(i) - 0.48).abs() < 1e-10),
"b" => assert!((vals.value(i) - 0.25).abs() < 1e-10),
_ => panic!("unexpected group name"),
}
}
}
#[tokio::test]
async fn test_nor_commutativity() {
let fwd = make_test_batch(vec!["a", "a", "a"], vec![0.2, 0.5, 0.8]);
let rev = make_test_batch(vec!["a", "a", "a"], vec![0.8, 0.5, 0.2]);
let binding = vec![FoldBinding {
output_name: "prob".to_string(),
name: SmolStr::new_static("MNOR"),
aggregate: builtin_agg("MNOR"),
input_col_index: 1,
input_col_name: None,
}];
let r1 = execute_fold(make_memory_exec(fwd), vec![0], binding.clone()).await;
let r2 = execute_fold(make_memory_exec(rev), vec![0], binding).await;
let v1 = r1
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap()
.value(0);
let v2 = r2
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap()
.value(0);
assert!((v1 - 0.92).abs() < 1e-10);
assert!((v2 - 0.92).abs() < 1e-10);
assert!((v1 - v2).abs() < 1e-15, "commutativity violated");
}
#[tokio::test]
async fn test_prod_commutativity() {
let fwd = make_test_batch(vec!["a", "a"], vec![0.5, 0.25]);
let rev = make_test_batch(vec!["a", "a"], vec![0.25, 0.5]);
let binding = vec![FoldBinding {
output_name: "prob".to_string(),
name: SmolStr::new_static("MPROD"),
aggregate: builtin_agg("MPROD"),
input_col_index: 1,
input_col_name: None,
}];
let r1 = execute_fold(make_memory_exec(fwd), vec![0], binding.clone()).await;
let r2 = execute_fold(make_memory_exec(rev), vec![0], binding).await;
let v1 = r1
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap()
.value(0);
let v2 = r2
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap()
.value(0);
assert!((v1 - 0.125).abs() < 1e-10);
assert!((v2 - 0.125).abs() < 1e-10);
assert!((v1 - v2).abs() < 1e-15, "commutativity violated");
}
#[tokio::test]
async fn test_nor_boundary_near_zero() {
let batch = make_test_batch(vec!["a", "a"], vec![0.001, 0.002]);
let input = make_memory_exec(batch);
let result = execute_fold(
input,
vec![0],
vec![FoldBinding {
output_name: "prob".to_string(),
name: SmolStr::new_static("MNOR"),
aggregate: builtin_agg("MNOR"),
input_col_index: 1,
input_col_name: None,
}],
)
.await;
let vals = result
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
let expected = 1.0 - 0.999 * 0.998;
assert!(
(vals.value(0) - expected).abs() < 1e-10,
"expected {}, got {}",
expected,
vals.value(0)
);
}
#[tokio::test]
async fn test_nor_boundary_near_one() {
let batch = make_test_batch(vec!["a", "a"], vec![0.999, 0.998]);
let input = make_memory_exec(batch);
let result = execute_fold(
input,
vec![0],
vec![FoldBinding {
output_name: "prob".to_string(),
name: SmolStr::new_static("MNOR"),
aggregate: builtin_agg("MNOR"),
input_col_index: 1,
input_col_name: None,
}],
)
.await;
let vals = result
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
let expected = 1.0 - 0.001 * 0.002;
assert!(
(vals.value(0) - expected).abs() < 1e-10,
"expected {}, got {}",
expected,
vals.value(0)
);
}
#[tokio::test]
async fn test_prod_boundary_near_zero() {
let batch = make_test_batch(vec!["a", "a"], vec![0.001, 0.002]);
let input = make_memory_exec(batch);
let result = execute_fold(
input,
vec![0],
vec![FoldBinding {
output_name: "prob".to_string(),
name: SmolStr::new_static("MPROD"),
aggregate: builtin_agg("MPROD"),
input_col_index: 1,
input_col_name: None,
}],
)
.await;
let vals = result
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert!(
(vals.value(0) - 2e-6).abs() < 1e-15,
"expected 2e-6, got {}",
vals.value(0)
);
}
#[tokio::test]
async fn test_nor_empty_input() {
let schema = Arc::new(Schema::new(vec![
Field::new("name", DataType::Utf8, true),
Field::new("value", DataType::Float64, true),
]));
let batch = RecordBatch::new_empty(schema);
let input = make_memory_exec(batch);
let result = execute_fold(
input,
vec![0],
vec![FoldBinding {
output_name: "prob".to_string(),
name: SmolStr::new_static("MNOR"),
aggregate: builtin_agg("MNOR"),
input_col_index: 1,
input_col_name: None,
}],
)
.await;
assert_eq!(result.num_rows(), 0);
}
#[tokio::test]
async fn test_nor_nan_handling() {
let batch = make_test_batch(vec!["a", "a"], vec![0.3, f64::NAN]);
let input = make_memory_exec(batch);
let result = execute_fold(
input,
vec![0],
vec![FoldBinding {
output_name: "prob".to_string(),
name: SmolStr::new_static("MNOR"),
aggregate: builtin_agg("MNOR"),
input_col_index: 1,
input_col_name: None,
}],
)
.await;
let vals = result
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert!(vals.value(0).is_nan(), "NaN should propagate through MNOR");
}
#[tokio::test]
async fn test_prod_nan_handling() {
let batch = make_test_batch(vec!["a", "a"], vec![0.5, f64::NAN]);
let input = make_memory_exec(batch);
let result = execute_fold(
input,
vec![0],
vec![FoldBinding {
output_name: "prob".to_string(),
name: SmolStr::new_static("MPROD"),
aggregate: builtin_agg("MPROD"),
input_col_index: 1,
input_col_name: None,
}],
)
.await;
let vals = result
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert!(vals.value(0).is_nan(), "NaN should propagate through MPROD");
}
#[tokio::test]
async fn test_prod_infinity_handling() {
let batch = make_test_batch(vec!["a", "a"], vec![0.5, f64::INFINITY]);
let input = make_memory_exec(batch);
let result = execute_fold(
input,
vec![0],
vec![FoldBinding {
output_name: "prob".to_string(),
name: SmolStr::new_static("MPROD"),
aggregate: builtin_agg("MPROD"),
input_col_index: 1,
input_col_name: None,
}],
)
.await;
let vals = result
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert!((vals.value(0) - 0.5).abs() < 1e-10);
}
#[tokio::test]
async fn test_nor_infinity_handling() {
let batch = make_test_batch(vec!["a", "a"], vec![0.3, f64::INFINITY]);
let input = make_memory_exec(batch);
let result = execute_fold(
input,
vec![0],
vec![FoldBinding {
output_name: "prob".to_string(),
name: SmolStr::new_static("MNOR"),
aggregate: builtin_agg("MNOR"),
input_col_index: 1,
input_col_name: None,
}],
)
.await;
let vals = result
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert!((vals.value(0) - 1.0).abs() < 1e-10);
}
#[tokio::test]
async fn test_nor_all_null_values() {
let batch = make_nullable_test_batch(vec!["a", "a"], vec![None, None]);
let input = make_memory_exec(batch);
let result = execute_fold(
input,
vec![0],
vec![FoldBinding {
output_name: "prob".to_string(),
name: SmolStr::new_static("MNOR"),
aggregate: builtin_agg("MNOR"),
input_col_index: 1,
input_col_name: None,
}],
)
.await;
assert_eq!(result.num_rows(), 1);
let vals = result
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert!(vals.is_null(0), "all-null MNOR should produce null");
}
#[tokio::test]
async fn test_prod_all_null_values() {
let batch = make_nullable_test_batch(vec!["a", "a"], vec![None, None]);
let input = make_memory_exec(batch);
let result = execute_fold(
input,
vec![0],
vec![FoldBinding {
output_name: "prob".to_string(),
name: SmolStr::new_static("MPROD"),
aggregate: builtin_agg("MPROD"),
input_col_index: 1,
input_col_name: None,
}],
)
.await;
assert_eq!(result.num_rows(), 1);
let vals = result
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert!(vals.is_null(0), "all-null MPROD should produce null");
}
#[tokio::test]
async fn test_nor_mixed_null_values() {
let batch = make_nullable_test_batch(vec!["a", "a", "a"], vec![Some(0.3), None, Some(0.5)]);
let input = make_memory_exec(batch);
let result = execute_fold(
input,
vec![0],
vec![FoldBinding {
output_name: "prob".to_string(),
name: SmolStr::new_static("MNOR"),
aggregate: builtin_agg("MNOR"),
input_col_index: 1,
input_col_name: None,
}],
)
.await;
let vals = result
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert!((vals.value(0) - 0.65).abs() < 1e-10);
}
#[tokio::test]
async fn test_prod_mixed_null_values() {
let batch = make_nullable_test_batch(vec!["a", "a", "a"], vec![Some(0.6), None, Some(0.8)]);
let input = make_memory_exec(batch);
let result = execute_fold(
input,
vec![0],
vec![FoldBinding {
output_name: "prob".to_string(),
name: SmolStr::new_static("MPROD"),
aggregate: builtin_agg("MPROD"),
input_col_index: 1,
input_col_name: None,
}],
)
.await;
let vals = result
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert!((vals.value(0) - 0.48).abs() < 1e-10);
}
#[tokio::test]
async fn test_nor_many_small_values() {
let names: Vec<&str> = vec!["a"; 20];
let values: Vec<f64> = vec![0.1; 20];
let batch = make_test_batch(names, values);
let input = make_memory_exec(batch);
let result = execute_fold(
input,
vec![0],
vec![FoldBinding {
output_name: "prob".to_string(),
name: SmolStr::new_static("MNOR"),
aggregate: builtin_agg("MNOR"),
input_col_index: 1,
input_col_name: None,
}],
)
.await;
let vals = result
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
let expected = 1.0 - 0.9_f64.powi(20);
assert!(
(vals.value(0) - expected).abs() < 1e-10,
"expected {}, got {}",
expected,
vals.value(0)
);
}
#[test]
fn trait_dispatch_monotonicity() {
for name in [
"SUM", "MAX", "MIN", "COUNT", "AVG", "COLLECT", "MNOR", "MPROD",
] {
let agg = builtin_agg(name);
let sl = agg.semilattice();
let expect_monotone =
matches!(name, "MIN" | "MAX" | "MNOR" | "MPROD" | "COLLECT" | "COUNT");
assert_eq!(
sl.monotone_join, expect_monotone,
"monotone_join mismatch for {name}"
);
}
}
#[test]
fn trait_dispatch_initial_accumulator() {
assert_eq!(builtin_agg("SUM").initial_accum_f64(), Some(0.0));
assert_eq!(builtin_agg("COUNT").initial_accum_f64(), Some(0.0));
assert_eq!(builtin_agg("MNOR").initial_accum_f64(), Some(0.0));
assert_eq!(
builtin_agg("MAX").initial_accum_f64(),
Some(f64::NEG_INFINITY)
);
assert_eq!(builtin_agg("MIN").initial_accum_f64(), Some(f64::INFINITY));
assert_eq!(builtin_agg("MPROD").initial_accum_f64(), Some(1.0));
assert_eq!(builtin_agg("AVG").initial_accum_f64(), None);
assert_eq!(builtin_agg("COLLECT").initial_accum_f64(), None);
}
#[test]
fn trait_dispatch_probability_predicate() {
for name in ["MNOR", "MPROD"] {
assert!(
builtin_agg(name).is_probability_aggregate(),
"expected {name} to be probability-domain"
);
}
for name in ["SUM", "MAX", "MIN", "COUNT", "AVG", "COLLECT"] {
assert!(
!builtin_agg(name).is_probability_aggregate(),
"{name} should NOT be probability-domain"
);
}
}
#[test]
fn trait_dispatch_noisy_or_predicate() {
assert!(builtin_agg("MNOR").is_noisy_or());
assert!(!builtin_agg("MPROD").is_noisy_or());
}
async fn execute_fold_strict(
input: Arc<dyn ExecutionPlan>,
key_indices: Vec<usize>,
fold_bindings: Vec<FoldBinding>,
strict: bool,
) -> DFResult<RecordBatch> {
let exec = FoldExec::new(input, key_indices, fold_bindings, strict, 1e-15);
let ctx = SessionContext::new();
let task_ctx = ctx.task_ctx();
let stream = exec.execute(0, task_ctx).unwrap();
let batches: Vec<RecordBatch> = datafusion::physical_plan::common::collect(stream).await?;
if batches.is_empty() {
Ok(RecordBatch::new_empty(exec.schema()))
} else {
arrow::compute::concat_batches(&exec.schema(), &batches).map_err(arrow_err)
}
}
#[tokio::test]
async fn test_nor_strict_rejects_above_one() {
let batch = make_test_batch(vec!["a"], vec![1.5]);
let input = make_memory_exec(batch);
let result = execute_fold_strict(
input,
vec![0],
vec![FoldBinding {
output_name: "p".into(),
name: SmolStr::new_static("MNOR"),
aggregate: builtin_agg("MNOR"),
input_col_index: 1,
input_col_name: None,
}],
true,
)
.await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("strict_probability_domain"),
"Expected strict error, got: {}",
err
);
}
#[tokio::test]
async fn test_nor_strict_rejects_negative() {
let batch = make_test_batch(vec!["a"], vec![-0.1]);
let input = make_memory_exec(batch);
let result = execute_fold_strict(
input,
vec![0],
vec![FoldBinding {
output_name: "p".into(),
name: SmolStr::new_static("MNOR"),
aggregate: builtin_agg("MNOR"),
input_col_index: 1,
input_col_name: None,
}],
true,
)
.await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("strict_probability_domain"),
"Expected strict error, got: {}",
err
);
}
#[tokio::test]
async fn test_prod_strict_rejects_above_one() {
let batch = make_test_batch(vec!["a"], vec![2.0]);
let input = make_memory_exec(batch);
let result = execute_fold_strict(
input,
vec![0],
vec![FoldBinding {
output_name: "p".into(),
name: SmolStr::new_static("MPROD"),
aggregate: builtin_agg("MPROD"),
input_col_index: 1,
input_col_name: None,
}],
true,
)
.await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("strict_probability_domain"),
"Expected strict error, got: {}",
err
);
}
#[tokio::test]
async fn test_prod_strict_rejects_negative() {
let batch = make_test_batch(vec!["a"], vec![-0.5]);
let input = make_memory_exec(batch);
let result = execute_fold_strict(
input,
vec![0],
vec![FoldBinding {
output_name: "p".into(),
name: SmolStr::new_static("MPROD"),
aggregate: builtin_agg("MPROD"),
input_col_index: 1,
input_col_name: None,
}],
true,
)
.await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("strict_probability_domain"),
"Expected strict error, got: {}",
err
);
}
#[tokio::test]
async fn test_nor_strict_accepts_valid() {
let batch = make_test_batch(vec!["a", "a"], vec![0.3, 0.5]);
let input = make_memory_exec(batch);
let result = execute_fold_strict(
input,
vec![0],
vec![FoldBinding {
output_name: "p".into(),
name: SmolStr::new_static("MNOR"),
aggregate: builtin_agg("MNOR"),
input_col_index: 1,
input_col_name: None,
}],
true,
)
.await;
assert!(result.is_ok(), "Expected Ok, got: {:?}", result.err());
let batch = result.unwrap();
let vals = batch
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
let expected = 0.65; assert!(
(vals.value(0) - expected).abs() < 1e-10,
"expected {}, got {}",
expected,
vals.value(0)
);
}
#[tokio::test]
async fn test_count_all_groups_by_key() {
let batch = make_test_batch(vec!["a", "a", "b"], vec![10.0, 20.0, 30.0]);
let input = make_memory_exec(batch);
let result = execute_fold(
input,
vec![0],
vec![FoldBinding {
output_name: "cnt".to_string(),
name: SmolStr::new_static("COUNTALL"),
aggregate: builtin_agg("COUNTALL"),
input_col_index: 0, input_col_name: None,
}],
)
.await;
assert_eq!(result.num_rows(), 2, "Should have 2 groups");
let counts = result
.column(1)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(counts.value(0), 2, "Group 'a' should have count 2");
assert_eq!(counts.value(1), 1, "Group 'b' should have count 1");
}
#[derive(Debug)]
struct IdentityAgg;
impl LocyAggregate for IdentityAgg {
fn semilattice(&self) -> uni_plugin::traits::locy::Semilattice {
uni_plugin::traits::locy::Semilattice::BOUNDED_MIN_MAX
}
fn output_type(&self) -> arrow_schema::DataType {
arrow_schema::DataType::Float64
}
fn create(&self) -> Box<dyn uni_plugin::traits::locy::LocyAggState> {
panic!("IdentityAgg::create not used in this sanity test")
}
}
#[test]
fn resolve_locy_aggregate_returns_registered_instance() {
let registry = uni_plugin::PluginRegistry::new();
let plugin_id = uni_plugin::PluginId::new(uni_plugin::QName::BUILTIN_NS);
let caps = uni_plugin::CapabilitySet::from_iter_of([uni_plugin::Capability::LocyAggregate]);
let registered: Arc<dyn LocyAggregate> = Arc::new(IdentityAgg);
let mut r = uni_plugin::PluginRegistrar::new(plugin_id, &caps, ®istry);
r.locy_aggregate(
uni_plugin::QName::builtin("TEST_IDENTITY"),
Arc::clone(®istered),
)
.expect("register");
r.commit_to_registry().expect("commit");
let resolved = resolve_locy_aggregate(®istry, "TEST_IDENTITY")
.expect("registered aggregate should resolve");
assert!(
Arc::ptr_eq(®istered, &resolved.aggregate),
"registry must return the exact Arc that was registered"
);
assert!(resolve_locy_aggregate(®istry, "NOT_REGISTERED").is_none());
}
#[test]
fn default_locy_plugin_registry_contains_all_builtins() {
let r = default_locy_plugin_registry();
for name in [
"MIN", "MAX", "SUM", "MSUM", "COUNT", "AVG", "COLLECT", "MNOR", "MPROD",
] {
assert!(
resolve_locy_aggregate(&r, name).is_some(),
"default registry should contain built-in `{name}`"
);
}
}
#[derive(Debug)]
struct RangeAgg;
impl LocyAggregate for RangeAgg {
fn semilattice(&self) -> uni_plugin::traits::locy::Semilattice {
uni_plugin::traits::locy::Semilattice::NON_MONOTONE
}
fn output_type(&self) -> DataType {
DataType::Float64
}
fn create(&self) -> Box<dyn uni_plugin::traits::locy::LocyAggState> {
Box::new(RangeState {
min: None,
max: None,
})
}
}
#[derive(Debug)]
struct RangeState {
min: Option<f64>,
max: Option<f64>,
}
impl uni_plugin::traits::locy::LocyAggState for RangeState {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn ingest_indices(
&mut self,
col: &dyn Array,
indices: &[usize],
_cx: &FoldContext,
) -> Result<(), uni_plugin::FnError> {
let arr = col.as_any().downcast_ref::<Float64Array>().unwrap();
for &i in indices {
if arr.is_null(i) {
continue;
}
let v = arr.value(i);
self.min = Some(self.min.map_or(v, |m| m.min(v)));
self.max = Some(self.max.map_or(v, |m| m.max(v)));
}
Ok(())
}
fn merge(
&mut self,
_other: &dyn uni_plugin::traits::locy::LocyAggState,
) -> Result<(), uni_plugin::FnError> {
Ok(())
}
fn finalize(&self) -> Result<ScalarValue, uni_plugin::FnError> {
match (self.min, self.max) {
(Some(lo), Some(hi)) => Ok(ScalarValue::Float64(Some(hi - lo))),
_ => Ok(ScalarValue::Float64(None)),
}
}
}
#[tokio::test]
async fn user_defined_aggregate_runs_in_non_recursive_fold() {
let batch = make_test_batch(vec!["a", "a", "b"], vec![1.0, 5.0, 3.0]);
let input = make_memory_exec(batch);
let binding = FoldBinding {
output_name: "r".into(),
name: SmolStr::new_static("RANGE"),
aggregate: Arc::new(RangeAgg),
input_col_index: 1,
input_col_name: Some("value".to_string()),
};
let out = execute_fold(input, vec![0], vec![binding]).await;
assert_eq!(out.num_rows(), 2);
let col = out
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.expect("range output is Float64");
assert_eq!(col.value(0), 4.0);
assert_eq!(col.value(1), 0.0);
}
}