uni-query 1.1.0

OpenCypher query parser, planner, and vectorized executor for Uni
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
//! Weighted model counting (WMC) via BDDs for shared-lineage groups.
//!
//! When MNOR/MPROD groups have shared base facts (violating the
//! independence assumption), this module computes the correct joint
//! probability via weighted model counting on Binary Decision Diagrams
//! (Sang et al. 2005; Darwiche 2011 for SDD generalization).

use std::collections::{HashMap, HashSet};

use biodivine_lib_bdd::{Bdd, BddPointer, BddVariable, BddVariableSet};

/// Semiring operation for probability aggregation.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SemiringOp {
    /// MNOR: P(any row) = 1 − ∏(1 − pᵢ)
    Disjunction,
    /// MPROD: P(all rows) = ∏ pᵢ
    Conjunction,
}

/// Result of weighted model counting for a single aggregation group.
#[derive(Debug)]
pub struct WmcResult {
    /// The exact (or fallback) probability value.
    pub probability: f64,
    /// True if the computation fell back to independence mode because
    /// the number of unique base facts exceeded `max_bdd_variables`.
    pub approximated: bool,
    /// Number of unique base fact variables in this group.
    pub variable_count: usize,
}

/// Compute exact probability for an aggregation group via weighted model counting (WMC).
///
/// Each derivation row contributes a set of base facts (identified by
/// opaque byte-string hashes). The function builds a BDD representing
/// the Boolean combination of those facts and evaluates the probability
/// via Shannon expansion.
///
/// - `SemiringOp::Disjunction` → MNOR semantics: P(any row derives) = P(row₁ ∨ row₂ ∨ …)
/// - `SemiringOp::Conjunction` → MPROD semantics: P(all rows derive) = P(row₁ ∧ row₂ ∧ …)
pub fn weighted_model_count(
    group_lineage: &[HashSet<Vec<u8>>],
    base_fact_weights: &HashMap<Vec<u8>, f64>,
    semiring_op: SemiringOp,
    max_bdd_variables: usize,
) -> WmcResult {
    let is_disjunction = semiring_op == SemiringOp::Disjunction;
    if group_lineage.is_empty() {
        return WmcResult {
            probability: if is_disjunction { 0.0 } else { 1.0 },
            approximated: false,
            variable_count: 0,
        };
    }

    // Collect all unique base facts across all derivation rows.
    let mut all_facts: Vec<Vec<u8>> =
        HashSet::<&Vec<u8>>::from_iter(group_lineage.iter().flat_map(|s| s.iter()))
            .into_iter()
            .cloned()
            .collect();
    // Sort for deterministic variable ordering.
    all_facts.sort();

    let variable_count = all_facts.len();

    // Check BDD variable limit.
    if variable_count > max_bdd_variables || variable_count > u16::MAX as usize {
        return WmcResult {
            probability: 0.0, // caller should use independence-mode result
            approximated: true,
            variable_count,
        };
    }

    // Map each base fact to a BDD variable index.
    let fact_to_idx: HashMap<&Vec<u8>, usize> =
        all_facts.iter().enumerate().map(|(i, f)| (f, i)).collect();

    let vars = BddVariableSet::new_anonymous(variable_count as u16);
    let bdd_vars: Vec<BddVariable> = vars.variables();

    // Build a BDD term per derivation row.
    // Each row's term is the AND of its base fact variables:
    //   term_i = base_fact_a ∧ base_fact_b ∧ …
    let mut combined: Option<Bdd> = None;

    for row_facts in group_lineage {
        if row_facts.is_empty() {
            // A row with no base facts is unconditionally true.
            let term = vars.mk_true();
            combined = Some(match combined {
                Some(acc) => {
                    if is_disjunction {
                        acc.or(&term)
                    } else {
                        acc.and(&term)
                    }
                }
                None => term,
            });
            continue;
        }

        // Build AND of all base fact variables for this row.
        let mut term = vars.mk_true();
        for fact in row_facts {
            if let Some(&idx) = fact_to_idx.get(fact) {
                let var_bdd = vars.mk_var(bdd_vars[idx]);
                term = term.and(&var_bdd);
            }
        }

        // Combine with previous rows: OR for MNOR, AND for MPROD.
        combined = Some(match combined {
            Some(acc) => {
                if is_disjunction {
                    acc.or(&term)
                } else {
                    acc.and(&term)
                }
            }
            None => term,
        });
    }

    let bdd = match combined {
        Some(b) => b,
        None => {
            return WmcResult {
                probability: if is_disjunction { 0.0 } else { 1.0 },
                approximated: false,
                variable_count,
            };
        }
    };

    // Build probability map: BddVariable → probability.
    let prob_map: HashMap<BddVariable, f64> = all_facts
        .iter()
        .enumerate()
        .map(|(i, fact)| {
            let p = base_fact_weights.get(fact).copied().unwrap_or(0.5);
            (bdd_vars[i], p)
        })
        .collect();

    let probability = eval_bdd_probability(&bdd, &prob_map);

    WmcResult {
        probability,
        approximated: false,
        variable_count,
    }
}

