1use crate::eval::{evaluate_fast_with_context, EvalContext};
17use crate::symbol_table::SymbolTable;
18use std::sync::Arc;
19
20const QUANTIZE_SCALE: f64 = 1e8;
30
31const MAX_QUANTIZED_VALUE: f64 = 1e10;
37
38const MAX_GENERATED_VALUE: f64 = 1e12;
43use crate::expr::{EvaluatedExpr, Expression, MAX_EXPR_LEN};
44use crate::profile::UserConstant;
45use crate::symbol::{NumType, Seft, Symbol};
46use crate::udf::UserFunction;
47use std::collections::HashMap;
48
49#[derive(Clone)]
82pub struct GenConfig {
83 pub max_lhs_complexity: u32,
91
92 pub max_rhs_complexity: u32,
99
100 pub max_length: usize,
107
108 pub constants: Vec<Symbol>,
117
118 pub unary_ops: Vec<Symbol>,
129
130 pub binary_ops: Vec<Symbol>,
138
139 pub rhs_constants: Option<Vec<Symbol>>,
146
147 pub rhs_unary_ops: Option<Vec<Symbol>>,
154
155 pub rhs_binary_ops: Option<Vec<Symbol>>,
161
162 pub symbol_max_counts: HashMap<Symbol, u32>,
170
171 pub rhs_symbol_max_counts: Option<HashMap<Symbol, u32>>,
178
179 pub min_num_type: NumType,
191
192 pub generate_lhs: bool,
199
200 pub generate_rhs: bool,
207
208 pub user_constants: Vec<UserConstant>,
216
217 pub user_functions: Vec<UserFunction>,
225
226 pub show_pruned_arith: bool,
234
235 pub symbol_table: Arc<SymbolTable>,
243}
244
245#[derive(Debug, Clone, Copy)]
250pub struct ExpressionConstraintOptions {
251 pub rational_exponents: bool,
253 pub rational_trig_args: bool,
255 pub max_trig_cycles: Option<u32>,
257 pub user_constant_types: [NumType; 16],
259 pub user_function_types: [NumType; 16],
261}
262
263impl Default for ExpressionConstraintOptions {
264 fn default() -> Self {
265 Self {
266 rational_exponents: false,
267 rational_trig_args: false,
268 max_trig_cycles: None,
269 user_constant_types: [NumType::Transcendental; 16],
270 user_function_types: [NumType::Transcendental; 16],
271 }
272 }
273}
274
275pub fn expression_respects_constraints(
280 expression: &Expression,
281 opts: ExpressionConstraintOptions,
282) -> bool {
283 #[derive(Clone, Copy)]
284 struct ConstraintValue {
285 has_x: bool,
286 num_type: NumType,
287 }
288
289 let mut stack: Vec<ConstraintValue> = Vec::with_capacity(expression.len());
290 let mut trig_ops: u32 = 0;
291
292 for &sym in expression.symbols() {
293 match sym.seft() {
294 Seft::A => {
295 let num_type = if let Some(idx) = sym.user_constant_index() {
296 opts.user_constant_types[idx as usize]
297 } else {
298 sym.inherent_type()
299 };
300 stack.push(ConstraintValue {
301 has_x: sym == Symbol::X,
302 num_type,
303 });
304 }
305 Seft::B => {
306 let Some(arg) = stack.pop() else {
307 return false;
308 };
309
310 if matches!(sym, Symbol::SinPi | Symbol::CosPi | Symbol::TanPi) {
311 trig_ops = trig_ops.saturating_add(1);
312 if opts.rational_trig_args && (arg.has_x || arg.num_type < NumType::Rational) {
313 return false;
314 }
315 }
316
317 let num_type = match sym {
318 Symbol::Neg | Symbol::Square => arg.num_type,
319 Symbol::Recip => {
320 if arg.num_type >= NumType::Rational {
321 NumType::Rational
322 } else {
323 arg.num_type
324 }
325 }
326 Symbol::Sqrt => {
327 if arg.num_type >= NumType::Rational {
328 NumType::Algebraic
329 } else {
330 arg.num_type
331 }
332 }
333 Symbol::UserFunction0
334 | Symbol::UserFunction1
335 | Symbol::UserFunction2
336 | Symbol::UserFunction3
337 | Symbol::UserFunction4
338 | Symbol::UserFunction5
339 | Symbol::UserFunction6
340 | Symbol::UserFunction7
341 | Symbol::UserFunction8
342 | Symbol::UserFunction9
343 | Symbol::UserFunction10
344 | Symbol::UserFunction11
345 | Symbol::UserFunction12
346 | Symbol::UserFunction13
347 | Symbol::UserFunction14
348 | Symbol::UserFunction15 => {
349 let idx = sym.user_function_index().unwrap_or(0) as usize;
350 opts.user_function_types[idx]
351 }
352 _ => NumType::Transcendental,
353 };
354
355 stack.push(ConstraintValue {
356 has_x: arg.has_x,
357 num_type,
358 });
359 }
360 Seft::C => {
361 let Some(rhs) = stack.pop() else {
362 return false;
363 };
364 let Some(lhs) = stack.pop() else {
365 return false;
366 };
367
368 if opts.rational_exponents
369 && sym == Symbol::Pow
370 && (rhs.has_x || rhs.num_type < NumType::Rational)
371 {
372 return false;
373 }
374
375 let num_type = match sym {
376 Symbol::Add | Symbol::Sub | Symbol::Mul => lhs.num_type.combine(rhs.num_type),
377 Symbol::Div => {
378 let combined = lhs.num_type.combine(rhs.num_type);
379 if combined == NumType::Integer {
380 NumType::Rational
381 } else {
382 combined
383 }
384 }
385 Symbol::Pow => {
386 if rhs.has_x {
387 NumType::Transcendental
388 } else if rhs.num_type == NumType::Integer {
389 lhs.num_type
390 } else if lhs.num_type >= NumType::Rational
391 && rhs.num_type >= NumType::Rational
392 {
393 NumType::Algebraic
394 } else {
395 NumType::Transcendental
396 }
397 }
398 Symbol::Root => NumType::Algebraic,
399 Symbol::Log | Symbol::Atan2 => NumType::Transcendental,
400 _ => NumType::Transcendental,
401 };
402
403 stack.push(ConstraintValue {
404 has_x: lhs.has_x || rhs.has_x,
405 num_type,
406 });
407 }
408 }
409 }
410
411 if stack.len() != 1 {
412 return false;
413 }
414
415 opts.max_trig_cycles
416 .is_none_or(|max_cycles| trig_ops <= max_cycles)
417}
418
419impl Default for GenConfig {
420 fn default() -> Self {
421 Self {
422 max_lhs_complexity: 128,
423 max_rhs_complexity: 128,
424 max_length: MAX_EXPR_LEN,
425 constants: Symbol::constants().to_vec(),
426 unary_ops: Symbol::unary_ops().to_vec(),
427 binary_ops: Symbol::binary_ops().to_vec(),
428 rhs_constants: None,
429 rhs_unary_ops: None,
430 rhs_binary_ops: None,
431 symbol_max_counts: HashMap::new(),
432 rhs_symbol_max_counts: None,
433 min_num_type: NumType::Transcendental,
434 generate_lhs: true,
435 generate_rhs: true,
436 user_constants: Vec::new(),
437 user_functions: Vec::new(),
438 show_pruned_arith: false,
439 symbol_table: Arc::new(SymbolTable::new()),
440 }
441 }
442}
443
444pub struct GeneratedExprs {
446 pub lhs: Vec<EvaluatedExpr>,
448 pub rhs: Vec<EvaluatedExpr>,
450}
451
452pub struct StreamingCallbacks<'a> {
457 pub on_rhs: &'a mut dyn FnMut(&EvaluatedExpr) -> bool,
460 pub on_lhs: &'a mut dyn FnMut(&EvaluatedExpr) -> bool,
463}
464
465pub type LhsKey = (i64, i64);
468
469#[inline]
471pub fn quantize_value(v: f64) -> i64 {
472 if !v.is_finite() || v.abs() > MAX_QUANTIZED_VALUE {
473 if v > MAX_QUANTIZED_VALUE {
475 return i64::MAX - 1;
476 } else if v < -MAX_QUANTIZED_VALUE {
477 return i64::MIN + 1;
478 }
479 return i64::MAX;
480 }
481 (v * QUANTIZE_SCALE).round() as i64
483}
484
485pub fn generate_all(config: &GenConfig, target: f64) -> GeneratedExprs {
487 generate_all_with_context(
488 config,
489 target,
490 &EvalContext::from_slices(&config.user_constants, &config.user_functions),
491 )
492}
493
494pub fn generate_all_with_context(
496 config: &GenConfig,
497 target: f64,
498 eval_context: &EvalContext<'_>,
499) -> GeneratedExprs {
500 let mut lhs_raw = Vec::new();
501 let mut rhs_raw = Vec::new();
502
503 if config.generate_lhs && config.generate_rhs && has_rhs_symbol_overrides(config) {
504 let mut lhs_config = config.clone();
506 lhs_config.generate_lhs = true;
507 lhs_config.generate_rhs = false;
508 generate_recursive(
509 &lhs_config,
510 target,
511 *eval_context,
512 &mut Expression::new(),
513 0,
514 &mut lhs_raw,
515 &mut rhs_raw,
516 );
517
518 let rhs_config = rhs_only_config(config);
520 generate_recursive(
521 &rhs_config,
522 target,
523 *eval_context,
524 &mut Expression::new(),
525 0,
526 &mut lhs_raw,
527 &mut rhs_raw,
528 );
529 } else {
530 generate_recursive(
532 config,
533 target,
534 *eval_context,
535 &mut Expression::new(),
536 0, &mut lhs_raw,
538 &mut rhs_raw,
539 );
540 }
541
542 let mut rhs_map: HashMap<i64, EvaluatedExpr> = HashMap::new();
544 for expr in rhs_raw {
545 let key = quantize_value(expr.value);
546 rhs_map
547 .entry(key)
548 .and_modify(|existing| {
549 if expr.expr.complexity() < existing.expr.complexity() {
550 *existing = expr.clone();
551 }
552 })
553 .or_insert(expr);
554 }
555
556 let mut lhs_map: HashMap<LhsKey, EvaluatedExpr> = HashMap::new();
558 for expr in lhs_raw {
559 let key = (quantize_value(expr.value), quantize_value(expr.derivative));
560 lhs_map
561 .entry(key)
562 .and_modify(|existing| {
563 if expr.expr.complexity() < existing.expr.complexity() {
564 *existing = expr.clone();
565 }
566 })
567 .or_insert(expr);
568 }
569
570 GeneratedExprs {
571 lhs: lhs_map.into_values().collect(),
572 rhs: rhs_map.into_values().collect(),
573 }
574}
575
576pub fn generate_all_with_limit(
595 config: &GenConfig,
596 target: f64,
597 max_expressions: usize,
598) -> Option<GeneratedExprs> {
599 generate_all_with_limit_and_context(
600 config,
601 target,
602 &EvalContext::from_slices(&config.user_constants, &config.user_functions),
603 max_expressions,
604 )
605}
606
607pub fn generate_all_with_limit_and_context(
609 config: &GenConfig,
610 target: f64,
611 eval_context: &EvalContext<'_>,
612 max_expressions: usize,
613) -> Option<GeneratedExprs> {
614 use std::sync::atomic::{AtomicUsize, Ordering};
615 use std::sync::Arc;
616
617 let count = Arc::new(AtomicUsize::new(0));
618 let limit = max_expressions;
619
620 let mut lhs_raw = Vec::new();
622 let mut rhs_raw = Vec::new();
623
624 let mut callbacks = StreamingCallbacks {
626 on_lhs: &mut |expr| {
627 let current = count.fetch_add(1, Ordering::Relaxed) + 1;
628 if current > limit {
629 return false; }
631 lhs_raw.push(expr.clone());
632 true
633 },
634 on_rhs: &mut |expr| {
635 let current = count.fetch_add(1, Ordering::Relaxed) + 1;
636 if current > limit {
637 return false; }
639 rhs_raw.push(expr.clone());
640 true
641 },
642 };
643
644 generate_streaming_with_context(config, target, eval_context, &mut callbacks);
645
646 let final_count = count.load(Ordering::Relaxed);
648 if final_count > limit {
649 return None;
650 }
651
652 let mut rhs_map: HashMap<i64, EvaluatedExpr> = HashMap::new();
654 for expr in rhs_raw {
655 let key = quantize_value(expr.value);
656 rhs_map
657 .entry(key)
658 .and_modify(|existing| {
659 if expr.expr.complexity() < existing.expr.complexity() {
660 *existing = expr.clone();
661 }
662 })
663 .or_insert(expr);
664 }
665
666 let mut lhs_map: HashMap<LhsKey, EvaluatedExpr> = HashMap::new();
667 for expr in lhs_raw {
668 let key = (quantize_value(expr.value), quantize_value(expr.derivative));
669 lhs_map
670 .entry(key)
671 .and_modify(|existing| {
672 if expr.expr.complexity() < existing.expr.complexity() {
673 *existing = expr.clone();
674 }
675 })
676 .or_insert(expr);
677 }
678
679 Some(GeneratedExprs {
680 lhs: lhs_map.into_values().collect(),
681 rhs: rhs_map.into_values().collect(),
682 })
683}
684
685pub fn generate_streaming(config: &GenConfig, target: f64, callbacks: &mut StreamingCallbacks) {
727 generate_streaming_with_context(
728 config,
729 target,
730 &EvalContext::from_slices(&config.user_constants, &config.user_functions),
731 callbacks,
732 );
733}
734
735pub fn generate_streaming_with_context(
737 config: &GenConfig,
738 target: f64,
739 eval_context: &EvalContext<'_>,
740 callbacks: &mut StreamingCallbacks,
741) {
742 if config.generate_lhs && config.generate_rhs && has_rhs_symbol_overrides(config) {
743 let mut lhs_config = config.clone();
744 lhs_config.generate_lhs = true;
745 lhs_config.generate_rhs = false;
746 if !generate_recursive_streaming(
747 &lhs_config,
748 target,
749 *eval_context,
750 &mut Expression::new(),
751 0,
752 callbacks,
753 ) {
754 return;
755 }
756
757 let rhs_config = rhs_only_config(config);
758 generate_recursive_streaming(
759 &rhs_config,
760 target,
761 *eval_context,
762 &mut Expression::new(),
763 0,
764 callbacks,
765 );
766 } else {
767 generate_recursive_streaming(
768 config,
769 target,
770 *eval_context,
771 &mut Expression::new(),
772 0, callbacks,
774 );
775 }
776}
777
778#[inline]
779fn has_rhs_symbol_overrides(config: &GenConfig) -> bool {
780 config.rhs_constants.is_some()
781 || config.rhs_unary_ops.is_some()
782 || config.rhs_binary_ops.is_some()
783 || config.rhs_symbol_max_counts.is_some()
784}
785
786#[inline]
791fn should_include_expression(
792 result: &crate::eval::EvalResult,
793 config: &GenConfig,
794 complexity: u32,
795 contains_x: bool,
796) -> bool {
797 result.value.is_finite()
798 && result.value.abs() <= MAX_GENERATED_VALUE
799 && result.num_type >= config.min_num_type
800 && if contains_x {
801 config.generate_lhs && complexity <= config.max_lhs_complexity
802 } else {
803 config.generate_rhs && complexity <= config.max_rhs_complexity
804 }
805}
806
807#[inline]
813fn get_max_complexity(config: &GenConfig, contains_x: bool) -> u32 {
814 if contains_x {
815 config.max_lhs_complexity
816 } else {
817 std::cmp::max(config.max_lhs_complexity, config.max_rhs_complexity)
820 }
821}
822
823fn rhs_only_config(config: &GenConfig) -> GenConfig {
824 let mut rhs_config = config.clone();
825 rhs_config.generate_lhs = false;
826 rhs_config.generate_rhs = true;
827 if let Some(constants) = &config.rhs_constants {
828 rhs_config.constants = constants.clone();
829 }
830 if let Some(unary_ops) = &config.rhs_unary_ops {
831 rhs_config.unary_ops = unary_ops.clone();
832 }
833 if let Some(binary_ops) = &config.rhs_binary_ops {
834 rhs_config.binary_ops = binary_ops.clone();
835 }
836 if let Some(rhs_symbol_max_counts) = &config.rhs_symbol_max_counts {
837 rhs_config.symbol_max_counts = rhs_symbol_max_counts.clone();
838 }
839 rhs_config
840}
841
842#[inline]
843fn exceeds_symbol_limit(config: &GenConfig, current: &Expression, sym: Symbol) -> bool {
844 config
845 .symbol_max_counts
846 .get(&sym)
847 .is_some_and(|&max| current.count_symbol(sym) >= max)
848}
849
850fn generate_recursive_streaming(
855 config: &GenConfig,
856 target: f64,
857 eval_context: EvalContext<'_>,
858 current: &mut Expression,
859 stack_depth: usize,
860 callbacks: &mut StreamingCallbacks,
861) -> bool {
862 if stack_depth == 1 && !current.is_empty() {
864 match evaluate_fast_with_context(current, target, &eval_context) {
866 Ok(result) => {
867 if should_include_expression(
869 &result,
870 config,
871 current.complexity(),
872 current.contains_x(),
873 ) {
874 let expr = current.clone();
875 let eval_expr =
876 EvaluatedExpr::new(expr, result.value, result.derivative, result.num_type);
877
878 let should_continue = if current.contains_x() {
880 (callbacks.on_lhs)(&eval_expr)
881 } else {
882 (callbacks.on_rhs)(&eval_expr)
883 };
884 if !should_continue {
885 return false;
886 }
887 }
888 }
889 Err(e) => {
890 if config.show_pruned_arith {
892 eprintln!(
893 " [pruned arith] expression=\"{}\" reason={:?}",
894 current.to_postfix(),
895 e
896 );
897 }
898 }
899 }
900 }
901
902 if current.len() >= config.max_length {
904 return true;
905 }
906
907 let max_complexity = get_max_complexity(config, current.contains_x());
909
910 if current.complexity() >= max_complexity {
911 return true;
912 }
913
914 let min_remaining = min_complexity_to_complete(stack_depth, config);
916 if current.complexity() + min_remaining > max_complexity {
917 return true;
918 }
919
920 for &sym in &config.constants {
924 let sym_weight = config.symbol_table.weight(sym);
925 if current.complexity() + sym_weight > max_complexity {
926 continue;
927 }
928 if exceeds_symbol_limit(config, current, sym) {
929 continue;
930 }
931
932 if sym == Symbol::X && !config.generate_lhs {
934 continue;
935 }
936
937 current.push_with_table(sym, &config.symbol_table);
938 if !generate_recursive_streaming(
939 config,
940 target,
941 eval_context,
942 current,
943 stack_depth + 1,
944 callbacks,
945 ) {
946 current.pop_with_table(&config.symbol_table);
947 return false;
948 }
949 current.pop_with_table(&config.symbol_table);
950 }
951
952 if config.generate_lhs && !config.constants.contains(&Symbol::X) {
954 let sym = Symbol::X;
955 let sym_weight = config.symbol_table.weight(sym);
956 if current.complexity() + sym_weight <= max_complexity
957 && !exceeds_symbol_limit(config, current, sym)
958 {
959 current.push_with_table(sym, &config.symbol_table);
960 if !generate_recursive_streaming(
961 config,
962 target,
963 eval_context,
964 current,
965 stack_depth + 1,
966 callbacks,
967 ) {
968 current.pop_with_table(&config.symbol_table);
969 return false;
970 }
971 current.pop_with_table(&config.symbol_table);
972 }
973 }
974
975 if stack_depth >= 1 {
977 for &sym in &config.unary_ops {
978 let sym_weight = config.symbol_table.weight(sym);
979 if current.complexity() + sym_weight > max_complexity {
980 continue;
981 }
982 if exceeds_symbol_limit(config, current, sym) {
983 continue;
984 }
985
986 if should_prune_unary(current, sym) {
988 continue;
989 }
990
991 current.push_with_table(sym, &config.symbol_table);
992 if !generate_recursive_streaming(
993 config,
994 target,
995 eval_context,
996 current,
997 stack_depth,
998 callbacks,
999 ) {
1000 current.pop_with_table(&config.symbol_table);
1001 return false;
1002 }
1003 current.pop_with_table(&config.symbol_table);
1004 }
1005 }
1006
1007 if stack_depth >= 2 {
1009 for &sym in &config.binary_ops {
1010 let sym_weight = config.symbol_table.weight(sym);
1011 if current.complexity() + sym_weight > max_complexity {
1012 continue;
1013 }
1014 if exceeds_symbol_limit(config, current, sym) {
1015 continue;
1016 }
1017
1018 if should_prune_binary(current, sym) {
1020 continue;
1021 }
1022
1023 current.push_with_table(sym, &config.symbol_table);
1024 if !generate_recursive_streaming(
1025 config,
1026 target,
1027 eval_context,
1028 current,
1029 stack_depth - 1,
1030 callbacks,
1031 ) {
1032 current.pop_with_table(&config.symbol_table);
1033 return false;
1034 }
1035 current.pop_with_table(&config.symbol_table);
1036 }
1037 }
1038
1039 true
1040}
1041
1042fn generate_recursive(
1044 config: &GenConfig,
1045 target: f64,
1046 eval_context: EvalContext<'_>,
1047 current: &mut Expression,
1048 stack_depth: usize,
1049 lhs_out: &mut Vec<EvaluatedExpr>,
1050 rhs_out: &mut Vec<EvaluatedExpr>,
1051) {
1052 if stack_depth == 1 && !current.is_empty() {
1054 match evaluate_fast_with_context(current, target, &eval_context) {
1056 Ok(result) => {
1057 if should_include_expression(
1059 &result,
1060 config,
1061 current.complexity(),
1062 current.contains_x(),
1063 ) {
1064 let expr = current.clone();
1065 let eval_expr =
1066 EvaluatedExpr::new(expr, result.value, result.derivative, result.num_type);
1067
1068 if current.contains_x() {
1070 lhs_out.push(eval_expr);
1071 } else {
1072 rhs_out.push(eval_expr);
1073 }
1074 }
1075 }
1076 Err(e) => {
1077 if config.show_pruned_arith {
1079 eprintln!(
1080 " [pruned arith] expression=\"{}\" reason={:?}",
1081 current.to_postfix(),
1082 e
1083 );
1084 }
1085 }
1086 }
1087 }
1088
1089 if current.len() >= config.max_length {
1091 return;
1092 }
1093
1094 let max_complexity = get_max_complexity(config, current.contains_x());
1096
1097 if current.complexity() >= max_complexity {
1098 return;
1099 }
1100
1101 let min_remaining = min_complexity_to_complete(stack_depth, config);
1103 if current.complexity() + min_remaining > max_complexity {
1104 return;
1105 }
1106
1107 for &sym in &config.constants {
1111 let sym_weight = config.symbol_table.weight(sym);
1112 if current.complexity() + sym_weight > max_complexity {
1113 continue;
1114 }
1115 if exceeds_symbol_limit(config, current, sym) {
1116 continue;
1117 }
1118
1119 if sym == Symbol::X && !config.generate_lhs {
1121 continue;
1122 }
1123
1124 current.push_with_table(sym, &config.symbol_table);
1125 generate_recursive(
1126 config,
1127 target,
1128 eval_context,
1129 current,
1130 stack_depth + 1,
1131 lhs_out,
1132 rhs_out,
1133 );
1134 current.pop_with_table(&config.symbol_table);
1135 }
1136
1137 if config.generate_lhs && !config.constants.contains(&Symbol::X) {
1139 let sym = Symbol::X;
1140 let sym_weight = config.symbol_table.weight(sym);
1141 if current.complexity() + sym_weight <= max_complexity
1142 && !exceeds_symbol_limit(config, current, sym)
1143 {
1144 current.push_with_table(sym, &config.symbol_table);
1145 generate_recursive(
1146 config,
1147 target,
1148 eval_context,
1149 current,
1150 stack_depth + 1,
1151 lhs_out,
1152 rhs_out,
1153 );
1154 current.pop_with_table(&config.symbol_table);
1155 }
1156 }
1157
1158 if stack_depth >= 1 {
1160 for &sym in &config.unary_ops {
1161 let sym_weight = config.symbol_table.weight(sym);
1162 if current.complexity() + sym_weight > max_complexity {
1163 continue;
1164 }
1165 if exceeds_symbol_limit(config, current, sym) {
1166 continue;
1167 }
1168
1169 if should_prune_unary(current, sym) {
1171 continue;
1172 }
1173
1174 current.push_with_table(sym, &config.symbol_table);
1175 generate_recursive(
1176 config,
1177 target,
1178 eval_context,
1179 current,
1180 stack_depth,
1181 lhs_out,
1182 rhs_out,
1183 );
1184 current.pop_with_table(&config.symbol_table);
1185 }
1186 }
1187
1188 if stack_depth >= 2 {
1190 for &sym in &config.binary_ops {
1191 let sym_weight = config.symbol_table.weight(sym);
1192 if current.complexity() + sym_weight > max_complexity {
1193 continue;
1194 }
1195 if exceeds_symbol_limit(config, current, sym) {
1196 continue;
1197 }
1198
1199 if should_prune_binary(current, sym) {
1201 continue;
1202 }
1203
1204 current.push_with_table(sym, &config.symbol_table);
1205 generate_recursive(
1206 config,
1207 target,
1208 eval_context,
1209 current,
1210 stack_depth - 1,
1211 lhs_out,
1212 rhs_out,
1213 );
1214 current.pop_with_table(&config.symbol_table);
1215 }
1216 }
1217}
1218
1219fn min_complexity_to_complete(stack_depth: usize, config: &GenConfig) -> u32 {
1221 if stack_depth <= 1 {
1222 return 0;
1223 }
1224
1225 let min_binary_weight = config
1227 .binary_ops
1228 .iter()
1229 .map(|s| config.symbol_table.weight(*s))
1230 .min()
1231 .unwrap_or(4);
1232
1233 ((stack_depth - 1) as u32) * min_binary_weight
1234}
1235
1236fn should_prune_unary(expr: &Expression, sym: Symbol) -> bool {
1238 let symbols = expr.symbols();
1239 if symbols.is_empty() {
1240 return false;
1241 }
1242
1243 let last = symbols[symbols.len() - 1];
1244
1245 use Symbol::*;
1246
1247 match (last, sym) {
1248 (Neg, Neg) => true,
1250 (Recip, Recip) => true,
1252 (Square, Sqrt) => true,
1254 (Sqrt, Square) => true,
1256 (Exp, Ln) => true,
1258 (Ln, Exp) => true,
1260
1261 (Sqrt, Recip) => true,
1264 (Square, Recip) => true,
1265 (Ln, Recip) => true,
1267 (Square, Square) => true,
1269 (Sqrt, Sqrt) => true,
1271 (Sub, Neg) => true,
1274
1275 (SinPi, SinPi) => true,
1279 (CosPi, CosPi) => true,
1280 (Exp, Exp) => true,
1286
1287 (Exp, LambertW) => true,
1289
1290 (Recip, LambertW) => true,
1293
1294 _ => false,
1295 }
1296}
1297
1298fn should_prune_binary(expr: &Expression, sym: Symbol) -> bool {
1300 let symbols = expr.symbols();
1301 if symbols.len() < 2 {
1302 return false;
1303 }
1304
1305 let last = symbols[symbols.len() - 1];
1306 let prev = symbols[symbols.len() - 2];
1307
1308 use Symbol::*;
1309
1310 match sym {
1311 Sub if is_same_subexpr(symbols, 2) => true,
1313 Sub if last == X && prev == X => true,
1315
1316 Div if is_same_subexpr(symbols, 2) => true,
1318 Div if last == X && prev == X => true,
1320 Div if last == One => true,
1322
1323 Add if is_same_subexpr(symbols, 2) => true,
1325 Add if last == Neg
1327 && symbols.len() >= 3
1328 && symbols[symbols.len() - 2] == X
1329 && prev == X =>
1330 {
1331 true
1332 }
1333
1334 Pow if prev == One => true,
1337 Pow if last == One => true,
1339
1340 Mul if last == One || prev == One => true,
1342
1343 Root if prev == One => true,
1346 Root if last == One => true,
1348 Root if last == Two => true,
1350
1351 Log if last == X && prev == X => true,
1353 Log if prev == One || last == One => true,
1355 Log if prev == E => true,
1357
1358 Add | Mul if prev > last && is_constant(last) && is_constant(prev) => true,
1360
1361 _ => false,
1362 }
1363}
1364
1365fn is_same_subexpr(symbols: &[Symbol], n: usize) -> bool {
1371 if symbols.len() < n * 2 || n < 2 {
1372 return false;
1373 }
1374
1375 let mut stack_depths: Vec<usize> = Vec::with_capacity(symbols.len() + 1);
1380 stack_depths.push(0); for &sym in symbols {
1383 let prev_depth = *stack_depths.last().unwrap();
1384 let new_depth = match sym.seft() {
1385 Seft::A => prev_depth + 1,
1386 Seft::B => prev_depth, Seft::C => prev_depth - 1, };
1389 stack_depths.push(new_depth);
1390 }
1391
1392 let final_depth = *stack_depths.last().unwrap();
1393 if final_depth < n {
1394 return false;
1395 }
1396
1397 let mut subexpr_starts: Vec<usize> = Vec::with_capacity(n);
1399 let mut target_depth = final_depth;
1400
1401 for i in (0..symbols.len()).rev() {
1402 if stack_depths[i] == target_depth && stack_depths[i + 1] > target_depth {
1403 subexpr_starts.push(i);
1404 target_depth -= 1;
1405 if subexpr_starts.len() == n {
1406 break;
1407 }
1408 }
1409 }
1410
1411 if subexpr_starts.len() != n {
1412 return false;
1413 }
1414
1415 if n == 2 && subexpr_starts.len() == 2 {
1418 let start1 = subexpr_starts[1]; let start2 = subexpr_starts[0]; let end1 = start2; let end2 = symbols.len(); if end1 - start1 == end2 - start2 {
1425 return symbols[start1..end1] == symbols[start2..end2];
1426 }
1427 }
1428
1429 false
1430}
1431
1432fn is_constant(sym: Symbol) -> bool {
1434 matches!(sym.seft(), Seft::A) && sym != Symbol::X
1435}
1436
1437#[cfg(feature = "parallel")]
1439pub fn generate_all_parallel(config: &GenConfig, target: f64) -> GeneratedExprs {
1440 generate_all_parallel_with_context(
1441 config,
1442 target,
1443 &EvalContext::from_slices(&config.user_constants, &config.user_functions),
1444 )
1445}
1446
1447#[cfg(feature = "parallel")]
1449pub fn generate_all_parallel_with_context(
1450 config: &GenConfig,
1451 target: f64,
1452 eval_context: &EvalContext<'_>,
1453) -> GeneratedExprs {
1454 use rayon::prelude::*;
1455
1456 if has_rhs_symbol_overrides(config) {
1458 return generate_all_with_context(config, target, eval_context);
1459 }
1460
1461 let mut prefixes: Vec<(Expression, usize)> = Vec::new();
1464 let mut immediate_results_lhs = Vec::new();
1465 let mut immediate_results_rhs = Vec::new();
1466
1467 let first_symbols: Vec<Symbol> = config
1468 .constants
1469 .iter()
1470 .copied()
1471 .chain(
1472 if config.generate_lhs && !config.constants.contains(&Symbol::X) {
1473 Some(Symbol::X)
1474 } else {
1475 None
1476 },
1477 )
1478 .filter(|&sym| {
1479 config
1480 .symbol_max_counts
1481 .get(&sym)
1482 .is_none_or(|&max| max > 0)
1483 })
1484 .collect();
1485
1486 for sym1 in first_symbols {
1487 let mut expr1 = Expression::new();
1488 expr1.push_with_table(sym1, &config.symbol_table);
1489
1490 let max_complexity = if expr1.contains_x() {
1491 config.max_lhs_complexity
1492 } else {
1493 std::cmp::max(config.max_lhs_complexity, config.max_rhs_complexity)
1494 };
1495
1496 if expr1.complexity() > max_complexity {
1497 continue;
1498 }
1499
1500 if let Ok(result) = evaluate_fast_with_context(&expr1, target, eval_context) {
1502 if result.value.is_finite()
1503 && result.value.abs() <= MAX_GENERATED_VALUE
1504 && result.num_type >= config.min_num_type
1505 {
1506 let eval_expr = EvaluatedExpr::new(
1507 expr1.clone(),
1508 result.value,
1509 result.derivative,
1510 result.num_type,
1511 );
1512
1513 if expr1.contains_x() {
1514 if config.generate_lhs && expr1.complexity() <= config.max_lhs_complexity {
1515 immediate_results_lhs.push(eval_expr);
1516 }
1517 } else if config.generate_rhs && expr1.complexity() <= config.max_rhs_complexity {
1518 immediate_results_rhs.push(eval_expr);
1519 }
1520 }
1521 }
1522
1523 if expr1.len() >= config.max_length {
1524 continue;
1525 }
1526
1527 let mut next_constants = config.constants.clone();
1531 if config.generate_lhs && !next_constants.contains(&Symbol::X) {
1532 next_constants.push(Symbol::X);
1533 }
1534
1535 for &sym2 in &next_constants {
1536 let sym2_weight = config.symbol_table.weight(sym2);
1537 let next_max = if expr1.contains_x() || sym2 == Symbol::X {
1538 config.max_lhs_complexity
1539 } else {
1540 std::cmp::max(config.max_lhs_complexity, config.max_rhs_complexity)
1541 };
1542
1543 if expr1.complexity() + sym2_weight <= next_max
1544 && !exceeds_symbol_limit(config, &expr1, sym2)
1545 {
1546 let mut expr2 = expr1.clone();
1547 expr2.push_with_table(sym2, &config.symbol_table);
1548 let min_remaining = min_complexity_to_complete(2, config);
1550 if expr2.complexity() + min_remaining <= next_max {
1551 prefixes.push((expr2, 2));
1552 }
1553 }
1554 }
1555
1556 for &sym2 in &config.unary_ops {
1558 let sym2_weight = config.symbol_table.weight(sym2);
1559 if expr1.complexity() + sym2_weight <= max_complexity
1560 && !exceeds_symbol_limit(config, &expr1, sym2)
1561 && !should_prune_unary(&expr1, sym2)
1562 {
1563 let mut expr2 = expr1.clone();
1564 expr2.push_with_table(sym2, &config.symbol_table);
1565 let min_remaining = min_complexity_to_complete(1, config);
1566 if expr2.complexity() + min_remaining <= max_complexity {
1567 prefixes.push((expr2, 1));
1568 }
1569 }
1570 }
1571 }
1572
1573 let results: Vec<(Vec<EvaluatedExpr>, Vec<EvaluatedExpr>)> = prefixes
1574 .into_par_iter()
1575 .map(|(mut expr, depth)| {
1576 let mut lhs = Vec::new();
1577 let mut rhs = Vec::new();
1578 generate_recursive(
1579 config,
1580 target,
1581 *eval_context,
1582 &mut expr,
1583 depth,
1584 &mut lhs,
1585 &mut rhs,
1586 );
1587 (lhs, rhs)
1588 })
1589 .collect();
1590
1591 let mut lhs_raw = immediate_results_lhs;
1593 let mut rhs_raw = immediate_results_rhs;
1594 for (lhs, rhs) in results {
1595 lhs_raw.extend(lhs);
1596 rhs_raw.extend(rhs);
1597 }
1598
1599 let mut rhs_map: HashMap<i64, EvaluatedExpr> = HashMap::new();
1601 for expr in rhs_raw {
1602 let key = quantize_value(expr.value);
1603 rhs_map
1604 .entry(key)
1605 .and_modify(|existing| {
1606 if expr.expr.complexity() < existing.expr.complexity() {
1607 *existing = expr.clone();
1608 }
1609 })
1610 .or_insert(expr);
1611 }
1612
1613 let mut lhs_map: HashMap<LhsKey, EvaluatedExpr> = HashMap::new();
1615 for expr in lhs_raw {
1616 let key = (quantize_value(expr.value), quantize_value(expr.derivative));
1617 lhs_map
1618 .entry(key)
1619 .and_modify(|existing| {
1620 if expr.expr.complexity() < existing.expr.complexity() {
1621 *existing = expr.clone();
1622 }
1623 })
1624 .or_insert(expr);
1625 }
1626
1627 GeneratedExprs {
1628 lhs: lhs_map.into_values().collect(),
1629 rhs: rhs_map.into_values().collect(),
1630 }
1631}
1632
1633#[cfg(test)]
1634mod tests {
1635 use super::*;
1636
1637 fn fast_test_config() -> GenConfig {
1639 GenConfig {
1640 max_lhs_complexity: 20,
1641 max_rhs_complexity: 20,
1642 max_length: 8,
1643 constants: vec![
1644 Symbol::One,
1645 Symbol::Two,
1646 Symbol::Three,
1647 Symbol::Four,
1648 Symbol::Five,
1649 Symbol::Pi,
1650 Symbol::E,
1651 ],
1652 unary_ops: vec![Symbol::Neg, Symbol::Recip, Symbol::Square, Symbol::Sqrt],
1653 binary_ops: vec![Symbol::Add, Symbol::Sub, Symbol::Mul, Symbol::Div],
1654 rhs_constants: None,
1655 rhs_unary_ops: None,
1656 rhs_binary_ops: None,
1657 symbol_max_counts: HashMap::new(),
1658 rhs_symbol_max_counts: None,
1659 min_num_type: NumType::Transcendental,
1660 generate_lhs: true,
1661 generate_rhs: true,
1662 user_constants: Vec::new(),
1663 user_functions: Vec::new(),
1664 show_pruned_arith: false,
1665 symbol_table: Arc::new(SymbolTable::new()),
1666 }
1667 }
1668
1669 #[test]
1670 fn test_generate_simple() {
1671 let mut config = fast_test_config();
1672 config.generate_lhs = false; let result = generate_all(&config, 1.0);
1675
1676 assert!(!result.rhs.is_empty());
1678
1679 for expr in &result.rhs {
1681 assert!(!expr.expr.contains_x());
1682 }
1683 }
1684
1685 #[test]
1686 fn test_generate_lhs() {
1687 let mut config = fast_test_config();
1688 config.generate_rhs = false;
1689
1690 let result = generate_all(&config, 2.0);
1691
1692 assert!(!result.lhs.is_empty());
1694 for expr in &result.lhs {
1695 assert!(expr.expr.contains_x());
1696 }
1697 }
1698
1699 #[test]
1700 fn test_complexity_limit() {
1701 let config = fast_test_config();
1702
1703 let result = generate_all(&config, 1.0);
1704
1705 for expr in &result.rhs {
1706 assert!(expr.expr.complexity() <= config.max_rhs_complexity);
1707 }
1708 for expr in &result.lhs {
1709 assert!(expr.expr.complexity() <= config.max_lhs_complexity);
1710 }
1711 }
1712
1713 #[test]
1714 fn test_generate_all_with_limit_aborts_when_exceeded() {
1715 let mut config = fast_test_config();
1718 config.max_lhs_complexity = 30;
1719 config.max_rhs_complexity = 30;
1720
1721 let unlimited = generate_all(&config, 2.5);
1723 let total_unlimited = unlimited.lhs.len() + unlimited.rhs.len();
1724
1725 assert!(
1727 total_unlimited > 10,
1728 "Test config should generate >10 expressions"
1729 );
1730
1731 let limit = total_unlimited / 2; let result = generate_all_with_limit(&config, 2.5, limit);
1734
1735 assert!(
1736 result.is_none(),
1737 "generate_all_with_limit should return None when limit ({}) is exceeded (actual: {})",
1738 limit,
1739 total_unlimited
1740 );
1741 }
1742
1743 #[test]
1744 fn test_generate_all_with_limit_succeeds_when_within_limit() {
1745 let mut config = fast_test_config();
1747 config.max_lhs_complexity = 30;
1748 config.max_rhs_complexity = 30;
1749
1750 let result = generate_all_with_limit(&config, 2.5, 10_000);
1752
1753 assert!(
1754 result.is_some(),
1755 "generate_all_with_limit should return Some when limit is not exceeded"
1756 );
1757
1758 let generated = result.unwrap();
1759 assert!(!generated.lhs.is_empty() || !generated.rhs.is_empty());
1761 }
1762
1763 fn expr_from_postfix(s: &str) -> Expression {
1766 Expression::parse(s).expect("valid expression")
1767 }
1768
1769 #[test]
1770 fn test_constraints_default_allows_all() {
1771 let opts = ExpressionConstraintOptions::default();
1772
1773 let expr = expr_from_postfix("xp^"); assert!(
1776 expression_respects_constraints(&expr, opts),
1777 "x^pi should be allowed with default options"
1778 );
1779
1780 let expr = expr_from_postfix("eS"); assert!(
1783 expression_respects_constraints(&expr, opts),
1784 "sinpi(e) should be allowed with default options"
1785 );
1786 }
1787
1788 #[test]
1789 fn test_constraints_rational_exponents_rejects_transcendental() {
1790 let opts = ExpressionConstraintOptions {
1791 rational_exponents: true,
1792 ..Default::default()
1793 };
1794
1795 let expr = expr_from_postfix("xp^");
1797 assert!(
1798 !expression_respects_constraints(&expr, opts),
1799 "x^pi should be rejected with rational_exponents=true"
1800 );
1801
1802 let expr = expr_from_postfix("xe^");
1804 assert!(
1805 !expression_respects_constraints(&expr, opts),
1806 "x^e should be rejected with rational_exponents=true"
1807 );
1808 }
1809
1810 #[test]
1811 fn test_constraints_rational_exponents_allows_integer() {
1812 let opts = ExpressionConstraintOptions {
1813 rational_exponents: true,
1814 ..Default::default()
1815 };
1816
1817 let expr = expr_from_postfix("x2^");
1819 assert!(
1820 expression_respects_constraints(&expr, opts),
1821 "x^2 should be allowed with rational_exponents=true"
1822 );
1823
1824 let expr = expr_from_postfix("x1^");
1826 assert!(
1827 expression_respects_constraints(&expr, opts),
1828 "x^1 should be allowed with rational_exponents=true"
1829 );
1830 }
1831
1832 #[test]
1833 fn test_constraints_rational_trig_args_rejects_irrational() {
1834 let opts = ExpressionConstraintOptions {
1835 rational_trig_args: true,
1836 ..Default::default()
1837 };
1838
1839 let expr = expr_from_postfix("eS"); assert!(
1842 !expression_respects_constraints(&expr, opts),
1843 "sinpi(e) should be rejected with rational_trig_args=true"
1844 );
1845
1846 let expr = expr_from_postfix("pS"); assert!(
1849 !expression_respects_constraints(&expr, opts),
1850 "sinpi(pi) should be rejected with rational_trig_args=true"
1851 );
1852 }
1853
1854 #[test]
1855 fn test_constraints_rational_trig_args_allows_rational() {
1856 let opts = ExpressionConstraintOptions {
1857 rational_trig_args: true,
1858 ..Default::default()
1859 };
1860
1861 let expr = expr_from_postfix("1S"); assert!(
1864 expression_respects_constraints(&expr, opts),
1865 "sinpi(1) should be allowed with rational_trig_args=true"
1866 );
1867
1868 let expr = expr_from_postfix("2S");
1870 assert!(
1871 expression_respects_constraints(&expr, opts),
1872 "sinpi(2) should be allowed with rational_trig_args=true"
1873 );
1874 }
1875
1876 #[test]
1877 fn test_constraints_rational_trig_args_rejects_x() {
1878 let opts = ExpressionConstraintOptions {
1879 rational_trig_args: true,
1880 ..Default::default()
1881 };
1882
1883 let expr = expr_from_postfix("xS"); assert!(
1886 !expression_respects_constraints(&expr, opts),
1887 "sinpi(x) should be rejected with rational_trig_args=true"
1888 );
1889 }
1890
1891 #[test]
1892 fn test_constraints_max_trig_cycles() {
1893 let opts = ExpressionConstraintOptions {
1894 max_trig_cycles: Some(2),
1895 ..Default::default()
1896 };
1897
1898 let expr = expr_from_postfix("xS"); assert!(
1901 expression_respects_constraints(&expr, opts),
1902 "1 trig op should pass with max=2"
1903 );
1904
1905 let expr = expr_from_postfix("xCS");
1908 assert!(
1909 expression_respects_constraints(&expr, opts),
1910 "2 trig ops should pass with max=2"
1911 );
1912
1913 let expr = expr_from_postfix("xTCS");
1916 assert!(
1917 !expression_respects_constraints(&expr, opts),
1918 "3 trig ops should fail with max=2"
1919 );
1920 }
1921
1922 #[test]
1923 fn test_constraints_max_trig_cycles_none_unlimited() {
1924 let opts = ExpressionConstraintOptions {
1925 max_trig_cycles: None, ..Default::default()
1927 };
1928
1929 let expr = expr_from_postfix("xTCSTCS");
1932 assert!(
1933 expression_respects_constraints(&expr, opts),
1934 "Unlimited trig should pass any depth"
1935 );
1936 }
1937
1938 #[test]
1939 fn test_constraints_combined() {
1940 let opts = ExpressionConstraintOptions {
1941 rational_exponents: true,
1942 rational_trig_args: true,
1943 max_trig_cycles: Some(1),
1944 ..Default::default()
1945 };
1946
1947 let expr = expr_from_postfix("x2^1S+"); assert!(
1950 expression_respects_constraints(&expr, opts),
1951 "x^2 + sinpi(1) should pass all constraints"
1952 );
1953
1954 let expr = expr_from_postfix("xp^");
1956 assert!(
1957 !expression_respects_constraints(&expr, opts),
1958 "x^pi should fail rational_exponents"
1959 );
1960
1961 let expr = expr_from_postfix("xS"); assert!(
1964 !expression_respects_constraints(&expr, opts),
1965 "sinpi(x) should fail rational_trig_args"
1966 );
1967
1968 let expr = expr_from_postfix("1CS"); assert!(
1971 !expression_respects_constraints(&expr, opts),
1972 "double trig should fail max_trig_cycles=1"
1973 );
1974 }
1975
1976 #[test]
1977 fn test_constraints_malformed_expression() {
1978 let opts = ExpressionConstraintOptions::default();
1979
1980 let expr = Expression::from_symbols(&[crate::symbol::Symbol::Add]); assert!(
1983 !expression_respects_constraints(&expr, opts),
1984 "Malformed expression should return false"
1985 );
1986
1987 let expr =
1989 Expression::from_symbols(&[crate::symbol::Symbol::One, crate::symbol::Symbol::Two]);
1990 assert!(
1991 !expression_respects_constraints(&expr, opts),
1992 "Incomplete expression should return false"
1993 );
1994 }
1995
1996 #[test]
1997 fn test_constraints_user_constant_types() {
1998 let mut user_types = [NumType::Transcendental; 16];
2000 user_types[0] = NumType::Integer;
2001
2002 let opts = ExpressionConstraintOptions {
2003 rational_exponents: true,
2004 user_constant_types: user_types,
2005 ..Default::default()
2006 };
2007
2008 assert_eq!(opts.user_constant_types[0], NumType::Integer);
2012 }
2013}