1use crate::search::Match;
11use crate::symbol::{Seft, Symbol};
12use crate::thresholds::{DEGENERATE_TEST_THRESHOLD, EXACT_MATCH_TOLERANCE};
13use std::collections::HashMap;
14
15#[derive(Clone, Debug)]
17pub struct MatchMetrics {
18 pub error: f64,
20 pub is_exact: bool,
22 pub complexity: u32,
24 pub ugliness: f64,
26 pub novelty: f64,
28 pub stability: f64,
30 pub diversity: f64,
32}
33
34impl MatchMetrics {
35 pub fn from_match(m: &Match, freq_map: Option<&OperatorFrequency>) -> Self {
37 let error = m.error.abs();
38 let is_exact = error < EXACT_MATCH_TOLERANCE;
39 let complexity = m.complexity;
40
41 let ugliness = compute_ugliness(m);
43
44 let novelty = compute_novelty(m, freq_map);
46
47 let stability = compute_stability(m);
49
50 let diversity = compute_diversity(m);
52
53 Self {
54 error,
55 is_exact,
56 complexity,
57 ugliness,
58 novelty,
59 stability,
60 diversity,
61 }
62 }
63
64 pub fn elegant_score(&self) -> f64 {
67 self.complexity as f64 + 0.1 * self.ugliness
68 }
69
70 pub fn interesting_score(&self, error_cap: f64) -> f64 {
73 if self.error > error_cap {
74 return f64::NEG_INFINITY;
75 }
76
77 let error_norm = if self.error < EXACT_MATCH_TOLERANCE {
81 0.0
82 } else {
83 let denom = error_cap.log10() + 14.0;
84 if denom.abs() < f64::EPSILON {
85 0.0
86 } else {
87 (self.error.log10() + 14.0) / denom
88 }
89 };
90
91 let complexity_norm = (self.complexity as f64) / 100.0;
93
94 self.novelty + 0.3 * self.diversity - 0.7 * error_norm - 0.2 * complexity_norm
96 }
97
98 pub fn stable_score(&self) -> f64 {
100 self.stability
101 }
102}
103
104#[derive(Default)]
106pub struct OperatorFrequency {
107 symbol_counts: HashMap<Symbol, usize>,
109 total: usize,
111 bigram_counts: HashMap<(Symbol, Symbol), usize>,
113 total_bigrams: usize,
115}
116
117impl OperatorFrequency {
118 pub fn new() -> Self {
120 Self::default()
121 }
122
123 pub fn add(&mut self, m: &Match) {
125 let lhs_syms = m.lhs.expr.symbols();
126 let rhs_syms = m.rhs.expr.symbols();
127
128 for &sym in lhs_syms.iter().chain(rhs_syms.iter()) {
130 *self.symbol_counts.entry(sym).or_insert(0) += 1;
131 self.total += 1;
132 }
133
134 for window in lhs_syms.windows(2) {
136 let bigram = (window[0], window[1]);
137 *self.bigram_counts.entry(bigram).or_insert(0) += 1;
138 self.total_bigrams += 1;
139 }
140 for window in rhs_syms.windows(2) {
141 let bigram = (window[0], window[1]);
142 *self.bigram_counts.entry(bigram).or_insert(0) += 1;
143 self.total_bigrams += 1;
144 }
145 }
146
147 pub fn symbol_rarity(&self, sym: Symbol) -> f64 {
149 if self.total == 0 {
150 return 1.0;
151 }
152 let count = self.symbol_counts.get(&sym).copied().unwrap_or(0);
153 if count == 0 {
154 return 2.0; }
156 let freq = count as f64 / self.total as f64;
157 (-freq.log10()).max(0.0)
159 }
160
161 pub fn bigram_rarity(&self, a: Symbol, b: Symbol) -> f64 {
163 if self.total_bigrams == 0 {
164 return 1.0;
165 }
166 let count = self.bigram_counts.get(&(a, b)).copied().unwrap_or(0);
167 if count == 0 {
168 return 2.0;
169 }
170 let freq = count as f64 / self.total_bigrams as f64;
171 (-freq.log10()).max(0.0)
172 }
173}
174
175fn compute_ugliness(m: &Match) -> f64 {
177 let mut score = 0.0;
178
179 let op_count = count_operators(&m.lhs) + count_operators(&m.rhs);
181 score += op_count as f64 * 0.5;
182
183 let total_len = m.lhs.expr.len() + m.rhs.expr.len();
185 if total_len > 8 {
186 score += (total_len - 8) as f64 * 0.3;
187 }
188
189 for sym in m.lhs.expr.symbols().iter().chain(m.rhs.expr.symbols()) {
191 if matches!(
192 sym,
193 Symbol::Ln
194 | Symbol::Exp
195 | Symbol::SinPi
196 | Symbol::CosPi
197 | Symbol::TanPi
198 | Symbol::LambertW
199 | Symbol::Log
200 | Symbol::Atan2
201 ) {
202 score += 1.0;
203 }
204 }
205
206 score
207}
208
209fn count_operators(expr: &crate::expr::EvaluatedExpr) -> usize {
211 expr.expr
212 .symbols()
213 .iter()
214 .filter(|s| s.seft() != Seft::A)
215 .count()
216}
217
218fn compute_novelty(m: &Match, freq_map: Option<&OperatorFrequency>) -> f64 {
220 let mut score = 0.0;
221
222 for sym in m.lhs.expr.symbols().iter().chain(m.rhs.expr.symbols()) {
224 if let Some(freq) = freq_map {
225 score += freq.symbol_rarity(*sym);
226 } else {
227 score += default_rarity(*sym);
229 }
230 }
231
232 if let Some(freq) = freq_map {
234 let lhs_syms = m.lhs.expr.symbols();
235 for window in lhs_syms.windows(2) {
236 score += freq.bigram_rarity(window[0], window[1]) * 0.5;
237 }
238 }
239
240 let len = (m.lhs.expr.len() + m.rhs.expr.len()).max(1);
242 score / len as f64
243}
244
245fn default_rarity(sym: Symbol) -> f64 {
247 match sym {
248 Symbol::One | Symbol::Two | Symbol::X => 0.1,
250 Symbol::Three | Symbol::Four | Symbol::Five => 0.2,
251 Symbol::Pi | Symbol::E => 0.3,
252 Symbol::Six | Symbol::Seven | Symbol::Eight | Symbol::Nine => 0.4,
253 Symbol::Phi => 0.6,
254 Symbol::Gamma => 0.7,
256 Symbol::Plastic => 0.7,
257 Symbol::Apery => 0.8,
258 Symbol::Catalan => 0.7,
259
260 Symbol::Add | Symbol::Sub | Symbol::Mul | Symbol::Div => 0.2,
262 Symbol::Pow | Symbol::Sqrt | Symbol::Square => 0.3,
263
264 Symbol::Recip | Symbol::Neg => 0.4,
266 Symbol::Ln | Symbol::Exp => 0.5,
267
268 Symbol::SinPi | Symbol::CosPi => 0.7,
270 Symbol::TanPi => 0.8,
271 Symbol::Root | Symbol::Log => 0.7,
272 Symbol::LambertW | Symbol::Atan2 => 1.0,
273
274 Symbol::UserConstant0
276 | Symbol::UserConstant1
277 | Symbol::UserConstant2
278 | Symbol::UserConstant3
279 | Symbol::UserConstant4
280 | Symbol::UserConstant5
281 | Symbol::UserConstant6
282 | Symbol::UserConstant7
283 | Symbol::UserConstant8
284 | Symbol::UserConstant9
285 | Symbol::UserConstant10
286 | Symbol::UserConstant11
287 | Symbol::UserConstant12
288 | Symbol::UserConstant13
289 | Symbol::UserConstant14
290 | Symbol::UserConstant15 => 0.5,
291
292 Symbol::UserFunction0
294 | Symbol::UserFunction1
295 | Symbol::UserFunction2
296 | Symbol::UserFunction3
297 | Symbol::UserFunction4
298 | Symbol::UserFunction5
299 | Symbol::UserFunction6
300 | Symbol::UserFunction7
301 | Symbol::UserFunction8
302 | Symbol::UserFunction9
303 | Symbol::UserFunction10
304 | Symbol::UserFunction11
305 | Symbol::UserFunction12
306 | Symbol::UserFunction13
307 | Symbol::UserFunction14
308 | Symbol::UserFunction15 => 0.6,
309 }
310}
311
312fn compute_stability(m: &Match) -> f64 {
314 let deriv = m.lhs.derivative.abs();
315
316 if deriv < DEGENERATE_TEST_THRESHOLD {
319 return 0.0; }
321
322 let log_deriv = deriv.log10();
323
324 let distance_from_ideal = log_deriv.abs();
327
328 (1.0 - distance_from_ideal / 5.0).max(0.0)
330}
331
332fn compute_diversity(m: &Match) -> f64 {
334 let mut has_algebraic = false;
335 let mut has_transcendental = false;
336 let mut has_trigonometric = false;
337
338 for sym in m.lhs.expr.symbols().iter().chain(m.rhs.expr.symbols()) {
339 match sym {
340 Symbol::Add
341 | Symbol::Sub
342 | Symbol::Mul
343 | Symbol::Div
344 | Symbol::Pow
345 | Symbol::Sqrt
346 | Symbol::Square
347 | Symbol::Root
348 | Symbol::Neg
349 | Symbol::Recip => has_algebraic = true,
350
351 Symbol::Ln | Symbol::Exp | Symbol::LambertW => has_transcendental = true,
352
353 Symbol::SinPi | Symbol::CosPi | Symbol::TanPi | Symbol::Atan2 => {
354 has_trigonometric = true;
355 }
356
357 _ => {}
358 }
359 }
360
361 let mut score = 0.0;
362 let count = [has_algebraic, has_transcendental, has_trigonometric]
363 .iter()
364 .filter(|&&b| b)
365 .count();
366
367 if count >= 2 {
368 score += 0.5;
369 }
370 if count >= 3 {
371 score += 0.5;
372 }
373
374 score
375}
376
377#[cfg(test)]
378mod tests {
379 use super::*;
380 use crate::expr::{EvaluatedExpr, Expression};
381 use crate::symbol::NumType;
382
383 fn make_match(lhs: &str, rhs: &str, error: f64, deriv: f64) -> Match {
384 let lhs_expr = Expression::parse(lhs).unwrap();
385 let rhs_expr = Expression::parse(rhs).unwrap();
386 Match {
387 lhs: EvaluatedExpr::new(lhs_expr.clone(), 0.0, deriv, NumType::Integer),
388 rhs: EvaluatedExpr::new(rhs_expr.clone(), 0.0, 0.0, NumType::Integer),
389 x_value: 2.5,
390 error,
391 complexity: lhs_expr.complexity() + rhs_expr.complexity(),
392 }
393 }
394
395 #[test]
396 fn test_metrics_exact() {
397 let m = make_match("2x*", "5", 0.0, 2.0);
398 let metrics = MatchMetrics::from_match(&m, None);
399
400 assert!(metrics.is_exact);
401 assert!(metrics.stability > 0.5); }
403
404 #[test]
405 fn test_elegant_score() {
406 let simple = make_match("2x*", "5", 0.0, 2.0);
407 let complex = make_match("xx^ps+", "3qE", 0.001, 1.0);
408
409 let simple_metrics = MatchMetrics::from_match(&simple, None);
410 let complex_metrics = MatchMetrics::from_match(&complex, None);
411
412 assert!(simple_metrics.elegant_score() < complex_metrics.elegant_score());
414 }
415
416 #[test]
417 fn test_stability_extremes() {
418 let stable = make_match("x", "25/", 0.0, 1.0);
419 let unstable = make_match("x", "25/", 0.0, 1e-12);
420
421 let stable_metrics = MatchMetrics::from_match(&stable, None);
422 let unstable_metrics = MatchMetrics::from_match(&unstable, None);
423
424 assert!(stable_metrics.stability > unstable_metrics.stability);
425 }
426
427 #[test]
430 fn test_interesting_score_finite_at_exact_tolerance_boundary() {
431 let m = make_match("2x*", "5", EXACT_MATCH_TOLERANCE, 2.0);
434 let metrics = MatchMetrics::from_match(&m, None);
435 let interesting = metrics.interesting_score(EXACT_MATCH_TOLERANCE);
436 assert!(
437 interesting.is_finite(),
438 "interesting_score must be finite, got {interesting}"
439 );
440 }
441}