/// Evaluate the probability of a BDD via Shannon expansion.
///
/// For each internal node with variable `v`:
///   P(node) = (1 - p_v) · P(low_child) + p_v · P(high_child)
///
/// Terminal nodes: P(⊥) = 0, P(⊤) = 1.
fn eval_bdd_probability(bdd: &Bdd, prob_map: &HashMap<BddVariable, f64>) -> f64 {
    let mut memo: HashMap<BddPointer, f64> = HashMap::new();
    eval_ptr(bdd, bdd.root_pointer(), prob_map, &mut memo)
}

fn eval_ptr(
    bdd: &Bdd,
    ptr: BddPointer,
    prob_map: &HashMap<BddVariable, f64>,
    memo: &mut HashMap<BddPointer, f64>,
) -> f64 {
    if ptr.is_zero() {
        return 0.0;
    }
    if ptr.is_one() {
        return 1.0;
    }
    if let Some(&cached) = memo.get(&ptr) {
        return cached;
    }

    let var = bdd.var_of(ptr);
    let p = prob_map.get(&var).copied().unwrap_or(0.5);
    let lo = eval_ptr(bdd, bdd.low_link_of(ptr), prob_map, memo);
    let hi = eval_ptr(bdd, bdd.high_link_of(ptr), prob_map, memo);
    let result = (1.0 - p) * lo + p * hi;
    memo.insert(ptr, result);
    result
}

#[cfg(test)]
mod tests {
    use super::*;

    /// Helper: compute noisy-OR under independence assumption.
    fn noisy_or_independent(probs: &[f64]) -> f64 {
        1.0 - probs.iter().fold(1.0, |acc, &p| acc * (1.0 - p))
    }

    #[test]
    fn independent_facts_mnor_matches_noisy_or() {
        // Two derivation rows with completely independent base facts.
        // Row 0: {A(0.3)}
        // Row 1: {B(0.5)}
        // MNOR = P(A ∨ B) = 1 - (1-0.3)(1-0.5) = 0.65
        let a = b"fact_a".to_vec();
        let b = b"fact_b".to_vec();

        let rows = vec![HashSet::from([a.clone()]), HashSet::from([b.clone()])];
        let probs = HashMap::from([(a, 0.3), (b, 0.5)]);

        let result = weighted_model_count(&rows, &probs, SemiringOp::Disjunction, 1000);
        assert!(!result.approximated);
        assert_eq!(result.variable_count, 2);

        let expected = noisy_or_independent(&[0.3, 0.5]);
        assert!(
            (result.probability - expected).abs() < 1e-10,
            "BDD={}, expected={}",
            result.probability,
            expected
        );
    }

