1use std::marker::PhantomData;
4
5use scirs2_core::ndarray::{Array1, Array2, Axis};
6use scirs2_linalg::compat::ArrayLinalgExt;
7use sklears_core::{
9 error::{validate, Result, SklearsError},
10 traits::{Estimator, Fit, Predict, Score, Trained, Untrained},
11 types::Float,
12};
13
14#[derive(Debug, Clone)]
16pub struct LarsConfig {
17 pub fit_intercept: bool,
19 pub normalize: bool,
21 pub n_nonzero_coefs: Option<usize>,
23 pub eps: Float,
25}
26
27impl Default for LarsConfig {
28 fn default() -> Self {
29 Self {
30 fit_intercept: true,
31 normalize: true,
32 n_nonzero_coefs: None,
33 eps: Float::EPSILON.sqrt(),
34 }
35 }
36}
37
38#[derive(Debug, Clone)]
40pub struct Lars<State = Untrained> {
41 config: LarsConfig,
42 state: PhantomData<State>,
43 coef_: Option<Array1<Float>>,
45 intercept_: Option<Float>,
46 n_features_: Option<usize>,
47 active_: Option<Vec<usize>>,
48 alphas_: Option<Array1<Float>>,
49 n_iter_: Option<usize>,
50}
51
52impl Lars<Untrained> {
53 pub fn new() -> Self {
55 Self {
56 config: LarsConfig::default(),
57 state: PhantomData,
58 coef_: None,
59 intercept_: None,
60 n_features_: None,
61 active_: None,
62 alphas_: None,
63 n_iter_: None,
64 }
65 }
66
67 pub fn fit_intercept(mut self, fit_intercept: bool) -> Self {
69 self.config.fit_intercept = fit_intercept;
70 self
71 }
72
73 pub fn normalize(mut self, normalize: bool) -> Self {
75 self.config.normalize = normalize;
76 self
77 }
78
79 pub fn n_nonzero_coefs(mut self, n_nonzero_coefs: usize) -> Self {
81 self.config.n_nonzero_coefs = Some(n_nonzero_coefs);
82 self
83 }
84
85 pub fn eps(mut self, eps: Float) -> Self {
87 self.config.eps = eps;
88 self
89 }
90}
91
92impl Default for Lars<Untrained> {
93 fn default() -> Self {
94 Self::new()
95 }
96}
97
98impl Estimator for Lars<Untrained> {
99 type Config = LarsConfig;
100 type Error = SklearsError;
101 type Float = Float;
102
103 fn config(&self) -> &Self::Config {
104 &self.config
105 }
106}
107
108impl Fit<Array2<Float>, Array1<Float>> for Lars<Untrained> {
109 type Fitted = Lars<Trained>;
110
111 fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
112 validate::check_consistent_length(x, y)?;
114
115 let n_samples = x.nrows();
116 let n_features = x.ncols();
117
118 let x_mean = x.mean_axis(Axis(0)).unwrap();
120 let mut x_centered = x - &x_mean;
121
122 let y_mean = if self.config.fit_intercept {
123 y.mean().unwrap_or(0.0)
124 } else {
125 0.0
126 };
127 let y_centered = y - y_mean;
128
129 let x_scale = if self.config.normalize {
131 let mut scale = Array1::zeros(n_features);
132 for j in 0..n_features {
133 let col = x_centered.column(j);
134 scale[j] = col.dot(&col).sqrt();
135 if scale[j] > self.config.eps {
136 x_centered.column_mut(j).mapv_inplace(|x| x / scale[j]);
137 } else {
138 scale[j] = 1.0;
139 }
140 }
141 scale
142 } else {
143 Array1::ones(n_features)
144 };
145
146 let mut coef = Array1::zeros(n_features);
148 let mut active: Vec<usize> = Vec::new();
149 let mut alphas = Vec::new();
150
151 let max_features = self
153 .config
154 .n_nonzero_coefs
155 .unwrap_or(n_features)
156 .min(n_features);
157
158 let mut residual = y_centered.clone();
160 let mut correlations = x_centered.t().dot(&residual);
161
162 let mut n_iter = 0;
163
164 while active.len() < max_features {
165 let mut max_corr = 0.0;
167 let mut best_idx = 0;
168
169 for j in 0..n_features {
170 if !active.contains(&j) {
171 let corr = correlations[j].abs();
172 if corr > max_corr {
173 max_corr = corr;
174 best_idx = j;
175 }
176 }
177 }
178
179 if max_corr < self.config.eps {
181 break;
182 }
183
184 active.push(best_idx);
186 alphas.push(max_corr);
187
188 let n_active = active.len();
190 let mut x_active = Array2::zeros((n_samples, n_active));
191 for (i, &j) in active.iter().enumerate() {
192 x_active.column_mut(i).assign(&x_centered.column(j));
193 }
194
195 let gram = x_active.t().dot(&x_active);
197
198 let ones = Array1::ones(n_active);
200
201 let mut gram_reg = gram.clone();
203 for i in 0..n_active {
204 gram_reg[[i, i]] += 1e-10;
205 }
206
207 let gram_inv_ones = &gram_reg
208 .solve(&ones)
209 .map_err(|e| SklearsError::NumericalError(format!("Failed to solve: {}", e)))?;
210
211 let normalization = 1.0 / ones.dot(gram_inv_ones).sqrt();
212 let direction = gram_inv_ones * normalization;
213
214 let equiangular = x_active.dot(&direction);
216
217 let mut gamma = max_corr;
219
220 for j in 0..n_features {
222 if !active.contains(&j) {
223 let a_j = x_centered.column(j).dot(&equiangular);
224 let c_j = correlations[j];
225
226 let gamma_plus = (max_corr - c_j) / (normalization - a_j + self.config.eps);
228 let gamma_minus = (max_corr + c_j) / (normalization + a_j + self.config.eps);
229
230 if gamma_plus > 0.0 && gamma_plus < gamma {
231 gamma = gamma_plus;
232 }
233 if gamma_minus > 0.0 && gamma_minus < gamma {
234 gamma = gamma_minus;
235 }
236 }
237 }
238
239 for (i, &j) in active.iter().enumerate() {
241 coef[j] += gamma * direction[i];
242 }
243
244 residual = residual - gamma * equiangular;
246 correlations = x_centered.t().dot(&residual);
247
248 n_iter += 1;
249 }
250
251 if self.config.normalize {
253 for j in 0..n_features {
254 if x_scale[j] > 0.0 {
255 coef[j] /= x_scale[j];
256 }
257 }
258 }
259
260 let intercept = if self.config.fit_intercept {
262 Some(y_mean - x_mean.dot(&coef))
263 } else {
264 None
265 };
266
267 Ok(Lars {
268 config: self.config,
269 state: PhantomData,
270 coef_: Some(coef),
271 intercept_: intercept,
272 n_features_: Some(n_features),
273 active_: Some(active),
274 alphas_: Some(Array1::from(alphas)),
275 n_iter_: Some(n_iter),
276 })
277 }
278}
279
280impl Lars<Trained> {
281 pub fn coef(&self) -> &Array1<Float> {
283 self.coef_.as_ref().expect("Model is trained")
284 }
285
286 pub fn intercept(&self) -> Option<Float> {
288 self.intercept_
289 }
290
291 pub fn active(&self) -> &[usize] {
293 self.active_.as_ref().expect("Model is trained")
294 }
295
296 pub fn alphas(&self) -> &Array1<Float> {
298 self.alphas_.as_ref().expect("Model is trained")
299 }
300
301 pub fn n_iter(&self) -> usize {
303 self.n_iter_.expect("Model is trained")
304 }
305}
306
307impl Predict<Array2<Float>, Array1<Float>> for Lars<Trained> {
308 fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
309 let n_features = self.n_features_.expect("Model is trained");
310 validate::check_n_features(x, n_features)?;
311
312 let coef = self.coef_.as_ref().expect("Model is trained");
313 let mut predictions = x.dot(coef);
314
315 if let Some(intercept) = self.intercept_ {
316 predictions += intercept;
317 }
318
319 Ok(predictions)
320 }
321}
322
323impl Score<Array2<Float>, Array1<Float>> for Lars<Trained> {
324 type Float = Float;
325
326 fn score(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<f64> {
327 let predictions = self.predict(x)?;
328
329 let ss_res = (&predictions - y).mapv(|x| x * x).sum();
331 let y_mean = y.mean().unwrap_or(0.0);
332 let ss_tot = y.mapv(|yi| (yi - y_mean).powi(2)).sum();
333
334 if ss_tot == 0.0 {
335 return Ok(1.0);
336 }
337
338 Ok(1.0 - (ss_res / ss_tot))
339 }
340}
341
342#[allow(non_snake_case)]
343#[cfg(test)]
344mod tests {
345 use super::*;
346 use approx::assert_abs_diff_eq;
347 use scirs2_core::ndarray::array;
348
349 #[test]
350 fn test_lars_simple() {
351 let x = array![[1.0, 2.0], [2.0, 4.0], [3.0, 6.0], [4.0, 8.0]];
353 let y = array![3.0, 6.0, 9.0, 12.0]; let model = Lars::new()
356 .fit_intercept(false)
357 .normalize(false)
358 .fit(&x, &y)
359 .unwrap();
360
361 let coef = model.coef();
363 assert!(coef[0].abs() > 0.0 || coef[1].abs() > 0.0);
364
365 let predictions = model.predict(&x).unwrap();
367 for i in 0..4 {
368 assert_abs_diff_eq!(predictions[i], y[i], epsilon = 1e-5);
369 }
370 }
371
372 #[test]
373 fn test_lars_orthogonal_features() {
374 let x = array![
376 [1.0, 0.0],
377 [2.0, 0.0],
378 [0.0, 1.0],
379 [0.0, 2.0],
380 [3.0, 0.0],
381 [0.0, 3.0],
382 ];
383 let y = array![2.0, 4.0, 3.0, 6.0, 6.0, 9.0]; let model = Lars::new()
386 .fit_intercept(false)
387 .normalize(false)
388 .fit(&x, &y)
389 .unwrap();
390
391 let _predictions = model.predict(&x).unwrap();
393 let r2 = model.score(&x, &y).unwrap();
394 assert!(
395 r2 > 0.99,
396 "R² score should be very high for perfect linear relationship"
397 );
398 }
399
400 #[test]
401 fn test_lars_max_features() {
402 let x = array![
404 [1.0, 0.1, 0.01],
405 [2.0, 0.2, 0.02],
406 [3.0, 0.3, 0.03],
407 [4.0, 0.4, 0.04],
408 [5.0, 0.5, 0.05],
409 [6.0, 0.6, 0.06],
410 ];
411 let y = array![2.0, 4.0, 6.0, 8.0, 10.0, 12.0]; let model = Lars::new()
414 .fit_intercept(false)
415 .n_nonzero_coefs(1)
416 .normalize(false)
417 .fit(&x, &y)
418 .unwrap();
419
420 let coef = model.coef();
421 let n_nonzero = coef.iter().filter(|&&c| c.abs() > 1e-10).count();
422 assert_eq!(n_nonzero, 1);
423
424 assert_abs_diff_eq!(coef[0], 2.0, epsilon = 1e-3);
426
427 assert_eq!(model.active().len(), 1);
429 assert_eq!(model.active()[0], 0);
430 }
431
432 #[test]
433 fn test_lars_with_intercept() {
434 let x = array![[1.0], [2.0], [3.0], [4.0]];
435 let y = array![3.0, 5.0, 7.0, 9.0]; let model = Lars::new().fit_intercept(true).fit(&x, &y).unwrap();
438
439 assert_abs_diff_eq!(model.coef()[0], 2.0, epsilon = 1e-5);
440 assert_abs_diff_eq!(model.intercept().unwrap(), 1.0, epsilon = 1e-5);
441 }
442}