1use super::newton::newton_raphson_with_constants;
2use super::{Match, SearchConfig, SearchContext, SearchStats, SearchTimer};
3
4use crate::expr::EvaluatedExpr;
5
6use crate::pool::TopKPool;
7
8use crate::thresholds::{
9 ADAPTIVE_COMPLEXITY_SCALE, ADAPTIVE_EXACT_MATCH_FACTOR, ADAPTIVE_POOL_FULLNESS_SCALE,
10 BASE_SEARCH_RADIUS_FACTOR, DEGENERATE_RANGE_TOLERANCE, DEGENERATE_TEST_THRESHOLD,
11 EXACT_MATCH_TOLERANCE, MAX_SEARCH_RADIUS_FACTOR, NEWTON_FINAL_TOLERANCE, TIER_0_MAX,
12 TIER_1_MAX, TIER_2_MAX,
13};
14
15pub struct ExprDatabase {
18 rhs_sorted: Vec<EvaluatedExpr>,
20}
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
31pub enum ComplexityTier {
32 Tier0,
34 Tier1,
36 Tier2,
38 Tier3,
40}
41
42impl ComplexityTier {
43 #[inline]
45 pub fn from_complexity(complexity: u32) -> Self {
46 if complexity <= TIER_0_MAX {
47 ComplexityTier::Tier0
48 } else if complexity <= TIER_1_MAX {
49 ComplexityTier::Tier1
50 } else if complexity <= TIER_2_MAX {
51 ComplexityTier::Tier2
52 } else {
53 ComplexityTier::Tier3
54 }
55 }
56}
57
58pub struct TieredExprDatabase {
64 tiers: [Vec<EvaluatedExpr>; 4],
66 total_count: usize,
68}
69
70impl TieredExprDatabase {
71 pub fn new() -> Self {
73 Self {
74 tiers: [Vec::new(), Vec::new(), Vec::new(), Vec::new()],
75 total_count: 0,
76 }
77 }
78
79 pub fn insert(&mut self, expr: EvaluatedExpr) {
81 let tier = ComplexityTier::from_complexity(expr.expr.complexity());
82 let tier_idx = tier as usize;
83 self.tiers[tier_idx].push(expr);
84 self.total_count += 1;
85 }
86
87 pub fn finalize(&mut self) {
89 for tier in &mut self.tiers {
90 tier.sort_by(|a, b| a.value.total_cmp(&b.value));
91 }
92 }
93
94 pub fn total_count(&self) -> usize {
96 self.total_count
97 }
98
99 #[allow(dead_code)]
101 pub fn tier_count(&self, tier: ComplexityTier) -> usize {
102 self.tiers[tier as usize].len()
103 }
104
105 #[cfg(test)]
106 pub(super) fn tier(&self, tier: ComplexityTier) -> &[EvaluatedExpr] {
107 &self.tiers[tier as usize]
108 }
109
110 #[allow(dead_code)]
112 pub fn range_in_tier(&self, tier: ComplexityTier, low: f64, high: f64) -> &[EvaluatedExpr] {
113 let tier_vec = &self.tiers[tier as usize];
114 let start = tier_vec.partition_point(|e| e.value < low);
115 let end = tier_vec.partition_point(|e| e.value <= high);
116 &tier_vec[start..end]
117 }
118
119 pub fn iter_tiers_in_range(&self, low: f64, high: f64) -> TieredRangeIter<'_> {
122 TieredRangeIter::new(self, low, high)
123 }
124}
125
126impl Default for TieredExprDatabase {
127 fn default() -> Self {
128 Self::new()
129 }
130}
131
132pub struct TieredRangeIter<'a> {
134 db: &'a TieredExprDatabase,
135 low: f64,
136 high: f64,
137 current_tier: usize,
138 current_start: usize,
139 current_end: usize,
140}
141
142impl<'a> TieredRangeIter<'a> {
143 fn new(db: &'a TieredExprDatabase, low: f64, high: f64) -> Self {
144 let mut iter = Self {
145 db,
146 low,
147 high,
148 current_tier: 0,
149 current_start: 0,
150 current_end: 0,
151 };
152 iter.find_next_nonempty_tier();
153 iter
154 }
155
156 fn calculate_tier_range(&self, tier_idx: usize) -> (usize, usize) {
158 let tier_vec = &self.db.tiers[tier_idx];
159 let start = tier_vec.partition_point(|e| e.value < self.low);
160 let end = tier_vec.partition_point(|e| e.value <= self.high);
161 (start, end)
162 }
163
164 fn find_next_nonempty_tier(&mut self) {
166 while self.current_tier < 4 {
167 let (start, end) = self.calculate_tier_range(self.current_tier);
168 self.current_start = start;
169 self.current_end = end;
170
171 if self.current_start < self.current_end {
172 return;
174 }
175 self.current_tier += 1;
176 }
177 }
178}
179
180impl<'a> Iterator for TieredRangeIter<'a> {
181 type Item = &'a EvaluatedExpr;
182
183 fn next(&mut self) -> Option<Self::Item> {
184 while self.current_tier < 4 {
185 if self.current_start < self.current_end {
186 let expr = &self.db.tiers[self.current_tier][self.current_start];
187 self.current_start += 1;
188 return Some(expr);
189 }
190 self.current_tier += 1;
191 self.find_next_nonempty_tier();
192 }
193 None
194 }
195}
196
197#[inline]
218pub(super) fn calculate_adaptive_search_radius(
219 derivative: f64,
220 complexity: u32,
221 pool_size: usize,
222 pool_capacity: usize,
223 best_error: f64,
224) -> f64 {
225 let deriv_abs = derivative.abs();
226
227 let base_radius = BASE_SEARCH_RADIUS_FACTOR * deriv_abs;
229
230 let normalized_complexity = (complexity as f64) / 50.0;
233 let complexity_factor = 1.0 / (1.0 + ADAPTIVE_COMPLEXITY_SCALE * normalized_complexity);
234
235 let pool_fraction = if pool_capacity > 0 {
237 pool_size as f64 / pool_capacity as f64
238 } else {
239 0.0
240 };
241 let pool_factor = (1.0 - ADAPTIVE_POOL_FULLNESS_SCALE * pool_fraction).max(0.1);
242
243 let exact_factor = if best_error < NEWTON_FINAL_TOLERANCE {
245 ADAPTIVE_EXACT_MATCH_FACTOR
246 } else {
247 1.0
248 };
249
250 let radius = base_radius * complexity_factor * pool_factor * exact_factor;
252
253 let min_radius = 0.1 * deriv_abs; radius
256 .max(min_radius)
257 .min(MAX_SEARCH_RADIUS_FACTOR * deriv_abs)
258}
259
260impl ExprDatabase {
261 pub fn new() -> Self {
262 Self {
263 rhs_sorted: Vec::new(),
264 }
265 }
266
267 pub fn insert_rhs(&mut self, mut exprs: Vec<EvaluatedExpr>) {
270 exprs.sort_by(|a, b| a.value.total_cmp(&b.value));
273 self.rhs_sorted = exprs;
274 }
275
276 pub fn rhs_count(&self) -> usize {
278 self.rhs_sorted.len()
279 }
280
281 #[inline]
284 pub fn range(&self, low: f64, high: f64) -> &[EvaluatedExpr] {
285 let start = self.rhs_sorted.partition_point(|e| e.value < low);
287 let end = self.rhs_sorted.partition_point(|e| e.value <= high);
288 &self.rhs_sorted[start..end]
289 }
290
291 #[allow(dead_code)]
296 pub fn find_matches(&self, lhs_exprs: &[EvaluatedExpr], config: &SearchConfig) -> Vec<Match> {
297 let (matches, _stats) = self.find_matches_with_stats(lhs_exprs, config);
298 matches
299 }
300
301 pub fn find_matches_with_context(
303 &self,
304 lhs_exprs: &[EvaluatedExpr],
305 context: &SearchContext<'_>,
306 ) -> Vec<Match> {
307 let (matches, _stats) = self.find_matches_with_stats_and_context(lhs_exprs, context);
308 matches
309 }
310
311 pub fn find_matches_with_stats(
313 &self,
314 lhs_exprs: &[EvaluatedExpr],
315 config: &SearchConfig,
316 ) -> (Vec<Match>, SearchStats) {
317 let context = SearchContext::new(config);
318 self.find_matches_with_stats_and_context(lhs_exprs, &context)
319 }
320
321 pub fn find_matches_with_stats_and_context(
323 &self,
324 lhs_exprs: &[EvaluatedExpr],
325 context: &SearchContext<'_>,
326 ) -> (Vec<Match>, SearchStats) {
327 let config = context.config;
328 let mut stats = SearchStats::new();
329 let search_start = SearchTimer::start();
330
331 let initial_max_error = config.max_error.max(1e-12);
333
334 let mut pool = TopKPool::new_with_diagnostics(
336 config.max_matches,
337 initial_max_error,
338 config.show_db_adds,
339 config.ranking_mode,
340 );
341
342 let mut sorted_lhs: Vec<_> = lhs_exprs.iter().collect();
344 sorted_lhs.sort_by_key(|e| e.expr.complexity());
345
346 let mut early_exit = false;
348
349 'outer: for lhs in sorted_lhs {
350 if early_exit {
352 break;
353 }
354 if lhs.value.abs() < config.zero_value_threshold {
358 if config.show_pruned_range {
359 eprintln!(
360 " [pruned range] value={:.6e} reason=\"near-zero\" expr=\"{}\"",
361 lhs.value,
362 lhs.expr.to_infix()
363 );
364 }
365 continue;
366 }
367
368 if lhs.derivative.abs() < DEGENERATE_TEST_THRESHOLD {
371 let test_x = config.target + std::f64::consts::E;
376 if let Ok(test_result) =
380 crate::eval::evaluate_fast_with_context(&lhs.expr, test_x, &context.eval)
381 {
382 let value_unchanged =
385 (test_result.value - lhs.value).abs() < DEGENERATE_TEST_THRESHOLD;
386 let deriv_still_zero = test_result.derivative.abs() < DEGENERATE_TEST_THRESHOLD;
387 if deriv_still_zero || value_unchanged {
388 continue;
390 }
391 }
392 let val_error = DEGENERATE_RANGE_TOLERANCE;
395 let low = lhs.value - val_error;
396 let high = lhs.value + val_error;
397
398 stats.lhs_tested += 1;
399 for rhs in self.range(low, high) {
400 if !config.rhs_symbol_allowed(&rhs.expr) {
401 continue;
402 }
403 stats.candidates_tested += 1;
404 if config.show_match_checks {
405 eprintln!(
406 " [match] checking lhs={:.6} rhs={:.6}",
407 lhs.value, rhs.value
408 );
409 }
410 let val_diff = (lhs.value - rhs.value).abs();
411 if val_diff < val_error && pool.would_accept(0.0, true) {
412 let m = Match {
413 lhs: lhs.clone(),
414 rhs: rhs.clone(),
415 x_value: config.target,
416 error: 0.0,
417 complexity: lhs.expr.complexity() + rhs.expr.complexity(),
418 };
419 pool.try_insert(m);
420 }
421 }
422 continue;
423 }
424
425 stats.lhs_tested += 1;
426
427 let min_search_radius = 0.5 * lhs.derivative.abs(); let search_radius = (pool.accept_error * lhs.derivative.abs()).max(min_search_radius);
431 let low = lhs.value - search_radius;
432 let high = lhs.value + search_radius;
433
434 let rhs_slice = self.range(low, high);
435 for rhs in rhs_slice {
438 if !config.rhs_symbol_allowed(&rhs.expr) {
439 continue;
440 }
441 stats.candidates_tested += 1;
442 if config.show_match_checks {
443 eprintln!(
444 " [match] checking lhs={:.6} rhs={:.6}",
445 lhs.value, rhs.value
446 );
447 }
448
449 let val_diff = lhs.value - rhs.value;
451 let x_delta = -val_diff / lhs.derivative;
452 let coarse_error = x_delta.abs();
453
454 let is_potentially_exact = coarse_error < NEWTON_FINAL_TOLERANCE;
457 if !pool.would_accept_strict(coarse_error, is_potentially_exact) {
458 continue;
459 }
460
461 if !config.refine_with_newton {
462 let refined_x = config.target + x_delta;
463 let refined_error = x_delta;
464 let is_exact = refined_error.abs() < EXACT_MATCH_TOLERANCE;
465
466 if pool.would_accept(refined_error.abs(), is_exact) {
467 let m = Match {
468 lhs: lhs.clone(),
469 rhs: rhs.clone(),
470 x_value: refined_x,
471 error: refined_error,
472 complexity: lhs.expr.complexity() + rhs.expr.complexity(),
473 };
474
475 pool.try_insert(m);
476
477 if config.stop_at_exact && is_exact {
478 early_exit = true;
479 break 'outer;
480 }
481 if let Some(threshold) = config.stop_below {
482 if refined_error.abs() < threshold {
483 early_exit = true;
484 break 'outer;
485 }
486 }
487 }
488 continue;
489 }
490
491 stats.newton_calls += 1;
493 if let Some(refined_x) = newton_raphson_with_constants(
494 &lhs.expr,
495 rhs.value,
496 config.target,
497 config.newton_iterations,
498 &context.eval,
499 config.show_newton,
500 config.derivative_margin,
501 ) {
502 stats.newton_success += 1;
503 let refined_error = refined_x - config.target;
504 let is_exact = refined_error.abs() < EXACT_MATCH_TOLERANCE;
505
506 if pool.would_accept(refined_error.abs(), is_exact) {
508 let m = Match {
509 lhs: lhs.clone(),
510 rhs: rhs.clone(),
511 x_value: refined_x,
512 error: refined_error,
513 complexity: lhs.expr.complexity() + rhs.expr.complexity(),
514 };
515
516 pool.try_insert(m);
518
519 if config.stop_at_exact && is_exact {
521 early_exit = true;
522 break 'outer;
523 }
524 if let Some(threshold) = config.stop_below {
525 if refined_error.abs() < threshold {
526 early_exit = true;
527 break 'outer;
528 }
529 }
530 }
531 }
532 }
533 }
534
535 stats.pool_insertions = pool.stats.insertions;
537 stats.pool_rejections_error = pool.stats.rejections_error;
538 stats.pool_rejections_dedupe = pool.stats.rejections_dedupe;
539 stats.pool_evictions = pool.stats.evictions;
540 stats.pool_final_size = pool.len();
541 stats.pool_best_error = pool.best_error;
542 stats.search_time = search_start.elapsed();
543 stats.early_exit = early_exit;
544
545 (pool.into_sorted(), stats)
547 }
548}
549
550impl Default for ExprDatabase {
551 fn default() -> Self {
552 Self::new()
553 }
554}