    #[test]
    fn shared_facts_mnor_differs_from_independence() {
        // Diamond pattern: two paths share base fact C (smelter).
        // Row 0: path through A and C → base facts {A(0.3), C(0.7)}
        // Row 1: path through B and C → base facts {B(0.5), C(0.7)}
        //
        // Independence MNOR: 1 - (1 - 0.3*0.7)(1 - 0.5*0.7) = 1 - 0.79*0.65 = 0.4865
        // Exact BDD:         P(row0 ∨ row1) = P((A∧C) ∨ (B∧C))
        //                  = P(C · (A∨B)) = 0.7 · (1 - 0.7·0.5) = 0.7 · 0.65 = 0.455
        let a = b"fact_a".to_vec();
        let b = b"fact_b".to_vec();
        let c = b"fact_c".to_vec();

        let rows = vec![
            HashSet::from([a.clone(), c.clone()]),
            HashSet::from([b.clone(), c.clone()]),
        ];
        let probs = HashMap::from([(a, 0.3), (b, 0.5), (c, 0.7)]);

        let result = weighted_model_count(&rows, &probs, SemiringOp::Disjunction, 1000);
        assert!(!result.approximated);
        assert_eq!(result.variable_count, 3);

        // Exact: P(C) * P(A ∨ B) = 0.7 * (1 - (1-0.3)(1-0.5)) = 0.7 * 0.65 = 0.455
        let expected_exact = 0.455;
        assert!(
            (result.probability - expected_exact).abs() < 1e-10,
            "BDD={}, expected={}",
            result.probability,
            expected_exact
        );

        // Verify it differs from independence mode.
        let independence = noisy_or_independent(&[0.3 * 0.7, 0.5 * 0.7]);
        assert!(
            (result.probability - independence).abs() > 0.01,
            "BDD result should differ from independence: BDD={}, indep={}",
            result.probability,
            independence
        );
    }

    #[test]
    fn shared_facts_mprod() {
        // MPROD with shared facts.
        // Row 0: {A(0.3), C(0.7)}
        // Row 1: {B(0.5), C(0.7)}
        //
        // MPROD = P(row0 ∧ row1) = P((A∧C) ∧ (B∧C)) = P(A∧B∧C) = 0.3 * 0.5 * 0.7 = 0.105
        let a = b"fact_a".to_vec();
        let b = b"fact_b".to_vec();
        let c = b"fact_c".to_vec();

        let rows = vec![
            HashSet::from([a.clone(), c.clone()]),
            HashSet::from([b.clone(), c.clone()]),
        ];
        let probs = HashMap::from([(a, 0.3), (b, 0.5), (c, 0.7)]);

        let result = weighted_model_count(&rows, &probs, SemiringOp::Conjunction, 1000);
        assert!(!result.approximated);
        assert_eq!(result.variable_count, 3);

        let expected = 0.3 * 0.5 * 0.7;
        assert!(
            (result.probability - expected).abs() < 1e-10,
            "BDD={}, expected={}",
            result.probability,
            expected
        );
    }

    #[test]
    fn bdd_limit_exceeded_falls_back() {
        let a = b"fact_a".to_vec();
        let b = b"fact_b".to_vec();

        let rows = vec![HashSet::from([a.clone()]), HashSet::from([b.clone()])];
        let probs = HashMap::from([(a, 0.3), (b, 0.5)]);

        // Set limit to 1 — there are 2 unique facts, so it should fall back.
        let result = weighted_model_count(&rows, &probs, SemiringOp::Disjunction, 1);
        assert!(result.approximated);
        assert_eq!(result.variable_count, 2);
    }

