use crate::query::df_graph::GraphExecutionContext;
use crate::query::df_graph::common::{
collect_all_partitions, compute_plan_properties, execute_subplan,
};
use crate::query::df_graph::locy_best_by::SortCriterion;
use crate::query::df_graph::locy_explain::ProvenanceStore;
use crate::query::df_graph::locy_fixpoint::{
DerivedScanRegistry, FixpointClausePlan, FixpointExec, FixpointRulePlan, IsRefBinding,
};
use crate::query::df_graph::locy_fold::{FoldBinding, resolve_locy_aggregate};
use crate::query::planner_locy_types::{
LocyCommand, LocyIsRef, LocyRulePlan, LocyStratum, LocyYieldColumn,
};
use arrow_array::RecordBatch;
use arrow_schema::{DataType, Field, Schema as ArrowSchema, SchemaRef};
use datafusion::common::Result as DFResult;
use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
use futures::Stream;
use parking_lot::RwLock;
use std::any::Any;
use std::collections::HashMap;
use std::fmt;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::RwLock as StdRwLock;
use std::task::{Context, Poll};
use std::time::{Duration, Instant};
use uni_common::Value;
use uni_common::core::schema::Schema as UniSchema;
use uni_cypher::ast::Expr;
use uni_locy::{
ClassifierRegistry, CommandResult, FactRow, ModelInvocationCache, RuntimeWarning, SemiringKind,
};
use uni_plugin::PluginRegistry;
use uni_store::storage::manager::StorageManager;
pub struct DerivedStore {
relations: HashMap<String, Vec<RecordBatch>>,
}
impl Default for DerivedStore {
fn default() -> Self {
Self::new()
}
}
impl DerivedStore {
pub fn new() -> Self {
Self {
relations: HashMap::new(),
}
}
pub fn insert(&mut self, rule_name: String, facts: Vec<RecordBatch>) {
self.relations.insert(rule_name, facts);
}
pub fn get(&self, rule_name: &str) -> Option<&Vec<RecordBatch>> {
self.relations.get(rule_name)
}
pub fn fact_count(&self, rule_name: &str) -> usize {
self.relations
.get(rule_name)
.map(|batches| batches.iter().map(|b| b.num_rows()).sum())
.unwrap_or(0)
}
pub fn rule_names(&self) -> impl Iterator<Item = &str> {
self.relations.keys().map(|s| s.as_str())
}
}
pub struct LocyProgramExec {
strata: Vec<LocyStratum>,
commands: Vec<LocyCommand>,
derived_scan_registry: Arc<DerivedScanRegistry>,
plugin_registry: Arc<PluginRegistry>,
graph_ctx: Arc<GraphExecutionContext>,
session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
storage: Arc<StorageManager>,
schema_info: Arc<UniSchema>,
params: HashMap<String, Value>,
output_schema: SchemaRef,
properties: Arc<PlanProperties>,
metrics: ExecutionPlanMetricsSet,
max_iterations: usize,
timeout: Duration,
max_derived_bytes: usize,
deterministic_best_by: bool,
strict_probability_domain: bool,
probability_epsilon: f64,
exact_probability: bool,
max_bdd_variables: usize,
semiring_kind: SemiringKind,
classifier_registry: Arc<ClassifierRegistry>,
classifier_cache: Option<Arc<ModelInvocationCache>>,
classifier_provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
derived_store_slot: Arc<StdRwLock<Option<DerivedStore>>>,
approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
derivation_tracker: Arc<StdRwLock<Option<Arc<ProvenanceStore>>>>,
iteration_counts_slot: Arc<StdRwLock<HashMap<String, usize>>>,
peak_memory_slot: Arc<StdRwLock<usize>>,
warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
command_results_slot: Arc<StdRwLock<Vec<(usize, CommandResult)>>>,
top_k_proofs: usize,
timeout_flag: Arc<std::sync::atomic::AtomicU8>,
incomplete_slot: Arc<StdRwLock<Option<uni_common::LocyIncomplete>>>,
}
pub(crate) mod interruption {
use std::sync::atomic::{AtomicU8, Ordering};
use uni_common::LocyIncompleteReason;
pub(crate) const NONE: u8 = 0;
pub(crate) const TIMEOUT: u8 = 1;
pub(crate) const ITERATION_LIMIT: u8 = 2;
pub(crate) fn reason(flag: &AtomicU8) -> Option<LocyIncompleteReason> {
match flag.load(Ordering::Relaxed) {
TIMEOUT => Some(LocyIncompleteReason::Timeout),
ITERATION_LIMIT => Some(LocyIncompleteReason::IterationLimit),
_ => None,
}
}
pub(crate) fn set(flag: &AtomicU8, code: u8) {
let _ = flag.compare_exchange(NONE, code, Ordering::Relaxed, Ordering::Relaxed);
}
}
impl fmt::Debug for LocyProgramExec {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("LocyProgramExec")
.field("strata_count", &self.strata.len())
.field("commands_count", &self.commands.len())
.field("max_iterations", &self.max_iterations)
.field("timeout", &self.timeout)
.field("output_schema", &self.output_schema)
.field("max_derived_bytes", &self.max_derived_bytes)
.finish_non_exhaustive()
}
}
impl LocyProgramExec {
#[expect(
clippy::too_many_arguments,
reason = "execution plan node requires full graph and session context"
)]
#[deprecated(
note = "use `new_with_semiring_classifiers_and_cache` (or the lighter \
`new_with_semiring_and_classifiers` / `new_with_semiring`) — \
this legacy ctor defaults the semiring to AddMultProb and \
ships no classifier registry. To be removed after C0 Stage 2."
)]
pub fn new(
strata: Vec<LocyStratum>,
commands: Vec<LocyCommand>,
derived_scan_registry: Arc<DerivedScanRegistry>,
plugin_registry: Arc<PluginRegistry>,
graph_ctx: Arc<GraphExecutionContext>,
session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
storage: Arc<StorageManager>,
schema_info: Arc<UniSchema>,
params: HashMap<String, Value>,
output_schema: SchemaRef,
max_iterations: usize,
timeout: Duration,
max_derived_bytes: usize,
deterministic_best_by: bool,
strict_probability_domain: bool,
probability_epsilon: f64,
exact_probability: bool,
max_bdd_variables: usize,
top_k_proofs: usize,
) -> Self {
Self::new_with_semiring_and_classifiers(
strata,
commands,
derived_scan_registry,
plugin_registry,
graph_ctx,
session_ctx,
storage,
schema_info,
params,
output_schema,
max_iterations,
timeout,
max_derived_bytes,
deterministic_best_by,
strict_probability_domain,
probability_epsilon,
exact_probability,
max_bdd_variables,
top_k_proofs,
SemiringKind::AddMultProb,
Arc::new(ClassifierRegistry::new()),
)
}
#[expect(
clippy::too_many_arguments,
reason = "execution plan node requires full graph and session context"
)]
pub fn new_with_semiring(
strata: Vec<LocyStratum>,
commands: Vec<LocyCommand>,
derived_scan_registry: Arc<DerivedScanRegistry>,
plugin_registry: Arc<PluginRegistry>,
graph_ctx: Arc<GraphExecutionContext>,
session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
storage: Arc<StorageManager>,
schema_info: Arc<UniSchema>,
params: HashMap<String, Value>,
output_schema: SchemaRef,
max_iterations: usize,
timeout: Duration,
max_derived_bytes: usize,
deterministic_best_by: bool,
strict_probability_domain: bool,
probability_epsilon: f64,
exact_probability: bool,
max_bdd_variables: usize,
top_k_proofs: usize,
semiring_kind: SemiringKind,
) -> Self {
Self::new_with_semiring_and_classifiers(
strata,
commands,
derived_scan_registry,
plugin_registry,
graph_ctx,
session_ctx,
storage,
schema_info,
params,
output_schema,
max_iterations,
timeout,
max_derived_bytes,
deterministic_best_by,
strict_probability_domain,
probability_epsilon,
exact_probability,
max_bdd_variables,
top_k_proofs,
semiring_kind,
Arc::new(ClassifierRegistry::new()),
)
}
#[expect(
clippy::too_many_arguments,
reason = "execution plan node requires full graph and session context"
)]
pub fn new_with_semiring_and_classifiers(
strata: Vec<LocyStratum>,
commands: Vec<LocyCommand>,
derived_scan_registry: Arc<DerivedScanRegistry>,
plugin_registry: Arc<PluginRegistry>,
graph_ctx: Arc<GraphExecutionContext>,
session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
storage: Arc<StorageManager>,
schema_info: Arc<UniSchema>,
params: HashMap<String, Value>,
output_schema: SchemaRef,
max_iterations: usize,
timeout: Duration,
max_derived_bytes: usize,
deterministic_best_by: bool,
strict_probability_domain: bool,
probability_epsilon: f64,
exact_probability: bool,
max_bdd_variables: usize,
top_k_proofs: usize,
semiring_kind: SemiringKind,
classifier_registry: Arc<ClassifierRegistry>,
) -> Self {
Self::new_with_semiring_classifiers_and_cache(
strata,
commands,
derived_scan_registry,
plugin_registry,
graph_ctx,
session_ctx,
storage,
schema_info,
params,
output_schema,
max_iterations,
timeout,
max_derived_bytes,
deterministic_best_by,
strict_probability_domain,
probability_epsilon,
exact_probability,
max_bdd_variables,
top_k_proofs,
semiring_kind,
classifier_registry,
None,
None,
)
}
#[expect(
clippy::too_many_arguments,
reason = "execution plan node requires full graph and session context"
)]
pub fn new_with_semiring_classifiers_and_cache(
strata: Vec<LocyStratum>,
commands: Vec<LocyCommand>,
derived_scan_registry: Arc<DerivedScanRegistry>,
plugin_registry: Arc<PluginRegistry>,
graph_ctx: Arc<GraphExecutionContext>,
session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
storage: Arc<StorageManager>,
schema_info: Arc<UniSchema>,
params: HashMap<String, Value>,
output_schema: SchemaRef,
max_iterations: usize,
timeout: Duration,
max_derived_bytes: usize,
deterministic_best_by: bool,
strict_probability_domain: bool,
probability_epsilon: f64,
exact_probability: bool,
max_bdd_variables: usize,
top_k_proofs: usize,
semiring_kind: SemiringKind,
classifier_registry: Arc<ClassifierRegistry>,
classifier_cache: Option<Arc<ModelInvocationCache>>,
classifier_provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
) -> Self {
let properties = compute_plan_properties(Arc::clone(&output_schema));
Self {
strata,
commands,
derived_scan_registry,
plugin_registry,
graph_ctx,
session_ctx,
storage,
schema_info,
params,
output_schema,
properties,
metrics: ExecutionPlanMetricsSet::new(),
max_iterations,
timeout,
max_derived_bytes,
deterministic_best_by,
strict_probability_domain,
probability_epsilon,
exact_probability,
max_bdd_variables,
semiring_kind,
classifier_registry,
classifier_cache,
classifier_provenance_store,
derived_store_slot: Arc::new(StdRwLock::new(None)),
approximate_slot: Arc::new(StdRwLock::new(HashMap::new())),
derivation_tracker: Arc::new(StdRwLock::new(None)),
iteration_counts_slot: Arc::new(StdRwLock::new(HashMap::new())),
peak_memory_slot: Arc::new(StdRwLock::new(0)),
warnings_slot: Arc::new(StdRwLock::new(Vec::new())),
command_results_slot: Arc::new(StdRwLock::new(Vec::new())),
top_k_proofs,
timeout_flag: Arc::new(std::sync::atomic::AtomicU8::new(interruption::NONE)),
incomplete_slot: Arc::new(StdRwLock::new(None)),
}
}
pub fn derived_store_slot(&self) -> Arc<StdRwLock<Option<DerivedStore>>> {
Arc::clone(&self.derived_store_slot)
}
pub fn set_derivation_tracker(&self, tracker: Arc<ProvenanceStore>) {
if let Ok(mut guard) = self.derivation_tracker.write() {
*guard = Some(tracker);
}
}
pub fn iteration_counts_slot(&self) -> Arc<StdRwLock<HashMap<String, usize>>> {
Arc::clone(&self.iteration_counts_slot)
}
pub fn peak_memory_slot(&self) -> Arc<StdRwLock<usize>> {
Arc::clone(&self.peak_memory_slot)
}
pub fn warnings_slot(&self) -> Arc<StdRwLock<Vec<RuntimeWarning>>> {
Arc::clone(&self.warnings_slot)
}
pub fn approximate_slot(&self) -> Arc<StdRwLock<HashMap<String, Vec<String>>>> {
Arc::clone(&self.approximate_slot)
}
pub fn command_results_slot(&self) -> Arc<StdRwLock<Vec<(usize, CommandResult)>>> {
Arc::clone(&self.command_results_slot)
}
pub fn timeout_flag(&self) -> Arc<std::sync::atomic::AtomicU8> {
Arc::clone(&self.timeout_flag)
}
pub fn incomplete_slot(&self) -> Arc<StdRwLock<Option<uni_common::LocyIncomplete>>> {
Arc::clone(&self.incomplete_slot)
}
}
impl DisplayAs for LocyProgramExec {
fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"LocyProgramExec: strata={}, commands={}, max_iter={}, timeout={:?}",
self.strata.len(),
self.commands.len(),
self.max_iterations,
self.timeout,
)
}
}
impl ExecutionPlan for LocyProgramExec {
fn name(&self) -> &str {
"LocyProgramExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
Arc::clone(&self.output_schema)
}
fn properties(&self) -> &Arc<PlanProperties> {
&self.properties
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> DFResult<Arc<dyn ExecutionPlan>> {
if !children.is_empty() {
return Err(datafusion::error::DataFusionError::Plan(
"LocyProgramExec has no children".to_string(),
));
}
Ok(self)
}
fn execute(
&self,
partition: usize,
_context: Arc<TaskContext>,
) -> DFResult<SendableRecordBatchStream> {
let metrics = BaselineMetrics::new(&self.metrics, partition);
let strata = self.strata.clone();
let registry = Arc::clone(&self.derived_scan_registry);
let plugin_registry = Arc::clone(&self.plugin_registry);
let graph_ctx = Arc::clone(&self.graph_ctx);
let session_ctx = Arc::clone(&self.session_ctx);
let storage = Arc::clone(&self.storage);
let schema_info = Arc::clone(&self.schema_info);
let params = self.params.clone();
let output_schema = Arc::clone(&self.output_schema);
let max_iterations = self.max_iterations;
let timeout = self.timeout;
let max_derived_bytes = self.max_derived_bytes;
let deterministic_best_by = self.deterministic_best_by;
let strict_probability_domain = self.strict_probability_domain;
let probability_epsilon = self.probability_epsilon;
let exact_probability = self.exact_probability;
let max_bdd_variables = self.max_bdd_variables;
let derived_store_slot = Arc::clone(&self.derived_store_slot);
let approximate_slot = Arc::clone(&self.approximate_slot);
let iteration_counts_slot = Arc::clone(&self.iteration_counts_slot);
let peak_memory_slot = Arc::clone(&self.peak_memory_slot);
let derivation_tracker = self.derivation_tracker.read().ok().and_then(|g| g.clone());
let warnings_slot = Arc::clone(&self.warnings_slot);
let commands = self.commands.clone();
let command_results_slot = Arc::clone(&self.command_results_slot);
let top_k_proofs = self.top_k_proofs;
let timeout_flag = Arc::clone(&self.timeout_flag);
let incomplete_slot = Arc::clone(&self.incomplete_slot);
let semiring_kind = self.semiring_kind;
let classifier_registry = Arc::clone(&self.classifier_registry);
let classifier_cache = self.classifier_cache.as_ref().map(Arc::clone);
let classifier_provenance_store = self.classifier_provenance_store.as_ref().map(Arc::clone);
let fut = async move {
run_program(
strata,
commands,
registry,
plugin_registry,
graph_ctx,
session_ctx,
storage,
schema_info,
params,
output_schema,
max_iterations,
timeout,
max_derived_bytes,
deterministic_best_by,
strict_probability_domain,
probability_epsilon,
exact_probability,
max_bdd_variables,
derived_store_slot,
approximate_slot,
iteration_counts_slot,
peak_memory_slot,
derivation_tracker,
warnings_slot,
command_results_slot,
top_k_proofs,
timeout_flag,
incomplete_slot,
semiring_kind,
classifier_registry,
classifier_cache,
classifier_provenance_store,
)
.await
};
Ok(Box::pin(ProgramStream {
state: ProgramStreamState::Running(Box::pin(fut)),
schema: Arc::clone(&self.output_schema),
metrics,
}))
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
}
enum ProgramStreamState {
Running(Pin<Box<dyn std::future::Future<Output = DFResult<Vec<RecordBatch>>> + Send>>),
Emitting(Vec<RecordBatch>, usize),
Done,
}
struct ProgramStream {
state: ProgramStreamState,
schema: SchemaRef,
metrics: BaselineMetrics,
}
impl Stream for ProgramStream {
type Item = DFResult<RecordBatch>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
let metrics = this.metrics.clone();
let _timer = metrics.elapsed_compute().timer();
loop {
match &mut this.state {
ProgramStreamState::Running(fut) => match fut.as_mut().poll(cx) {
Poll::Ready(Ok(batches)) => {
if batches.is_empty() {
this.state = ProgramStreamState::Done;
return Poll::Ready(None);
}
this.state = ProgramStreamState::Emitting(batches, 0);
}
Poll::Ready(Err(e)) => {
this.state = ProgramStreamState::Done;
return Poll::Ready(Some(Err(e)));
}
Poll::Pending => return Poll::Pending,
},
ProgramStreamState::Emitting(batches, idx) => {
if *idx >= batches.len() {
this.state = ProgramStreamState::Done;
return Poll::Ready(None);
}
let batch = batches[*idx].clone();
*idx += 1;
this.metrics.record_output(batch.num_rows());
return Poll::Ready(Some(Ok(batch)));
}
ProgramStreamState::Done => return Poll::Ready(None),
}
}
}
}
impl RecordBatchStream for ProgramStream {
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
}
async fn execute_cypher_inline(
query: &uni_cypher::ast::Query,
schema_info: &Arc<UniSchema>,
params: &HashMap<String, Value>,
graph_ctx: &Arc<GraphExecutionContext>,
session_ctx: &Arc<RwLock<datafusion::prelude::SessionContext>>,
storage: &Arc<StorageManager>,
) -> DFResult<Vec<FactRow>> {
let planner = crate::query::planner::QueryPlanner::new(Arc::clone(schema_info));
let logical_plan = planner.plan(query.clone()).map_err(|e| {
datafusion::error::DataFusionError::Execution(format!("Cypher plan error: {e}"))
})?;
let batches = execute_subplan(
&logical_plan,
params,
&HashMap::new(),
graph_ctx,
session_ctx,
storage,
schema_info,
None, )
.await?;
Ok(super::locy_eval::record_batches_to_locy_rows(&batches))
}
#[expect(
clippy::too_many_arguments,
reason = "program evaluation requires full graph and session context"
)]
async fn run_program(
strata: Vec<LocyStratum>,
commands: Vec<LocyCommand>,
registry: Arc<DerivedScanRegistry>,
plugin_registry: Arc<PluginRegistry>,
graph_ctx: Arc<GraphExecutionContext>,
session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
storage: Arc<StorageManager>,
schema_info: Arc<UniSchema>,
params: HashMap<String, Value>,
output_schema: SchemaRef,
max_iterations: usize,
timeout: Duration,
max_derived_bytes: usize,
deterministic_best_by: bool,
strict_probability_domain: bool,
probability_epsilon: f64,
exact_probability: bool,
max_bdd_variables: usize,
derived_store_slot: Arc<StdRwLock<Option<DerivedStore>>>,
approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
iteration_counts_slot: Arc<StdRwLock<HashMap<String, usize>>>,
peak_memory_slot: Arc<StdRwLock<usize>>,
derivation_tracker: Option<Arc<ProvenanceStore>>,
warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
command_results_slot: Arc<StdRwLock<Vec<(usize, CommandResult)>>>,
top_k_proofs: usize,
timeout_flag: Arc<std::sync::atomic::AtomicU8>,
incomplete_slot: Arc<StdRwLock<Option<uni_common::LocyIncomplete>>>,
semiring_kind: SemiringKind,
classifier_registry: Arc<ClassifierRegistry>,
classifier_cache: Option<Arc<ModelInvocationCache>>,
classifier_provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
) -> DFResult<Vec<RecordBatch>> {
let start = Instant::now();
let mut derived_store = DerivedStore::new();
if semiring_kind == SemiringKind::MaxMinProb {
let mut warnings = warnings_slot.write().unwrap_or_else(|e| e.into_inner());
let mut already: std::collections::HashSet<String> = warnings
.iter()
.filter(|w| w.code == uni_locy::RuntimeWarningCode::FuzzyNotProbabilistic)
.map(|w| w.rule_name.clone())
.collect();
for stratum in &strata {
for rule in &stratum.rules {
let has_prob = rule.yield_schema.iter().any(|c| c.is_prob);
if has_prob && !already.contains(&rule.name) {
warnings.push(RuntimeWarning {
code: uni_locy::RuntimeWarningCode::FuzzyNotProbabilistic,
message: format!(
"rule '{}' carries a PROB column but is being evaluated under \
the MaxMinProb (fuzzy / Viterbi) semiring; outputs are fuzzy \
truth values, not probabilities",
rule.name
),
rule_name: rule.name.clone(),
variable_count: None,
key_group: None,
});
already.insert(rule.name.clone());
}
}
}
}
let total_strata = strata.len();
let mut completed_strata = 0usize;
let mut partial_stratum: Option<usize> = None;
for (stratum_idx, stratum) in strata.iter().enumerate() {
write_cross_stratum_facts(®istry, &derived_store, stratum);
let remaining_timeout = timeout.saturating_sub(start.elapsed());
if remaining_timeout.is_zero() {
tracing::warn!("Locy program timeout exceeded during stratum evaluation");
interruption::set(&timeout_flag, interruption::TIMEOUT);
break;
}
if stratum.is_recursive {
let fixpoint_rules = convert_to_fixpoint_plans(
&stratum.rules,
®istry,
&plugin_registry,
deterministic_best_by,
)?;
let fixpoint_schema = build_fixpoint_output_schema(&stratum.rules);
let exec = FixpointExec::new_with_semiring_classifiers_and_cache(
fixpoint_rules,
max_iterations,
remaining_timeout,
Arc::clone(&graph_ctx),
Arc::clone(&session_ctx),
Arc::clone(&storage),
Arc::clone(&schema_info),
params.clone(),
Arc::clone(®istry),
fixpoint_schema,
max_derived_bytes,
derivation_tracker.clone(),
Arc::clone(&iteration_counts_slot),
strict_probability_domain,
probability_epsilon,
exact_probability,
max_bdd_variables,
Arc::clone(&warnings_slot),
Arc::clone(&approximate_slot),
top_k_proofs,
Arc::clone(&timeout_flag),
semiring_kind,
Arc::clone(&classifier_registry),
classifier_cache.as_ref().map(Arc::clone),
classifier_provenance_store.as_ref().map(Arc::clone),
);
let task_ctx = session_ctx.read().task_ctx();
let exec_arc: Arc<dyn ExecutionPlan> = Arc::new(exec);
let batches = collect_all_partitions(&exec_arc, task_ctx).await?;
for rule in &stratum.rules {
if rule.yield_schema.is_empty() {
continue;
}
let rule_entries = registry.entries_for_rule(&rule.name);
for entry in rule_entries {
if !entry.is_self_ref {
let all_facts: Vec<RecordBatch> = batches
.iter()
.filter(|b| {
let rule_schema = yield_columns_to_arrow_schema(&rule.yield_schema);
b.schema().fields().len() == rule_schema.fields().len()
})
.cloned()
.collect();
let mut guard = entry.data.write();
*guard = if all_facts.is_empty() {
vec![RecordBatch::new_empty(Arc::clone(&entry.schema))]
} else {
all_facts
};
}
}
derived_store.insert(rule.name.clone(), batches.clone());
}
} else {
let fixpoint_rules = convert_to_fixpoint_plans(
&stratum.rules,
®istry,
&plugin_registry,
deterministic_best_by,
)?;
let task_ctx = session_ctx.read().task_ctx();
for (rule, fp_rule) in stratum.rules.iter().zip(fixpoint_rules.iter()) {
if rule.yield_schema.is_empty() {
continue;
}
let mut tagged_clause_facts: Vec<(usize, Vec<RecordBatch>)> = Vec::new();
for (clause_idx, (clause, fp_clause)) in
rule.clauses.iter().zip(fp_rule.clauses.iter()).enumerate()
{
let mut batches = execute_subplan(
&clause.body,
¶ms,
&HashMap::new(),
&graph_ctx,
&session_ctx,
&storage,
&schema_info,
None, )
.await?;
for binding in &fp_clause.is_ref_bindings {
if binding.negated
&& !binding.anti_join_cols.is_empty()
&& let Some(entry) = registry.get(binding.derived_scan_index)
{
let neg_facts = entry.data.read().clone();
if !neg_facts.is_empty() {
if binding.target_has_prob && fp_rule.prob_column_name.is_some() {
let complement_col =
format!("__prob_complement_{}", binding.rule_name);
if let Some(prob_col) = &binding.target_prob_col {
batches =
super::locy_fixpoint::apply_prob_complement_composite(
batches,
&neg_facts,
&binding.anti_join_cols,
prob_col,
&complement_col,
)?;
} else {
batches = super::locy_fixpoint::apply_anti_join_composite(
batches,
&neg_facts,
&binding.anti_join_cols,
)?;
}
} else {
batches = super::locy_fixpoint::apply_anti_join_composite(
batches,
&neg_facts,
&binding.anti_join_cols,
)?;
}
}
}
}
let complement_cols: Vec<String> = if !batches.is_empty() {
batches[0]
.schema()
.fields()
.iter()
.filter(|f| f.name().starts_with("__prob_complement_"))
.map(|f| f.name().clone())
.collect()
} else {
vec![]
};
if !complement_cols.is_empty() {
batches = super::locy_fixpoint::multiply_prob_factors(
batches,
fp_rule.prob_column_name.as_deref(),
&complement_cols,
)?;
}
tagged_clause_facts.push((clause_idx, batches));
}
let shared_info = if semiring_kind == SemiringKind::MaxMinProb {
None
} else if let Some(ref tracker) = derivation_tracker {
super::locy_fixpoint::record_and_detect_lineage_nonrecursive(
fp_rule,
&tagged_clause_facts,
tracker,
&warnings_slot,
®istry,
top_k_proofs,
super::locy_fixpoint::ClassifierRefs {
registry: &classifier_registry,
cache: classifier_cache.as_ref(),
provenance_store: classifier_provenance_store.as_ref(),
},
semiring_kind,
)
.await
} else {
None
};
let mut all_clause_facts: Vec<RecordBatch> = tagged_clause_facts
.into_iter()
.flat_map(|(_, batches)| batches)
.collect();
if exact_probability
&& let Some(ref info) = shared_info
&& let Some(ref tracker) = derivation_tracker
{
all_clause_facts = super::locy_fixpoint::apply_exact_wmc(
all_clause_facts,
fp_rule,
info,
tracker,
max_bdd_variables,
&warnings_slot,
&approximate_slot,
)?;
}
let facts = super::locy_fixpoint::apply_post_fixpoint_chain(
all_clause_facts,
fp_rule,
&task_ctx,
strict_probability_domain,
probability_epsilon,
semiring_kind,
derivation_tracker.as_ref().map(Arc::clone),
top_k_proofs,
Some(Arc::clone(®istry)),
)
.await?;
write_facts_to_registry(®istry, &rule.name, &facts);
derived_store.insert(rule.name.clone(), facts);
}
}
if interruption::reason(&timeout_flag).is_some() {
partial_stratum = Some(stratum_idx);
break;
}
completed_strata += 1;
}
if let Some(reason) = interruption::reason(&timeout_flag) {
let skipped_start = match partial_stratum {
Some(i) => i + 1,
None => completed_strata,
};
let incomplete_rules: Vec<String> = partial_stratum
.map(|i| strata[i].rules.iter().map(|r| r.name.clone()).collect())
.unwrap_or_default();
let skipped_rules: Vec<String> = strata[skipped_start..]
.iter()
.flat_map(|s| s.rules.iter().map(|r| r.name.clone()))
.collect();
let mut complement_rules_affected = Vec::new();
for idx in partial_stratum
.into_iter()
.chain(skipped_start..total_strata)
{
for rule in &strata[idx].rules {
if rule
.clauses
.iter()
.any(|c| c.is_refs.iter().any(|r| r.negated))
{
complement_rules_affected.push(rule.name.clone());
}
}
}
if let Ok(mut slot) = incomplete_slot.write() {
*slot = Some(uni_common::LocyIncomplete {
reason,
elapsed_ms: u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
limit_ms: u64::try_from(timeout.as_millis()).unwrap_or(u64::MAX),
max_iterations,
completed_strata,
total_strata,
incomplete_rules,
skipped_rules,
complement_rules_affected,
});
}
}
let peak_bytes: usize = derived_store
.relations
.values()
.flat_map(|batches| batches.iter())
.map(|b| {
b.columns()
.iter()
.map(|col| col.get_buffer_memory_size())
.sum::<usize>()
})
.sum();
*peak_memory_slot.write().unwrap() = peak_bytes;
let first_derive_idx = commands
.iter()
.position(|c| matches!(c, LocyCommand::Derive { .. }));
let mut inline_results: Vec<(usize, CommandResult)> = Vec::new();
for (cmd_idx, cmd) in commands.iter().enumerate() {
match cmd {
LocyCommand::Cypher { query } => {
if first_derive_idx.is_some_and(|di| cmd_idx > di) {
continue;
}
let rows = execute_cypher_inline(
query,
&schema_info,
¶ms,
&graph_ctx,
&session_ctx,
&storage,
)
.await?;
inline_results.push((cmd_idx, CommandResult::Cypher(rows)));
}
LocyCommand::Validate { validate } => {
let rule_key_cols: Vec<String> = strata
.iter()
.flat_map(|s| s.rules.iter())
.find(|r| r.name == validate.rule_name)
.map(|r| {
r.yield_schema
.iter()
.filter(|c| c.is_key)
.map(|c| c.name.clone())
.collect()
})
.unwrap_or_default();
let query =
super::locy_validate::validate_collection_query(validate, &rule_key_cols);
let target_rows = execute_cypher_inline(
&query,
&schema_info,
¶ms,
&graph_ctx,
&session_ctx,
&storage,
)
.await?;
let rule_facts: Vec<uni_locy::FactRow> = derived_store
.get(&validate.rule_name)
.map(|batches| super::locy_eval::record_batches_to_locy_rows(batches))
.unwrap_or_default();
let result = super::locy_validate::run_validate(
validate,
&rule_key_cols,
&rule_facts,
target_rows,
)
.map_err(|e| {
datafusion::error::DataFusionError::Execution(format!("VALIDATE error: {e}"))
})?;
inline_results.push((cmd_idx, CommandResult::Validate(result)));
}
LocyCommand::Calibrate {
calibrate,
model_inputs,
} => {
let model_snapshot = uni_locy::CompiledModel {
name: calibrate.model_name.clone(),
inputs: model_inputs.clone(),
features: vec![],
path_context: None,
output_type: uni_cypher::locy_ast::OutputType::Prob,
output_name: String::new(),
xervo_alias: String::new(),
embedder_alias: None,
calibration: None,
version: None,
annotations: Default::default(),
};
let query =
super::locy_calibrate::calibrate_collection_query(calibrate, &model_snapshot);
let rows = execute_cypher_inline(
&query,
&schema_info,
¶ms,
&graph_ctx,
&session_ctx,
&storage,
)
.await?;
let mut catalog = std::collections::HashMap::new();
catalog.insert(calibrate.model_name.clone(), model_snapshot);
let result = super::locy_calibrate::run_calibrate(
calibrate,
&catalog,
&classifier_registry,
rows,
)
.await
.map_err(|e| {
datafusion::error::DataFusionError::Execution(format!("CALIBRATE error: {e}"))
})?;
inline_results.push((cmd_idx, CommandResult::Calibrate(result)));
}
_ => {}
}
}
*command_results_slot.write().unwrap() = inline_results;
let stats = vec![build_stats_batch(&derived_store, &strata, output_schema)];
*derived_store_slot.write().unwrap() = Some(derived_store);
Ok(stats)
}
fn write_cross_stratum_facts(
registry: &DerivedScanRegistry,
derived_store: &DerivedStore,
stratum: &LocyStratum,
) {
for rule in &stratum.rules {
for clause in &rule.clauses {
for is_ref in &clause.is_refs {
if let Some(facts) = derived_store.get(&is_ref.rule_name) {
write_facts_to_registry(registry, &is_ref.rule_name, facts);
}
}
}
}
}
fn write_facts_to_registry(registry: &DerivedScanRegistry, rule_name: &str, facts: &[RecordBatch]) {
let entries = registry.entries_for_rule(rule_name);
for entry in entries {
if !entry.is_self_ref {
let mut guard = entry.data.write();
*guard = if facts.is_empty() || facts.iter().all(|b| b.num_rows() == 0) {
vec![RecordBatch::new_empty(Arc::clone(&entry.schema))]
} else {
facts
.iter()
.filter(|b| b.num_rows() > 0)
.map(|b| {
RecordBatch::try_new(Arc::clone(&entry.schema), b.columns().to_vec())
.unwrap_or_else(|_| b.clone())
})
.collect()
};
}
}
}
fn convert_to_fixpoint_plans(
rules: &[LocyRulePlan],
registry: &DerivedScanRegistry,
plugin_registry: &PluginRegistry,
deterministic_best_by: bool,
) -> DFResult<Vec<FixpointRulePlan>> {
let stratum_rule_names: std::collections::HashSet<&str> =
rules.iter().map(|r| r.name.as_str()).collect();
rules
.iter()
.map(|rule| {
let yield_schema = yield_columns_to_arrow_schema(&rule.yield_schema);
let key_column_indices: Vec<usize> = rule
.yield_schema
.iter()
.enumerate()
.filter(|(_, yc)| yc.is_key)
.map(|(i, _)| i)
.collect();
let clauses: Vec<FixpointClausePlan> = rule
.clauses
.iter()
.map(|clause| {
let is_ref_bindings =
convert_is_refs(&clause.is_refs, registry, &stratum_rule_names)?;
Ok(FixpointClausePlan {
body_logical: clause.body.clone(),
is_ref_bindings,
priority: clause.priority,
along_bindings: clause.along_bindings.clone(),
model_invocations: clause.model_invocations.clone(),
})
})
.collect::<DFResult<Vec<_>>>()?;
let fold_bindings =
convert_fold_bindings(&rule.fold_bindings, &rule.yield_schema, plugin_registry)?;
let best_by_criteria =
convert_best_by_criteria(&rule.best_by_criteria, &rule.yield_schema)?;
let has_priority = rule.priority.is_some();
let yield_schema = if has_priority {
let mut fields: Vec<Arc<Field>> = yield_schema.fields().iter().cloned().collect();
fields.push(Arc::new(Field::new("__priority", DataType::Int64, true)));
ArrowSchema::new(fields)
} else {
yield_schema
};
let prob_column_name = rule
.yield_schema
.iter()
.find(|yc| yc.is_prob)
.map(|yc| yc.name.clone());
let non_linear = rule.clauses.iter().any(|clause| {
clause
.is_refs
.iter()
.filter(|ir| !ir.negated && stratum_rule_names.contains(ir.rule_name.as_str()))
.count()
>= 2
});
Ok(FixpointRulePlan {
name: rule.name.clone(),
clauses,
yield_schema: Arc::new(yield_schema),
key_column_indices,
priority: rule.priority,
has_fold: !rule.fold_bindings.is_empty(),
fold_bindings,
having: rule.having.clone(),
has_best_by: !rule.best_by_criteria.is_empty(),
best_by_criteria,
has_priority,
deterministic: deterministic_best_by,
prob_column_name,
non_linear,
})
})
.collect()
}
fn convert_is_refs(
is_refs: &[LocyIsRef],
registry: &DerivedScanRegistry,
stratum_rule_names: &std::collections::HashSet<&str>,
) -> DFResult<Vec<IsRefBinding>> {
is_refs
.iter()
.map(|is_ref| {
let entries = registry.entries_for_rule(&is_ref.rule_name);
let want_self_ref = stratum_rule_names.contains(is_ref.rule_name.as_str());
let entry = entries
.iter()
.find(|e| e.is_self_ref == want_self_ref)
.or_else(|| entries.first())
.ok_or_else(|| {
datafusion::error::DataFusionError::Plan(format!(
"No derived scan entry found for IS-ref to '{}'",
is_ref.rule_name
))
})?;
let anti_join_cols = if is_ref.negated {
let mut cols: Vec<(String, String)> = is_ref
.subjects
.iter()
.enumerate()
.filter_map(|(i, s)| {
if let uni_cypher::ast::Expr::Variable(var) = s {
let right_col = entry
.schema
.fields()
.get(i)
.map(|f| f.name().clone())
.unwrap_or_else(|| var.clone());
Some((var.clone(), right_col))
} else {
None
}
})
.collect();
if let Some(uni_cypher::ast::Expr::Variable(target_var)) = &is_ref.target {
let target_idx = is_ref.subjects.len();
if let Some(field) = entry.schema.fields().get(target_idx) {
cols.push((target_var.clone(), field.name().clone()));
}
}
cols
} else {
Vec::new()
};
let provenance_join_cols: Vec<(String, String)> = is_ref
.subjects
.iter()
.enumerate()
.filter_map(|(i, s)| {
if let uni_cypher::ast::Expr::Variable(var) = s {
let right_col = entry
.schema
.fields()
.get(i)
.map(|f| f.name().clone())
.unwrap_or_else(|| var.clone());
Some((var.clone(), right_col))
} else {
None
}
})
.collect();
Ok(IsRefBinding {
derived_scan_index: entry.scan_index,
rule_name: is_ref.rule_name.clone(),
is_self_ref: entry.is_self_ref,
negated: is_ref.negated,
anti_join_cols,
target_has_prob: is_ref.target_has_prob,
target_prob_col: is_ref.target_prob_col.clone(),
provenance_join_cols,
})
})
.collect()
}
fn convert_fold_bindings(
fold_bindings: &[(String, String, Expr)],
yield_schema: &[LocyYieldColumn],
plugin_registry: &PluginRegistry,
) -> DFResult<Vec<FoldBinding>> {
fold_bindings
.iter()
.map(|(name, yield_alias, expr)| {
let (agg_name, _input_col_name) = parse_fold_aggregate(expr)?;
let entry =
resolve_locy_aggregate(plugin_registry, agg_name.as_str()).ok_or_else(|| {
datafusion::error::DataFusionError::Plan(format!(
"Unknown Locy aggregate '{agg_name}' — not registered in plugin registry"
))
})?;
let aggregate = Arc::clone(&entry.aggregate);
if agg_name.as_str() == "COUNTALL" {
return Ok(FoldBinding {
output_name: yield_alias.clone(),
name: agg_name,
aggregate,
input_col_index: 0, input_col_name: None,
});
}
let input_col_index = yield_schema
.iter()
.position(|yc| yc.name == *name || yc.name == *yield_alias)
.unwrap_or(0);
Ok(FoldBinding {
output_name: yield_alias.clone(),
name: agg_name,
aggregate,
input_col_index,
input_col_name: Some(name.clone()),
})
})
.collect()
}
fn parse_fold_aggregate(expr: &Expr) -> DFResult<(smol_str::SmolStr, String)> {
match expr {
Expr::FunctionCall { name, args, .. } => {
let upper = name.to_uppercase();
let is_count = matches!(upper.as_str(), "COUNT" | "MCOUNT");
if is_count && args.is_empty() {
return Ok((smol_str::SmolStr::new_static("COUNTALL"), String::new()));
}
let canonical = match upper.as_str() {
"SUM" | "MSUM" => smol_str::SmolStr::new_static("SUM"),
"MAX" | "MMAX" => smol_str::SmolStr::new_static("MAX"),
"MIN" | "MMIN" => smol_str::SmolStr::new_static("MIN"),
"COUNT" | "MCOUNT" => smol_str::SmolStr::new_static("COUNT"),
"AVG" => smol_str::SmolStr::new_static("AVG"),
"COLLECT" => smol_str::SmolStr::new_static("COLLECT"),
"MNOR" => smol_str::SmolStr::new_static("MNOR"),
"MPROD" => smol_str::SmolStr::new_static("MPROD"),
_ => {
return Err(datafusion::error::DataFusionError::Plan(format!(
"Unknown FOLD aggregate function: {}",
name
)));
}
};
let col_name = match args.first() {
Some(Expr::Variable(v)) => v.clone(),
Some(Expr::Property(_, prop)) => prop.clone(),
Some(other) => other.to_string_repr(),
None => {
return Err(datafusion::error::DataFusionError::Plan(
"FOLD aggregate function requires at least one argument".to_string(),
));
}
};
Ok((canonical, col_name))
}
_ => Err(datafusion::error::DataFusionError::Plan(
"FOLD binding must be a function call (e.g., SUM(x))".to_string(),
)),
}
}
fn convert_best_by_criteria(
criteria: &[(Expr, bool)],
yield_schema: &[LocyYieldColumn],
) -> DFResult<Vec<SortCriterion>> {
criteria
.iter()
.map(|(expr, ascending)| {
let col_name = match expr {
Expr::Property(_, prop) => prop.clone(),
Expr::Variable(v) => v.clone(),
_ => {
return Err(datafusion::error::DataFusionError::Plan(
"BEST BY criterion must be a variable or property reference".to_string(),
));
}
};
let col_index = yield_schema
.iter()
.position(|yc| yc.name == col_name)
.or_else(|| {
let short_name = col_name.rsplit('.').next().unwrap_or(&col_name);
yield_schema.iter().position(|yc| yc.name == short_name)
})
.ok_or_else(|| {
datafusion::error::DataFusionError::Plan(format!(
"BEST BY column '{}' not found in yield schema",
col_name
))
})?;
Ok(SortCriterion {
col_index,
ascending: *ascending,
nulls_first: false,
})
})
.collect()
}
fn yield_columns_to_arrow_schema(columns: &[LocyYieldColumn]) -> ArrowSchema {
let fields: Vec<Arc<Field>> = columns
.iter()
.map(|yc| Arc::new(Field::new(&yc.name, yc.data_type.clone(), true)))
.collect();
ArrowSchema::new(fields)
}
fn build_fixpoint_output_schema(rules: &[LocyRulePlan]) -> SchemaRef {
if let Some(rule) = rules.first() {
Arc::new(yield_columns_to_arrow_schema(&rule.yield_schema))
} else {
Arc::new(ArrowSchema::empty())
}
}
fn build_stats_batch(
derived_store: &DerivedStore,
_strata: &[LocyStratum],
output_schema: SchemaRef,
) -> RecordBatch {
let mut rule_names: Vec<String> = derived_store.rule_names().map(String::from).collect();
rule_names.sort();
let name_col: arrow_array::StringArray = rule_names.iter().map(|s| Some(s.as_str())).collect();
let count_col: arrow_array::Int64Array = rule_names
.iter()
.map(|name| Some(derived_store.fact_count(name) as i64))
.collect();
let stats_schema = stats_schema();
RecordBatch::try_new(stats_schema, vec![Arc::new(name_col), Arc::new(count_col)])
.unwrap_or_else(|_| RecordBatch::new_empty(output_schema))
}
pub fn stats_schema() -> SchemaRef {
Arc::new(ArrowSchema::new(vec![
Arc::new(Field::new("rule_name", DataType::Utf8, false)),
Arc::new(Field::new("fact_count", DataType::Int64, false)),
]))
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::{Int64Array, LargeBinaryArray, StringArray};
#[test]
fn test_derived_store_insert_and_get() {
let mut store = DerivedStore::new();
assert!(store.get("test").is_none());
let schema = Arc::new(ArrowSchema::new(vec![Arc::new(Field::new(
"x",
DataType::LargeBinary,
true,
))]));
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![Arc::new(LargeBinaryArray::from(vec![
Some(b"a" as &[u8]),
Some(b"b"),
]))],
)
.unwrap();
store.insert("test".to_string(), vec![batch.clone()]);
let facts = store.get("test").unwrap();
assert_eq!(facts.len(), 1);
assert_eq!(facts[0].num_rows(), 2);
}
#[test]
fn test_derived_store_fact_count() {
let mut store = DerivedStore::new();
assert_eq!(store.fact_count("empty"), 0);
let schema = Arc::new(ArrowSchema::new(vec![Arc::new(Field::new(
"x",
DataType::LargeBinary,
true,
))]));
let batch1 = RecordBatch::try_new(
Arc::clone(&schema),
vec![Arc::new(LargeBinaryArray::from(vec![Some(b"a" as &[u8])]))],
)
.unwrap();
let batch2 = RecordBatch::try_new(
Arc::clone(&schema),
vec![Arc::new(LargeBinaryArray::from(vec![
Some(b"b" as &[u8]),
Some(b"c"),
]))],
)
.unwrap();
store.insert("test".to_string(), vec![batch1, batch2]);
assert_eq!(store.fact_count("test"), 3);
}
#[test]
fn test_stats_batch_schema() {
let schema = stats_schema();
assert_eq!(schema.fields().len(), 2);
assert_eq!(schema.field(0).name(), "rule_name");
assert_eq!(schema.field(1).name(), "fact_count");
assert_eq!(schema.field(0).data_type(), &DataType::Utf8);
assert_eq!(schema.field(1).data_type(), &DataType::Int64);
}
#[test]
fn test_stats_batch_content() {
let mut store = DerivedStore::new();
let schema = Arc::new(ArrowSchema::new(vec![Arc::new(Field::new(
"x",
DataType::LargeBinary,
true,
))]));
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![Arc::new(LargeBinaryArray::from(vec![
Some(b"a" as &[u8]),
Some(b"b"),
]))],
)
.unwrap();
store.insert("reach".to_string(), vec![batch]);
let output_schema = stats_schema();
let stats = build_stats_batch(&store, &[], Arc::clone(&output_schema));
assert_eq!(stats.num_rows(), 1);
let names = stats
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert_eq!(names.value(0), "reach");
let counts = stats
.column(1)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(counts.value(0), 2);
}
#[test]
fn test_yield_columns_to_arrow_schema() {
let columns = vec![
LocyYieldColumn {
name: "a".to_string(),
is_key: true,
is_prob: false,
data_type: DataType::UInt64,
},
LocyYieldColumn {
name: "b".to_string(),
is_key: false,
is_prob: false,
data_type: DataType::LargeUtf8,
},
LocyYieldColumn {
name: "c".to_string(),
is_key: true,
is_prob: false,
data_type: DataType::Float64,
},
];
let schema = yield_columns_to_arrow_schema(&columns);
assert_eq!(schema.fields().len(), 3);
assert_eq!(schema.field(0).name(), "a");
assert_eq!(schema.field(1).name(), "b");
assert_eq!(schema.field(2).name(), "c");
assert_eq!(schema.field(0).data_type(), &DataType::UInt64);
assert_eq!(schema.field(1).data_type(), &DataType::LargeUtf8);
assert_eq!(schema.field(2).data_type(), &DataType::Float64);
for field in schema.fields() {
assert!(field.is_nullable());
}
}
#[test]
fn test_key_column_indices() {
let columns = [
LocyYieldColumn {
name: "a".to_string(),
is_key: true,
is_prob: false,
data_type: DataType::LargeBinary,
},
LocyYieldColumn {
name: "b".to_string(),
is_key: false,
is_prob: false,
data_type: DataType::LargeBinary,
},
LocyYieldColumn {
name: "c".to_string(),
is_key: true,
is_prob: false,
data_type: DataType::LargeBinary,
},
];
let key_indices: Vec<usize> = columns
.iter()
.enumerate()
.filter(|(_, yc)| yc.is_key)
.map(|(i, _)| i)
.collect();
assert_eq!(key_indices, vec![0, 2]);
}
#[test]
fn test_parse_fold_aggregate_sum() {
let expr = Expr::FunctionCall {
name: "SUM".to_string(),
args: vec![Expr::Variable("cost".to_string())],
distinct: false,
window_spec: None,
};
let (kind, col) = parse_fold_aggregate(&expr).unwrap();
assert_eq!(kind.as_str(), "SUM");
assert_eq!(col, "cost");
}
#[test]
fn test_parse_fold_aggregate_monotonic() {
let expr = Expr::FunctionCall {
name: "MMAX".to_string(),
args: vec![Expr::Variable("score".to_string())],
distinct: false,
window_spec: None,
};
let (kind, col) = parse_fold_aggregate(&expr).unwrap();
assert_eq!(kind.as_str(), "MAX");
assert_eq!(col, "score");
}
#[test]
fn test_parse_fold_aggregate_unknown() {
let expr = Expr::FunctionCall {
name: "UNKNOWN_AGG".to_string(),
args: vec![Expr::Variable("x".to_string())],
distinct: false,
window_spec: None,
};
assert!(parse_fold_aggregate(&expr).is_err());
}
#[test]
fn test_no_commands_returns_stats() {
let store = DerivedStore::new();
let output_schema = stats_schema();
let stats = build_stats_batch(&store, &[], Arc::clone(&output_schema));
assert_eq!(stats.num_rows(), 0);
}
}