1use crate::eval::{evaluate_with_context, EvalContext};
8use crate::expr::{EvaluatedExpr, Expression};
9use crate::profile::UserConstant;
10use crate::search::Match;
11use crate::symbol::{NumType, Symbol};
12use crate::symbol_table::SymbolTable;
13use std::collections::HashSet;
14
15const EXACT_TOLERANCE: f64 = 1e-14;
17
18fn expr_from_symbols_with_table(symbols: &[Symbol], table: &SymbolTable) -> Expression {
20 let mut expr = Expression::new();
21 for &sym in symbols {
22 expr.push_with_table(sym, table);
23 }
24 expr
25}
26
27fn get_num_type(symbols: &[Symbol]) -> NumType {
30 use Symbol::*;
38
39 if symbols.len() == 1 {
41 return symbols[0].inherent_type();
42 }
43
44 if symbols.len() == 2 {
46 if matches!(symbols[1], Sqrt) {
47 if matches!(
48 symbols[0],
49 One | Two | Three | Four | Five | Six | Seven | Eight | Nine
50 ) {
51 return NumType::Algebraic; }
53 if matches!(symbols[0], Pi | E | Gamma | Apery | Catalan) {
55 return NumType::Transcendental;
56 }
57 }
58 if matches!(symbols[1], Recip)
60 && matches!(
61 symbols[0],
62 One | Two | Three | Four | Five | Six | Seven | Eight | Nine
63 )
64 {
65 return NumType::Rational;
66 }
67 if matches!(symbols[1], Div)
69 && matches!(
70 symbols[0],
71 One | Two | Three | Four | Five | Six | Seven | Eight | Nine
72 )
73 && symbols.len() >= 3
74 {
75 return NumType::Rational;
77 }
78 }
79
80 if symbols.len() == 3 && matches!(symbols[2], Div) {
82 if matches!(
84 symbols[0],
85 One | Two | Three | Four | Five | Six | Seven | Eight | Nine
86 ) && matches!(
87 symbols[1],
88 One | Two | Three | Four | Five | Six | Seven | Eight | Nine
89 ) {
90 return NumType::Rational;
91 }
92 }
93
94 for &sym in symbols {
96 let sym_type = sym.inherent_type();
97 if sym_type == NumType::Transcendental {
98 return NumType::Transcendental;
99 }
100 }
101
102 for &sym in symbols {
104 if matches!(sym, Phi | Plastic) {
105 return NumType::Algebraic;
106 }
107 }
108
109 NumType::Transcendental
111}
112
113fn contains_excluded(symbols: &[Symbol], excluded: &HashSet<u8>) -> bool {
115 symbols.iter().any(|s| excluded.contains(&(*s as u8)))
116}
117
118struct FastCandidate {
120 symbols: &'static [Symbol],
122}
123
124fn get_constant_candidates() -> Vec<FastCandidate> {
126 vec![
127 FastCandidate {
129 symbols: &[Symbol::One],
130 },
131 FastCandidate {
132 symbols: &[Symbol::Two],
133 },
134 FastCandidate {
135 symbols: &[Symbol::Three],
136 },
137 FastCandidate {
138 symbols: &[Symbol::Four],
139 },
140 FastCandidate {
141 symbols: &[Symbol::Five],
142 },
143 FastCandidate {
144 symbols: &[Symbol::Six],
145 },
146 FastCandidate {
147 symbols: &[Symbol::Seven],
148 },
149 FastCandidate {
150 symbols: &[Symbol::Eight],
151 },
152 FastCandidate {
153 symbols: &[Symbol::Nine],
154 },
155 FastCandidate {
157 symbols: &[Symbol::Pi],
158 },
159 FastCandidate {
160 symbols: &[Symbol::E],
161 },
162 FastCandidate {
163 symbols: &[Symbol::Phi],
164 },
165 FastCandidate {
166 symbols: &[Symbol::Gamma],
167 },
168 FastCandidate {
169 symbols: &[Symbol::Plastic],
170 },
171 FastCandidate {
172 symbols: &[Symbol::Apery],
173 },
174 FastCandidate {
175 symbols: &[Symbol::Catalan],
176 },
177 FastCandidate {
179 symbols: &[Symbol::One, Symbol::Two, Symbol::Div],
180 },
181 FastCandidate {
182 symbols: &[Symbol::One, Symbol::Three, Symbol::Div],
183 },
184 FastCandidate {
185 symbols: &[Symbol::Two, Symbol::Three, Symbol::Div],
186 },
187 FastCandidate {
188 symbols: &[Symbol::One, Symbol::Four, Symbol::Div],
189 },
190 FastCandidate {
191 symbols: &[Symbol::Three, Symbol::Four, Symbol::Div],
192 },
193 FastCandidate {
195 symbols: &[Symbol::Two, Symbol::Sqrt],
196 },
197 FastCandidate {
198 symbols: &[Symbol::Three, Symbol::Sqrt],
199 },
200 FastCandidate {
201 symbols: &[Symbol::Five, Symbol::Sqrt],
202 },
203 FastCandidate {
204 symbols: &[Symbol::Six, Symbol::Sqrt],
205 },
206 FastCandidate {
207 symbols: &[Symbol::Seven, Symbol::Sqrt],
208 },
209 FastCandidate {
210 symbols: &[Symbol::Eight, Symbol::Sqrt],
211 },
212 FastCandidate {
213 symbols: &[Symbol::Pi, Symbol::Sqrt],
214 },
215 FastCandidate {
216 symbols: &[Symbol::E, Symbol::Sqrt],
217 },
218 FastCandidate {
220 symbols: &[Symbol::Two, Symbol::Ln],
221 },
222 FastCandidate {
223 symbols: &[Symbol::Pi, Symbol::Ln],
224 },
225 FastCandidate {
227 symbols: &[Symbol::E, Symbol::One, Symbol::Sub],
228 },
229 FastCandidate {
230 symbols: &[Symbol::E, Symbol::One, Symbol::Add],
231 },
232 FastCandidate {
234 symbols: &[Symbol::Pi, Symbol::One, Symbol::Sub],
235 },
236 FastCandidate {
237 symbols: &[Symbol::Pi, Symbol::One, Symbol::Add],
238 },
239 FastCandidate {
240 symbols: &[Symbol::Pi, Symbol::Two, Symbol::Sub],
241 },
242 FastCandidate {
244 symbols: &[Symbol::One, Symbol::Two, Symbol::Add],
245 },
246 FastCandidate {
247 symbols: &[Symbol::One, Symbol::Sqrt, Symbol::One, Symbol::Add],
248 },
249 FastCandidate {
250 symbols: &[Symbol::Two, Symbol::Sqrt, Symbol::One, Symbol::Add],
251 },
252 FastCandidate {
254 symbols: &[Symbol::Phi, Symbol::One, Symbol::Add],
255 },
256 FastCandidate {
257 symbols: &[Symbol::Phi, Symbol::Two, Symbol::Add],
258 },
259 FastCandidate {
260 symbols: &[Symbol::Phi, Symbol::Square],
261 },
262 FastCandidate {
264 symbols: &[Symbol::Pi, Symbol::Recip],
265 },
266 FastCandidate {
267 symbols: &[Symbol::E, Symbol::Recip],
268 },
269 FastCandidate {
270 symbols: &[Symbol::Phi, Symbol::Recip],
271 },
272 ]
273}
274
275fn check_integer(target: f64) -> Option<(i64, f64)> {
277 let rounded = target.round();
278 let error = (target - rounded).abs();
279 if error < EXACT_TOLERANCE && rounded.abs() < 1000.0 {
280 Some((rounded as i64, error))
281 } else {
282 None
283 }
284}
285
286pub struct FastMatchConfig<'a> {
288 pub excluded_symbols: &'a HashSet<u8>,
290 pub allowed_symbols: Option<&'a HashSet<u8>>,
292 pub min_num_type: NumType,
294}
295
296#[inline]
297fn passes_symbol_filters(symbols: &[Symbol], config: &FastMatchConfig<'_>) -> bool {
298 if contains_excluded(symbols, config.excluded_symbols) {
299 return false;
300 }
301 if let Some(allowed) = config.allowed_symbols {
302 if symbols.iter().any(|s| !allowed.contains(&(*s as u8))) {
303 return false;
304 }
305 }
306 true
307}
308
309pub fn find_fast_match(
314 target: f64,
315 user_constants: &[UserConstant],
316 config: &FastMatchConfig<'_>,
317 table: &SymbolTable,
318) -> Option<Match> {
319 let context = EvalContext::from_slices(user_constants, &[]);
320 find_fast_match_with_context(target, &context, config, table)
321}
322
323pub fn find_fast_match_with_context(
325 target: f64,
326 context: &EvalContext<'_>,
327 config: &FastMatchConfig<'_>,
328 table: &SymbolTable,
329) -> Option<Match> {
330 if let Some((n, error)) = check_integer(target) {
332 if (1..=9).contains(&n) {
333 let symbols: &[Symbol] = match n {
335 1 => &[Symbol::One],
336 2 => &[Symbol::Two],
337 3 => &[Symbol::Three],
338 4 => &[Symbol::Four],
339 5 => &[Symbol::Five],
340 6 => &[Symbol::Six],
341 7 => &[Symbol::Seven],
342 8 => &[Symbol::Eight],
343 9 => &[Symbol::Nine],
344 _ => return None,
345 };
346 if passes_symbol_filters(symbols, config)
348 && get_num_type(symbols) >= config.min_num_type
349 {
350 if let Some(m) = make_match(symbols, target, error, table, context) {
351 return Some(m);
352 }
353 }
354 }
355 for (idx, uc) in context.user_constants.iter().enumerate() {
357 if idx < 16 && (uc.value - target).abs() < EXACT_TOLERANCE {
358 if let Some(sym) = Symbol::from_byte(128 + idx as u8) {
359 let symbols = [sym];
360 if passes_symbol_filters(&symbols, config) && uc.num_type >= config.min_num_type
361 {
362 if let Some(m) =
363 make_match(&symbols, target, (uc.value - target).abs(), table, context)
364 {
365 return Some(m);
366 }
367 }
368 }
369 }
370 }
371 }
372
373 for (idx, uc) in context.user_constants.iter().enumerate() {
375 if idx >= 16 {
376 break;
377 }
378 if (uc.value - target).abs() < EXACT_TOLERANCE {
379 if let Some(sym) = Symbol::from_byte(128 + idx as u8) {
380 let symbols = [sym];
381 if passes_symbol_filters(&symbols, config) && uc.num_type >= config.min_num_type {
382 if let Some(m) =
383 make_match(&symbols, target, (uc.value - target).abs(), table, context)
384 {
385 return Some(m);
386 }
387 }
388 }
389 }
390 }
391
392 let candidates = get_constant_candidates();
394 for candidate in candidates {
395 if !passes_symbol_filters(candidate.symbols, config) {
397 continue;
398 }
399 if get_num_type(candidate.symbols) < config.min_num_type {
401 continue;
402 }
403
404 let expr = expr_from_symbols_with_table(candidate.symbols, table);
405 if let Ok(result) = evaluate_with_context(&expr, target, context) {
406 let error = (result.value - target).abs();
407 if error < EXACT_TOLERANCE {
408 if let Some(m) = make_match(candidate.symbols, target, error, table, context) {
409 return Some(m);
410 }
411 }
412 }
413 }
414
415 for (idx, uc) in context.user_constants.iter().enumerate() {
417 if idx >= 16 {
418 break;
419 }
420 if let Some(sym) = Symbol::from_byte(128 + idx as u8) {
421 if uc.value != 0.0 {
423 let recip_val = 1.0 / uc.value;
424 if (recip_val - target).abs() < EXACT_TOLERANCE {
425 let symbols = [sym, Symbol::Recip];
426 if passes_symbol_filters(&symbols, config) && uc.num_type >= config.min_num_type
427 {
428 if let Some(m) =
429 make_match(&symbols, target, (recip_val - target).abs(), table, context)
430 {
431 return Some(m);
432 }
433 }
434 }
435 }
436 if uc.value > 0.0 {
438 let sqrt_val = uc.value.sqrt();
439 if (sqrt_val - target).abs() < EXACT_TOLERANCE {
440 let symbols = [sym, Symbol::Sqrt];
441 if passes_symbol_filters(&symbols, config) && uc.num_type >= config.min_num_type
442 {
443 if let Some(m) =
444 make_match(&symbols, target, (sqrt_val - target).abs(), table, context)
445 {
446 return Some(m);
447 }
448 }
449 }
450 }
451 }
452 }
453
454 None
455}
456
457fn make_match(
459 symbols: &[Symbol],
460 target: f64,
461 error: f64,
462 table: &SymbolTable,
463 context: &EvalContext<'_>,
464) -> Option<Match> {
465 let lhs_expr = expr_from_symbols_with_table(&[Symbol::X], table);
466 let rhs_expr = expr_from_symbols_with_table(symbols, table);
467 let complexity = lhs_expr.complexity() + rhs_expr.complexity();
468
469 let lhs_eval = evaluate_with_context(&lhs_expr, target, context).ok()?;
470 let rhs_eval = evaluate_with_context(&rhs_expr, target, context).ok()?;
471
472 Some(Match {
473 lhs: EvaluatedExpr {
474 expr: lhs_expr,
475 value: lhs_eval.value,
476 derivative: lhs_eval.derivative,
477 num_type: NumType::Transcendental,
478 },
479 rhs: EvaluatedExpr {
480 expr: rhs_expr,
481 value: rhs_eval.value,
482 derivative: 0.0,
483 num_type: rhs_eval.num_type,
484 },
485 x_value: target,
486 error,
487 complexity,
488 })
489}
490
491#[cfg(test)]
492mod tests {
493 use super::*;
494
495 fn default_config() -> FastMatchConfig<'static> {
496 static EMPTY: std::sync::OnceLock<HashSet<u8>> = std::sync::OnceLock::new();
497 let empty = EMPTY.get_or_init(HashSet::new);
498 FastMatchConfig {
499 excluded_symbols: empty,
500 allowed_symbols: None,
501 min_num_type: NumType::Transcendental,
502 }
503 }
504
505 fn default_table() -> SymbolTable {
506 SymbolTable::new()
507 }
508
509 #[test]
510 fn test_pi_match() {
511 let m = find_fast_match(
512 std::f64::consts::PI,
513 &[],
514 &default_config(),
515 &default_table(),
516 );
517 assert!(m.is_some());
518 let m = m.unwrap();
519 assert!(m.error.abs() < 1e-14);
520 assert_eq!(m.rhs.expr.to_postfix(), "p");
521 }
522
523 #[test]
524 fn test_pi_excluded() {
525 let excluded: HashSet<u8> = vec![b'p'].into_iter().collect();
526 let config = FastMatchConfig {
527 excluded_symbols: &excluded,
528 allowed_symbols: None,
529 min_num_type: NumType::Transcendental,
530 };
531 let m = find_fast_match(std::f64::consts::PI, &[], &config, &default_table());
532 assert!(m.is_none(), "Should not find pi when it's excluded");
533 }
534
535 #[test]
536 fn test_pi_algebraic_only() {
537 static EMPTY: std::sync::OnceLock<HashSet<u8>> = std::sync::OnceLock::new();
538 let empty = EMPTY.get_or_init(HashSet::new);
539 let config = FastMatchConfig {
540 excluded_symbols: empty,
541 allowed_symbols: None,
542 min_num_type: NumType::Algebraic,
543 };
544 let m = find_fast_match(std::f64::consts::PI, &[], &config, &default_table());
545 assert!(
546 m.is_none(),
547 "Should not find pi when only algebraic allowed"
548 );
549 }
550
551 #[test]
552 fn test_sqrt2_algebraic_ok() {
553 static EMPTY: std::sync::OnceLock<HashSet<u8>> = std::sync::OnceLock::new();
554 let empty = EMPTY.get_or_init(HashSet::new);
555 let config = FastMatchConfig {
556 excluded_symbols: empty,
557 allowed_symbols: None,
558 min_num_type: NumType::Algebraic,
559 };
560 let m = find_fast_match(2.0_f64.sqrt(), &[], &config, &default_table());
561 assert!(m.is_some(), "sqrt(2) should be found with algebraic-only");
562 }
563
564 #[test]
565 fn test_e_match() {
566 let m = find_fast_match(
567 std::f64::consts::E,
568 &[],
569 &default_config(),
570 &default_table(),
571 );
572 assert!(m.is_some());
573 let m = m.unwrap();
574 assert!(m.error.abs() < 1e-14);
575 assert_eq!(m.rhs.expr.to_postfix(), "e");
576 }
577
578 #[test]
579 fn test_sqrt2_match() {
580 let m = find_fast_match(2.0_f64.sqrt(), &[], &default_config(), &default_table());
581 assert!(m.is_some());
582 let m = m.unwrap();
583 assert!(m.error.abs() < 1e-14);
584 assert_eq!(m.rhs.expr.to_postfix(), "2q");
585 }
586
587 #[test]
588 fn test_phi_match() {
589 let phi = (1.0 + 5.0_f64.sqrt()) / 2.0;
590 let m = find_fast_match(phi, &[], &default_config(), &default_table());
591 assert!(m.is_some());
592 let m = m.unwrap();
593 assert!(m.error.abs() < 1e-14);
594 assert_eq!(m.rhs.expr.to_postfix(), "f");
595 }
596
597 #[test]
598 fn test_integer_match() {
599 let m = find_fast_match(5.0, &[], &default_config(), &default_table());
600 assert!(m.is_some());
601 let m = m.unwrap();
602 assert!(m.error.abs() < 1e-14);
603 assert_eq!(m.rhs.expr.to_postfix(), "5");
604 }
605
606 #[test]
607 fn test_no_match_for_random() {
608 let m = find_fast_match(2.506314, &[], &default_config(), &default_table());
610 assert!(m.is_none());
611 }
612
613 #[test]
614 fn test_user_constant_match() {
615 let uc = UserConstant {
616 weight: 4,
617 name: "myconst".to_string(),
618 description: "Test constant".to_string(),
619 value: std::f64::consts::E,
620 num_type: NumType::Transcendental,
621 };
622 let m = find_fast_match(
623 std::f64::consts::E,
624 &[uc],
625 &default_config(),
626 &default_table(),
627 );
628 assert!(m.is_some());
629 }
630}