use crate::expr::Expression;
use crate::search::Match;
use crate::thresholds::{
ACCEPT_ERROR_TIGHTEN_FACTOR, BEST_ERROR_TIGHTEN_FACTOR, EXACT_MATCH_TOLERANCE,
NEWTON_TOLERANCE, STRICT_GATE_CAPACITY_FRACTION, STRICT_GATE_FACTOR,
};
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashSet};
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum RankingMode {
#[default]
Complexity,
Parity,
}
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct EqnKey {
lhs: Expression,
rhs: Expression,
}
impl EqnKey {
#[inline]
pub fn from_match(m: &Match) -> Self {
Self {
lhs: m.lhs.expr.clone(),
rhs: m.rhs.expr.clone(),
}
}
}
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct LhsKey {
lhs: Expression,
}
impl LhsKey {
#[inline]
pub fn from_match(m: &Match) -> Self {
Self {
lhs: m.lhs.expr.clone(),
}
}
}
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct SignatureKey {
key: Box<[u8]>,
}
impl SignatureKey {
pub fn from_match(m: &Match) -> Self {
let expected_len = m.lhs.expr.len() + m.rhs.expr.len() + 1;
let mut ops = Vec::with_capacity(expected_len);
for sym in m.lhs.expr.symbols() {
ops.push(*sym as u8);
}
ops.push(b'=');
for sym in m.rhs.expr.symbols() {
ops.push(*sym as u8);
}
Self {
key: ops.into_boxed_slice(),
}
}
}
pub fn legacy_parity_score_expr(expr: &Expression) -> i32 {
expr.symbols().iter().fold(0_i32, |acc, sym| {
acc.saturating_add(sym.legacy_parity_weight())
})
}
pub fn legacy_parity_score_match(m: &Match) -> i32 {
legacy_parity_score_expr(&m.lhs.expr).saturating_add(legacy_parity_score_expr(&m.rhs.expr))
}
#[inline]
fn compare_expr(a: &Expression, b: &Expression) -> Ordering {
a.symbols()
.iter()
.map(|s| *s as u8)
.cmp(b.symbols().iter().map(|s| *s as u8))
}
pub fn compare_matches(a: &Match, b: &Match, ranking_mode: RankingMode) -> Ordering {
let a_exactness = if a.error.abs() < EXACT_MATCH_TOLERANCE {
0_u8
} else {
1_u8
};
let b_exactness = if b.error.abs() < EXACT_MATCH_TOLERANCE {
0_u8
} else {
1_u8
};
let mut ord = a_exactness.cmp(&b_exactness).then_with(|| {
a.error
.abs()
.partial_cmp(&b.error.abs())
.unwrap_or(Ordering::Equal)
});
if ord != Ordering::Equal {
return ord;
}
ord = match ranking_mode {
RankingMode::Complexity => a.complexity.cmp(&b.complexity),
RankingMode::Parity => legacy_parity_score_match(a)
.cmp(&legacy_parity_score_match(b))
.then_with(|| a.complexity.cmp(&b.complexity)),
};
if ord != Ordering::Equal {
return ord;
}
compare_expr(&a.lhs.expr, &b.lhs.expr).then_with(|| compare_expr(&a.rhs.expr, &b.rhs.expr))
}
#[derive(Clone)]
struct PoolEntry {
m: Match,
rank_key: (u8, i64, i32, u32), }
impl PoolEntry {
fn new(m: Match, ranking_mode: RankingMode) -> Self {
let is_exact = m.error.abs() < EXACT_MATCH_TOLERANCE;
let exactness_rank = if is_exact { 0 } else { 1 };
let error_abs = m.error.abs();
let error_bits = if error_abs.is_nan() {
i64::MAX
} else if error_abs.is_infinite() {
i64::MAX - 1
} else {
error_abs.to_bits() as i64
};
let mode_tie = match ranking_mode {
RankingMode::Complexity => m.complexity as i32,
RankingMode::Parity => legacy_parity_score_match(&m),
};
Self {
rank_key: (exactness_rank, error_bits, mode_tie, m.complexity),
m,
}
}
}
impl PartialEq for PoolEntry {
fn eq(&self, other: &Self) -> bool {
self.rank_key == other.rank_key
&& self.m.lhs.expr == other.m.lhs.expr
&& self.m.rhs.expr == other.m.rhs.expr
}
}
impl Eq for PoolEntry {}
impl PartialOrd for PoolEntry {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for PoolEntry {
fn cmp(&self, other: &Self) -> Ordering {
self.rank_key
.cmp(&other.rank_key)
.then_with(|| compare_expr(&self.m.lhs.expr, &other.m.lhs.expr))
.then_with(|| compare_expr(&self.m.rhs.expr, &other.m.rhs.expr))
}
}
#[derive(Clone, Debug, Default)]
pub struct PoolStats {
pub insertions: usize,
pub rejections_error: usize,
pub rejections_dedupe: usize,
pub evictions: usize,
}
pub struct TopKPool {
capacity: usize,
heap: BinaryHeap<PoolEntry>,
seen_eqn: HashSet<EqnKey>,
pub best_error: f64,
pub accept_error: f64,
pub stats: PoolStats,
show_db_adds: bool,
ranking_mode: RankingMode,
}
impl TopKPool {
#[allow(dead_code)]
pub fn new(capacity: usize, initial_max_error: f64) -> Self {
Self {
capacity,
heap: BinaryHeap::with_capacity(capacity + 1),
seen_eqn: HashSet::new(),
best_error: initial_max_error,
accept_error: initial_max_error,
stats: PoolStats::default(),
show_db_adds: false,
ranking_mode: RankingMode::Complexity,
}
}
pub fn new_with_diagnostics(
capacity: usize,
initial_max_error: f64,
show_db_adds: bool,
ranking_mode: RankingMode,
) -> Self {
Self {
capacity,
heap: BinaryHeap::with_capacity(capacity + 1),
seen_eqn: HashSet::new(),
best_error: initial_max_error,
accept_error: initial_max_error,
stats: PoolStats::default(),
show_db_adds,
ranking_mode,
}
}
pub fn try_insert(&mut self, m: Match) -> bool {
let error = m.error.abs();
let is_exact = error < EXACT_MATCH_TOLERANCE;
if !is_exact && error > self.accept_error {
self.stats.rejections_error += 1;
return false;
}
let eqn_key = EqnKey::from_match(&m);
if self.seen_eqn.contains(&eqn_key) {
self.stats.rejections_dedupe += 1;
return false;
}
let entry = PoolEntry::new(m, self.ranking_mode);
self.seen_eqn.insert(eqn_key);
if self.show_db_adds {
eprintln!(
" [db add] lhs={:?} rhs={:?} error={:.6e} complexity={}",
entry.m.lhs.expr.to_postfix(),
entry.m.rhs.expr.to_postfix(),
entry.m.error,
entry.m.complexity
);
}
self.heap.push(entry);
self.stats.insertions += 1;
if is_exact {
self.best_error =
EXACT_MATCH_TOLERANCE.max(self.best_error * BEST_ERROR_TIGHTEN_FACTOR);
} else if error < self.best_error {
self.best_error = error * BEST_ERROR_TIGHTEN_FACTOR - NEWTON_TOLERANCE;
self.best_error = self.best_error.max(EXACT_MATCH_TOLERANCE);
}
if error < self.accept_error * ACCEPT_ERROR_TIGHTEN_FACTOR {
self.accept_error *= ACCEPT_ERROR_TIGHTEN_FACTOR;
}
if self.heap.len() > self.capacity {
if let Some(evicted) = self.heap.pop() {
self.seen_eqn.remove(&EqnKey::from_match(&evicted.m));
self.stats.evictions += 1;
}
}
true
}
pub fn would_accept(&self, error: f64, is_exact: bool) -> bool {
if is_exact {
return true;
}
error <= self.accept_error
}
pub fn would_accept_strict(&self, coarse_error: f64, is_potentially_exact: bool) -> bool {
if is_potentially_exact {
return true;
}
if coarse_error > self.accept_error {
return false;
}
if self.heap.len() as f64 >= self.capacity as f64 * STRICT_GATE_CAPACITY_FRACTION {
if coarse_error > self.accept_error * STRICT_GATE_FACTOR {
return false;
}
}
true
}
pub fn into_sorted(self) -> Vec<Match> {
let ranking_mode = self.ranking_mode;
let mut matches: Vec<Match> = self.heap.into_iter().map(|e| e.m).collect();
matches.sort_by(|a, b| compare_matches(a, b, ranking_mode));
matches
}
pub fn len(&self) -> usize {
self.heap.len()
}
#[allow(dead_code)]
pub fn is_empty(&self) -> bool {
self.heap.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::expr::{EvaluatedExpr, Expression};
use crate::symbol::NumType;
fn make_match(lhs: &str, rhs: &str, error: f64, complexity: u32) -> Match {
let lhs_expr = Expression::parse(lhs).unwrap();
let rhs_expr = Expression::parse(rhs).unwrap();
Match {
lhs: EvaluatedExpr::new(lhs_expr, 0.0, 1.0, NumType::Integer),
rhs: EvaluatedExpr::new(rhs_expr, 0.0, 0.0, NumType::Integer),
x_value: 2.5,
error,
complexity,
}
}
#[test]
fn test_pool_basic() {
let mut pool = TopKPool::new(5, 1.0);
assert!(pool.try_insert(make_match("2x*", "5", 0.0, 27)));
assert!(pool.try_insert(make_match("x1+", "35/", 0.01, 34)));
assert_eq!(pool.len(), 2);
}
#[test]
fn test_pool_eviction() {
let mut pool = TopKPool::new(2, 1.0);
pool.try_insert(make_match("xs", "64/", 0.1, 50));
pool.try_insert(make_match("2x*", "5", 0.0, 27));
pool.try_insert(make_match("x1+", "35/", 0.01, 34));
assert_eq!(pool.len(), 2);
let sorted = pool.into_sorted();
let remaining: Vec<_> = sorted.iter().map(|m| m.lhs.expr.to_postfix()).collect();
assert!(
remaining.contains(&"2x*".to_string()),
"Expected 2x* to remain, got: {:?}",
remaining
);
assert!(
remaining.contains(&"x1+".to_string()),
"Expected x1+ to remain, got: {:?}",
remaining
);
}
#[test]
fn test_pool_dedupe() {
let mut pool = TopKPool::new(10, 1.0);
assert!(pool.try_insert(make_match("2x*", "5", 0.0, 27)));
assert!(!pool.try_insert(make_match("2x*", "5", 0.0, 27)));
assert_eq!(pool.len(), 1);
}
#[test]
fn test_parity_score_prefers_operator_dense_form() {
let low_operator = make_match("2x*", "5", 1e-6, 10);
let high_operator = make_match("x1+", "3", 1e-6, 20);
let low_score = legacy_parity_score_match(&low_operator);
let high_score = legacy_parity_score_match(&high_operator);
assert!(
high_score < low_score,
"expected operator-dense form to have lower legacy parity score ({} vs {})",
high_score,
low_score
);
}
#[test]
fn test_parity_ranking_changes_ordering() {
let low_operator = make_match("2x*", "5", 1e-6, 10);
let high_operator = make_match("x1+", "3", 1e-6, 20);
let mut complexity_pool =
TopKPool::new_with_diagnostics(10, 1.0, false, RankingMode::Complexity);
complexity_pool.try_insert(low_operator.clone());
complexity_pool.try_insert(high_operator.clone());
let complexity_sorted = complexity_pool.into_sorted();
assert_eq!(complexity_sorted[0].lhs.expr.to_postfix(), "2x*");
let mut parity_pool = TopKPool::new_with_diagnostics(10, 1.0, false, RankingMode::Parity);
parity_pool.try_insert(low_operator);
parity_pool.try_insert(high_operator);
let parity_sorted = parity_pool.into_sorted();
assert_eq!(parity_sorted[0].lhs.expr.to_postfix(), "x1+");
}
#[test]
fn test_pool_handles_nan_and_infinity_errors() {
let mut pool = TopKPool::new(10, f64::INFINITY);
let normal = make_match("x", "1", 0.01, 25);
assert!(pool.try_insert(normal));
let infinite = make_match("x1+", "2", f64::INFINITY, 30);
assert!(pool.try_insert(infinite));
let nan_match = make_match("x2*", "3", f64::NAN, 35);
assert!(pool.try_insert(nan_match));
assert_eq!(pool.len(), 3);
let sorted = pool.into_sorted();
assert_eq!(sorted[0].lhs.expr.to_postfix(), "x");
}
#[test]
fn test_pool_entry_distinct_with_same_rank_key() {
let mut pool = TopKPool::new(10, 1.0);
let m1 = make_match("x", "1", 0.0, 25);
let m2 = make_match("x1-", "1", 0.0, 25);
assert!(pool.try_insert(m1));
assert!(pool.try_insert(m2));
assert_eq!(pool.len(), 2);
}
}