use crate::query::df_graph::GraphExecutionContext;
use crate::query::df_graph::common::{
ScalarKey, arrow_err, collect_all_partitions, compute_plan_properties, execute_subplan,
extract_scalar_key,
};
use crate::query::df_graph::locy_best_by::{BestByExec, SortCriterion};
use crate::query::df_graph::locy_errors::LocyRuntimeError;
use crate::query::df_graph::locy_explain::{
ProofTerm, ProvenanceAnnotation, ProvenanceStore, compute_proof_probability,
};
use crate::query::df_graph::locy_fold::{FoldBinding, FoldExec};
use crate::query::df_graph::locy_priority::PriorityExec;
use crate::query::planner::LogicalPlan;
use arrow_array::RecordBatch;
use arrow_row::{RowConverter, SortField};
use arrow_schema::SchemaRef;
use datafusion::common::JoinType;
use datafusion::common::Result as DFResult;
use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode};
use datafusion::physical_plan::memory::MemoryStream;
use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
use futures::Stream;
use parking_lot::RwLock;
use std::any::Any;
use std::collections::{HashMap, HashSet};
use std::fmt;
use std::pin::Pin;
use std::sync::{Arc, RwLock as StdRwLock};
use std::task::{Context, Poll};
use std::time::{Duration, Instant};
use uni_common::Value;
use uni_common::core::schema::Schema as UniSchema;
use uni_cypher::ast::Expr;
use uni_locy::RuntimeWarning;
use uni_store::storage::manager::StorageManager;
#[derive(Debug)]
pub struct DerivedScanEntry {
pub scan_index: usize,
pub rule_name: String,
pub is_self_ref: bool,
pub data: Arc<RwLock<Vec<RecordBatch>>>,
pub schema: SchemaRef,
}
#[derive(Debug, Default)]
pub struct DerivedScanRegistry {
entries: Vec<DerivedScanEntry>,
}
impl DerivedScanRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn add(&mut self, entry: DerivedScanEntry) {
self.entries.push(entry);
}
pub fn get(&self, scan_index: usize) -> Option<&DerivedScanEntry> {
self.entries.iter().find(|e| e.scan_index == scan_index)
}
pub fn write_data(&self, scan_index: usize, batches: Vec<RecordBatch>) {
if let Some(entry) = self.get(scan_index) {
let mut guard = entry.data.write();
*guard = batches;
}
}
pub fn entries_for_rule(&self, rule_name: &str) -> Vec<&DerivedScanEntry> {
self.entries
.iter()
.filter(|e| e.rule_name == rule_name)
.collect()
}
}
#[derive(Debug, Clone)]
pub struct MonotonicFoldBinding {
pub fold_name: String,
pub kind: crate::query::df_graph::locy_fold::FoldAggKind,
pub input_col_index: usize,
pub input_col_name: Option<String>,
}
#[derive(Debug)]
pub struct MonotonicAggState {
accumulators: HashMap<(Vec<ScalarKey>, String), f64>,
prev_snapshot: HashMap<(Vec<ScalarKey>, String), f64>,
bindings: Vec<MonotonicFoldBinding>,
}
impl MonotonicAggState {
pub fn new(bindings: Vec<MonotonicFoldBinding>) -> Self {
Self {
accumulators: HashMap::new(),
prev_snapshot: HashMap::new(),
bindings,
}
}
pub fn update(
&mut self,
key_indices: &[usize],
delta_batches: &[RecordBatch],
strict: bool,
) -> DFResult<bool> {
use crate::query::df_graph::locy_fold::FoldAggKind;
let mut changed = false;
for batch in delta_batches {
for row_idx in 0..batch.num_rows() {
let group_key = extract_scalar_key(batch, key_indices, row_idx);
for binding in &self.bindings {
let idx = binding
.input_col_name
.as_ref()
.and_then(|name| batch.schema().index_of(name).ok())
.unwrap_or(binding.input_col_index);
if idx >= batch.num_columns() {
continue;
}
let col = batch.column(idx);
let val = extract_f64(col.as_ref(), row_idx);
if let Some(val) = val {
let map_key = (group_key.clone(), binding.fold_name.clone());
let entry = self
.accumulators
.entry(map_key)
.or_insert(binding.kind.identity().unwrap_or(0.0));
let old = *entry;
match binding.kind {
FoldAggKind::Sum | FoldAggKind::Count => *entry += val,
FoldAggKind::Max => {
if val > *entry {
*entry = val;
}
}
FoldAggKind::Min => {
if val < *entry {
*entry = val;
}
}
FoldAggKind::Nor => {
if strict && !(0.0..=1.0).contains(&val) {
return Err(datafusion::error::DataFusionError::Execution(
format!(
"strict_probability_domain: MNOR input {val} is outside [0, 1]"
),
));
}
if !strict && !(0.0..=1.0).contains(&val) {
tracing::warn!(
"MNOR input {val} outside [0,1], clamped to {}",
val.clamp(0.0, 1.0)
);
}
let p = val.clamp(0.0, 1.0);
*entry = 1.0 - (1.0 - *entry) * (1.0 - p);
}
FoldAggKind::Prod => {
if strict && !(0.0..=1.0).contains(&val) {
return Err(datafusion::error::DataFusionError::Execution(
format!(
"strict_probability_domain: MPROD input {val} is outside [0, 1]"
),
));
}
if !strict && !(0.0..=1.0).contains(&val) {
tracing::warn!(
"MPROD input {val} outside [0,1], clamped to {}",
val.clamp(0.0, 1.0)
);
}
let p = val.clamp(0.0, 1.0);
*entry *= p;
}
_ => {}
}
if (*entry - old).abs() > f64::EPSILON {
changed = true;
}
}
}
}
}
Ok(changed)
}
pub fn snapshot(&mut self) {
self.prev_snapshot = self.accumulators.clone();
}
pub fn is_stable(&self) -> bool {
if self.accumulators.len() != self.prev_snapshot.len() {
return false;
}
for (key, val) in &self.accumulators {
match self.prev_snapshot.get(key) {
Some(prev) if (*val - *prev).abs() <= f64::EPSILON => {}
_ => return false,
}
}
true
}
#[cfg(test)]
pub(crate) fn get_accumulator(&self, key: &(Vec<ScalarKey>, String)) -> Option<f64> {
self.accumulators.get(key).copied()
}
}
fn extract_f64(col: &dyn arrow_array::Array, row_idx: usize) -> Option<f64> {
if col.is_null(row_idx) {
return None;
}
if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Float64Array>() {
Some(arr.value(row_idx))
} else {
col.as_any()
.downcast_ref::<arrow_array::Int64Array>()
.map(|arr| arr.value(row_idx) as f64)
}
}
struct RowDedupState {
converter: RowConverter,
seen: HashSet<Box<[u8]>>,
}
impl RowDedupState {
fn try_new(schema: &SchemaRef) -> Option<Self> {
let fields: Vec<SortField> = schema
.fields()
.iter()
.map(|f| SortField::new(f.data_type().clone()))
.collect();
match RowConverter::new(fields) {
Ok(converter) => Some(Self {
converter,
seen: HashSet::new(),
}),
Err(e) => {
tracing::warn!(
"RowDedupState: RowConverter unsupported for schema, falling back to legacy dedup: {}",
e
);
None
}
}
}
fn ingest_existing(&mut self, facts: &[RecordBatch], _schema: &SchemaRef) {
self.seen.clear();
for batch in facts {
if batch.num_rows() == 0 {
continue;
}
let arrays: Vec<_> = batch.columns().to_vec();
if let Ok(rows) = self.converter.convert_columns(&arrays) {
for row_idx in 0..batch.num_rows() {
let row_bytes: Box<[u8]> = rows.row(row_idx).data().into();
self.seen.insert(row_bytes);
}
}
}
}
fn compute_delta(
&mut self,
candidates: &[RecordBatch],
schema: &SchemaRef,
) -> DFResult<Vec<RecordBatch>> {
let mut delta_batches = Vec::new();
for batch in candidates {
if batch.num_rows() == 0 {
continue;
}
let arrays: Vec<_> = batch.columns().to_vec();
let rows = self.converter.convert_columns(&arrays).map_err(arrow_err)?;
let mut keep = Vec::with_capacity(batch.num_rows());
for row_idx in 0..batch.num_rows() {
let row_bytes: Box<[u8]> = rows.row(row_idx).data().into();
keep.push(self.seen.insert(row_bytes));
}
let keep_mask = arrow_array::BooleanArray::from(keep);
let new_cols = batch
.columns()
.iter()
.map(|col| {
arrow::compute::filter(col.as_ref(), &keep_mask).map_err(|e| {
datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
})
})
.collect::<DFResult<Vec<_>>>()?;
if new_cols.first().is_some_and(|c| !c.is_empty()) {
let filtered = RecordBatch::try_new(Arc::clone(schema), new_cols).map_err(|e| {
datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
})?;
delta_batches.push(filtered);
}
}
Ok(delta_batches)
}
}
pub struct FixpointState {
rule_name: String,
facts: Vec<RecordBatch>,
delta: Vec<RecordBatch>,
schema: SchemaRef,
key_column_indices: Vec<usize>,
key_column_names: Vec<String>,
all_column_indices: Vec<usize>,
facts_bytes: usize,
max_derived_bytes: usize,
monotonic_agg: Option<MonotonicAggState>,
row_dedup: Option<RowDedupState>,
strict_probability_domain: bool,
}
impl FixpointState {
pub fn new(
rule_name: String,
schema: SchemaRef,
key_column_indices: Vec<usize>,
max_derived_bytes: usize,
monotonic_agg: Option<MonotonicAggState>,
strict_probability_domain: bool,
) -> Self {
let num_cols = schema.fields().len();
let row_dedup = RowDedupState::try_new(&schema);
let key_column_names: Vec<String> = key_column_indices
.iter()
.filter_map(|&i| schema.fields().get(i).map(|f| f.name().clone()))
.collect();
Self {
rule_name,
facts: Vec::new(),
delta: Vec::new(),
schema,
key_column_indices,
key_column_names,
all_column_indices: (0..num_cols).collect(),
facts_bytes: 0,
max_derived_bytes,
monotonic_agg,
row_dedup,
strict_probability_domain,
}
}
fn reconcile_schema(&mut self, actual_schema: &SchemaRef) {
if self.schema.fields() != actual_schema.fields() {
tracing::debug!(
rule = %self.rule_name,
"Reconciling fixpoint schema from physical plan output",
);
self.schema = Arc::clone(actual_schema);
self.row_dedup = RowDedupState::try_new(&self.schema);
let new_indices: Vec<usize> = self
.key_column_names
.iter()
.filter_map(|name| actual_schema.index_of(name).ok())
.collect();
if new_indices.len() == self.key_column_names.len() {
self.key_column_indices = new_indices;
}
let num_cols = actual_schema.fields().len();
self.all_column_indices = (0..num_cols).collect();
}
}
pub async fn merge_delta(
&mut self,
candidates: Vec<RecordBatch>,
task_ctx: Option<Arc<TaskContext>>,
) -> DFResult<bool> {
if candidates.is_empty() || candidates.iter().all(|b| b.num_rows() == 0) {
self.delta.clear();
return Ok(false);
}
if let Some(first) = candidates.iter().find(|b| b.num_rows() > 0) {
self.reconcile_schema(&first.schema());
}
let candidates = round_float_columns(&candidates);
let delta = self.compute_delta(&candidates, task_ctx.as_ref()).await?;
if delta.is_empty() || delta.iter().all(|b| b.num_rows() == 0) {
self.delta.clear();
if let Some(ref mut agg) = self.monotonic_agg {
agg.snapshot();
}
return Ok(false);
}
let delta_bytes: usize = delta.iter().map(batch_byte_size).sum();
if self.facts_bytes + delta_bytes > self.max_derived_bytes {
return Err(datafusion::error::DataFusionError::Execution(
LocyRuntimeError::MemoryLimitExceeded {
rule: self.rule_name.clone(),
bytes: self.facts_bytes + delta_bytes,
limit: self.max_derived_bytes,
}
.to_string(),
));
}
if let Some(ref mut agg) = self.monotonic_agg {
agg.snapshot();
agg.update(
&self.key_column_indices,
&delta,
self.strict_probability_domain,
)?;
}
self.facts_bytes += delta_bytes;
self.facts.extend(delta.iter().cloned());
self.delta = delta;
Ok(true)
}
async fn compute_delta(
&mut self,
candidates: &[RecordBatch],
task_ctx: Option<&Arc<TaskContext>>,
) -> DFResult<Vec<RecordBatch>> {
let total_existing: usize = self.facts.iter().map(|b| b.num_rows()).sum();
if total_existing >= DEDUP_ANTI_JOIN_THRESHOLD
&& let Some(ctx) = task_ctx
{
return arrow_left_anti_dedup(candidates.to_vec(), &self.facts, &self.schema, ctx)
.await;
}
if let Some(ref mut rd) = self.row_dedup {
rd.compute_delta(candidates, &self.schema)
} else {
self.compute_delta_legacy(candidates)
}
}
fn compute_delta_legacy(&self, candidates: &[RecordBatch]) -> DFResult<Vec<RecordBatch>> {
let mut existing: HashSet<Vec<ScalarKey>> = HashSet::new();
for batch in &self.facts {
for row_idx in 0..batch.num_rows() {
let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
existing.insert(key);
}
}
let mut delta_batches = Vec::new();
for batch in candidates {
if batch.num_rows() == 0 {
continue;
}
let mut keep = Vec::with_capacity(batch.num_rows());
for row_idx in 0..batch.num_rows() {
let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
keep.push(!existing.contains(&key));
}
for (row_idx, kept) in keep.iter_mut().enumerate() {
if *kept {
let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
if !existing.insert(key) {
*kept = false;
}
}
}
let keep_mask = arrow_array::BooleanArray::from(keep);
let new_rows = batch
.columns()
.iter()
.map(|col| {
arrow::compute::filter(col.as_ref(), &keep_mask).map_err(|e| {
datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
})
})
.collect::<DFResult<Vec<_>>>()?;
if new_rows.first().is_some_and(|c| !c.is_empty()) {
let filtered =
RecordBatch::try_new(Arc::clone(&self.schema), new_rows).map_err(|e| {
datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
})?;
delta_batches.push(filtered);
}
}
Ok(delta_batches)
}
pub fn is_converged(&self) -> bool {
let delta_empty = self.delta.is_empty() || self.delta.iter().all(|b| b.num_rows() == 0);
let agg_stable = self.monotonic_agg.as_ref().is_none_or(|a| a.is_stable());
delta_empty && agg_stable
}
pub fn all_facts(&self) -> &[RecordBatch] {
&self.facts
}
pub fn all_delta(&self) -> &[RecordBatch] {
&self.delta
}
pub fn into_facts(self) -> Vec<RecordBatch> {
self.facts
}
pub fn merge_best_by(
&mut self,
candidates: Vec<RecordBatch>,
sort_criteria: &[SortCriterion],
) -> DFResult<bool> {
if candidates.is_empty() || candidates.iter().all(|b| b.num_rows() == 0) {
self.delta.clear();
return Ok(false);
}
if let Some(first) = candidates.iter().find(|b| b.num_rows() > 0) {
self.reconcile_schema(&first.schema());
}
let candidates = round_float_columns(&candidates);
let old_best: HashMap<Vec<ScalarKey>, Vec<ScalarKey>> =
self.build_key_criteria_map(sort_criteria);
let mut all_batches = self.facts.clone();
all_batches.extend(candidates);
let all_batches: Vec<_> = all_batches
.into_iter()
.filter(|b| b.num_rows() > 0)
.collect();
if all_batches.is_empty() {
self.delta.clear();
return Ok(false);
}
let combined = arrow::compute::concat_batches(&self.schema, &all_batches)
.map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
if combined.num_rows() == 0 {
self.delta.clear();
return Ok(false);
}
let mut sort_columns = Vec::new();
for &ki in &self.key_column_indices {
if ki >= combined.num_columns() {
continue;
}
sort_columns.push(arrow::compute::SortColumn {
values: Arc::clone(combined.column(ki)),
options: Some(arrow::compute::SortOptions {
descending: false,
nulls_first: false,
}),
});
}
for criterion in sort_criteria {
if criterion.col_index >= combined.num_columns() {
continue;
}
sort_columns.push(arrow::compute::SortColumn {
values: Arc::clone(combined.column(criterion.col_index)),
options: Some(arrow::compute::SortOptions {
descending: !criterion.ascending,
nulls_first: criterion.nulls_first,
}),
});
}
let sorted_indices =
arrow::compute::lexsort_to_indices(&sort_columns, None).map_err(arrow_err)?;
let sorted_columns: Vec<_> = combined
.columns()
.iter()
.map(|col| arrow::compute::take(col.as_ref(), &sorted_indices, None))
.collect::<Result<Vec<_>, _>>()
.map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
let sorted = RecordBatch::try_new(Arc::clone(&self.schema), sorted_columns)
.map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
let mut keep_indices: Vec<u32> = Vec::new();
let mut prev_key: Option<Vec<ScalarKey>> = None;
for row_idx in 0..sorted.num_rows() {
let key = extract_scalar_key(&sorted, &self.key_column_indices, row_idx);
let is_new_group = match &prev_key {
None => true,
Some(prev) => *prev != key,
};
if is_new_group {
keep_indices.push(row_idx as u32);
prev_key = Some(key);
}
}
let keep_array = arrow_array::UInt32Array::from(keep_indices);
let output_columns: Vec<_> = sorted
.columns()
.iter()
.map(|col| arrow::compute::take(col.as_ref(), &keep_array, None))
.collect::<Result<Vec<_>, _>>()
.map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
let pruned = RecordBatch::try_new(Arc::clone(&self.schema), output_columns)
.map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
let new_best: HashMap<Vec<ScalarKey>, Vec<ScalarKey>> = {
let mut map = HashMap::new();
for row_idx in 0..pruned.num_rows() {
let key = extract_scalar_key(&pruned, &self.key_column_indices, row_idx);
let criteria: Vec<ScalarKey> = sort_criteria
.iter()
.flat_map(|c| extract_scalar_key(&pruned, &[c.col_index], row_idx))
.collect();
map.insert(key, criteria);
}
map
};
let changed = old_best != new_best;
tracing::debug!(
rule = %self.rule_name,
old_keys = old_best.len(),
new_keys = new_best.len(),
changed = changed,
"BEST BY merge"
);
self.facts_bytes = batch_byte_size(&pruned);
self.facts = vec![pruned];
if changed {
self.delta = self.facts.clone();
} else {
self.delta.clear();
}
self.row_dedup = RowDedupState::try_new(&self.schema);
if let Some(ref mut rd) = self.row_dedup {
rd.ingest_existing(&self.facts, &self.schema);
}
Ok(changed)
}
fn build_key_criteria_map(
&self,
sort_criteria: &[SortCriterion],
) -> HashMap<Vec<ScalarKey>, Vec<ScalarKey>> {
let mut map = HashMap::new();
for batch in &self.facts {
for row_idx in 0..batch.num_rows() {
let key = extract_scalar_key(batch, &self.key_column_indices, row_idx);
let criteria: Vec<ScalarKey> = sort_criteria
.iter()
.flat_map(|c| extract_scalar_key(batch, &[c.col_index], row_idx))
.collect();
map.insert(key, criteria);
}
}
map
}
}
fn batch_byte_size(batch: &RecordBatch) -> usize {
batch
.columns()
.iter()
.map(|col| col.get_buffer_memory_size())
.sum()
}
fn round_float_columns(batches: &[RecordBatch]) -> Vec<RecordBatch> {
batches
.iter()
.map(|batch| {
let schema = batch.schema();
let has_float = schema
.fields()
.iter()
.any(|f| *f.data_type() == arrow_schema::DataType::Float64);
if !has_float {
return batch.clone();
}
let columns: Vec<arrow_array::ArrayRef> = batch
.columns()
.iter()
.enumerate()
.map(|(i, col)| {
if *schema.field(i).data_type() == arrow_schema::DataType::Float64 {
let arr = col
.as_any()
.downcast_ref::<arrow_array::Float64Array>()
.unwrap();
let rounded: arrow_array::Float64Array = arr
.iter()
.map(|v| v.map(|f| (f * 1e12).round() / 1e12))
.collect();
Arc::new(rounded) as arrow_array::ArrayRef
} else {
Arc::clone(col)
}
})
.collect();
RecordBatch::try_new(schema, columns).unwrap_or_else(|_| batch.clone())
})
.collect()
}
const DEDUP_ANTI_JOIN_THRESHOLD: usize = 300;
async fn arrow_left_anti_dedup(
candidates: Vec<RecordBatch>,
existing: &[RecordBatch],
schema: &SchemaRef,
task_ctx: &Arc<TaskContext>,
) -> DFResult<Vec<RecordBatch>> {
if existing.is_empty() || existing.iter().all(|b| b.num_rows() == 0) {
return Ok(candidates);
}
let left: Arc<dyn ExecutionPlan> = Arc::new(InMemoryExec::new(candidates, Arc::clone(schema)));
let right: Arc<dyn ExecutionPlan> =
Arc::new(InMemoryExec::new(existing.to_vec(), Arc::clone(schema)));
let on: Vec<(
Arc<dyn datafusion::physical_plan::PhysicalExpr>,
Arc<dyn datafusion::physical_plan::PhysicalExpr>,
)> = schema
.fields()
.iter()
.enumerate()
.map(|(i, field)| {
let l: Arc<dyn datafusion::physical_plan::PhysicalExpr> = Arc::new(
datafusion::physical_plan::expressions::Column::new(field.name(), i),
);
let r: Arc<dyn datafusion::physical_plan::PhysicalExpr> = Arc::new(
datafusion::physical_plan::expressions::Column::new(field.name(), i),
);
(l, r)
})
.collect();
if on.is_empty() {
return Ok(vec![]);
}
let join = HashJoinExec::try_new(
left,
right,
on,
None,
&JoinType::LeftAnti,
None,
PartitionMode::CollectLeft,
datafusion::common::NullEquality::NullEqualsNull,
)?;
let join_arc: Arc<dyn ExecutionPlan> = Arc::new(join);
collect_all_partitions(&join_arc, task_ctx.clone()).await
}
#[derive(Debug, Clone)]
pub struct IsRefBinding {
pub derived_scan_index: usize,
pub rule_name: String,
pub is_self_ref: bool,
pub negated: bool,
pub anti_join_cols: Vec<(String, String)>,
pub target_has_prob: bool,
pub target_prob_col: Option<String>,
pub provenance_join_cols: Vec<(String, String)>,
}
#[derive(Debug)]
pub struct FixpointClausePlan {
pub body_logical: LogicalPlan,
pub is_ref_bindings: Vec<IsRefBinding>,
pub priority: Option<i64>,
pub along_bindings: Vec<String>,
}
#[derive(Debug)]
pub struct FixpointRulePlan {
pub name: String,
pub clauses: Vec<FixpointClausePlan>,
pub yield_schema: SchemaRef,
pub key_column_indices: Vec<usize>,
pub priority: Option<i64>,
pub has_fold: bool,
pub fold_bindings: Vec<FoldBinding>,
pub having: Vec<Expr>,
pub has_best_by: bool,
pub best_by_criteria: Vec<SortCriterion>,
pub has_priority: bool,
pub deterministic: bool,
pub prob_column_name: Option<String>,
}
#[expect(clippy::too_many_arguments, reason = "Fixpoint loop needs all context")]
async fn run_fixpoint_loop(
rules: Vec<FixpointRulePlan>,
max_iterations: usize,
timeout: Duration,
graph_ctx: Arc<GraphExecutionContext>,
session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
storage: Arc<StorageManager>,
schema_info: Arc<UniSchema>,
params: HashMap<String, Value>,
registry: Arc<DerivedScanRegistry>,
output_schema: SchemaRef,
max_derived_bytes: usize,
derivation_tracker: Option<Arc<ProvenanceStore>>,
iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
strict_probability_domain: bool,
probability_epsilon: f64,
exact_probability: bool,
max_bdd_variables: usize,
warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
top_k_proofs: usize,
timeout_flag: Arc<std::sync::atomic::AtomicBool>,
) -> DFResult<Vec<RecordBatch>> {
let start = Instant::now();
let task_ctx = session_ctx.read().task_ctx();
let mut states: Vec<FixpointState> = rules
.iter()
.map(|rule| {
let monotonic_agg = if !rule.fold_bindings.is_empty() {
let bindings: Vec<MonotonicFoldBinding> = rule
.fold_bindings
.iter()
.map(|fb| MonotonicFoldBinding {
fold_name: fb.output_name.clone(),
kind: fb.kind.clone(),
input_col_index: fb.input_col_index,
input_col_name: fb.input_col_name.clone(),
})
.collect();
Some(MonotonicAggState::new(bindings))
} else {
None
};
FixpointState::new(
rule.name.clone(),
Arc::clone(&rule.yield_schema),
rule.key_column_indices.clone(),
max_derived_bytes,
monotonic_agg,
strict_probability_domain,
)
})
.collect();
let mut converged = false;
let mut total_iters = 0usize;
for iteration in 0..max_iterations {
total_iters = iteration + 1;
tracing::debug!("fixpoint iteration {}", iteration);
let mut any_changed = false;
for rule_idx in 0..rules.len() {
let rule = &rules[rule_idx];
update_derived_scan_handles(®istry, &states, rule_idx, &rules);
let mut all_candidates = Vec::new();
let mut clause_candidates: Vec<Vec<RecordBatch>> = Vec::new();
for clause in &rule.clauses {
let mut batches = execute_subplan(
&clause.body_logical,
¶ms,
&HashMap::new(),
&graph_ctx,
&session_ctx,
&storage,
&schema_info,
)
.await?;
for binding in &clause.is_ref_bindings {
if binding.negated
&& !binding.anti_join_cols.is_empty()
&& let Some(entry) = registry.get(binding.derived_scan_index)
{
let neg_facts = entry.data.read().clone();
if !neg_facts.is_empty() {
if binding.target_has_prob && rule.prob_column_name.is_some() {
let complement_col =
format!("__prob_complement_{}", binding.rule_name);
if let Some(prob_col) = &binding.target_prob_col {
batches = apply_prob_complement_composite(
batches,
&neg_facts,
&binding.anti_join_cols,
prob_col,
&complement_col,
)?;
} else {
batches = apply_anti_join_composite(
batches,
&neg_facts,
&binding.anti_join_cols,
)?;
}
} else {
batches = apply_anti_join_composite(
batches,
&neg_facts,
&binding.anti_join_cols,
)?;
}
}
}
}
let complement_cols: Vec<String> = if !batches.is_empty() {
batches[0]
.schema()
.fields()
.iter()
.filter(|f| f.name().starts_with("__prob_complement_"))
.map(|f| f.name().clone())
.collect()
} else {
vec![]
};
if !complement_cols.is_empty() {
batches = multiply_prob_factors(
batches,
rule.prob_column_name.as_deref(),
&complement_cols,
)?;
}
clause_candidates.push(batches.clone());
all_candidates.extend(batches);
}
let changed = if rule.has_best_by && !rule.best_by_criteria.is_empty() {
states[rule_idx].merge_best_by(all_candidates, &rule.best_by_criteria)?
} else {
states[rule_idx]
.merge_delta(all_candidates, Some(Arc::clone(&task_ctx)))
.await?
};
if changed {
any_changed = true;
if let Some(ref tracker) = derivation_tracker {
record_provenance(
tracker,
rule,
&states[rule_idx],
&clause_candidates,
iteration,
®istry,
top_k_proofs,
);
}
}
}
if !any_changed && states.iter().all(|s| s.is_converged()) {
tracing::debug!("fixpoint converged after {} iterations", iteration + 1);
converged = true;
break;
}
if start.elapsed() > timeout {
tracing::warn!(
"fixpoint timeout after {} iterations; returning partial results",
iteration + 1,
);
timeout_flag.store(true, std::sync::atomic::Ordering::Relaxed);
break;
}
}
if let Ok(mut counts) = iteration_counts.write() {
for rule in &rules {
counts.insert(rule.name.clone(), total_iters);
}
}
if !converged && !timeout_flag.load(std::sync::atomic::Ordering::Relaxed) {
tracing::warn!(
"fixpoint did not converge after {max_iterations} iterations; returning partial results",
);
timeout_flag.store(true, std::sync::atomic::Ordering::Relaxed);
}
let task_ctx = session_ctx.read().task_ctx();
let mut all_output = Vec::new();
for (rule_idx, state) in states.into_iter().enumerate() {
let rule = &rules[rule_idx];
let mut facts = state.into_facts();
if facts.is_empty() {
continue;
}
let shared_info = if let Some(ref tracker) = derivation_tracker {
detect_shared_lineage(rule, &facts, tracker, &warnings_slot)
} else {
None
};
if exact_probability
&& let Some(ref info) = shared_info
&& let Some(ref tracker) = derivation_tracker
{
facts = apply_exact_wmc(
facts,
rule,
info,
tracker,
max_bdd_variables,
&warnings_slot,
&approximate_slot,
)?;
}
let processed = apply_post_fixpoint_chain(
facts,
rule,
&task_ctx,
strict_probability_domain,
probability_epsilon,
)
.await?;
all_output.extend(processed);
}
if all_output.is_empty() {
all_output.push(RecordBatch::new_empty(output_schema));
}
Ok(all_output)
}
fn record_provenance(
tracker: &Arc<ProvenanceStore>,
rule: &FixpointRulePlan,
state: &FixpointState,
clause_candidates: &[Vec<RecordBatch>],
iteration: usize,
registry: &Arc<DerivedScanRegistry>,
top_k_proofs: usize,
) {
let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
let base_probs = if top_k_proofs > 0 {
tracker.base_fact_probs()
} else {
HashMap::new()
};
for delta_batch in state.all_delta() {
for row_idx in 0..delta_batch.num_rows() {
let row_hash = format!(
"{:?}",
extract_scalar_key(delta_batch, &all_indices, row_idx)
)
.into_bytes();
let fact_row = batch_row_to_value_map(delta_batch, row_idx);
let clause_index =
find_clause_for_row(delta_batch, row_idx, &all_indices, clause_candidates);
let support = collect_is_ref_inputs(rule, clause_index, delta_batch, row_idx, registry);
let proof_probability = if top_k_proofs > 0 {
compute_proof_probability(&support, &base_probs)
} else {
None
};
let entry = ProvenanceAnnotation {
rule_name: rule.name.clone(),
clause_index,
support,
along_values: {
let along_names: Vec<String> = rule
.clauses
.get(clause_index)
.map(|c| c.along_bindings.clone())
.unwrap_or_default();
along_names
.iter()
.filter_map(|name| fact_row.get(name).map(|v| (name.clone(), v.clone())))
.collect()
},
iteration,
fact_row,
proof_probability,
};
if top_k_proofs > 0 {
tracker.record_top_k(row_hash, entry, top_k_proofs);
} else {
tracker.record(row_hash, entry);
}
}
}
}
fn collect_is_ref_inputs(
rule: &FixpointRulePlan,
clause_index: usize,
delta_batch: &RecordBatch,
row_idx: usize,
registry: &Arc<DerivedScanRegistry>,
) -> Vec<ProofTerm> {
let clause = match rule.clauses.get(clause_index) {
Some(c) => c,
None => return vec![],
};
let mut inputs = Vec::new();
let delta_schema = delta_batch.schema();
for binding in &clause.is_ref_bindings {
if binding.negated {
continue;
}
if binding.provenance_join_cols.is_empty() {
continue;
}
let body_values: Vec<(String, ScalarKey)> = binding
.provenance_join_cols
.iter()
.filter_map(|(body_col, _derived_col)| {
let col_idx = delta_schema
.fields()
.iter()
.position(|f| f.name() == body_col)?;
let key = extract_scalar_key(delta_batch, &[col_idx], row_idx);
Some((body_col.clone(), key.into_iter().next()?))
})
.collect();
if body_values.len() != binding.provenance_join_cols.len() {
continue;
}
let entry = match registry.get(binding.derived_scan_index) {
Some(e) => e,
None => continue,
};
let source_batches = entry.data.read();
let source_schema = &entry.schema;
for src_batch in source_batches.iter() {
let all_src_indices: Vec<usize> = (0..src_batch.num_columns()).collect();
for src_row in 0..src_batch.num_rows() {
let matches = binding.provenance_join_cols.iter().enumerate().all(
|(i, (_body_col, derived_col))| {
let src_col_idx = source_schema
.fields()
.iter()
.position(|f| f.name() == derived_col);
match src_col_idx {
Some(idx) => {
let src_key = extract_scalar_key(src_batch, &[idx], src_row);
src_key.first() == Some(&body_values[i].1)
}
None => false,
}
},
);
if matches {
let fact_hash = format!(
"{:?}",
extract_scalar_key(src_batch, &all_src_indices, src_row)
)
.into_bytes();
inputs.push(ProofTerm {
source_rule: binding.rule_name.clone(),
base_fact_id: fact_hash,
});
}
}
}
}
inputs
}
#[expect(
dead_code,
reason = "Fields accessed via SharedLineageInfo in detect_shared_lineage"
)]
pub(crate) struct SharedGroupRow {
pub fact_hash: Vec<u8>,
pub lineage: HashSet<Vec<u8>>,
}
pub(crate) struct SharedLineageInfo {
pub shared_groups: HashMap<Vec<ScalarKey>, Vec<SharedGroupRow>>,
}
fn fact_hash_key(batch: &RecordBatch, all_indices: &[usize], row_idx: usize) -> Vec<u8> {
format!("{:?}", extract_scalar_key(batch, all_indices, row_idx)).into_bytes()
}
fn detect_shared_lineage(
rule: &FixpointRulePlan,
pre_fold_facts: &[RecordBatch],
tracker: &Arc<ProvenanceStore>,
warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
) -> Option<SharedLineageInfo> {
use crate::query::df_graph::locy_fold::FoldAggKind;
use uni_locy::{RuntimeWarning, RuntimeWarningCode};
let has_prob_fold = rule
.fold_bindings
.iter()
.any(|fb| matches!(fb.kind, FoldAggKind::Nor | FoldAggKind::Prod));
if !has_prob_fold {
return None;
}
let key_indices = &rule.key_column_indices;
let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
let mut groups: HashMap<Vec<ScalarKey>, Vec<Vec<u8>>> = HashMap::new();
for batch in pre_fold_facts {
for row_idx in 0..batch.num_rows() {
let key = extract_scalar_key(batch, key_indices, row_idx);
let fact_hash = fact_hash_key(batch, &all_indices, row_idx);
groups.entry(key).or_default().push(fact_hash);
}
}
let mut shared_groups: HashMap<Vec<ScalarKey>, Vec<SharedGroupRow>> = HashMap::new();
let mut any_shared = false;
for (key, fact_hashes) in &groups {
if fact_hashes.len() < 2 {
continue;
}
let mut has_inputs = false;
let mut per_row_bases: Vec<HashSet<Vec<u8>>> = Vec::new();
for fh in fact_hashes {
let bases = compute_lineage(fh, tracker, &mut HashSet::new());
if let Some(entry) = tracker.lookup(fh)
&& !entry.support.is_empty()
{
has_inputs = true;
}
per_row_bases.push(bases);
}
let shared_found = if has_inputs {
let mut found = false;
'outer: for i in 0..per_row_bases.len() {
for j in (i + 1)..per_row_bases.len() {
if !per_row_bases[i].is_disjoint(&per_row_bases[j]) {
found = true;
break 'outer;
}
}
}
found
} else {
fact_hashes.iter().any(|fh| {
tracker.lookup(fh).is_some_and(|entry| {
rule.clauses
.get(entry.clause_index)
.is_some_and(|clause| clause.is_ref_bindings.iter().any(|b| !b.negated))
})
})
};
if shared_found {
any_shared = true;
let rows: Vec<SharedGroupRow> = fact_hashes
.iter()
.zip(per_row_bases.into_iter())
.map(|(fh, bases)| SharedGroupRow {
fact_hash: fh.clone(),
lineage: bases,
})
.collect();
shared_groups.insert(key.clone(), rows);
}
}
{
let mut input_to_groups: HashMap<Vec<u8>, HashSet<Vec<ScalarKey>>> = HashMap::new();
for (key, fact_hashes) in &groups {
for fh in fact_hashes {
if let Some(entry) = tracker.lookup(fh) {
for input in &entry.support {
input_to_groups
.entry(input.base_fact_id.clone())
.or_default()
.insert(key.clone());
}
}
}
}
let has_cross_group = input_to_groups.values().any(|g| g.len() > 1);
if has_cross_group && let Ok(mut warnings) = warnings_slot.write() {
let already_warned = warnings.iter().any(|w| {
w.code == RuntimeWarningCode::CrossGroupCorrelationNotExact
&& w.rule_name == rule.name
});
if !already_warned {
warnings.push(RuntimeWarning {
code: RuntimeWarningCode::CrossGroupCorrelationNotExact,
message: format!(
"Rule '{}': IS-ref base facts are shared across different KEY \
groups. BDD corrects per-group probabilities but cannot account \
for cross-group correlations.",
rule.name
),
rule_name: rule.name.clone(),
variable_count: None,
key_group: None,
});
}
}
}
if any_shared {
if let Ok(mut warnings) = warnings_slot.write() {
let already_warned = warnings.iter().any(|w| {
w.code == RuntimeWarningCode::SharedProbabilisticDependency
&& w.rule_name == rule.name
});
if !already_warned {
warnings.push(RuntimeWarning {
code: RuntimeWarningCode::SharedProbabilisticDependency,
message: format!(
"Rule '{}' aggregates with MNOR/MPROD but some proof paths \
share intermediate facts, violating the independence assumption. \
Results may overestimate probability.",
rule.name
),
rule_name: rule.name.clone(),
variable_count: None,
key_group: None,
});
}
}
Some(SharedLineageInfo { shared_groups })
} else {
None
}
}
pub(crate) fn record_and_detect_lineage_nonrecursive(
rule: &FixpointRulePlan,
tagged_clause_facts: &[(usize, Vec<RecordBatch>)],
tracker: &Arc<ProvenanceStore>,
warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
registry: &Arc<DerivedScanRegistry>,
top_k_proofs: usize,
) -> Option<SharedLineageInfo> {
let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
let base_probs = if top_k_proofs > 0 {
tracker.base_fact_probs()
} else {
HashMap::new()
};
for (clause_index, batches) in tagged_clause_facts {
for batch in batches {
for row_idx in 0..batch.num_rows() {
let row_hash = fact_hash_key(batch, &all_indices, row_idx);
let fact_row = batch_row_to_value_map(batch, row_idx);
let support = collect_is_ref_inputs(rule, *clause_index, batch, row_idx, registry);
let proof_probability = if top_k_proofs > 0 {
compute_proof_probability(&support, &base_probs)
} else {
None
};
let entry = ProvenanceAnnotation {
rule_name: rule.name.clone(),
clause_index: *clause_index,
support,
along_values: {
let along_names: Vec<String> = rule
.clauses
.get(*clause_index)
.map(|c| c.along_bindings.clone())
.unwrap_or_default();
along_names
.iter()
.filter_map(|name| {
fact_row.get(name).map(|v| (name.clone(), v.clone()))
})
.collect()
},
iteration: 0,
fact_row,
proof_probability,
};
if top_k_proofs > 0 {
tracker.record_top_k(row_hash, entry, top_k_proofs);
} else {
tracker.record(row_hash, entry);
}
}
}
}
let all_facts: Vec<RecordBatch> = tagged_clause_facts
.iter()
.flat_map(|(_, batches)| batches.iter().cloned())
.collect();
detect_shared_lineage(rule, &all_facts, tracker, warnings_slot)
}
pub(crate) fn apply_exact_wmc(
pre_fold_facts: Vec<RecordBatch>,
rule: &FixpointRulePlan,
shared_info: &SharedLineageInfo,
tracker: &Arc<ProvenanceStore>,
max_bdd_variables: usize,
warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
approximate_slot: &Arc<StdRwLock<HashMap<String, Vec<String>>>>,
) -> DFResult<Vec<RecordBatch>> {
use crate::query::df_graph::locy_bdd::{SemiringOp, weighted_model_count};
use crate::query::df_graph::locy_fold::FoldAggKind;
use uni_locy::{RuntimeWarning, RuntimeWarningCode};
let prob_fold = rule
.fold_bindings
.iter()
.find(|fb| matches!(fb.kind, FoldAggKind::Nor | FoldAggKind::Prod));
let prob_fold = match prob_fold {
Some(f) => f,
None => return Ok(pre_fold_facts),
};
let semiring_op = if matches!(prob_fold.kind, FoldAggKind::Nor) {
SemiringOp::Disjunction
} else {
SemiringOp::Conjunction
};
let prob_col_idx = prob_fold.input_col_index;
let prob_col_name = rule.yield_schema.field(prob_col_idx).name().clone();
let key_indices = &rule.key_column_indices;
let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
let shared_keys: HashSet<Vec<ScalarKey>> = shared_info.shared_groups.keys().cloned().collect();
struct GroupAccum {
base_facts: Vec<HashSet<Vec<u8>>>,
base_probs: HashMap<Vec<u8>, f64>,
representative: (usize, usize),
row_locations: Vec<(usize, usize)>,
}
let mut group_accums: HashMap<Vec<ScalarKey>, GroupAccum> = HashMap::new();
let mut non_shared_rows: Vec<(usize, usize)> = Vec::new();
for (batch_idx, batch) in pre_fold_facts.iter().enumerate() {
for row_idx in 0..batch.num_rows() {
let key = extract_scalar_key(batch, key_indices, row_idx);
if shared_keys.contains(&key) {
let fact_hash = fact_hash_key(batch, &all_indices, row_idx);
let bases = compute_lineage(&fact_hash, tracker, &mut HashSet::new());
let accum = group_accums.entry(key).or_insert_with(|| GroupAccum {
base_facts: Vec::new(),
base_probs: HashMap::new(),
representative: (batch_idx, row_idx),
row_locations: Vec::new(),
});
for bf in &bases {
if !accum.base_probs.contains_key(bf)
&& let Some(entry) = tracker.lookup(bf)
&& let Some(val) = entry.fact_row.get(&prob_col_name)
&& let Some(p) = value_to_f64(val)
{
accum.base_probs.insert(bf.clone(), p);
}
}
accum.base_facts.push(bases);
accum.row_locations.push((batch_idx, row_idx));
} else {
non_shared_rows.push((batch_idx, row_idx));
}
}
}
let mut keep_rows: HashSet<(usize, usize)> = HashSet::new();
let mut overrides: HashMap<(usize, usize), f64> = HashMap::new();
for &loc in &non_shared_rows {
keep_rows.insert(loc);
}
for (key, accum) in &group_accums {
let bdd_result = weighted_model_count(
&accum.base_facts,
&accum.base_probs,
semiring_op,
max_bdd_variables,
);
if bdd_result.approximated {
if let Ok(mut warnings) = warnings_slot.write() {
let key_desc = format!("{key:?}");
let already_warned = warnings.iter().any(|w| {
w.code == RuntimeWarningCode::BddLimitExceeded
&& w.rule_name == rule.name
&& w.key_group.as_deref() == Some(&key_desc)
});
if !already_warned {
warnings.push(RuntimeWarning {
code: RuntimeWarningCode::BddLimitExceeded,
message: format!(
"Rule '{}': BDD variable limit exceeded ({} > {}). \
Falling back to independence-mode result.",
rule.name, bdd_result.variable_count, max_bdd_variables
),
rule_name: rule.name.clone(),
variable_count: Some(bdd_result.variable_count),
key_group: Some(key_desc),
});
}
}
if let Ok(mut approx) = approximate_slot.write() {
let key_desc = format!("{key:?}");
approx.entry(rule.name.clone()).or_default().push(key_desc);
}
for &loc in &accum.row_locations {
keep_rows.insert(loc);
}
} else {
keep_rows.insert(accum.representative);
overrides.insert(accum.representative, bdd_result.probability);
}
}
let mut result_batches = Vec::new();
for (batch_idx, batch) in pre_fold_facts.iter().enumerate() {
let kept_indices: Vec<usize> = (0..batch.num_rows())
.filter(|&row_idx| keep_rows.contains(&(batch_idx, row_idx)))
.collect();
if kept_indices.is_empty() {
continue;
}
let indices = arrow::array::UInt32Array::from(
kept_indices.iter().map(|&i| i as u32).collect::<Vec<_>>(),
);
let mut columns: Vec<arrow::array::ArrayRef> = batch
.columns()
.iter()
.map(|col| arrow::compute::take(col, &indices, None))
.collect::<Result<Vec<_>, _>>()
.map_err(arrow_err)?;
let override_map: Vec<Option<f64>> = kept_indices
.iter()
.map(|&row_idx| overrides.get(&(batch_idx, row_idx)).copied())
.collect();
if override_map.iter().any(|o| o.is_some()) && prob_col_idx < columns.len() {
let existing_prob = columns[prob_col_idx]
.as_any()
.downcast_ref::<arrow::array::Float64Array>();
let new_values: Vec<f64> = override_map
.iter()
.enumerate()
.map(|(i, ov)| match ov {
Some(p) => *p,
None => existing_prob.map(|arr| arr.value(i)).unwrap_or(0.0),
})
.collect();
columns[prob_col_idx] = Arc::new(arrow::array::Float64Array::from(new_values));
}
let result_batch = RecordBatch::try_new(batch.schema(), columns).map_err(arrow_err)?;
result_batches.push(result_batch);
}
Ok(result_batches)
}
fn value_to_f64(val: &uni_common::Value) -> Option<f64> {
match val {
uni_common::Value::Float(f) => Some(*f),
uni_common::Value::Int(i) => Some(*i as f64),
_ => None,
}
}
fn compute_lineage(
fact_hash: &[u8],
tracker: &Arc<ProvenanceStore>,
visited: &mut HashSet<Vec<u8>>,
) -> HashSet<Vec<u8>> {
if !visited.insert(fact_hash.to_vec()) {
return HashSet::new(); }
match tracker.lookup(fact_hash) {
Some(entry) if !entry.support.is_empty() => {
let mut bases = HashSet::new();
for input in &entry.support {
let child_bases = compute_lineage(&input.base_fact_id, tracker, visited);
bases.extend(child_bases);
}
bases
}
_ => {
let mut set = HashSet::new();
set.insert(fact_hash.to_vec());
set
}
}
}
fn find_clause_for_row(
delta_batch: &RecordBatch,
row_idx: usize,
all_indices: &[usize],
clause_candidates: &[Vec<RecordBatch>],
) -> usize {
let target_key = extract_scalar_key(delta_batch, all_indices, row_idx);
for (clause_idx, batches) in clause_candidates.iter().enumerate() {
for batch in batches {
if batch.num_columns() != all_indices.len() {
continue;
}
for r in 0..batch.num_rows() {
if extract_scalar_key(batch, all_indices, r) == target_key {
return clause_idx;
}
}
}
}
0
}
fn batch_row_to_value_map(
batch: &RecordBatch,
row_idx: usize,
) -> std::collections::HashMap<String, Value> {
use uni_store::storage::arrow_convert::arrow_to_value;
let schema = batch.schema();
schema
.fields()
.iter()
.enumerate()
.map(|(col_idx, field)| {
let col = batch.column(col_idx);
let val = arrow_to_value(col.as_ref(), row_idx, None);
(field.name().clone(), val)
})
.collect()
}
pub fn apply_anti_join(
batches: Vec<RecordBatch>,
neg_facts: &[RecordBatch],
left_col: &str,
right_col: &str,
) -> datafusion::error::Result<Vec<RecordBatch>> {
use arrow::compute::filter_record_batch;
use arrow_array::{Array as _, BooleanArray, UInt64Array};
let mut banned: std::collections::HashSet<u64> = std::collections::HashSet::new();
for batch in neg_facts {
let Ok(idx) = batch.schema().index_of(right_col) else {
continue;
};
let arr = batch.column(idx);
let Some(vids) = arr.as_any().downcast_ref::<UInt64Array>() else {
continue;
};
for i in 0..vids.len() {
if !vids.is_null(i) {
banned.insert(vids.value(i));
}
}
}
if banned.is_empty() {
return Ok(batches);
}
let mut result = Vec::new();
for batch in batches {
let Ok(idx) = batch.schema().index_of(left_col) else {
result.push(batch);
continue;
};
let arr = batch.column(idx);
let Some(vids) = arr.as_any().downcast_ref::<UInt64Array>() else {
result.push(batch);
continue;
};
let keep: Vec<bool> = (0..vids.len())
.map(|i| vids.is_null(i) || !banned.contains(&vids.value(i)))
.collect();
let keep_arr = BooleanArray::from(keep);
let filtered = filter_record_batch(&batch, &keep_arr).map_err(arrow_err)?;
if filtered.num_rows() > 0 {
result.push(filtered);
}
}
Ok(result)
}
pub fn apply_prob_complement(
batches: Vec<RecordBatch>,
neg_facts: &[RecordBatch],
left_col: &str,
right_col: &str,
prob_col: &str,
complement_col_name: &str,
) -> datafusion::error::Result<Vec<RecordBatch>> {
use arrow_array::{Array as _, Float64Array, UInt64Array};
let mut prob_map: std::collections::HashMap<u64, f64> = std::collections::HashMap::new();
for batch in neg_facts {
let Ok(vid_idx) = batch.schema().index_of(right_col) else {
continue;
};
let Ok(prob_idx) = batch.schema().index_of(prob_col) else {
continue;
};
let Some(vids) = batch.column(vid_idx).as_any().downcast_ref::<UInt64Array>() else {
continue;
};
let prob_arr = batch.column(prob_idx);
let probs = prob_arr.as_any().downcast_ref::<Float64Array>();
for i in 0..vids.len() {
if !vids.is_null(i) {
let p = probs
.and_then(|arr| {
if arr.is_null(i) {
None
} else {
Some(arr.value(i))
}
})
.unwrap_or(0.0);
prob_map
.entry(vids.value(i))
.and_modify(|existing| {
*existing = 1.0 - (1.0 - *existing) * (1.0 - p);
})
.or_insert(p);
}
}
}
let mut result = Vec::new();
for batch in batches {
let Ok(idx) = batch.schema().index_of(left_col) else {
result.push(batch);
continue;
};
let Some(vids) = batch.column(idx).as_any().downcast_ref::<UInt64Array>() else {
result.push(batch);
continue;
};
let complements: Vec<f64> = (0..vids.len())
.map(|i| {
if vids.is_null(i) {
1.0
} else {
let p = prob_map.get(&vids.value(i)).copied().unwrap_or(0.0);
1.0 - p
}
})
.collect();
let complement_arr = Float64Array::from(complements);
let mut columns: Vec<arrow_array::ArrayRef> = batch.columns().to_vec();
columns.push(std::sync::Arc::new(complement_arr));
let mut fields: Vec<std::sync::Arc<arrow_schema::Field>> =
batch.schema().fields().iter().cloned().collect();
fields.push(std::sync::Arc::new(arrow_schema::Field::new(
complement_col_name,
arrow_schema::DataType::Float64,
true,
)));
let new_schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
let new_batch = RecordBatch::try_new(new_schema, columns).map_err(arrow_err)?;
result.push(new_batch);
}
Ok(result)
}
pub fn apply_prob_complement_composite(
batches: Vec<RecordBatch>,
neg_facts: &[RecordBatch],
join_cols: &[(String, String)],
prob_col: &str,
complement_col_name: &str,
) -> datafusion::error::Result<Vec<RecordBatch>> {
use arrow_array::{Array as _, Float64Array, UInt64Array};
let mut prob_map: HashMap<Vec<u64>, f64> = HashMap::new();
for batch in neg_facts {
let right_indices: Vec<usize> = join_cols
.iter()
.filter_map(|(_, rc)| batch.schema().index_of(rc).ok())
.collect();
if right_indices.len() != join_cols.len() {
continue;
}
let Ok(prob_idx) = batch.schema().index_of(prob_col) else {
continue;
};
let prob_arr = batch.column(prob_idx);
let probs = prob_arr.as_any().downcast_ref::<Float64Array>();
for row in 0..batch.num_rows() {
let mut key = Vec::with_capacity(right_indices.len());
let mut valid = true;
for &ci in &right_indices {
let col = batch.column(ci);
if let Some(vids) = col.as_any().downcast_ref::<UInt64Array>() {
if vids.is_null(row) {
valid = false;
break;
}
key.push(vids.value(row));
} else {
valid = false;
break;
}
}
if !valid {
continue;
}
let p = probs
.and_then(|arr| {
if arr.is_null(row) {
None
} else {
Some(arr.value(row))
}
})
.unwrap_or(0.0);
prob_map
.entry(key)
.and_modify(|existing| {
*existing = 1.0 - (1.0 - *existing) * (1.0 - p);
})
.or_insert(p);
}
}
let mut result = Vec::new();
for batch in batches {
let left_indices: Vec<usize> = join_cols
.iter()
.filter_map(|(lc, _)| batch.schema().index_of(lc).ok())
.collect();
if left_indices.len() != join_cols.len() {
result.push(batch);
continue;
}
let all_u64 = left_indices.iter().all(|&ci| {
batch
.column(ci)
.as_any()
.downcast_ref::<UInt64Array>()
.is_some()
});
if !all_u64 {
result.push(batch);
continue;
}
let complements: Vec<f64> = (0..batch.num_rows())
.map(|row| {
let mut key = Vec::with_capacity(left_indices.len());
for &ci in &left_indices {
let vids = batch
.column(ci)
.as_any()
.downcast_ref::<UInt64Array>()
.unwrap();
if vids.is_null(row) {
return 1.0;
}
key.push(vids.value(row));
}
let p = prob_map.get(&key).copied().unwrap_or(0.0);
1.0 - p
})
.collect();
let complement_arr = Float64Array::from(complements);
let mut columns: Vec<arrow_array::ArrayRef> = batch.columns().to_vec();
columns.push(Arc::new(complement_arr));
let mut fields: Vec<Arc<arrow_schema::Field>> =
batch.schema().fields().iter().cloned().collect();
fields.push(Arc::new(arrow_schema::Field::new(
complement_col_name,
arrow_schema::DataType::Float64,
true,
)));
let new_schema = Arc::new(arrow_schema::Schema::new(fields));
let new_batch = RecordBatch::try_new(new_schema, columns).map_err(arrow_err)?;
result.push(new_batch);
}
Ok(result)
}
pub fn apply_anti_join_composite(
batches: Vec<RecordBatch>,
neg_facts: &[RecordBatch],
join_cols: &[(String, String)],
) -> datafusion::error::Result<Vec<RecordBatch>> {
use arrow::compute::filter_record_batch;
use arrow_array::{Array as _, BooleanArray, UInt64Array};
let mut banned: HashSet<Vec<u64>> = HashSet::new();
for batch in neg_facts {
let right_indices: Vec<usize> = join_cols
.iter()
.filter_map(|(_, rc)| batch.schema().index_of(rc).ok())
.collect();
if right_indices.len() != join_cols.len() {
continue;
}
for row in 0..batch.num_rows() {
let mut key = Vec::with_capacity(right_indices.len());
let mut valid = true;
for &ci in &right_indices {
let col = batch.column(ci);
if let Some(vids) = col.as_any().downcast_ref::<UInt64Array>() {
if vids.is_null(row) {
valid = false;
break;
}
key.push(vids.value(row));
} else {
valid = false;
break;
}
}
if valid {
banned.insert(key);
}
}
}
if banned.is_empty() {
return Ok(batches);
}
let mut result = Vec::new();
for batch in batches {
let left_indices: Vec<usize> = join_cols
.iter()
.filter_map(|(lc, _)| batch.schema().index_of(lc).ok())
.collect();
if left_indices.len() != join_cols.len() {
result.push(batch);
continue;
}
let all_u64 = left_indices.iter().all(|&ci| {
batch
.column(ci)
.as_any()
.downcast_ref::<UInt64Array>()
.is_some()
});
if !all_u64 {
result.push(batch);
continue;
}
let keep: Vec<bool> = (0..batch.num_rows())
.map(|row| {
let mut key = Vec::with_capacity(left_indices.len());
for &ci in &left_indices {
let vids = batch
.column(ci)
.as_any()
.downcast_ref::<UInt64Array>()
.unwrap();
if vids.is_null(row) {
return true; }
key.push(vids.value(row));
}
!banned.contains(&key)
})
.collect();
let keep_arr = BooleanArray::from(keep);
let filtered = filter_record_batch(&batch, &keep_arr).map_err(arrow_err)?;
if filtered.num_rows() > 0 {
result.push(filtered);
}
}
Ok(result)
}
pub fn multiply_prob_factors(
batches: Vec<RecordBatch>,
prob_col: Option<&str>,
complement_cols: &[String],
) -> datafusion::error::Result<Vec<RecordBatch>> {
use arrow_array::{Array as _, Float64Array};
let mut result = Vec::with_capacity(batches.len());
for batch in batches {
if batch.num_rows() == 0 {
let keep: Vec<usize> = batch
.schema()
.fields()
.iter()
.enumerate()
.filter(|(_, f)| !complement_cols.contains(f.name()))
.map(|(i, _)| i)
.collect();
let fields: Vec<_> = keep
.iter()
.map(|&i| batch.schema().field(i).clone())
.collect();
let cols: Vec<_> = keep.iter().map(|&i| batch.column(i).clone()).collect();
let schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
result.push(
RecordBatch::try_new(schema, cols).map_err(|e| {
datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
})?,
);
continue;
}
let num_rows = batch.num_rows();
let mut combined = vec![1.0f64; num_rows];
for col_name in complement_cols {
if let Ok(idx) = batch.schema().index_of(col_name) {
let arr = batch
.column(idx)
.as_any()
.downcast_ref::<Float64Array>()
.ok_or_else(|| {
datafusion::error::DataFusionError::Internal(format!(
"Expected Float64 for complement column {col_name}"
))
})?;
for (i, val) in combined.iter_mut().enumerate().take(num_rows) {
if !arr.is_null(i) {
*val *= arr.value(i);
}
}
}
}
let final_prob: Vec<f64> = if let Some(prob_name) = prob_col {
if let Ok(idx) = batch.schema().index_of(prob_name) {
let arr = batch
.column(idx)
.as_any()
.downcast_ref::<Float64Array>()
.ok_or_else(|| {
datafusion::error::DataFusionError::Internal(format!(
"Expected Float64 for PROB column {prob_name}"
))
})?;
(0..num_rows)
.map(|i| {
if arr.is_null(i) {
combined[i]
} else {
arr.value(i) * combined[i]
}
})
.collect()
} else {
combined
}
} else {
combined
};
let new_prob_array: arrow_array::ArrayRef =
std::sync::Arc::new(Float64Array::from(final_prob));
let mut fields = Vec::new();
let mut columns = Vec::new();
for (idx, field) in batch.schema().fields().iter().enumerate() {
if complement_cols.contains(field.name()) {
continue;
}
if prob_col.is_some_and(|p| field.name() == p) {
fields.push(field.clone());
columns.push(new_prob_array.clone());
} else {
fields.push(field.clone());
columns.push(batch.column(idx).clone());
}
}
let schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
result.push(RecordBatch::try_new(schema, columns).map_err(arrow_err)?);
}
Ok(result)
}
fn update_derived_scan_handles(
registry: &DerivedScanRegistry,
states: &[FixpointState],
current_rule_idx: usize,
rules: &[FixpointRulePlan],
) {
let current_rule_name = &rules[current_rule_idx].name;
for entry in ®istry.entries {
let source_state_idx = rules.iter().position(|r| r.name == entry.rule_name);
let Some(source_idx) = source_state_idx else {
continue;
};
let is_self = entry.rule_name == *current_rule_name;
let data = if is_self {
states[source_idx].all_delta().to_vec()
} else {
states[source_idx].all_facts().to_vec()
};
let data = if data.is_empty() || data.iter().all(|b| b.num_rows() == 0) {
vec![RecordBatch::new_empty(Arc::clone(&entry.schema))]
} else {
data
};
let mut guard = entry.data.write();
*guard = data;
}
}
pub struct DerivedScanExec {
data: Arc<RwLock<Vec<RecordBatch>>>,
schema: SchemaRef,
properties: PlanProperties,
}
impl DerivedScanExec {
pub fn new(data: Arc<RwLock<Vec<RecordBatch>>>, schema: SchemaRef) -> Self {
let properties = compute_plan_properties(Arc::clone(&schema));
Self {
data,
schema,
properties,
}
}
}
impl fmt::Debug for DerivedScanExec {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("DerivedScanExec")
.field("schema", &self.schema)
.finish()
}
}
impl DisplayAs for DerivedScanExec {
fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "DerivedScanExec")
}
}
impl ExecutionPlan for DerivedScanExec {
fn name(&self) -> &str {
"DerivedScanExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
fn properties(&self) -> &PlanProperties {
&self.properties
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![]
}
fn with_new_children(
self: Arc<Self>,
_children: Vec<Arc<dyn ExecutionPlan>>,
) -> DFResult<Arc<dyn ExecutionPlan>> {
Ok(self)
}
fn execute(
&self,
_partition: usize,
_context: Arc<TaskContext>,
) -> DFResult<SendableRecordBatchStream> {
let batches = {
let guard = self.data.read();
if guard.is_empty() {
vec![RecordBatch::new_empty(Arc::clone(&self.schema))]
} else {
guard.clone()
}
};
Ok(Box::pin(MemoryStream::try_new(
batches,
Arc::clone(&self.schema),
None,
)?))
}
}
struct InMemoryExec {
batches: Vec<RecordBatch>,
schema: SchemaRef,
properties: PlanProperties,
}
impl InMemoryExec {
fn new(batches: Vec<RecordBatch>, schema: SchemaRef) -> Self {
let properties = compute_plan_properties(Arc::clone(&schema));
Self {
batches,
schema,
properties,
}
}
}
impl fmt::Debug for InMemoryExec {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("InMemoryExec")
.field("num_batches", &self.batches.len())
.field("schema", &self.schema)
.finish()
}
}
impl DisplayAs for InMemoryExec {
fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "InMemoryExec: batches={}", self.batches.len())
}
}
impl ExecutionPlan for InMemoryExec {
fn name(&self) -> &str {
"InMemoryExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
fn properties(&self) -> &PlanProperties {
&self.properties
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![]
}
fn with_new_children(
self: Arc<Self>,
_children: Vec<Arc<dyn ExecutionPlan>>,
) -> DFResult<Arc<dyn ExecutionPlan>> {
Ok(self)
}
fn execute(
&self,
_partition: usize,
_context: Arc<TaskContext>,
) -> DFResult<SendableRecordBatchStream> {
Ok(Box::pin(MemoryStream::try_new(
self.batches.clone(),
Arc::clone(&self.schema),
None,
)?))
}
}
fn apply_having_filter(
batches: Vec<RecordBatch>,
having_exprs: &[Expr],
schema: &SchemaRef,
) -> DFResult<Vec<RecordBatch>> {
use arrow::compute::{and, filter_record_batch};
use arrow_array::BooleanArray;
use datafusion::common::DFSchema;
use datafusion::logical_expr::LogicalPlanBuilder;
use datafusion::optimizer::AnalyzerRule;
use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
use datafusion::physical_expr::create_physical_expr;
use datafusion::prelude::SessionContext;
if batches.is_empty() {
return Ok(batches);
}
let df_schema = DFSchema::try_from(schema.as_ref().clone()).map_err(|e| {
datafusion::common::DataFusionError::Internal(format!("HAVING schema conversion: {e}"))
})?;
let ctx = SessionContext::new();
let state = ctx.state();
let config = state.config_options().clone();
let props = state.execution_props();
let physical_exprs: Vec<Arc<dyn datafusion::physical_expr::PhysicalExpr>> = having_exprs
.iter()
.map(|expr| {
let df_expr = crate::query::df_expr::cypher_expr_to_df(expr, None).map_err(|e| {
datafusion::common::DataFusionError::Internal(format!(
"HAVING expression conversion: {e}"
))
})?;
let empty = datafusion::logical_expr::LogicalPlan::EmptyRelation(
datafusion::logical_expr::EmptyRelation {
produce_one_row: false,
schema: Arc::new(df_schema.clone()),
},
);
let filter_plan = LogicalPlanBuilder::from(empty)
.filter(df_expr.clone())?
.build()?;
let coerced_expr = match TypeCoercion::new().analyze(filter_plan, &config) {
Ok(datafusion::logical_expr::LogicalPlan::Filter(f)) => f.predicate,
_ => df_expr,
};
create_physical_expr(&coerced_expr, &df_schema, props)
})
.collect::<DFResult<Vec<_>>>()?;
let mut result = Vec::new();
for batch in batches {
let mut mask: Option<BooleanArray> = None;
for phys_expr in &physical_exprs {
let value = phys_expr.evaluate(&batch)?;
let arr = value.into_array(batch.num_rows())?;
let bool_arr = arr.as_any().downcast_ref::<BooleanArray>().ok_or_else(|| {
datafusion::common::DataFusionError::Internal(
"HAVING condition must evaluate to boolean".into(),
)
})?;
mask = Some(match mask {
None => bool_arr.clone(),
Some(prev) => and(&prev, bool_arr).map_err(arrow_err)?,
});
}
if let Some(ref m) = mask {
let filtered = filter_record_batch(&batch, m).map_err(arrow_err)?;
if filtered.num_rows() > 0 {
result.push(filtered);
}
} else {
result.push(batch);
}
}
Ok(result)
}
pub(crate) async fn apply_post_fixpoint_chain(
facts: Vec<RecordBatch>,
rule: &FixpointRulePlan,
task_ctx: &Arc<TaskContext>,
strict_probability_domain: bool,
probability_epsilon: f64,
) -> DFResult<Vec<RecordBatch>> {
if !rule.has_fold && !rule.has_best_by && !rule.has_priority && rule.having.is_empty() {
return Ok(facts);
}
let schema = facts
.iter()
.find(|b| b.num_rows() > 0)
.map(|b| b.schema())
.unwrap_or_else(|| Arc::clone(&rule.yield_schema));
let input: Arc<dyn ExecutionPlan> = Arc::new(InMemoryExec::new(facts, schema.clone()));
let key_column_indices: Vec<usize> = rule
.key_column_indices
.iter()
.filter_map(|&i| {
let name = rule.yield_schema.field(i).name();
schema.index_of(name).ok()
})
.collect();
let current: Arc<dyn ExecutionPlan> = if rule.has_priority {
let priority_schema = input.schema();
let priority_idx = priority_schema.index_of("__priority").map_err(|_| {
datafusion::common::DataFusionError::Internal(
"PRIORITY rule missing __priority column".to_string(),
)
})?;
Arc::new(PriorityExec::new(
input,
key_column_indices.clone(),
priority_idx,
))
} else {
input
};
let current: Arc<dyn ExecutionPlan> = if rule.has_fold && !rule.fold_bindings.is_empty() {
Arc::new(FoldExec::new(
current,
key_column_indices.clone(),
rule.fold_bindings.clone(),
strict_probability_domain,
probability_epsilon,
))
} else {
current
};
let current: Arc<dyn ExecutionPlan> = if !rule.having.is_empty() {
let batches = collect_all_partitions(¤t, Arc::clone(task_ctx)).await?;
let filtered = apply_having_filter(batches, &rule.having, ¤t.schema())?;
if filtered.is_empty() {
return Ok(filtered);
}
Arc::new(InMemoryExec::new(filtered, Arc::clone(¤t.schema())))
} else {
current
};
let current: Arc<dyn ExecutionPlan> = if rule.has_best_by && !rule.best_by_criteria.is_empty() {
Arc::new(BestByExec::new(
current,
key_column_indices.clone(),
rule.best_by_criteria.clone(),
rule.deterministic,
))
} else {
current
};
collect_all_partitions(¤t, Arc::clone(task_ctx)).await
}
pub struct FixpointExec {
rules: Vec<FixpointRulePlan>,
max_iterations: usize,
timeout: Duration,
graph_ctx: Arc<GraphExecutionContext>,
session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
storage: Arc<StorageManager>,
schema_info: Arc<UniSchema>,
params: HashMap<String, Value>,
derived_scan_registry: Arc<DerivedScanRegistry>,
output_schema: SchemaRef,
properties: PlanProperties,
metrics: ExecutionPlanMetricsSet,
max_derived_bytes: usize,
derivation_tracker: Option<Arc<ProvenanceStore>>,
iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
strict_probability_domain: bool,
probability_epsilon: f64,
exact_probability: bool,
max_bdd_variables: usize,
warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
top_k_proofs: usize,
timeout_flag: Arc<std::sync::atomic::AtomicBool>,
}
impl fmt::Debug for FixpointExec {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FixpointExec")
.field("rules_count", &self.rules.len())
.field("max_iterations", &self.max_iterations)
.field("timeout", &self.timeout)
.field("output_schema", &self.output_schema)
.field("max_derived_bytes", &self.max_derived_bytes)
.finish_non_exhaustive()
}
}
impl FixpointExec {
#[expect(
clippy::too_many_arguments,
reason = "FixpointExec configuration needs all context"
)]
pub fn new(
rules: Vec<FixpointRulePlan>,
max_iterations: usize,
timeout: Duration,
graph_ctx: Arc<GraphExecutionContext>,
session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
storage: Arc<StorageManager>,
schema_info: Arc<UniSchema>,
params: HashMap<String, Value>,
derived_scan_registry: Arc<DerivedScanRegistry>,
output_schema: SchemaRef,
max_derived_bytes: usize,
derivation_tracker: Option<Arc<ProvenanceStore>>,
iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
strict_probability_domain: bool,
probability_epsilon: f64,
exact_probability: bool,
max_bdd_variables: usize,
warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
top_k_proofs: usize,
timeout_flag: Arc<std::sync::atomic::AtomicBool>,
) -> Self {
let properties = compute_plan_properties(Arc::clone(&output_schema));
Self {
rules,
max_iterations,
timeout,
graph_ctx,
session_ctx,
storage,
schema_info,
params,
derived_scan_registry,
output_schema,
properties,
metrics: ExecutionPlanMetricsSet::new(),
max_derived_bytes,
derivation_tracker,
iteration_counts,
strict_probability_domain,
probability_epsilon,
exact_probability,
max_bdd_variables,
warnings_slot,
approximate_slot,
top_k_proofs,
timeout_flag,
}
}
pub fn iteration_counts(&self) -> Arc<StdRwLock<HashMap<String, usize>>> {
Arc::clone(&self.iteration_counts)
}
}
impl DisplayAs for FixpointExec {
fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"FixpointExec: rules=[{}], max_iter={}, timeout={:?}",
self.rules
.iter()
.map(|r| r.name.as_str())
.collect::<Vec<_>>()
.join(", "),
self.max_iterations,
self.timeout,
)
}
}
impl ExecutionPlan for FixpointExec {
fn name(&self) -> &str {
"FixpointExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
Arc::clone(&self.output_schema)
}
fn properties(&self) -> &PlanProperties {
&self.properties
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> DFResult<Arc<dyn ExecutionPlan>> {
if !children.is_empty() {
return Err(datafusion::error::DataFusionError::Plan(
"FixpointExec has no children".to_string(),
));
}
Ok(self)
}
fn execute(
&self,
partition: usize,
_context: Arc<TaskContext>,
) -> DFResult<SendableRecordBatchStream> {
let metrics = BaselineMetrics::new(&self.metrics, partition);
let rules = self
.rules
.iter()
.map(|r| {
FixpointRulePlan {
name: r.name.clone(),
clauses: r
.clauses
.iter()
.map(|c| FixpointClausePlan {
body_logical: c.body_logical.clone(),
is_ref_bindings: c.is_ref_bindings.clone(),
priority: c.priority,
along_bindings: c.along_bindings.clone(),
})
.collect(),
yield_schema: Arc::clone(&r.yield_schema),
key_column_indices: r.key_column_indices.clone(),
priority: r.priority,
has_fold: r.has_fold,
fold_bindings: r.fold_bindings.clone(),
having: r.having.clone(),
has_best_by: r.has_best_by,
best_by_criteria: r.best_by_criteria.clone(),
has_priority: r.has_priority,
deterministic: r.deterministic,
prob_column_name: r.prob_column_name.clone(),
}
})
.collect();
let max_iterations = self.max_iterations;
let timeout = self.timeout;
let graph_ctx = Arc::clone(&self.graph_ctx);
let session_ctx = Arc::clone(&self.session_ctx);
let storage = Arc::clone(&self.storage);
let schema_info = Arc::clone(&self.schema_info);
let params = self.params.clone();
let registry = Arc::clone(&self.derived_scan_registry);
let output_schema = Arc::clone(&self.output_schema);
let max_derived_bytes = self.max_derived_bytes;
let derivation_tracker = self.derivation_tracker.clone();
let iteration_counts = Arc::clone(&self.iteration_counts);
let strict_probability_domain = self.strict_probability_domain;
let probability_epsilon = self.probability_epsilon;
let exact_probability = self.exact_probability;
let max_bdd_variables = self.max_bdd_variables;
let warnings_slot = Arc::clone(&self.warnings_slot);
let approximate_slot = Arc::clone(&self.approximate_slot);
let top_k_proofs = self.top_k_proofs;
let timeout_flag = Arc::clone(&self.timeout_flag);
let fut = async move {
run_fixpoint_loop(
rules,
max_iterations,
timeout,
graph_ctx,
session_ctx,
storage,
schema_info,
params,
registry,
output_schema,
max_derived_bytes,
derivation_tracker,
iteration_counts,
strict_probability_domain,
probability_epsilon,
exact_probability,
max_bdd_variables,
warnings_slot,
approximate_slot,
top_k_proofs,
timeout_flag,
)
.await
};
Ok(Box::pin(FixpointStream {
state: FixpointStreamState::Running(Box::pin(fut)),
schema: Arc::clone(&self.output_schema),
metrics,
}))
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
}
enum FixpointStreamState {
Running(Pin<Box<dyn std::future::Future<Output = DFResult<Vec<RecordBatch>>> + Send>>),
Emitting(Vec<RecordBatch>, usize),
Done,
}
struct FixpointStream {
state: FixpointStreamState,
schema: SchemaRef,
metrics: BaselineMetrics,
}
impl Stream for FixpointStream {
type Item = DFResult<RecordBatch>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
loop {
match &mut this.state {
FixpointStreamState::Running(fut) => match fut.as_mut().poll(cx) {
Poll::Ready(Ok(batches)) => {
if batches.is_empty() {
this.state = FixpointStreamState::Done;
return Poll::Ready(None);
}
this.state = FixpointStreamState::Emitting(batches, 0);
}
Poll::Ready(Err(e)) => {
this.state = FixpointStreamState::Done;
return Poll::Ready(Some(Err(e)));
}
Poll::Pending => return Poll::Pending,
},
FixpointStreamState::Emitting(batches, idx) => {
if *idx >= batches.len() {
this.state = FixpointStreamState::Done;
return Poll::Ready(None);
}
let batch = batches[*idx].clone();
*idx += 1;
this.metrics.record_output(batch.num_rows());
return Poll::Ready(Some(Ok(batch)));
}
FixpointStreamState::Done => return Poll::Ready(None),
}
}
}
}
impl RecordBatchStream for FixpointStream {
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::{Float64Array, Int64Array, StringArray};
use arrow_schema::{DataType, Field, Schema};
fn test_schema() -> SchemaRef {
Arc::new(Schema::new(vec![
Field::new("name", DataType::Utf8, true),
Field::new("value", DataType::Int64, true),
]))
}
fn make_batch(names: &[&str], values: &[i64]) -> RecordBatch {
RecordBatch::try_new(
test_schema(),
vec![
Arc::new(StringArray::from(
names.iter().map(|s| Some(*s)).collect::<Vec<_>>(),
)),
Arc::new(Int64Array::from(values.to_vec())),
],
)
.unwrap()
}
#[tokio::test]
async fn test_fixpoint_state_empty_facts_adds_all() {
let schema = test_schema();
let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
let batch = make_batch(&["a", "b", "c"], &[1, 2, 3]);
let changed = state.merge_delta(vec![batch], None).await.unwrap();
assert!(changed);
assert_eq!(state.all_facts().len(), 1);
assert_eq!(state.all_facts()[0].num_rows(), 3);
assert_eq!(state.all_delta().len(), 1);
assert_eq!(state.all_delta()[0].num_rows(), 3);
}
#[tokio::test]
async fn test_fixpoint_state_exact_duplicates_excluded() {
let schema = test_schema();
let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
let batch1 = make_batch(&["a", "b"], &[1, 2]);
state.merge_delta(vec![batch1], None).await.unwrap();
let batch2 = make_batch(&["a", "b"], &[1, 2]);
let changed = state.merge_delta(vec![batch2], None).await.unwrap();
assert!(!changed);
assert!(
state.all_delta().is_empty() || state.all_delta().iter().all(|b| b.num_rows() == 0)
);
}
#[tokio::test]
async fn test_fixpoint_state_partial_overlap() {
let schema = test_schema();
let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
let batch1 = make_batch(&["a", "b"], &[1, 2]);
state.merge_delta(vec![batch1], None).await.unwrap();
let batch2 = make_batch(&["a", "c"], &[1, 3]);
let changed = state.merge_delta(vec![batch2], None).await.unwrap();
assert!(changed);
let delta_rows: usize = state.all_delta().iter().map(|b| b.num_rows()).sum();
assert_eq!(delta_rows, 1);
let total_rows: usize = state.all_facts().iter().map(|b| b.num_rows()).sum();
assert_eq!(total_rows, 3);
}
#[tokio::test]
async fn test_fixpoint_state_convergence() {
let schema = test_schema();
let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
let batch = make_batch(&["a"], &[1]);
state.merge_delta(vec![batch], None).await.unwrap();
let changed = state.merge_delta(vec![], None).await.unwrap();
assert!(!changed);
assert!(state.is_converged());
}
#[test]
fn test_row_dedup_persistent_across_calls() {
let schema = test_schema();
let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
let batch1 = make_batch(&["a", "b"], &[1, 2]);
let delta1 = rd.compute_delta(&[batch1], &schema).unwrap();
let rows1: usize = delta1.iter().map(|b| b.num_rows()).sum();
assert_eq!(rows1, 2);
let batch2 = make_batch(&["a", "b"], &[1, 2]);
let delta2 = rd.compute_delta(&[batch2], &schema).unwrap();
let rows2: usize = delta2.iter().map(|b| b.num_rows()).sum();
assert_eq!(rows2, 0);
let batch3 = make_batch(&["a", "c"], &[1, 3]);
let delta3 = rd.compute_delta(&[batch3], &schema).unwrap();
let rows3: usize = delta3.iter().map(|b| b.num_rows()).sum();
assert_eq!(rows3, 1);
}
#[test]
fn test_row_dedup_null_handling() {
use arrow_array::StringArray;
use arrow_schema::{DataType, Field, Schema};
let schema: SchemaRef = Arc::new(Schema::new(vec![
Field::new("a", DataType::Utf8, true),
Field::new("b", DataType::Int64, true),
]));
let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
let batch_nulls = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(StringArray::from(vec![None::<&str>, None::<&str>])),
Arc::new(arrow_array::Int64Array::from(vec![1i64, 1i64])),
],
)
.unwrap();
let delta = rd.compute_delta(&[batch_nulls], &schema).unwrap();
let rows: usize = delta.iter().map(|b| b.num_rows()).sum();
assert_eq!(rows, 1, "two identical NULL rows should be deduped to one");
let batch_diff = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(StringArray::from(vec![None::<&str>])),
Arc::new(arrow_array::Int64Array::from(vec![2i64])),
],
)
.unwrap();
let delta2 = rd.compute_delta(&[batch_diff], &schema).unwrap();
let rows2: usize = delta2.iter().map(|b| b.num_rows()).sum();
assert_eq!(rows2, 1, "(NULL, 2) is distinct from (NULL, 1)");
}
#[test]
fn test_row_dedup_within_candidate_dedup() {
let schema = test_schema();
let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
let batch = make_batch(&["a", "a", "b"], &[1, 1, 2]);
let delta = rd.compute_delta(&[batch], &schema).unwrap();
let rows: usize = delta.iter().map(|b| b.num_rows()).sum();
assert_eq!(rows, 2, "within-batch dup should be collapsed: a:1, b:2");
}
#[test]
fn test_round_float_columns_near_duplicates() {
let schema = Arc::new(Schema::new(vec![
Field::new("name", DataType::Utf8, true),
Field::new("dist", DataType::Float64, true),
]));
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(StringArray::from(vec![Some("a"), Some("a")])),
Arc::new(Float64Array::from(vec![1.0000000000001, 1.0000000000002])),
],
)
.unwrap();
let rounded = round_float_columns(&[batch]);
assert_eq!(rounded.len(), 1);
let col = rounded[0]
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert_eq!(col.value(0), col.value(1));
}
#[test]
fn test_registry_write_read_round_trip() {
let schema = test_schema();
let data = Arc::new(RwLock::new(Vec::new()));
let mut reg = DerivedScanRegistry::new();
reg.add(DerivedScanEntry {
scan_index: 0,
rule_name: "reachable".into(),
is_self_ref: true,
data: Arc::clone(&data),
schema: Arc::clone(&schema),
});
let batch = make_batch(&["x"], &[42]);
reg.write_data(0, vec![batch.clone()]);
let entry = reg.get(0).unwrap();
let guard = entry.data.read();
assert_eq!(guard.len(), 1);
assert_eq!(guard[0].num_rows(), 1);
}
#[test]
fn test_registry_entries_for_rule() {
let schema = test_schema();
let mut reg = DerivedScanRegistry::new();
reg.add(DerivedScanEntry {
scan_index: 0,
rule_name: "r1".into(),
is_self_ref: true,
data: Arc::new(RwLock::new(Vec::new())),
schema: Arc::clone(&schema),
});
reg.add(DerivedScanEntry {
scan_index: 1,
rule_name: "r2".into(),
is_self_ref: false,
data: Arc::new(RwLock::new(Vec::new())),
schema: Arc::clone(&schema),
});
reg.add(DerivedScanEntry {
scan_index: 2,
rule_name: "r1".into(),
is_self_ref: false,
data: Arc::new(RwLock::new(Vec::new())),
schema: Arc::clone(&schema),
});
assert_eq!(reg.entries_for_rule("r1").len(), 2);
assert_eq!(reg.entries_for_rule("r2").len(), 1);
assert_eq!(reg.entries_for_rule("r3").len(), 0);
}
#[test]
fn test_monotonic_agg_update_and_stability() {
use crate::query::df_graph::locy_fold::FoldAggKind;
let bindings = vec![MonotonicFoldBinding {
fold_name: "total".into(),
kind: FoldAggKind::Sum,
input_col_index: 1,
input_col_name: None,
}];
let mut agg = MonotonicAggState::new(bindings);
let batch = make_batch(&["a"], &[10]);
agg.snapshot();
let changed = agg.update(&[0], &[batch], false).unwrap();
assert!(changed);
assert!(!agg.is_stable());
agg.snapshot();
let changed = agg.update(&[0], &[], false).unwrap();
assert!(!changed);
assert!(agg.is_stable());
}
#[tokio::test]
async fn test_memory_limit_exceeded() {
let schema = test_schema();
let mut state = FixpointState::new("test".into(), schema, vec![0], 1, None, false);
let batch = make_batch(&["a", "b", "c"], &[1, 2, 3]);
let result = state.merge_delta(vec![batch], None).await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("memory limit"), "Error was: {}", err);
}
#[tokio::test]
async fn test_fixpoint_stream_emitting() {
use futures::StreamExt;
let schema = test_schema();
let batch1 = make_batch(&["a"], &[1]);
let batch2 = make_batch(&["b"], &[2]);
let metrics = ExecutionPlanMetricsSet::new();
let baseline = BaselineMetrics::new(&metrics, 0);
let mut stream = FixpointStream {
state: FixpointStreamState::Emitting(vec![batch1, batch2], 0),
schema,
metrics: baseline,
};
let stream = Pin::new(&mut stream);
let batches: Vec<RecordBatch> = stream.filter_map(|r| async { r.ok() }).collect().await;
assert_eq!(batches.len(), 2);
assert_eq!(batches[0].num_rows(), 1);
assert_eq!(batches[1].num_rows(), 1);
}
fn make_f64_batch(names: &[&str], values: &[f64]) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![
Field::new("name", DataType::Utf8, true),
Field::new("value", DataType::Float64, true),
]));
RecordBatch::try_new(
schema,
vec![
Arc::new(StringArray::from(
names.iter().map(|s| Some(*s)).collect::<Vec<_>>(),
)),
Arc::new(Float64Array::from(values.to_vec())),
],
)
.unwrap()
}
fn make_nor_binding() -> Vec<MonotonicFoldBinding> {
use crate::query::df_graph::locy_fold::FoldAggKind;
vec![MonotonicFoldBinding {
fold_name: "prob".into(),
kind: FoldAggKind::Nor,
input_col_index: 1,
input_col_name: None,
}]
}
fn make_prod_binding() -> Vec<MonotonicFoldBinding> {
use crate::query::df_graph::locy_fold::FoldAggKind;
vec![MonotonicFoldBinding {
fold_name: "prob".into(),
kind: FoldAggKind::Prod,
input_col_index: 1,
input_col_name: None,
}]
}
fn acc_key(name: &str) -> (Vec<ScalarKey>, String) {
(vec![ScalarKey::Utf8(name.to_string())], "prob".to_string())
}
#[test]
fn test_monotonic_nor_first_update() {
let mut agg = MonotonicAggState::new(make_nor_binding());
let batch = make_f64_batch(&["a"], &[0.3]);
let changed = agg.update(&[0], &[batch], false).unwrap();
assert!(changed);
let val = agg.get_accumulator(&acc_key("a")).unwrap();
assert!((val - 0.3).abs() < 1e-10, "expected 0.3, got {}", val);
}
#[test]
fn test_monotonic_nor_two_updates() {
let mut agg = MonotonicAggState::new(make_nor_binding());
let batch1 = make_f64_batch(&["a"], &[0.3]);
agg.update(&[0], &[batch1], false).unwrap();
let batch2 = make_f64_batch(&["a"], &[0.5]);
agg.update(&[0], &[batch2], false).unwrap();
let val = agg.get_accumulator(&acc_key("a")).unwrap();
assert!((val - 0.65).abs() < 1e-10, "expected 0.65, got {}", val);
}
#[test]
fn test_monotonic_prod_first_update() {
let mut agg = MonotonicAggState::new(make_prod_binding());
let batch = make_f64_batch(&["a"], &[0.6]);
let changed = agg.update(&[0], &[batch], false).unwrap();
assert!(changed);
let val = agg.get_accumulator(&acc_key("a")).unwrap();
assert!((val - 0.6).abs() < 1e-10, "expected 0.6, got {}", val);
}
#[test]
fn test_monotonic_prod_two_updates() {
let mut agg = MonotonicAggState::new(make_prod_binding());
let batch1 = make_f64_batch(&["a"], &[0.6]);
agg.update(&[0], &[batch1], false).unwrap();
let batch2 = make_f64_batch(&["a"], &[0.8]);
agg.update(&[0], &[batch2], false).unwrap();
let val = agg.get_accumulator(&acc_key("a")).unwrap();
assert!((val - 0.48).abs() < 1e-10, "expected 0.48, got {}", val);
}
#[test]
fn test_monotonic_nor_stability() {
let mut agg = MonotonicAggState::new(make_nor_binding());
let batch = make_f64_batch(&["a"], &[0.3]);
agg.update(&[0], &[batch], false).unwrap();
agg.snapshot();
let changed = agg.update(&[0], &[], false).unwrap();
assert!(!changed);
assert!(agg.is_stable());
}
#[test]
fn test_monotonic_prod_stability() {
let mut agg = MonotonicAggState::new(make_prod_binding());
let batch = make_f64_batch(&["a"], &[0.6]);
agg.update(&[0], &[batch], false).unwrap();
agg.snapshot();
let changed = agg.update(&[0], &[], false).unwrap();
assert!(!changed);
assert!(agg.is_stable());
}
#[test]
fn test_monotonic_nor_multi_group() {
let mut agg = MonotonicAggState::new(make_nor_binding());
let batch1 = make_f64_batch(&["a", "b"], &[0.3, 0.5]);
agg.update(&[0], &[batch1], false).unwrap();
let batch2 = make_f64_batch(&["a", "b"], &[0.5, 0.2]);
agg.update(&[0], &[batch2], false).unwrap();
let val_a = agg.get_accumulator(&acc_key("a")).unwrap();
let val_b = agg.get_accumulator(&acc_key("b")).unwrap();
assert!(
(val_a - 0.65).abs() < 1e-10,
"expected a=0.65, got {}",
val_a
);
assert!((val_b - 0.6).abs() < 1e-10, "expected b=0.6, got {}", val_b);
}
#[test]
fn test_monotonic_prod_zero_absorbing() {
let mut agg = MonotonicAggState::new(make_prod_binding());
let batch1 = make_f64_batch(&["a"], &[0.5]);
agg.update(&[0], &[batch1], false).unwrap();
let batch2 = make_f64_batch(&["a"], &[0.0]);
agg.update(&[0], &[batch2], false).unwrap();
let val = agg.get_accumulator(&acc_key("a")).unwrap();
assert!((val - 0.0).abs() < 1e-10, "expected 0.0, got {}", val);
agg.snapshot();
let batch3 = make_f64_batch(&["a"], &[0.5]);
let changed = agg.update(&[0], &[batch3], false).unwrap();
assert!(!changed);
assert!(agg.is_stable());
}
#[test]
fn test_monotonic_nor_clamping() {
let mut agg = MonotonicAggState::new(make_nor_binding());
let batch = make_f64_batch(&["a"], &[1.5]);
agg.update(&[0], &[batch], false).unwrap();
let val = agg.get_accumulator(&acc_key("a")).unwrap();
assert!((val - 1.0).abs() < 1e-10, "expected 1.0, got {}", val);
}
#[test]
fn test_monotonic_nor_absorbing() {
let mut agg = MonotonicAggState::new(make_nor_binding());
let batch1 = make_f64_batch(&["a"], &[0.3]);
agg.update(&[0], &[batch1], false).unwrap();
let batch2 = make_f64_batch(&["a"], &[1.0]);
agg.update(&[0], &[batch2], false).unwrap();
let val = agg.get_accumulator(&acc_key("a")).unwrap();
assert!((val - 1.0).abs() < 1e-10, "expected 1.0, got {}", val);
}
#[test]
fn test_monotonic_agg_strict_nor_rejects() {
let mut agg = MonotonicAggState::new(make_nor_binding());
let batch = make_f64_batch(&["a"], &[1.5]);
let result = agg.update(&[0], &[batch], true);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("strict_probability_domain"),
"Expected strict error, got: {}",
err
);
}
#[test]
fn test_monotonic_agg_strict_prod_rejects() {
let mut agg = MonotonicAggState::new(make_prod_binding());
let batch = make_f64_batch(&["a"], &[2.0]);
let result = agg.update(&[0], &[batch], true);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("strict_probability_domain"),
"Expected strict error, got: {}",
err
);
}
#[test]
fn test_monotonic_agg_strict_accepts_valid() {
let mut agg = MonotonicAggState::new(make_nor_binding());
let batch = make_f64_batch(&["a"], &[0.5]);
let result = agg.update(&[0], &[batch], true);
assert!(result.is_ok());
let val = agg.get_accumulator(&acc_key("a")).unwrap();
assert!((val - 0.5).abs() < 1e-10, "expected 0.5, got {}", val);
}
fn make_vid_prob_batch(vids: &[u64], probs: &[f64]) -> RecordBatch {
use arrow_array::UInt64Array;
let schema = Arc::new(Schema::new(vec![
Field::new("vid", DataType::UInt64, true),
Field::new("prob", DataType::Float64, true),
]));
RecordBatch::try_new(
schema,
vec![
Arc::new(UInt64Array::from(vids.to_vec())),
Arc::new(Float64Array::from(probs.to_vec())),
],
)
.unwrap()
}
#[test]
fn test_prob_complement_basic() {
let body = make_vid_prob_batch(&[1, 2], &[0.9, 0.8]);
let neg = make_vid_prob_batch(&[1], &[0.7]);
let join_cols = vec![("vid".to_string(), "vid".to_string())];
let result = apply_prob_complement_composite(
vec![body],
&[neg],
&join_cols,
"prob",
"__complement_0",
)
.unwrap();
assert_eq!(result.len(), 1);
let batch = &result[0];
let complement = batch
.column_by_name("__complement_0")
.unwrap()
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert!(
(complement.value(0) - 0.3).abs() < 1e-10,
"expected 0.3, got {}",
complement.value(0)
);
assert!(
(complement.value(1) - 1.0).abs() < 1e-10,
"expected 1.0, got {}",
complement.value(1)
);
}
#[test]
fn test_prob_complement_noisy_or_duplicates() {
let body = make_vid_prob_batch(&[1], &[0.9]);
let neg = make_vid_prob_batch(&[1, 1], &[0.3, 0.5]);
let join_cols = vec![("vid".to_string(), "vid".to_string())];
let result = apply_prob_complement_composite(
vec![body],
&[neg],
&join_cols,
"prob",
"__complement_0",
)
.unwrap();
let batch = &result[0];
let complement = batch
.column_by_name("__complement_0")
.unwrap()
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert!(
(complement.value(0) - 0.35).abs() < 1e-10,
"expected 0.35, got {}",
complement.value(0)
);
}
#[test]
fn test_prob_complement_empty_neg() {
let body = make_vid_prob_batch(&[1, 2], &[0.5, 0.6]);
let join_cols = vec![("vid".to_string(), "vid".to_string())];
let result =
apply_prob_complement_composite(vec![body], &[], &join_cols, "prob", "__complement_0")
.unwrap();
let batch = &result[0];
let complement = batch
.column_by_name("__complement_0")
.unwrap()
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
for i in 0..2 {
assert!(
(complement.value(i) - 1.0).abs() < 1e-10,
"row {}: expected 1.0, got {}",
i,
complement.value(i)
);
}
}
#[test]
fn test_anti_join_basic() {
use arrow_array::UInt64Array;
let body = make_vid_prob_batch(&[1, 2, 3], &[0.5, 0.6, 0.7]);
let neg = make_vid_prob_batch(&[2], &[0.0]);
let join_cols = vec![("vid".to_string(), "vid".to_string())];
let result = apply_anti_join_composite(vec![body], &[neg], &join_cols).unwrap();
assert_eq!(result.len(), 1);
let batch = &result[0];
assert_eq!(batch.num_rows(), 2);
let vids = batch
.column_by_name("vid")
.unwrap()
.as_any()
.downcast_ref::<UInt64Array>()
.unwrap();
assert_eq!(vids.value(0), 1);
assert_eq!(vids.value(1), 3);
}
#[test]
fn test_anti_join_empty_neg() {
let body = make_vid_prob_batch(&[1, 2, 3], &[0.5, 0.6, 0.7]);
let join_cols = vec![("vid".to_string(), "vid".to_string())];
let result = apply_anti_join_composite(vec![body], &[], &join_cols).unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].num_rows(), 3);
}
#[test]
fn test_anti_join_all_excluded() {
let body = make_vid_prob_batch(&[1, 2], &[0.5, 0.6]);
let neg = make_vid_prob_batch(&[1, 2], &[0.0, 0.0]);
let join_cols = vec![("vid".to_string(), "vid".to_string())];
let result = apply_anti_join_composite(vec![body], &[neg], &join_cols).unwrap();
let total: usize = result.iter().map(|b| b.num_rows()).sum();
assert_eq!(total, 0);
}
#[test]
fn test_multiply_prob_single_complement() {
let body = make_vid_prob_batch(&[1], &[0.8]);
let complement_arr = Float64Array::from(vec![0.5]);
let mut cols: Vec<arrow_array::ArrayRef> = body.columns().to_vec();
cols.push(Arc::new(complement_arr));
let mut fields: Vec<Arc<Field>> = body.schema().fields().iter().cloned().collect();
fields.push(Arc::new(Field::new(
"__complement_0",
DataType::Float64,
true,
)));
let schema = Arc::new(Schema::new(fields));
let batch = RecordBatch::try_new(schema, cols).unwrap();
let result =
multiply_prob_factors(vec![batch], Some("prob"), &["__complement_0".to_string()])
.unwrap();
assert_eq!(result.len(), 1);
let out = &result[0];
assert!(out.column_by_name("__complement_0").is_none());
let prob = out
.column_by_name("prob")
.unwrap()
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert!(
(prob.value(0) - 0.4).abs() < 1e-10,
"expected 0.4, got {}",
prob.value(0)
);
}
#[test]
fn test_multiply_prob_multiple_complements() {
let body = make_vid_prob_batch(&[1], &[0.8]);
let c1 = Float64Array::from(vec![0.5]);
let c2 = Float64Array::from(vec![0.6]);
let mut cols: Vec<arrow_array::ArrayRef> = body.columns().to_vec();
cols.push(Arc::new(c1));
cols.push(Arc::new(c2));
let mut fields: Vec<Arc<Field>> = body.schema().fields().iter().cloned().collect();
fields.push(Arc::new(Field::new("__c1", DataType::Float64, true)));
fields.push(Arc::new(Field::new("__c2", DataType::Float64, true)));
let schema = Arc::new(Schema::new(fields));
let batch = RecordBatch::try_new(schema, cols).unwrap();
let result = multiply_prob_factors(
vec![batch],
Some("prob"),
&["__c1".to_string(), "__c2".to_string()],
)
.unwrap();
let out = &result[0];
assert!(out.column_by_name("__c1").is_none());
assert!(out.column_by_name("__c2").is_none());
let prob = out
.column_by_name("prob")
.unwrap()
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert!(
(prob.value(0) - 0.24).abs() < 1e-10,
"expected 0.24, got {}",
prob.value(0)
);
}
#[test]
fn test_multiply_prob_no_prob_column() {
use arrow_array::UInt64Array;
let schema = Arc::new(Schema::new(vec![
Field::new("vid", DataType::UInt64, true),
Field::new("__c1", DataType::Float64, true),
]));
let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(UInt64Array::from(vec![1u64])),
Arc::new(Float64Array::from(vec![0.7])),
],
)
.unwrap();
let result = multiply_prob_factors(vec![batch], None, &["__c1".to_string()]).unwrap();
let out = &result[0];
assert!(out.column_by_name("__c1").is_none());
assert_eq!(out.num_columns(), 1);
}
}