Skip to main content

ries_rs/
metrics.rs

1//! Metrics and scoring for match categorization
2//!
3//! Computes scores across multiple dimensions to categorize matches:
4//! - Exact: error below machine epsilon
5//! - Best: lowest error approximations
6//! - Elegant: simplest/cleanest expressions
7//! - Interesting: novel/unexpected combinations
8//! - Stable: robust matches (good Newton conditioning)
9
10use crate::search::Match;
11use crate::symbol::{Seft, Symbol};
12use crate::thresholds::{DEGENERATE_TEST_THRESHOLD, EXACT_MATCH_TOLERANCE};
13use std::collections::HashMap;
14
15/// Metrics computed for a match
16#[derive(Clone, Debug)]
17pub struct MatchMetrics {
18    /// Absolute error from target
19    pub error: f64,
20    /// Whether this is an exact match (error < 1e-14)
21    pub is_exact: bool,
22    /// Total complexity score
23    pub complexity: u32,
24    /// "Ugliness" penalty (deep nesting, many ops)
25    pub ugliness: f64,
26    /// Novelty score (rarer operators/constants)
27    pub novelty: f64,
28    /// Stability score (Newton conditioning)
29    pub stability: f64,
30    /// Operator diversity score
31    pub diversity: f64,
32}
33
34impl MatchMetrics {
35    /// Compute metrics for a match
36    pub fn from_match(m: &Match, freq_map: Option<&OperatorFrequency>) -> Self {
37        let error = m.error.abs();
38        let is_exact = error < EXACT_MATCH_TOLERANCE;
39        let complexity = m.complexity;
40
41        // Ugliness: penalize deep nesting and operator count
42        let ugliness = compute_ugliness(m);
43
44        // Novelty: based on operator rarity
45        let novelty = compute_novelty(m, freq_map);
46
47        // Stability: based on derivative magnitude at solution
48        let stability = compute_stability(m);
49
50        // Diversity: bonus for mixed operator families
51        let diversity = compute_diversity(m);
52
53        Self {
54            error,
55            is_exact,
56            complexity,
57            ugliness,
58            novelty,
59            stability,
60            diversity,
61        }
62    }
63
64    /// Elegant score: lower is better
65    /// Optimizes for simplicity and cleanliness
66    pub fn elegant_score(&self) -> f64 {
67        self.complexity as f64 + 0.1 * self.ugliness
68    }
69
70    /// Interesting score: higher is better
71    /// Optimizes for novelty while maintaining reasonable error
72    pub fn interesting_score(&self, error_cap: f64) -> f64 {
73        if self.error > error_cap {
74            return f64::NEG_INFINITY;
75        }
76
77        // Normalize error to [0, 1] range within cap.
78        // When error_cap == EXACT_MATCH_TOLERANCE (1e-14) the denominator is 0;
79        // treat it as a near-exact match (error_norm = 0).
80        let error_norm = if self.error < EXACT_MATCH_TOLERANCE {
81            0.0
82        } else {
83            let denom = error_cap.log10() + 14.0;
84            if denom.abs() < f64::EPSILON {
85                0.0
86            } else {
87                (self.error.log10() + 14.0) / denom
88            }
89        };
90
91        // Normalize complexity to rough [0, 1] range (100 = max typical)
92        let complexity_norm = (self.complexity as f64) / 100.0;
93
94        // Score formula: novelty is king, but penalize high error and complexity
95        self.novelty + 0.3 * self.diversity - 0.7 * error_norm - 0.2 * complexity_norm
96    }
97
98    /// Stable score: higher is better
99    pub fn stable_score(&self) -> f64 {
100        self.stability
101    }
102}
103
104/// Operator frequency map for computing rarity
105#[derive(Default)]
106pub struct OperatorFrequency {
107    /// Count of each symbol across all matches
108    symbol_counts: HashMap<Symbol, usize>,
109    /// Total symbol occurrences
110    total: usize,
111    /// Bigram counts (consecutive symbols)
112    bigram_counts: HashMap<(Symbol, Symbol), usize>,
113    /// Total bigrams
114    total_bigrams: usize,
115}
116
117impl OperatorFrequency {
118    /// Create a new frequency map
119    pub fn new() -> Self {
120        Self::default()
121    }
122
123    /// Add a match to the frequency counts
124    pub fn add(&mut self, m: &Match) {
125        let lhs_syms = m.lhs.expr.symbols();
126        let rhs_syms = m.rhs.expr.symbols();
127
128        // Count symbols
129        for &sym in lhs_syms.iter().chain(rhs_syms.iter()) {
130            *self.symbol_counts.entry(sym).or_insert(0) += 1;
131            self.total += 1;
132        }
133
134        // Count bigrams
135        for window in lhs_syms.windows(2) {
136            let bigram = (window[0], window[1]);
137            *self.bigram_counts.entry(bigram).or_insert(0) += 1;
138            self.total_bigrams += 1;
139        }
140        for window in rhs_syms.windows(2) {
141            let bigram = (window[0], window[1]);
142            *self.bigram_counts.entry(bigram).or_insert(0) += 1;
143            self.total_bigrams += 1;
144        }
145    }
146
147    /// Get rarity score for a symbol (higher = rarer)
148    pub fn symbol_rarity(&self, sym: Symbol) -> f64 {
149        if self.total == 0 {
150            return 1.0;
151        }
152        let count = self.symbol_counts.get(&sym).copied().unwrap_or(0);
153        if count == 0 {
154            return 2.0; // Very rare (not seen)
155        }
156        let freq = count as f64 / self.total as f64;
157        // Inverse log frequency as rarity
158        (-freq.log10()).max(0.0)
159    }
160
161    /// Get rarity score for a bigram
162    pub fn bigram_rarity(&self, a: Symbol, b: Symbol) -> f64 {
163        if self.total_bigrams == 0 {
164            return 1.0;
165        }
166        let count = self.bigram_counts.get(&(a, b)).copied().unwrap_or(0);
167        if count == 0 {
168            return 2.0;
169        }
170        let freq = count as f64 / self.total_bigrams as f64;
171        (-freq.log10()).max(0.0)
172    }
173}
174
175/// Compute ugliness score
176fn compute_ugliness(m: &Match) -> f64 {
177    let mut score = 0.0;
178
179    // Penalize total operator count
180    let op_count = count_operators(&m.lhs) + count_operators(&m.rhs);
181    score += op_count as f64 * 0.5;
182
183    // Penalize nesting depth (approximated by expression length)
184    let total_len = m.lhs.expr.len() + m.rhs.expr.len();
185    if total_len > 8 {
186        score += (total_len - 8) as f64 * 0.3;
187    }
188
189    // Penalize transcendental operators (they're "expensive")
190    for sym in m.lhs.expr.symbols().iter().chain(m.rhs.expr.symbols()) {
191        if matches!(
192            sym,
193            Symbol::Ln
194                | Symbol::Exp
195                | Symbol::SinPi
196                | Symbol::CosPi
197                | Symbol::TanPi
198                | Symbol::LambertW
199                | Symbol::Log
200                | Symbol::Atan2
201        ) {
202            score += 1.0;
203        }
204    }
205
206    score
207}
208
209/// Count operators in an expression
210fn count_operators(expr: &crate::expr::EvaluatedExpr) -> usize {
211    expr.expr
212        .symbols()
213        .iter()
214        .filter(|s| s.seft() != Seft::A)
215        .count()
216}
217
218/// Compute novelty score based on operator rarity
219fn compute_novelty(m: &Match, freq_map: Option<&OperatorFrequency>) -> f64 {
220    let mut score = 0.0;
221
222    // Base novelty from using uncommon operators
223    for sym in m.lhs.expr.symbols().iter().chain(m.rhs.expr.symbols()) {
224        if let Some(freq) = freq_map {
225            score += freq.symbol_rarity(*sym);
226        } else {
227            // Default rarity based on operator type
228            score += default_rarity(*sym);
229        }
230    }
231
232    // Bonus for bigram rarity
233    if let Some(freq) = freq_map {
234        let lhs_syms = m.lhs.expr.symbols();
235        for window in lhs_syms.windows(2) {
236            score += freq.bigram_rarity(window[0], window[1]) * 0.5;
237        }
238    }
239
240    // Normalize by expression length
241    let len = (m.lhs.expr.len() + m.rhs.expr.len()).max(1);
242    score / len as f64
243}
244
245/// Default rarity for operators (when no frequency map available)
246fn default_rarity(sym: Symbol) -> f64 {
247    match sym {
248        // Common constants
249        Symbol::One | Symbol::Two | Symbol::X => 0.1,
250        Symbol::Three | Symbol::Four | Symbol::Five => 0.2,
251        Symbol::Pi | Symbol::E => 0.3,
252        Symbol::Six | Symbol::Seven | Symbol::Eight | Symbol::Nine => 0.4,
253        Symbol::Phi => 0.6,
254        // New constants - medium-high rarity (less common)
255        Symbol::Gamma => 0.7,
256        Symbol::Plastic => 0.7,
257        Symbol::Apery => 0.8,
258        Symbol::Catalan => 0.7,
259
260        // Common operators
261        Symbol::Add | Symbol::Sub | Symbol::Mul | Symbol::Div => 0.2,
262        Symbol::Pow | Symbol::Sqrt | Symbol::Square => 0.3,
263
264        // Less common operators
265        Symbol::Recip | Symbol::Neg => 0.4,
266        Symbol::Ln | Symbol::Exp => 0.5,
267
268        // Uncommon operators (higher novelty)
269        Symbol::SinPi | Symbol::CosPi => 0.7,
270        Symbol::TanPi => 0.8,
271        Symbol::Root | Symbol::Log => 0.7,
272        Symbol::LambertW | Symbol::Atan2 => 1.0,
273
274        // User constants - medium rarity
275        Symbol::UserConstant0
276        | Symbol::UserConstant1
277        | Symbol::UserConstant2
278        | Symbol::UserConstant3
279        | Symbol::UserConstant4
280        | Symbol::UserConstant5
281        | Symbol::UserConstant6
282        | Symbol::UserConstant7
283        | Symbol::UserConstant8
284        | Symbol::UserConstant9
285        | Symbol::UserConstant10
286        | Symbol::UserConstant11
287        | Symbol::UserConstant12
288        | Symbol::UserConstant13
289        | Symbol::UserConstant14
290        | Symbol::UserConstant15 => 0.5,
291
292        // User functions - medium-high rarity (custom operations)
293        Symbol::UserFunction0
294        | Symbol::UserFunction1
295        | Symbol::UserFunction2
296        | Symbol::UserFunction3
297        | Symbol::UserFunction4
298        | Symbol::UserFunction5
299        | Symbol::UserFunction6
300        | Symbol::UserFunction7
301        | Symbol::UserFunction8
302        | Symbol::UserFunction9
303        | Symbol::UserFunction10
304        | Symbol::UserFunction11
305        | Symbol::UserFunction12
306        | Symbol::UserFunction13
307        | Symbol::UserFunction14
308        | Symbol::UserFunction15 => 0.6,
309    }
310}
311
312/// Compute stability score based on Newton conditioning
313fn compute_stability(m: &Match) -> f64 {
314    let deriv = m.lhs.derivative.abs();
315
316    // Ideal: derivative magnitude near 1 (order-1 updates)
317    // Bad: too small (sensitive) or too large (ill-conditioned)
318    if deriv < DEGENERATE_TEST_THRESHOLD {
319        return 0.0; // Very unstable (degenerate)
320    }
321
322    let log_deriv = deriv.log10();
323
324    // Sweet spot: log10(deriv) near 0
325    // Penalize extremes
326    let distance_from_ideal = log_deriv.abs();
327
328    // Score: higher is better, max at 1.0
329    (1.0 - distance_from_ideal / 5.0).max(0.0)
330}
331
332/// Compute diversity score (bonus for mixed operator families)
333fn compute_diversity(m: &Match) -> f64 {
334    let mut has_algebraic = false;
335    let mut has_transcendental = false;
336    let mut has_trigonometric = false;
337
338    for sym in m.lhs.expr.symbols().iter().chain(m.rhs.expr.symbols()) {
339        match sym {
340            Symbol::Add
341            | Symbol::Sub
342            | Symbol::Mul
343            | Symbol::Div
344            | Symbol::Pow
345            | Symbol::Sqrt
346            | Symbol::Square
347            | Symbol::Root
348            | Symbol::Neg
349            | Symbol::Recip => has_algebraic = true,
350
351            Symbol::Ln | Symbol::Exp | Symbol::LambertW => has_transcendental = true,
352
353            Symbol::SinPi | Symbol::CosPi | Symbol::TanPi | Symbol::Atan2 => {
354                has_trigonometric = true;
355            }
356
357            _ => {}
358        }
359    }
360
361    let mut score = 0.0;
362    let count = [has_algebraic, has_transcendental, has_trigonometric]
363        .iter()
364        .filter(|&&b| b)
365        .count();
366
367    if count >= 2 {
368        score += 0.5;
369    }
370    if count >= 3 {
371        score += 0.5;
372    }
373
374    score
375}
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380    use crate::expr::{EvaluatedExpr, Expression};
381    use crate::symbol::NumType;
382
383    fn make_match(lhs: &str, rhs: &str, error: f64, deriv: f64) -> Match {
384        let lhs_expr = Expression::parse(lhs).unwrap();
385        let rhs_expr = Expression::parse(rhs).unwrap();
386        Match {
387            lhs: EvaluatedExpr::new(lhs_expr.clone(), 0.0, deriv, NumType::Integer),
388            rhs: EvaluatedExpr::new(rhs_expr.clone(), 0.0, 0.0, NumType::Integer),
389            x_value: 2.5,
390            error,
391            complexity: lhs_expr.complexity() + rhs_expr.complexity(),
392        }
393    }
394
395    #[test]
396    fn test_metrics_exact() {
397        let m = make_match("2x*", "5", 0.0, 2.0);
398        let metrics = MatchMetrics::from_match(&m, None);
399
400        assert!(metrics.is_exact);
401        assert!(metrics.stability > 0.5); // Good conditioning
402    }
403
404    #[test]
405    fn test_elegant_score() {
406        let simple = make_match("2x*", "5", 0.0, 2.0);
407        let complex = make_match("xx^ps+", "3qE", 0.001, 1.0);
408
409        let simple_metrics = MatchMetrics::from_match(&simple, None);
410        let complex_metrics = MatchMetrics::from_match(&complex, None);
411
412        // Simpler expression should have lower elegant score
413        assert!(simple_metrics.elegant_score() < complex_metrics.elegant_score());
414    }
415
416    #[test]
417    fn test_stability_extremes() {
418        let stable = make_match("x", "25/", 0.0, 1.0);
419        let unstable = make_match("x", "25/", 0.0, 1e-12);
420
421        let stable_metrics = MatchMetrics::from_match(&stable, None);
422        let unstable_metrics = MatchMetrics::from_match(&unstable, None);
423
424        assert!(stable_metrics.stability > unstable_metrics.stability);
425    }
426
427    /// Issue: when error_cap == EXACT_MATCH_TOLERANCE (1e-14), the denominator
428    /// `error_cap.log10() + 14.0` is exactly 0.0, producing NaN via 0/0.
429    #[test]
430    fn test_interesting_score_finite_at_exact_tolerance_boundary() {
431        // error == error_cap == 1e-14: falls into the division branch (not < EXACT_MATCH_TOLERANCE)
432        // and denominator = 1e-14.log10() + 14.0 = -14 + 14 = 0 → 0/0 = NaN before fix.
433        let m = make_match("2x*", "5", EXACT_MATCH_TOLERANCE, 2.0);
434        let metrics = MatchMetrics::from_match(&m, None);
435        let interesting = metrics.interesting_score(EXACT_MATCH_TOLERANCE);
436        assert!(
437            interesting.is_finite(),
438            "interesting_score must be finite, got {interesting}"
439        );
440    }
441}