lowess/regression.rs
1//! Local regression fitting for LOWESS smoothing.
2//!
3//! This module implements the local weighted least-squares regression used by
4//! the HIGHER-level LOWESS pipeline. It provides a single-point fitter and
5//! supporting utilities (weight computation, normalization, fallback policies,
6//! and simple diagnostics) designed for robust production use.
7//!
8//! Global expectations
9//! - Inputs x and y are numeric and aligned; many helpers assume x is sorted
10//! ascending. Debug-only assertions validate sorting but production code
11//! returns safe fallbacks for degenerate inputs.
12//! - All numeric tolerances are conservative to avoid panics; callers may
13//! perform stricter validation upstream for performance or determinism.
14//!
15//! Primary parameters and flags
16//! - x: `&[T]` — sorted (recommended) independent variable values.
17//! - y: `&[T]` — dependent variable values aligned with x.
18//! - idx: `usize` — index of the target point to fit (0..n-1).
19//! - left/right: usize — inclusive window boundaries defining the local
20//! neighborhood used for the fit. These are clamped to [0, n-1] by helpers.
21//! - use_robustness: bool — when true, per-observation robustness weights
22//! (from IRLS) are multiplied with kernel weights before normalization.
23//! - robustness_weights: `&[T]` — per-observation multiplicative weights from a
24//! previous robustness pass. If not used, pass a slice of ones.
25//! - weights: `&mut [T]` — scratch buffer for computed (unnormalized) weights;
26//! must be length n (or at least cover the positions accessed). The buffer
27//! is normalized in-place prior to regression.
28//! - weight_fn: WeightFunction — kernel used to compute distance-based weights
29//! (Tricube, Epanechnikov, Gaussian, etc.). Bounded kernels support a fast
30//! short-circuit for |u| >= 1.
31//! - zero_weight_fallback: ZeroWeightFallback — policy applied when the local
32//! sum of computed weights is zero. Options:
33//! * UseLocalMean — return the (unweighted) mean over `[left..=right]`.
34//! * ReturnOriginal — return `y[idx]`.
35//! * ReturnNone — propagate failure (caller decides).
36//!
37//! WeightParams specifics
38//! - x_current: T — the x location being fitted.
39//! - bandwidth: T — effective local half-width used for normalized distance u.
40//! Must be > 0 for a full regression; zero triggers constant-average fallback.
41//! - h1: T — a tiny fraction of bandwidth below which kernel weight is forced to 1.
42//! This avoids numerical cancellation for extremely close points.
43//! - h9: T — slightly less than bandwidth (e.g. 0.999*h) used to truncate the
44//! effective neighbor scan and determine the rightmost point to include.
45//!
46//! Functions and behaviors
47//! - fit_point(ctx): primary entry. Computes kernel ± robustness weights,
48//! normalizes them, and runs a weighted linear least-squares fit evaluated at
49//! x_current. If weights sum to zero the configured fallback is used. If the
50//! weighted x-variance is too small, the fitter falls back to the weighted mean.
51//! - compute_weights(...): fast, streaming weight computation that scans from
52//! left to right, short-circuits outside h9, applies h1 fast-path, and
53//! multiplies by robustness weights when requested. Returns the (unnormalized)
54//! total weight for the scanned region.
55//! - find_rightmost_point(...): returns the largest index within h9 of x_current.
56//! - normalize_weights(...): in-place normalization over [left..=right]. Debug
57//! builds assert sum > 0; production code expects callers to handle zero-sum.
58//! - weighted_least_squares(...): numerically stable WLS for degree-1. Falls
59//! back to weighted average when denominator is below conservatively chosen
60//! tolerance (absolute and bandwidth-scaled relative terms).
61//! - compute_weighted_average(...): assumes weights already normalized and
62//! returns ∑ wᵢ vᵢ over [left..=right].
63//!
64//! Debug & determinism
65//! - Debug-only asserts check sorted x and buffer lengths; they do not change
66//! release behavior. Median/selection helpers used elsewhere prefer
67//! select_nth_unstable for performance (linear-time, not stable ordering).
68//!
69//! Production recommendations
70//! - Pre-sort and deduplicate x/y upstream for reproducible window semantics.
71//! - Provide a pre-allocated weights buffer to avoid repeated allocations.
72//! - Choose an appropriate ZeroWeightFallback policy for your deployment:
73//! UseLocalMean for graceful smoothing, ReturnNone for strict failure modes.
74//! - Use delta-driven interpolation and robust weights at the builder layer to
75//! control performance vs. robustness trade-offs for large datasets.
76
77#[cfg(not(feature = "std"))]
78extern crate alloc;
79#[cfg(not(feature = "std"))]
80use alloc::vec::Vec;
81
82use crate::kernel::WeightFunction;
83use num_traits::Float;
84
85// ============================================================================
86// Zero-weight fallback policy
87// ============================================================================
88
89/// Behavior to use when the computed total weight for a fit is zero.
90#[derive(Copy, Clone, Debug, Default)]
91pub enum ZeroWeightFallback {
92 /// Fall back to the simple mean over the local window (current behavior).
93 #[default]
94 UseLocalMean,
95 /// Return the original y value at the target index.
96 ReturnOriginal,
97 /// Return None (propagate failure) so caller can decide what to do.
98 ReturnNone,
99}
100
101// ============================================================================
102// Context Structures
103// ============================================================================
104
105/// Context for fitting a single point in LOWESS.
106pub struct FitContext<'a, T> {
107 /// X coordinates (sorted)
108 pub x: &'a [T],
109 /// Y coordinates
110 pub y: &'a [T],
111 /// Index of point to fit
112 pub idx: usize,
113 /// Left boundary of window
114 pub left: usize,
115 /// Right boundary of window
116 pub right: usize,
117 /// Whether to use robustness weights
118 pub use_robustness: bool,
119 /// Robustness weights from previous iteration
120 pub robustness_weights: &'a [T],
121 /// Scratch space for computed weights (length n)
122 pub weights: &'a mut [T],
123 /// Weight function to use
124 pub weight_fn: WeightFunction,
125 /// What to do if total weight is zero
126 pub zero_weight_fallback: ZeroWeightFallback,
127}
128
129/// Parameters for weight computation.
130pub struct WeightParams<T> {
131 /// Current x value being fitted
132 pub x_current: T,
133 /// Bandwidth (window half-width)
134 pub bandwidth: T,
135 /// NOTE: h1 and h9 are implementation optimizations / truncation thresholds,
136 /// not formal LOWESS parameters. They are used to fast-path near-zero
137 /// distances (h1) and to truncate the kernel support for efficiency (h9).
138 /// h1 = 0.001 * bandwidth ⇒ kernel weight forced to 1.0 for extremely
139 /// close points; h9 = 0.999 * bandwidth used to stop scanning beyond the
140 /// effective window in sorted x arrays.
141 /// Low threshold (0.001 * bandwidth) – assign weight 1.0
142 pub h1: T,
143 /// High threshold (0.999 * bandwidth) – assign weight 0.0
144 pub h9: T,
145 /// Whether to use robustness weights
146 pub use_robustness: bool,
147 /// Weight function to use
148 pub weight_fn: WeightFunction,
149}
150
151// ============================================================================
152// Debug helpers
153// ============================================================================
154
155#[cfg(debug_assertions)]
156fn assert_sorted<T: Float>(x: &[T]) {
157 // Ensure x is non-decreasing. This is a debug-only check to document
158 // and validate the sorted-x assumption made elsewhere (e.g. early break).
159 for i in 1..x.len() {
160 debug_assert!(
161 x[i] >= x[i - 1],
162 "x must be sorted non-decreasing (found x[{}] < x[{}])",
163 i,
164 i - 1
165 );
166 }
167}
168
169// ============================================================================
170// Main Fitting Function
171// ============================================================================
172
173/// Fit a single point using local weighted regression.
174///
175/// # Returns
176///
177/// `Some(fitted_y)` on success, `None` if the weight sum is zero.
178pub fn fit_point<T: Float>(ctx: FitContext<T>) -> Option<T> {
179 let n = ctx.x.len();
180 let x_current = ctx.x[ctx.idx];
181
182 // ---- bandwidth -------------------------------------------------------
183 let bandwidth = T::max(x_current - ctx.x[ctx.left], ctx.x[ctx.right] - x_current);
184 if bandwidth <= T::zero() {
185 // Degenerate window – fall back to weighted average of the whole window
186 let mut sum_w = T::zero();
187 for j in ctx.left..=ctx.right {
188 sum_w = sum_w + ctx.weights[j];
189 }
190 if sum_w > T::zero() {
191 return Some(compute_weighted_average(
192 ctx.y,
193 ctx.weights,
194 ctx.left,
195 ctx.right,
196 ));
197 } else {
198 return Some(ctx.y[ctx.idx]);
199 }
200 }
201
202 // ---- thresholds ------------------------------------------------------
203 let h9 = T::from(0.999).unwrap() * bandwidth;
204 let h1 = T::from(0.001).unwrap() * bandwidth;
205
206 let params = WeightParams {
207 x_current,
208 bandwidth,
209 h1,
210 h9,
211 use_robustness: ctx.use_robustness,
212 weight_fn: ctx.weight_fn,
213 };
214
215 // ---- compute raw weights ---------------------------------------------
216 let weight_sum = compute_weights(
217 ctx.x,
218 ctx.left,
219 n,
220 ¶ms,
221 ctx.use_robustness,
222 ctx.robustness_weights,
223 ctx.weights,
224 );
225
226 if weight_sum <= T::zero() {
227 // Configurable fallback behavior when all computed weights are zero.
228 match ctx.zero_weight_fallback {
229 ZeroWeightFallback::UseLocalMean => {
230 let cnt = T::from((ctx.right - ctx.left + 1) as f64).unwrap_or(T::one());
231 let mean = ctx.y[ctx.left..=ctx.right]
232 .iter()
233 .copied()
234 .fold(T::zero(), |acc, v| acc + v)
235 / cnt;
236 return Some(mean);
237 }
238 ZeroWeightFallback::ReturnOriginal => {
239 return Some(ctx.y[ctx.idx]);
240 }
241 ZeroWeightFallback::ReturnNone => {
242 return None;
243 }
244 }
245 }
246
247 // ---- effective rightmost point ---------------------------------------
248 let right_most = find_rightmost_point(ctx.x, x_current, ctx.left, n, h9);
249
250 // ---- normalize -------------------------------------------------------
251 normalize_weights(ctx.weights, ctx.left, right_most, weight_sum);
252
253 // ---- weighted least squares -------------------------------------------
254 Some(weighted_least_squares(
255 ctx.x,
256 ctx.y,
257 ctx.weights,
258 ctx.left,
259 right_most,
260 x_current,
261 bandwidth,
262 ))
263}
264
265// ============================================================================
266// Weight Computation
267// ============================================================================
268
269/// Compute kernel (and optional robustness) weights.
270///
271/// Uses the **fast** kernel path (`compute_weight_fast`) for speed.
272pub fn compute_weights<T: Float>(
273 x: &[T],
274 left: usize,
275 n: usize,
276 params: &WeightParams<T>,
277 use_robustness: bool,
278 robustness_weights: &[T],
279 weights: &mut [T],
280) -> T {
281 // Debug check: x must be sorted
282 #[cfg(debug_assertions)]
283 {
284 assert_sorted(x);
285 }
286
287 // Degenerate bandwidth → zero weights
288 if params.bandwidth <= T::zero() {
289 for w in weights.iter_mut().take(n).skip(left) {
290 *w = T::zero();
291 }
292 return T::zero();
293 }
294
295 let mut sum = T::zero();
296
297 // First, skip all points to the left of (x_current - h9)
298 let lower_bound = params.x_current - params.h9;
299 let mut start = left;
300 while start < n && x[start] < lower_bound {
301 weights[start] = T::zero();
302 start += 1;
303 }
304
305 // Then compute weights from start to right, breaking early if beyond h9
306 for j in start..n {
307 weights[j] = T::zero();
308
309 let xj = x[j];
310 let distance = (xj - params.x_current).abs();
311
312 if distance > params.h9 {
313 // Outside effective window
314 if xj > params.x_current {
315 break;
316 }
317 // To the left but outside h9, continue scanning
318 continue;
319 }
320
321 // ---- kernel weight ------------------------------------------------
322 let kernel = if distance <= params.h1 {
323 T::one()
324 } else {
325 let u = distance / params.bandwidth;
326 // *** fast path ***
327 params.weight_fn.compute_weight_fast(u)
328 };
329
330 // ---- combine with robustness --------------------------------------
331 let w = if use_robustness {
332 kernel * robustness_weights[j]
333 } else {
334 kernel
335 };
336
337 weights[j] = w;
338 sum = sum + w;
339 }
340
341 sum
342}
343
344/// Find the rightmost point whose distance ≤ `h9`.
345pub fn find_rightmost_point<T: Float>(
346 x: &[T],
347 x_current: T,
348 left: usize,
349 n: usize,
350 h9: T,
351) -> usize {
352 #[cfg(debug_assertions)]
353 {
354 assert_sorted(x);
355 }
356
357 let mut rightmost = left;
358 for (j, &xj) in x.iter().enumerate().take(n).skip(left) {
359 if (xj - x_current).abs() <= h9 {
360 rightmost = j;
361 } else {
362 // If this point is to the right of x_current and outside h9 we can
363 // break (future points will be farther). If it's to the left,
364 // continue scanning since later points may fall within h9.
365 if xj > x_current {
366 break;
367 }
368 continue;
369 }
370 }
371 rightmost
372}
373
374/// Normalize weights **in-place** over `[left, right]`.
375///
376/// # Panics
377///
378/// In debug builds if `sum == 0`.
379#[inline]
380pub fn normalize_weights<T: Float>(weights: &mut [T], left: usize, right: usize, sum: T) {
381 debug_assert!(sum > T::zero(), "weight sum must be positive");
382 let inv = T::one() / sum;
383 for w in weights.iter_mut().take(right + 1).skip(left) {
384 *w = *w * inv;
385 }
386}
387
388// ============================================================================
389// Weighted Least Squares
390// ============================================================================
391
392/// Weighted linear regression evaluated at `x_current`.
393///
394/// Falls back to a weighted average when the weighted variance is too small.
395pub fn weighted_least_squares<T: Float>(
396 x: &[T],
397 y: &[T],
398 weights: &[T],
399 left: usize,
400 right: usize,
401 x_current: T,
402 bandwidth: T,
403) -> T {
404 // Degenerate bandwidth → simple weighted average
405 if bandwidth <= T::zero() {
406 return compute_weighted_average(y, weights, left, right);
407 }
408
409 // ---- weighted means -------------------------------------------
410 let x_mean = compute_weighted_average(x, weights, left, right);
411 let y_mean = compute_weighted_average(y, weights, left, right);
412
413 // ---- weighted variance (denominator) and covariance (numerator) ---
414 let mut denom = T::zero();
415 let mut numer = T::zero();
416 for j in left..=right {
417 let dx = x[j] - x_mean;
418 denom = denom + weights[j] * dx * dx;
419 numer = numer + weights[j] * dx * (y[j] - y_mean);
420 }
421
422 // ---- numerical stability check ------------------------------------
423 // If the weighted x-variance (denom) is numerically tiny, fall back to
424 // the weighted average. Use a hybrid tolerance: an absolute floor to
425 // capture extremely narrow / near-collinear windows, and a tiny
426 // bandwidth-scaled relative term to avoid underflow on very large scales.
427 let abs_tol = T::from(1e-7).unwrap();
428 let rel_tol = T::epsilon() * bandwidth * bandwidth;
429 let tol = abs_tol.max(rel_tol);
430 if denom <= tol {
431 return compute_weighted_average(y, weights, left, right);
432 }
433
434 // ---- slope (β1) and fitted value ----------------------------------
435 debug_assert!(denom > tol, "Division by near-zero denom");
436 let beta1 = numer / denom;
437 y_mean + beta1 * (x_current - x_mean)
438}
439
440/// Simple weighted average (∑ wᵢ yᵢ) – assumes weights already normalized.
441#[inline]
442pub fn compute_weighted_average<T: Float>(
443 values: &[T],
444 weights: &[T],
445 left: usize,
446 right: usize,
447) -> T {
448 let mut sum = T::zero();
449 for j in left..=right {
450 sum = sum + weights[j] * values[j];
451 }
452 sum
453}
454
455// ============================================================================
456// Regression Diagnostics
457// ============================================================================
458
459/// Trace of the hat matrix – effective degrees of freedom.
460pub fn compute_effective_df<T: Float>(weights_matrix: &[Vec<T>]) -> T {
461 let n = weights_matrix.len();
462 let mut trace = T::zero();
463
464 for (i, row) in weights_matrix.iter().enumerate().take(n) {
465 trace = trace + row[i];
466 }
467
468 trace
469}
470
471/// Leverage of a single observation (diagonal of the hat matrix).
472#[inline]
473pub fn compute_leverage<T: Float>(weights: &[T], idx: usize) -> T {
474 weights[idx]
475}
476
477/// Residual variance estimate (weighted).
478///
479/// Uses `n_eff - 2` degrees of freedom (linear fit).
480pub fn compute_residual_variance<T: Float>(
481 residuals: &[T],
482 weights: &[T],
483 left: usize,
484 right: usize,
485) -> T {
486 let mut sum_sq = T::zero();
487 let mut n_eff = T::zero();
488
489 for j in left..=right {
490 if weights[j] > T::zero() {
491 sum_sq = sum_sq + weights[j] * residuals[j] * residuals[j];
492 n_eff = n_eff + T::one();
493 }
494 }
495
496 let df = n_eff - T::from(2.0).unwrap();
497 if df > T::zero() {
498 sum_sq / df
499 } else {
500 T::zero()
501 }
502}
503
504// ============================================================================
505// Alternative Fitting Methods
506// ============================================================================
507
508/// Higher-degree polynomial fit (currently only degree 1 is implemented).
509pub fn weighted_polynomial_fit<T: Float>(
510 x: &[T],
511 y: &[T],
512 weights: &[T],
513 left: usize,
514 right: usize,
515 x_current: T,
516 degree: usize,
517) -> T {
518 if degree == 1 {
519 let bandwidth = T::max(x_current - x[left], x[right] - x_current);
520 weighted_least_squares(x, y, weights, left, right, x_current, bandwidth)
521 } else {
522 // For now just fall back to constant fit
523 compute_weighted_average(y, weights, left, right)
524 }
525}
526
527/// Locally constant (degree-0) regression – just a weighted average.
528#[inline]
529pub fn locally_constant_fit<T: Float>(y: &[T], weights: &[T], left: usize, right: usize) -> T {
530 compute_weighted_average(y, weights, left, right)
531}