1use crate::expr::Expression;
14use crate::search::Match;
15use crate::thresholds::{
16 ACCEPT_ERROR_TIGHTEN_FACTOR, BEST_ERROR_TIGHTEN_FACTOR, EXACT_MATCH_TOLERANCE,
17 NEWTON_TOLERANCE, STRICT_GATE_CAPACITY_FRACTION, STRICT_GATE_FACTOR,
18};
19use std::cmp::Ordering;
20use std::collections::{BinaryHeap, HashSet};
21
22#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
24pub enum RankingMode {
25 #[default]
27 Complexity,
28 Parity,
30}
31
32#[derive(Clone, PartialEq, Eq, Hash)]
37pub struct EqnKey {
38 lhs: Expression,
40 rhs: Expression,
42}
43
44impl EqnKey {
45 #[inline]
47 pub fn from_match(m: &Match) -> Self {
48 Self {
49 lhs: m.lhs.expr.clone(),
50 rhs: m.rhs.expr.clone(),
51 }
52 }
53}
54
55#[derive(Clone, PartialEq, Eq, Hash)]
59pub struct LhsKey {
60 lhs: Expression,
62}
63
64impl LhsKey {
65 #[inline]
67 pub fn from_match(m: &Match) -> Self {
68 Self {
69 lhs: m.lhs.expr.clone(),
70 }
71 }
72}
73
74#[derive(Clone, PartialEq, Eq, Hash)]
79pub struct SignatureKey {
80 key: Box<[u8]>,
82}
83
84impl SignatureKey {
85 pub fn from_match(m: &Match) -> Self {
86 let expected_len = m.lhs.expr.len() + m.rhs.expr.len() + 1;
88 let mut ops = Vec::with_capacity(expected_len);
89
90 for sym in m.lhs.expr.symbols() {
91 ops.push(*sym as u8);
92 }
93 ops.push(b'=');
94 for sym in m.rhs.expr.symbols() {
95 ops.push(*sym as u8);
96 }
97
98 Self {
99 key: ops.into_boxed_slice(),
100 }
101 }
102}
103
104pub fn legacy_parity_score_expr(expr: &Expression) -> i32 {
106 expr.symbols().iter().fold(0_i32, |acc, sym| {
107 acc.saturating_add(sym.legacy_parity_weight())
108 })
109}
110
111pub fn legacy_parity_score_match(m: &Match) -> i32 {
113 legacy_parity_score_expr(&m.lhs.expr).saturating_add(legacy_parity_score_expr(&m.rhs.expr))
114}
115
116#[inline]
117fn compare_expr(a: &Expression, b: &Expression) -> Ordering {
118 a.symbols()
119 .iter()
120 .map(|s| *s as u8)
121 .cmp(b.symbols().iter().map(|s| *s as u8))
122}
123
124pub fn compare_matches(a: &Match, b: &Match, ranking_mode: RankingMode) -> Ordering {
126 let a_exactness = if a.error.abs() < EXACT_MATCH_TOLERANCE {
127 0_u8
128 } else {
129 1_u8
130 };
131 let b_exactness = if b.error.abs() < EXACT_MATCH_TOLERANCE {
132 0_u8
133 } else {
134 1_u8
135 };
136
137 let mut ord = a_exactness.cmp(&b_exactness).then_with(|| {
138 a.error
139 .abs()
140 .partial_cmp(&b.error.abs())
141 .unwrap_or(Ordering::Equal)
142 });
143
144 if ord != Ordering::Equal {
145 return ord;
146 }
147
148 ord = match ranking_mode {
149 RankingMode::Complexity => a.complexity.cmp(&b.complexity),
150 RankingMode::Parity => legacy_parity_score_match(a)
151 .cmp(&legacy_parity_score_match(b))
152 .then_with(|| a.complexity.cmp(&b.complexity)),
153 };
154
155 if ord != Ordering::Equal {
156 return ord;
157 }
158
159 compare_expr(&a.lhs.expr, &b.lhs.expr).then_with(|| compare_expr(&a.rhs.expr, &b.rhs.expr))
160}
161
162#[derive(Clone)]
165struct PoolEntry {
166 m: Match,
167 rank_key: (u8, i64, i32, u32), }
169
170impl PoolEntry {
171 fn new(m: Match, ranking_mode: RankingMode) -> Self {
172 let is_exact = m.error.abs() < EXACT_MATCH_TOLERANCE;
173 let exactness_rank = if is_exact { 0 } else { 1 };
174 let error_abs = m.error.abs();
178 let error_bits = if error_abs.is_nan() {
179 i64::MAX
181 } else if error_abs.is_infinite() {
182 i64::MAX - 1
184 } else {
185 error_abs.to_bits() as i64
188 };
189 let mode_tie = match ranking_mode {
190 RankingMode::Complexity => m.complexity as i32,
191 RankingMode::Parity => legacy_parity_score_match(&m),
192 };
193 Self {
194 rank_key: (exactness_rank, error_bits, mode_tie, m.complexity),
195 m,
196 }
197 }
198}
199
200impl PartialEq for PoolEntry {
201 fn eq(&self, other: &Self) -> bool {
202 self.rank_key == other.rank_key
203 && self.m.lhs.expr == other.m.lhs.expr
204 && self.m.rhs.expr == other.m.rhs.expr
205 }
206}
207
208impl Eq for PoolEntry {}
209
210impl PartialOrd for PoolEntry {
211 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
212 Some(self.cmp(other))
213 }
214}
215
216impl Ord for PoolEntry {
217 fn cmp(&self, other: &Self) -> Ordering {
218 self.rank_key
220 .cmp(&other.rank_key)
221 .then_with(|| compare_expr(&self.m.lhs.expr, &other.m.lhs.expr))
222 .then_with(|| compare_expr(&self.m.rhs.expr, &other.m.rhs.expr))
223 }
224}
225
226#[derive(Clone, Debug, Default)]
228pub struct PoolStats {
229 pub insertions: usize,
231 pub rejections_error: usize,
233 pub rejections_dedupe: usize,
235 pub evictions: usize,
237}
238
239pub struct TopKPool {
241 capacity: usize,
243 heap: BinaryHeap<PoolEntry>,
245 seen_eqn: HashSet<EqnKey>,
247 pub best_error: f64,
249 pub accept_error: f64,
251 pub stats: PoolStats,
253 show_db_adds: bool,
255 ranking_mode: RankingMode,
257}
258
259impl TopKPool {
260 #[allow(dead_code)]
262 pub fn new(capacity: usize, initial_max_error: f64) -> Self {
263 Self {
264 capacity,
265 heap: BinaryHeap::with_capacity(capacity + 1),
266 seen_eqn: HashSet::new(),
267 best_error: initial_max_error,
268 accept_error: initial_max_error,
269 stats: PoolStats::default(),
270 show_db_adds: false,
271 ranking_mode: RankingMode::Complexity,
272 }
273 }
274
275 pub fn new_with_diagnostics(
277 capacity: usize,
278 initial_max_error: f64,
279 show_db_adds: bool,
280 ranking_mode: RankingMode,
281 ) -> Self {
282 Self {
283 capacity,
284 heap: BinaryHeap::with_capacity(capacity + 1),
285 seen_eqn: HashSet::new(),
286 best_error: initial_max_error,
287 accept_error: initial_max_error,
288 stats: PoolStats::default(),
289 show_db_adds,
290 ranking_mode,
291 }
292 }
293
294 pub fn try_insert(&mut self, m: Match) -> bool {
297 let error = m.error.abs();
298 let is_exact = error < EXACT_MATCH_TOLERANCE;
299
300 if !is_exact && error > self.accept_error {
302 self.stats.rejections_error += 1;
303 return false;
304 }
305
306 let eqn_key = EqnKey::from_match(&m);
308 if self.seen_eqn.contains(&eqn_key) {
309 self.stats.rejections_dedupe += 1;
310 return false;
311 }
312
313 let entry = PoolEntry::new(m, self.ranking_mode);
315 self.seen_eqn.insert(eqn_key);
316
317 if self.show_db_adds {
319 eprintln!(
320 " [db add] lhs={:?} rhs={:?} error={:.6e} complexity={}",
321 entry.m.lhs.expr.to_postfix(),
322 entry.m.rhs.expr.to_postfix(),
323 entry.m.error,
324 entry.m.complexity
325 );
326 }
327
328 self.heap.push(entry);
329 self.stats.insertions += 1;
330
331 if is_exact {
333 self.best_error =
335 EXACT_MATCH_TOLERANCE.max(self.best_error * BEST_ERROR_TIGHTEN_FACTOR);
336 } else if error < self.best_error {
337 self.best_error = error * BEST_ERROR_TIGHTEN_FACTOR - NEWTON_TOLERANCE;
339 self.best_error = self.best_error.max(EXACT_MATCH_TOLERANCE);
340 }
341
342 if error < self.accept_error * ACCEPT_ERROR_TIGHTEN_FACTOR {
344 self.accept_error *= ACCEPT_ERROR_TIGHTEN_FACTOR;
345 }
346
347 if self.heap.len() > self.capacity {
349 if let Some(evicted) = self.heap.pop() {
350 self.seen_eqn.remove(&EqnKey::from_match(&evicted.m));
353 self.stats.evictions += 1;
354 }
355 }
356
357 true
358 }
359
360 pub fn would_accept(&self, error: f64, is_exact: bool) -> bool {
362 if is_exact {
363 return true;
364 }
365 error <= self.accept_error
366 }
367
368 pub fn would_accept_strict(&self, coarse_error: f64, is_potentially_exact: bool) -> bool {
371 if is_potentially_exact {
373 return true;
374 }
375
376 if coarse_error > self.accept_error {
378 return false;
379 }
380
381 if self.heap.len() as f64 >= self.capacity as f64 * STRICT_GATE_CAPACITY_FRACTION {
384 if coarse_error > self.accept_error * STRICT_GATE_FACTOR {
387 return false;
388 }
389 }
390
391 true
392 }
393
394 pub fn into_sorted(self) -> Vec<Match> {
396 let ranking_mode = self.ranking_mode;
397 let mut matches: Vec<Match> = self.heap.into_iter().map(|e| e.m).collect();
398 matches.sort_by(|a, b| compare_matches(a, b, ranking_mode));
399 matches
400 }
401
402 pub fn len(&self) -> usize {
404 self.heap.len()
405 }
406
407 #[allow(dead_code)]
409 pub fn is_empty(&self) -> bool {
410 self.heap.is_empty()
411 }
412}
413
414#[cfg(test)]
415mod tests {
416 use super::*;
417 use crate::expr::{EvaluatedExpr, Expression};
418 use crate::symbol::NumType;
419
420 fn make_match(lhs: &str, rhs: &str, error: f64, complexity: u32) -> Match {
421 let lhs_expr = Expression::parse(lhs).unwrap();
422 let rhs_expr = Expression::parse(rhs).unwrap();
423 Match {
424 lhs: EvaluatedExpr::new(lhs_expr, 0.0, 1.0, NumType::Integer),
425 rhs: EvaluatedExpr::new(rhs_expr, 0.0, 0.0, NumType::Integer),
426 x_value: 2.5,
427 error,
428 complexity,
429 }
430 }
431
432 #[test]
433 fn test_pool_basic() {
434 let mut pool = TopKPool::new(5, 1.0);
435
436 assert!(pool.try_insert(make_match("2x*", "5", 0.0, 27)));
438 assert!(pool.try_insert(make_match("x1+", "35/", 0.01, 34)));
439
440 assert_eq!(pool.len(), 2);
441 }
442
443 #[test]
444 fn test_pool_eviction() {
445 let mut pool = TopKPool::new(2, 1.0);
446
447 pool.try_insert(make_match("xs", "64/", 0.1, 50));
452 pool.try_insert(make_match("2x*", "5", 0.0, 27));
453 pool.try_insert(make_match("x1+", "35/", 0.01, 34));
454
455 assert_eq!(pool.len(), 2);
457
458 let sorted = pool.into_sorted();
459 let remaining: Vec<_> = sorted.iter().map(|m| m.lhs.expr.to_postfix()).collect();
462 assert!(
463 remaining.contains(&"2x*".to_string()),
464 "Expected 2x* to remain, got: {:?}",
465 remaining
466 );
467 assert!(
468 remaining.contains(&"x1+".to_string()),
469 "Expected x1+ to remain, got: {:?}",
470 remaining
471 );
472 }
473
474 #[test]
475 fn test_pool_dedupe() {
476 let mut pool = TopKPool::new(10, 1.0);
477
478 assert!(pool.try_insert(make_match("2x*", "5", 0.0, 27)));
480 assert!(!pool.try_insert(make_match("2x*", "5", 0.0, 27)));
481
482 assert_eq!(pool.len(), 1);
483 }
484
485 #[test]
486 fn test_parity_score_prefers_operator_dense_form() {
487 let low_operator = make_match("2x*", "5", 1e-6, 10);
488 let high_operator = make_match("x1+", "3", 1e-6, 20);
489
490 let low_score = legacy_parity_score_match(&low_operator);
491 let high_score = legacy_parity_score_match(&high_operator);
492 assert!(
493 high_score < low_score,
494 "expected operator-dense form to have lower legacy parity score ({} vs {})",
495 high_score,
496 low_score
497 );
498 }
499
500 #[test]
501 fn test_parity_ranking_changes_ordering() {
502 let low_operator = make_match("2x*", "5", 1e-6, 10);
503 let high_operator = make_match("x1+", "3", 1e-6, 20);
504
505 let mut complexity_pool =
507 TopKPool::new_with_diagnostics(10, 1.0, false, RankingMode::Complexity);
508 complexity_pool.try_insert(low_operator.clone());
509 complexity_pool.try_insert(high_operator.clone());
510 let complexity_sorted = complexity_pool.into_sorted();
511 assert_eq!(complexity_sorted[0].lhs.expr.to_postfix(), "2x*");
512
513 let mut parity_pool = TopKPool::new_with_diagnostics(10, 1.0, false, RankingMode::Parity);
515 parity_pool.try_insert(low_operator);
516 parity_pool.try_insert(high_operator);
517 let parity_sorted = parity_pool.into_sorted();
518 assert_eq!(parity_sorted[0].lhs.expr.to_postfix(), "x1+");
519 }
520
521 #[test]
522 fn test_pool_handles_nan_and_infinity_errors() {
523 let mut pool = TopKPool::new(10, f64::INFINITY);
524
525 let normal = make_match("x", "1", 0.01, 25);
527 assert!(pool.try_insert(normal));
528
529 let infinite = make_match("x1+", "2", f64::INFINITY, 30);
531 assert!(pool.try_insert(infinite));
532
533 let nan_match = make_match("x2*", "3", f64::NAN, 35);
535 assert!(pool.try_insert(nan_match));
536
537 assert_eq!(pool.len(), 3);
539
540 let sorted = pool.into_sorted();
542 assert_eq!(sorted[0].lhs.expr.to_postfix(), "x");
543 }
544
545 #[test]
546 fn test_pool_entry_distinct_with_same_rank_key() {
547 let mut pool = TopKPool::new(10, 1.0);
550
551 let m1 = make_match("x", "1", 0.0, 25);
553 let m2 = make_match("x1-", "1", 0.0, 25);
554
555 assert!(pool.try_insert(m1));
556 assert!(pool.try_insert(m2));
557
558 assert_eq!(pool.len(), 2);
560 }
561}