Skip to main content

ries_rs/
gen.rs

1//! Expression generation
2//!
3//! Generates valid postfix expressions by enumerating "forms" (stack effect patterns).
4//!
5//! # Streaming Architecture
6//!
7//! For high complexity levels, the traditional approach of generating ALL expressions
8//! into memory before matching can cause memory exhaustion. This module provides both:
9//!
10//! - **Batch generation**: `generate_all()` returns all expressions (backward compatible)
11//! - **Streaming generation**: `generate_streaming()` processes expressions via callbacks
12//!
13//! Streaming reduces memory from O(expressions) to O(depth) by processing expressions
14//! as they're generated rather than accumulating them.
15
16use crate::eval::{evaluate_fast_with_context, EvalContext};
17use crate::symbol_table::SymbolTable;
18use std::sync::Arc;
19
20// =============================================================================
21// NAMED CONSTANTS FOR QUANTIZATION AND VALUE LIMITS
22// =============================================================================
23
24/// Scale factor for quantizing floating-point values to integers.
25///
26/// This preserves approximately 8 significant digits, which is sufficient
27/// for deduplication while avoiding overflow when converting to i64.
28/// Values are quantized as: `(v * QUANTIZE_SCALE).round() as i64`
29const QUANTIZE_SCALE: f64 = 1e8;
30
31/// Maximum absolute value for quantization before using sentinel values.
32///
33/// Values larger than this threshold are represented by sentinel values
34/// (i64::MAX - 1 for positive, i64::MIN + 1 for negative) to avoid
35/// overflow during the quantization calculation.
36const MAX_QUANTIZED_VALUE: f64 = 1e10;
37
38/// Maximum absolute value for generated expressions.
39///
40/// Expressions with values larger than this are considered overflow-prone
41/// and unlikely to be useful, so they are filtered out during generation.
42const MAX_GENERATED_VALUE: f64 = 1e12;
43use crate::expr::{EvaluatedExpr, Expression, MAX_EXPR_LEN};
44use crate::profile::UserConstant;
45use crate::symbol::{NumType, Seft, Symbol};
46use crate::udf::UserFunction;
47use std::collections::HashMap;
48
49/// Configuration for expression generation
50///
51/// Controls which symbols are available, complexity limits,
52/// and various generation options for creating candidate expressions
53/// that may solve a given equation.
54///
55/// # Architecture
56///
57/// Expressions are generated in two categories:
58/// - **LHS (Left-Hand Side)**: Expressions containing `x`, representing functions like `f(x)`
59/// - **RHS (Right-Hand Side)**: Constant expressions not containing `x`, like `π²` or `sqrt(2)`
60///
61/// The generator creates all valid expressions up to the configured complexity limits,
62/// then the solver finds pairs where `LHS(target) ≈ RHS`.
63///
64/// # Example
65///
66/// ```rust
67/// use ries_rs::gen::GenConfig;
68/// use ries_rs::symbol::Symbol;
69/// use std::collections::HashMap;
70///
71/// let config = GenConfig {
72///     max_lhs_complexity: 50,
73///     max_rhs_complexity: 30,
74///     max_length: 12,
75///     constants: vec![Symbol::One, Symbol::Two, Symbol::Pi, Symbol::E],
76///     unary_ops: vec![Symbol::Neg, Symbol::Sqrt, Symbol::Square],
77///     binary_ops: vec![Symbol::Add, Symbol::Sub, Symbol::Mul, Symbol::Div],
78///     ..GenConfig::default()
79/// };
80/// ```
81#[derive(Clone)]
82pub struct GenConfig {
83    /// Maximum complexity score for left-hand-side expressions.
84    ///
85    /// LHS expressions contain `x` and represent the function side of equations.
86    /// Higher values allow more complex expressions (e.g., `sin(x) + x²`), but
87    /// exponentially increase search time and memory usage.
88    ///
89    /// Default: 128 (allows fairly complex expressions)
90    pub max_lhs_complexity: u32,
91
92    /// Maximum complexity score for right-hand-side expressions.
93    ///
94    /// RHS expressions are constants not containing `x`. Since they don't need
95    /// to be solved for, they can typically use lower complexity limits than LHS.
96    ///
97    /// Default: 128
98    pub max_rhs_complexity: u32,
99
100    /// Maximum number of symbols in a single expression.
101    ///
102    /// This is a hard limit on expression length regardless of complexity score.
103    /// Prevents pathological cases with many low-complexity symbols.
104    ///
105    /// Default: `MAX_EXPR_LEN` (255)
106    pub max_length: usize,
107
108    /// Symbols available for constants and variables (Seft::A type).
109    ///
110    /// These push a value onto the expression stack. Typically includes:
111    /// - `One`, `Two`, `Three`, etc. (numeric constants)
112    /// - `Pi`, `E` (mathematical constants)
113    /// - `X` (the variable to solve for)
114    ///
115    /// Default: All built-in constants from `Symbol::constants()`
116    pub constants: Vec<Symbol>,
117
118    /// Symbols available for unary operations (Seft::B type).
119    ///
120    /// These transform a single value: `f(a)`. Includes operations like:
121    /// - `Neg` (negation: `-a`)
122    /// - `Sqrt`, `Square` (powers and roots)
123    /// - `SinPi`, `CosPi` (trigonometric functions)
124    /// - `Ln`, `Exp` (logarithmic and exponential)
125    /// - `Recip` (reciprocal: `1/a`)
126    ///
127    /// Default: All built-in unary operators from `Symbol::unary_ops()`
128    pub unary_ops: Vec<Symbol>,
129
130    /// Symbols available for binary operations (Seft::C type).
131    ///
132    /// These combine two values: `f(a, b)`. Includes operations like:
133    /// - `Add`, `Sub`, `Mul`, `Div` (arithmetic)
134    /// - `Pow`, `Root`, `Log` (power functions and logarithms)
135    ///
136    /// Default: All built-in binary operators from `Symbol::binary_ops()`
137    pub binary_ops: Vec<Symbol>,
138
139    /// Optional override for RHS-only constant symbols.
140    ///
141    /// When set, RHS expressions use these symbols instead of `constants`.
142    /// Useful for generating LHS with more symbols but keeping RHS simple.
143    ///
144    /// Default: `None` (use `constants` for both LHS and RHS)
145    pub rhs_constants: Option<Vec<Symbol>>,
146
147    /// Optional override for RHS-only unary operators.
148    ///
149    /// When set, RHS expressions use these operators instead of `unary_ops`.
150    /// Example: allow Lambert W in LHS only, exclude from RHS constants.
151    ///
152    /// Default: `None` (use `unary_ops` for both LHS and RHS)
153    pub rhs_unary_ops: Option<Vec<Symbol>>,
154
155    /// Optional override for RHS-only binary operators.
156    ///
157    /// When set, RHS expressions use these operators instead of `binary_ops`.
158    ///
159    /// Default: `None` (use `binary_ops` for both LHS and RHS)
160    pub rhs_binary_ops: Option<Vec<Symbol>>,
161
162    /// Maximum usage count per symbol within a single expression.
163    ///
164    /// Maps each symbol to the maximum number of times it can appear.
165    /// Useful for limiting redundancy (e.g., max 2 uses of `Pi`).
166    /// Corresponds to the `-O` command-line option.
167    ///
168    /// Default: Empty (no limits)
169    pub symbol_max_counts: HashMap<Symbol, u32>,
170
171    /// Optional RHS-only symbol count limits.
172    ///
173    /// When set, applies different symbol count limits to RHS expressions.
174    /// Corresponds to the `--O-RHS` command-line option.
175    ///
176    /// Default: `None` (use `symbol_max_counts` for both)
177    pub rhs_symbol_max_counts: Option<HashMap<Symbol, u32>>,
178
179    /// Minimum numeric type required for generated expressions.
180    ///
181    /// Filters expressions by the "sophistication" of numbers they produce:
182    /// - `Integer`: Only integer results
183    /// - `Rational`: Rational numbers (fractions)
184    /// - `Algebraic`: Algebraic numbers (roots of polynomials)
185    /// - `Transcendental`: Any real number (including π, e, trig)
186    ///
187    /// Lower values restrict output to simpler mathematical constructs.
188    ///
189    /// Default: `NumType::Transcendental` (accept all)
190    pub min_num_type: NumType,
191
192    /// Whether to generate LHS expressions containing `x`.
193    ///
194    /// Set to `false` if you only need constant RHS expressions.
195    /// Can significantly reduce generation time when LHS is not needed.
196    ///
197    /// Default: `true`
198    pub generate_lhs: bool,
199
200    /// Whether to generate RHS constant expressions.
201    ///
202    /// Set to `false` if you only need LHS expressions.
203    /// Useful for specific analysis tasks.
204    ///
205    /// Default: `true`
206    pub generate_rhs: bool,
207
208    /// User-defined constants for custom searches.
209    ///
210    /// These constants are available during expression evaluation,
211    /// allowing searches involving domain-specific values.
212    /// Defined via `-N` command-line option.
213    ///
214    /// Default: Empty
215    pub user_constants: Vec<UserConstant>,
216
217    /// User-defined functions for custom searches.
218    ///
219    /// Custom functions that can appear in generated expressions,
220    /// extending the available operations beyond built-in symbols.
221    /// Defined via `-F` command-line option.
222    ///
223    /// Default: Empty
224    pub user_functions: Vec<UserFunction>,
225
226    /// Enable diagnostic output for arithmetic pruning.
227    ///
228    /// When `true`, prints information about expressions that were
229    /// discarded due to arithmetic errors (overflow, domain errors, etc.).
230    /// Useful for debugging generation behavior.
231    ///
232    /// Default: `false`
233    pub show_pruned_arith: bool,
234
235    /// Symbol table with weights and display names.
236    ///
237    /// Provides complexity weights for each symbol and custom display
238    /// names. Weights control how "expensive" each symbol is toward
239    /// the complexity limit.
240    ///
241    /// Default: Empty table (uses built-in default weights)
242    pub symbol_table: Arc<SymbolTable>,
243}
244
245/// Options for additional expression constraints
246///
247/// These constraints allow filtering expressions based on their numeric properties
248/// or structural limits (like trig cycles or exponent types).
249#[derive(Debug, Clone, Copy)]
250pub struct ExpressionConstraintOptions {
251    /// If true, power exponents must be rational (no transcendental exponents like x^pi)
252    pub rational_exponents: bool,
253    /// If true, trigonometric function arguments must be rational
254    pub rational_trig_args: bool,
255    /// Maximum number of trigonometric operations allowed in an expression
256    pub max_trig_cycles: Option<u32>,
257    /// Inherited numeric types for user-defined constants 0-15
258    pub user_constant_types: [NumType; 16],
259    /// Inherited numeric types for user-defined functions 0-15
260    pub user_function_types: [NumType; 16],
261}
262
263impl Default for ExpressionConstraintOptions {
264    fn default() -> Self {
265        Self {
266            rational_exponents: false,
267            rational_trig_args: false,
268            max_trig_cycles: None,
269            user_constant_types: [NumType::Transcendental; 16],
270            user_function_types: [NumType::Transcendental; 16],
271        }
272    }
273}
274
275/// Check if an expression respects the configured structural and numeric constraints.
276///
277/// This performs a symbolic walkthrough of the expression to verify that it
278/// matches the requested properties (e.g., no transcendental exponents).
279pub fn expression_respects_constraints(
280    expression: &Expression,
281    opts: ExpressionConstraintOptions,
282) -> bool {
283    #[derive(Clone, Copy)]
284    struct ConstraintValue {
285        has_x: bool,
286        num_type: NumType,
287    }
288
289    let mut stack: Vec<ConstraintValue> = Vec::with_capacity(expression.len());
290    let mut trig_ops: u32 = 0;
291
292    for &sym in expression.symbols() {
293        match sym.seft() {
294            Seft::A => {
295                let num_type = if let Some(idx) = sym.user_constant_index() {
296                    opts.user_constant_types[idx as usize]
297                } else {
298                    sym.inherent_type()
299                };
300                stack.push(ConstraintValue {
301                    has_x: sym == Symbol::X,
302                    num_type,
303                });
304            }
305            Seft::B => {
306                let Some(arg) = stack.pop() else {
307                    return false;
308                };
309
310                if matches!(sym, Symbol::SinPi | Symbol::CosPi | Symbol::TanPi) {
311                    trig_ops = trig_ops.saturating_add(1);
312                    if opts.rational_trig_args && (arg.has_x || arg.num_type < NumType::Rational) {
313                        return false;
314                    }
315                }
316
317                let num_type = match sym {
318                    Symbol::Neg | Symbol::Square => arg.num_type,
319                    Symbol::Recip => {
320                        if arg.num_type >= NumType::Rational {
321                            NumType::Rational
322                        } else {
323                            arg.num_type
324                        }
325                    }
326                    Symbol::Sqrt => {
327                        if arg.num_type >= NumType::Rational {
328                            NumType::Algebraic
329                        } else {
330                            arg.num_type
331                        }
332                    }
333                    Symbol::UserFunction0
334                    | Symbol::UserFunction1
335                    | Symbol::UserFunction2
336                    | Symbol::UserFunction3
337                    | Symbol::UserFunction4
338                    | Symbol::UserFunction5
339                    | Symbol::UserFunction6
340                    | Symbol::UserFunction7
341                    | Symbol::UserFunction8
342                    | Symbol::UserFunction9
343                    | Symbol::UserFunction10
344                    | Symbol::UserFunction11
345                    | Symbol::UserFunction12
346                    | Symbol::UserFunction13
347                    | Symbol::UserFunction14
348                    | Symbol::UserFunction15 => {
349                        let idx = sym.user_function_index().unwrap_or(0) as usize;
350                        opts.user_function_types[idx]
351                    }
352                    _ => NumType::Transcendental,
353                };
354
355                stack.push(ConstraintValue {
356                    has_x: arg.has_x,
357                    num_type,
358                });
359            }
360            Seft::C => {
361                let Some(rhs) = stack.pop() else {
362                    return false;
363                };
364                let Some(lhs) = stack.pop() else {
365                    return false;
366                };
367
368                if opts.rational_exponents
369                    && sym == Symbol::Pow
370                    && (rhs.has_x || rhs.num_type < NumType::Rational)
371                {
372                    return false;
373                }
374
375                let num_type = match sym {
376                    Symbol::Add | Symbol::Sub | Symbol::Mul => lhs.num_type.combine(rhs.num_type),
377                    Symbol::Div => {
378                        let combined = lhs.num_type.combine(rhs.num_type);
379                        if combined == NumType::Integer {
380                            NumType::Rational
381                        } else {
382                            combined
383                        }
384                    }
385                    Symbol::Pow => {
386                        if rhs.has_x {
387                            NumType::Transcendental
388                        } else if rhs.num_type == NumType::Integer {
389                            lhs.num_type
390                        } else if lhs.num_type >= NumType::Rational
391                            && rhs.num_type >= NumType::Rational
392                        {
393                            NumType::Algebraic
394                        } else {
395                            NumType::Transcendental
396                        }
397                    }
398                    Symbol::Root => NumType::Algebraic,
399                    Symbol::Log | Symbol::Atan2 => NumType::Transcendental,
400                    _ => NumType::Transcendental,
401                };
402
403                stack.push(ConstraintValue {
404                    has_x: lhs.has_x || rhs.has_x,
405                    num_type,
406                });
407            }
408        }
409    }
410
411    if stack.len() != 1 {
412        return false;
413    }
414
415    opts.max_trig_cycles
416        .is_none_or(|max_cycles| trig_ops <= max_cycles)
417}
418
419impl Default for GenConfig {
420    fn default() -> Self {
421        Self {
422            max_lhs_complexity: 128,
423            max_rhs_complexity: 128,
424            max_length: MAX_EXPR_LEN,
425            constants: Symbol::constants().to_vec(),
426            unary_ops: Symbol::unary_ops().to_vec(),
427            binary_ops: Symbol::binary_ops().to_vec(),
428            rhs_constants: None,
429            rhs_unary_ops: None,
430            rhs_binary_ops: None,
431            symbol_max_counts: HashMap::new(),
432            rhs_symbol_max_counts: None,
433            min_num_type: NumType::Transcendental,
434            generate_lhs: true,
435            generate_rhs: true,
436            user_constants: Vec::new(),
437            user_functions: Vec::new(),
438            show_pruned_arith: false,
439            symbol_table: Arc::new(SymbolTable::new()),
440        }
441    }
442}
443
444/// Result of expression generation
445pub struct GeneratedExprs {
446    /// LHS expressions (contain x)
447    pub lhs: Vec<EvaluatedExpr>,
448    /// RHS expressions (constants only)
449    pub rhs: Vec<EvaluatedExpr>,
450}
451
452/// Callbacks for streaming expression generation
453///
454/// Using callbacks instead of accumulation allows processing expressions
455/// as they're generated, reducing memory from O(expressions) to O(depth).
456pub struct StreamingCallbacks<'a> {
457    /// Called for each RHS (constant-only) expression generated
458    /// Return false to stop generation early
459    pub on_rhs: &'a mut dyn FnMut(&EvaluatedExpr) -> bool,
460    /// Called for each LHS (contains x) expression generated
461    /// Return false to stop generation early
462    pub on_lhs: &'a mut dyn FnMut(&EvaluatedExpr) -> bool,
463}
464
465/// Quantize a value to reduce floating-point noise
466/// Key for LHS deduplication: (quantized value, quantized derivative)
467pub type LhsKey = (i64, i64);
468
469/// Uses ~8 significant digits for deduplication
470#[inline]
471pub fn quantize_value(v: f64) -> i64 {
472    if !v.is_finite() || v.abs() > MAX_QUANTIZED_VALUE {
473        // For very large values, use a different quantization to avoid overflow
474        if v > MAX_QUANTIZED_VALUE {
475            return i64::MAX - 1;
476        } else if v < -MAX_QUANTIZED_VALUE {
477            return i64::MIN + 1;
478        }
479        return i64::MAX;
480    }
481    // Scale to preserve ~8 significant digits (avoid overflow)
482    (v * QUANTIZE_SCALE).round() as i64
483}
484
485/// Generate all valid expressions up to the configured limits
486pub fn generate_all(config: &GenConfig, target: f64) -> GeneratedExprs {
487    generate_all_with_context(
488        config,
489        target,
490        &EvalContext::from_slices(&config.user_constants, &config.user_functions),
491    )
492}
493
494/// Generate all valid expressions up to the configured limits using an explicit evaluation context.
495pub fn generate_all_with_context(
496    config: &GenConfig,
497    target: f64,
498    eval_context: &EvalContext<'_>,
499) -> GeneratedExprs {
500    let mut lhs_raw = Vec::new();
501    let mut rhs_raw = Vec::new();
502
503    if config.generate_lhs && config.generate_rhs && has_rhs_symbol_overrides(config) {
504        // LHS pass with base symbol set.
505        let mut lhs_config = config.clone();
506        lhs_config.generate_lhs = true;
507        lhs_config.generate_rhs = false;
508        generate_recursive(
509            &lhs_config,
510            target,
511            *eval_context,
512            &mut Expression::new(),
513            0,
514            &mut lhs_raw,
515            &mut rhs_raw,
516        );
517
518        // RHS pass with RHS-specific symbol overrides.
519        let rhs_config = rhs_only_config(config);
520        generate_recursive(
521            &rhs_config,
522            target,
523            *eval_context,
524            &mut Expression::new(),
525            0,
526            &mut lhs_raw,
527            &mut rhs_raw,
528        );
529    } else {
530        // Generate expressions for each possible "form" (sequence of stack effects)
531        generate_recursive(
532            config,
533            target,
534            *eval_context,
535            &mut Expression::new(),
536            0, // current stack depth
537            &mut lhs_raw,
538            &mut rhs_raw,
539        );
540    }
541
542    // Deduplicate RHS by value, keeping simplest expression for each value
543    let mut rhs_map: HashMap<i64, EvaluatedExpr> = HashMap::new();
544    for expr in rhs_raw {
545        let key = quantize_value(expr.value);
546        rhs_map
547            .entry(key)
548            .and_modify(|existing| {
549                if expr.expr.complexity() < existing.expr.complexity() {
550                    *existing = expr.clone();
551                }
552            })
553            .or_insert(expr);
554    }
555
556    // Deduplicate LHS by (value, derivative), keeping simplest expression
557    let mut lhs_map: HashMap<LhsKey, EvaluatedExpr> = HashMap::new();
558    for expr in lhs_raw {
559        let key = (quantize_value(expr.value), quantize_value(expr.derivative));
560        lhs_map
561            .entry(key)
562            .and_modify(|existing| {
563                if expr.expr.complexity() < existing.expr.complexity() {
564                    *existing = expr.clone();
565                }
566            })
567            .or_insert(expr);
568    }
569
570    GeneratedExprs {
571        lhs: lhs_map.into_values().collect(),
572        rhs: rhs_map.into_values().collect(),
573    }
574}
575
576/// Generate expressions with an early-abort limit on total count.
577///
578/// Returns `Some(expressions)` if generation completed within the limit,
579/// or `None` if the limit was exceeded (caller should use streaming mode instead).
580///
581/// This is a safety mechanism to prevent OOM from unexpectedly large generation
582/// at high complexity levels. The limit check happens during generation, not after.
583///
584/// # Arguments
585///
586/// * `config` - Generation configuration (complexity limits, symbols)
587/// * `target` - Target value for evaluation
588/// * `max_expressions` - Maximum total expressions (LHS + RHS) before aborting
589///
590/// # Returns
591///
592/// * `Some(GeneratedExprs)` - if generation completed within limit
593/// * `None` - if the limit was exceeded during generation
594pub fn generate_all_with_limit(
595    config: &GenConfig,
596    target: f64,
597    max_expressions: usize,
598) -> Option<GeneratedExprs> {
599    generate_all_with_limit_and_context(
600        config,
601        target,
602        &EvalContext::from_slices(&config.user_constants, &config.user_functions),
603        max_expressions,
604    )
605}
606
607/// Generate expressions with an early-abort limit using an explicit evaluation context.
608pub fn generate_all_with_limit_and_context(
609    config: &GenConfig,
610    target: f64,
611    eval_context: &EvalContext<'_>,
612    max_expressions: usize,
613) -> Option<GeneratedExprs> {
614    use std::sync::atomic::{AtomicUsize, Ordering};
615    use std::sync::Arc;
616
617    let count = Arc::new(AtomicUsize::new(0));
618    let limit = max_expressions;
619
620    // Collect expressions if within limit
621    let mut lhs_raw = Vec::new();
622    let mut rhs_raw = Vec::new();
623
624    // Callback that counts expressions and stops when limit is hit
625    let mut callbacks = StreamingCallbacks {
626        on_lhs: &mut |expr| {
627            let current = count.fetch_add(1, Ordering::Relaxed) + 1;
628            if current > limit {
629                return false; // Abort generation
630            }
631            lhs_raw.push(expr.clone());
632            true
633        },
634        on_rhs: &mut |expr| {
635            let current = count.fetch_add(1, Ordering::Relaxed) + 1;
636            if current > limit {
637                return false; // Abort generation
638            }
639            rhs_raw.push(expr.clone());
640            true
641        },
642    };
643
644    generate_streaming_with_context(config, target, eval_context, &mut callbacks);
645
646    // Check if we exceeded the limit
647    let final_count = count.load(Ordering::Relaxed);
648    if final_count > limit {
649        return None;
650    }
651
652    // Deduplicate (same logic as generate_all)
653    let mut rhs_map: HashMap<i64, EvaluatedExpr> = HashMap::new();
654    for expr in rhs_raw {
655        let key = quantize_value(expr.value);
656        rhs_map
657            .entry(key)
658            .and_modify(|existing| {
659                if expr.expr.complexity() < existing.expr.complexity() {
660                    *existing = expr.clone();
661                }
662            })
663            .or_insert(expr);
664    }
665
666    let mut lhs_map: HashMap<LhsKey, EvaluatedExpr> = HashMap::new();
667    for expr in lhs_raw {
668        let key = (quantize_value(expr.value), quantize_value(expr.derivative));
669        lhs_map
670            .entry(key)
671            .and_modify(|existing| {
672                if expr.expr.complexity() < existing.expr.complexity() {
673                    *existing = expr.clone();
674                }
675            })
676            .or_insert(expr);
677    }
678
679    Some(GeneratedExprs {
680        lhs: lhs_map.into_values().collect(),
681        rhs: rhs_map.into_values().collect(),
682    })
683}
684
685/// Generate expressions with streaming callbacks for memory-efficient processing
686///
687/// This function is the foundation of the streaming architecture. Instead of
688/// accumulating all expressions in memory, it calls the provided callbacks
689/// for each generated expression, allowing immediate processing.
690///
691/// # Memory Efficiency
692///
693/// - Traditional: O(expressions) memory - all expressions stored before processing
694/// - Streaming: O(depth) memory - only the recursion stack is stored
695///
696/// # Early Exit
697///
698/// The callbacks can return `false` to signal early termination. This is useful
699/// when good matches have been found and additional expressions aren't needed.
700///
701/// # Deduplication
702///
703/// The caller is responsible for deduplication if needed. This allows flexibility
704/// in deduplication strategies (e.g., per-batch, per-tier, etc.).
705///
706/// # Example
707///
708/// ```no_run
709/// use ries_rs::gen::{GenConfig, StreamingCallbacks, generate_streaming};
710/// let config = GenConfig::default();
711/// let target = 2.5_f64;
712/// let mut rhs_count = 0;
713/// let mut lhs_count = 0;
714/// let mut callbacks = StreamingCallbacks {
715///     on_rhs: &mut |_expr| {
716///         rhs_count += 1;
717///         true // continue generation
718///     },
719///     on_lhs: &mut |_expr| {
720///         lhs_count += 1;
721///         true // continue generation
722///     },
723/// };
724/// generate_streaming(&config, target, &mut callbacks);
725/// ```
726pub fn generate_streaming(config: &GenConfig, target: f64, callbacks: &mut StreamingCallbacks) {
727    generate_streaming_with_context(
728        config,
729        target,
730        &EvalContext::from_slices(&config.user_constants, &config.user_functions),
731        callbacks,
732    );
733}
734
735/// Generate expressions with streaming callbacks using an explicit evaluation context.
736pub fn generate_streaming_with_context(
737    config: &GenConfig,
738    target: f64,
739    eval_context: &EvalContext<'_>,
740    callbacks: &mut StreamingCallbacks,
741) {
742    if config.generate_lhs && config.generate_rhs && has_rhs_symbol_overrides(config) {
743        let mut lhs_config = config.clone();
744        lhs_config.generate_lhs = true;
745        lhs_config.generate_rhs = false;
746        if !generate_recursive_streaming(
747            &lhs_config,
748            target,
749            *eval_context,
750            &mut Expression::new(),
751            0,
752            callbacks,
753        ) {
754            return;
755        }
756
757        let rhs_config = rhs_only_config(config);
758        generate_recursive_streaming(
759            &rhs_config,
760            target,
761            *eval_context,
762            &mut Expression::new(),
763            0,
764            callbacks,
765        );
766    } else {
767        generate_recursive_streaming(
768            config,
769            target,
770            *eval_context,
771            &mut Expression::new(),
772            0, // current stack depth
773            callbacks,
774        );
775    }
776}
777
778#[inline]
779fn has_rhs_symbol_overrides(config: &GenConfig) -> bool {
780    config.rhs_constants.is_some()
781        || config.rhs_unary_ops.is_some()
782        || config.rhs_binary_ops.is_some()
783        || config.rhs_symbol_max_counts.is_some()
784}
785
786/// Check if an evaluated expression meets generation criteria
787///
788/// This shared helper function is used by both batch and streaming generation
789/// to validate expressions before including them in results.
790#[inline]
791fn should_include_expression(
792    result: &crate::eval::EvalResult,
793    config: &GenConfig,
794    complexity: u32,
795    contains_x: bool,
796) -> bool {
797    result.value.is_finite()
798        && result.value.abs() <= MAX_GENERATED_VALUE
799        && result.num_type >= config.min_num_type
800        && if contains_x {
801            config.generate_lhs && complexity <= config.max_lhs_complexity
802        } else {
803            config.generate_rhs && complexity <= config.max_rhs_complexity
804        }
805}
806
807/// Calculate the appropriate complexity limit based on whether expression contains x
808///
809/// For expressions containing x, uses LHS limit.
810/// For RHS-only paths, uses RHS limit.
811/// For paths that might still add x, uses the max of both limits.
812#[inline]
813fn get_max_complexity(config: &GenConfig, contains_x: bool) -> u32 {
814    if contains_x {
815        config.max_lhs_complexity
816    } else {
817        // For RHS-only paths, use RHS limit
818        // For paths that might still add x, use the max of both
819        std::cmp::max(config.max_lhs_complexity, config.max_rhs_complexity)
820    }
821}
822
823fn rhs_only_config(config: &GenConfig) -> GenConfig {
824    let mut rhs_config = config.clone();
825    rhs_config.generate_lhs = false;
826    rhs_config.generate_rhs = true;
827    if let Some(constants) = &config.rhs_constants {
828        rhs_config.constants = constants.clone();
829    }
830    if let Some(unary_ops) = &config.rhs_unary_ops {
831        rhs_config.unary_ops = unary_ops.clone();
832    }
833    if let Some(binary_ops) = &config.rhs_binary_ops {
834        rhs_config.binary_ops = binary_ops.clone();
835    }
836    if let Some(rhs_symbol_max_counts) = &config.rhs_symbol_max_counts {
837        rhs_config.symbol_max_counts = rhs_symbol_max_counts.clone();
838    }
839    rhs_config
840}
841
842#[inline]
843fn exceeds_symbol_limit(config: &GenConfig, current: &Expression, sym: Symbol) -> bool {
844    config
845        .symbol_max_counts
846        .get(&sym)
847        .is_some_and(|&max| current.count_symbol(sym) >= max)
848}
849
850/// Recursively generate expressions with streaming callbacks
851///
852/// This is the core streaming generation function. It mirrors `generate_recursive`
853/// but calls callbacks instead of accumulating expressions.
854fn generate_recursive_streaming(
855    config: &GenConfig,
856    target: f64,
857    eval_context: EvalContext<'_>,
858    current: &mut Expression,
859    stack_depth: usize,
860    callbacks: &mut StreamingCallbacks,
861) -> bool {
862    // Check if we have a complete expression
863    if stack_depth == 1 && !current.is_empty() {
864        // Try to evaluate it with user constants and functions support
865        match evaluate_fast_with_context(current, target, &eval_context) {
866            Ok(result) => {
867                // Use shared validation helper
868                if should_include_expression(
869                    &result,
870                    config,
871                    current.complexity(),
872                    current.contains_x(),
873                ) {
874                    let expr = current.clone();
875                    let eval_expr =
876                        EvaluatedExpr::new(expr, result.value, result.derivative, result.num_type);
877
878                    // Call the appropriate callback; return false if it signals stop
879                    let should_continue = if current.contains_x() {
880                        (callbacks.on_lhs)(&eval_expr)
881                    } else {
882                        (callbacks.on_rhs)(&eval_expr)
883                    };
884                    if !should_continue {
885                        return false;
886                    }
887                }
888            }
889            Err(e) => {
890                // Expression was pruned due to arithmetic error
891                if config.show_pruned_arith {
892                    eprintln!(
893                        "  [pruned arith] expression=\"{}\" reason={:?}",
894                        current.to_postfix(),
895                        e
896                    );
897                }
898            }
899        }
900    }
901
902    // Check limits before recursing
903    if current.len() >= config.max_length {
904        return true;
905    }
906
907    // Use shared helper for complexity limit calculation
908    let max_complexity = get_max_complexity(config, current.contains_x());
909
910    if current.complexity() >= max_complexity {
911        return true;
912    }
913
914    // Calculate minimum additional complexity needed to complete expression
915    let min_remaining = min_complexity_to_complete(stack_depth, config);
916    if current.complexity() + min_remaining > max_complexity {
917        return true;
918    }
919
920    // Try adding each possible symbol
921
922    // Constants (Seft::A) - always increase stack by 1
923    for &sym in &config.constants {
924        let sym_weight = config.symbol_table.weight(sym);
925        if current.complexity() + sym_weight > max_complexity {
926            continue;
927        }
928        if exceeds_symbol_limit(config, current, sym) {
929            continue;
930        }
931
932        // Skip x if we only want RHS
933        if sym == Symbol::X && !config.generate_lhs {
934            continue;
935        }
936
937        current.push_with_table(sym, &config.symbol_table);
938        if !generate_recursive_streaming(
939            config,
940            target,
941            eval_context,
942            current,
943            stack_depth + 1,
944            callbacks,
945        ) {
946            current.pop_with_table(&config.symbol_table);
947            return false;
948        }
949        current.pop_with_table(&config.symbol_table);
950    }
951
952    // Also add x for LHS generation
953    if config.generate_lhs && !config.constants.contains(&Symbol::X) {
954        let sym = Symbol::X;
955        let sym_weight = config.symbol_table.weight(sym);
956        if current.complexity() + sym_weight <= max_complexity
957            && !exceeds_symbol_limit(config, current, sym)
958        {
959            current.push_with_table(sym, &config.symbol_table);
960            if !generate_recursive_streaming(
961                config,
962                target,
963                eval_context,
964                current,
965                stack_depth + 1,
966                callbacks,
967            ) {
968                current.pop_with_table(&config.symbol_table);
969                return false;
970            }
971            current.pop_with_table(&config.symbol_table);
972        }
973    }
974
975    // Unary operators (Seft::B) - need at least 1 on stack
976    if stack_depth >= 1 {
977        for &sym in &config.unary_ops {
978            let sym_weight = config.symbol_table.weight(sym);
979            if current.complexity() + sym_weight > max_complexity {
980                continue;
981            }
982            if exceeds_symbol_limit(config, current, sym) {
983                continue;
984            }
985
986            // Apply pruning rules
987            if should_prune_unary(current, sym) {
988                continue;
989            }
990
991            current.push_with_table(sym, &config.symbol_table);
992            if !generate_recursive_streaming(
993                config,
994                target,
995                eval_context,
996                current,
997                stack_depth,
998                callbacks,
999            ) {
1000                current.pop_with_table(&config.symbol_table);
1001                return false;
1002            }
1003            current.pop_with_table(&config.symbol_table);
1004        }
1005    }
1006
1007    // Binary operators (Seft::C) - need at least 2 on stack
1008    if stack_depth >= 2 {
1009        for &sym in &config.binary_ops {
1010            let sym_weight = config.symbol_table.weight(sym);
1011            if current.complexity() + sym_weight > max_complexity {
1012                continue;
1013            }
1014            if exceeds_symbol_limit(config, current, sym) {
1015                continue;
1016            }
1017
1018            // Apply pruning rules
1019            if should_prune_binary(current, sym) {
1020                continue;
1021            }
1022
1023            current.push_with_table(sym, &config.symbol_table);
1024            if !generate_recursive_streaming(
1025                config,
1026                target,
1027                eval_context,
1028                current,
1029                stack_depth - 1,
1030                callbacks,
1031            ) {
1032                current.pop_with_table(&config.symbol_table);
1033                return false;
1034            }
1035            current.pop_with_table(&config.symbol_table);
1036        }
1037    }
1038
1039    true
1040}
1041
1042/// Recursively generate expressions
1043fn generate_recursive(
1044    config: &GenConfig,
1045    target: f64,
1046    eval_context: EvalContext<'_>,
1047    current: &mut Expression,
1048    stack_depth: usize,
1049    lhs_out: &mut Vec<EvaluatedExpr>,
1050    rhs_out: &mut Vec<EvaluatedExpr>,
1051) {
1052    // Check if we have a complete expression
1053    if stack_depth == 1 && !current.is_empty() {
1054        // Try to evaluate it with user constants and functions support
1055        match evaluate_fast_with_context(current, target, &eval_context) {
1056            Ok(result) => {
1057                // Use shared validation helper
1058                if should_include_expression(
1059                    &result,
1060                    config,
1061                    current.complexity(),
1062                    current.contains_x(),
1063                ) {
1064                    let expr = current.clone();
1065                    let eval_expr =
1066                        EvaluatedExpr::new(expr, result.value, result.derivative, result.num_type);
1067
1068                    // Keep all LHS expressions; derivative≈0 cases handled in search
1069                    if current.contains_x() {
1070                        lhs_out.push(eval_expr);
1071                    } else {
1072                        rhs_out.push(eval_expr);
1073                    }
1074                }
1075            }
1076            Err(e) => {
1077                // Expression was pruned due to arithmetic error
1078                if config.show_pruned_arith {
1079                    eprintln!(
1080                        "  [pruned arith] expression=\"{}\" reason={:?}",
1081                        current.to_postfix(),
1082                        e
1083                    );
1084                }
1085            }
1086        }
1087    }
1088
1089    // Check limits before recursing
1090    if current.len() >= config.max_length {
1091        return;
1092    }
1093
1094    // Use shared helper for complexity limit calculation
1095    let max_complexity = get_max_complexity(config, current.contains_x());
1096
1097    if current.complexity() >= max_complexity {
1098        return;
1099    }
1100
1101    // Calculate minimum additional complexity needed to complete expression
1102    let min_remaining = min_complexity_to_complete(stack_depth, config);
1103    if current.complexity() + min_remaining > max_complexity {
1104        return;
1105    }
1106
1107    // Try adding each possible symbol
1108
1109    // Constants (Seft::A) - always increase stack by 1
1110    for &sym in &config.constants {
1111        let sym_weight = config.symbol_table.weight(sym);
1112        if current.complexity() + sym_weight > max_complexity {
1113            continue;
1114        }
1115        if exceeds_symbol_limit(config, current, sym) {
1116            continue;
1117        }
1118
1119        // Skip x if we only want RHS
1120        if sym == Symbol::X && !config.generate_lhs {
1121            continue;
1122        }
1123
1124        current.push_with_table(sym, &config.symbol_table);
1125        generate_recursive(
1126            config,
1127            target,
1128            eval_context,
1129            current,
1130            stack_depth + 1,
1131            lhs_out,
1132            rhs_out,
1133        );
1134        current.pop_with_table(&config.symbol_table);
1135    }
1136
1137    // Also add x for LHS generation
1138    if config.generate_lhs && !config.constants.contains(&Symbol::X) {
1139        let sym = Symbol::X;
1140        let sym_weight = config.symbol_table.weight(sym);
1141        if current.complexity() + sym_weight <= max_complexity
1142            && !exceeds_symbol_limit(config, current, sym)
1143        {
1144            current.push_with_table(sym, &config.symbol_table);
1145            generate_recursive(
1146                config,
1147                target,
1148                eval_context,
1149                current,
1150                stack_depth + 1,
1151                lhs_out,
1152                rhs_out,
1153            );
1154            current.pop_with_table(&config.symbol_table);
1155        }
1156    }
1157
1158    // Unary operators (Seft::B) - need at least 1 on stack
1159    if stack_depth >= 1 {
1160        for &sym in &config.unary_ops {
1161            let sym_weight = config.symbol_table.weight(sym);
1162            if current.complexity() + sym_weight > max_complexity {
1163                continue;
1164            }
1165            if exceeds_symbol_limit(config, current, sym) {
1166                continue;
1167            }
1168
1169            // Apply pruning rules
1170            if should_prune_unary(current, sym) {
1171                continue;
1172            }
1173
1174            current.push_with_table(sym, &config.symbol_table);
1175            generate_recursive(
1176                config,
1177                target,
1178                eval_context,
1179                current,
1180                stack_depth,
1181                lhs_out,
1182                rhs_out,
1183            );
1184            current.pop_with_table(&config.symbol_table);
1185        }
1186    }
1187
1188    // Binary operators (Seft::C) - need at least 2 on stack
1189    if stack_depth >= 2 {
1190        for &sym in &config.binary_ops {
1191            let sym_weight = config.symbol_table.weight(sym);
1192            if current.complexity() + sym_weight > max_complexity {
1193                continue;
1194            }
1195            if exceeds_symbol_limit(config, current, sym) {
1196                continue;
1197            }
1198
1199            // Apply pruning rules
1200            if should_prune_binary(current, sym) {
1201                continue;
1202            }
1203
1204            current.push_with_table(sym, &config.symbol_table);
1205            generate_recursive(
1206                config,
1207                target,
1208                eval_context,
1209                current,
1210                stack_depth - 1,
1211                lhs_out,
1212                rhs_out,
1213            );
1214            current.pop_with_table(&config.symbol_table);
1215        }
1216    }
1217}
1218
1219/// Calculate minimum complexity needed to reduce stack to depth 1
1220fn min_complexity_to_complete(stack_depth: usize, config: &GenConfig) -> u32 {
1221    if stack_depth <= 1 {
1222        return 0;
1223    }
1224
1225    // Need (stack_depth - 1) binary operators to reduce to 1
1226    let min_binary_weight = config
1227        .binary_ops
1228        .iter()
1229        .map(|s| config.symbol_table.weight(*s))
1230        .min()
1231        .unwrap_or(4);
1232
1233    ((stack_depth - 1) as u32) * min_binary_weight
1234}
1235
1236/// Pruning rules for unary operators to avoid redundant expressions
1237fn should_prune_unary(expr: &Expression, sym: Symbol) -> bool {
1238    let symbols = expr.symbols();
1239    if symbols.is_empty() {
1240        return false;
1241    }
1242
1243    let last = symbols[symbols.len() - 1];
1244
1245    use Symbol::*;
1246
1247    match (last, sym) {
1248        // Double negation: --a = a
1249        (Neg, Neg) => true,
1250        // Double reciprocal: 1/(1/a) = a
1251        (Recip, Recip) => true,
1252        // sqrt(a^2) = |a| (we don't handle absolute value)
1253        (Square, Sqrt) => true,
1254        // (sqrt(a))^2 = a
1255        (Sqrt, Square) => true,
1256        // ln(e^a) = a
1257        (Exp, Ln) => true,
1258        // e^(ln(a)) = a
1259        (Ln, Exp) => true,
1260
1261        // Additional pruning rules for cleaner output:
1262        // 1/sqrt(a) and 1/a^2 are rare, prefer a^-0.5 or a^-2 notation
1263        (Sqrt, Recip) => true,
1264        (Square, Recip) => true,
1265        // 1/ln(a) is rarely useful
1266        (Ln, Recip) => true,
1267        // Double square: (a^2)^2 = a^4, use power directly
1268        (Square, Square) => true,
1269        // Double sqrt: sqrt(sqrt(a)) = a^0.25, use power directly
1270        (Sqrt, Sqrt) => true,
1271        // Negation after subtraction is redundant with addition
1272        // e.g., -(a-b) = b-a which we could express directly
1273        (Sub, Neg) => true,
1274
1275        // ===== ENHANCED PRUNING RULES =====
1276        // Trig reduction: asin(sin(pi*x)/pi) = x, similar for acos
1277        // These are rarely useful and add many redundant expressions
1278        (SinPi, SinPi) => true,
1279        (CosPi, CosPi) => true,
1280        // asin after sinpi is identity (mod periodicity)
1281        // acos after cospi is identity (mod periodicity)
1282        // These patterns are captured by double application above
1283
1284        // Exp grows too fast - double exp is almost never useful
1285        (Exp, Exp) => true,
1286
1287        // LambertW after exp: W(e^a) = a, so W(e^x) = x
1288        (Exp, LambertW) => true,
1289
1290        // LambertW on small values often doesn't converge usefully
1291        // W of reciprocal is rarely needed
1292        (Recip, LambertW) => true,
1293
1294        _ => false,
1295    }
1296}
1297
1298/// Pruning rules for binary operators
1299fn should_prune_binary(expr: &Expression, sym: Symbol) -> bool {
1300    let symbols = expr.symbols();
1301    if symbols.len() < 2 {
1302        return false;
1303    }
1304
1305    let last = symbols[symbols.len() - 1];
1306    let prev = symbols[symbols.len() - 2];
1307
1308    use Symbol::*;
1309
1310    match sym {
1311        // a - a = 0 (if both operands are identical)
1312        Sub if is_same_subexpr(symbols, 2) => true,
1313        // x - x = 0 (trivial - always 0)
1314        Sub if last == X && prev == X => true,
1315
1316        // a / a = 1 (degenerate if a contains x)
1317        Div if is_same_subexpr(symbols, 2) => true,
1318        // x / x = 1 (trivial identity)
1319        Div if last == X && prev == X => true,
1320        // Division by 1: a/1 = a (useless)
1321        Div if last == One => true,
1322
1323        // Prefer a*2 over a+a
1324        Add if is_same_subexpr(symbols, 2) => true,
1325        // x + (-x) = 0 - check for negated x
1326        Add if last == Neg
1327            && symbols.len() >= 3
1328            && symbols[symbols.len() - 2] == X
1329            && prev == X =>
1330        {
1331            true
1332        }
1333
1334        // 1^b = 1 (degenerate - always equals 1 regardless of b)
1335        // This catches 1^x, 1^(anything)
1336        Pow if prev == One => true,
1337        // a^1 = a (useless)
1338        Pow if last == One => true,
1339
1340        // x * 1 = x, 1 * x = x
1341        Mul if last == One || prev == One => true,
1342
1343        // a"/1 = a^(1/1) = a (1st root is identity)
1344        // But more importantly: 1"/x = 1^(1/x) = 1 (degenerate)
1345        Root if prev == One => true,
1346        // x"/1 means 1^(1/x) = 1 (degenerate)
1347        Root if last == One => true,
1348        // 2nd root is just sqrt, prefer using sqrt
1349        Root if last == Two => true,
1350
1351        // log_x(x) = 1 (trivial identity)
1352        Log if last == X && prev == X => true,
1353        // log_1(anything) is undefined/infinite, log_a(1) = 0
1354        Log if prev == One || last == One => true,
1355        // log_e(a) = ln(a) - prefer ln notation
1356        Log if prev == E => true,
1357
1358        // Ordering: prefer 2+3 over 3+2 for commutative ops
1359        Add | Mul if prev > last && is_constant(last) && is_constant(prev) => true,
1360
1361        _ => false,
1362    }
1363}
1364
1365/// Check if the last n stack items are identical subexpressions
1366///
1367/// This uses a stack-based approach to identify subexpression boundaries.
1368/// For postfix notation, we track the stack depth to find where each
1369/// subexpression starts.
1370fn is_same_subexpr(symbols: &[Symbol], n: usize) -> bool {
1371    if symbols.len() < n * 2 || n < 2 {
1372        return false;
1373    }
1374
1375    // Find the boundaries of the last n subexpressions on the stack
1376    // We need to trace backwards through the postfix to find where each
1377    // complete subexpression starts
1378
1379    let mut stack_depths: Vec<usize> = Vec::with_capacity(symbols.len() + 1);
1380    stack_depths.push(0); // Initial depth
1381
1382    for &sym in symbols {
1383        let prev_depth = *stack_depths.last().unwrap();
1384        let new_depth = match sym.seft() {
1385            Seft::A => prev_depth + 1,
1386            Seft::B => prev_depth,     // pop 1, push 1
1387            Seft::C => prev_depth - 1, // pop 2, push 1
1388        };
1389        stack_depths.push(new_depth);
1390    }
1391
1392    let final_depth = *stack_depths.last().unwrap();
1393    if final_depth < n {
1394        return false;
1395    }
1396
1397    // Find where each of the last n subexpressions starts
1398    let mut subexpr_starts: Vec<usize> = Vec::with_capacity(n);
1399    let mut target_depth = final_depth;
1400
1401    for i in (0..symbols.len()).rev() {
1402        if stack_depths[i] == target_depth && stack_depths[i + 1] > target_depth {
1403            subexpr_starts.push(i);
1404            target_depth -= 1;
1405            if subexpr_starts.len() == n {
1406                break;
1407            }
1408        }
1409    }
1410
1411    if subexpr_starts.len() != n {
1412        return false;
1413    }
1414
1415    // Check if all n subexpressions are identical
1416    // For simplicity with n=2, compare the two subexpressions
1417    if n == 2 && subexpr_starts.len() == 2 {
1418        let start1 = subexpr_starts[1]; // Earlier subexpression
1419        let start2 = subexpr_starts[0]; // Later subexpression
1420        let end1 = start2; // End of first is start of second
1421        let end2 = symbols.len(); // End of second is end of expression
1422
1423        // Compare the symbol slices
1424        if end1 - start1 == end2 - start2 {
1425            return symbols[start1..end1] == symbols[start2..end2];
1426        }
1427    }
1428
1429    false
1430}
1431
1432/// Check if a symbol is a constant (no x)
1433fn is_constant(sym: Symbol) -> bool {
1434    matches!(sym.seft(), Seft::A) && sym != Symbol::X
1435}
1436
1437/// Generate expressions in parallel using Rayon
1438#[cfg(feature = "parallel")]
1439pub fn generate_all_parallel(config: &GenConfig, target: f64) -> GeneratedExprs {
1440    generate_all_parallel_with_context(
1441        config,
1442        target,
1443        &EvalContext::from_slices(&config.user_constants, &config.user_functions),
1444    )
1445}
1446
1447/// Generate expressions in parallel using Rayon with an explicit evaluation context.
1448#[cfg(feature = "parallel")]
1449pub fn generate_all_parallel_with_context(
1450    config: &GenConfig,
1451    target: f64,
1452    eval_context: &EvalContext<'_>,
1453) -> GeneratedExprs {
1454    use rayon::prelude::*;
1455
1456    // Parallel path currently assumes shared LHS/RHS symbol sets.
1457    if has_rhs_symbol_overrides(config) {
1458        return generate_all_with_context(config, target, eval_context);
1459    }
1460
1461    // Generate valid prefixes of length 1 and 2 to create smaller,
1462    // more evenly distributed tasks for Rayon to schedule.
1463    let mut prefixes: Vec<(Expression, usize)> = Vec::new();
1464    let mut immediate_results_lhs = Vec::new();
1465    let mut immediate_results_rhs = Vec::new();
1466
1467    let first_symbols: Vec<Symbol> = config
1468        .constants
1469        .iter()
1470        .copied()
1471        .chain(
1472            if config.generate_lhs && !config.constants.contains(&Symbol::X) {
1473                Some(Symbol::X)
1474            } else {
1475                None
1476            },
1477        )
1478        .filter(|&sym| {
1479            config
1480                .symbol_max_counts
1481                .get(&sym)
1482                .is_none_or(|&max| max > 0)
1483        })
1484        .collect();
1485
1486    for sym1 in first_symbols {
1487        let mut expr1 = Expression::new();
1488        expr1.push_with_table(sym1, &config.symbol_table);
1489
1490        let max_complexity = if expr1.contains_x() {
1491            config.max_lhs_complexity
1492        } else {
1493            std::cmp::max(config.max_lhs_complexity, config.max_rhs_complexity)
1494        };
1495
1496        if expr1.complexity() > max_complexity {
1497            continue;
1498        }
1499
1500        // 1. Evaluate length-1 prefix (simulate top of generate_recursive)
1501        if let Ok(result) = evaluate_fast_with_context(&expr1, target, eval_context) {
1502            if result.value.is_finite()
1503                && result.value.abs() <= MAX_GENERATED_VALUE
1504                && result.num_type >= config.min_num_type
1505            {
1506                let eval_expr = EvaluatedExpr::new(
1507                    expr1.clone(),
1508                    result.value,
1509                    result.derivative,
1510                    result.num_type,
1511                );
1512
1513                if expr1.contains_x() {
1514                    if config.generate_lhs && expr1.complexity() <= config.max_lhs_complexity {
1515                        immediate_results_lhs.push(eval_expr);
1516                    }
1517                } else if config.generate_rhs && expr1.complexity() <= config.max_rhs_complexity {
1518                    immediate_results_rhs.push(eval_expr);
1519                }
1520            }
1521        }
1522
1523        if expr1.len() >= config.max_length {
1524            continue;
1525        }
1526
1527        // 2. Add next symbols (simulate bottom of generate_recursive)
1528
1529        // Constants (+1 stack)
1530        let mut next_constants = config.constants.clone();
1531        if config.generate_lhs && !next_constants.contains(&Symbol::X) {
1532            next_constants.push(Symbol::X);
1533        }
1534
1535        for &sym2 in &next_constants {
1536            let sym2_weight = config.symbol_table.weight(sym2);
1537            let next_max = if expr1.contains_x() || sym2 == Symbol::X {
1538                config.max_lhs_complexity
1539            } else {
1540                std::cmp::max(config.max_lhs_complexity, config.max_rhs_complexity)
1541            };
1542
1543            if expr1.complexity() + sym2_weight <= next_max
1544                && !exceeds_symbol_limit(config, &expr1, sym2)
1545            {
1546                let mut expr2 = expr1.clone();
1547                expr2.push_with_table(sym2, &config.symbol_table);
1548                // Min complexity check: for stack depth 2, we need at least 1 binary op
1549                let min_remaining = min_complexity_to_complete(2, config);
1550                if expr2.complexity() + min_remaining <= next_max {
1551                    prefixes.push((expr2, 2));
1552                }
1553            }
1554        }
1555
1556        // Unary ops (+0 stack)
1557        for &sym2 in &config.unary_ops {
1558            let sym2_weight = config.symbol_table.weight(sym2);
1559            if expr1.complexity() + sym2_weight <= max_complexity
1560                && !exceeds_symbol_limit(config, &expr1, sym2)
1561                && !should_prune_unary(&expr1, sym2)
1562            {
1563                let mut expr2 = expr1.clone();
1564                expr2.push_with_table(sym2, &config.symbol_table);
1565                let min_remaining = min_complexity_to_complete(1, config);
1566                if expr2.complexity() + min_remaining <= max_complexity {
1567                    prefixes.push((expr2, 1));
1568                }
1569            }
1570        }
1571    }
1572
1573    let results: Vec<(Vec<EvaluatedExpr>, Vec<EvaluatedExpr>)> = prefixes
1574        .into_par_iter()
1575        .map(|(mut expr, depth)| {
1576            let mut lhs = Vec::new();
1577            let mut rhs = Vec::new();
1578            generate_recursive(
1579                config,
1580                target,
1581                *eval_context,
1582                &mut expr,
1583                depth,
1584                &mut lhs,
1585                &mut rhs,
1586            );
1587            (lhs, rhs)
1588        })
1589        .collect();
1590
1591    // Merge results
1592    let mut lhs_raw = immediate_results_lhs;
1593    let mut rhs_raw = immediate_results_rhs;
1594    for (lhs, rhs) in results {
1595        lhs_raw.extend(lhs);
1596        rhs_raw.extend(rhs);
1597    }
1598
1599    // Deduplicate RHS by value, keeping simplest expression for each value
1600    let mut rhs_map: HashMap<i64, EvaluatedExpr> = HashMap::new();
1601    for expr in rhs_raw {
1602        let key = quantize_value(expr.value);
1603        rhs_map
1604            .entry(key)
1605            .and_modify(|existing| {
1606                if expr.expr.complexity() < existing.expr.complexity() {
1607                    *existing = expr.clone();
1608                }
1609            })
1610            .or_insert(expr);
1611    }
1612
1613    // Deduplicate LHS by (value, derivative), keeping simplest expression
1614    let mut lhs_map: HashMap<LhsKey, EvaluatedExpr> = HashMap::new();
1615    for expr in lhs_raw {
1616        let key = (quantize_value(expr.value), quantize_value(expr.derivative));
1617        lhs_map
1618            .entry(key)
1619            .and_modify(|existing| {
1620                if expr.expr.complexity() < existing.expr.complexity() {
1621                    *existing = expr.clone();
1622                }
1623            })
1624            .or_insert(expr);
1625    }
1626
1627    GeneratedExprs {
1628        lhs: lhs_map.into_values().collect(),
1629        rhs: rhs_map.into_values().collect(),
1630    }
1631}
1632
1633#[cfg(test)]
1634mod tests {
1635    use super::*;
1636
1637    /// Create a fast test config with limited complexity and operators
1638    fn fast_test_config() -> GenConfig {
1639        GenConfig {
1640            max_lhs_complexity: 20,
1641            max_rhs_complexity: 20,
1642            max_length: 8,
1643            constants: vec![
1644                Symbol::One,
1645                Symbol::Two,
1646                Symbol::Three,
1647                Symbol::Four,
1648                Symbol::Five,
1649                Symbol::Pi,
1650                Symbol::E,
1651            ],
1652            unary_ops: vec![Symbol::Neg, Symbol::Recip, Symbol::Square, Symbol::Sqrt],
1653            binary_ops: vec![Symbol::Add, Symbol::Sub, Symbol::Mul, Symbol::Div],
1654            rhs_constants: None,
1655            rhs_unary_ops: None,
1656            rhs_binary_ops: None,
1657            symbol_max_counts: HashMap::new(),
1658            rhs_symbol_max_counts: None,
1659            min_num_type: NumType::Transcendental,
1660            generate_lhs: true,
1661            generate_rhs: true,
1662            user_constants: Vec::new(),
1663            user_functions: Vec::new(),
1664            show_pruned_arith: false,
1665            symbol_table: Arc::new(SymbolTable::new()),
1666        }
1667    }
1668
1669    #[test]
1670    fn test_generate_simple() {
1671        let mut config = fast_test_config();
1672        config.generate_lhs = false; // Only RHS for simpler test
1673
1674        let result = generate_all(&config, 1.0);
1675
1676        // Should have some RHS expressions
1677        assert!(!result.rhs.is_empty());
1678
1679        // All should be valid (evaluate without error)
1680        for expr in &result.rhs {
1681            assert!(!expr.expr.contains_x());
1682        }
1683    }
1684
1685    #[test]
1686    fn test_generate_lhs() {
1687        let mut config = fast_test_config();
1688        config.generate_rhs = false;
1689
1690        let result = generate_all(&config, 2.0);
1691
1692        // Should have LHS expressions containing x
1693        assert!(!result.lhs.is_empty());
1694        for expr in &result.lhs {
1695            assert!(expr.expr.contains_x());
1696        }
1697    }
1698
1699    #[test]
1700    fn test_complexity_limit() {
1701        let config = fast_test_config();
1702
1703        let result = generate_all(&config, 1.0);
1704
1705        for expr in &result.rhs {
1706            assert!(expr.expr.complexity() <= config.max_rhs_complexity);
1707        }
1708        for expr in &result.lhs {
1709            assert!(expr.expr.complexity() <= config.max_lhs_complexity);
1710        }
1711    }
1712
1713    #[test]
1714    fn test_generate_all_with_limit_aborts_when_exceeded() {
1715        // Config with high complexity that will generate many expressions.
1716        // With new calibrated weights, even moderate complexity can generate 100+ expressions.
1717        let mut config = fast_test_config();
1718        config.max_lhs_complexity = 30;
1719        config.max_rhs_complexity = 30;
1720
1721        // First, check how many expressions would be generated without limit.
1722        let unlimited = generate_all(&config, 2.5);
1723        let total_unlimited = unlimited.lhs.len() + unlimited.rhs.len();
1724
1725        // The test only makes sense if we'd generate more than a handful.
1726        assert!(
1727            total_unlimited > 10,
1728            "Test config should generate >10 expressions"
1729        );
1730
1731        // Now test with a limit less than the actual count — should return None.
1732        let limit = total_unlimited / 2; // Set limit to half of what would be generated
1733        let result = generate_all_with_limit(&config, 2.5, limit);
1734
1735        assert!(
1736            result.is_none(),
1737            "generate_all_with_limit should return None when limit ({}) is exceeded (actual: {})",
1738            limit,
1739            total_unlimited
1740        );
1741    }
1742
1743    #[test]
1744    fn test_generate_all_with_limit_succeeds_when_within_limit() {
1745        // Same config but with a generous limit that won't be hit.
1746        let mut config = fast_test_config();
1747        config.max_lhs_complexity = 30;
1748        config.max_rhs_complexity = 30;
1749
1750        // Set limit much higher than expected expression count.
1751        let result = generate_all_with_limit(&config, 2.5, 10_000);
1752
1753        assert!(
1754            result.is_some(),
1755            "generate_all_with_limit should return Some when limit is not exceeded"
1756        );
1757
1758        let generated = result.unwrap();
1759        // Should have generated some expressions.
1760        assert!(!generated.lhs.is_empty() || !generated.rhs.is_empty());
1761    }
1762
1763    // ==================== expression_respects_constraints tests ====================
1764
1765    fn expr_from_postfix(s: &str) -> Expression {
1766        Expression::parse(s).expect("valid expression")
1767    }
1768
1769    #[test]
1770    fn test_constraints_default_allows_all() {
1771        let opts = ExpressionConstraintOptions::default();
1772
1773        // x^pi should be allowed with default options
1774        let expr = expr_from_postfix("xp^"); // x^pi
1775        assert!(
1776            expression_respects_constraints(&expr, opts),
1777            "x^pi should be allowed with default options"
1778        );
1779
1780        // sinpi(e) should be allowed
1781        let expr = expr_from_postfix("eS"); // e then sinpi (S = SinPi)
1782        assert!(
1783            expression_respects_constraints(&expr, opts),
1784            "sinpi(e) should be allowed with default options"
1785        );
1786    }
1787
1788    #[test]
1789    fn test_constraints_rational_exponents_rejects_transcendental() {
1790        let opts = ExpressionConstraintOptions {
1791            rational_exponents: true,
1792            ..Default::default()
1793        };
1794
1795        // x^pi should be rejected (pi is transcendental)
1796        let expr = expr_from_postfix("xp^");
1797        assert!(
1798            !expression_respects_constraints(&expr, opts),
1799            "x^pi should be rejected with rational_exponents=true"
1800        );
1801
1802        // x^e should be rejected
1803        let expr = expr_from_postfix("xe^");
1804        assert!(
1805            !expression_respects_constraints(&expr, opts),
1806            "x^e should be rejected with rational_exponents=true"
1807        );
1808    }
1809
1810    #[test]
1811    fn test_constraints_rational_exponents_allows_integer() {
1812        let opts = ExpressionConstraintOptions {
1813            rational_exponents: true,
1814            ..Default::default()
1815        };
1816
1817        // x^2 should be allowed (2 is integer)
1818        let expr = expr_from_postfix("x2^");
1819        assert!(
1820            expression_respects_constraints(&expr, opts),
1821            "x^2 should be allowed with rational_exponents=true"
1822        );
1823
1824        // x^1 should be allowed
1825        let expr = expr_from_postfix("x1^");
1826        assert!(
1827            expression_respects_constraints(&expr, opts),
1828            "x^1 should be allowed with rational_exponents=true"
1829        );
1830    }
1831
1832    #[test]
1833    fn test_constraints_rational_trig_args_rejects_irrational() {
1834        let opts = ExpressionConstraintOptions {
1835            rational_trig_args: true,
1836            ..Default::default()
1837        };
1838
1839        // sinpi(e) should be rejected (e is irrational/transcendental)
1840        let expr = expr_from_postfix("eS"); // e then sinpi (S = SinPi)
1841        assert!(
1842            !expression_respects_constraints(&expr, opts),
1843            "sinpi(e) should be rejected with rational_trig_args=true"
1844        );
1845
1846        // sinpi(pi) should be rejected (pi is transcendental)
1847        let expr = expr_from_postfix("pS"); // pi then sinpi
1848        assert!(
1849            !expression_respects_constraints(&expr, opts),
1850            "sinpi(pi) should be rejected with rational_trig_args=true"
1851        );
1852    }
1853
1854    #[test]
1855    fn test_constraints_rational_trig_args_allows_rational() {
1856        let opts = ExpressionConstraintOptions {
1857            rational_trig_args: true,
1858            ..Default::default()
1859        };
1860
1861        // sinpi(1) should be allowed (1 is integer, hence rational)
1862        let expr = expr_from_postfix("1S"); // 1 then sinpi (S = SinPi)
1863        assert!(
1864            expression_respects_constraints(&expr, opts),
1865            "sinpi(1) should be allowed with rational_trig_args=true"
1866        );
1867
1868        // sinpi(2) should be allowed
1869        let expr = expr_from_postfix("2S");
1870        assert!(
1871            expression_respects_constraints(&expr, opts),
1872            "sinpi(2) should be allowed with rational_trig_args=true"
1873        );
1874    }
1875
1876    #[test]
1877    fn test_constraints_rational_trig_args_rejects_x() {
1878        let opts = ExpressionConstraintOptions {
1879            rational_trig_args: true,
1880            ..Default::default()
1881        };
1882
1883        // sinpi(x) should be rejected (x is not a constant rational)
1884        let expr = expr_from_postfix("xS"); // x then sinpi (S = SinPi)
1885        assert!(
1886            !expression_respects_constraints(&expr, opts),
1887            "sinpi(x) should be rejected with rational_trig_args=true"
1888        );
1889    }
1890
1891    #[test]
1892    fn test_constraints_max_trig_cycles() {
1893        let opts = ExpressionConstraintOptions {
1894            max_trig_cycles: Some(2),
1895            ..Default::default()
1896        };
1897
1898        // Single trig: sinpi(x) - should pass
1899        let expr = expr_from_postfix("xS"); // x then sinpi (S = SinPi)
1900        assert!(
1901            expression_respects_constraints(&expr, opts),
1902            "1 trig op should pass with max=2"
1903        );
1904
1905        // Double nested: sinpi(cospi(x)) - should pass
1906        // x C S = sinpi(cospi(x)) where C = CosPi, S = SinPi
1907        let expr = expr_from_postfix("xCS");
1908        assert!(
1909            expression_respects_constraints(&expr, opts),
1910            "2 trig ops should pass with max=2"
1911        );
1912
1913        // Triple nested: sinpi(cospi(tanpi(x))) - should fail
1914        // x T C S = sinpi(cospi(tanpi(x))) where T = TanPi
1915        let expr = expr_from_postfix("xTCS");
1916        assert!(
1917            !expression_respects_constraints(&expr, opts),
1918            "3 trig ops should fail with max=2"
1919        );
1920    }
1921
1922    #[test]
1923    fn test_constraints_max_trig_cycles_none_unlimited() {
1924        let opts = ExpressionConstraintOptions {
1925            max_trig_cycles: None, // No limit
1926            ..Default::default()
1927        };
1928
1929        // Even deeply nested trig should pass
1930        // x T C S T C S = 6 trig ops
1931        let expr = expr_from_postfix("xTCSTCS");
1932        assert!(
1933            expression_respects_constraints(&expr, opts),
1934            "Unlimited trig should pass any depth"
1935        );
1936    }
1937
1938    #[test]
1939    fn test_constraints_combined() {
1940        let opts = ExpressionConstraintOptions {
1941            rational_exponents: true,
1942            rational_trig_args: true,
1943            max_trig_cycles: Some(1),
1944            ..Default::default()
1945        };
1946
1947        // x^2 + sinpi(1) should pass
1948        let expr = expr_from_postfix("x2^1S+"); // S = SinPi
1949        assert!(
1950            expression_respects_constraints(&expr, opts),
1951            "x^2 + sinpi(1) should pass all constraints"
1952        );
1953
1954        // x^pi should fail (rational_exponents)
1955        let expr = expr_from_postfix("xp^");
1956        assert!(
1957            !expression_respects_constraints(&expr, opts),
1958            "x^pi should fail rational_exponents"
1959        );
1960
1961        // sinpi(x) should fail (rational_trig_args)
1962        let expr = expr_from_postfix("xS"); // S = SinPi
1963        assert!(
1964            !expression_respects_constraints(&expr, opts),
1965            "sinpi(x) should fail rational_trig_args"
1966        );
1967
1968        // sinpi(cospi(1)) should fail (max_trig_cycles)
1969        let expr = expr_from_postfix("1CS"); // C = CosPi, S = SinPi
1970        assert!(
1971            !expression_respects_constraints(&expr, opts),
1972            "double trig should fail max_trig_cycles=1"
1973        );
1974    }
1975
1976    #[test]
1977    fn test_constraints_malformed_expression() {
1978        let opts = ExpressionConstraintOptions::default();
1979
1980        // Expression that would cause stack underflow
1981        let expr = Expression::from_symbols(&[crate::symbol::Symbol::Add]); // Just a binary op
1982        assert!(
1983            !expression_respects_constraints(&expr, opts),
1984            "Malformed expression should return false"
1985        );
1986
1987        // Incomplete expression (too many values)
1988        let expr =
1989            Expression::from_symbols(&[crate::symbol::Symbol::One, crate::symbol::Symbol::Two]);
1990        assert!(
1991            !expression_respects_constraints(&expr, opts),
1992            "Incomplete expression should return false"
1993        );
1994    }
1995
1996    #[test]
1997    fn test_constraints_user_constant_types() {
1998        // Set user constant 0 to be Integer type
1999        let mut user_types = [NumType::Transcendental; 16];
2000        user_types[0] = NumType::Integer;
2001
2002        let opts = ExpressionConstraintOptions {
2003            rational_exponents: true,
2004            user_constant_types: user_types,
2005            ..Default::default()
2006        };
2007
2008        // If UserConstant0 is treated as Integer, x^UserConstant0 should be allowed
2009        // (We can't easily test this without actually having user constants in the expression,
2010        // but this verifies the options struct is properly configured)
2011        assert_eq!(opts.user_constant_types[0], NumType::Integer);
2012    }
2013}