use super::newton::newton_raphson_with_constants;
use super::{Match, SearchConfig, SearchContext, SearchStats, SearchTimer};
use crate::expr::EvaluatedExpr;
use crate::pool::TopKPool;
use crate::thresholds::{
ADAPTIVE_COMPLEXITY_SCALE, ADAPTIVE_EXACT_MATCH_FACTOR, ADAPTIVE_POOL_FULLNESS_SCALE,
BASE_SEARCH_RADIUS_FACTOR, DEGENERATE_RANGE_TOLERANCE, DEGENERATE_TEST_THRESHOLD,
EXACT_MATCH_TOLERANCE, MAX_SEARCH_RADIUS_FACTOR, NEWTON_FINAL_TOLERANCE, TIER_0_MAX,
TIER_1_MAX, TIER_2_MAX,
};
pub struct ExprDatabase {
rhs_sorted: Vec<EvaluatedExpr>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum ComplexityTier {
Tier0,
Tier1,
Tier2,
Tier3,
}
impl ComplexityTier {
#[inline]
pub fn from_complexity(complexity: u32) -> Self {
if complexity <= TIER_0_MAX {
ComplexityTier::Tier0
} else if complexity <= TIER_1_MAX {
ComplexityTier::Tier1
} else if complexity <= TIER_2_MAX {
ComplexityTier::Tier2
} else {
ComplexityTier::Tier3
}
}
}
pub struct TieredExprDatabase {
tiers: [Vec<EvaluatedExpr>; 4],
total_count: usize,
}
impl TieredExprDatabase {
pub fn new() -> Self {
Self {
tiers: [Vec::new(), Vec::new(), Vec::new(), Vec::new()],
total_count: 0,
}
}
pub fn insert(&mut self, expr: EvaluatedExpr) {
let tier = ComplexityTier::from_complexity(expr.expr.complexity());
let tier_idx = tier as usize;
self.tiers[tier_idx].push(expr);
self.total_count += 1;
}
pub fn finalize(&mut self) {
for tier in &mut self.tiers {
tier.sort_by(|a, b| a.value.total_cmp(&b.value));
}
}
pub fn total_count(&self) -> usize {
self.total_count
}
#[allow(dead_code)]
pub fn tier_count(&self, tier: ComplexityTier) -> usize {
self.tiers[tier as usize].len()
}
#[cfg(test)]
pub(super) fn tier(&self, tier: ComplexityTier) -> &[EvaluatedExpr] {
&self.tiers[tier as usize]
}
#[allow(dead_code)]
pub fn range_in_tier(&self, tier: ComplexityTier, low: f64, high: f64) -> &[EvaluatedExpr] {
let tier_vec = &self.tiers[tier as usize];
let start = tier_vec.partition_point(|e| e.value < low);
let end = tier_vec.partition_point(|e| e.value <= high);
&tier_vec[start..end]
}
pub fn iter_tiers_in_range(&self, low: f64, high: f64) -> TieredRangeIter<'_> {
TieredRangeIter::new(self, low, high)
}
}
impl Default for TieredExprDatabase {
fn default() -> Self {
Self::new()
}
}
pub struct TieredRangeIter<'a> {
db: &'a TieredExprDatabase,
low: f64,
high: f64,
current_tier: usize,
current_start: usize,
current_end: usize,
}
impl<'a> TieredRangeIter<'a> {
fn new(db: &'a TieredExprDatabase, low: f64, high: f64) -> Self {
let mut iter = Self {
db,
low,
high,
current_tier: 0,
current_start: 0,
current_end: 0,
};
iter.find_next_nonempty_tier();
iter
}
fn calculate_tier_range(&self, tier_idx: usize) -> (usize, usize) {
let tier_vec = &self.db.tiers[tier_idx];
let start = tier_vec.partition_point(|e| e.value < self.low);
let end = tier_vec.partition_point(|e| e.value <= self.high);
(start, end)
}
fn find_next_nonempty_tier(&mut self) {
while self.current_tier < 4 {
let (start, end) = self.calculate_tier_range(self.current_tier);
self.current_start = start;
self.current_end = end;
if self.current_start < self.current_end {
return;
}
self.current_tier += 1;
}
}
}
impl<'a> Iterator for TieredRangeIter<'a> {
type Item = &'a EvaluatedExpr;
fn next(&mut self) -> Option<Self::Item> {
while self.current_tier < 4 {
if self.current_start < self.current_end {
let expr = &self.db.tiers[self.current_tier][self.current_start];
self.current_start += 1;
return Some(expr);
}
self.current_tier += 1;
self.find_next_nonempty_tier();
}
None
}
}
#[inline]
pub(super) fn calculate_adaptive_search_radius(
derivative: f64,
complexity: u32,
pool_size: usize,
pool_capacity: usize,
best_error: f64,
) -> f64 {
let deriv_abs = derivative.abs();
let base_radius = BASE_SEARCH_RADIUS_FACTOR * deriv_abs;
let normalized_complexity = (complexity as f64) / 50.0;
let complexity_factor = 1.0 / (1.0 + ADAPTIVE_COMPLEXITY_SCALE * normalized_complexity);
let pool_fraction = if pool_capacity > 0 {
pool_size as f64 / pool_capacity as f64
} else {
0.0
};
let pool_factor = (1.0 - ADAPTIVE_POOL_FULLNESS_SCALE * pool_fraction).max(0.1);
let exact_factor = if best_error < NEWTON_FINAL_TOLERANCE {
ADAPTIVE_EXACT_MATCH_FACTOR
} else {
1.0
};
let radius = base_radius * complexity_factor * pool_factor * exact_factor;
let min_radius = 0.1 * deriv_abs; radius
.max(min_radius)
.min(MAX_SEARCH_RADIUS_FACTOR * deriv_abs)
}
impl ExprDatabase {
pub fn new() -> Self {
Self {
rhs_sorted: Vec::new(),
}
}
pub fn insert_rhs(&mut self, mut exprs: Vec<EvaluatedExpr>) {
exprs.sort_by(|a, b| a.value.total_cmp(&b.value));
self.rhs_sorted = exprs;
}
pub fn rhs_count(&self) -> usize {
self.rhs_sorted.len()
}
#[inline]
pub fn range(&self, low: f64, high: f64) -> &[EvaluatedExpr] {
let start = self.rhs_sorted.partition_point(|e| e.value < low);
let end = self.rhs_sorted.partition_point(|e| e.value <= high);
&self.rhs_sorted[start..end]
}
#[allow(dead_code)]
pub fn find_matches(&self, lhs_exprs: &[EvaluatedExpr], config: &SearchConfig) -> Vec<Match> {
let (matches, _stats) = self.find_matches_with_stats(lhs_exprs, config);
matches
}
pub fn find_matches_with_context(
&self,
lhs_exprs: &[EvaluatedExpr],
context: &SearchContext<'_>,
) -> Vec<Match> {
let (matches, _stats) = self.find_matches_with_stats_and_context(lhs_exprs, context);
matches
}
pub fn find_matches_with_stats(
&self,
lhs_exprs: &[EvaluatedExpr],
config: &SearchConfig,
) -> (Vec<Match>, SearchStats) {
let context = SearchContext::new(config);
self.find_matches_with_stats_and_context(lhs_exprs, &context)
}
pub fn find_matches_with_stats_and_context(
&self,
lhs_exprs: &[EvaluatedExpr],
context: &SearchContext<'_>,
) -> (Vec<Match>, SearchStats) {
let config = context.config;
let mut stats = SearchStats::new();
let search_start = SearchTimer::start();
let initial_max_error = config.max_error.max(1e-12);
let mut pool = TopKPool::new_with_diagnostics(
config.max_matches,
initial_max_error,
config.show_db_adds,
config.ranking_mode,
);
let mut sorted_lhs: Vec<_> = lhs_exprs.iter().collect();
sorted_lhs.sort_by_key(|e| e.expr.complexity());
let mut early_exit = false;
'outer: for lhs in sorted_lhs {
if early_exit {
break;
}
if lhs.value.abs() < config.zero_value_threshold {
if config.show_pruned_range {
eprintln!(
" [pruned range] value={:.6e} reason=\"near-zero\" expr=\"{}\"",
lhs.value,
lhs.expr.to_infix()
);
}
continue;
}
if lhs.derivative.abs() < DEGENERATE_TEST_THRESHOLD {
let test_x = config.target + std::f64::consts::E;
if let Ok(test_result) =
crate::eval::evaluate_fast_with_context(&lhs.expr, test_x, &context.eval)
{
let value_unchanged =
(test_result.value - lhs.value).abs() < DEGENERATE_TEST_THRESHOLD;
let deriv_still_zero = test_result.derivative.abs() < DEGENERATE_TEST_THRESHOLD;
if deriv_still_zero || value_unchanged {
continue;
}
}
let val_error = DEGENERATE_RANGE_TOLERANCE;
let low = lhs.value - val_error;
let high = lhs.value + val_error;
stats.lhs_tested += 1;
for rhs in self.range(low, high) {
if !config.rhs_symbol_allowed(&rhs.expr) {
continue;
}
stats.candidates_tested += 1;
if config.show_match_checks {
eprintln!(
" [match] checking lhs={:.6} rhs={:.6}",
lhs.value, rhs.value
);
}
let val_diff = (lhs.value - rhs.value).abs();
if val_diff < val_error && pool.would_accept(0.0, true) {
let m = Match {
lhs: lhs.clone(),
rhs: rhs.clone(),
x_value: config.target,
error: 0.0,
complexity: lhs.expr.complexity() + rhs.expr.complexity(),
};
pool.try_insert(m);
}
}
continue;
}
stats.lhs_tested += 1;
let min_search_radius = 0.5 * lhs.derivative.abs(); let search_radius = (pool.accept_error * lhs.derivative.abs()).max(min_search_radius);
let low = lhs.value - search_radius;
let high = lhs.value + search_radius;
let rhs_slice = self.range(low, high);
for rhs in rhs_slice {
if !config.rhs_symbol_allowed(&rhs.expr) {
continue;
}
stats.candidates_tested += 1;
if config.show_match_checks {
eprintln!(
" [match] checking lhs={:.6} rhs={:.6}",
lhs.value, rhs.value
);
}
let val_diff = lhs.value - rhs.value;
let x_delta = -val_diff / lhs.derivative;
let coarse_error = x_delta.abs();
let is_potentially_exact = coarse_error < NEWTON_FINAL_TOLERANCE;
if !pool.would_accept_strict(coarse_error, is_potentially_exact) {
continue;
}
if !config.refine_with_newton {
let refined_x = config.target + x_delta;
let refined_error = x_delta;
let is_exact = refined_error.abs() < EXACT_MATCH_TOLERANCE;
if pool.would_accept(refined_error.abs(), is_exact) {
let m = Match {
lhs: lhs.clone(),
rhs: rhs.clone(),
x_value: refined_x,
error: refined_error,
complexity: lhs.expr.complexity() + rhs.expr.complexity(),
};
pool.try_insert(m);
if config.stop_at_exact && is_exact {
early_exit = true;
break 'outer;
}
if let Some(threshold) = config.stop_below {
if refined_error.abs() < threshold {
early_exit = true;
break 'outer;
}
}
}
continue;
}
stats.newton_calls += 1;
if let Some(refined_x) = newton_raphson_with_constants(
&lhs.expr,
rhs.value,
config.target,
config.newton_iterations,
&context.eval,
config.show_newton,
config.derivative_margin,
) {
stats.newton_success += 1;
let refined_error = refined_x - config.target;
let is_exact = refined_error.abs() < EXACT_MATCH_TOLERANCE;
if pool.would_accept(refined_error.abs(), is_exact) {
let m = Match {
lhs: lhs.clone(),
rhs: rhs.clone(),
x_value: refined_x,
error: refined_error,
complexity: lhs.expr.complexity() + rhs.expr.complexity(),
};
pool.try_insert(m);
if config.stop_at_exact && is_exact {
early_exit = true;
break 'outer;
}
if let Some(threshold) = config.stop_below {
if refined_error.abs() < threshold {
early_exit = true;
break 'outer;
}
}
}
}
}
}
stats.pool_insertions = pool.stats.insertions;
stats.pool_rejections_error = pool.stats.rejections_error;
stats.pool_rejections_dedupe = pool.stats.rejections_dedupe;
stats.pool_evictions = pool.stats.evictions;
stats.pool_final_size = pool.len();
stats.pool_best_error = pool.best_error;
stats.search_time = search_start.elapsed();
stats.early_exit = early_exit;
(pool.into_sorted(), stats)
}
}
impl Default for ExprDatabase {
fn default() -> Self {
Self::new()
}
}