use crate::eval::{evaluate_fast_with_context, EvalContext};
use crate::symbol_table::SymbolTable;
use std::sync::Arc;
const QUANTIZE_SCALE: f64 = 1e8;
const MAX_QUANTIZED_VALUE: f64 = 1e10;
const MAX_GENERATED_VALUE: f64 = 1e12;
use crate::expr::{EvaluatedExpr, Expression, MAX_EXPR_LEN};
use crate::profile::UserConstant;
use crate::symbol::{NumType, Seft, Symbol};
use crate::udf::UserFunction;
use std::collections::HashMap;
#[derive(Clone)]
pub struct GenConfig {
pub max_lhs_complexity: u32,
pub max_rhs_complexity: u32,
pub max_length: usize,
pub constants: Vec<Symbol>,
pub unary_ops: Vec<Symbol>,
pub binary_ops: Vec<Symbol>,
pub rhs_constants: Option<Vec<Symbol>>,
pub rhs_unary_ops: Option<Vec<Symbol>>,
pub rhs_binary_ops: Option<Vec<Symbol>>,
pub symbol_max_counts: HashMap<Symbol, u32>,
pub rhs_symbol_max_counts: Option<HashMap<Symbol, u32>>,
pub min_num_type: NumType,
pub generate_lhs: bool,
pub generate_rhs: bool,
pub user_constants: Vec<UserConstant>,
pub user_functions: Vec<UserFunction>,
pub show_pruned_arith: bool,
pub symbol_table: Arc<SymbolTable>,
}
#[derive(Debug, Clone, Copy)]
pub struct ExpressionConstraintOptions {
pub rational_exponents: bool,
pub rational_trig_args: bool,
pub max_trig_cycles: Option<u32>,
pub user_constant_types: [NumType; 16],
pub user_function_types: [NumType; 16],
}
impl Default for ExpressionConstraintOptions {
fn default() -> Self {
Self {
rational_exponents: false,
rational_trig_args: false,
max_trig_cycles: None,
user_constant_types: [NumType::Transcendental; 16],
user_function_types: [NumType::Transcendental; 16],
}
}
}
pub fn expression_respects_constraints(
expression: &Expression,
opts: ExpressionConstraintOptions,
) -> bool {
#[derive(Clone, Copy)]
struct ConstraintValue {
has_x: bool,
num_type: NumType,
}
let mut stack: Vec<ConstraintValue> = Vec::with_capacity(expression.len());
let mut trig_ops: u32 = 0;
for &sym in expression.symbols() {
match sym.seft() {
Seft::A => {
let num_type = if let Some(idx) = sym.user_constant_index() {
opts.user_constant_types[idx as usize]
} else {
sym.inherent_type()
};
stack.push(ConstraintValue {
has_x: sym == Symbol::X,
num_type,
});
}
Seft::B => {
let Some(arg) = stack.pop() else {
return false;
};
if matches!(sym, Symbol::SinPi | Symbol::CosPi | Symbol::TanPi) {
trig_ops = trig_ops.saturating_add(1);
if opts.rational_trig_args && (arg.has_x || arg.num_type < NumType::Rational) {
return false;
}
}
let num_type = match sym {
Symbol::Neg | Symbol::Square => arg.num_type,
Symbol::Recip => {
if arg.num_type >= NumType::Rational {
NumType::Rational
} else {
arg.num_type
}
}
Symbol::Sqrt => {
if arg.num_type >= NumType::Rational {
NumType::Algebraic
} else {
arg.num_type
}
}
Symbol::UserFunction0
| Symbol::UserFunction1
| Symbol::UserFunction2
| Symbol::UserFunction3
| Symbol::UserFunction4
| Symbol::UserFunction5
| Symbol::UserFunction6
| Symbol::UserFunction7
| Symbol::UserFunction8
| Symbol::UserFunction9
| Symbol::UserFunction10
| Symbol::UserFunction11
| Symbol::UserFunction12
| Symbol::UserFunction13
| Symbol::UserFunction14
| Symbol::UserFunction15 => {
let idx = sym.user_function_index().unwrap_or(0) as usize;
opts.user_function_types[idx]
}
_ => NumType::Transcendental,
};
stack.push(ConstraintValue {
has_x: arg.has_x,
num_type,
});
}
Seft::C => {
let Some(rhs) = stack.pop() else {
return false;
};
let Some(lhs) = stack.pop() else {
return false;
};
if opts.rational_exponents
&& sym == Symbol::Pow
&& (rhs.has_x || rhs.num_type < NumType::Rational)
{
return false;
}
let num_type = match sym {
Symbol::Add | Symbol::Sub | Symbol::Mul => lhs.num_type.combine(rhs.num_type),
Symbol::Div => {
let combined = lhs.num_type.combine(rhs.num_type);
if combined == NumType::Integer {
NumType::Rational
} else {
combined
}
}
Symbol::Pow => {
if rhs.has_x {
NumType::Transcendental
} else if rhs.num_type == NumType::Integer {
lhs.num_type
} else if lhs.num_type >= NumType::Rational
&& rhs.num_type >= NumType::Rational
{
NumType::Algebraic
} else {
NumType::Transcendental
}
}
Symbol::Root => NumType::Algebraic,
Symbol::Log | Symbol::Atan2 => NumType::Transcendental,
_ => NumType::Transcendental,
};
stack.push(ConstraintValue {
has_x: lhs.has_x || rhs.has_x,
num_type,
});
}
}
}
if stack.len() != 1 {
return false;
}
opts.max_trig_cycles
.is_none_or(|max_cycles| trig_ops <= max_cycles)
}
impl Default for GenConfig {
fn default() -> Self {
Self {
max_lhs_complexity: 128,
max_rhs_complexity: 128,
max_length: MAX_EXPR_LEN,
constants: Symbol::constants().to_vec(),
unary_ops: Symbol::unary_ops().to_vec(),
binary_ops: Symbol::binary_ops().to_vec(),
rhs_constants: None,
rhs_unary_ops: None,
rhs_binary_ops: None,
symbol_max_counts: HashMap::new(),
rhs_symbol_max_counts: None,
min_num_type: NumType::Transcendental,
generate_lhs: true,
generate_rhs: true,
user_constants: Vec::new(),
user_functions: Vec::new(),
show_pruned_arith: false,
symbol_table: Arc::new(SymbolTable::new()),
}
}
}
pub struct GeneratedExprs {
pub lhs: Vec<EvaluatedExpr>,
pub rhs: Vec<EvaluatedExpr>,
}
pub struct StreamingCallbacks<'a> {
pub on_rhs: &'a mut dyn FnMut(&EvaluatedExpr) -> bool,
pub on_lhs: &'a mut dyn FnMut(&EvaluatedExpr) -> bool,
}
pub type LhsKey = (i64, i64);
#[inline]
pub fn quantize_value(v: f64) -> i64 {
if !v.is_finite() || v.abs() > MAX_QUANTIZED_VALUE {
if v > MAX_QUANTIZED_VALUE {
return i64::MAX - 1;
} else if v < -MAX_QUANTIZED_VALUE {
return i64::MIN + 1;
}
return i64::MAX;
}
(v * QUANTIZE_SCALE).round() as i64
}
pub fn generate_all(config: &GenConfig, target: f64) -> GeneratedExprs {
generate_all_with_context(
config,
target,
&EvalContext::from_slices(&config.user_constants, &config.user_functions),
)
}
pub fn generate_all_with_context(
config: &GenConfig,
target: f64,
eval_context: &EvalContext<'_>,
) -> GeneratedExprs {
let mut lhs_raw = Vec::new();
let mut rhs_raw = Vec::new();
if config.generate_lhs && config.generate_rhs && has_rhs_symbol_overrides(config) {
let mut lhs_config = config.clone();
lhs_config.generate_lhs = true;
lhs_config.generate_rhs = false;
generate_recursive(
&lhs_config,
target,
*eval_context,
&mut Expression::new(),
0,
&mut lhs_raw,
&mut rhs_raw,
);
let rhs_config = rhs_only_config(config);
generate_recursive(
&rhs_config,
target,
*eval_context,
&mut Expression::new(),
0,
&mut lhs_raw,
&mut rhs_raw,
);
} else {
generate_recursive(
config,
target,
*eval_context,
&mut Expression::new(),
0, &mut lhs_raw,
&mut rhs_raw,
);
}
let mut rhs_map: HashMap<i64, EvaluatedExpr> = HashMap::new();
for expr in rhs_raw {
let key = quantize_value(expr.value);
rhs_map
.entry(key)
.and_modify(|existing| {
if expr.expr.complexity() < existing.expr.complexity() {
*existing = expr.clone();
}
})
.or_insert(expr);
}
let mut lhs_map: HashMap<LhsKey, EvaluatedExpr> = HashMap::new();
for expr in lhs_raw {
let key = (quantize_value(expr.value), quantize_value(expr.derivative));
lhs_map
.entry(key)
.and_modify(|existing| {
if expr.expr.complexity() < existing.expr.complexity() {
*existing = expr.clone();
}
})
.or_insert(expr);
}
GeneratedExprs {
lhs: lhs_map.into_values().collect(),
rhs: rhs_map.into_values().collect(),
}
}
pub fn generate_all_with_limit(
config: &GenConfig,
target: f64,
max_expressions: usize,
) -> Option<GeneratedExprs> {
generate_all_with_limit_and_context(
config,
target,
&EvalContext::from_slices(&config.user_constants, &config.user_functions),
max_expressions,
)
}
pub fn generate_all_with_limit_and_context(
config: &GenConfig,
target: f64,
eval_context: &EvalContext<'_>,
max_expressions: usize,
) -> Option<GeneratedExprs> {
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
let count = Arc::new(AtomicUsize::new(0));
let limit = max_expressions;
let mut lhs_raw = Vec::new();
let mut rhs_raw = Vec::new();
let mut callbacks = StreamingCallbacks {
on_lhs: &mut |expr| {
let current = count.fetch_add(1, Ordering::Relaxed) + 1;
if current > limit {
return false; }
lhs_raw.push(expr.clone());
true
},
on_rhs: &mut |expr| {
let current = count.fetch_add(1, Ordering::Relaxed) + 1;
if current > limit {
return false; }
rhs_raw.push(expr.clone());
true
},
};
generate_streaming_with_context(config, target, eval_context, &mut callbacks);
let final_count = count.load(Ordering::Relaxed);
if final_count > limit {
return None;
}
let mut rhs_map: HashMap<i64, EvaluatedExpr> = HashMap::new();
for expr in rhs_raw {
let key = quantize_value(expr.value);
rhs_map
.entry(key)
.and_modify(|existing| {
if expr.expr.complexity() < existing.expr.complexity() {
*existing = expr.clone();
}
})
.or_insert(expr);
}
let mut lhs_map: HashMap<LhsKey, EvaluatedExpr> = HashMap::new();
for expr in lhs_raw {
let key = (quantize_value(expr.value), quantize_value(expr.derivative));
lhs_map
.entry(key)
.and_modify(|existing| {
if expr.expr.complexity() < existing.expr.complexity() {
*existing = expr.clone();
}
})
.or_insert(expr);
}
Some(GeneratedExprs {
lhs: lhs_map.into_values().collect(),
rhs: rhs_map.into_values().collect(),
})
}
pub fn generate_streaming(config: &GenConfig, target: f64, callbacks: &mut StreamingCallbacks) {
generate_streaming_with_context(
config,
target,
&EvalContext::from_slices(&config.user_constants, &config.user_functions),
callbacks,
);
}
pub fn generate_streaming_with_context(
config: &GenConfig,
target: f64,
eval_context: &EvalContext<'_>,
callbacks: &mut StreamingCallbacks,
) {
if config.generate_lhs && config.generate_rhs && has_rhs_symbol_overrides(config) {
let mut lhs_config = config.clone();
lhs_config.generate_lhs = true;
lhs_config.generate_rhs = false;
if !generate_recursive_streaming(
&lhs_config,
target,
*eval_context,
&mut Expression::new(),
0,
callbacks,
) {
return;
}
let rhs_config = rhs_only_config(config);
generate_recursive_streaming(
&rhs_config,
target,
*eval_context,
&mut Expression::new(),
0,
callbacks,
);
} else {
generate_recursive_streaming(
config,
target,
*eval_context,
&mut Expression::new(),
0, callbacks,
);
}
}
#[inline]
fn has_rhs_symbol_overrides(config: &GenConfig) -> bool {
config.rhs_constants.is_some()
|| config.rhs_unary_ops.is_some()
|| config.rhs_binary_ops.is_some()
|| config.rhs_symbol_max_counts.is_some()
}
#[inline]
fn should_include_expression(
result: &crate::eval::EvalResult,
config: &GenConfig,
complexity: u32,
contains_x: bool,
) -> bool {
result.value.is_finite()
&& result.value.abs() <= MAX_GENERATED_VALUE
&& result.num_type >= config.min_num_type
&& if contains_x {
config.generate_lhs && complexity <= config.max_lhs_complexity
} else {
config.generate_rhs && complexity <= config.max_rhs_complexity
}
}
#[inline]
fn get_max_complexity(config: &GenConfig, contains_x: bool) -> u32 {
if contains_x {
config.max_lhs_complexity
} else {
std::cmp::max(config.max_lhs_complexity, config.max_rhs_complexity)
}
}
fn rhs_only_config(config: &GenConfig) -> GenConfig {
let mut rhs_config = config.clone();
rhs_config.generate_lhs = false;
rhs_config.generate_rhs = true;
if let Some(constants) = &config.rhs_constants {
rhs_config.constants = constants.clone();
}
if let Some(unary_ops) = &config.rhs_unary_ops {
rhs_config.unary_ops = unary_ops.clone();
}
if let Some(binary_ops) = &config.rhs_binary_ops {
rhs_config.binary_ops = binary_ops.clone();
}
if let Some(rhs_symbol_max_counts) = &config.rhs_symbol_max_counts {
rhs_config.symbol_max_counts = rhs_symbol_max_counts.clone();
}
rhs_config
}
#[inline]
fn exceeds_symbol_limit(config: &GenConfig, current: &Expression, sym: Symbol) -> bool {
config
.symbol_max_counts
.get(&sym)
.is_some_and(|&max| current.count_symbol(sym) >= max)
}
fn generate_recursive_streaming(
config: &GenConfig,
target: f64,
eval_context: EvalContext<'_>,
current: &mut Expression,
stack_depth: usize,
callbacks: &mut StreamingCallbacks,
) -> bool {
if stack_depth == 1 && !current.is_empty() {
match evaluate_fast_with_context(current, target, &eval_context) {
Ok(result) => {
if should_include_expression(
&result,
config,
current.complexity(),
current.contains_x(),
) {
let expr = current.clone();
let eval_expr =
EvaluatedExpr::new(expr, result.value, result.derivative, result.num_type);
let should_continue = if current.contains_x() {
(callbacks.on_lhs)(&eval_expr)
} else {
(callbacks.on_rhs)(&eval_expr)
};
if !should_continue {
return false;
}
}
}
Err(e) => {
if config.show_pruned_arith {
eprintln!(
" [pruned arith] expression=\"{}\" reason={:?}",
current.to_postfix(),
e
);
}
}
}
}
if current.len() >= config.max_length {
return true;
}
let max_complexity = get_max_complexity(config, current.contains_x());
if current.complexity() >= max_complexity {
return true;
}
let min_remaining = min_complexity_to_complete(stack_depth, config);
if current.complexity() + min_remaining > max_complexity {
return true;
}
for &sym in &config.constants {
let sym_weight = config.symbol_table.weight(sym);
if current.complexity() + sym_weight > max_complexity {
continue;
}
if exceeds_symbol_limit(config, current, sym) {
continue;
}
if sym == Symbol::X && !config.generate_lhs {
continue;
}
current.push_with_table(sym, &config.symbol_table);
if !generate_recursive_streaming(
config,
target,
eval_context,
current,
stack_depth + 1,
callbacks,
) {
current.pop_with_table(&config.symbol_table);
return false;
}
current.pop_with_table(&config.symbol_table);
}
if config.generate_lhs && !config.constants.contains(&Symbol::X) {
let sym = Symbol::X;
let sym_weight = config.symbol_table.weight(sym);
if current.complexity() + sym_weight <= max_complexity
&& !exceeds_symbol_limit(config, current, sym)
{
current.push_with_table(sym, &config.symbol_table);
if !generate_recursive_streaming(
config,
target,
eval_context,
current,
stack_depth + 1,
callbacks,
) {
current.pop_with_table(&config.symbol_table);
return false;
}
current.pop_with_table(&config.symbol_table);
}
}
if stack_depth >= 1 {
for &sym in &config.unary_ops {
let sym_weight = config.symbol_table.weight(sym);
if current.complexity() + sym_weight > max_complexity {
continue;
}
if exceeds_symbol_limit(config, current, sym) {
continue;
}
if should_prune_unary(current, sym) {
continue;
}
current.push_with_table(sym, &config.symbol_table);
if !generate_recursive_streaming(
config,
target,
eval_context,
current,
stack_depth,
callbacks,
) {
current.pop_with_table(&config.symbol_table);
return false;
}
current.pop_with_table(&config.symbol_table);
}
}
if stack_depth >= 2 {
for &sym in &config.binary_ops {
let sym_weight = config.symbol_table.weight(sym);
if current.complexity() + sym_weight > max_complexity {
continue;
}
if exceeds_symbol_limit(config, current, sym) {
continue;
}
if should_prune_binary(current, sym) {
continue;
}
current.push_with_table(sym, &config.symbol_table);
if !generate_recursive_streaming(
config,
target,
eval_context,
current,
stack_depth - 1,
callbacks,
) {
current.pop_with_table(&config.symbol_table);
return false;
}
current.pop_with_table(&config.symbol_table);
}
}
true
}
fn generate_recursive(
config: &GenConfig,
target: f64,
eval_context: EvalContext<'_>,
current: &mut Expression,
stack_depth: usize,
lhs_out: &mut Vec<EvaluatedExpr>,
rhs_out: &mut Vec<EvaluatedExpr>,
) {
if stack_depth == 1 && !current.is_empty() {
match evaluate_fast_with_context(current, target, &eval_context) {
Ok(result) => {
if should_include_expression(
&result,
config,
current.complexity(),
current.contains_x(),
) {
let expr = current.clone();
let eval_expr =
EvaluatedExpr::new(expr, result.value, result.derivative, result.num_type);
if current.contains_x() {
lhs_out.push(eval_expr);
} else {
rhs_out.push(eval_expr);
}
}
}
Err(e) => {
if config.show_pruned_arith {
eprintln!(
" [pruned arith] expression=\"{}\" reason={:?}",
current.to_postfix(),
e
);
}
}
}
}
if current.len() >= config.max_length {
return;
}
let max_complexity = get_max_complexity(config, current.contains_x());
if current.complexity() >= max_complexity {
return;
}
let min_remaining = min_complexity_to_complete(stack_depth, config);
if current.complexity() + min_remaining > max_complexity {
return;
}
for &sym in &config.constants {
let sym_weight = config.symbol_table.weight(sym);
if current.complexity() + sym_weight > max_complexity {
continue;
}
if exceeds_symbol_limit(config, current, sym) {
continue;
}
if sym == Symbol::X && !config.generate_lhs {
continue;
}
current.push_with_table(sym, &config.symbol_table);
generate_recursive(
config,
target,
eval_context,
current,
stack_depth + 1,
lhs_out,
rhs_out,
);
current.pop_with_table(&config.symbol_table);
}
if config.generate_lhs && !config.constants.contains(&Symbol::X) {
let sym = Symbol::X;
let sym_weight = config.symbol_table.weight(sym);
if current.complexity() + sym_weight <= max_complexity
&& !exceeds_symbol_limit(config, current, sym)
{
current.push_with_table(sym, &config.symbol_table);
generate_recursive(
config,
target,
eval_context,
current,
stack_depth + 1,
lhs_out,
rhs_out,
);
current.pop_with_table(&config.symbol_table);
}
}
if stack_depth >= 1 {
for &sym in &config.unary_ops {
let sym_weight = config.symbol_table.weight(sym);
if current.complexity() + sym_weight > max_complexity {
continue;
}
if exceeds_symbol_limit(config, current, sym) {
continue;
}
if should_prune_unary(current, sym) {
continue;
}
current.push_with_table(sym, &config.symbol_table);
generate_recursive(
config,
target,
eval_context,
current,
stack_depth,
lhs_out,
rhs_out,
);
current.pop_with_table(&config.symbol_table);
}
}
if stack_depth >= 2 {
for &sym in &config.binary_ops {
let sym_weight = config.symbol_table.weight(sym);
if current.complexity() + sym_weight > max_complexity {
continue;
}
if exceeds_symbol_limit(config, current, sym) {
continue;
}
if should_prune_binary(current, sym) {
continue;
}
current.push_with_table(sym, &config.symbol_table);
generate_recursive(
config,
target,
eval_context,
current,
stack_depth - 1,
lhs_out,
rhs_out,
);
current.pop_with_table(&config.symbol_table);
}
}
}
fn min_complexity_to_complete(stack_depth: usize, config: &GenConfig) -> u32 {
if stack_depth <= 1 {
return 0;
}
let min_binary_weight = config
.binary_ops
.iter()
.map(|s| config.symbol_table.weight(*s))
.min()
.unwrap_or(4);
((stack_depth - 1) as u32) * min_binary_weight
}
fn should_prune_unary(expr: &Expression, sym: Symbol) -> bool {
let symbols = expr.symbols();
if symbols.is_empty() {
return false;
}
let last = symbols[symbols.len() - 1];
use Symbol::*;
match (last, sym) {
(Neg, Neg) => true,
(Recip, Recip) => true,
(Square, Sqrt) => true,
(Sqrt, Square) => true,
(Exp, Ln) => true,
(Ln, Exp) => true,
(Sqrt, Recip) => true,
(Square, Recip) => true,
(Ln, Recip) => true,
(Square, Square) => true,
(Sqrt, Sqrt) => true,
(Sub, Neg) => true,
(SinPi, SinPi) => true,
(CosPi, CosPi) => true,
(Exp, Exp) => true,
(Exp, LambertW) => true,
(Recip, LambertW) => true,
_ => false,
}
}
fn should_prune_binary(expr: &Expression, sym: Symbol) -> bool {
let symbols = expr.symbols();
if symbols.len() < 2 {
return false;
}
let last = symbols[symbols.len() - 1];
let prev = symbols[symbols.len() - 2];
use Symbol::*;
match sym {
Sub if is_same_subexpr(symbols, 2) => true,
Sub if last == X && prev == X => true,
Div if is_same_subexpr(symbols, 2) => true,
Div if last == X && prev == X => true,
Div if last == One => true,
Add if is_same_subexpr(symbols, 2) => true,
Add if last == Neg
&& symbols.len() >= 3
&& symbols[symbols.len() - 2] == X
&& prev == X =>
{
true
}
Pow if prev == One => true,
Pow if last == One => true,
Mul if last == One || prev == One => true,
Root if prev == One => true,
Root if last == One => true,
Root if last == Two => true,
Log if last == X && prev == X => true,
Log if prev == One || last == One => true,
Log if prev == E => true,
Add | Mul if prev > last && is_constant(last) && is_constant(prev) => true,
_ => false,
}
}
fn is_same_subexpr(symbols: &[Symbol], n: usize) -> bool {
if symbols.len() < n * 2 || n < 2 {
return false;
}
let mut stack_depths: Vec<usize> = Vec::with_capacity(symbols.len() + 1);
stack_depths.push(0);
for &sym in symbols {
let prev_depth = *stack_depths.last().unwrap();
let new_depth = match sym.seft() {
Seft::A => prev_depth + 1,
Seft::B => prev_depth, Seft::C => prev_depth - 1, };
stack_depths.push(new_depth);
}
let final_depth = *stack_depths.last().unwrap();
if final_depth < n {
return false;
}
let mut subexpr_starts: Vec<usize> = Vec::with_capacity(n);
let mut target_depth = final_depth;
for i in (0..symbols.len()).rev() {
if stack_depths[i] == target_depth && stack_depths[i + 1] > target_depth {
subexpr_starts.push(i);
target_depth -= 1;
if subexpr_starts.len() == n {
break;
}
}
}
if subexpr_starts.len() != n {
return false;
}
if n == 2 && subexpr_starts.len() == 2 {
let start1 = subexpr_starts[1]; let start2 = subexpr_starts[0]; let end1 = start2; let end2 = symbols.len();
if end1 - start1 == end2 - start2 {
return symbols[start1..end1] == symbols[start2..end2];
}
}
false
}
fn is_constant(sym: Symbol) -> bool {
matches!(sym.seft(), Seft::A) && sym != Symbol::X
}
#[cfg(feature = "parallel")]
pub fn generate_all_parallel(config: &GenConfig, target: f64) -> GeneratedExprs {
generate_all_parallel_with_context(
config,
target,
&EvalContext::from_slices(&config.user_constants, &config.user_functions),
)
}
#[cfg(feature = "parallel")]
pub fn generate_all_parallel_with_context(
config: &GenConfig,
target: f64,
eval_context: &EvalContext<'_>,
) -> GeneratedExprs {
use rayon::prelude::*;
if has_rhs_symbol_overrides(config) {
return generate_all_with_context(config, target, eval_context);
}
let mut prefixes: Vec<(Expression, usize)> = Vec::new();
let mut immediate_results_lhs = Vec::new();
let mut immediate_results_rhs = Vec::new();
let first_symbols: Vec<Symbol> = config
.constants
.iter()
.copied()
.chain(
if config.generate_lhs && !config.constants.contains(&Symbol::X) {
Some(Symbol::X)
} else {
None
},
)
.filter(|&sym| {
config
.symbol_max_counts
.get(&sym)
.is_none_or(|&max| max > 0)
})
.collect();
for sym1 in first_symbols {
let mut expr1 = Expression::new();
expr1.push_with_table(sym1, &config.symbol_table);
let max_complexity = if expr1.contains_x() {
config.max_lhs_complexity
} else {
std::cmp::max(config.max_lhs_complexity, config.max_rhs_complexity)
};
if expr1.complexity() > max_complexity {
continue;
}
if let Ok(result) = evaluate_fast_with_context(&expr1, target, eval_context) {
if result.value.is_finite()
&& result.value.abs() <= MAX_GENERATED_VALUE
&& result.num_type >= config.min_num_type
{
let eval_expr = EvaluatedExpr::new(
expr1.clone(),
result.value,
result.derivative,
result.num_type,
);
if expr1.contains_x() {
if config.generate_lhs && expr1.complexity() <= config.max_lhs_complexity {
immediate_results_lhs.push(eval_expr);
}
} else if config.generate_rhs && expr1.complexity() <= config.max_rhs_complexity {
immediate_results_rhs.push(eval_expr);
}
}
}
if expr1.len() >= config.max_length {
continue;
}
let mut next_constants = config.constants.clone();
if config.generate_lhs && !next_constants.contains(&Symbol::X) {
next_constants.push(Symbol::X);
}
for &sym2 in &next_constants {
let sym2_weight = config.symbol_table.weight(sym2);
let next_max = if expr1.contains_x() || sym2 == Symbol::X {
config.max_lhs_complexity
} else {
std::cmp::max(config.max_lhs_complexity, config.max_rhs_complexity)
};
if expr1.complexity() + sym2_weight <= next_max
&& !exceeds_symbol_limit(config, &expr1, sym2)
{
let mut expr2 = expr1.clone();
expr2.push_with_table(sym2, &config.symbol_table);
let min_remaining = min_complexity_to_complete(2, config);
if expr2.complexity() + min_remaining <= next_max {
prefixes.push((expr2, 2));
}
}
}
for &sym2 in &config.unary_ops {
let sym2_weight = config.symbol_table.weight(sym2);
if expr1.complexity() + sym2_weight <= max_complexity
&& !exceeds_symbol_limit(config, &expr1, sym2)
&& !should_prune_unary(&expr1, sym2)
{
let mut expr2 = expr1.clone();
expr2.push_with_table(sym2, &config.symbol_table);
let min_remaining = min_complexity_to_complete(1, config);
if expr2.complexity() + min_remaining <= max_complexity {
prefixes.push((expr2, 1));
}
}
}
}
let results: Vec<(Vec<EvaluatedExpr>, Vec<EvaluatedExpr>)> = prefixes
.into_par_iter()
.map(|(mut expr, depth)| {
let mut lhs = Vec::new();
let mut rhs = Vec::new();
generate_recursive(
config,
target,
*eval_context,
&mut expr,
depth,
&mut lhs,
&mut rhs,
);
(lhs, rhs)
})
.collect();
let mut lhs_raw = immediate_results_lhs;
let mut rhs_raw = immediate_results_rhs;
for (lhs, rhs) in results {
lhs_raw.extend(lhs);
rhs_raw.extend(rhs);
}
let mut rhs_map: HashMap<i64, EvaluatedExpr> = HashMap::new();
for expr in rhs_raw {
let key = quantize_value(expr.value);
rhs_map
.entry(key)
.and_modify(|existing| {
if expr.expr.complexity() < existing.expr.complexity() {
*existing = expr.clone();
}
})
.or_insert(expr);
}
let mut lhs_map: HashMap<LhsKey, EvaluatedExpr> = HashMap::new();
for expr in lhs_raw {
let key = (quantize_value(expr.value), quantize_value(expr.derivative));
lhs_map
.entry(key)
.and_modify(|existing| {
if expr.expr.complexity() < existing.expr.complexity() {
*existing = expr.clone();
}
})
.or_insert(expr);
}
GeneratedExprs {
lhs: lhs_map.into_values().collect(),
rhs: rhs_map.into_values().collect(),
}
}
#[cfg(test)]
mod tests {
use super::*;
fn fast_test_config() -> GenConfig {
GenConfig {
max_lhs_complexity: 20,
max_rhs_complexity: 20,
max_length: 8,
constants: vec![
Symbol::One,
Symbol::Two,
Symbol::Three,
Symbol::Four,
Symbol::Five,
Symbol::Pi,
Symbol::E,
],
unary_ops: vec![Symbol::Neg, Symbol::Recip, Symbol::Square, Symbol::Sqrt],
binary_ops: vec![Symbol::Add, Symbol::Sub, Symbol::Mul, Symbol::Div],
rhs_constants: None,
rhs_unary_ops: None,
rhs_binary_ops: None,
symbol_max_counts: HashMap::new(),
rhs_symbol_max_counts: None,
min_num_type: NumType::Transcendental,
generate_lhs: true,
generate_rhs: true,
user_constants: Vec::new(),
user_functions: Vec::new(),
show_pruned_arith: false,
symbol_table: Arc::new(SymbolTable::new()),
}
}
#[test]
fn test_generate_simple() {
let mut config = fast_test_config();
config.generate_lhs = false;
let result = generate_all(&config, 1.0);
assert!(!result.rhs.is_empty());
for expr in &result.rhs {
assert!(!expr.expr.contains_x());
}
}
#[test]
fn test_generate_lhs() {
let mut config = fast_test_config();
config.generate_rhs = false;
let result = generate_all(&config, 2.0);
assert!(!result.lhs.is_empty());
for expr in &result.lhs {
assert!(expr.expr.contains_x());
}
}
#[test]
fn test_complexity_limit() {
let config = fast_test_config();
let result = generate_all(&config, 1.0);
for expr in &result.rhs {
assert!(expr.expr.complexity() <= config.max_rhs_complexity);
}
for expr in &result.lhs {
assert!(expr.expr.complexity() <= config.max_lhs_complexity);
}
}
#[test]
fn test_generate_all_with_limit_aborts_when_exceeded() {
let mut config = fast_test_config();
config.max_lhs_complexity = 30;
config.max_rhs_complexity = 30;
let unlimited = generate_all(&config, 2.5);
let total_unlimited = unlimited.lhs.len() + unlimited.rhs.len();
assert!(
total_unlimited > 10,
"Test config should generate >10 expressions"
);
let limit = total_unlimited / 2; let result = generate_all_with_limit(&config, 2.5, limit);
assert!(
result.is_none(),
"generate_all_with_limit should return None when limit ({}) is exceeded (actual: {})",
limit,
total_unlimited
);
}
#[test]
fn test_generate_all_with_limit_succeeds_when_within_limit() {
let mut config = fast_test_config();
config.max_lhs_complexity = 30;
config.max_rhs_complexity = 30;
let result = generate_all_with_limit(&config, 2.5, 10_000);
assert!(
result.is_some(),
"generate_all_with_limit should return Some when limit is not exceeded"
);
let generated = result.unwrap();
assert!(!generated.lhs.is_empty() || !generated.rhs.is_empty());
}
fn expr_from_postfix(s: &str) -> Expression {
Expression::parse(s).expect("valid expression")
}
#[test]
fn test_constraints_default_allows_all() {
let opts = ExpressionConstraintOptions::default();
let expr = expr_from_postfix("xp^"); assert!(
expression_respects_constraints(&expr, opts),
"x^pi should be allowed with default options"
);
let expr = expr_from_postfix("eS"); assert!(
expression_respects_constraints(&expr, opts),
"sinpi(e) should be allowed with default options"
);
}
#[test]
fn test_constraints_rational_exponents_rejects_transcendental() {
let opts = ExpressionConstraintOptions {
rational_exponents: true,
..Default::default()
};
let expr = expr_from_postfix("xp^");
assert!(
!expression_respects_constraints(&expr, opts),
"x^pi should be rejected with rational_exponents=true"
);
let expr = expr_from_postfix("xe^");
assert!(
!expression_respects_constraints(&expr, opts),
"x^e should be rejected with rational_exponents=true"
);
}
#[test]
fn test_constraints_rational_exponents_allows_integer() {
let opts = ExpressionConstraintOptions {
rational_exponents: true,
..Default::default()
};
let expr = expr_from_postfix("x2^");
assert!(
expression_respects_constraints(&expr, opts),
"x^2 should be allowed with rational_exponents=true"
);
let expr = expr_from_postfix("x1^");
assert!(
expression_respects_constraints(&expr, opts),
"x^1 should be allowed with rational_exponents=true"
);
}
#[test]
fn test_constraints_rational_trig_args_rejects_irrational() {
let opts = ExpressionConstraintOptions {
rational_trig_args: true,
..Default::default()
};
let expr = expr_from_postfix("eS"); assert!(
!expression_respects_constraints(&expr, opts),
"sinpi(e) should be rejected with rational_trig_args=true"
);
let expr = expr_from_postfix("pS"); assert!(
!expression_respects_constraints(&expr, opts),
"sinpi(pi) should be rejected with rational_trig_args=true"
);
}
#[test]
fn test_constraints_rational_trig_args_allows_rational() {
let opts = ExpressionConstraintOptions {
rational_trig_args: true,
..Default::default()
};
let expr = expr_from_postfix("1S"); assert!(
expression_respects_constraints(&expr, opts),
"sinpi(1) should be allowed with rational_trig_args=true"
);
let expr = expr_from_postfix("2S");
assert!(
expression_respects_constraints(&expr, opts),
"sinpi(2) should be allowed with rational_trig_args=true"
);
}
#[test]
fn test_constraints_rational_trig_args_rejects_x() {
let opts = ExpressionConstraintOptions {
rational_trig_args: true,
..Default::default()
};
let expr = expr_from_postfix("xS"); assert!(
!expression_respects_constraints(&expr, opts),
"sinpi(x) should be rejected with rational_trig_args=true"
);
}
#[test]
fn test_constraints_max_trig_cycles() {
let opts = ExpressionConstraintOptions {
max_trig_cycles: Some(2),
..Default::default()
};
let expr = expr_from_postfix("xS"); assert!(
expression_respects_constraints(&expr, opts),
"1 trig op should pass with max=2"
);
let expr = expr_from_postfix("xCS");
assert!(
expression_respects_constraints(&expr, opts),
"2 trig ops should pass with max=2"
);
let expr = expr_from_postfix("xTCS");
assert!(
!expression_respects_constraints(&expr, opts),
"3 trig ops should fail with max=2"
);
}
#[test]
fn test_constraints_max_trig_cycles_none_unlimited() {
let opts = ExpressionConstraintOptions {
max_trig_cycles: None, ..Default::default()
};
let expr = expr_from_postfix("xTCSTCS");
assert!(
expression_respects_constraints(&expr, opts),
"Unlimited trig should pass any depth"
);
}
#[test]
fn test_constraints_combined() {
let opts = ExpressionConstraintOptions {
rational_exponents: true,
rational_trig_args: true,
max_trig_cycles: Some(1),
..Default::default()
};
let expr = expr_from_postfix("x2^1S+"); assert!(
expression_respects_constraints(&expr, opts),
"x^2 + sinpi(1) should pass all constraints"
);
let expr = expr_from_postfix("xp^");
assert!(
!expression_respects_constraints(&expr, opts),
"x^pi should fail rational_exponents"
);
let expr = expr_from_postfix("xS"); assert!(
!expression_respects_constraints(&expr, opts),
"sinpi(x) should fail rational_trig_args"
);
let expr = expr_from_postfix("1CS"); assert!(
!expression_respects_constraints(&expr, opts),
"double trig should fail max_trig_cycles=1"
);
}
#[test]
fn test_constraints_malformed_expression() {
let opts = ExpressionConstraintOptions::default();
let expr = Expression::from_symbols(&[crate::symbol::Symbol::Add]); assert!(
!expression_respects_constraints(&expr, opts),
"Malformed expression should return false"
);
let expr =
Expression::from_symbols(&[crate::symbol::Symbol::One, crate::symbol::Symbol::Two]);
assert!(
!expression_respects_constraints(&expr, opts),
"Incomplete expression should return false"
);
}
#[test]
fn test_constraints_user_constant_types() {
let mut user_types = [NumType::Transcendental; 16];
user_types[0] = NumType::Integer;
let opts = ExpressionConstraintOptions {
rational_exponents: true,
user_constant_types: user_types,
..Default::default()
};
assert_eq!(opts.user_constant_types[0], NumType::Integer);
}
}