lowess/
regression.rs

1//! Local regression fitting for LOWESS smoothing.
2//!
3//! This module implements the local weighted least-squares regression used by
4//! the HIGHER-level LOWESS pipeline. It provides a single-point fitter and
5//! supporting utilities (weight computation, normalization, fallback policies,
6//! and simple diagnostics) designed for robust production use.
7//!
8//! Global expectations
9//! - Inputs x and y are numeric and aligned; many helpers assume x is sorted
10//!   ascending. Debug-only assertions validate sorting but production code
11//!   returns safe fallbacks for degenerate inputs.
12//! - All numeric tolerances are conservative to avoid panics; callers may
13//!   perform stricter validation upstream for performance or determinism.
14//!
15//! Primary parameters and flags
16//! - x: `&[T]` — sorted (recommended) independent variable values.
17//! - y: `&[T]` — dependent variable values aligned with x.
18//! - idx: `usize` — index of the target point to fit (0..n-1).
19//! - left/right: usize — inclusive window boundaries defining the local
20//!   neighborhood used for the fit. These are clamped to [0, n-1] by helpers.
21//! - use_robustness: bool — when true, per-observation robustness weights
22//!   (from IRLS) are multiplied with kernel weights before normalization.
23//! - robustness_weights: `&[T]` — per-observation multiplicative weights from a
24//!   previous robustness pass. If not used, pass a slice of ones.
25//! - weights: `&mut [T]` — scratch buffer for computed (unnormalized) weights;
26//!   must be length n (or at least cover the positions accessed). The buffer
27//!   is normalized in-place prior to regression.
28//! - weight_fn: WeightFunction — kernel used to compute distance-based weights
29//!   (Tricube, Epanechnikov, Gaussian, etc.). Bounded kernels support a fast
30//!   short-circuit for |u| >= 1.
31//! - zero_weight_fallback: ZeroWeightFallback — policy applied when the local
32//!   sum of computed weights is zero. Options:
33//!     * UseLocalMean — return the (unweighted) mean over `[left..=right]`.
34//!     * ReturnOriginal — return `y[idx]`.
35//!     * ReturnNone — propagate failure (caller decides).
36//!
37//! WeightParams specifics
38//! - x_current: T — the x location being fitted.
39//! - bandwidth: T — effective local half-width used for normalized distance u.
40//!   Must be > 0 for a full regression; zero triggers constant-average fallback.
41//! - h1: T — a tiny fraction of bandwidth below which kernel weight is forced to 1.
42//!   This avoids numerical cancellation for extremely close points.
43//! - h9: T — slightly less than bandwidth (e.g. 0.999*h) used to truncate the
44//!   effective neighbor scan and determine the rightmost point to include.
45//!
46//! Functions and behaviors
47//! - fit_point(ctx): primary entry. Computes kernel ± robustness weights,
48//!   normalizes them, and runs a weighted linear least-squares fit evaluated at
49//!   x_current. If weights sum to zero the configured fallback is used. If the
50//!   weighted x-variance is too small, the fitter falls back to the weighted mean.
51//! - compute_weights(...): fast, streaming weight computation that scans from
52//!   left to right, short-circuits outside h9, applies h1 fast-path, and
53//!   multiplies by robustness weights when requested. Returns the (unnormalized)
54//!   total weight for the scanned region.
55//! - find_rightmost_point(...): returns the largest index within h9 of x_current.
56//! - normalize_weights(...): in-place normalization over [left..=right]. Debug
57//!   builds assert sum > 0; production code expects callers to handle zero-sum.
58//! - weighted_least_squares(...): numerically stable WLS for degree-1. Falls
59//!   back to weighted average when denominator is below conservatively chosen
60//!   tolerance (absolute and bandwidth-scaled relative terms).
61//! - compute_weighted_average(...): assumes weights already normalized and
62//!   returns ∑ wᵢ vᵢ over [left..=right].
63//!
64//! Debug & determinism
65//! - Debug-only asserts check sorted x and buffer lengths; they do not change
66//!   release behavior. Median/selection helpers used elsewhere prefer
67//!   select_nth_unstable for performance (linear-time, not stable ordering).
68//!
69//! Production recommendations
70//! - Pre-sort and deduplicate x/y upstream for reproducible window semantics.
71//! - Provide a pre-allocated weights buffer to avoid repeated allocations.
72//! - Choose an appropriate ZeroWeightFallback policy for your deployment:
73//!   UseLocalMean for graceful smoothing, ReturnNone for strict failure modes.
74//! - Use delta-driven interpolation and robust weights at the builder layer to
75//!   control performance vs. robustness trade-offs for large datasets.
76
77#[cfg(not(feature = "std"))]
78extern crate alloc;
79#[cfg(not(feature = "std"))]
80use alloc::vec::Vec;
81
82use crate::kernel::WeightFunction;
83use num_traits::Float;
84
85// ============================================================================
86// Zero-weight fallback policy
87// ============================================================================
88
89/// Behavior to use when the computed total weight for a fit is zero.
90#[derive(Copy, Clone, Debug, Default)]
91pub enum ZeroWeightFallback {
92    /// Fall back to the simple mean over the local window (current behavior).
93    #[default]
94    UseLocalMean,
95    /// Return the original y value at the target index.
96    ReturnOriginal,
97    /// Return None (propagate failure) so caller can decide what to do.
98    ReturnNone,
99}
100
101// ============================================================================
102// Context Structures
103// ============================================================================
104
105/// Context for fitting a single point in LOWESS.
106pub struct FitContext<'a, T> {
107    /// X coordinates (sorted)
108    pub x: &'a [T],
109    /// Y coordinates
110    pub y: &'a [T],
111    /// Index of point to fit
112    pub idx: usize,
113    /// Left boundary of window
114    pub left: usize,
115    /// Right boundary of window
116    pub right: usize,
117    /// Whether to use robustness weights
118    pub use_robustness: bool,
119    /// Robustness weights from previous iteration
120    pub robustness_weights: &'a [T],
121    /// Scratch space for computed weights (length n)
122    pub weights: &'a mut [T],
123    /// Weight function to use
124    pub weight_fn: WeightFunction,
125    /// What to do if total weight is zero
126    pub zero_weight_fallback: ZeroWeightFallback,
127}
128
129/// Parameters for weight computation.
130pub struct WeightParams<T> {
131    /// Current x value being fitted
132    pub x_current: T,
133    /// Bandwidth (window half-width)
134    pub bandwidth: T,
135    /// NOTE: h1 and h9 are implementation optimizations / truncation thresholds,
136    /// not formal LOWESS parameters. They are used to fast-path near-zero
137    /// distances (h1) and to truncate the kernel support for efficiency (h9).
138    /// h1 = 0.001 * bandwidth ⇒ kernel weight forced to 1.0 for extremely
139    /// close points; h9 = 0.999 * bandwidth used to stop scanning beyond the
140    /// effective window in sorted x arrays.
141    /// Low threshold (0.001 * bandwidth) – assign weight 1.0
142    pub h1: T,
143    /// High threshold (0.999 * bandwidth) – assign weight 0.0
144    pub h9: T,
145    /// Whether to use robustness weights
146    pub use_robustness: bool,
147    /// Weight function to use
148    pub weight_fn: WeightFunction,
149}
150
151// ============================================================================
152// Debug helpers
153// ============================================================================
154
155#[cfg(debug_assertions)]
156fn assert_sorted<T: Float>(x: &[T]) {
157    // Ensure x is non-decreasing. This is a debug-only check to document
158    // and validate the sorted-x assumption made elsewhere (e.g. early break).
159    for i in 1..x.len() {
160        debug_assert!(
161            x[i] >= x[i - 1],
162            "x must be sorted non-decreasing (found x[{}] < x[{}])",
163            i,
164            i - 1
165        );
166    }
167}
168
169// ============================================================================
170// Main Fitting Function
171// ============================================================================
172
173/// Fit a single point using local weighted regression.
174///
175/// # Returns
176///
177/// `Some(fitted_y)` on success, `None` if the weight sum is zero.
178pub fn fit_point<T: Float>(ctx: FitContext<T>) -> Option<T> {
179    let n = ctx.x.len();
180    let x_current = ctx.x[ctx.idx];
181
182    // ---- bandwidth -------------------------------------------------------
183    let bandwidth = T::max(x_current - ctx.x[ctx.left], ctx.x[ctx.right] - x_current);
184    if bandwidth <= T::zero() {
185        // Degenerate window – fall back to weighted average of the whole window
186        let mut sum_w = T::zero();
187        for j in ctx.left..=ctx.right {
188            sum_w = sum_w + ctx.weights[j];
189        }
190        if sum_w > T::zero() {
191            return Some(compute_weighted_average(
192                ctx.y,
193                ctx.weights,
194                ctx.left,
195                ctx.right,
196            ));
197        } else {
198            return Some(ctx.y[ctx.idx]);
199        }
200    }
201
202    // ---- thresholds ------------------------------------------------------
203    let h9 = T::from(0.999).unwrap() * bandwidth;
204    let h1 = T::from(0.001).unwrap() * bandwidth;
205
206    let params = WeightParams {
207        x_current,
208        bandwidth,
209        h1,
210        h9,
211        use_robustness: ctx.use_robustness,
212        weight_fn: ctx.weight_fn,
213    };
214
215    // ---- compute raw weights ---------------------------------------------
216    let weight_sum = compute_weights(
217        ctx.x,
218        ctx.left,
219        n,
220        &params,
221        ctx.use_robustness,
222        ctx.robustness_weights,
223        ctx.weights,
224    );
225
226    if weight_sum <= T::zero() {
227        // Configurable fallback behavior when all computed weights are zero.
228        match ctx.zero_weight_fallback {
229            ZeroWeightFallback::UseLocalMean => {
230                let cnt = T::from((ctx.right - ctx.left + 1) as f64).unwrap_or(T::one());
231                let mean = ctx.y[ctx.left..=ctx.right]
232                    .iter()
233                    .copied()
234                    .fold(T::zero(), |acc, v| acc + v)
235                    / cnt;
236                return Some(mean);
237            }
238            ZeroWeightFallback::ReturnOriginal => {
239                return Some(ctx.y[ctx.idx]);
240            }
241            ZeroWeightFallback::ReturnNone => {
242                return None;
243            }
244        }
245    }
246
247    // ---- effective rightmost point ---------------------------------------
248    let right_most = find_rightmost_point(ctx.x, x_current, ctx.left, n, h9);
249
250    // ---- normalize -------------------------------------------------------
251    normalize_weights(ctx.weights, ctx.left, right_most, weight_sum);
252
253    // ---- weighted least squares -------------------------------------------
254    Some(weighted_least_squares(
255        ctx.x,
256        ctx.y,
257        ctx.weights,
258        ctx.left,
259        right_most,
260        x_current,
261        bandwidth,
262    ))
263}
264
265// ============================================================================
266// Weight Computation
267// ============================================================================
268
269/// Compute kernel (and optional robustness) weights.
270///
271/// Uses the **fast** kernel path (`compute_weight_fast`) for speed.
272pub fn compute_weights<T: Float>(
273    x: &[T],
274    left: usize,
275    n: usize,
276    params: &WeightParams<T>,
277    use_robustness: bool,
278    robustness_weights: &[T],
279    weights: &mut [T],
280) -> T {
281    // Debug check: x must be sorted
282    #[cfg(debug_assertions)]
283    {
284        assert_sorted(x);
285    }
286
287    // Degenerate bandwidth → zero weights
288    if params.bandwidth <= T::zero() {
289        for w in weights.iter_mut().take(n).skip(left) {
290            *w = T::zero();
291        }
292        return T::zero();
293    }
294
295    let mut sum = T::zero();
296
297    // First, skip all points to the left of (x_current - h9)
298    let lower_bound = params.x_current - params.h9;
299    let mut start = left;
300    while start < n && x[start] < lower_bound {
301        weights[start] = T::zero();
302        start += 1;
303    }
304
305    // Then compute weights from start to right, breaking early if beyond h9
306    for j in start..n {
307        weights[j] = T::zero();
308
309        let xj = x[j];
310        let distance = (xj - params.x_current).abs();
311
312        if distance > params.h9 {
313            // Outside effective window
314            if xj > params.x_current {
315                break;
316            }
317            // To the left but outside h9, continue scanning
318            continue;
319        }
320
321        // ---- kernel weight ------------------------------------------------
322        let kernel = if distance <= params.h1 {
323            T::one()
324        } else {
325            let u = distance / params.bandwidth;
326            // *** fast path ***
327            params.weight_fn.compute_weight_fast(u)
328        };
329
330        // ---- combine with robustness --------------------------------------
331        let w = if use_robustness {
332            kernel * robustness_weights[j]
333        } else {
334            kernel
335        };
336
337        weights[j] = w;
338        sum = sum + w;
339    }
340
341    sum
342}
343
344/// Find the rightmost point whose distance ≤ `h9`.
345pub fn find_rightmost_point<T: Float>(
346    x: &[T],
347    x_current: T,
348    left: usize,
349    n: usize,
350    h9: T,
351) -> usize {
352    #[cfg(debug_assertions)]
353    {
354        assert_sorted(x);
355    }
356
357    let mut rightmost = left;
358    for (j, &xj) in x.iter().enumerate().take(n).skip(left) {
359        if (xj - x_current).abs() <= h9 {
360            rightmost = j;
361        } else {
362            // If this point is to the right of x_current and outside h9 we can
363            // break (future points will be farther). If it's to the left,
364            // continue scanning since later points may fall within h9.
365            if xj > x_current {
366                break;
367            }
368            continue;
369        }
370    }
371    rightmost
372}
373
374/// Normalize weights **in-place** over `[left, right]`.
375///
376/// # Panics
377///
378/// In debug builds if `sum == 0`.
379#[inline]
380pub fn normalize_weights<T: Float>(weights: &mut [T], left: usize, right: usize, sum: T) {
381    debug_assert!(sum > T::zero(), "weight sum must be positive");
382    let inv = T::one() / sum;
383    for w in weights.iter_mut().take(right + 1).skip(left) {
384        *w = *w * inv;
385    }
386}
387
388// ============================================================================
389// Weighted Least Squares
390// ============================================================================
391
392/// Weighted linear regression evaluated at `x_current`.
393///
394/// Falls back to a weighted average when the weighted variance is too small.
395pub fn weighted_least_squares<T: Float>(
396    x: &[T],
397    y: &[T],
398    weights: &[T],
399    left: usize,
400    right: usize,
401    x_current: T,
402    bandwidth: T,
403) -> T {
404    // Degenerate bandwidth → simple weighted average
405    if bandwidth <= T::zero() {
406        return compute_weighted_average(y, weights, left, right);
407    }
408
409    // ---- weighted means -------------------------------------------
410    let x_mean = compute_weighted_average(x, weights, left, right);
411    let y_mean = compute_weighted_average(y, weights, left, right);
412
413    // ---- weighted variance (denominator) and covariance (numerator) ---
414    let mut denom = T::zero();
415    let mut numer = T::zero();
416    for j in left..=right {
417        let dx = x[j] - x_mean;
418        denom = denom + weights[j] * dx * dx;
419        numer = numer + weights[j] * dx * (y[j] - y_mean);
420    }
421
422    // ---- numerical stability check ------------------------------------
423    // If the weighted x-variance (denom) is numerically tiny, fall back to
424    // the weighted average. Use a hybrid tolerance: an absolute floor to
425    // capture extremely narrow / near-collinear windows, and a tiny
426    // bandwidth-scaled relative term to avoid underflow on very large scales.
427    let abs_tol = T::from(1e-7).unwrap();
428    let rel_tol = T::epsilon() * bandwidth * bandwidth;
429    let tol = abs_tol.max(rel_tol);
430    if denom <= tol {
431        return compute_weighted_average(y, weights, left, right);
432    }
433
434    // ---- slope (β1) and fitted value ----------------------------------
435    debug_assert!(denom > tol, "Division by near-zero denom");
436    let beta1 = numer / denom;
437    y_mean + beta1 * (x_current - x_mean)
438}
439
440/// Simple weighted average (∑ wᵢ yᵢ) – assumes weights already normalized.
441#[inline]
442pub fn compute_weighted_average<T: Float>(
443    values: &[T],
444    weights: &[T],
445    left: usize,
446    right: usize,
447) -> T {
448    let mut sum = T::zero();
449    for j in left..=right {
450        sum = sum + weights[j] * values[j];
451    }
452    sum
453}
454
455// ============================================================================
456// Regression Diagnostics
457// ============================================================================
458
459/// Trace of the hat matrix – effective degrees of freedom.
460pub fn compute_effective_df<T: Float>(weights_matrix: &[Vec<T>]) -> T {
461    let n = weights_matrix.len();
462    let mut trace = T::zero();
463
464    for (i, row) in weights_matrix.iter().enumerate().take(n) {
465        trace = trace + row[i];
466    }
467
468    trace
469}
470
471/// Leverage of a single observation (diagonal of the hat matrix).
472#[inline]
473pub fn compute_leverage<T: Float>(weights: &[T], idx: usize) -> T {
474    weights[idx]
475}
476
477/// Residual variance estimate (weighted).
478///
479/// Uses `n_eff - 2` degrees of freedom (linear fit).
480pub fn compute_residual_variance<T: Float>(
481    residuals: &[T],
482    weights: &[T],
483    left: usize,
484    right: usize,
485) -> T {
486    let mut sum_sq = T::zero();
487    let mut n_eff = T::zero();
488
489    for j in left..=right {
490        if weights[j] > T::zero() {
491            sum_sq = sum_sq + weights[j] * residuals[j] * residuals[j];
492            n_eff = n_eff + T::one();
493        }
494    }
495
496    let df = n_eff - T::from(2.0).unwrap();
497    if df > T::zero() {
498        sum_sq / df
499    } else {
500        T::zero()
501    }
502}
503
504// ============================================================================
505// Alternative Fitting Methods
506// ============================================================================
507
508/// Higher-degree polynomial fit (currently only degree 1 is implemented).
509pub fn weighted_polynomial_fit<T: Float>(
510    x: &[T],
511    y: &[T],
512    weights: &[T],
513    left: usize,
514    right: usize,
515    x_current: T,
516    degree: usize,
517) -> T {
518    if degree == 1 {
519        let bandwidth = T::max(x_current - x[left], x[right] - x_current);
520        weighted_least_squares(x, y, weights, left, right, x_current, bandwidth)
521    } else {
522        // For now just fall back to constant fit
523        compute_weighted_average(y, weights, left, right)
524    }
525}
526
527/// Locally constant (degree-0) regression – just a weighted average.
528#[inline]
529pub fn locally_constant_fit<T: Float>(y: &[T], weights: &[T], left: usize, right: usize) -> T {
530    compute_weighted_average(y, weights, left, right)
531}