Skip to main content

ries_rs/
pool.rs

1//! Bounded priority pool for streaming match collection
2//!
3//! Implements a bounded pool that keeps the best matches across
4//! multiple dimensions (error, complexity) with deduplication.
5//!
6//! # Key Design
7//!
8//! Keys use `Expression` directly rather than `String` for zero-allocation
9//! deduplication. Since `Expression` uses `SmallVec<[Symbol; 21]>`, expressions
10//! within the length limit stay inline on the stack, avoiding heap allocation
11//! during hashing and comparison.
12
13use crate::expr::Expression;
14use crate::search::Match;
15use crate::thresholds::{
16    ACCEPT_ERROR_TIGHTEN_FACTOR, BEST_ERROR_TIGHTEN_FACTOR, EXACT_MATCH_TOLERANCE,
17    NEWTON_TOLERANCE, STRICT_GATE_CAPACITY_FRACTION, STRICT_GATE_FACTOR,
18};
19use std::cmp::Ordering;
20use std::collections::{BinaryHeap, HashSet};
21
22/// Match ranking mode for pool eviction and final ordering.
23#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
24pub enum RankingMode {
25    /// Sort by exactness -> error -> complexity (current default behavior)
26    #[default]
27    Complexity,
28    /// Sort by exactness -> error -> legacy signed parity score -> complexity
29    Parity,
30}
31
32/// Keys for full equation deduplication (LHS + RHS pair)
33///
34/// Uses a pair of expressions directly for zero-allocation hashing.
35/// The tuple (lhs, rhs) uniquely identifies an equation.
36#[derive(Clone, PartialEq, Eq, Hash)]
37pub struct EqnKey {
38    /// LHS expression (contains x)
39    lhs: Expression,
40    /// RHS expression (constants only)
41    rhs: Expression,
42}
43
44impl EqnKey {
45    /// Create a key from a match
46    #[inline]
47    pub fn from_match(m: &Match) -> Self {
48        Self {
49            lhs: m.lhs.expr.clone(),
50            rhs: m.rhs.expr.clone(),
51        }
52    }
53}
54
55/// Keys for LHS-only deduplication
56///
57/// Used by report.rs to prevent showing too many variants of the same LHS.
58#[derive(Clone, PartialEq, Eq, Hash)]
59pub struct LhsKey {
60    /// LHS expression
61    lhs: Expression,
62}
63
64impl LhsKey {
65    /// Create a key from a match
66    #[inline]
67    pub fn from_match(m: &Match) -> Self {
68        Self {
69            lhs: m.lhs.expr.clone(),
70        }
71    }
72}
73
74/// Signature for operator/constant pattern (used for "interesting" dedupe)
75///
76/// Uses a boxed slice for efficient storage and hashing.
77/// Signatures are created during pool insertion (not the hot path).
78#[derive(Clone, PartialEq, Eq, Hash)]
79pub struct SignatureKey {
80    /// Operator pattern signature as bytes
81    key: Box<[u8]>,
82}
83
84impl SignatureKey {
85    pub fn from_match(m: &Match) -> Self {
86        // Build a signature from operator types and constants used
87        let expected_len = m.lhs.expr.len() + m.rhs.expr.len() + 1;
88        let mut ops = Vec::with_capacity(expected_len);
89
90        for sym in m.lhs.expr.symbols() {
91            ops.push(*sym as u8);
92        }
93        ops.push(b'=');
94        for sym in m.rhs.expr.symbols() {
95            ops.push(*sym as u8);
96        }
97
98        Self {
99            key: ops.into_boxed_slice(),
100        }
101    }
102}
103
104/// Compute legacy (original RIES style) signed parity score for an expression.
105pub fn legacy_parity_score_expr(expr: &Expression) -> i32 {
106    expr.symbols().iter().fold(0_i32, |acc, sym| {
107        acc.saturating_add(sym.legacy_parity_weight())
108    })
109}
110
111/// Compute legacy (original RIES style) signed parity score for a match.
112pub fn legacy_parity_score_match(m: &Match) -> i32 {
113    legacy_parity_score_expr(&m.lhs.expr).saturating_add(legacy_parity_score_expr(&m.rhs.expr))
114}
115
116#[inline]
117fn compare_expr(a: &Expression, b: &Expression) -> Ordering {
118    a.symbols()
119        .iter()
120        .map(|s| *s as u8)
121        .cmp(b.symbols().iter().map(|s| *s as u8))
122}
123
124/// Compare two matches according to the selected ranking mode.
125pub fn compare_matches(a: &Match, b: &Match, ranking_mode: RankingMode) -> Ordering {
126    let a_exactness = if a.error.abs() < EXACT_MATCH_TOLERANCE {
127        0_u8
128    } else {
129        1_u8
130    };
131    let b_exactness = if b.error.abs() < EXACT_MATCH_TOLERANCE {
132        0_u8
133    } else {
134        1_u8
135    };
136
137    let mut ord = a_exactness.cmp(&b_exactness).then_with(|| {
138        a.error
139            .abs()
140            .partial_cmp(&b.error.abs())
141            .unwrap_or(Ordering::Equal)
142    });
143
144    if ord != Ordering::Equal {
145        return ord;
146    }
147
148    ord = match ranking_mode {
149        RankingMode::Complexity => a.complexity.cmp(&b.complexity),
150        RankingMode::Parity => legacy_parity_score_match(a)
151            .cmp(&legacy_parity_score_match(b))
152            .then_with(|| a.complexity.cmp(&b.complexity)),
153    };
154
155    if ord != Ordering::Equal {
156        return ord;
157    }
158
159    compare_expr(&a.lhs.expr, &b.lhs.expr).then_with(|| compare_expr(&a.rhs.expr, &b.rhs.expr))
160}
161
162/// Wrapper for Match that implements ordering for the heap
163/// We keep the worst-ranked entry at the heap top for eviction.
164#[derive(Clone)]
165struct PoolEntry {
166    m: Match,
167    rank_key: (u8, i64, i32, u32), // (exactness, error_bits, mode_tie, complexity)
168}
169
170impl PoolEntry {
171    fn new(m: Match, ranking_mode: RankingMode) -> Self {
172        let is_exact = m.error.abs() < EXACT_MATCH_TOLERANCE;
173        let exactness_rank = if is_exact { 0 } else { 1 };
174        // Convert error to sortable integer, handling special values.
175        // For IEEE 754 doubles, positive values' bit patterns preserve ordering
176        // when interpreted as unsigned, but we need to handle NaN/Infinity specially.
177        let error_abs = m.error.abs();
178        let error_bits = if error_abs.is_nan() {
179            // NaN should sort as worst (largest) error
180            i64::MAX
181        } else if error_abs.is_infinite() {
182            // Infinity should also sort as worst (just below NaN)
183            i64::MAX - 1
184        } else {
185            // For normal positive floats, the bit pattern preserves ordering
186            // when cast to i64 (since all positive floats have bit patterns < i64::MAX)
187            error_abs.to_bits() as i64
188        };
189        let mode_tie = match ranking_mode {
190            RankingMode::Complexity => m.complexity as i32,
191            RankingMode::Parity => legacy_parity_score_match(&m),
192        };
193        Self {
194            rank_key: (exactness_rank, error_bits, mode_tie, m.complexity),
195            m,
196        }
197    }
198}
199
200impl PartialEq for PoolEntry {
201    fn eq(&self, other: &Self) -> bool {
202        self.rank_key == other.rank_key
203            && self.m.lhs.expr == other.m.lhs.expr
204            && self.m.rhs.expr == other.m.rhs.expr
205    }
206}
207
208impl Eq for PoolEntry {}
209
210impl PartialOrd for PoolEntry {
211    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
212        Some(self.cmp(other))
213    }
214}
215
216impl Ord for PoolEntry {
217    fn cmp(&self, other: &Self) -> Ordering {
218        // Keep worst (least exact, largest error, largest complexity) at top for eviction.
219        self.rank_key
220            .cmp(&other.rank_key)
221            .then_with(|| compare_expr(&self.m.lhs.expr, &other.m.lhs.expr))
222            .then_with(|| compare_expr(&self.m.rhs.expr, &other.m.rhs.expr))
223    }
224}
225
226/// Statistics from pool operations
227#[derive(Clone, Debug, Default)]
228pub struct PoolStats {
229    /// Number of successful insertions
230    pub insertions: usize,
231    /// Number rejected due to error threshold
232    pub rejections_error: usize,
233    /// Number rejected due to deduplication
234    pub rejections_dedupe: usize,
235    /// Number evicted to maintain capacity
236    pub evictions: usize,
237}
238
239/// Bounded pool for collecting matches
240pub struct TopKPool {
241    /// Max capacity
242    capacity: usize,
243    /// Priority queue (worst at top for eviction)
244    heap: BinaryHeap<PoolEntry>,
245    /// Seen equation keys for dedupe
246    seen_eqn: HashSet<EqnKey>,
247    /// Best error seen so far (for threshold tightening)
248    pub best_error: f64,
249    /// Accept error threshold (tightens slowly for diversity)
250    pub accept_error: f64,
251    /// Statistics
252    pub stats: PoolStats,
253    /// Show diagnostic output for database adds (-DG)
254    show_db_adds: bool,
255    /// Ranking mode for eviction and output ordering
256    ranking_mode: RankingMode,
257}
258
259impl TopKPool {
260    /// Create a new pool with given capacity
261    #[allow(dead_code)]
262    pub fn new(capacity: usize, initial_max_error: f64) -> Self {
263        Self {
264            capacity,
265            heap: BinaryHeap::with_capacity(capacity + 1),
266            seen_eqn: HashSet::new(),
267            best_error: initial_max_error,
268            accept_error: initial_max_error,
269            stats: PoolStats::default(),
270            show_db_adds: false,
271            ranking_mode: RankingMode::Complexity,
272        }
273    }
274
275    /// Create a new pool with diagnostic options
276    pub fn new_with_diagnostics(
277        capacity: usize,
278        initial_max_error: f64,
279        show_db_adds: bool,
280        ranking_mode: RankingMode,
281    ) -> Self {
282        Self {
283            capacity,
284            heap: BinaryHeap::with_capacity(capacity + 1),
285            seen_eqn: HashSet::new(),
286            best_error: initial_max_error,
287            accept_error: initial_max_error,
288            stats: PoolStats::default(),
289            show_db_adds,
290            ranking_mode,
291        }
292    }
293
294    /// Try to insert a match into the pool
295    /// Returns true if inserted, false if rejected
296    pub fn try_insert(&mut self, m: Match) -> bool {
297        let error = m.error.abs();
298        let is_exact = error < EXACT_MATCH_TOLERANCE;
299
300        // Check error threshold (must be better than accept_error)
301        if !is_exact && error > self.accept_error {
302            self.stats.rejections_error += 1;
303            return false;
304        }
305
306        // Check equation-level dedupe
307        let eqn_key = EqnKey::from_match(&m);
308        if self.seen_eqn.contains(&eqn_key) {
309            self.stats.rejections_dedupe += 1;
310            return false;
311        }
312
313        // Insert
314        let entry = PoolEntry::new(m, self.ranking_mode);
315        self.seen_eqn.insert(eqn_key);
316
317        // Diagnostic output for -DG (before moving entry into heap)
318        if self.show_db_adds {
319            eprintln!(
320                "  [db add] lhs={:?} rhs={:?} error={:.6e} complexity={}",
321                entry.m.lhs.expr.to_postfix(),
322                entry.m.rhs.expr.to_postfix(),
323                entry.m.error,
324                entry.m.complexity
325            );
326        }
327
328        self.heap.push(entry);
329        self.stats.insertions += 1;
330
331        // Update thresholds
332        if is_exact {
333            // Exact match: tighten best_error aggressively but keep a floor
334            self.best_error =
335                EXACT_MATCH_TOLERANCE.max(self.best_error * BEST_ERROR_TIGHTEN_FACTOR);
336        } else if error < self.best_error {
337            // Better approximation: tighten best_error
338            self.best_error = error * BEST_ERROR_TIGHTEN_FACTOR - NEWTON_TOLERANCE;
339            self.best_error = self.best_error.max(EXACT_MATCH_TOLERANCE);
340        }
341
342        // Slowly tighten accept_error for diversity
343        if error < self.accept_error * ACCEPT_ERROR_TIGHTEN_FACTOR {
344            self.accept_error *= ACCEPT_ERROR_TIGHTEN_FACTOR;
345        }
346
347        // Evict worst if over capacity
348        if self.heap.len() > self.capacity {
349            if let Some(evicted) = self.heap.pop() {
350                // Remove the equation key so a *different* RHS for the same LHS can
351                // be inserted later.
352                self.seen_eqn.remove(&EqnKey::from_match(&evicted.m));
353                self.stats.evictions += 1;
354            }
355        }
356
357        true
358    }
359
360    /// Check if a match would be accepted (for early pruning)
361    pub fn would_accept(&self, error: f64, is_exact: bool) -> bool {
362        if is_exact {
363            return true;
364        }
365        error <= self.accept_error
366    }
367
368    /// Check if a match would be accepted, with stricter gate when pool is near capacity
369    /// This is used as a pre-Newton filter to avoid expensive refinement calls
370    pub fn would_accept_strict(&self, coarse_error: f64, is_potentially_exact: bool) -> bool {
371        // Always accept potential exact matches
372        if is_potentially_exact {
373            return true;
374        }
375
376        // Basic threshold check
377        if coarse_error > self.accept_error {
378            return false;
379        }
380
381        // Stricter check when pool is near capacity:
382        // If we're at capacity fraction and have good matches, be more aggressive
383        if self.heap.len() as f64 >= self.capacity as f64 * STRICT_GATE_CAPACITY_FRACTION {
384            // Only accept if error is below strict gate threshold
385            // This avoids Newton calls for marginal candidates
386            if coarse_error > self.accept_error * STRICT_GATE_FACTOR {
387                return false;
388            }
389        }
390
391        true
392    }
393
394    /// Get all matches, sorted by ranking mode.
395    pub fn into_sorted(self) -> Vec<Match> {
396        let ranking_mode = self.ranking_mode;
397        let mut matches: Vec<Match> = self.heap.into_iter().map(|e| e.m).collect();
398        matches.sort_by(|a, b| compare_matches(a, b, ranking_mode));
399        matches
400    }
401
402    /// Get current pool size
403    pub fn len(&self) -> usize {
404        self.heap.len()
405    }
406
407    /// Check if pool is empty
408    #[allow(dead_code)]
409    pub fn is_empty(&self) -> bool {
410        self.heap.is_empty()
411    }
412}
413
414#[cfg(test)]
415mod tests {
416    use super::*;
417    use crate::expr::{EvaluatedExpr, Expression};
418    use crate::symbol::NumType;
419
420    fn make_match(lhs: &str, rhs: &str, error: f64, complexity: u32) -> Match {
421        let lhs_expr = Expression::parse(lhs).unwrap();
422        let rhs_expr = Expression::parse(rhs).unwrap();
423        Match {
424            lhs: EvaluatedExpr::new(lhs_expr, 0.0, 1.0, NumType::Integer),
425            rhs: EvaluatedExpr::new(rhs_expr, 0.0, 0.0, NumType::Integer),
426            x_value: 2.5,
427            error,
428            complexity,
429        }
430    }
431
432    #[test]
433    fn test_pool_basic() {
434        let mut pool = TopKPool::new(5, 1.0);
435
436        // Insert some matches
437        assert!(pool.try_insert(make_match("2x*", "5", 0.0, 27)));
438        assert!(pool.try_insert(make_match("x1+", "35/", 0.01, 34)));
439
440        assert_eq!(pool.len(), 2);
441    }
442
443    #[test]
444    fn test_pool_eviction() {
445        let mut pool = TopKPool::new(2, 1.0);
446
447        // Insert 3 matches into pool of capacity 2
448        // Worst (highest complexity): xs with complexity 50
449        // Best: 2x* with complexity 27 and exact match
450        // Medium: x1+ with complexity 34
451        pool.try_insert(make_match("xs", "64/", 0.1, 50));
452        pool.try_insert(make_match("2x*", "5", 0.0, 27));
453        pool.try_insert(make_match("x1+", "35/", 0.01, 34));
454
455        // Should have evicted the worst one (highest complexity)
456        assert_eq!(pool.len(), 2);
457
458        let sorted = pool.into_sorted();
459        // The two best matches should remain (by complexity)
460        // 2x* (27) and x1+ (34) should remain, xs (50) should be evicted
461        let remaining: Vec<_> = sorted.iter().map(|m| m.lhs.expr.to_postfix()).collect();
462        assert!(
463            remaining.contains(&"2x*".to_string()),
464            "Expected 2x* to remain, got: {:?}",
465            remaining
466        );
467        assert!(
468            remaining.contains(&"x1+".to_string()),
469            "Expected x1+ to remain, got: {:?}",
470            remaining
471        );
472    }
473
474    #[test]
475    fn test_pool_dedupe() {
476        let mut pool = TopKPool::new(10, 1.0);
477
478        // Try to insert same equation twice
479        assert!(pool.try_insert(make_match("2x*", "5", 0.0, 27)));
480        assert!(!pool.try_insert(make_match("2x*", "5", 0.0, 27)));
481
482        assert_eq!(pool.len(), 1);
483    }
484
485    #[test]
486    fn test_parity_score_prefers_operator_dense_form() {
487        let low_operator = make_match("2x*", "5", 1e-6, 10);
488        let high_operator = make_match("x1+", "3", 1e-6, 20);
489
490        let low_score = legacy_parity_score_match(&low_operator);
491        let high_score = legacy_parity_score_match(&high_operator);
492        assert!(
493            high_score < low_score,
494            "expected operator-dense form to have lower legacy parity score ({} vs {})",
495            high_score,
496            low_score
497        );
498    }
499
500    #[test]
501    fn test_parity_ranking_changes_ordering() {
502        let low_operator = make_match("2x*", "5", 1e-6, 10);
503        let high_operator = make_match("x1+", "3", 1e-6, 20);
504
505        // Complexity mode: simpler complexity first.
506        let mut complexity_pool =
507            TopKPool::new_with_diagnostics(10, 1.0, false, RankingMode::Complexity);
508        complexity_pool.try_insert(low_operator.clone());
509        complexity_pool.try_insert(high_operator.clone());
510        let complexity_sorted = complexity_pool.into_sorted();
511        assert_eq!(complexity_sorted[0].lhs.expr.to_postfix(), "2x*");
512
513        // Parity mode: legacy parity score first.
514        let mut parity_pool = TopKPool::new_with_diagnostics(10, 1.0, false, RankingMode::Parity);
515        parity_pool.try_insert(low_operator);
516        parity_pool.try_insert(high_operator);
517        let parity_sorted = parity_pool.into_sorted();
518        assert_eq!(parity_sorted[0].lhs.expr.to_postfix(), "x1+");
519    }
520
521    #[test]
522    fn test_pool_handles_nan_and_infinity_errors() {
523        let mut pool = TopKPool::new(10, f64::INFINITY);
524
525        // Normal error should be accepted
526        let normal = make_match("x", "1", 0.01, 25);
527        assert!(pool.try_insert(normal));
528
529        // Infinity error should sort as worst but still be accepted
530        let infinite = make_match("x1+", "2", f64::INFINITY, 30);
531        assert!(pool.try_insert(infinite));
532
533        // NaN error should also be handled (sorts as worst)
534        let nan_match = make_match("x2*", "3", f64::NAN, 35);
535        assert!(pool.try_insert(nan_match));
536
537        // All three should be in the pool
538        assert_eq!(pool.len(), 3);
539
540        // When sorted, the normal match should come first (lowest error)
541        let sorted = pool.into_sorted();
542        assert_eq!(sorted[0].lhs.expr.to_postfix(), "x");
543    }
544
545    #[test]
546    fn test_pool_entry_distinct_with_same_rank_key() {
547        // Two matches with identical rank_key but different expressions
548        // should both be insertable (they're distinct equations)
549        let mut pool = TopKPool::new(10, 1.0);
550
551        // Both have error 0.0 and same complexity, but different LHS
552        let m1 = make_match("x", "1", 0.0, 25);
553        let m2 = make_match("x1-", "1", 0.0, 25);
554
555        assert!(pool.try_insert(m1));
556        assert!(pool.try_insert(m2));
557
558        // Both should be in the pool since they're different equations
559        assert_eq!(pool.len(), 2);
560    }
561}