use ries_rs::{expr, search, symbol, EvalContext, SymbolTable, EXACT_MATCH_TOLERANCE};
#[derive(Debug, Clone, Copy)]
pub enum DisplayFormat {
Infix(expr::OutputFormat),
PostfixCompact,
PostfixVerbose,
Condensed,
}
pub fn format_value(v: f64) -> String {
if v.abs() >= 1e6 || (v.abs() < 1e-4 && v != 0.0) {
format!("{:.10e}", v)
} else {
format!("{:.10}", v)
}
}
pub fn parse_display_format(s: &str) -> DisplayFormat {
match s.to_lowercase().as_str() {
"0" => DisplayFormat::PostfixCompact,
"1" => DisplayFormat::Condensed, "3" => DisplayFormat::PostfixVerbose,
"pretty" | "unicode" => DisplayFormat::Infix(expr::OutputFormat::Pretty),
"mathematica" | "math" | "mma" => DisplayFormat::Infix(expr::OutputFormat::Mathematica),
"sympy" | "python" => DisplayFormat::Infix(expr::OutputFormat::SymPy),
_ => DisplayFormat::Infix(expr::OutputFormat::Default),
}
}
fn postfix_verbose_token(sym: symbol::Symbol) -> String {
use symbol::Symbol;
match sym {
Symbol::Neg => "neg".to_string(),
Symbol::Recip => "recip".to_string(),
Symbol::Sqrt => "sqrt".to_string(),
Symbol::Square => "dup*".to_string(),
Symbol::Pow => "**".to_string(),
Symbol::Root => "root".to_string(),
Symbol::Log => "logn".to_string(),
Symbol::Exp => "exp".to_string(),
_ => sym.display_name(),
}
}
pub fn apply_explicit_multiply(infix: &str) -> String {
let chars: Vec<char> = infix.chars().collect();
let mut out = String::with_capacity(infix.len() + 8);
for i in 0..chars.len() {
let ch = chars[i];
if ch != ' ' {
out.push(ch);
continue;
}
let prev = i.checked_sub(1).and_then(|idx| chars.get(idx).copied());
let next = chars.get(i + 1).copied();
let implicit_mul = prev.is_some_and(|c| c.is_ascii_digit() || c == ')')
&& next.is_some_and(|c| c.is_ascii_alphabetic() || c == '(');
if implicit_mul {
out.push('*');
} else {
out.push(' ');
}
}
out
}
pub fn format_expression_for_display(
expression: &expr::Expression,
format: DisplayFormat,
explicit_multiply: bool,
table: Option<&SymbolTable>,
) -> String {
match format {
DisplayFormat::Infix(inner) => {
let infix = match inner {
expr::OutputFormat::Default => {
if let Some(t) = table {
expression.to_infix_with_table(t)
} else {
expression.to_infix()
}
}
expr::OutputFormat::Pretty => {
let base = if let Some(t) = table {
expression.to_infix_with_table(t)
} else {
expression.to_infix()
};
apply_pretty_unicode(&base)
}
expr::OutputFormat::Mathematica => expression.to_infix_mathematica(),
expr::OutputFormat::SymPy => expression.to_infix_sympy(),
};
if explicit_multiply {
apply_explicit_multiply(&infix)
} else {
infix
}
}
DisplayFormat::PostfixCompact | DisplayFormat::Condensed => expression.to_postfix(),
DisplayFormat::PostfixVerbose => expression
.symbols()
.iter()
.map(|sym| postfix_verbose_token(*sym))
.collect::<Vec<_>>()
.join(" "),
}
}
fn apply_pretty_unicode(s: &str) -> String {
let mut result = s.to_string();
result = result.replace("pi", "Ï€");
result = result.replace("sqrt(", "√(");
result = result.replace("^2", "²");
result
}
pub fn print_match_relative(
m: &search::Match,
_solve: bool,
format: DisplayFormat,
explicit_multiply: bool,
solved_rhs: Option<&expr::Expression>,
table: Option<&SymbolTable>,
) {
let lhs_expr = if solved_rhs.is_some() {
let mut x_expr = expr::Expression::new();
x_expr.push(symbol::Symbol::X);
x_expr
} else {
m.lhs.expr.clone()
};
let rhs_expr = solved_rhs.unwrap_or(&m.rhs.expr);
let lhs_str = format_expression_for_display(&lhs_expr, format, explicit_multiply, table);
let rhs_str = format_expression_for_display(rhs_expr, format, explicit_multiply, table);
let error_str = if m.error.abs() < EXACT_MATCH_TOLERANCE {
"('exact' match)".to_string()
} else {
let sign = if m.error >= 0.0 { "+" } else { "-" };
format!("for x = T {} {:.6e}", sign, m.error.abs())
};
println!(
"{:>24} = {:<24} {} {{{}}}",
lhs_str, rhs_str, error_str, m.complexity
);
}
pub fn print_match_absolute(
m: &search::Match,
_solve: bool,
format: DisplayFormat,
explicit_multiply: bool,
solved_rhs: Option<&expr::Expression>,
table: Option<&SymbolTable>,
) {
let lhs_expr = if solved_rhs.is_some() {
let mut x_expr = expr::Expression::new();
x_expr.push(symbol::Symbol::X);
x_expr
} else {
m.lhs.expr.clone()
};
let rhs_expr = solved_rhs.unwrap_or(&m.rhs.expr);
let lhs_str = format_expression_for_display(&lhs_expr, format, explicit_multiply, table);
let rhs_str = format_expression_for_display(rhs_expr, format, explicit_multiply, table);
println!(
"{:>24} = {:<24} for x = {:.15} {{{}}}",
lhs_str, rhs_str, m.x_value, m.complexity
);
}
pub fn print_header(target: f64, level: i32) {
println!();
println!(" Target: {}", target);
println!(" Level: {}", level);
println!();
}
pub fn print_footer(stats: &search::SearchStats, elapsed: std::time::Duration) {
println!();
println!(" === Summary ===");
let total_tested = stats.lhs_tested.saturating_add(stats.candidates_tested);
println!(" Total expressions tested: {}", total_tested);
println!(" LHS expressions: {}", stats.lhs_count);
println!(" RHS expressions: {}", stats.rhs_count);
println!(" Search time: {:.3}s", elapsed.as_secs_f64());
}
fn expression_from_symbols(symbols: &[symbol::Symbol]) -> expr::Expression {
let mut expression = expr::Expression::new();
for &sym in symbols {
expression.push(sym);
}
expression
}
fn decompose_subexpressions(expression: &expr::Expression) -> Vec<expr::Expression> {
let mut stack: Vec<expr::Expression> = Vec::new();
let mut steps = Vec::new();
for &sym in expression.symbols() {
match sym.seft() {
symbol::Seft::A => {
let mut atom = expr::Expression::new();
atom.push(sym);
stack.push(atom.clone());
steps.push(atom);
}
symbol::Seft::B => {
let Some(mut a) = stack.pop() else {
break;
};
a.push(sym);
stack.push(a.clone());
steps.push(a);
}
symbol::Seft::C => {
let Some(b) = stack.pop() else {
break;
};
let Some(a) = stack.pop() else {
break;
};
let mut combined = expression_from_symbols(a.symbols());
for &rhs_sym in b.symbols() {
combined.push(rhs_sym);
}
combined.push(sym);
stack.push(combined.clone());
steps.push(combined);
}
}
}
steps
}
#[allow(clippy::too_many_arguments)]
fn print_expression_steps(
label: &str,
expression: &expr::Expression,
x: f64,
format: DisplayFormat,
explicit_multiply: bool,
eval_context: &EvalContext<'_>,
table: Option<&SymbolTable>,
) {
println!(" {} steps:", label);
for (idx, step_expr) in decompose_subexpressions(expression).iter().enumerate() {
let rendered = format_expression_for_display(step_expr, format, explicit_multiply, table);
match ries_rs::eval::evaluate_with_context(step_expr, x, eval_context) {
Ok(result) => println!(
" {:>2}. {:<28} value={:+.12e} deriv={:+.12e}",
idx + 1,
rendered,
result.value,
result.derivative
),
Err(err) => println!(
" {:>2}. {:<28} evaluation error: {}",
idx + 1,
rendered,
err
),
}
}
}
pub fn print_show_work_details(
shown_matches: &[&search::Match],
format: DisplayFormat,
explicit_multiply: bool,
eval_context: &EvalContext<'_>,
table: Option<&SymbolTable>,
) {
if shown_matches.is_empty() {
return;
}
println!();
println!(" --show-work details:");
for (idx, m) in shown_matches.iter().enumerate() {
println!(" Match {} at x = {:.15}", idx + 1, m.x_value);
print_expression_steps(
"LHS",
&m.lhs.expr,
m.x_value,
format,
explicit_multiply,
eval_context,
table,
);
print_expression_steps(
"RHS",
&m.rhs.expr,
m.x_value,
format,
explicit_multiply,
eval_context,
table,
);
}
}
pub fn compute_significant_digits_tolerance(target: f64) -> f64 {
if target == 0.0 {
return 1e-15;
}
let target_str = format!("{:.15}", target);
let trimmed = target_str.trim_end_matches('0');
let decimal_pos = trimmed.find('.');
let digits_after_decimal = if let Some(pos) = decimal_pos {
trimmed.len() - pos - 1
} else {
0
};
let tolerance = 0.5 * 10_f64.powi(-(digits_after_decimal as i32));
tolerance.max(1e-15)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[allow(clippy::approx_constant)]
fn test_format_value() {
assert_eq!(format_value(2.71828), "2.7182800000");
assert_eq!(format_value(1e10), "1.0000000000e10");
}
#[test]
fn test_parse_display_format() {
assert!(matches!(
parse_display_format("0"),
DisplayFormat::PostfixCompact
));
assert!(matches!(
parse_display_format("pretty"),
DisplayFormat::Infix(expr::OutputFormat::Pretty)
));
assert!(matches!(
parse_display_format("mathematica"),
DisplayFormat::Infix(expr::OutputFormat::Mathematica)
));
}
#[test]
fn test_compute_significant_digits_tolerance() {
let tol = compute_significant_digits_tolerance(2.5);
assert!(tol > 0.04 && tol < 0.06);
let tol = compute_significant_digits_tolerance(2.51);
assert!(tol > 0.004 && tol < 0.006);
}
}