Skip to main content

uni_plugin/traits/
locy.rs

1//! Locy aggregate and predicate plugins.
2//!
3//! Locy aggregates (used in `FOLD value AS X`) require `Semilattice`
4//! metadata so the fixpoint engine can verify its monotonicity proofs
5//! explicitly.
6//!
7//! Locy predicates evaluate to boolean (or fuzzy) columns and are the
8//! surface neural predicates plug into.
9
10use arrow_array::{Array, BooleanArray, Float64Array};
11use arrow_schema::DataType;
12use datafusion::arrow::record_batch::RecordBatch;
13use datafusion::logical_expr::{ColumnarValue, Volatility};
14use datafusion::scalar::ScalarValue;
15
16use crate::errors::FnError;
17use crate::traits::scalar::ArgType;
18
19/// Probability semiring selected for a `FOLD` evaluation.
20///
21/// Mirrors the host's `uni_locy::SemiringKind` minus the provenance-only
22/// `TopKProofs` / `BddExact` variants, which are handled above the aggregate
23/// trait by the executor. Plugin aggregates only ever see the two value-level
24/// combinators: independence (`AddMult`) and Viterbi/fuzzy (`MaxMin`).
25#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
26pub enum FoldSemiring {
27    /// Independence-mode probability: noisy-OR (`1 − ∏(1 − pᵢ)`) / product.
28    #[default]
29    AddMult,
30    /// Viterbi / fuzzy-truth: max-disjunction / min-conjunction.
31    MaxMin,
32}
33
34/// Per-fold evaluation context threaded into [`LocyAggState::ingest_indices`].
35///
36/// Carries the probability-domain policy (`strict`), the underflow guard
37/// (`epsilon`, used by bounded-product log-space switching), and the active
38/// [`FoldSemiring`]. Constructed once per `FOLD` execution and passed by
39/// reference to every per-group ingest call.
40#[derive(Clone, Copy, Debug, PartialEq)]
41pub struct FoldContext {
42    /// When `true`, probability-domain aggregates error on inputs outside
43    /// `[0, 1]` instead of clamping them with a warning.
44    pub strict: bool,
45    /// Underflow threshold: bounded-product switches to log-space once the
46    /// running product drops below this value. `0.0` disables the switch.
47    pub epsilon: f64,
48    /// Active probability semiring for this fold.
49    pub semiring: FoldSemiring,
50}
51
52impl Default for FoldContext {
53    fn default() -> Self {
54        Self {
55            strict: false,
56            epsilon: 0.0,
57            semiring: FoldSemiring::AddMult,
58        }
59    }
60}
61
62/// A Locy aggregate plugin (`FOLD value AS plugin_name`).
63///
64/// The fixpoint engine uses `semilattice()` metadata to verify monotonicity
65/// and prove termination; non-monotone aggregates are rejected at compile
66/// time when used inside a recursive Locy clause.
67///
68/// `Debug` is a supertrait so `Arc<dyn LocyAggregate>` can sit inside
69/// `#[derive(Debug)]` structs in the fixpoint engine. The 9 built-in
70/// impls already `#[derive(Debug)]` their unit/struct types.
71pub trait LocyAggregate: Send + Sync + std::fmt::Debug {
72    /// Lattice properties used by the fixpoint engine.
73    fn semilattice(&self) -> Semilattice;
74
75    /// Declared output type for `FOLD` results.
76    fn output_type(&self) -> DataType;
77
78    /// Output type given the aggregate's *input* column type.
79    ///
80    /// Defaults to [`LocyAggregate::output_type`] (input-independent).
81    /// Type-preserving aggregates (`MIN` / `MAX`) override this to return the
82    /// input type so an `Int64` column folds to an `Int64` result rather than
83    /// being widened to `Float64`.
84    fn output_type_for_input(&self, _input: &DataType) -> DataType {
85        self.output_type()
86    }
87
88    /// Construct a fresh per-grouping state.
89    fn create(&self) -> Box<dyn LocyAggState>;
90
91    /// Initial accumulator value for the row-level fast path used by
92    /// the Locy fixpoint engine ([`MonotonicAggState`]). For numeric
93    /// aggregates this is the identity element (`0` for SUM/COUNT/NOR,
94    /// `1` for PROD, `+inf` for MIN, `-inf` for MAX). Returns `None`
95    /// for aggregates that have no row-level fast path (`AVG`, `COLLECT`
96    /// — these run outside the fast path).
97    ///
98    /// [`MonotonicAggState`]: ../../uni_query/query/df_graph/locy_fixpoint/struct.MonotonicAggState.html
99    fn initial_accum_f64(&self) -> Option<f64> {
100        None
101    }
102
103    /// Row-level update step on a primitive `f64` accumulator.
104    ///
105    /// Returns the new accumulator value after folding `val` into `accum`.
106    /// `strict` enables strict-mode probability-domain validation for
107    /// `MNOR` / `MPROD` (inputs outside `[0, 1]` produce a `FnError`
108    /// instead of being clamped with a warning).
109    ///
110    /// Default impl returns [`FnError::CODE_UNKNOWN_FUNCTION`] indicating
111    /// the aggregate has no row-level fast path; the fixpoint engine
112    /// must use the batch-shape [`LocyAggState::ingest`] path instead.
113    ///
114    /// # Errors
115    ///
116    /// - In strict mode with an out-of-domain value for a probabilistic
117    ///   aggregate.
118    /// - When the aggregate has no row-level path (default impl).
119    fn update_step(&self, _accum: f64, _val: f64, _strict: bool) -> Result<f64, FnError> {
120        Err(FnError::new(
121            FnError::CODE_UNKNOWN_FUNCTION,
122            "aggregate has no row-level update_step; use ingest()",
123        ))
124    }
125
126    /// True if this aggregate operates on the probability domain `[0, 1]`.
127    ///
128    /// Used by the Locy fixpoint engine to trigger provenance tracking
129    /// (shared-proof detection) when any rule's stratum has a
130    /// probability-domain aggregate. Default `false`. Override `true`
131    /// for `MNOR`, `MPROD`, and future probability-domain aggregates
132    /// authored by users.
133    fn is_probability_aggregate(&self) -> bool {
134        false
135    }
136
137    /// True if this aggregate is the noisy-OR semiring (`1 − ∏(1 − pᵢ)`).
138    ///
139    /// Used by the fixpoint engine's `apply_post_fixpoint_chain` to
140    /// select the per-row probability combination operator when
141    /// multiple independent evidence sources are joined. Default
142    /// `false`. Override `true` for `MNOR`.
143    fn is_noisy_or(&self) -> bool {
144        false
145    }
146}
147
148/// Per-grouping state for a [`LocyAggregate`].
149///
150/// `'static` is required so the fixpoint engine can safely downcast
151/// `&dyn LocyAggState` to the concrete state via [`LocyAggState::as_any`]
152/// during `merge`. Implementations expose `as_any` with a one-liner
153/// `fn as_any(&self) -> &dyn std::any::Any { self }`.
154pub trait LocyAggState: Send + 'static {
155    /// Return `&dyn Any` for safe downcasting in `merge` implementations.
156    ///
157    /// Default `impl` is `self`. Implementations should not override unless
158    /// they need to expose a different concrete type than the implementor.
159    fn as_any(&self) -> &dyn std::any::Any;
160
161    /// Ingest the rows at `indices` of `col` into the state under `cx`.
162    ///
163    /// This is the primitive the non-recursive `FOLD` executor calls once per
164    /// key group: `col` is the whole fold-input column and `indices` selects
165    /// the rows belonging to one group. [`FoldContext`] carries the
166    /// strict-domain / epsilon / semiring policy. Implementations skip null
167    /// rows. Built-in aggregates override this for byte-identical, context-aware
168    /// folding; user aggregates may rely on the [`LocyAggState::ingest`] default
169    /// if they do not need per-group dispatch.
170    ///
171    /// # Errors
172    ///
173    /// Returns [`FnError`] if a value cannot be ingested (e.g., a strict
174    /// probability-domain violation or an unexpected Arrow type).
175    fn ingest_indices(
176        &mut self,
177        col: &dyn Array,
178        indices: &[usize],
179        cx: &FoldContext,
180    ) -> Result<(), FnError>;
181
182    /// Ingest every row of column `value_col` in `batch` into the state.
183    ///
184    /// Convenience wrapper over [`LocyAggState::ingest_indices`] across all
185    /// rows with a default [`FoldContext`]. Kept for callers and tests that
186    /// fold a whole batch as a single group.
187    ///
188    /// # Errors
189    ///
190    /// Returns [`FnError`] if the column cannot be ingested.
191    fn ingest(&mut self, batch: &RecordBatch, value_col: usize) -> Result<(), FnError> {
192        let indices: Vec<usize> = (0..batch.num_rows()).collect();
193        self.ingest_indices(batch.column(value_col), &indices, &FoldContext::default())
194    }
195
196    /// Merge `other`'s state into `self`.
197    ///
198    /// # Errors
199    ///
200    /// Returns [`FnError`] if the states cannot be merged (e.g., type
201    /// mismatch between aggregate instances).
202    fn merge(&mut self, other: &dyn LocyAggState) -> Result<(), FnError>;
203
204    /// Produce the final aggregated value.
205    ///
206    /// # Errors
207    ///
208    /// Returns [`FnError`] if the value cannot be finalized.
209    fn finalize(&self) -> Result<ScalarValue, FnError>;
210
211    /// Fixpoint shortcut: `true` once no further `ingest` can change state.
212    ///
213    /// `MAX` over a bounded domain returns `true` at the top; `SUM` never
214    /// returns `true`. The fixpoint engine uses this to terminate early.
215    fn is_at_top(&self) -> bool {
216        false
217    }
218}
219
220/// Lattice properties of an aggregate.
221///
222/// Used by the Locy fixpoint engine to verify monotonicity and prove
223/// termination. The flags are not independent: `monotone_join` typically
224/// implies `commutative` and `associative`.
225#[derive(Clone, Copy, Debug, PartialEq, Eq)]
226pub struct Semilattice {
227    /// `f(x, x) == x`. Idempotent aggregates can deduplicate inputs.
228    pub idempotent: bool,
229    /// `f(x, y) == f(y, x)`. Commutative aggregates are order-independent.
230    pub commutative: bool,
231    /// `f(f(x, y), z) == f(x, f(y, z))`. Associative aggregates can be
232    /// partial-aggregated.
233    pub associative: bool,
234    /// `f` preserves or raises the partial order. Monotone aggregates
235    /// produce sound fixpoints; non-monotone ones cannot be used inside
236    /// recursive Locy clauses.
237    pub monotone_join: bool,
238    /// Bounded domain — `is_at_top()` may return `true`. Enables fixpoint
239    /// shortcuts (no further ingest can change the state).
240    pub has_top: bool,
241}
242
243impl Semilattice {
244    /// Properties of a non-monotone aggregate (`SUM`, `AVG`).
245    ///
246    /// Such aggregates may not appear inside recursive Locy clauses.
247    pub const NON_MONOTONE: Self = Self {
248        idempotent: false,
249        commutative: true,
250        associative: true,
251        monotone_join: false,
252        has_top: false,
253    };
254
255    /// Properties of `MIN` / `MAX` over a bounded domain — fully monotone.
256    pub const BOUNDED_MIN_MAX: Self = Self {
257        idempotent: true,
258        commutative: true,
259        associative: true,
260        monotone_join: true,
261        has_top: true,
262    };
263
264    /// Properties of `COUNT` — monotone but unbounded.
265    pub const COUNT: Self = Self {
266        idempotent: false,
267        commutative: true,
268        associative: true,
269        monotone_join: true,
270        has_top: false,
271    };
272}
273
274/// A Locy predicate plugin — boolean (or fuzzy) column over inputs.
275pub trait LocyPredicate: Send + Sync {
276    /// Static signature.
277    fn signature(&self) -> &PredSignature;
278
279    /// Evaluate the predicate over a batch of inputs to a boolean column.
280    ///
281    /// # Errors
282    ///
283    /// Returns [`FnError`] if the predicate cannot be evaluated on this input.
284    fn evaluate(&self, args: &[ColumnarValue], rows: usize) -> Result<BooleanArray, FnError>;
285
286    /// Optional fuzzy evaluation — `Some(scores)` for predicates that
287    /// participate in PROB chains, `None` otherwise.
288    ///
289    /// # Errors
290    ///
291    /// Returns [`FnError`] if fuzzy evaluation is unsupported or fails.
292    fn evaluate_fuzzy(
293        &self,
294        _args: &[ColumnarValue],
295        _rows: usize,
296    ) -> Option<Result<Float64Array, FnError>> {
297        None
298    }
299}
300
301/// Static signature of a Locy predicate.
302#[derive(Clone, Debug)]
303pub struct PredSignature {
304    /// Argument types.
305    pub args: Vec<ArgType>,
306    /// Volatility.
307    pub volatility: Volatility,
308    /// Whether `evaluate_fuzzy` returns `Some(...)`.
309    pub supports_fuzzy: bool,
310    /// Hint for batch sizing — neural predicates often prefer larger batches.
311    pub batch_hint: BatchHint,
312}
313
314/// Preferred batch size for predicate evaluation.
315#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
316#[non_exhaustive]
317pub enum BatchHint {
318    /// Small batches; row-at-a-time is acceptable.
319    Small,
320    /// Medium batches; the default.
321    #[default]
322    Medium,
323    /// Large batches; the host should accumulate many rows before invoking
324    /// (neural predicates benefit dramatically).
325    Large,
326}
327
328/// Provenance / derivation tracker — placeholder.
329///
330/// `LocyAggState::provenance` returns an opaque reference that the
331/// fixpoint engine uses for shared-proof detection. The exact contents
332/// are wired up by the fixpoint engine.
333#[derive(Clone, Debug, Default)]
334pub struct DerivationTracker {
335    _placeholder: (),
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341
342    #[test]
343    fn semilattice_constants() {
344        const {
345            assert!(Semilattice::BOUNDED_MIN_MAX.monotone_join);
346            assert!(Semilattice::BOUNDED_MIN_MAX.has_top);
347            assert!(!Semilattice::NON_MONOTONE.monotone_join);
348            assert!(Semilattice::COUNT.monotone_join);
349            assert!(!Semilattice::COUNT.has_top);
350        }
351    }
352}