Skip to main content

alkahest_cas/kernel/
pool.rs

1use crate::kernel::{
2    domain::Domain,
3    expr::{BigFloat, BigInt, BigRat, ExprData, ExprId},
4};
5use std::fmt;
6
7/// Canonical ∞ symbol name for [`ExprPool::pos_infinity`] / limits (V2-16).
8pub const POS_INFINITY_SYMBOL: &str = "\u{221e}";
9
10// ---------------------------------------------------------------------------
11// Lock-free arena for ExprPool nodes.
12//
13// Strategy:
14//   * The `nodes` array (ExprId → ExprData) is a `boxcar::Vec` — a
15//     lock-free, append-only, reference-stable segmented array.  Reads
16//     (`with`, `get`, `len`) acquire no lock at all; they index directly
17//     into the array via a single atomic load.
18//   * The `index` (ExprData → ExprId) still requires coordination during
19//     insertion to preserve hash-cons uniqueness:
20//     - Under `--features parallel` we use `DashMap::entry` which holds a
21//       per-shard write-lock only for the duration of the insert.  The
22//       closure passed to `or_insert_with` calls `boxcar::push` (lock-free)
23//       while the shard lock is held, so no two threads can insert the same
24//       key.
25//     - Without `parallel` the `Mutex<HashMap>` serialises all inserts as
26//       before; the boxcar push happens while the Mutex is held.
27// ---------------------------------------------------------------------------
28
29#[cfg(feature = "parallel")]
30use dashmap::DashMap;
31
32#[cfg(not(feature = "parallel"))]
33use std::collections::HashMap;
34
35#[cfg(not(feature = "parallel"))]
36use std::sync::Mutex;
37
38// ---------------------------------------------------------------------------
39// PoolState — two variants depending on build features
40// ---------------------------------------------------------------------------
41
42#[cfg(feature = "parallel")]
43struct PoolIndex(DashMap<ExprData, ExprId>);
44
45#[cfg(not(feature = "parallel"))]
46struct PoolIndex(HashMap<ExprData, ExprId>);
47
48#[cfg(feature = "parallel")]
49impl PoolIndex {
50    fn new() -> Self {
51        PoolIndex(DashMap::new())
52    }
53    fn get(&self, data: &ExprData) -> Option<ExprId> {
54        self.0.get(data).map(|v| *v)
55    }
56    /// Atomically return the existing id for `key`, or call `f` to produce one
57    /// and insert it.  The DashMap shard write-lock is held for the duration of
58    /// `f`, guaranteeing at most one call to `f` per unique key.
59    fn or_insert_with(&self, key: ExprData, f: impl FnOnce() -> ExprId) -> ExprId {
60        *self.0.entry(key).or_insert_with(f)
61    }
62}
63
64#[cfg(not(feature = "parallel"))]
65impl PoolIndex {
66    fn new() -> Self {
67        PoolIndex(HashMap::new())
68    }
69    fn get(&self, data: &ExprData) -> Option<ExprId> {
70        self.0.get(data).copied()
71    }
72    fn insert(&mut self, data: ExprData, id: ExprId) {
73        self.0.insert(data, id);
74    }
75}
76
77/// Owns all expression nodes. Every [`ExprId`] is valid only within its pool.
78///
79/// `ExprPool` is `Send + Sync`.
80///
81/// Read operations (`with`, `get`, `len`) are fully lock-free — they index
82/// into a `boxcar::Vec` via a single atomic load with no lock acquisition.
83/// Write operations (`intern`) use a per-shard lock (parallel mode) or a
84/// `Mutex` (non-parallel mode) only during new-node insertion.
85pub struct ExprPool {
86    /// Lock-free, append-only, reference-stable node array.
87    nodes: boxcar::Vec<ExprData>,
88    /// Deduplication index: ExprData → ExprId.
89    #[cfg(feature = "parallel")]
90    index: PoolIndex,
91    #[cfg(not(feature = "parallel"))]
92    index: Mutex<PoolIndex>,
93}
94
95unsafe impl Send for ExprPool {}
96unsafe impl Sync for ExprPool {}
97
98impl ExprPool {
99    pub fn new() -> Self {
100        ExprPool {
101            nodes: boxcar::Vec::new(),
102            #[cfg(feature = "parallel")]
103            index: PoolIndex::new(),
104            #[cfg(not(feature = "parallel"))]
105            index: Mutex::new(PoolIndex::new()),
106        }
107    }
108
109    /// Intern `data`, returning a shared [`ExprId`]. Identical structures
110    /// always return the same id; structural equality ⟺ id equality.
111    pub fn intern(&self, data: ExprData) -> ExprId {
112        #[cfg(feature = "parallel")]
113        {
114            // Fast path: lock-free DashMap read.
115            if let Some(id) = self.index.get(&data) {
116                return id;
117            }
118            // Slow path: DashMap shard write-lock ensures at most one push
119            // per unique key.  `boxcar::push` is lock-free so it can be
120            // called safely while the shard lock is held.
121            self.index
122                .or_insert_with(data.clone(), || ExprId(self.nodes.push(data) as u32))
123        }
124
125        #[cfg(not(feature = "parallel"))]
126        {
127            let mut idx = self.index.lock().expect("ExprPool index Mutex poisoned");
128            if let Some(id) = idx.get(&data) {
129                return id;
130            }
131            let id = ExprId(self.nodes.push(data.clone()) as u32);
132            idx.insert(data, id);
133            id
134        }
135    }
136
137    /// Borrow a node by id and apply `f` without cloning.  Lock-free.
138    pub fn with<R, F: FnOnce(&ExprData) -> R>(&self, id: ExprId, f: F) -> R {
139        f(self
140            .nodes
141            .get(id.0 as usize)
142            .expect("ExprPool: ExprId out of range"))
143    }
144
145    /// Clone and return the `ExprData` for `id`.
146    pub fn get(&self, id: ExprId) -> ExprData {
147        self.with(id, |d| d.clone())
148    }
149
150    /// Number of distinct expressions interned so far.  Lock-free.
151    pub fn len(&self) -> usize {
152        self.nodes.count()
153    }
154
155    pub fn is_empty(&self) -> bool {
156        self.nodes.is_empty()
157    }
158
159    // -----------------------------------------------------------------------
160    // Atom constructors
161    // -----------------------------------------------------------------------
162
163    /// Free symbol; multiplication treats it as commuting with every other factor (default).
164    pub fn symbol(&self, name: impl Into<String>, domain: Domain) -> ExprId {
165        self.symbol_commutative(name, domain, true)
166    }
167
168    /// Free symbol with explicit commutative flag (V3-2). `commutative: false` is for
169    /// matrix or operator generators where `A*B` and `B*A` must remain distinct.
170    pub fn symbol_commutative(
171        &self,
172        name: impl Into<String>,
173        domain: Domain,
174        commutative: bool,
175    ) -> ExprId {
176        self.intern(ExprData::Symbol {
177            name: name.into(),
178            domain,
179            commutative,
180        })
181    }
182
183    pub fn integer(&self, n: impl Into<rug::Integer>) -> ExprId {
184        self.intern(ExprData::Integer(BigInt(n.into())))
185    }
186
187    pub fn rational(
188        &self,
189        numer: impl Into<rug::Integer>,
190        denom: impl Into<rug::Integer>,
191    ) -> ExprId {
192        let r = rug::Rational::from((numer.into(), denom.into()));
193        self.intern(ExprData::Rational(BigRat(r)))
194    }
195
196    pub fn float(&self, value: f64, prec: u32) -> ExprId {
197        let f = rug::Float::with_val(prec, value);
198        self.intern(ExprData::Float(BigFloat { inner: f, prec }))
199    }
200
201    // -----------------------------------------------------------------------
202    // Compound constructors
203    // -----------------------------------------------------------------------
204
205    pub fn add(&self, mut args: Vec<ExprId>) -> ExprId {
206        // Sort children at construction time so that commutativity holds
207        // structurally: `a + b` and `b + a` intern to the same ExprId.
208        // The sort key is the raw ExprId (opaque u32), which gives a stable,
209        // deterministic canonical order.
210        args.sort_unstable();
211        self.intern(ExprData::Add(args))
212    }
213
214    pub fn mul(&self, mut args: Vec<ExprId>) -> ExprId {
215        // Canonical sort only when every subtree is multiplicatively commutative (V3-2).
216        let sort_ok = args
217            .iter()
218            .all(|&a| crate::kernel::expr_props::mult_tree_is_commutative(self, a));
219        if sort_ok {
220            args.sort_unstable();
221        }
222        self.intern(ExprData::Mul(args))
223    }
224
225    pub fn pow(&self, base: ExprId, exp: ExprId) -> ExprId {
226        self.intern(ExprData::Pow { base, exp })
227    }
228
229    pub fn func(&self, name: impl Into<String>, args: Vec<ExprId>) -> ExprId {
230        self.intern(ExprData::Func {
231            name: name.into(),
232            args,
233        })
234    }
235
236    // -----------------------------------------------------------------------
237    // PA-9 — Piecewise / Predicate constructors
238    // -----------------------------------------------------------------------
239
240    /// Build a `Piecewise` expression.
241    ///
242    /// Branches are `(cond, value)` pairs where `cond` must be a
243    /// `Predicate` node.  The `default` value is used when no condition
244    /// matches.
245    pub fn piecewise(&self, branches: Vec<(ExprId, ExprId)>, default: ExprId) -> ExprId {
246        self.intern(ExprData::Piecewise { branches, default })
247    }
248
249    /// Build a `Predicate` node (symbolic boolean condition).
250    pub fn predicate(&self, kind: crate::kernel::expr::PredicateKind, args: Vec<ExprId>) -> ExprId {
251        self.intern(ExprData::Predicate { kind, args })
252    }
253
254    // Convenience constructors for common predicates.
255    pub fn pred_lt(&self, a: ExprId, b: ExprId) -> ExprId {
256        self.predicate(crate::kernel::expr::PredicateKind::Lt, vec![a, b])
257    }
258    pub fn pred_le(&self, a: ExprId, b: ExprId) -> ExprId {
259        self.predicate(crate::kernel::expr::PredicateKind::Le, vec![a, b])
260    }
261    pub fn pred_gt(&self, a: ExprId, b: ExprId) -> ExprId {
262        self.predicate(crate::kernel::expr::PredicateKind::Gt, vec![a, b])
263    }
264    pub fn pred_ge(&self, a: ExprId, b: ExprId) -> ExprId {
265        self.predicate(crate::kernel::expr::PredicateKind::Ge, vec![a, b])
266    }
267    pub fn pred_eq(&self, a: ExprId, b: ExprId) -> ExprId {
268        self.predicate(crate::kernel::expr::PredicateKind::Eq, vec![a, b])
269    }
270    pub fn pred_ne(&self, a: ExprId, b: ExprId) -> ExprId {
271        self.predicate(crate::kernel::expr::PredicateKind::Ne, vec![a, b])
272    }
273    pub fn pred_and(&self, args: Vec<ExprId>) -> ExprId {
274        self.predicate(crate::kernel::expr::PredicateKind::And, args)
275    }
276    pub fn pred_or(&self, args: Vec<ExprId>) -> ExprId {
277        self.predicate(crate::kernel::expr::PredicateKind::Or, args)
278    }
279    pub fn pred_not(&self, a: ExprId) -> ExprId {
280        self.predicate(crate::kernel::expr::PredicateKind::Not, vec![a])
281    }
282    pub fn pred_true(&self) -> ExprId {
283        self.predicate(crate::kernel::expr::PredicateKind::True, vec![])
284    }
285    pub fn pred_false(&self) -> ExprId {
286        self.predicate(crate::kernel::expr::PredicateKind::False, vec![])
287    }
288
289    // V3-3 — first-order quantifiers (first-class `Formula` / FOFormula).
290    /// `∀ var . body`
291    pub fn forall(&self, var: ExprId, body: ExprId) -> ExprId {
292        self.intern(ExprData::Forall { var, body })
293    }
294
295    /// `∃ var . body`
296    pub fn exists(&self, var: ExprId, body: ExprId) -> ExprId {
297        self.intern(ExprData::Exists { var, body })
298    }
299
300    /// `O(arg)` — symbolic big-O bound used in truncated series (V2-15).
301    pub fn big_o(&self, arg: ExprId) -> ExprId {
302        self.intern(ExprData::BigO(arg))
303    }
304
305    /// Canonical `+∞` symbol for limits at infinity (V2-16).
306    pub fn pos_infinity(&self) -> ExprId {
307        self.symbol(POS_INFINITY_SYMBOL, Domain::Positive)
308    }
309
310    // -----------------------------------------------------------------------
311    // Display helper
312    // -----------------------------------------------------------------------
313
314    pub fn display(&self, id: ExprId) -> ExprDisplay<'_> {
315        ExprDisplay { id, pool: self }
316    }
317}
318
319impl Default for ExprPool {
320    fn default() -> Self {
321        Self::new()
322    }
323}
324
325// ---------------------------------------------------------------------------
326// Display — pool-aware recursive formatter
327// ---------------------------------------------------------------------------
328
329/// Wraps an `(ExprId, &ExprPool)` pair so it can implement [`fmt::Display`].
330pub struct ExprDisplay<'a> {
331    pub id: ExprId,
332    pub pool: &'a ExprPool,
333}
334
335impl fmt::Display for ExprDisplay<'_> {
336    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
337        let data = self.pool.get(self.id);
338        fmt_data(&data, self.pool, f)
339    }
340}
341
342impl fmt::Debug for ExprDisplay<'_> {
343    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
344        write!(f, "{}", self)
345    }
346}
347
348fn fmt_data(data: &ExprData, pool: &ExprPool, f: &mut fmt::Formatter<'_>) -> fmt::Result {
349    match data {
350        ExprData::Symbol { name, .. } => write!(f, "{}", name),
351        ExprData::Integer(n) => write!(f, "{}", n),
352        ExprData::Rational(r) => write!(f, "{}", r),
353        ExprData::Float(fl) => write!(f, "{}", fl),
354        ExprData::Add(args) => {
355            write!(f, "(")?;
356            for (i, &arg) in args.iter().enumerate() {
357                if i > 0 {
358                    write!(f, " + ")?;
359                }
360                write!(f, "{}", pool.display(arg))?;
361            }
362            write!(f, ")")
363        }
364        ExprData::Mul(args) => {
365            write!(f, "(")?;
366            for (i, &arg) in args.iter().enumerate() {
367                if i > 0 {
368                    write!(f, " * ")?;
369                }
370                write!(f, "{}", pool.display(arg))?;
371            }
372            write!(f, ")")
373        }
374        ExprData::Pow { base, exp } => {
375            write!(f, "{}^{}", pool.display(*base), pool.display(*exp))
376        }
377        ExprData::Func { name, args } => {
378            write!(f, "{}(", name)?;
379            for (i, &arg) in args.iter().enumerate() {
380                if i > 0 {
381                    write!(f, ", ")?;
382                }
383                write!(f, "{}", pool.display(arg))?;
384            }
385            write!(f, ")")
386        }
387        ExprData::Piecewise { branches, default } => {
388            write!(f, "Piecewise(")?;
389            for (i, (cond, val)) in branches.iter().enumerate() {
390                if i > 0 {
391                    write!(f, ", ")?;
392                }
393                write!(f, "({}, {})", pool.display(*cond), pool.display(*val))?;
394            }
395            write!(f, "; default={})", pool.display(*default))
396        }
397        ExprData::Predicate { kind, args } => match kind {
398            crate::kernel::expr::PredicateKind::True => write!(f, "True"),
399            crate::kernel::expr::PredicateKind::False => write!(f, "False"),
400            crate::kernel::expr::PredicateKind::Not => {
401                write!(f, "¬({})", pool.display(args[0]))
402            }
403            crate::kernel::expr::PredicateKind::And | crate::kernel::expr::PredicateKind::Or => {
404                write!(f, "(")?;
405                for (i, &arg) in args.iter().enumerate() {
406                    if i > 0 {
407                        write!(f, " {} ", kind)?;
408                    }
409                    write!(f, "{}", pool.display(arg))?;
410                }
411                write!(f, ")")
412            }
413            _ => {
414                write!(
415                    f,
416                    "({} {} {})",
417                    pool.display(args[0]),
418                    kind,
419                    pool.display(args[1])
420                )
421            }
422        },
423        ExprData::Forall { var, body } => {
424            write!(f, "∀ {} . {}", pool.display(*var), pool.display(*body))
425        }
426        ExprData::Exists { var, body } => {
427            write!(f, "∃ {} . {}", pool.display(*var), pool.display(*body))
428        }
429        ExprData::BigO(arg) => {
430            write!(f, "O({})", pool.display(*arg))
431        }
432    }
433}
434
435// ---------------------------------------------------------------------------
436// Unit tests
437// ---------------------------------------------------------------------------
438
439#[cfg(test)]
440mod tests {
441    use super::*;
442    use crate::kernel::domain::Domain;
443
444    fn pool() -> ExprPool {
445        ExprPool::new()
446    }
447
448    #[test]
449    fn noncommutative_mul_orders_distinct() {
450        let p = pool();
451        let a = p.symbol_commutative("A", Domain::Real, false);
452        let b = p.symbol_commutative("B", Domain::Real, false);
453        assert_ne!(
454            p.mul(vec![a, b]),
455            p.mul(vec![b, a]),
456            "A*B and B*A must not hash-cons together for NC symbols"
457        );
458    }
459
460    #[test]
461    fn symbol_commutative_is_structural() {
462        let p = pool();
463        let xc = p.symbol_commutative("x", Domain::Real, true);
464        let xnc = p.symbol_commutative("x", Domain::Real, false);
465        assert_ne!(xc, xnc);
466    }
467
468    // --- construction and equality ---
469
470    #[test]
471    fn symbol_interning() {
472        let p = pool();
473        let x1 = p.symbol("x", Domain::Real);
474        let x2 = p.symbol("x", Domain::Real);
475        assert_eq!(x1, x2, "same symbol must return same ExprId");
476    }
477
478    #[test]
479    fn domain_is_structural() {
480        let p = pool();
481        let xr = p.symbol("x", Domain::Real);
482        let xc = p.symbol("x", Domain::Complex);
483        assert_ne!(xr, xc, "same name but different domain must be distinct");
484    }
485
486    #[test]
487    fn integer_interning() {
488        let p = pool();
489        let a = p.integer(42_i32);
490        let b = p.integer(42_i32);
491        let c = p.integer(99_i32);
492        assert_eq!(a, b);
493        assert_ne!(a, c);
494    }
495
496    #[test]
497    fn rational_canonical() {
498        let p = pool();
499        // 2/4 reduces to 1/2
500        let r1 = p.rational(2_i32, 4_i32);
501        let r2 = p.rational(1_i32, 2_i32);
502        assert_eq!(r1, r2, "rationals must be reduced to canonical form");
503    }
504
505    #[test]
506    fn float_precision_is_structural() {
507        let p = pool();
508        let f53 = p.float(1.0, 53);
509        let f64_ = p.float(1.0, 64);
510        assert_ne!(
511            f53, f64_,
512            "same value but different precision is a different expr"
513        );
514    }
515
516    // --- compound expressions and subexpression sharing ---
517
518    #[test]
519    fn subexpression_sharing() {
520        let p = pool();
521        let x = p.symbol("x", Domain::Real);
522        let two = p.integer(2_i32);
523
524        // Build x^2 twice; both must return the same ExprId.
525        let xsq1 = p.pow(x, two);
526        let xsq2 = p.pow(x, two);
527        assert_eq!(xsq1, xsq2);
528
529        // Pool should have exactly 3 nodes: x, 2, x^2.
530        assert_eq!(p.len(), 3);
531    }
532
533    #[test]
534    fn add_interning() {
535        let p = pool();
536        let x = p.symbol("x", Domain::Real);
537        let y = p.symbol("y", Domain::Real);
538        let s1 = p.add(vec![x, y]);
539        let s2 = p.add(vec![x, y]);
540        assert_eq!(s1, s2);
541    }
542
543    #[test]
544    fn arg_order_is_canonical() {
545        // PA-3: Add/Mul children are sorted at construction time so that
546        // commutativity holds structurally — a+b and b+a intern to the same ExprId.
547        let p = pool();
548        let x = p.symbol("x", Domain::Real);
549        let y = p.symbol("y", Domain::Real);
550        let s1 = p.add(vec![x, y]);
551        let s2 = p.add(vec![y, x]);
552        assert_eq!(s1, s2, "a+b and b+a must be the same expression after PA-3");
553        let m1 = p.mul(vec![x, y]);
554        let m2 = p.mul(vec![y, x]);
555        assert_eq!(m1, m2, "a*b and b*a must be the same expression after PA-3");
556    }
557
558    #[test]
559    fn func_interning() {
560        let p = pool();
561        let x = p.symbol("x", Domain::Real);
562        let s1 = p.func("sin", vec![x]);
563        let s2 = p.func("sin", vec![x]);
564        let c1 = p.func("cos", vec![x]);
565        assert_eq!(s1, s2);
566        assert_ne!(s1, c1);
567    }
568
569    // --- display ---
570
571    #[test]
572    fn display_symbol() {
573        let p = pool();
574        let x = p.symbol("x", Domain::Real);
575        assert_eq!(p.display(x).to_string(), "x");
576    }
577
578    #[test]
579    fn display_integer() {
580        let p = pool();
581        let n = p.integer(42_i32);
582        assert_eq!(p.display(n).to_string(), "42");
583    }
584
585    #[test]
586    fn display_pow() {
587        let p = pool();
588        let x = p.symbol("x", Domain::Real);
589        let two = p.integer(2_i32);
590        let xsq = p.pow(x, two);
591        assert_eq!(p.display(xsq).to_string(), "x^2");
592    }
593
594    #[test]
595    fn display_add() {
596        let p = pool();
597        let x = p.symbol("x", Domain::Real);
598        let y = p.symbol("y", Domain::Real);
599        let s = p.add(vec![x, y]);
600        assert_eq!(p.display(s).to_string(), "(x + y)");
601    }
602
603    #[test]
604    fn display_func() {
605        let p = pool();
606        let x = p.symbol("x", Domain::Real);
607        let s = p.func("sin", vec![x]);
608        assert_eq!(p.display(s).to_string(), "sin(x)");
609    }
610
611    #[test]
612    fn display_nested() {
613        let p = pool();
614        let x = p.symbol("x", Domain::Real);
615        let two = p.integer(2_i32);
616        let xsq = p.pow(x, two);
617        let one = p.integer(1_i32);
618        let expr = p.add(vec![xsq, one]);
619        assert_eq!(p.display(expr).to_string(), "(x^2 + 1)");
620    }
621
622    // --- send + sync: compile-time check ---
623
624    fn assert_send_sync<T: Send + Sync>() {}
625
626    #[test]
627    fn pool_is_send_sync() {
628        assert_send_sync::<ExprPool>();
629    }
630}