use super::parsing::{
find_matching_paren, split_args_balanced, DelimiterTracker, SUPPORTED_MATH_FUNCS,
};
use pyisheval::{Interpreter, Value};
use std::collections::HashMap;
fn replace_math_pi_constants(expr: &str) -> String {
let bytes = expr.as_bytes();
let mut result = String::with_capacity(expr.len());
let mut tracker = DelimiterTracker::new();
let mut last_copy = 0; let mut i = 0;
while i < bytes.len() {
if i + 7 <= bytes.len() && &bytes[i..i + 7] == b"math.pi" {
let after_ok = i + 7 >= bytes.len()
|| (!bytes[i + 7].is_ascii_alphanumeric() && bytes[i + 7] != b'_');
let before_ok =
i == 0 || (!bytes[i - 1].is_ascii_alphanumeric() && bytes[i - 1] != b'_');
if !tracker.in_string() && before_ok && after_ok {
result.push_str(&expr[last_copy..i]);
result.push_str("pi");
i += 7;
last_copy = i;
continue;
}
}
tracker.process(bytes[i]);
i += 1;
}
result.push_str(&expr[last_copy..]);
result
}
fn scan_math_functions(expr: &str) -> Vec<(usize, &'static str, usize)> {
let result_bytes = expr.as_bytes();
let mut function_matches = Vec::new();
let mut scan_tracker = DelimiterTracker::new();
let mut i = 0;
while i < result_bytes.len() {
scan_tracker.process(result_bytes[i]);
log::trace!(
"[scan] At pos {}, char: {}, in_string: {}",
i,
result_bytes[i] as char,
scan_tracker.in_string()
);
if scan_tracker.in_string() {
i += 1;
continue;
}
let has_prefix = i + 5 <= result_bytes.len() && &result_bytes[i..i + 5] == b"math.";
let func_start = if has_prefix { i + 5 } else { i };
if func_start > 0 {
let prev_ch = result_bytes[func_start - 1];
if prev_ch.is_ascii_alphanumeric()
|| prev_ch == b'_'
|| (!has_prefix && prev_ch == b'.')
{
i += 1;
continue;
}
}
let mut found = None;
for &func in SUPPORTED_MATH_FUNCS {
let func_bytes = func.as_bytes();
let func_end = func_start + func_bytes.len();
if func_end <= result_bytes.len() && &result_bytes[func_start..func_end] == func_bytes {
log::trace!("[scan] Potential match '{}' at pos {}", func, func_start);
let mut paren_pos = func_end;
while paren_pos < result_bytes.len()
&& result_bytes[paren_pos].is_ascii_whitespace()
{
paren_pos += 1;
}
if paren_pos < result_bytes.len() && result_bytes[paren_pos] == b'(' {
log::trace!("[scan] Confirmed match '{}(' at pos {}", func, func_start);
found = Some((func, paren_pos, i));
break;
} else {
log::trace!(
"[scan] No '(' after '{}' at pos {}, char at paren_pos: {:?}",
func,
func_start,
result_bytes.get(paren_pos).map(|&b| b as char)
);
}
}
}
if let Some((func_name, paren_pos, match_start)) = found {
function_matches.push((match_start, func_name, paren_pos));
i = paren_pos + 1;
} else {
i += 1;
}
}
function_matches
}
pub(super) fn preprocess_math_functions(
expr: &str,
interp: &mut Interpreter,
context: &HashMap<String, Value>,
) -> Result<String, super::EvalError> {
macro_rules! eval_math_arg {
($interp:expr, $context:expr, $arg_expr:expr, $func_name:expr, $arg_template:expr) => {
match $interp.eval_with_context($arg_expr, $context) {
Ok(Value::Number(n)) => n,
Ok(val) => {
log::warn!(
"[eval_math_arg] Expected numeric argument for {}({}), got {:?}. Skipping this match.",
$func_name, $arg_expr, val
);
continue;
}
Err(e) => {
return Err(super::EvalError::PyishEval {
expr: format!($arg_template, $func_name, $arg_expr),
source: e,
});
}
}
};
}
let mut result = replace_math_pi_constants(expr);
let mut iteration = 0;
const MAX_ITERATIONS: usize = 100;
loop {
iteration += 1;
if iteration > MAX_ITERATIONS {
return Err(super::EvalError::PyishEval {
expr: expr.to_string(),
source: pyisheval::EvalError::ParseError(
"Too many nested math function calls (possible infinite loop)".to_string(),
),
});
}
let function_matches = scan_math_functions(&result);
if log::log_enabled!(log::Level::Debug) {
let result_char_len = result.chars().count();
let max_preview_chars = 200;
let preview: String = result.chars().take(max_preview_chars).collect();
let truncated = result_char_len > max_preview_chars;
log::debug!(
"[preprocess_math_functions] Found {} function matches (len={}{}). Preview: {}",
function_matches.len(),
result_char_len,
if truncated { ", truncated" } else { "" },
preview,
);
}
if function_matches.is_empty() {
break;
}
let mut made_replacement = false;
for &(match_start, func_name, paren_pos) in function_matches.iter().rev() {
let close_pos = match find_matching_paren(&result, paren_pos) {
Some(pos) => pos,
None => continue, };
let arg = &result[paren_pos + 1..close_pos];
if func_name == "atan2" || func_name == "pow" || func_name == "log" {
log::debug!(
"[preprocess_math_functions] Found {}({}) at pos {}",
func_name,
arg,
match_start
);
let args = split_args_balanced(arg);
if args.len() == 2 {
log::debug!(
"[preprocess_math_functions] Evaluating {} arg1: '{}' with context: {:?}",
func_name,
args[0].trim(),
context.keys().collect::<Vec<_>>()
);
let first =
eval_math_arg!(interp, context, args[0].trim(), func_name, "{}({}, ...)");
let second =
eval_math_arg!(interp, context, args[1].trim(), func_name, "{}(..., {})");
let computed = match func_name {
"atan2" => first.atan2(second),
"pow" => first.powf(second),
"log" => first.log(second), _ => unreachable!(),
};
let replacement = format!("{}", computed);
result.replace_range(match_start..=close_pos, &replacement);
made_replacement = true;
break;
}
if func_name == "log" {
if args.len() == 1 {
} else if args.len() != 2 {
log::warn!("log() expects 1 or 2 arguments, but got {}.", args.len());
continue;
}
} else {
if args.len() != 2 {
log::warn!(
"{}() expects 2 arguments, but got {}.",
func_name,
args.len()
);
continue;
}
}
}
let n = eval_math_arg!(interp, context, arg, func_name, "{}({})");
if (func_name == "acos" || func_name == "asin") && !(-1.0..=1.0).contains(&n) {
log::warn!(
"{}({}) domain error: argument must be in [-1, 1], got {}",
func_name,
arg,
n
);
continue; }
let computed = match func_name {
"cos" => n.cos(),
"sin" => n.sin(),
"tan" => n.tan(),
"acos" => n.acos(),
"asin" => n.asin(),
"atan" => n.atan(),
"sqrt" => n.sqrt(),
"abs" => n.abs(),
"floor" => n.floor(),
"ceil" => n.ceil(),
"log" => n.ln(), _ => unreachable!(
"Function '{}' in SUPPORTED_MATH_FUNCS but not in match statement",
func_name
),
};
let replacement = format!("{}", computed);
result.replace_range(match_start..=close_pos, &replacement);
made_replacement = true;
break;
}
if !made_replacement {
break;
}
}
log::debug!("[preprocess_math_functions] Output: {}", result);
Ok(result)
}