    #[test]
    fn empty_group_returns_identity() {
        let probs = HashMap::new();

        let nor_result = weighted_model_count(&[], &probs, SemiringOp::Disjunction, 1000);
        assert!(!nor_result.approximated);
        assert!((nor_result.probability - 0.0).abs() < 1e-10);

        let prod_result = weighted_model_count(&[], &probs, SemiringOp::Conjunction, 1000);
        assert!(!prod_result.approximated);
        assert!((prod_result.probability - 1.0).abs() < 1e-10);
    }

    #[test]
    fn single_row_returns_product_of_base_facts() {
        // Single row with two base facts: P = A * B = 0.3 * 0.5 = 0.15
        let a = b"fact_a".to_vec();
        let b = b"fact_b".to_vec();

        let rows = vec![HashSet::from([a.clone(), b.clone()])];
        let probs = HashMap::from([(a, 0.3), (b, 0.5)]);

        // For MNOR or MPROD with a single row, the result is the same:
        // P(A ∧ B) = 0.15
        let result = weighted_model_count(&rows, &probs, SemiringOp::Disjunction, 1000);
        assert!((result.probability - 0.15).abs() < 1e-10);

        let result = weighted_model_count(&rows, &probs, SemiringOp::Conjunction, 1000);
        assert!((result.probability - 0.15).abs() < 1e-10);
    }

    // ── New tests (Phase 4 coverage gap closure) ──────────────────────────

    #[test]
    fn independent_facts_mprod_matches_product() {
        // MPROD: two independent rows, no sharing.
        // Row 0: {A(0.3)}, Row 1: {B(0.5)}
        // MPROD = P(A ∧ B) = P(A) * P(B) = 0.3 * 0.5 = 0.15 (independent)
        let a = b"fact_a".to_vec();
        let b = b"fact_b".to_vec();

        let rows = vec![HashSet::from([a.clone()]), HashSet::from([b.clone()])];
        let probs = HashMap::from([(a, 0.3), (b, 0.5)]);

        let result = weighted_model_count(&rows, &probs, SemiringOp::Conjunction, 1000);
        assert!(!result.approximated);
        assert_eq!(result.variable_count, 2);

        let expected = 0.3 * 0.5;
        assert!(
            (result.probability - expected).abs() < 1e-10,
            "MPROD BDD={}, expected={}",
            result.probability,
            expected
        );
    }

    #[test]
    fn bdd_limit_exceeded_returns_zero_probability() {
        // Three rows with three distinct facts, limit=2 → falls back.
        let a = b"fact_a".to_vec();
        let b = b"fact_b".to_vec();
        let c = b"fact_c".to_vec();

        let rows = vec![
            HashSet::from([a.clone()]),
            HashSet::from([b.clone()]),
            HashSet::from([c.clone()]),
        ];
        let probs = HashMap::from([(a, 0.3), (b, 0.5), (c, 0.7)]);

        let result = weighted_model_count(&rows, &probs, SemiringOp::Disjunction, 2);
        assert!(
            result.approximated,
            "Expected approximated=true when limit exceeded"
        );
        assert_eq!(result.variable_count, 3);
        assert!(
            (result.probability - 0.0).abs() < 1e-10,
            "Fallback probability should be 0.0, got {}",
            result.probability
        );
    }

