Skip to main content

ries_rs/search/
db.rs

1use super::newton::newton_raphson_with_constants;
2use super::{Match, SearchConfig, SearchContext, SearchStats, SearchTimer};
3
4use crate::expr::EvaluatedExpr;
5
6use crate::pool::TopKPool;
7
8use crate::thresholds::{
9    ADAPTIVE_COMPLEXITY_SCALE, ADAPTIVE_EXACT_MATCH_FACTOR, ADAPTIVE_POOL_FULLNESS_SCALE,
10    BASE_SEARCH_RADIUS_FACTOR, DEGENERATE_RANGE_TOLERANCE, DEGENERATE_TEST_THRESHOLD,
11    EXACT_MATCH_TOLERANCE, MAX_SEARCH_RADIUS_FACTOR, NEWTON_FINAL_TOLERANCE, TIER_0_MAX,
12    TIER_1_MAX, TIER_2_MAX,
13};
14
15/// Database for storing expressions sorted by value
16/// Uses a flat sorted vector for cache-friendly range scans
17pub struct ExprDatabase {
18    /// RHS expressions sorted by value (flat vector for cache locality)
19    rhs_sorted: Vec<EvaluatedExpr>,
20}
21
22// =============================================================================
23// TIERED DATABASE FOR MULTI-LEVEL INDEXING
24// =============================================================================
25
26/// Complexity tier for tiered search
27///
28/// Lower tiers contain simpler expressions and are searched first.
29/// This allows early exit when good matches are found in simpler tiers.
30#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
31pub enum ComplexityTier {
32    /// Tier 0: complexity 0-15 (simplest expressions)
33    Tier0,
34    /// Tier 1: complexity 16-25
35    Tier1,
36    /// Tier 2: complexity 26-35
37    Tier2,
38    /// Tier 3: complexity 36+ (most complex)
39    Tier3,
40}
41
42impl ComplexityTier {
43    /// Determine the tier for a given complexity value
44    #[inline]
45    pub fn from_complexity(complexity: u32) -> Self {
46        if complexity <= TIER_0_MAX {
47            ComplexityTier::Tier0
48        } else if complexity <= TIER_1_MAX {
49            ComplexityTier::Tier1
50        } else if complexity <= TIER_2_MAX {
51            ComplexityTier::Tier2
52        } else {
53            ComplexityTier::Tier3
54        }
55    }
56}
57
58/// Database with tiered storage for efficient priority-based searching
59///
60/// Expressions are organized by complexity tiers, allowing searches to
61/// process simpler expressions first and potentially skip higher tiers
62/// when good matches are found.
63pub struct TieredExprDatabase {
64    /// RHS expressions organized by tier, each sorted by value
65    tiers: [Vec<EvaluatedExpr>; 4],
66    /// Total count across all tiers
67    total_count: usize,
68}
69
70impl TieredExprDatabase {
71    /// Create a new empty tiered database
72    pub fn new() -> Self {
73        Self {
74            tiers: [Vec::new(), Vec::new(), Vec::new(), Vec::new()],
75            total_count: 0,
76        }
77    }
78
79    /// Insert an expression into the appropriate tier
80    pub fn insert(&mut self, expr: EvaluatedExpr) {
81        let tier = ComplexityTier::from_complexity(expr.expr.complexity());
82        let tier_idx = tier as usize;
83        self.tiers[tier_idx].push(expr);
84        self.total_count += 1;
85    }
86
87    /// Finalize the database by sorting each tier by value
88    pub fn finalize(&mut self) {
89        for tier in &mut self.tiers {
90            tier.sort_by(|a, b| a.value.total_cmp(&b.value));
91        }
92    }
93
94    /// Get total count of expressions across all tiers
95    pub fn total_count(&self) -> usize {
96        self.total_count
97    }
98
99    /// Get count for a specific tier
100    #[allow(dead_code)]
101    pub fn tier_count(&self, tier: ComplexityTier) -> usize {
102        self.tiers[tier as usize].len()
103    }
104
105    #[cfg(test)]
106    pub(super) fn tier(&self, tier: ComplexityTier) -> &[EvaluatedExpr] {
107        &self.tiers[tier as usize]
108    }
109
110    /// Find expressions in a specific tier within the value range [low, high]
111    #[allow(dead_code)]
112    pub fn range_in_tier(&self, tier: ComplexityTier, low: f64, high: f64) -> &[EvaluatedExpr] {
113        let tier_vec = &self.tiers[tier as usize];
114        let start = tier_vec.partition_point(|e| e.value < low);
115        let end = tier_vec.partition_point(|e| e.value <= high);
116        &tier_vec[start..end]
117    }
118
119    /// Create an iterator that yields expressions from all tiers in order
120    /// (Tier 0 first, then Tier 1, etc.) within a value range
121    pub fn iter_tiers_in_range(&self, low: f64, high: f64) -> TieredRangeIter<'_> {
122        TieredRangeIter::new(self, low, high)
123    }
124}
125
126impl Default for TieredExprDatabase {
127    fn default() -> Self {
128        Self::new()
129    }
130}
131
132/// Iterator over expressions in a value range, yielding from lower tiers first
133pub struct TieredRangeIter<'a> {
134    db: &'a TieredExprDatabase,
135    low: f64,
136    high: f64,
137    current_tier: usize,
138    current_start: usize,
139    current_end: usize,
140}
141
142impl<'a> TieredRangeIter<'a> {
143    fn new(db: &'a TieredExprDatabase, low: f64, high: f64) -> Self {
144        let mut iter = Self {
145            db,
146            low,
147            high,
148            current_tier: 0,
149            current_start: 0,
150            current_end: 0,
151        };
152        iter.find_next_nonempty_tier();
153        iter
154    }
155
156    /// Calculate the range [start, end) of expressions in a tier that fall within [low, high]
157    fn calculate_tier_range(&self, tier_idx: usize) -> (usize, usize) {
158        let tier_vec = &self.db.tiers[tier_idx];
159        let start = tier_vec.partition_point(|e| e.value < self.low);
160        let end = tier_vec.partition_point(|e| e.value <= self.high);
161        (start, end)
162    }
163
164    /// Advance to the next tier with matching expressions
165    fn find_next_nonempty_tier(&mut self) {
166        while self.current_tier < 4 {
167            let (start, end) = self.calculate_tier_range(self.current_tier);
168            self.current_start = start;
169            self.current_end = end;
170
171            if self.current_start < self.current_end {
172                // Found expressions in this tier
173                return;
174            }
175            self.current_tier += 1;
176        }
177    }
178}
179
180impl<'a> Iterator for TieredRangeIter<'a> {
181    type Item = &'a EvaluatedExpr;
182
183    fn next(&mut self) -> Option<Self::Item> {
184        while self.current_tier < 4 {
185            if self.current_start < self.current_end {
186                let expr = &self.db.tiers[self.current_tier][self.current_start];
187                self.current_start += 1;
188                return Some(expr);
189            }
190            self.current_tier += 1;
191            self.find_next_nonempty_tier();
192        }
193        None
194    }
195}
196
197// =============================================================================
198// ADAPTIVE SEARCH RADIUS
199// =============================================================================
200
201/// Calculate adaptive search radius based on multiple factors
202///
203/// The search radius determines how far from an LHS value we look for
204/// matching RHS expressions. A tighter radius means fewer candidates
205/// but faster search; a wider radius means more candidates but slower.
206///
207/// # Factors
208///
209/// 1. **Derivative magnitude**: Larger derivative = tighter radius (faster convergence)
210/// 2. **Complexity**: Higher complexity = tighter radius (prefer simpler matches)
211/// 3. **Pool fullness**: Fuller pool = tighter radius (be more selective)
212/// 4. **Best error found**: If we have exact matches, be very selective
213///
214/// # Returns
215///
216/// The search radius as an absolute value (not relative to derivative).
217#[inline]
218pub(super) fn calculate_adaptive_search_radius(
219    derivative: f64,
220    complexity: u32,
221    pool_size: usize,
222    pool_capacity: usize,
223    best_error: f64,
224) -> f64 {
225    let deriv_abs = derivative.abs();
226
227    // Base radius: proportional to derivative
228    let base_radius = BASE_SEARCH_RADIUS_FACTOR * deriv_abs;
229
230    // Complexity factor: reduce radius for complex expressions
231    // normalized_complexity is roughly 0-1 for typical complexity ranges
232    let normalized_complexity = (complexity as f64) / 50.0;
233    let complexity_factor = 1.0 / (1.0 + ADAPTIVE_COMPLEXITY_SCALE * normalized_complexity);
234
235    // Pool fullness factor: reduce radius as pool fills up
236    let pool_fraction = if pool_capacity > 0 {
237        pool_size as f64 / pool_capacity as f64
238    } else {
239        0.0
240    };
241    let pool_factor = (1.0 - ADAPTIVE_POOL_FULLNESS_SCALE * pool_fraction).max(0.1);
242
243    // Exact match factor: if we have good matches, be very selective
244    let exact_factor = if best_error < NEWTON_FINAL_TOLERANCE {
245        ADAPTIVE_EXACT_MATCH_FACTOR
246    } else {
247        1.0
248    };
249
250    // Combined radius
251    let radius = base_radius * complexity_factor * pool_factor * exact_factor;
252
253    // Ensure we have a reasonable minimum and cap at maximum
254    let min_radius = 0.1 * deriv_abs; // At least 0.1 * derivative
255    radius
256        .max(min_radius)
257        .min(MAX_SEARCH_RADIUS_FACTOR * deriv_abs)
258}
259
260impl ExprDatabase {
261    pub fn new() -> Self {
262        Self {
263            rhs_sorted: Vec::new(),
264        }
265    }
266
267    /// Insert RHS expressions into the database
268    /// Sorts by value for efficient range queries using partition_point
269    pub fn insert_rhs(&mut self, mut exprs: Vec<EvaluatedExpr>) {
270        // Sort by value for binary search range queries
271        // Use total_cmp for consistent ordering (NaN sorts as greater than all floats)
272        exprs.sort_by(|a, b| a.value.total_cmp(&b.value));
273        self.rhs_sorted = exprs;
274    }
275
276    /// Get total count of RHS expressions
277    pub fn rhs_count(&self) -> usize {
278        self.rhs_sorted.len()
279    }
280
281    /// Find RHS expressions in the value range [low, high]
282    /// Returns a slice of matching expressions (contiguous, cache-friendly)
283    #[inline]
284    pub fn range(&self, low: f64, high: f64) -> &[EvaluatedExpr] {
285        // Binary search for range bounds using partition_point
286        let start = self.rhs_sorted.partition_point(|e| e.value < low);
287        let end = self.rhs_sorted.partition_point(|e| e.value <= high);
288        &self.rhs_sorted[start..end]
289    }
290
291    /// Find matches for LHS expressions using streaming collection
292    ///
293    /// This method is part of the public API for library consumers who want
294    /// to perform matching without statistics collection.
295    #[allow(dead_code)]
296    pub fn find_matches(&self, lhs_exprs: &[EvaluatedExpr], config: &SearchConfig) -> Vec<Match> {
297        let (matches, _stats) = self.find_matches_with_stats(lhs_exprs, config);
298        matches
299    }
300
301    /// Find matches with an explicit per-run search context.
302    pub fn find_matches_with_context(
303        &self,
304        lhs_exprs: &[EvaluatedExpr],
305        context: &SearchContext<'_>,
306    ) -> Vec<Match> {
307        let (matches, _stats) = self.find_matches_with_stats_and_context(lhs_exprs, context);
308        matches
309    }
310
311    /// Find matches with statistics collection
312    pub fn find_matches_with_stats(
313        &self,
314        lhs_exprs: &[EvaluatedExpr],
315        config: &SearchConfig,
316    ) -> (Vec<Match>, SearchStats) {
317        let context = SearchContext::new(config);
318        self.find_matches_with_stats_and_context(lhs_exprs, &context)
319    }
320
321    /// Find matches with statistics collection using an explicit per-run search context.
322    pub fn find_matches_with_stats_and_context(
323        &self,
324        lhs_exprs: &[EvaluatedExpr],
325        context: &SearchContext<'_>,
326    ) -> (Vec<Match>, SearchStats) {
327        let config = context.config;
328        let mut stats = SearchStats::new();
329        let search_start = SearchTimer::start();
330
331        // Respect configured max error (with a tiny floor for numerical stability)
332        let initial_max_error = config.max_error.max(1e-12);
333
334        // Create bounded pool with configured capacity
335        let mut pool = TopKPool::new_with_diagnostics(
336            config.max_matches,
337            initial_max_error,
338            config.show_db_adds,
339            config.ranking_mode,
340        );
341
342        // Sort LHS by complexity so simpler expressions are processed first
343        let mut sorted_lhs: Vec<_> = lhs_exprs.iter().collect();
344        sorted_lhs.sort_by_key(|e| e.expr.complexity());
345
346        // Early exit tracking
347        let mut early_exit = false;
348
349        'outer: for lhs in sorted_lhs {
350            // Check early exit conditions
351            if early_exit {
352                break;
353            }
354            // Skip LHS with value too close to 0 - these produce floods of
355            // trivial matches (like cospi(2.5)=0 matching many RHS near 0)
356            // Original RIES: "Prune zero subexpressions"
357            if lhs.value.abs() < config.zero_value_threshold {
358                if config.show_pruned_range {
359                    eprintln!(
360                        "  [pruned range] value={:.6e} reason=\"near-zero\" expr=\"{}\"",
361                        lhs.value,
362                        lhs.expr.to_infix()
363                    );
364                }
365                continue;
366            }
367
368            // Skip degenerate expressions: contain x but derivative is 0
369            // These are trivial identities like 1^x=1, x/x=1, log_x(x)=1
370            if lhs.derivative.abs() < DEGENERATE_TEST_THRESHOLD {
371                // To distinguish true repeated roots from degenerate expressions,
372                // evaluate at a different x value. Degenerate expressions have
373                // derivative 0 everywhere; true repeated roots only at specific x.
374                // Use an irrational offset to avoid hitting special values
375                let test_x = config.target + std::f64::consts::E;
376                // Use the full evaluator (including user_functions) so that UDF-containing
377                // expressions are not silently skipped due to evaluate_with_constants
378                // returning Err for user-function symbols.
379                if let Ok(test_result) =
380                    crate::eval::evaluate_fast_with_context(&lhs.expr, test_x, &context.eval)
381                {
382                    // Check both: derivative still ~0, AND value unchanged
383                    // This catches x*(1/x)=1 type expressions
384                    let value_unchanged =
385                        (test_result.value - lhs.value).abs() < DEGENERATE_TEST_THRESHOLD;
386                    let deriv_still_zero = test_result.derivative.abs() < DEGENERATE_TEST_THRESHOLD;
387                    if deriv_still_zero || value_unchanged {
388                        // Degenerate expression - skip
389                        continue;
390                    }
391                }
392                // Derivative is non-zero at test_x, so this might be a true repeated root
393                // Check if LHS(target) ≈ some RHS
394                let val_error = DEGENERATE_RANGE_TOLERANCE;
395                let low = lhs.value - val_error;
396                let high = lhs.value + val_error;
397
398                stats.lhs_tested += 1;
399                for rhs in self.range(low, high) {
400                    if !config.rhs_symbol_allowed(&rhs.expr) {
401                        continue;
402                    }
403                    stats.candidates_tested += 1;
404                    if config.show_match_checks {
405                        eprintln!(
406                            "  [match] checking lhs={:.6} rhs={:.6}",
407                            lhs.value, rhs.value
408                        );
409                    }
410                    let val_diff = (lhs.value - rhs.value).abs();
411                    if val_diff < val_error && pool.would_accept(0.0, true) {
412                        let m = Match {
413                            lhs: lhs.clone(),
414                            rhs: rhs.clone(),
415                            x_value: config.target,
416                            error: 0.0,
417                            complexity: lhs.expr.complexity() + rhs.expr.complexity(),
418                        };
419                        pool.try_insert(m);
420                    }
421                }
422                continue;
423            }
424
425            stats.lhs_tested += 1;
426
427            // Search for RHS expressions near this LHS value
428            // Use adaptive search radius based on current thresholds
429            let min_search_radius = 0.5 * lhs.derivative.abs(); // Allow ~0.5 error in x
430            let search_radius = (pool.accept_error * lhs.derivative.abs()).max(min_search_radius);
431            let low = lhs.value - search_radius;
432            let high = lhs.value + search_radius;
433
434            let rhs_slice = self.range(low, high);
435            // Track slice sizes for optimization analysis
436            // println!("LHS {} (val={:.4}): slice size = {}", lhs.expr.to_postfix(), lhs.value, rhs_slice.len());
437            for rhs in rhs_slice {
438                if !config.rhs_symbol_allowed(&rhs.expr) {
439                    continue;
440                }
441                stats.candidates_tested += 1;
442                if config.show_match_checks {
443                    eprintln!(
444                        "  [match] checking lhs={:.6} rhs={:.6}",
445                        lhs.value, rhs.value
446                    );
447                }
448
449                // Compute initial error estimate (coarse filter)
450                let val_diff = lhs.value - rhs.value;
451                let x_delta = -val_diff / lhs.derivative;
452                let coarse_error = x_delta.abs();
453
454                // Skip if coarse estimate won't pass threshold
455                // Use strict gate to avoid expensive Newton calls for marginal candidates
456                let is_potentially_exact = coarse_error < NEWTON_FINAL_TOLERANCE;
457                if !pool.would_accept_strict(coarse_error, is_potentially_exact) {
458                    continue;
459                }
460
461                if !config.refine_with_newton {
462                    let refined_x = config.target + x_delta;
463                    let refined_error = x_delta;
464                    let is_exact = refined_error.abs() < EXACT_MATCH_TOLERANCE;
465
466                    if pool.would_accept(refined_error.abs(), is_exact) {
467                        let m = Match {
468                            lhs: lhs.clone(),
469                            rhs: rhs.clone(),
470                            x_value: refined_x,
471                            error: refined_error,
472                            complexity: lhs.expr.complexity() + rhs.expr.complexity(),
473                        };
474
475                        pool.try_insert(m);
476
477                        if config.stop_at_exact && is_exact {
478                            early_exit = true;
479                            break 'outer;
480                        }
481                        if let Some(threshold) = config.stop_below {
482                            if refined_error.abs() < threshold {
483                                early_exit = true;
484                                break 'outer;
485                            }
486                        }
487                    }
488                    continue;
489                }
490
491                // Refine with Newton-Raphson
492                stats.newton_calls += 1;
493                if let Some(refined_x) = newton_raphson_with_constants(
494                    &lhs.expr,
495                    rhs.value,
496                    config.target,
497                    config.newton_iterations,
498                    &context.eval,
499                    config.show_newton,
500                    config.derivative_margin,
501                ) {
502                    stats.newton_success += 1;
503                    let refined_error = refined_x - config.target;
504                    let is_exact = refined_error.abs() < EXACT_MATCH_TOLERANCE;
505
506                    // Check if this is acceptable
507                    if pool.would_accept(refined_error.abs(), is_exact) {
508                        let m = Match {
509                            lhs: lhs.clone(),
510                            rhs: rhs.clone(),
511                            x_value: refined_x,
512                            error: refined_error,
513                            complexity: lhs.expr.complexity() + rhs.expr.complexity(),
514                        };
515
516                        // Insert into pool (handles thresholds and eviction)
517                        pool.try_insert(m);
518
519                        // Check early exit conditions
520                        if config.stop_at_exact && is_exact {
521                            early_exit = true;
522                            break 'outer;
523                        }
524                        if let Some(threshold) = config.stop_below {
525                            if refined_error.abs() < threshold {
526                                early_exit = true;
527                                break 'outer;
528                            }
529                        }
530                    }
531                }
532            }
533        }
534
535        // Collect pool stats
536        stats.pool_insertions = pool.stats.insertions;
537        stats.pool_rejections_error = pool.stats.rejections_error;
538        stats.pool_rejections_dedupe = pool.stats.rejections_dedupe;
539        stats.pool_evictions = pool.stats.evictions;
540        stats.pool_final_size = pool.len();
541        stats.pool_best_error = pool.best_error;
542        stats.search_time = search_start.elapsed();
543        stats.early_exit = early_exit;
544
545        // Return sorted matches from pool
546        (pool.into_sorted(), stats)
547    }
548}
549
550impl Default for ExprDatabase {
551    fn default() -> Self {
552        Self::new()
553    }
554}