use crate::eval::{evaluate_with_context, EvalContext};
use crate::expr::{EvaluatedExpr, Expression};
use crate::profile::UserConstant;
use crate::search::Match;
use crate::symbol::{NumType, Symbol};
use crate::symbol_table::SymbolTable;
use std::collections::HashSet;
const EXACT_TOLERANCE: f64 = 1e-14;
fn expr_from_symbols_with_table(symbols: &[Symbol], table: &SymbolTable) -> Expression {
let mut expr = Expression::new();
for &sym in symbols {
expr.push_with_table(sym, table);
}
expr
}
fn get_num_type(symbols: &[Symbol]) -> NumType {
use Symbol::*;
if symbols.len() == 1 {
return symbols[0].inherent_type();
}
if symbols.len() == 2 {
if matches!(symbols[1], Sqrt) {
if matches!(
symbols[0],
One | Two | Three | Four | Five | Six | Seven | Eight | Nine
) {
return NumType::Algebraic; }
if matches!(symbols[0], Pi | E | Gamma | Apery | Catalan) {
return NumType::Transcendental;
}
}
if matches!(symbols[1], Recip)
&& matches!(
symbols[0],
One | Two | Three | Four | Five | Six | Seven | Eight | Nine
)
{
return NumType::Rational;
}
if matches!(symbols[1], Div)
&& matches!(
symbols[0],
One | Two | Three | Four | Five | Six | Seven | Eight | Nine
)
&& symbols.len() >= 3
{
return NumType::Rational;
}
}
if symbols.len() == 3 && matches!(symbols[2], Div) {
if matches!(
symbols[0],
One | Two | Three | Four | Five | Six | Seven | Eight | Nine
) && matches!(
symbols[1],
One | Two | Three | Four | Five | Six | Seven | Eight | Nine
) {
return NumType::Rational;
}
}
for &sym in symbols {
let sym_type = sym.inherent_type();
if sym_type == NumType::Transcendental {
return NumType::Transcendental;
}
}
for &sym in symbols {
if matches!(sym, Phi | Plastic) {
return NumType::Algebraic;
}
}
NumType::Transcendental
}
fn contains_excluded(symbols: &[Symbol], excluded: &HashSet<u8>) -> bool {
symbols.iter().any(|s| excluded.contains(&(*s as u8)))
}
struct FastCandidate {
symbols: &'static [Symbol],
}
fn get_constant_candidates() -> Vec<FastCandidate> {
vec![
FastCandidate {
symbols: &[Symbol::One],
},
FastCandidate {
symbols: &[Symbol::Two],
},
FastCandidate {
symbols: &[Symbol::Three],
},
FastCandidate {
symbols: &[Symbol::Four],
},
FastCandidate {
symbols: &[Symbol::Five],
},
FastCandidate {
symbols: &[Symbol::Six],
},
FastCandidate {
symbols: &[Symbol::Seven],
},
FastCandidate {
symbols: &[Symbol::Eight],
},
FastCandidate {
symbols: &[Symbol::Nine],
},
FastCandidate {
symbols: &[Symbol::Pi],
},
FastCandidate {
symbols: &[Symbol::E],
},
FastCandidate {
symbols: &[Symbol::Phi],
},
FastCandidate {
symbols: &[Symbol::Gamma],
},
FastCandidate {
symbols: &[Symbol::Plastic],
},
FastCandidate {
symbols: &[Symbol::Apery],
},
FastCandidate {
symbols: &[Symbol::Catalan],
},
FastCandidate {
symbols: &[Symbol::One, Symbol::Two, Symbol::Div],
},
FastCandidate {
symbols: &[Symbol::One, Symbol::Three, Symbol::Div],
},
FastCandidate {
symbols: &[Symbol::Two, Symbol::Three, Symbol::Div],
},
FastCandidate {
symbols: &[Symbol::One, Symbol::Four, Symbol::Div],
},
FastCandidate {
symbols: &[Symbol::Three, Symbol::Four, Symbol::Div],
},
FastCandidate {
symbols: &[Symbol::Two, Symbol::Sqrt],
},
FastCandidate {
symbols: &[Symbol::Three, Symbol::Sqrt],
},
FastCandidate {
symbols: &[Symbol::Five, Symbol::Sqrt],
},
FastCandidate {
symbols: &[Symbol::Six, Symbol::Sqrt],
},
FastCandidate {
symbols: &[Symbol::Seven, Symbol::Sqrt],
},
FastCandidate {
symbols: &[Symbol::Eight, Symbol::Sqrt],
},
FastCandidate {
symbols: &[Symbol::Pi, Symbol::Sqrt],
},
FastCandidate {
symbols: &[Symbol::E, Symbol::Sqrt],
},
FastCandidate {
symbols: &[Symbol::Two, Symbol::Ln],
},
FastCandidate {
symbols: &[Symbol::Pi, Symbol::Ln],
},
FastCandidate {
symbols: &[Symbol::E, Symbol::One, Symbol::Sub],
},
FastCandidate {
symbols: &[Symbol::E, Symbol::One, Symbol::Add],
},
FastCandidate {
symbols: &[Symbol::Pi, Symbol::One, Symbol::Sub],
},
FastCandidate {
symbols: &[Symbol::Pi, Symbol::One, Symbol::Add],
},
FastCandidate {
symbols: &[Symbol::Pi, Symbol::Two, Symbol::Sub],
},
FastCandidate {
symbols: &[Symbol::One, Symbol::Two, Symbol::Add],
},
FastCandidate {
symbols: &[Symbol::One, Symbol::Sqrt, Symbol::One, Symbol::Add],
},
FastCandidate {
symbols: &[Symbol::Two, Symbol::Sqrt, Symbol::One, Symbol::Add],
},
FastCandidate {
symbols: &[Symbol::Phi, Symbol::One, Symbol::Add],
},
FastCandidate {
symbols: &[Symbol::Phi, Symbol::Two, Symbol::Add],
},
FastCandidate {
symbols: &[Symbol::Phi, Symbol::Square],
},
FastCandidate {
symbols: &[Symbol::Pi, Symbol::Recip],
},
FastCandidate {
symbols: &[Symbol::E, Symbol::Recip],
},
FastCandidate {
symbols: &[Symbol::Phi, Symbol::Recip],
},
]
}
fn check_integer(target: f64) -> Option<(i64, f64)> {
let rounded = target.round();
let error = (target - rounded).abs();
if error < EXACT_TOLERANCE && rounded.abs() < 1000.0 {
Some((rounded as i64, error))
} else {
None
}
}
pub struct FastMatchConfig<'a> {
pub excluded_symbols: &'a HashSet<u8>,
pub allowed_symbols: Option<&'a HashSet<u8>>,
pub min_num_type: NumType,
}
#[inline]
fn passes_symbol_filters(symbols: &[Symbol], config: &FastMatchConfig<'_>) -> bool {
if contains_excluded(symbols, config.excluded_symbols) {
return false;
}
if let Some(allowed) = config.allowed_symbols {
if symbols.iter().any(|s| !allowed.contains(&(*s as u8))) {
return false;
}
}
true
}
pub fn find_fast_match(
target: f64,
user_constants: &[UserConstant],
config: &FastMatchConfig<'_>,
table: &SymbolTable,
) -> Option<Match> {
let context = EvalContext::from_slices(user_constants, &[]);
find_fast_match_with_context(target, &context, config, table)
}
pub fn find_fast_match_with_context(
target: f64,
context: &EvalContext<'_>,
config: &FastMatchConfig<'_>,
table: &SymbolTable,
) -> Option<Match> {
if let Some((n, error)) = check_integer(target) {
if (1..=9).contains(&n) {
let symbols: &[Symbol] = match n {
1 => &[Symbol::One],
2 => &[Symbol::Two],
3 => &[Symbol::Three],
4 => &[Symbol::Four],
5 => &[Symbol::Five],
6 => &[Symbol::Six],
7 => &[Symbol::Seven],
8 => &[Symbol::Eight],
9 => &[Symbol::Nine],
_ => return None,
};
if passes_symbol_filters(symbols, config)
&& get_num_type(symbols) >= config.min_num_type
{
if let Some(m) = make_match(symbols, target, error, table, context) {
return Some(m);
}
}
}
for (idx, uc) in context.user_constants.iter().enumerate() {
if idx < 16 && (uc.value - target).abs() < EXACT_TOLERANCE {
if let Some(sym) = Symbol::from_byte(128 + idx as u8) {
let symbols = [sym];
if passes_symbol_filters(&symbols, config) && uc.num_type >= config.min_num_type
{
if let Some(m) =
make_match(&symbols, target, (uc.value - target).abs(), table, context)
{
return Some(m);
}
}
}
}
}
}
for (idx, uc) in context.user_constants.iter().enumerate() {
if idx >= 16 {
break;
}
if (uc.value - target).abs() < EXACT_TOLERANCE {
if let Some(sym) = Symbol::from_byte(128 + idx as u8) {
let symbols = [sym];
if passes_symbol_filters(&symbols, config) && uc.num_type >= config.min_num_type {
if let Some(m) =
make_match(&symbols, target, (uc.value - target).abs(), table, context)
{
return Some(m);
}
}
}
}
}
let candidates = get_constant_candidates();
for candidate in candidates {
if !passes_symbol_filters(candidate.symbols, config) {
continue;
}
if get_num_type(candidate.symbols) < config.min_num_type {
continue;
}
let expr = expr_from_symbols_with_table(candidate.symbols, table);
if let Ok(result) = evaluate_with_context(&expr, target, context) {
let error = (result.value - target).abs();
if error < EXACT_TOLERANCE {
if let Some(m) = make_match(candidate.symbols, target, error, table, context) {
return Some(m);
}
}
}
}
for (idx, uc) in context.user_constants.iter().enumerate() {
if idx >= 16 {
break;
}
if let Some(sym) = Symbol::from_byte(128 + idx as u8) {
if uc.value != 0.0 {
let recip_val = 1.0 / uc.value;
if (recip_val - target).abs() < EXACT_TOLERANCE {
let symbols = [sym, Symbol::Recip];
if passes_symbol_filters(&symbols, config) && uc.num_type >= config.min_num_type
{
if let Some(m) =
make_match(&symbols, target, (recip_val - target).abs(), table, context)
{
return Some(m);
}
}
}
}
if uc.value > 0.0 {
let sqrt_val = uc.value.sqrt();
if (sqrt_val - target).abs() < EXACT_TOLERANCE {
let symbols = [sym, Symbol::Sqrt];
if passes_symbol_filters(&symbols, config) && uc.num_type >= config.min_num_type
{
if let Some(m) =
make_match(&symbols, target, (sqrt_val - target).abs(), table, context)
{
return Some(m);
}
}
}
}
}
}
None
}
fn make_match(
symbols: &[Symbol],
target: f64,
error: f64,
table: &SymbolTable,
context: &EvalContext<'_>,
) -> Option<Match> {
let lhs_expr = expr_from_symbols_with_table(&[Symbol::X], table);
let rhs_expr = expr_from_symbols_with_table(symbols, table);
let complexity = lhs_expr.complexity() + rhs_expr.complexity();
let lhs_eval = evaluate_with_context(&lhs_expr, target, context).ok()?;
let rhs_eval = evaluate_with_context(&rhs_expr, target, context).ok()?;
Some(Match {
lhs: EvaluatedExpr {
expr: lhs_expr,
value: lhs_eval.value,
derivative: lhs_eval.derivative,
num_type: NumType::Transcendental,
},
rhs: EvaluatedExpr {
expr: rhs_expr,
value: rhs_eval.value,
derivative: 0.0,
num_type: rhs_eval.num_type,
},
x_value: target,
error,
complexity,
})
}
#[cfg(test)]
mod tests {
use super::*;
fn default_config() -> FastMatchConfig<'static> {
static EMPTY: std::sync::OnceLock<HashSet<u8>> = std::sync::OnceLock::new();
let empty = EMPTY.get_or_init(HashSet::new);
FastMatchConfig {
excluded_symbols: empty,
allowed_symbols: None,
min_num_type: NumType::Transcendental,
}
}
fn default_table() -> SymbolTable {
SymbolTable::new()
}
#[test]
fn test_pi_match() {
let m = find_fast_match(
std::f64::consts::PI,
&[],
&default_config(),
&default_table(),
);
assert!(m.is_some());
let m = m.unwrap();
assert!(m.error.abs() < 1e-14);
assert_eq!(m.rhs.expr.to_postfix(), "p");
}
#[test]
fn test_pi_excluded() {
let excluded: HashSet<u8> = vec![b'p'].into_iter().collect();
let config = FastMatchConfig {
excluded_symbols: &excluded,
allowed_symbols: None,
min_num_type: NumType::Transcendental,
};
let m = find_fast_match(std::f64::consts::PI, &[], &config, &default_table());
assert!(m.is_none(), "Should not find pi when it's excluded");
}
#[test]
fn test_pi_algebraic_only() {
static EMPTY: std::sync::OnceLock<HashSet<u8>> = std::sync::OnceLock::new();
let empty = EMPTY.get_or_init(HashSet::new);
let config = FastMatchConfig {
excluded_symbols: empty,
allowed_symbols: None,
min_num_type: NumType::Algebraic,
};
let m = find_fast_match(std::f64::consts::PI, &[], &config, &default_table());
assert!(
m.is_none(),
"Should not find pi when only algebraic allowed"
);
}
#[test]
fn test_sqrt2_algebraic_ok() {
static EMPTY: std::sync::OnceLock<HashSet<u8>> = std::sync::OnceLock::new();
let empty = EMPTY.get_or_init(HashSet::new);
let config = FastMatchConfig {
excluded_symbols: empty,
allowed_symbols: None,
min_num_type: NumType::Algebraic,
};
let m = find_fast_match(2.0_f64.sqrt(), &[], &config, &default_table());
assert!(m.is_some(), "sqrt(2) should be found with algebraic-only");
}
#[test]
fn test_e_match() {
let m = find_fast_match(
std::f64::consts::E,
&[],
&default_config(),
&default_table(),
);
assert!(m.is_some());
let m = m.unwrap();
assert!(m.error.abs() < 1e-14);
assert_eq!(m.rhs.expr.to_postfix(), "e");
}
#[test]
fn test_sqrt2_match() {
let m = find_fast_match(2.0_f64.sqrt(), &[], &default_config(), &default_table());
assert!(m.is_some());
let m = m.unwrap();
assert!(m.error.abs() < 1e-14);
assert_eq!(m.rhs.expr.to_postfix(), "2q");
}
#[test]
fn test_phi_match() {
let phi = (1.0 + 5.0_f64.sqrt()) / 2.0;
let m = find_fast_match(phi, &[], &default_config(), &default_table());
assert!(m.is_some());
let m = m.unwrap();
assert!(m.error.abs() < 1e-14);
assert_eq!(m.rhs.expr.to_postfix(), "f");
}
#[test]
fn test_integer_match() {
let m = find_fast_match(5.0, &[], &default_config(), &default_table());
assert!(m.is_some());
let m = m.unwrap();
assert!(m.error.abs() < 1e-14);
assert_eq!(m.rhs.expr.to_postfix(), "5");
}
#[test]
fn test_no_match_for_random() {
let m = find_fast_match(2.506314, &[], &default_config(), &default_table());
assert!(m.is_none());
}
#[test]
fn test_user_constant_match() {
let uc = UserConstant {
weight: 4,
name: "myconst".to_string(),
description: "Test constant".to_string(),
value: std::f64::consts::E,
num_type: NumType::Transcendental,
};
let m = find_fast_match(
std::f64::consts::E,
&[uc],
&default_config(),
&default_table(),
);
assert!(m.is_some());
}
}