    #[test]
    fn three_way_shared_mnor() {
        // Three derivation rows all sharing base fact D(0.8).
        // Row 0: {A(0.3), D(0.8)}
        // Row 1: {B(0.5), D(0.8)}
        // Row 2: {C(0.4), D(0.8)}
        //
        // Exact: P((A∧D) ∨ (B∧D) ∨ (C∧D)) = P(D ∧ (A∨B∨C))
        //      = P(D) * P(A∨B∨C) = 0.8 * (1-(1-0.3)(1-0.5)(1-0.4))
        //      = 0.8 * (1 - 0.7*0.5*0.6) = 0.8 * (1 - 0.21) = 0.8 * 0.79 = 0.632
        let a = b"fact_a".to_vec();
        let b = b"fact_b".to_vec();
        let c = b"fact_c".to_vec();
        let d = b"fact_d".to_vec();

        let rows = vec![
            HashSet::from([a.clone(), d.clone()]),
            HashSet::from([b.clone(), d.clone()]),
            HashSet::from([c.clone(), d.clone()]),
        ];
        let probs = HashMap::from([(a, 0.3), (b, 0.5), (c, 0.4), (d, 0.8)]);

        let result = weighted_model_count(&rows, &probs, SemiringOp::Disjunction, 1000);
        assert!(!result.approximated);
        assert_eq!(result.variable_count, 4);

        let expected = 0.8 * (1.0 - (1.0 - 0.3) * (1.0 - 0.5) * (1.0 - 0.4));
        assert!(
            (result.probability - expected).abs() < 1e-10,
            "BDD={}, expected={}",
            result.probability,
            expected
        );

        // Verify it differs from naive independence over row-products.
        let row0_prod = 0.3 * 0.8;
        let row1_prod = 0.5 * 0.8;
        let row2_prod = 0.4 * 0.8;
        let independence = 1.0 - (1.0 - row0_prod) * (1.0 - row1_prod) * (1.0 - row2_prod);
        assert!(
            (result.probability - independence).abs() > 0.01,
            "BDD result should differ from independence: BDD={}, indep={}",
            result.probability,
            independence
        );
    }

    #[test]
    fn partially_overlapping_rows_mnor() {
        // Three rows: two share fact C, one is fully independent.
        // Row 0: {A(0.3), C(0.7)}
        // Row 1: {B(0.5), C(0.7)}
        // Row 2: {E(0.6)}             ← independent
        //
        // Exact: P((A∧C) ∨ (B∧C) ∨ E)
        //      = P(C·(A∨B) ∨ E) where C·(A∨B) and E are independent
        //      = 1 - (1 - P(C·(A∨B))) * (1 - P(E))
        // P(C·(A∨B)) = P(C) * P(A∨B) = 0.7 * (1-(1-0.3)(1-0.5)) = 0.7 * 0.65 = 0.455
        // = 1 - (1-0.455)(1-0.6) = 1 - 0.545*0.4 = 1 - 0.218 = 0.782
        let a = b"fact_a".to_vec();
        let b = b"fact_b".to_vec();
        let c = b"fact_c".to_vec();
        let e = b"fact_e".to_vec();

        let rows = vec![
            HashSet::from([a.clone(), c.clone()]),
            HashSet::from([b.clone(), c.clone()]),
            HashSet::from([e.clone()]),
        ];
        let probs = HashMap::from([(a, 0.3), (b, 0.5), (c, 0.7), (e, 0.6)]);

        let result = weighted_model_count(&rows, &probs, SemiringOp::Disjunction, 1000);
        assert!(!result.approximated);
        assert_eq!(result.variable_count, 4);

        let p_c_times_a_or_b = 0.7 * (1.0 - (1.0 - 0.3) * (1.0 - 0.5));
        let expected = 1.0 - (1.0 - p_c_times_a_or_b) * (1.0 - 0.6);
        assert!(
            (result.probability - expected).abs() < 1e-10,
            "BDD={}, expected={}",
            result.probability,
            expected
        );
    }

    #[test]
    fn missing_probability_defaults_to_half() {
        // A fact not present in the probability map should default to 0.5.
        // Single row: {unknown_fact}
        // P(unknown) = 0.5 (default)
        let unknown = b"unknown_fact".to_vec();

        let rows = vec![HashSet::from([unknown.clone()])];
        let probs = HashMap::new(); // empty — unknown_fact has no entry

        let result = weighted_model_count(&rows, &probs, SemiringOp::Disjunction, 1000);
        assert!(!result.approximated);
        assert_eq!(result.variable_count, 1);
        assert!(
            (result.probability - 0.5).abs() < 1e-10,
            "Expected default probability 0.5, got {}",
            result.probability
        );
    }
}