Skip to main content

ries_rs/
fast_match.rs

1//! Fast-path exact match detection
2//!
3//! Before doing expensive expression generation, check if the target
4//! is a simple exact match (like pi, e, sqrt(2), etc.) that can be
5//! found instantly.
6
7use crate::eval::{evaluate_with_context, EvalContext};
8use crate::expr::{EvaluatedExpr, Expression};
9use crate::profile::UserConstant;
10use crate::search::Match;
11use crate::symbol::{NumType, Symbol};
12use crate::symbol_table::SymbolTable;
13use std::collections::HashSet;
14
15/// Tolerance for exact match detection
16const EXACT_TOLERANCE: f64 = 1e-14;
17
18/// Build an expression from symbols using table-based weights
19fn expr_from_symbols_with_table(symbols: &[Symbol], table: &SymbolTable) -> Expression {
20    let mut expr = Expression::new();
21    for &sym in symbols {
22        expr.push_with_table(sym, table);
23    }
24    expr
25}
26
27/// Get the num_type of an expression based on its symbols
28/// This is a simplified type inference for the fast_match candidates
29fn get_num_type(symbols: &[Symbol]) -> NumType {
30    // For fast match candidates, we use simplified type inference:
31    // - Integer constants → Integer
32    // - Rational operations (division) → Rational
33    // - Sqrt of integer → Algebraic (constructible)
34    // - Transcendental constants → Transcendental
35    // - Any transcendental function → Transcendental
36
37    use Symbol::*;
38
39    // Handle simple patterns
40    if symbols.len() == 1 {
41        return symbols[0].inherent_type();
42    }
43
44    // Check for sqrt of integer (like 2q = sqrt(2))
45    if symbols.len() == 2 {
46        if matches!(symbols[1], Sqrt) {
47            if matches!(
48                symbols[0],
49                One | Two | Three | Four | Five | Six | Seven | Eight | Nine
50            ) {
51                return NumType::Algebraic; // sqrt of integer is algebraic
52            }
53            // sqrt of transcendental constant is transcendental
54            if matches!(symbols[0], Pi | E | Gamma | Apery | Catalan) {
55                return NumType::Transcendental;
56            }
57        }
58        // Check for reciprocal of integer
59        if matches!(symbols[1], Recip)
60            && matches!(
61                symbols[0],
62                One | Two | Three | Four | Five | Six | Seven | Eight | Nine
63            )
64        {
65            return NumType::Rational;
66        }
67        // Division: integer / integer = rational
68        if matches!(symbols[1], Div)
69            && matches!(
70                symbols[0],
71                One | Two | Three | Four | Five | Six | Seven | Eight | Nine
72            )
73            && symbols.len() >= 3
74        {
75            // This is more complex, but for simple cases it's rational
76            return NumType::Rational;
77        }
78    }
79
80    // Check for division pattern (3 symbols: num, denom, /)
81    if symbols.len() == 3 && matches!(symbols[2], Div) {
82        // Integer / Integer = Rational
83        if matches!(
84            symbols[0],
85            One | Two | Three | Four | Five | Six | Seven | Eight | Nine
86        ) && matches!(
87            symbols[1],
88            One | Two | Three | Four | Five | Six | Seven | Eight | Nine
89        ) {
90            return NumType::Rational;
91        }
92    }
93
94    // Default: check if any symbol is transcendental
95    for &sym in symbols {
96        let sym_type = sym.inherent_type();
97        if sym_type == NumType::Transcendental {
98            return NumType::Transcendental;
99        }
100    }
101
102    // If we have any algebraic constants (phi, plastic), result is algebraic
103    for &sym in symbols {
104        if matches!(sym, Phi | Plastic) {
105            return NumType::Algebraic;
106        }
107    }
108
109    // Default to transcendental (most general)
110    NumType::Transcendental
111}
112
113/// Check if any symbol in the expression is excluded
114fn contains_excluded(symbols: &[Symbol], excluded: &HashSet<u8>) -> bool {
115    symbols.iter().any(|s| excluded.contains(&(*s as u8)))
116}
117
118/// A candidate for a fast exact match
119struct FastCandidate {
120    /// The expression (as symbols)
121    symbols: &'static [Symbol],
122}
123
124/// Generate fast candidates for common constants and simple expressions
125fn get_constant_candidates() -> Vec<FastCandidate> {
126    vec![
127        // Integers
128        FastCandidate {
129            symbols: &[Symbol::One],
130        },
131        FastCandidate {
132            symbols: &[Symbol::Two],
133        },
134        FastCandidate {
135            symbols: &[Symbol::Three],
136        },
137        FastCandidate {
138            symbols: &[Symbol::Four],
139        },
140        FastCandidate {
141            symbols: &[Symbol::Five],
142        },
143        FastCandidate {
144            symbols: &[Symbol::Six],
145        },
146        FastCandidate {
147            symbols: &[Symbol::Seven],
148        },
149        FastCandidate {
150            symbols: &[Symbol::Eight],
151        },
152        FastCandidate {
153            symbols: &[Symbol::Nine],
154        },
155        // Named constants
156        FastCandidate {
157            symbols: &[Symbol::Pi],
158        },
159        FastCandidate {
160            symbols: &[Symbol::E],
161        },
162        FastCandidate {
163            symbols: &[Symbol::Phi],
164        },
165        FastCandidate {
166            symbols: &[Symbol::Gamma],
167        },
168        FastCandidate {
169            symbols: &[Symbol::Plastic],
170        },
171        FastCandidate {
172            symbols: &[Symbol::Apery],
173        },
174        FastCandidate {
175            symbols: &[Symbol::Catalan],
176        },
177        // Simple rationals (common ones)
178        FastCandidate {
179            symbols: &[Symbol::One, Symbol::Two, Symbol::Div],
180        },
181        FastCandidate {
182            symbols: &[Symbol::One, Symbol::Three, Symbol::Div],
183        },
184        FastCandidate {
185            symbols: &[Symbol::Two, Symbol::Three, Symbol::Div],
186        },
187        FastCandidate {
188            symbols: &[Symbol::One, Symbol::Four, Symbol::Div],
189        },
190        FastCandidate {
191            symbols: &[Symbol::Three, Symbol::Four, Symbol::Div],
192        },
193        // Simple roots
194        FastCandidate {
195            symbols: &[Symbol::Two, Symbol::Sqrt],
196        },
197        FastCandidate {
198            symbols: &[Symbol::Three, Symbol::Sqrt],
199        },
200        FastCandidate {
201            symbols: &[Symbol::Five, Symbol::Sqrt],
202        },
203        FastCandidate {
204            symbols: &[Symbol::Six, Symbol::Sqrt],
205        },
206        FastCandidate {
207            symbols: &[Symbol::Seven, Symbol::Sqrt],
208        },
209        FastCandidate {
210            symbols: &[Symbol::Eight, Symbol::Sqrt],
211        },
212        FastCandidate {
213            symbols: &[Symbol::Pi, Symbol::Sqrt],
214        },
215        FastCandidate {
216            symbols: &[Symbol::E, Symbol::Sqrt],
217        },
218        // Simple logs
219        FastCandidate {
220            symbols: &[Symbol::Two, Symbol::Ln],
221        },
222        FastCandidate {
223            symbols: &[Symbol::Pi, Symbol::Ln],
224        },
225        // e ± small integers
226        FastCandidate {
227            symbols: &[Symbol::E, Symbol::One, Symbol::Sub],
228        },
229        FastCandidate {
230            symbols: &[Symbol::E, Symbol::One, Symbol::Add],
231        },
232        // pi ± small integers
233        FastCandidate {
234            symbols: &[Symbol::Pi, Symbol::One, Symbol::Sub],
235        },
236        FastCandidate {
237            symbols: &[Symbol::Pi, Symbol::One, Symbol::Add],
238        },
239        FastCandidate {
240            symbols: &[Symbol::Pi, Symbol::Two, Symbol::Sub],
241        },
242        // Common combinations
243        FastCandidate {
244            symbols: &[Symbol::One, Symbol::Two, Symbol::Add],
245        },
246        FastCandidate {
247            symbols: &[Symbol::One, Symbol::Sqrt, Symbol::One, Symbol::Add],
248        },
249        FastCandidate {
250            symbols: &[Symbol::Two, Symbol::Sqrt, Symbol::One, Symbol::Add],
251        },
252        // phi combinations (golden ratio)
253        FastCandidate {
254            symbols: &[Symbol::Phi, Symbol::One, Symbol::Add],
255        },
256        FastCandidate {
257            symbols: &[Symbol::Phi, Symbol::Two, Symbol::Add],
258        },
259        FastCandidate {
260            symbols: &[Symbol::Phi, Symbol::Square],
261        },
262        // Reciprocals of constants
263        FastCandidate {
264            symbols: &[Symbol::Pi, Symbol::Recip],
265        },
266        FastCandidate {
267            symbols: &[Symbol::E, Symbol::Recip],
268        },
269        FastCandidate {
270            symbols: &[Symbol::Phi, Symbol::Recip],
271        },
272    ]
273}
274
275/// Check if target matches a simple integer
276fn check_integer(target: f64) -> Option<(i64, f64)> {
277    let rounded = target.round();
278    let error = (target - rounded).abs();
279    if error < EXACT_TOLERANCE && rounded.abs() < 1000.0 {
280        Some((rounded as i64, error))
281    } else {
282        None
283    }
284}
285
286/// Configuration for fast match filtering
287pub struct FastMatchConfig<'a> {
288    /// Symbols that are excluded (via -N flag)
289    pub excluded_symbols: &'a HashSet<u8>,
290    /// Symbols that are explicitly allowed (all symbols must be in set)
291    pub allowed_symbols: Option<&'a HashSet<u8>>,
292    /// Minimum numeric type required (via -a, -r, -i flags)
293    pub min_num_type: NumType,
294}
295
296#[inline]
297fn passes_symbol_filters(symbols: &[Symbol], config: &FastMatchConfig<'_>) -> bool {
298    if contains_excluded(symbols, config.excluded_symbols) {
299        return false;
300    }
301    if let Some(allowed) = config.allowed_symbols {
302        if symbols.iter().any(|s| !allowed.contains(&(*s as u8))) {
303            return false;
304        }
305    }
306    true
307}
308
309/// Try to find a fast exact match for the target value
310///
311/// Returns a Match if found, or None if no simple exact match exists.
312/// This function is designed to be called before expensive generation.
313pub fn find_fast_match(
314    target: f64,
315    user_constants: &[UserConstant],
316    config: &FastMatchConfig<'_>,
317    table: &SymbolTable,
318) -> Option<Match> {
319    let context = EvalContext::from_slices(user_constants, &[]);
320    find_fast_match_with_context(target, &context, config, table)
321}
322
323/// Try to find a fast exact match for the target value using an explicit evaluation context.
324pub fn find_fast_match_with_context(
325    target: f64,
326    context: &EvalContext<'_>,
327    config: &FastMatchConfig<'_>,
328    table: &SymbolTable,
329) -> Option<Match> {
330    // First check integers (fastest)
331    if let Some((n, error)) = check_integer(target) {
332        if (1..=9).contains(&n) {
333            // We have a direct constant for 1-9
334            let symbols: &[Symbol] = match n {
335                1 => &[Symbol::One],
336                2 => &[Symbol::Two],
337                3 => &[Symbol::Three],
338                4 => &[Symbol::Four],
339                5 => &[Symbol::Five],
340                6 => &[Symbol::Six],
341                7 => &[Symbol::Seven],
342                8 => &[Symbol::Eight],
343                9 => &[Symbol::Nine],
344                _ => return None,
345            };
346            // Check if excluded or wrong type
347            if passes_symbol_filters(symbols, config)
348                && get_num_type(symbols) >= config.min_num_type
349            {
350                if let Some(m) = make_match(symbols, target, error, table, context) {
351                    return Some(m);
352                }
353            }
354        }
355        // For other integers, check if they match user constants
356        for (idx, uc) in context.user_constants.iter().enumerate() {
357            if idx < 16 && (uc.value - target).abs() < EXACT_TOLERANCE {
358                if let Some(sym) = Symbol::from_byte(128 + idx as u8) {
359                    let symbols = [sym];
360                    if passes_symbol_filters(&symbols, config) && uc.num_type >= config.min_num_type
361                    {
362                        if let Some(m) =
363                            make_match(&symbols, target, (uc.value - target).abs(), table, context)
364                        {
365                            return Some(m);
366                        }
367                    }
368                }
369            }
370        }
371    }
372
373    // Check user constants first (they're explicitly defined)
374    for (idx, uc) in context.user_constants.iter().enumerate() {
375        if idx >= 16 {
376            break;
377        }
378        if (uc.value - target).abs() < EXACT_TOLERANCE {
379            if let Some(sym) = Symbol::from_byte(128 + idx as u8) {
380                let symbols = [sym];
381                if passes_symbol_filters(&symbols, config) && uc.num_type >= config.min_num_type {
382                    if let Some(m) =
383                        make_match(&symbols, target, (uc.value - target).abs(), table, context)
384                    {
385                        return Some(m);
386                    }
387                }
388            }
389        }
390    }
391
392    // Check known constant candidates
393    let candidates = get_constant_candidates();
394    for candidate in candidates {
395        // Skip if contains excluded symbols
396        if !passes_symbol_filters(candidate.symbols, config) {
397            continue;
398        }
399        // Skip if type doesn't meet requirement
400        if get_num_type(candidate.symbols) < config.min_num_type {
401            continue;
402        }
403
404        let expr = expr_from_symbols_with_table(candidate.symbols, table);
405        if let Ok(result) = evaluate_with_context(&expr, target, context) {
406            let error = (result.value - target).abs();
407            if error < EXACT_TOLERANCE {
408                if let Some(m) = make_match(candidate.symbols, target, error, table, context) {
409                    return Some(m);
410                }
411            }
412        }
413    }
414
415    // Check user constants with simple operations
416    for (idx, uc) in context.user_constants.iter().enumerate() {
417        if idx >= 16 {
418            break;
419        }
420        if let Some(sym) = Symbol::from_byte(128 + idx as u8) {
421            // Check 1/constant
422            if uc.value != 0.0 {
423                let recip_val = 1.0 / uc.value;
424                if (recip_val - target).abs() < EXACT_TOLERANCE {
425                    let symbols = [sym, Symbol::Recip];
426                    if passes_symbol_filters(&symbols, config) && uc.num_type >= config.min_num_type
427                    {
428                        if let Some(m) =
429                            make_match(&symbols, target, (recip_val - target).abs(), table, context)
430                        {
431                            return Some(m);
432                        }
433                    }
434                }
435            }
436            // Check sqrt(constant)
437            if uc.value > 0.0 {
438                let sqrt_val = uc.value.sqrt();
439                if (sqrt_val - target).abs() < EXACT_TOLERANCE {
440                    let symbols = [sym, Symbol::Sqrt];
441                    if passes_symbol_filters(&symbols, config) && uc.num_type >= config.min_num_type
442                    {
443                        if let Some(m) =
444                            make_match(&symbols, target, (sqrt_val - target).abs(), table, context)
445                        {
446                            return Some(m);
447                        }
448                    }
449                }
450            }
451        }
452    }
453
454    None
455}
456
457/// Create a Match from symbols representing the RHS value
458fn make_match(
459    symbols: &[Symbol],
460    target: f64,
461    error: f64,
462    table: &SymbolTable,
463    context: &EvalContext<'_>,
464) -> Option<Match> {
465    let lhs_expr = expr_from_symbols_with_table(&[Symbol::X], table);
466    let rhs_expr = expr_from_symbols_with_table(symbols, table);
467    let complexity = lhs_expr.complexity() + rhs_expr.complexity();
468
469    let lhs_eval = evaluate_with_context(&lhs_expr, target, context).ok()?;
470    let rhs_eval = evaluate_with_context(&rhs_expr, target, context).ok()?;
471
472    Some(Match {
473        lhs: EvaluatedExpr {
474            expr: lhs_expr,
475            value: lhs_eval.value,
476            derivative: lhs_eval.derivative,
477            num_type: NumType::Transcendental,
478        },
479        rhs: EvaluatedExpr {
480            expr: rhs_expr,
481            value: rhs_eval.value,
482            derivative: 0.0,
483            num_type: rhs_eval.num_type,
484        },
485        x_value: target,
486        error,
487        complexity,
488    })
489}
490
491#[cfg(test)]
492mod tests {
493    use super::*;
494
495    fn default_config() -> FastMatchConfig<'static> {
496        static EMPTY: std::sync::OnceLock<HashSet<u8>> = std::sync::OnceLock::new();
497        let empty = EMPTY.get_or_init(HashSet::new);
498        FastMatchConfig {
499            excluded_symbols: empty,
500            allowed_symbols: None,
501            min_num_type: NumType::Transcendental,
502        }
503    }
504
505    fn default_table() -> SymbolTable {
506        SymbolTable::new()
507    }
508
509    #[test]
510    fn test_pi_match() {
511        let m = find_fast_match(
512            std::f64::consts::PI,
513            &[],
514            &default_config(),
515            &default_table(),
516        );
517        assert!(m.is_some());
518        let m = m.unwrap();
519        assert!(m.error.abs() < 1e-14);
520        assert_eq!(m.rhs.expr.to_postfix(), "p");
521    }
522
523    #[test]
524    fn test_pi_excluded() {
525        let excluded: HashSet<u8> = vec![b'p'].into_iter().collect();
526        let config = FastMatchConfig {
527            excluded_symbols: &excluded,
528            allowed_symbols: None,
529            min_num_type: NumType::Transcendental,
530        };
531        let m = find_fast_match(std::f64::consts::PI, &[], &config, &default_table());
532        assert!(m.is_none(), "Should not find pi when it's excluded");
533    }
534
535    #[test]
536    fn test_pi_algebraic_only() {
537        static EMPTY: std::sync::OnceLock<HashSet<u8>> = std::sync::OnceLock::new();
538        let empty = EMPTY.get_or_init(HashSet::new);
539        let config = FastMatchConfig {
540            excluded_symbols: empty,
541            allowed_symbols: None,
542            min_num_type: NumType::Algebraic,
543        };
544        let m = find_fast_match(std::f64::consts::PI, &[], &config, &default_table());
545        assert!(
546            m.is_none(),
547            "Should not find pi when only algebraic allowed"
548        );
549    }
550
551    #[test]
552    fn test_sqrt2_algebraic_ok() {
553        static EMPTY: std::sync::OnceLock<HashSet<u8>> = std::sync::OnceLock::new();
554        let empty = EMPTY.get_or_init(HashSet::new);
555        let config = FastMatchConfig {
556            excluded_symbols: empty,
557            allowed_symbols: None,
558            min_num_type: NumType::Algebraic,
559        };
560        let m = find_fast_match(2.0_f64.sqrt(), &[], &config, &default_table());
561        assert!(m.is_some(), "sqrt(2) should be found with algebraic-only");
562    }
563
564    #[test]
565    fn test_e_match() {
566        let m = find_fast_match(
567            std::f64::consts::E,
568            &[],
569            &default_config(),
570            &default_table(),
571        );
572        assert!(m.is_some());
573        let m = m.unwrap();
574        assert!(m.error.abs() < 1e-14);
575        assert_eq!(m.rhs.expr.to_postfix(), "e");
576    }
577
578    #[test]
579    fn test_sqrt2_match() {
580        let m = find_fast_match(2.0_f64.sqrt(), &[], &default_config(), &default_table());
581        assert!(m.is_some());
582        let m = m.unwrap();
583        assert!(m.error.abs() < 1e-14);
584        assert_eq!(m.rhs.expr.to_postfix(), "2q");
585    }
586
587    #[test]
588    fn test_phi_match() {
589        let phi = (1.0 + 5.0_f64.sqrt()) / 2.0;
590        let m = find_fast_match(phi, &[], &default_config(), &default_table());
591        assert!(m.is_some());
592        let m = m.unwrap();
593        assert!(m.error.abs() < 1e-14);
594        assert_eq!(m.rhs.expr.to_postfix(), "f");
595    }
596
597    #[test]
598    fn test_integer_match() {
599        let m = find_fast_match(5.0, &[], &default_config(), &default_table());
600        assert!(m.is_some());
601        let m = m.unwrap();
602        assert!(m.error.abs() < 1e-14);
603        assert_eq!(m.rhs.expr.to_postfix(), "5");
604    }
605
606    #[test]
607    fn test_no_match_for_random() {
608        // 2.506314 is not a simple constant
609        let m = find_fast_match(2.506314, &[], &default_config(), &default_table());
610        assert!(m.is_none());
611    }
612
613    #[test]
614    fn test_user_constant_match() {
615        let uc = UserConstant {
616            weight: 4,
617            name: "myconst".to_string(),
618            description: "Test constant".to_string(),
619            value: std::f64::consts::E,
620            num_type: NumType::Transcendental,
621        };
622        let m = find_fast_match(
623            std::f64::consts::E,
624            &[uc],
625            &default_config(),
626            &default_table(),
627        );
628        assert!(m.is_some());
629    }
630}