use crate::query::df_graph::GraphExecutionContext;
use crate::query::df_graph::common::{
collect_all_partitions, compute_plan_properties, execute_subplan,
};
use crate::query::df_graph::locy_best_by::SortCriterion;
use crate::query::df_graph::locy_explain::ProvenanceStore;
use crate::query::df_graph::locy_fixpoint::{
DerivedScanRegistry, FixpointClausePlan, FixpointExec, FixpointRulePlan, IsRefBinding,
};
use crate::query::df_graph::locy_fold::{FoldAggKind, FoldBinding};
use crate::query::planner_locy_types::{
LocyCommand, LocyIsRef, LocyRulePlan, LocyStratum, LocyYieldColumn,
};
use arrow_array::RecordBatch;
use arrow_schema::{DataType, Field, Schema as ArrowSchema, SchemaRef};
use datafusion::common::Result as DFResult;
use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
use futures::Stream;
use parking_lot::RwLock;
use std::any::Any;
use std::collections::HashMap;
use std::fmt;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::RwLock as StdRwLock;
use std::task::{Context, Poll};
use std::time::{Duration, Instant};
use uni_common::Value;
use uni_common::core::schema::Schema as UniSchema;
use uni_cypher::ast::Expr;
use uni_cypher::locy_ast::GoalQuery;
use uni_locy::{CommandResult, FactRow, RuntimeWarning};
use uni_store::storage::manager::StorageManager;
pub struct DerivedStore {
relations: HashMap<String, Vec<RecordBatch>>,
}
impl Default for DerivedStore {
fn default() -> Self {
Self::new()
}
}
impl DerivedStore {
pub fn new() -> Self {
Self {
relations: HashMap::new(),
}
}
pub fn insert(&mut self, rule_name: String, facts: Vec<RecordBatch>) {
self.relations.insert(rule_name, facts);
}
pub fn get(&self, rule_name: &str) -> Option<&Vec<RecordBatch>> {
self.relations.get(rule_name)
}
pub fn fact_count(&self, rule_name: &str) -> usize {
self.relations
.get(rule_name)
.map(|batches| batches.iter().map(|b| b.num_rows()).sum())
.unwrap_or(0)
}
pub fn rule_names(&self) -> impl Iterator<Item = &str> {
self.relations.keys().map(|s| s.as_str())
}
}
pub struct LocyProgramExec {
strata: Vec<LocyStratum>,
commands: Vec<LocyCommand>,
derived_scan_registry: Arc<DerivedScanRegistry>,
graph_ctx: Arc<GraphExecutionContext>,
session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
storage: Arc<StorageManager>,
schema_info: Arc<UniSchema>,
params: HashMap<String, Value>,
output_schema: SchemaRef,
properties: PlanProperties,
metrics: ExecutionPlanMetricsSet,
max_iterations: usize,
timeout: Duration,
max_derived_bytes: usize,
deterministic_best_by: bool,
strict_probability_domain: bool,
probability_epsilon: f64,
exact_probability: bool,
max_bdd_variables: usize,
derived_store_slot: Arc<StdRwLock<Option<DerivedStore>>>,
approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
derivation_tracker: Arc<StdRwLock<Option<Arc<ProvenanceStore>>>>,
iteration_counts_slot: Arc<StdRwLock<HashMap<String, usize>>>,
peak_memory_slot: Arc<StdRwLock<usize>>,
warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
command_results_slot: Arc<StdRwLock<Vec<(usize, CommandResult)>>>,
top_k_proofs: usize,
timeout_flag: Arc<std::sync::atomic::AtomicBool>,
}
impl fmt::Debug for LocyProgramExec {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("LocyProgramExec")
.field("strata_count", &self.strata.len())
.field("commands_count", &self.commands.len())
.field("max_iterations", &self.max_iterations)
.field("timeout", &self.timeout)
.field("output_schema", &self.output_schema)
.field("max_derived_bytes", &self.max_derived_bytes)
.finish_non_exhaustive()
}
}
impl LocyProgramExec {
#[expect(
clippy::too_many_arguments,
reason = "execution plan node requires full graph and session context"
)]
pub fn new(
strata: Vec<LocyStratum>,
commands: Vec<LocyCommand>,
derived_scan_registry: Arc<DerivedScanRegistry>,
graph_ctx: Arc<GraphExecutionContext>,
session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
storage: Arc<StorageManager>,
schema_info: Arc<UniSchema>,
params: HashMap<String, Value>,
output_schema: SchemaRef,
max_iterations: usize,
timeout: Duration,
max_derived_bytes: usize,
deterministic_best_by: bool,
strict_probability_domain: bool,
probability_epsilon: f64,
exact_probability: bool,
max_bdd_variables: usize,
top_k_proofs: usize,
) -> Self {
let properties = compute_plan_properties(Arc::clone(&output_schema));
Self {
strata,
commands,
derived_scan_registry,
graph_ctx,
session_ctx,
storage,
schema_info,
params,
output_schema,
properties,
metrics: ExecutionPlanMetricsSet::new(),
max_iterations,
timeout,
max_derived_bytes,
deterministic_best_by,
strict_probability_domain,
probability_epsilon,
exact_probability,
max_bdd_variables,
derived_store_slot: Arc::new(StdRwLock::new(None)),
approximate_slot: Arc::new(StdRwLock::new(HashMap::new())),
derivation_tracker: Arc::new(StdRwLock::new(None)),
iteration_counts_slot: Arc::new(StdRwLock::new(HashMap::new())),
peak_memory_slot: Arc::new(StdRwLock::new(0)),
warnings_slot: Arc::new(StdRwLock::new(Vec::new())),
command_results_slot: Arc::new(StdRwLock::new(Vec::new())),
top_k_proofs,
timeout_flag: Arc::new(std::sync::atomic::AtomicBool::new(false)),
}
}
pub fn derived_store_slot(&self) -> Arc<StdRwLock<Option<DerivedStore>>> {
Arc::clone(&self.derived_store_slot)
}
pub fn set_derivation_tracker(&self, tracker: Arc<ProvenanceStore>) {
if let Ok(mut guard) = self.derivation_tracker.write() {
*guard = Some(tracker);
}
}
pub fn iteration_counts_slot(&self) -> Arc<StdRwLock<HashMap<String, usize>>> {
Arc::clone(&self.iteration_counts_slot)
}
pub fn peak_memory_slot(&self) -> Arc<StdRwLock<usize>> {
Arc::clone(&self.peak_memory_slot)
}
pub fn warnings_slot(&self) -> Arc<StdRwLock<Vec<RuntimeWarning>>> {
Arc::clone(&self.warnings_slot)
}
pub fn approximate_slot(&self) -> Arc<StdRwLock<HashMap<String, Vec<String>>>> {
Arc::clone(&self.approximate_slot)
}
pub fn command_results_slot(&self) -> Arc<StdRwLock<Vec<(usize, CommandResult)>>> {
Arc::clone(&self.command_results_slot)
}
pub fn timeout_flag(&self) -> Arc<std::sync::atomic::AtomicBool> {
Arc::clone(&self.timeout_flag)
}
}
impl DisplayAs for LocyProgramExec {
fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"LocyProgramExec: strata={}, commands={}, max_iter={}, timeout={:?}",
self.strata.len(),
self.commands.len(),
self.max_iterations,
self.timeout,
)
}
}
impl ExecutionPlan for LocyProgramExec {
fn name(&self) -> &str {
"LocyProgramExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
Arc::clone(&self.output_schema)
}
fn properties(&self) -> &PlanProperties {
&self.properties
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> DFResult<Arc<dyn ExecutionPlan>> {
if !children.is_empty() {
return Err(datafusion::error::DataFusionError::Plan(
"LocyProgramExec has no children".to_string(),
));
}
Ok(self)
}
fn execute(
&self,
partition: usize,
_context: Arc<TaskContext>,
) -> DFResult<SendableRecordBatchStream> {
let metrics = BaselineMetrics::new(&self.metrics, partition);
let strata = self.strata.clone();
let registry = Arc::clone(&self.derived_scan_registry);
let graph_ctx = Arc::clone(&self.graph_ctx);
let session_ctx = Arc::clone(&self.session_ctx);
let storage = Arc::clone(&self.storage);
let schema_info = Arc::clone(&self.schema_info);
let params = self.params.clone();
let output_schema = Arc::clone(&self.output_schema);
let max_iterations = self.max_iterations;
let timeout = self.timeout;
let max_derived_bytes = self.max_derived_bytes;
let deterministic_best_by = self.deterministic_best_by;
let strict_probability_domain = self.strict_probability_domain;
let probability_epsilon = self.probability_epsilon;
let exact_probability = self.exact_probability;
let max_bdd_variables = self.max_bdd_variables;
let derived_store_slot = Arc::clone(&self.derived_store_slot);
let approximate_slot = Arc::clone(&self.approximate_slot);
let iteration_counts_slot = Arc::clone(&self.iteration_counts_slot);
let peak_memory_slot = Arc::clone(&self.peak_memory_slot);
let derivation_tracker = self.derivation_tracker.read().ok().and_then(|g| g.clone());
let warnings_slot = Arc::clone(&self.warnings_slot);
let commands = self.commands.clone();
let command_results_slot = Arc::clone(&self.command_results_slot);
let top_k_proofs = self.top_k_proofs;
let timeout_flag = Arc::clone(&self.timeout_flag);
let fut = async move {
run_program(
strata,
commands,
registry,
graph_ctx,
session_ctx,
storage,
schema_info,
params,
output_schema,
max_iterations,
timeout,
max_derived_bytes,
deterministic_best_by,
strict_probability_domain,
probability_epsilon,
exact_probability,
max_bdd_variables,
derived_store_slot,
approximate_slot,
iteration_counts_slot,
peak_memory_slot,
derivation_tracker,
warnings_slot,
command_results_slot,
top_k_proofs,
timeout_flag,
)
.await
};
Ok(Box::pin(ProgramStream {
state: ProgramStreamState::Running(Box::pin(fut)),
schema: Arc::clone(&self.output_schema),
metrics,
}))
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
}
enum ProgramStreamState {
Running(Pin<Box<dyn std::future::Future<Output = DFResult<Vec<RecordBatch>>> + Send>>),
Emitting(Vec<RecordBatch>, usize),
Done,
}
struct ProgramStream {
state: ProgramStreamState,
schema: SchemaRef,
metrics: BaselineMetrics,
}
impl Stream for ProgramStream {
type Item = DFResult<RecordBatch>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
loop {
match &mut this.state {
ProgramStreamState::Running(fut) => match fut.as_mut().poll(cx) {
Poll::Ready(Ok(batches)) => {
if batches.is_empty() {
this.state = ProgramStreamState::Done;
return Poll::Ready(None);
}
this.state = ProgramStreamState::Emitting(batches, 0);
}
Poll::Ready(Err(e)) => {
this.state = ProgramStreamState::Done;
return Poll::Ready(Some(Err(e)));
}
Poll::Pending => return Poll::Pending,
},
ProgramStreamState::Emitting(batches, idx) => {
if *idx >= batches.len() {
this.state = ProgramStreamState::Done;
return Poll::Ready(None);
}
let batch = batches[*idx].clone();
*idx += 1;
this.metrics.record_output(batch.num_rows());
return Poll::Ready(Some(Ok(batch)));
}
ProgramStreamState::Done => return Poll::Ready(None),
}
}
}
}
impl RecordBatchStream for ProgramStream {
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
}
#[allow(dead_code)]
fn execute_query_inline(
query: &GoalQuery,
derived_store: &DerivedStore,
params: &HashMap<String, Value>,
) -> DFResult<Vec<FactRow>> {
let rule_name = query.rule_name.to_string();
let batches = derived_store.get(&rule_name).cloned().unwrap_or_default();
let rows = super::locy_eval::record_batches_to_locy_rows(&batches);
let filtered = if let Some(ref where_expr) = query.where_expr {
rows.into_iter()
.filter(|row| {
let merged = super::locy_query::merge_params(row, params);
super::locy_eval::eval_expr(where_expr, &merged)
.map(|v| v.as_bool().unwrap_or(false))
.unwrap_or(false)
})
.collect()
} else {
rows
};
super::locy_query::apply_return_clause(filtered, &query.return_clause, params)
.map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
}
async fn execute_cypher_inline(
query: &uni_cypher::ast::Query,
schema_info: &Arc<UniSchema>,
params: &HashMap<String, Value>,
graph_ctx: &Arc<GraphExecutionContext>,
session_ctx: &Arc<RwLock<datafusion::prelude::SessionContext>>,
storage: &Arc<StorageManager>,
) -> DFResult<Vec<FactRow>> {
let planner = crate::query::planner::QueryPlanner::new(Arc::clone(schema_info));
let logical_plan = planner.plan(query.clone()).map_err(|e| {
datafusion::error::DataFusionError::Execution(format!("Cypher plan error: {e}"))
})?;
let batches = execute_subplan(
&logical_plan,
params,
&HashMap::new(),
graph_ctx,
session_ctx,
storage,
schema_info,
)
.await?;
Ok(super::locy_eval::record_batches_to_locy_rows(&batches))
}
#[allow(dead_code)]
fn needs_node_enrichment(query: &GoalQuery) -> bool {
let where_has_property = query
.where_expr
.as_ref()
.is_some_and(expr_has_property_access);
let return_has_property = query.return_clause.as_ref().is_some_and(|rc| {
rc.items.iter().any(|item| match item {
uni_cypher::ast::ReturnItem::Expr { expr, .. } => expr_has_property_access(expr),
uni_cypher::ast::ReturnItem::All => false,
})
});
where_has_property || return_has_property
}
#[allow(dead_code)]
fn expr_has_property_access(expr: &Expr) -> bool {
match expr {
Expr::Property(..) => true,
Expr::BinaryOp { left, right, .. } => {
expr_has_property_access(left) || expr_has_property_access(right)
}
Expr::UnaryOp { expr, .. } => expr_has_property_access(expr),
Expr::FunctionCall { args, .. } => args.iter().any(expr_has_property_access),
Expr::List(items) => items.iter().any(expr_has_property_access),
Expr::Map(entries) => entries.iter().any(|(_, e)| expr_has_property_access(e)),
Expr::Case {
expr: case_expr,
when_then,
else_expr,
} => {
case_expr
.as_ref()
.is_some_and(|e| expr_has_property_access(e))
|| when_then
.iter()
.any(|(w, t)| expr_has_property_access(w) || expr_has_property_access(t))
|| else_expr
.as_ref()
.is_some_and(|e| expr_has_property_access(e))
}
Expr::IsNull(e) | Expr::IsNotNull(e) | Expr::IsUnique(e) => expr_has_property_access(e),
Expr::In { expr, list } => expr_has_property_access(expr) || expr_has_property_access(list),
Expr::ArrayIndex { array, index } => {
expr_has_property_access(array) || expr_has_property_access(index)
}
Expr::ArraySlice { array, start, end } => {
expr_has_property_access(array)
|| start.as_ref().is_some_and(|e| expr_has_property_access(e))
|| end.as_ref().is_some_and(|e| expr_has_property_access(e))
}
Expr::Quantifier {
list, predicate, ..
} => expr_has_property_access(list) || expr_has_property_access(predicate),
Expr::Reduce {
init, list, expr, ..
} => {
expr_has_property_access(init)
|| expr_has_property_access(list)
|| expr_has_property_access(expr)
}
Expr::ListComprehension {
list,
where_clause,
map_expr,
..
} => {
expr_has_property_access(list)
|| where_clause
.as_ref()
.is_some_and(|e| expr_has_property_access(e))
|| expr_has_property_access(map_expr)
}
Expr::PatternComprehension {
where_clause,
map_expr,
..
} => {
where_clause
.as_ref()
.is_some_and(|e| expr_has_property_access(e))
|| expr_has_property_access(map_expr)
}
Expr::ValidAt {
entity, timestamp, ..
} => expr_has_property_access(entity) || expr_has_property_access(timestamp),
Expr::MapProjection { base, .. } => expr_has_property_access(base),
Expr::LabelCheck { expr, .. } => expr_has_property_access(expr),
_ => false,
}
}
#[expect(
clippy::too_many_arguments,
reason = "program evaluation requires full graph and session context"
)]
async fn run_program(
strata: Vec<LocyStratum>,
commands: Vec<LocyCommand>,
registry: Arc<DerivedScanRegistry>,
graph_ctx: Arc<GraphExecutionContext>,
session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
storage: Arc<StorageManager>,
schema_info: Arc<UniSchema>,
params: HashMap<String, Value>,
output_schema: SchemaRef,
max_iterations: usize,
timeout: Duration,
max_derived_bytes: usize,
deterministic_best_by: bool,
strict_probability_domain: bool,
probability_epsilon: f64,
exact_probability: bool,
max_bdd_variables: usize,
derived_store_slot: Arc<StdRwLock<Option<DerivedStore>>>,
approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
iteration_counts_slot: Arc<StdRwLock<HashMap<String, usize>>>,
peak_memory_slot: Arc<StdRwLock<usize>>,
derivation_tracker: Option<Arc<ProvenanceStore>>,
warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
command_results_slot: Arc<StdRwLock<Vec<(usize, CommandResult)>>>,
top_k_proofs: usize,
timeout_flag: Arc<std::sync::atomic::AtomicBool>,
) -> DFResult<Vec<RecordBatch>> {
let start = Instant::now();
let mut derived_store = DerivedStore::new();
for stratum in &strata {
write_cross_stratum_facts(®istry, &derived_store, stratum);
let remaining_timeout = timeout.saturating_sub(start.elapsed());
if remaining_timeout.is_zero() {
tracing::warn!("Locy program timeout exceeded during stratum evaluation");
timeout_flag.store(true, std::sync::atomic::Ordering::Relaxed);
break;
}
if stratum.is_recursive {
let fixpoint_rules =
convert_to_fixpoint_plans(&stratum.rules, ®istry, deterministic_best_by)?;
let fixpoint_schema = build_fixpoint_output_schema(&stratum.rules);
let exec = FixpointExec::new(
fixpoint_rules,
max_iterations,
remaining_timeout,
Arc::clone(&graph_ctx),
Arc::clone(&session_ctx),
Arc::clone(&storage),
Arc::clone(&schema_info),
params.clone(),
Arc::clone(®istry),
fixpoint_schema,
max_derived_bytes,
derivation_tracker.clone(),
Arc::clone(&iteration_counts_slot),
strict_probability_domain,
probability_epsilon,
exact_probability,
max_bdd_variables,
Arc::clone(&warnings_slot),
Arc::clone(&approximate_slot),
top_k_proofs,
Arc::clone(&timeout_flag),
);
let task_ctx = session_ctx.read().task_ctx();
let exec_arc: Arc<dyn ExecutionPlan> = Arc::new(exec);
let batches = collect_all_partitions(&exec_arc, task_ctx).await?;
for rule in &stratum.rules {
if rule.yield_schema.is_empty() {
continue;
}
let rule_entries = registry.entries_for_rule(&rule.name);
for entry in rule_entries {
if !entry.is_self_ref {
let all_facts: Vec<RecordBatch> = batches
.iter()
.filter(|b| {
let rule_schema = yield_columns_to_arrow_schema(&rule.yield_schema);
b.schema().fields().len() == rule_schema.fields().len()
})
.cloned()
.collect();
let mut guard = entry.data.write();
*guard = if all_facts.is_empty() {
vec![RecordBatch::new_empty(Arc::clone(&entry.schema))]
} else {
all_facts
};
}
}
derived_store.insert(rule.name.clone(), batches.clone());
}
} else {
let fixpoint_rules =
convert_to_fixpoint_plans(&stratum.rules, ®istry, deterministic_best_by)?;
let task_ctx = session_ctx.read().task_ctx();
for (rule, fp_rule) in stratum.rules.iter().zip(fixpoint_rules.iter()) {
if rule.yield_schema.is_empty() {
continue;
}
let mut tagged_clause_facts: Vec<(usize, Vec<RecordBatch>)> = Vec::new();
for (clause_idx, (clause, fp_clause)) in
rule.clauses.iter().zip(fp_rule.clauses.iter()).enumerate()
{
let mut batches = execute_subplan(
&clause.body,
¶ms,
&HashMap::new(),
&graph_ctx,
&session_ctx,
&storage,
&schema_info,
)
.await?;
for binding in &fp_clause.is_ref_bindings {
if binding.negated
&& !binding.anti_join_cols.is_empty()
&& let Some(entry) = registry.get(binding.derived_scan_index)
{
let neg_facts = entry.data.read().clone();
if !neg_facts.is_empty() {
if binding.target_has_prob && fp_rule.prob_column_name.is_some() {
let complement_col =
format!("__prob_complement_{}", binding.rule_name);
if let Some(prob_col) = &binding.target_prob_col {
batches =
super::locy_fixpoint::apply_prob_complement_composite(
batches,
&neg_facts,
&binding.anti_join_cols,
prob_col,
&complement_col,
)?;
} else {
batches = super::locy_fixpoint::apply_anti_join_composite(
batches,
&neg_facts,
&binding.anti_join_cols,
)?;
}
} else {
batches = super::locy_fixpoint::apply_anti_join_composite(
batches,
&neg_facts,
&binding.anti_join_cols,
)?;
}
}
}
}
let complement_cols: Vec<String> = if !batches.is_empty() {
batches[0]
.schema()
.fields()
.iter()
.filter(|f| f.name().starts_with("__prob_complement_"))
.map(|f| f.name().clone())
.collect()
} else {
vec![]
};
if !complement_cols.is_empty() {
batches = super::locy_fixpoint::multiply_prob_factors(
batches,
fp_rule.prob_column_name.as_deref(),
&complement_cols,
)?;
}
tagged_clause_facts.push((clause_idx, batches));
}
let shared_info = if let Some(ref tracker) = derivation_tracker {
super::locy_fixpoint::record_and_detect_lineage_nonrecursive(
fp_rule,
&tagged_clause_facts,
tracker,
&warnings_slot,
®istry,
top_k_proofs,
)
} else {
None
};
let mut all_clause_facts: Vec<RecordBatch> = tagged_clause_facts
.into_iter()
.flat_map(|(_, batches)| batches)
.collect();
if exact_probability
&& let Some(ref info) = shared_info
&& let Some(ref tracker) = derivation_tracker
{
all_clause_facts = super::locy_fixpoint::apply_exact_wmc(
all_clause_facts,
fp_rule,
info,
tracker,
max_bdd_variables,
&warnings_slot,
&approximate_slot,
)?;
}
let facts = super::locy_fixpoint::apply_post_fixpoint_chain(
all_clause_facts,
fp_rule,
&task_ctx,
strict_probability_domain,
probability_epsilon,
)
.await?;
write_facts_to_registry(®istry, &rule.name, &facts);
derived_store.insert(rule.name.clone(), facts);
}
}
}
let peak_bytes: usize = derived_store
.relations
.values()
.flat_map(|batches| batches.iter())
.map(|b| {
b.columns()
.iter()
.map(|col| col.get_buffer_memory_size())
.sum::<usize>()
})
.sum();
*peak_memory_slot.write().unwrap() = peak_bytes;
let first_derive_idx = commands
.iter()
.position(|c| matches!(c, LocyCommand::Derive { .. }));
let mut inline_results: Vec<(usize, CommandResult)> = Vec::new();
for (cmd_idx, cmd) in commands.iter().enumerate() {
if let LocyCommand::Cypher { query } = cmd {
if first_derive_idx.is_some_and(|di| cmd_idx > di) {
continue;
}
let rows = execute_cypher_inline(
query,
&schema_info,
¶ms,
&graph_ctx,
&session_ctx,
&storage,
)
.await?;
inline_results.push((cmd_idx, CommandResult::Cypher(rows)));
}
}
*command_results_slot.write().unwrap() = inline_results;
let stats = vec![build_stats_batch(&derived_store, &strata, output_schema)];
*derived_store_slot.write().unwrap() = Some(derived_store);
Ok(stats)
}
fn write_cross_stratum_facts(
registry: &DerivedScanRegistry,
derived_store: &DerivedStore,
stratum: &LocyStratum,
) {
for rule in &stratum.rules {
for clause in &rule.clauses {
for is_ref in &clause.is_refs {
if let Some(facts) = derived_store.get(&is_ref.rule_name) {
write_facts_to_registry(registry, &is_ref.rule_name, facts);
}
}
}
}
}
fn write_facts_to_registry(registry: &DerivedScanRegistry, rule_name: &str, facts: &[RecordBatch]) {
let entries = registry.entries_for_rule(rule_name);
for entry in entries {
if !entry.is_self_ref {
let mut guard = entry.data.write();
*guard = if facts.is_empty() || facts.iter().all(|b| b.num_rows() == 0) {
vec![RecordBatch::new_empty(Arc::clone(&entry.schema))]
} else {
facts
.iter()
.filter(|b| b.num_rows() > 0)
.map(|b| {
RecordBatch::try_new(Arc::clone(&entry.schema), b.columns().to_vec())
.unwrap_or_else(|_| b.clone())
})
.collect()
};
}
}
}
fn convert_to_fixpoint_plans(
rules: &[LocyRulePlan],
registry: &DerivedScanRegistry,
deterministic_best_by: bool,
) -> DFResult<Vec<FixpointRulePlan>> {
rules
.iter()
.map(|rule| {
let yield_schema = yield_columns_to_arrow_schema(&rule.yield_schema);
let key_column_indices: Vec<usize> = rule
.yield_schema
.iter()
.enumerate()
.filter(|(_, yc)| yc.is_key)
.map(|(i, _)| i)
.collect();
let clauses: Vec<FixpointClausePlan> = rule
.clauses
.iter()
.map(|clause| {
let is_ref_bindings = convert_is_refs(&clause.is_refs, registry)?;
Ok(FixpointClausePlan {
body_logical: clause.body.clone(),
is_ref_bindings,
priority: clause.priority,
along_bindings: clause.along_bindings.clone(),
})
})
.collect::<DFResult<Vec<_>>>()?;
let fold_bindings = convert_fold_bindings(&rule.fold_bindings, &rule.yield_schema)?;
let best_by_criteria =
convert_best_by_criteria(&rule.best_by_criteria, &rule.yield_schema)?;
let has_priority = rule.priority.is_some();
let yield_schema = if has_priority {
let mut fields: Vec<Arc<Field>> = yield_schema.fields().iter().cloned().collect();
fields.push(Arc::new(Field::new("__priority", DataType::Int64, true)));
ArrowSchema::new(fields)
} else {
yield_schema
};
let prob_column_name = rule
.yield_schema
.iter()
.find(|yc| yc.is_prob)
.map(|yc| yc.name.clone());
Ok(FixpointRulePlan {
name: rule.name.clone(),
clauses,
yield_schema: Arc::new(yield_schema),
key_column_indices,
priority: rule.priority,
has_fold: !rule.fold_bindings.is_empty(),
fold_bindings,
having: rule.having.clone(),
has_best_by: !rule.best_by_criteria.is_empty(),
best_by_criteria,
has_priority,
deterministic: deterministic_best_by,
prob_column_name,
})
})
.collect()
}
fn convert_is_refs(
is_refs: &[LocyIsRef],
registry: &DerivedScanRegistry,
) -> DFResult<Vec<IsRefBinding>> {
is_refs
.iter()
.map(|is_ref| {
let entries = registry.entries_for_rule(&is_ref.rule_name);
let entry = entries
.iter()
.find(|e| e.is_self_ref)
.or_else(|| entries.first())
.ok_or_else(|| {
datafusion::error::DataFusionError::Plan(format!(
"No derived scan entry found for IS-ref to '{}'",
is_ref.rule_name
))
})?;
let anti_join_cols = if is_ref.negated {
let mut cols: Vec<(String, String)> = is_ref
.subjects
.iter()
.enumerate()
.filter_map(|(i, s)| {
if let uni_cypher::ast::Expr::Variable(var) = s {
let right_col = entry
.schema
.fields()
.get(i)
.map(|f| f.name().clone())
.unwrap_or_else(|| var.clone());
Some((var.clone(), right_col))
} else {
None
}
})
.collect();
if let Some(uni_cypher::ast::Expr::Variable(target_var)) = &is_ref.target {
let target_idx = is_ref.subjects.len();
if let Some(field) = entry.schema.fields().get(target_idx) {
cols.push((target_var.clone(), field.name().clone()));
}
}
cols
} else {
Vec::new()
};
let provenance_join_cols: Vec<(String, String)> = is_ref
.subjects
.iter()
.enumerate()
.filter_map(|(i, s)| {
if let uni_cypher::ast::Expr::Variable(var) = s {
let right_col = entry
.schema
.fields()
.get(i)
.map(|f| f.name().clone())
.unwrap_or_else(|| var.clone());
Some((var.clone(), right_col))
} else {
None
}
})
.collect();
Ok(IsRefBinding {
derived_scan_index: entry.scan_index,
rule_name: is_ref.rule_name.clone(),
is_self_ref: entry.is_self_ref,
negated: is_ref.negated,
anti_join_cols,
target_has_prob: is_ref.target_has_prob,
target_prob_col: is_ref.target_prob_col.clone(),
provenance_join_cols,
})
})
.collect()
}
fn convert_fold_bindings(
fold_bindings: &[(String, String, Expr)],
yield_schema: &[LocyYieldColumn],
) -> DFResult<Vec<FoldBinding>> {
fold_bindings
.iter()
.map(|(name, yield_alias, expr)| {
let (kind, _input_col_name) = parse_fold_aggregate(expr)?;
if kind == FoldAggKind::CountAll {
return Ok(FoldBinding {
output_name: yield_alias.clone(),
kind,
input_col_index: 0, input_col_name: None,
});
}
let input_col_index = yield_schema
.iter()
.position(|yc| yc.name == *name || yc.name == *yield_alias)
.unwrap_or(0);
Ok(FoldBinding {
output_name: yield_alias.clone(),
kind,
input_col_index,
input_col_name: Some(name.clone()),
})
})
.collect()
}
fn parse_fold_aggregate(expr: &Expr) -> DFResult<(FoldAggKind, String)> {
match expr {
Expr::FunctionCall { name, args, .. } => {
let upper = name.to_uppercase();
let is_count = matches!(upper.as_str(), "COUNT" | "MCOUNT");
if is_count && args.is_empty() {
return Ok((FoldAggKind::CountAll, String::new()));
}
let kind = match upper.as_str() {
"SUM" | "MSUM" => FoldAggKind::Sum,
"MAX" | "MMAX" => FoldAggKind::Max,
"MIN" | "MMIN" => FoldAggKind::Min,
"COUNT" | "MCOUNT" => FoldAggKind::Count,
"AVG" => FoldAggKind::Avg,
"COLLECT" => FoldAggKind::Collect,
"MNOR" => FoldAggKind::Nor,
"MPROD" => FoldAggKind::Prod,
_ => {
return Err(datafusion::error::DataFusionError::Plan(format!(
"Unknown FOLD aggregate function: {}",
name
)));
}
};
let col_name = match args.first() {
Some(Expr::Variable(v)) => v.clone(),
Some(Expr::Property(_, prop)) => prop.clone(),
Some(other) => other.to_string_repr(),
None => {
return Err(datafusion::error::DataFusionError::Plan(
"FOLD aggregate function requires at least one argument".to_string(),
));
}
};
Ok((kind, col_name))
}
_ => Err(datafusion::error::DataFusionError::Plan(
"FOLD binding must be a function call (e.g., SUM(x))".to_string(),
)),
}
}
fn convert_best_by_criteria(
criteria: &[(Expr, bool)],
yield_schema: &[LocyYieldColumn],
) -> DFResult<Vec<SortCriterion>> {
criteria
.iter()
.map(|(expr, ascending)| {
let col_name = match expr {
Expr::Property(_, prop) => prop.clone(),
Expr::Variable(v) => v.clone(),
_ => {
return Err(datafusion::error::DataFusionError::Plan(
"BEST BY criterion must be a variable or property reference".to_string(),
));
}
};
let col_index = yield_schema
.iter()
.position(|yc| yc.name == col_name)
.or_else(|| {
let short_name = col_name.rsplit('.').next().unwrap_or(&col_name);
yield_schema.iter().position(|yc| yc.name == short_name)
})
.ok_or_else(|| {
datafusion::error::DataFusionError::Plan(format!(
"BEST BY column '{}' not found in yield schema",
col_name
))
})?;
Ok(SortCriterion {
col_index,
ascending: *ascending,
nulls_first: false,
})
})
.collect()
}
fn yield_columns_to_arrow_schema(columns: &[LocyYieldColumn]) -> ArrowSchema {
let fields: Vec<Arc<Field>> = columns
.iter()
.map(|yc| Arc::new(Field::new(&yc.name, yc.data_type.clone(), true)))
.collect();
ArrowSchema::new(fields)
}
fn build_fixpoint_output_schema(rules: &[LocyRulePlan]) -> SchemaRef {
if let Some(rule) = rules.first() {
Arc::new(yield_columns_to_arrow_schema(&rule.yield_schema))
} else {
Arc::new(ArrowSchema::empty())
}
}
fn build_stats_batch(
derived_store: &DerivedStore,
_strata: &[LocyStratum],
output_schema: SchemaRef,
) -> RecordBatch {
let mut rule_names: Vec<String> = derived_store.rule_names().map(String::from).collect();
rule_names.sort();
let name_col: arrow_array::StringArray = rule_names.iter().map(|s| Some(s.as_str())).collect();
let count_col: arrow_array::Int64Array = rule_names
.iter()
.map(|name| Some(derived_store.fact_count(name) as i64))
.collect();
let stats_schema = stats_schema();
RecordBatch::try_new(stats_schema, vec![Arc::new(name_col), Arc::new(count_col)])
.unwrap_or_else(|_| RecordBatch::new_empty(output_schema))
}
pub fn stats_schema() -> SchemaRef {
Arc::new(ArrowSchema::new(vec![
Arc::new(Field::new("rule_name", DataType::Utf8, false)),
Arc::new(Field::new("fact_count", DataType::Int64, false)),
]))
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::{Int64Array, LargeBinaryArray, StringArray};
#[test]
fn test_derived_store_insert_and_get() {
let mut store = DerivedStore::new();
assert!(store.get("test").is_none());
let schema = Arc::new(ArrowSchema::new(vec![Arc::new(Field::new(
"x",
DataType::LargeBinary,
true,
))]));
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![Arc::new(LargeBinaryArray::from(vec![
Some(b"a" as &[u8]),
Some(b"b"),
]))],
)
.unwrap();
store.insert("test".to_string(), vec![batch.clone()]);
let facts = store.get("test").unwrap();
assert_eq!(facts.len(), 1);
assert_eq!(facts[0].num_rows(), 2);
}
#[test]
fn test_derived_store_fact_count() {
let mut store = DerivedStore::new();
assert_eq!(store.fact_count("empty"), 0);
let schema = Arc::new(ArrowSchema::new(vec![Arc::new(Field::new(
"x",
DataType::LargeBinary,
true,
))]));
let batch1 = RecordBatch::try_new(
Arc::clone(&schema),
vec![Arc::new(LargeBinaryArray::from(vec![Some(b"a" as &[u8])]))],
)
.unwrap();
let batch2 = RecordBatch::try_new(
Arc::clone(&schema),
vec![Arc::new(LargeBinaryArray::from(vec![
Some(b"b" as &[u8]),
Some(b"c"),
]))],
)
.unwrap();
store.insert("test".to_string(), vec![batch1, batch2]);
assert_eq!(store.fact_count("test"), 3);
}
#[test]
fn test_stats_batch_schema() {
let schema = stats_schema();
assert_eq!(schema.fields().len(), 2);
assert_eq!(schema.field(0).name(), "rule_name");
assert_eq!(schema.field(1).name(), "fact_count");
assert_eq!(schema.field(0).data_type(), &DataType::Utf8);
assert_eq!(schema.field(1).data_type(), &DataType::Int64);
}
#[test]
fn test_stats_batch_content() {
let mut store = DerivedStore::new();
let schema = Arc::new(ArrowSchema::new(vec![Arc::new(Field::new(
"x",
DataType::LargeBinary,
true,
))]));
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![Arc::new(LargeBinaryArray::from(vec![
Some(b"a" as &[u8]),
Some(b"b"),
]))],
)
.unwrap();
store.insert("reach".to_string(), vec![batch]);
let output_schema = stats_schema();
let stats = build_stats_batch(&store, &[], Arc::clone(&output_schema));
assert_eq!(stats.num_rows(), 1);
let names = stats
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert_eq!(names.value(0), "reach");
let counts = stats
.column(1)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(counts.value(0), 2);
}
#[test]
fn test_yield_columns_to_arrow_schema() {
let columns = vec![
LocyYieldColumn {
name: "a".to_string(),
is_key: true,
is_prob: false,
data_type: DataType::UInt64,
},
LocyYieldColumn {
name: "b".to_string(),
is_key: false,
is_prob: false,
data_type: DataType::LargeUtf8,
},
LocyYieldColumn {
name: "c".to_string(),
is_key: true,
is_prob: false,
data_type: DataType::Float64,
},
];
let schema = yield_columns_to_arrow_schema(&columns);
assert_eq!(schema.fields().len(), 3);
assert_eq!(schema.field(0).name(), "a");
assert_eq!(schema.field(1).name(), "b");
assert_eq!(schema.field(2).name(), "c");
assert_eq!(schema.field(0).data_type(), &DataType::UInt64);
assert_eq!(schema.field(1).data_type(), &DataType::LargeUtf8);
assert_eq!(schema.field(2).data_type(), &DataType::Float64);
for field in schema.fields() {
assert!(field.is_nullable());
}
}
#[test]
fn test_key_column_indices() {
let columns = [
LocyYieldColumn {
name: "a".to_string(),
is_key: true,
is_prob: false,
data_type: DataType::LargeBinary,
},
LocyYieldColumn {
name: "b".to_string(),
is_key: false,
is_prob: false,
data_type: DataType::LargeBinary,
},
LocyYieldColumn {
name: "c".to_string(),
is_key: true,
is_prob: false,
data_type: DataType::LargeBinary,
},
];
let key_indices: Vec<usize> = columns
.iter()
.enumerate()
.filter(|(_, yc)| yc.is_key)
.map(|(i, _)| i)
.collect();
assert_eq!(key_indices, vec![0, 2]);
}
#[test]
fn test_parse_fold_aggregate_sum() {
let expr = Expr::FunctionCall {
name: "SUM".to_string(),
args: vec![Expr::Variable("cost".to_string())],
distinct: false,
window_spec: None,
};
let (kind, col) = parse_fold_aggregate(&expr).unwrap();
assert!(matches!(kind, FoldAggKind::Sum));
assert_eq!(col, "cost");
}
#[test]
fn test_parse_fold_aggregate_monotonic() {
let expr = Expr::FunctionCall {
name: "MMAX".to_string(),
args: vec![Expr::Variable("score".to_string())],
distinct: false,
window_spec: None,
};
let (kind, col) = parse_fold_aggregate(&expr).unwrap();
assert!(matches!(kind, FoldAggKind::Max));
assert_eq!(col, "score");
}
#[test]
fn test_parse_fold_aggregate_unknown() {
let expr = Expr::FunctionCall {
name: "UNKNOWN_AGG".to_string(),
args: vec![Expr::Variable("x".to_string())],
distinct: false,
window_spec: None,
};
assert!(parse_fold_aggregate(&expr).is_err());
}
#[test]
fn test_no_commands_returns_stats() {
let store = DerivedStore::new();
let output_schema = stats_schema();
let stats = build_stats_batch(&store, &[], Arc::clone(&output_schema));
assert_eq!(stats.num_rows(), 0);
}
}