aprender/preprocessing/
mod.rs

1//! Preprocessing transformers for data standardization and normalization.
2//!
3//! This module provides transformers that preprocess data before training.
4//!
5//! # Example
6//!
7//! ```
8//! use aprender::prelude::*;
9//! use aprender::preprocessing::StandardScaler;
10//!
11//! // Create data with different scales
12//! let data = Matrix::from_vec(4, 2, vec![
13//!     1.0, 100.0,
14//!     2.0, 200.0,
15//!     3.0, 300.0,
16//!     4.0, 400.0,
17//! ]).expect("valid matrix dimensions");
18//!
19//! // Standardize to zero mean and unit variance
20//! let mut scaler = StandardScaler::new();
21//! let scaled = scaler.fit_transform(&data).expect("fit_transform should succeed");
22//!
23//! // Each column now has mean ≈ 0 and std ≈ 1
24//! assert!(scaled.get(0, 0).abs() < 2.0);
25//! ```
26
27use crate::error::{AprenderError, Result};
28use crate::primitives::Matrix;
29use crate::traits::Transformer;
30use serde::{Deserialize, Serialize};
31use std::path::Path;
32
33/// Standardizes features by removing mean and scaling to unit variance.
34///
35/// The standard score of a sample x is: z = (x - mean) / std
36///
37/// This transformer is useful for algorithms that assume features have
38/// similar scales (e.g., regularized regression, neural networks).
39///
40/// # Example
41///
42/// ```
43/// use aprender::prelude::*;
44/// use aprender::preprocessing::StandardScaler;
45///
46/// let data = Matrix::from_vec(3, 2, vec![
47///     0.0, 0.0,
48///     1.0, 10.0,
49///     2.0, 20.0,
50/// ]).expect("valid matrix dimensions");
51///
52/// let mut scaler = StandardScaler::new();
53/// let scaled = scaler.fit_transform(&data).expect("fit_transform should succeed");
54///
55/// // Verify standardization
56/// let (n_rows, n_cols) = scaled.shape();
57/// for j in 0..n_cols {
58///     let mut sum = 0.0;
59///     for i in 0..n_rows {
60///         sum += scaled.get(i, j);
61///     }
62///     let mean = sum / n_rows as f32;
63///     assert!(mean.abs() < 1e-5, "Mean should be ~0");
64/// }
65/// ```
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct StandardScaler {
68    /// Mean of each feature (computed during fit).
69    mean: Option<Vec<f32>>,
70    /// Standard deviation of each feature (computed during fit).
71    std: Option<Vec<f32>>,
72    /// Whether to center the data (subtract mean).
73    with_mean: bool,
74    /// Whether to scale the data (divide by std).
75    with_std: bool,
76}
77
78impl Default for StandardScaler {
79    fn default() -> Self {
80        Self::new()
81    }
82}
83
84impl StandardScaler {
85    /// Creates a new `StandardScaler` with default settings.
86    ///
87    /// By default, both centering (subtract mean) and scaling (divide by std)
88    /// are enabled.
89    #[must_use]
90    pub fn new() -> Self {
91        Self {
92            mean: None,
93            std: None,
94            with_mean: true,
95            with_std: true,
96        }
97    }
98
99    /// Sets whether to center the data by subtracting the mean.
100    #[must_use]
101    pub fn with_mean(mut self, with_mean: bool) -> Self {
102        self.with_mean = with_mean;
103        self
104    }
105
106    /// Sets whether to scale the data by dividing by standard deviation.
107    #[must_use]
108    pub fn with_std(mut self, with_std: bool) -> Self {
109        self.with_std = with_std;
110        self
111    }
112
113    /// Returns the mean of each feature.
114    ///
115    /// # Panics
116    ///
117    /// Panics if the scaler is not fitted.
118    #[must_use]
119    pub fn mean(&self) -> &[f32] {
120        self.mean
121            .as_ref()
122            .expect("Scaler not fitted. Call fit() first.")
123    }
124
125    /// Returns the standard deviation of each feature.
126    ///
127    /// # Panics
128    ///
129    /// Panics if the scaler is not fitted.
130    #[must_use]
131    pub fn std(&self) -> &[f32] {
132        self.std
133            .as_ref()
134            .expect("Scaler not fitted. Call fit() first.")
135    }
136
137    /// Returns true if the scaler has been fitted.
138    #[must_use]
139    pub fn is_fitted(&self) -> bool {
140        self.mean.is_some()
141    }
142
143    /// Transforms data back to original scale.
144    ///
145    /// # Errors
146    ///
147    /// Returns an error if the scaler is not fitted or dimensions mismatch.
148    pub fn inverse_transform(&self, x: &Matrix<f32>) -> Result<Matrix<f32>> {
149        let mean = self
150            .mean
151            .as_ref()
152            .ok_or_else(|| AprenderError::from("Scaler not fitted"))?;
153        let std = self
154            .std
155            .as_ref()
156            .ok_or_else(|| AprenderError::from("Scaler not fitted"))?;
157
158        let (n_samples, n_features) = x.shape();
159        if n_features != mean.len() {
160            return Err("Feature dimension mismatch".into());
161        }
162
163        let mut result = vec![0.0; n_samples * n_features];
164
165        for i in 0..n_samples {
166            for j in 0..n_features {
167                let mut val = x.get(i, j);
168
169                // Reverse scaling
170                if self.with_std && std[j] > 1e-10 {
171                    val *= std[j];
172                }
173
174                // Reverse centering
175                if self.with_mean {
176                    val += mean[j];
177                }
178
179                result[i * n_features + j] = val;
180            }
181        }
182
183        Matrix::from_vec(n_samples, n_features, result).map_err(Into::into)
184    }
185
186    /// Saves the StandardScaler to a SafeTensors file.
187    ///
188    /// # Arguments
189    ///
190    /// * `path` - Path where the SafeTensors file will be saved
191    ///
192    /// # Errors
193    ///
194    /// Returns an error if the scaler is unfitted or if saving fails.
195    pub fn save_safetensors<P: AsRef<Path>>(&self, path: P) -> std::result::Result<(), String> {
196        use crate::serialization::safetensors;
197        use std::collections::BTreeMap;
198
199        // Check if scaler is fitted
200        let mean = self
201            .mean
202            .as_ref()
203            .ok_or_else(|| "Cannot save unfitted scaler. Call fit() first.".to_string())?;
204        let std = self
205            .std
206            .as_ref()
207            .ok_or_else(|| "Cannot save unfitted scaler. Call fit() first.".to_string())?;
208
209        let mut tensors = BTreeMap::new();
210
211        // Save mean and std vectors
212        tensors.insert("mean".to_string(), (mean.clone(), vec![mean.len()]));
213        tensors.insert("std".to_string(), (std.clone(), vec![std.len()]));
214
215        // Save hyperparameters as scalars
216        let with_mean_val = if self.with_mean { 1.0 } else { 0.0 };
217        tensors.insert("with_mean".to_string(), (vec![with_mean_val], vec![1]));
218
219        let with_std_val = if self.with_std { 1.0 } else { 0.0 };
220        tensors.insert("with_std".to_string(), (vec![with_std_val], vec![1]));
221
222        safetensors::save_safetensors(path, &tensors)?;
223        Ok(())
224    }
225
226    /// Loads a StandardScaler from a SafeTensors file.
227    ///
228    /// # Arguments
229    ///
230    /// * `path` - Path to the SafeTensors file
231    ///
232    /// # Errors
233    ///
234    /// Returns an error if loading fails or if the file format is invalid.
235    pub fn load_safetensors<P: AsRef<Path>>(path: P) -> std::result::Result<Self, String> {
236        use crate::serialization::safetensors;
237
238        // Load SafeTensors file
239        let (metadata, raw_data) = safetensors::load_safetensors(path)?;
240
241        // Extract mean tensor
242        let mean_meta = metadata
243            .get("mean")
244            .ok_or_else(|| "Missing 'mean' tensor in SafeTensors file".to_string())?;
245        let mean = safetensors::extract_tensor(&raw_data, mean_meta)?;
246
247        // Extract std tensor
248        let std_meta = metadata
249            .get("std")
250            .ok_or_else(|| "Missing 'std' tensor in SafeTensors file".to_string())?;
251        let std = safetensors::extract_tensor(&raw_data, std_meta)?;
252
253        // Verify mean and std have same length
254        if mean.len() != std.len() {
255            return Err("Mean and std vectors have different lengths".to_string());
256        }
257
258        // Load hyperparameters
259        let with_mean_meta = metadata
260            .get("with_mean")
261            .ok_or_else(|| "Missing 'with_mean' tensor".to_string())?;
262        let with_mean_data = safetensors::extract_tensor(&raw_data, with_mean_meta)?;
263        let with_mean = with_mean_data[0] > 0.5;
264
265        let with_std_meta = metadata
266            .get("with_std")
267            .ok_or_else(|| "Missing 'with_std' tensor".to_string())?;
268        let with_std_data = safetensors::extract_tensor(&raw_data, with_std_meta)?;
269        let with_std = with_std_data[0] > 0.5;
270
271        Ok(Self {
272            mean: Some(mean),
273            std: Some(std),
274            with_mean,
275            with_std,
276        })
277    }
278}
279
280impl Transformer for StandardScaler {
281    /// Computes the mean and standard deviation of each feature.
282    fn fit(&mut self, x: &Matrix<f32>) -> Result<()> {
283        let (n_samples, n_features) = x.shape();
284
285        if n_samples == 0 {
286            return Err("Cannot fit with zero samples".into());
287        }
288
289        // Compute mean for each feature
290        let mut mean = vec![0.0; n_features];
291        for (j, mean_j) in mean.iter_mut().enumerate() {
292            let mut sum = 0.0;
293            for i in 0..n_samples {
294                sum += x.get(i, j);
295            }
296            *mean_j = sum / n_samples as f32;
297        }
298
299        // Compute standard deviation for each feature
300        let mut std = vec![0.0; n_features];
301        for (j, std_j) in std.iter_mut().enumerate() {
302            let mut sum_sq = 0.0;
303            for i in 0..n_samples {
304                let diff = x.get(i, j) - mean[j];
305                sum_sq += diff * diff;
306            }
307            // Use population std (divide by n, not n-1) like sklearn
308            *std_j = (sum_sq / n_samples as f32).sqrt();
309        }
310
311        self.mean = Some(mean);
312        self.std = Some(std);
313
314        Ok(())
315    }
316
317    /// Standardizes the data using fitted mean and std.
318    fn transform(&self, x: &Matrix<f32>) -> Result<Matrix<f32>> {
319        let mean = self
320            .mean
321            .as_ref()
322            .ok_or_else(|| AprenderError::from("Scaler not fitted"))?;
323        let std = self
324            .std
325            .as_ref()
326            .ok_or_else(|| AprenderError::from("Scaler not fitted"))?;
327
328        let (n_samples, n_features) = x.shape();
329        if n_features != mean.len() {
330            return Err("Feature dimension mismatch".into());
331        }
332
333        let mut result = vec![0.0; n_samples * n_features];
334
335        for i in 0..n_samples {
336            for j in 0..n_features {
337                let mut val = x.get(i, j);
338
339                // Center
340                if self.with_mean {
341                    val -= mean[j];
342                }
343
344                // Scale
345                if self.with_std && std[j] > 1e-10 {
346                    val /= std[j];
347                }
348
349                result[i * n_features + j] = val;
350            }
351        }
352
353        Matrix::from_vec(n_samples, n_features, result).map_err(Into::into)
354    }
355}
356
357/// Scales features to a given range (default [0, 1]).
358///
359/// The transformation is: X_scaled = (X - X_min) / (X_max - X_min)
360///
361/// This transformer is useful for algorithms sensitive to feature scales
362/// and when you want bounded outputs (e.g., for neural networks).
363///
364/// # Example
365///
366/// ```
367/// use aprender::prelude::*;
368/// use aprender::preprocessing::MinMaxScaler;
369///
370/// let data = Matrix::from_vec(3, 2, vec![
371///     0.0, 0.0,
372///     5.0, 10.0,
373///     10.0, 20.0,
374/// ]).expect("valid matrix dimensions");
375///
376/// let mut scaler = MinMaxScaler::new();
377/// let scaled = scaler.fit_transform(&data).expect("fit_transform should succeed");
378///
379/// // Verify scaling to [0, 1]
380/// assert!((scaled.get(0, 0) - 0.0).abs() < 1e-6);
381/// assert!((scaled.get(2, 0) - 1.0).abs() < 1e-6);
382/// assert!((scaled.get(1, 0) - 0.5).abs() < 1e-6);
383/// ```
384#[derive(Debug, Clone, Serialize, Deserialize)]
385pub struct MinMaxScaler {
386    /// Minimum value of each feature (computed during fit).
387    data_min: Option<Vec<f32>>,
388    /// Maximum value of each feature (computed during fit).
389    data_max: Option<Vec<f32>>,
390    /// Target minimum for scaling (default 0.0).
391    feature_min: f32,
392    /// Target maximum for scaling (default 1.0).
393    feature_max: f32,
394}
395
396impl Default for MinMaxScaler {
397    fn default() -> Self {
398        Self::new()
399    }
400}
401
402impl MinMaxScaler {
403    /// Creates a new `MinMaxScaler` with default range [0, 1].
404    #[must_use]
405    pub fn new() -> Self {
406        Self {
407            data_min: None,
408            data_max: None,
409            feature_min: 0.0,
410            feature_max: 1.0,
411        }
412    }
413
414    /// Sets the target range for scaling.
415    ///
416    /// # Example
417    ///
418    /// ```
419    /// use aprender::preprocessing::MinMaxScaler;
420    ///
421    /// let scaler = MinMaxScaler::new().with_range(-1.0, 1.0);
422    /// ```
423    #[must_use]
424    pub fn with_range(mut self, min: f32, max: f32) -> Self {
425        self.feature_min = min;
426        self.feature_max = max;
427        self
428    }
429
430    /// Returns the minimum value of each feature.
431    ///
432    /// # Panics
433    ///
434    /// Panics if the scaler is not fitted.
435    #[must_use]
436    pub fn data_min(&self) -> &[f32] {
437        self.data_min
438            .as_ref()
439            .expect("Scaler not fitted. Call fit() first.")
440    }
441
442    /// Returns the maximum value of each feature.
443    ///
444    /// # Panics
445    ///
446    /// Panics if the scaler is not fitted.
447    #[must_use]
448    pub fn data_max(&self) -> &[f32] {
449        self.data_max
450            .as_ref()
451            .expect("Scaler not fitted. Call fit() first.")
452    }
453
454    /// Returns true if the scaler has been fitted.
455    #[must_use]
456    pub fn is_fitted(&self) -> bool {
457        self.data_min.is_some()
458    }
459
460    /// Transforms data back to original scale.
461    ///
462    /// # Errors
463    ///
464    /// Returns an error if the scaler is not fitted or dimensions mismatch.
465    pub fn inverse_transform(&self, x: &Matrix<f32>) -> Result<Matrix<f32>> {
466        let data_min = self
467            .data_min
468            .as_ref()
469            .ok_or_else(|| AprenderError::from("Scaler not fitted"))?;
470        let data_max = self
471            .data_max
472            .as_ref()
473            .ok_or_else(|| AprenderError::from("Scaler not fitted"))?;
474
475        let (n_samples, n_features) = x.shape();
476        if n_features != data_min.len() {
477            return Err("Feature dimension mismatch".into());
478        }
479
480        let feature_range = self.feature_max - self.feature_min;
481        let mut result = vec![0.0; n_samples * n_features];
482
483        for i in 0..n_samples {
484            for j in 0..n_features {
485                let val = x.get(i, j);
486                let data_range = data_max[j] - data_min[j];
487
488                let original = if data_range.abs() > 1e-10 {
489                    (val - self.feature_min) / feature_range * data_range + data_min[j]
490                } else {
491                    data_min[j]
492                };
493
494                result[i * n_features + j] = original;
495            }
496        }
497
498        Matrix::from_vec(n_samples, n_features, result).map_err(Into::into)
499    }
500}
501
502impl Transformer for MinMaxScaler {
503    /// Computes the min and max of each feature.
504    fn fit(&mut self, x: &Matrix<f32>) -> Result<()> {
505        let (n_samples, n_features) = x.shape();
506
507        if n_samples == 0 {
508            return Err("Cannot fit with zero samples".into());
509        }
510
511        let mut data_min = vec![f32::INFINITY; n_features];
512        let mut data_max = vec![f32::NEG_INFINITY; n_features];
513
514        for i in 0..n_samples {
515            for j in 0..n_features {
516                let val = x.get(i, j);
517                if val < data_min[j] {
518                    data_min[j] = val;
519                }
520                if val > data_max[j] {
521                    data_max[j] = val;
522                }
523            }
524        }
525
526        self.data_min = Some(data_min);
527        self.data_max = Some(data_max);
528
529        Ok(())
530    }
531
532    /// Scales the data to the target range.
533    fn transform(&self, x: &Matrix<f32>) -> Result<Matrix<f32>> {
534        let data_min = self
535            .data_min
536            .as_ref()
537            .ok_or_else(|| AprenderError::from("Scaler not fitted"))?;
538        let data_max = self
539            .data_max
540            .as_ref()
541            .ok_or_else(|| AprenderError::from("Scaler not fitted"))?;
542
543        let (n_samples, n_features) = x.shape();
544        if n_features != data_min.len() {
545            return Err("Feature dimension mismatch".into());
546        }
547
548        let feature_range = self.feature_max - self.feature_min;
549        let mut result = vec![0.0; n_samples * n_features];
550
551        for i in 0..n_samples {
552            for j in 0..n_features {
553                let val = x.get(i, j);
554                let data_range = data_max[j] - data_min[j];
555
556                let scaled = if data_range.abs() > 1e-10 {
557                    (val - data_min[j]) / data_range * feature_range + self.feature_min
558                } else {
559                    self.feature_min
560                };
561
562                result[i * n_features + j] = scaled;
563            }
564        }
565
566        Matrix::from_vec(n_samples, n_features, result).map_err(Into::into)
567    }
568}
569
570/// Principal Component Analysis (PCA) for dimensionality reduction.
571///
572/// PCA reduces dimensionality by projecting data onto principal components
573/// (directions of maximum variance).
574///
575/// # Example
576///
577/// ```
578/// use aprender::preprocessing::PCA;
579/// use aprender::traits::Transformer;
580/// use aprender::primitives::Matrix;
581///
582/// let data = Matrix::from_vec(4, 3, vec![
583///     1.0, 2.0, 3.0,
584///     4.0, 5.0, 6.0,
585///     7.0, 8.0, 9.0,
586///     10.0, 11.0, 12.0,
587/// ]).expect("valid matrix dimensions");
588///
589/// let mut pca = PCA::new(2); // Reduce to 2 components
590/// let transformed = pca.fit_transform(&data).expect("fit_transform should succeed");
591/// assert_eq!(transformed.shape(), (4, 2));
592/// ```
593#[derive(Debug, Clone)]
594pub struct PCA {
595    /// Number of components to keep.
596    n_components: usize,
597    /// Mean of each feature (computed during fit).
598    mean: Option<Vec<f32>>,
599    /// Principal components (eigenvectors).
600    components: Option<Matrix<f32>>,
601    /// Variance explained by each component.
602    explained_variance: Option<Vec<f32>>,
603    /// Ratio of variance explained by each component.
604    explained_variance_ratio: Option<Vec<f32>>,
605}
606
607impl PCA {
608    /// Creates a new PCA transformer.
609    ///
610    /// # Arguments
611    ///
612    /// * `n_components` - Number of principal components to keep
613    #[must_use]
614    pub fn new(n_components: usize) -> Self {
615        Self {
616            n_components,
617            mean: None,
618            components: None,
619            explained_variance: None,
620            explained_variance_ratio: None,
621        }
622    }
623
624    /// Returns the variance explained by each component.
625    #[must_use]
626    pub fn explained_variance(&self) -> Option<&[f32]> {
627        self.explained_variance.as_deref()
628    }
629
630    /// Returns the ratio of variance explained by each component.
631    #[must_use]
632    pub fn explained_variance_ratio(&self) -> Option<&[f32]> {
633        self.explained_variance_ratio.as_deref()
634    }
635
636    /// Returns the principal components.
637    #[must_use]
638    pub fn components(&self) -> Option<&Matrix<f32>> {
639        self.components.as_ref()
640    }
641
642    /// Reconstructs data from principal component space.
643    ///
644    /// # Errors
645    ///
646    /// Returns error if PCA is not fitted.
647    pub fn inverse_transform(&self, x: &Matrix<f32>) -> Result<Matrix<f32>> {
648        let components = self
649            .components
650            .as_ref()
651            .ok_or_else(|| AprenderError::from("PCA not fitted"))?;
652        let mean = self
653            .mean
654            .as_ref()
655            .ok_or_else(|| AprenderError::from("PCA not fitted"))?;
656
657        let (n_samples, n_components) = x.shape();
658        let n_features = mean.len();
659
660        if n_components != self.n_components {
661            return Err("Input has wrong number of components".into());
662        }
663
664        // X_reconstructed = X_pca @ components^T + mean
665        let mut result = vec![0.0; n_samples * n_features];
666
667        for i in 0..n_samples {
668            for j in 0..n_features {
669                let mut value = mean[j];
670                for k in 0..n_components {
671                    value += x.get(i, k) * components.get(k, j);
672                }
673                result[i * n_features + j] = value;
674            }
675        }
676
677        Matrix::from_vec(n_samples, n_features, result).map_err(Into::into)
678    }
679}
680
681impl Transformer for PCA {
682    fn fit(&mut self, x: &Matrix<f32>) -> Result<()> {
683        use trueno::SymmetricEigen;
684
685        let (n_samples, n_features) = x.shape();
686
687        if self.n_components > n_features {
688            return Err("n_components cannot exceed number of features".into());
689        }
690
691        // Compute mean
692        let mut mean = vec![0.0; n_features];
693        #[allow(clippy::needless_range_loop)]
694        for j in 0..n_features {
695            let mut sum = 0.0;
696            for i in 0..n_samples {
697                sum += x.get(i, j);
698            }
699            mean[j] = sum / n_samples as f32;
700        }
701
702        // Center the data
703        let mut centered = vec![0.0; n_samples * n_features];
704        for i in 0..n_samples {
705            for j in 0..n_features {
706                centered[i * n_features + j] = x.get(i, j) - mean[j];
707            }
708        }
709
710        // Compute covariance matrix: Σ = (X^T X) / (n-1)
711        let mut cov = vec![0.0; n_features * n_features];
712        for i in 0..n_features {
713            for j in 0..n_features {
714                let mut sum = 0.0;
715                for k in 0..n_samples {
716                    sum += centered[k * n_features + i] * centered[k * n_features + j];
717                }
718                cov[i * n_features + j] = sum / (n_samples - 1) as f32;
719            }
720        }
721
722        // Convert to trueno Matrix for eigendecomposition
723        let cov_matrix = trueno::Matrix::from_vec(n_features, n_features, cov)
724            .map_err(|e| format!("Failed to create covariance matrix: {e}"))?;
725        let eigen = SymmetricEigen::new(&cov_matrix)
726            .map_err(|e| format!("Eigendecomposition failed: {e}"))?;
727
728        // trueno returns eigenvalues in descending order (largest first) - perfect for PCA
729        let eigenvalues = eigen.eigenvalues();
730        let eigenvectors = eigen.eigenvectors();
731
732        // Select top n_components (already sorted descending)
733        let mut components_data = vec![0.0; self.n_components * n_features];
734        let mut explained_variance = vec![0.0; self.n_components];
735
736        for i in 0..self.n_components {
737            explained_variance[i] = eigenvalues[i];
738            for j in 0..n_features {
739                // trueno eigenvectors: columns are eigenvectors, access with get(row, col)
740                components_data[i * n_features + j] = *eigenvectors
741                    .get(j, i)
742                    .ok_or_else(|| format!("Invalid eigenvector index ({j}, {i})"))?;
743            }
744        }
745
746        // Compute explained variance ratio
747        let total_variance: f32 = eigenvalues.iter().copied().sum();
748        let explained_variance_ratio: Vec<f32> = explained_variance
749            .iter()
750            .map(|&v| v / total_variance)
751            .collect();
752
753        self.mean = Some(mean);
754        self.components = Some(Matrix::from_vec(
755            self.n_components,
756            n_features,
757            components_data,
758        )?);
759        self.explained_variance = Some(explained_variance);
760        self.explained_variance_ratio = Some(explained_variance_ratio);
761
762        Ok(())
763    }
764
765    fn transform(&self, x: &Matrix<f32>) -> Result<Matrix<f32>> {
766        let components = self
767            .components
768            .as_ref()
769            .ok_or_else(|| AprenderError::from("PCA not fitted"))?;
770        let mean = self
771            .mean
772            .as_ref()
773            .ok_or_else(|| AprenderError::from("PCA not fitted"))?;
774
775        let (n_samples, n_features) = x.shape();
776
777        if n_features != mean.len() {
778            return Err("Input has wrong number of features".into());
779        }
780
781        // Project onto principal components: X_pca = (X - mean) @ components^T
782        let mut result = vec![0.0; n_samples * self.n_components];
783
784        for i in 0..n_samples {
785            for j in 0..self.n_components {
786                let mut value = 0.0;
787                #[allow(clippy::needless_range_loop)]
788                for k in 0..n_features {
789                    value += (x.get(i, k) - mean[k]) * components.get(j, k);
790                }
791                result[i * self.n_components + j] = value;
792            }
793        }
794
795        Matrix::from_vec(n_samples, self.n_components, result).map_err(Into::into)
796    }
797}
798
799// ============================================================================
800// t-SNE (t-Distributed Stochastic Neighbor Embedding)
801// ============================================================================
802
803/// t-SNE for dimensionality reduction and visualization.
804///
805/// t-Distributed Stochastic Neighbor Embedding (t-SNE) is a non-linear
806/// dimensionality reduction technique optimized for visualization of
807/// high-dimensional data in 2D or 3D space.
808///
809/// # Algorithm
810///
811/// 1. Compute pairwise similarities in high-D using Gaussian kernel
812/// 2. Compute perplexity-based conditional probabilities
813/// 3. Initialize low-D embedding (random or PCA)
814/// 4. Compute pairwise similarities in low-D using Student's t-distribution
815/// 5. Minimize KL divergence via gradient descent with momentum
816///
817/// # Example
818///
819/// ```
820/// use aprender::prelude::*;
821/// use aprender::preprocessing::TSNE;
822///
823/// let data = Matrix::from_vec(
824///     6,
825///     4,
826///     vec![
827///         1.0, 2.0, 3.0, 4.0,
828///         1.1, 2.1, 3.1, 4.1,
829///         5.0, 6.0, 7.0, 8.0,
830///         5.1, 6.1, 7.1, 8.1,
831///         10.0, 11.0, 12.0, 13.0,
832///         10.1, 11.1, 12.1, 13.1,
833///     ],
834/// )
835/// .expect("valid matrix dimensions");
836///
837/// let mut tsne = TSNE::new(2).with_perplexity(5.0).with_n_iter(250);
838/// let embedding = tsne.fit_transform(&data).expect("fit_transform should succeed");
839/// assert_eq!(embedding.shape(), (6, 2));
840/// ```
841#[derive(Debug, Clone, Serialize, Deserialize)]
842pub struct TSNE {
843    /// Number of dimensions in embedding (usually 2 or 3).
844    n_components: usize,
845    /// Perplexity balances local vs global structure (5-50).
846    perplexity: f32,
847    /// Learning rate for gradient descent.
848    learning_rate: f32,
849    /// Number of gradient descent iterations.
850    n_iter: usize,
851    /// Random seed for reproducibility.
852    random_state: Option<u64>,
853    /// The learned embedding.
854    embedding: Option<Matrix<f32>>,
855}
856
857impl Default for TSNE {
858    fn default() -> Self {
859        Self::new(2)
860    }
861}
862
863impl TSNE {
864    /// Create a new t-SNE with default parameters.
865    ///
866    /// Default: perplexity=30.0, learning_rate=200.0, n_iter=1000
867    #[must_use]
868    pub fn new(n_components: usize) -> Self {
869        Self {
870            n_components,
871            perplexity: 30.0,
872            learning_rate: 200.0,
873            n_iter: 1000,
874            random_state: None,
875            embedding: None,
876        }
877    }
878
879    /// Set perplexity (balance between local and global structure).
880    ///
881    /// Typical range: 5-50. Higher perplexity considers more neighbors.
882    #[must_use]
883    pub fn with_perplexity(mut self, perplexity: f32) -> Self {
884        self.perplexity = perplexity;
885        self
886    }
887
888    /// Set learning rate for gradient descent.
889    #[must_use]
890    pub fn with_learning_rate(mut self, learning_rate: f32) -> Self {
891        self.learning_rate = learning_rate;
892        self
893    }
894
895    /// Set number of gradient descent iterations.
896    #[must_use]
897    pub fn with_n_iter(mut self, n_iter: usize) -> Self {
898        self.n_iter = n_iter;
899        self
900    }
901
902    /// Set random seed for reproducibility.
903    #[must_use]
904    pub fn with_random_state(mut self, seed: u64) -> Self {
905        self.random_state = Some(seed);
906        self
907    }
908
909    /// Get number of components.
910    #[must_use]
911    pub fn n_components(&self) -> usize {
912        self.n_components
913    }
914
915    /// Check if model has been fitted.
916    #[must_use]
917    pub fn is_fitted(&self) -> bool {
918        self.embedding.is_some()
919    }
920
921    /// Compute pairwise squared Euclidean distances.
922    #[allow(clippy::unused_self)]
923    fn compute_pairwise_distances(&self, x: &Matrix<f32>) -> Vec<f32> {
924        let (n_samples, n_features) = x.shape();
925        let mut distances = vec![0.0; n_samples * n_samples];
926
927        for i in 0..n_samples {
928            for j in 0..n_samples {
929                if i == j {
930                    distances[i * n_samples + j] = 0.0;
931                    continue;
932                }
933
934                let mut dist_sq = 0.0;
935                for k in 0..n_features {
936                    let diff = x.get(i, k) - x.get(j, k);
937                    dist_sq += diff * diff;
938                }
939                distances[i * n_samples + j] = dist_sq;
940            }
941        }
942
943        distances
944    }
945
946    /// Compute conditional probabilities P(j|i) with perplexity constraint.
947    ///
948    /// Uses binary search to find sigma_i such that perplexity matches target.
949    fn compute_p_conditional(&self, distances: &[f32], n_samples: usize) -> Vec<f32> {
950        let mut p_conditional = vec![0.0; n_samples * n_samples];
951        let target_entropy = self.perplexity.ln();
952
953        for i in 0..n_samples {
954            // Binary search for sigma that gives target perplexity
955            let mut beta_min = -f32::INFINITY;
956            let mut beta_max = f32::INFINITY;
957            let mut beta = 1.0; // beta = 1 / (2 * sigma^2)
958
959            for _ in 0..50 {
960                // Max iterations for binary search
961                // Compute P(j|i) with current beta
962                let mut sum_p = 0.0;
963                let mut entropy = 0.0;
964
965                for j in 0..n_samples {
966                    if i == j {
967                        p_conditional[i * n_samples + j] = 0.0;
968                        continue;
969                    }
970
971                    let p_ji = (-beta * distances[i * n_samples + j]).exp();
972                    p_conditional[i * n_samples + j] = p_ji;
973                    sum_p += p_ji;
974                }
975
976                // Normalize and compute entropy
977                if sum_p > 0.0 {
978                    for j in 0..n_samples {
979                        if i != j {
980                            let p_normalized = p_conditional[i * n_samples + j] / sum_p;
981                            p_conditional[i * n_samples + j] = p_normalized;
982                            if p_normalized > 1e-12 {
983                                entropy -= p_normalized * p_normalized.ln();
984                            }
985                        }
986                    }
987                }
988
989                // Check if entropy matches target
990                let entropy_diff = entropy - target_entropy;
991                if entropy_diff.abs() < 1e-5 {
992                    break;
993                }
994
995                // Update beta via binary search
996                if entropy_diff > 0.0 {
997                    beta_min = beta;
998                    beta = if beta_max.is_infinite() {
999                        beta * 2.0
1000                    } else {
1001                        (beta + beta_max) / 2.0
1002                    };
1003                } else {
1004                    beta_max = beta;
1005                    beta = if beta_min.is_infinite() {
1006                        beta / 2.0
1007                    } else {
1008                        (beta + beta_min) / 2.0
1009                    };
1010                }
1011            }
1012        }
1013
1014        p_conditional
1015    }
1016
1017    /// Compute symmetric P matrix: P_{ij} = (P_{j|i} + P_{i|j}) / (2N).
1018    #[allow(clippy::unused_self)]
1019    fn compute_p_joint(&self, p_conditional: &[f32], n_samples: usize) -> Vec<f32> {
1020        let mut p_joint = vec![0.0; n_samples * n_samples];
1021        let normalizer = 2.0 * n_samples as f32;
1022
1023        for i in 0..n_samples {
1024            for j in 0..n_samples {
1025                p_joint[i * n_samples + j] = (p_conditional[i * n_samples + j]
1026                    + p_conditional[j * n_samples + i])
1027                    / normalizer;
1028                // Numerical stability
1029                p_joint[i * n_samples + j] = p_joint[i * n_samples + j].max(1e-12);
1030            }
1031        }
1032
1033        p_joint
1034    }
1035
1036    /// Compute Q matrix in low-dimensional space using Student's t-distribution.
1037    fn compute_q(&self, y: &[f32], n_samples: usize) -> Vec<f32> {
1038        let mut q = vec![0.0; n_samples * n_samples];
1039        let mut sum_q = 0.0;
1040
1041        // Compute Q_{ij} = (1 + ||y_i - y_j||^2)^{-1}
1042        for i in 0..n_samples {
1043            for j in 0..n_samples {
1044                if i == j {
1045                    q[i * n_samples + j] = 0.0;
1046                    continue;
1047                }
1048
1049                let mut dist_sq = 0.0;
1050                for k in 0..self.n_components {
1051                    let diff = y[i * self.n_components + k] - y[j * self.n_components + k];
1052                    dist_sq += diff * diff;
1053                }
1054
1055                let q_ij = 1.0 / (1.0 + dist_sq);
1056                q[i * n_samples + j] = q_ij;
1057                sum_q += q_ij;
1058            }
1059        }
1060
1061        // Normalize
1062        if sum_q > 0.0 {
1063            for q_val in &mut q {
1064                *q_val /= sum_q;
1065                *q_val = q_val.max(1e-12); // Numerical stability
1066            }
1067        }
1068
1069        q
1070    }
1071
1072    /// Compute gradient of KL divergence.
1073    fn compute_gradient(&self, y: &[f32], p: &[f32], q: &[f32], n_samples: usize) -> Vec<f32> {
1074        let mut gradient = vec![0.0; n_samples * self.n_components];
1075
1076        for i in 0..n_samples {
1077            for j in 0..n_samples {
1078                if i == j {
1079                    continue;
1080                }
1081
1082                let p_ij = p[i * n_samples + j];
1083                let q_ij = q[i * n_samples + j];
1084
1085                // Gradient factor: 4 * (p_ij - q_ij) * q_ij * (1 + ||y_i - y_j||^2)^{-1}
1086                // Simplified: 4 * (p_ij - q_ij) / (1 + ||y_i - y_j||^2)
1087                let mut dist_sq = 0.0;
1088                for k in 0..self.n_components {
1089                    let diff = y[i * self.n_components + k] - y[j * self.n_components + k];
1090                    dist_sq += diff * diff;
1091                }
1092
1093                let factor = 4.0 * (p_ij - q_ij) / (1.0 + dist_sq);
1094
1095                for k in 0..self.n_components {
1096                    let diff = y[i * self.n_components + k] - y[j * self.n_components + k];
1097                    gradient[i * self.n_components + k] += factor * diff;
1098                }
1099            }
1100        }
1101
1102        gradient
1103    }
1104}
1105
1106impl Transformer for TSNE {
1107    fn fit(&mut self, x: &Matrix<f32>) -> Result<()> {
1108        let (n_samples, _n_features) = x.shape();
1109
1110        // Compute pairwise distances in high-D
1111        let distances = self.compute_pairwise_distances(x);
1112
1113        // Compute conditional probabilities with perplexity
1114        let p_conditional = self.compute_p_conditional(&distances, n_samples);
1115
1116        // Compute joint probabilities (symmetric)
1117        let p_joint = self.compute_p_joint(&p_conditional, n_samples);
1118
1119        // Initialize embedding randomly
1120        use std::collections::hash_map::DefaultHasher;
1121        use std::hash::{Hash, Hasher};
1122
1123        let seed = self.random_state.unwrap_or_else(|| {
1124            let mut hasher = DefaultHasher::new();
1125            std::time::SystemTime::now().hash(&mut hasher);
1126            hasher.finish()
1127        });
1128
1129        // Simple LCG random number generator for reproducibility
1130        let mut rng_state = seed;
1131        let mut rand = || -> f32 {
1132            rng_state = rng_state.wrapping_mul(1664525).wrapping_add(1013904223);
1133            ((rng_state >> 16) as f32 / 65536.0) - 0.5
1134        };
1135
1136        let mut y = vec![0.0; n_samples * self.n_components];
1137        for val in &mut y {
1138            *val = rand() * 0.0001; // Small random initialization
1139        }
1140
1141        // Gradient descent with momentum
1142        let mut velocity = vec![0.0; n_samples * self.n_components];
1143        let momentum = 0.5;
1144        let final_momentum = 0.8;
1145        let momentum_switch_iter = 250;
1146
1147        for iter in 0..self.n_iter {
1148            // Compute Q matrix in low-D
1149            let q = self.compute_q(&y, n_samples);
1150
1151            // Compute gradient
1152            let gradient = self.compute_gradient(&y, &p_joint, &q, n_samples);
1153
1154            // Update with momentum
1155            let current_momentum = if iter < momentum_switch_iter {
1156                momentum
1157            } else {
1158                final_momentum
1159            };
1160
1161            for i in 0..(n_samples * self.n_components) {
1162                velocity[i] = current_momentum * velocity[i] - self.learning_rate * gradient[i];
1163                y[i] += velocity[i];
1164            }
1165
1166            // Early exaggeration (first 100 iterations)
1167            if iter == 100 {
1168                // Remove early exaggeration by dividing P by 4
1169                // (we multiplied by 4 implicitly in gradient computation)
1170            }
1171        }
1172
1173        self.embedding = Some(Matrix::from_vec(n_samples, self.n_components, y)?);
1174        Ok(())
1175    }
1176
1177    fn transform(&self, _x: &Matrix<f32>) -> Result<Matrix<f32>> {
1178        assert!(self.is_fitted(), "Model not fitted. Call fit() first.");
1179        // t-SNE is non-parametric, return the embedding
1180        Ok(self
1181            .embedding
1182            .as_ref()
1183            .expect("embedding should exist after is_fitted() check")
1184            .clone())
1185    }
1186
1187    fn fit_transform(&mut self, x: &Matrix<f32>) -> Result<Matrix<f32>> {
1188        self.fit(x)?;
1189        self.transform(x)
1190    }
1191}
1192
1193#[cfg(test)]
1194mod tests {
1195    use super::*;
1196
1197    #[test]
1198    fn test_new() {
1199        let scaler = StandardScaler::new();
1200        assert!(!scaler.is_fitted());
1201    }
1202
1203    #[test]
1204    fn test_default() {
1205        let scaler = StandardScaler::default();
1206        assert!(!scaler.is_fitted());
1207    }
1208
1209    #[test]
1210    fn test_fit_basic() {
1211        let data = Matrix::from_vec(3, 2, vec![1.0, 10.0, 2.0, 20.0, 3.0, 30.0])
1212            .expect("valid matrix dimensions");
1213
1214        let mut scaler = StandardScaler::new();
1215        scaler
1216            .fit(&data)
1217            .expect("fit should succeed with valid data");
1218
1219        assert!(scaler.is_fitted());
1220
1221        // Mean should be [2.0, 20.0]
1222        let mean = scaler.mean();
1223        assert!((mean[0] - 2.0).abs() < 1e-6);
1224        assert!((mean[1] - 20.0).abs() < 1e-6);
1225
1226        // Std should be sqrt(2/3) ≈ 0.8165
1227        let std = scaler.std();
1228        let expected_std = (2.0_f32 / 3.0).sqrt();
1229        assert!((std[0] - expected_std).abs() < 1e-4);
1230        assert!((std[1] - expected_std * 10.0).abs() < 1e-3);
1231    }
1232
1233    #[test]
1234    fn test_transform_basic() {
1235        let data = Matrix::from_vec(3, 1, vec![1.0, 2.0, 3.0]).expect("valid matrix dimensions");
1236
1237        let mut scaler = StandardScaler::new();
1238        scaler
1239            .fit(&data)
1240            .expect("fit should succeed with valid data");
1241
1242        let transformed = scaler
1243            .transform(&data)
1244            .expect("transform should succeed after fit");
1245
1246        // Mean should be 0
1247        let mean: f32 = (0..3).map(|i| transformed.get(i, 0)).sum::<f32>() / 3.0;
1248        assert!(mean.abs() < 1e-6, "Mean should be ~0, got {mean}");
1249
1250        // Std should be 1
1251        let variance: f32 = (0..3)
1252            .map(|i| {
1253                let v = transformed.get(i, 0);
1254                v * v
1255            })
1256            .sum::<f32>()
1257            / 3.0;
1258        assert!(
1259            (variance.sqrt() - 1.0).abs() < 1e-6,
1260            "Std should be ~1, got {}",
1261            variance.sqrt()
1262        );
1263    }
1264
1265    #[test]
1266    fn test_fit_transform() {
1267        let data = Matrix::from_vec(4, 2, vec![1.0, 100.0, 2.0, 200.0, 3.0, 300.0, 4.0, 400.0])
1268            .expect("valid matrix dimensions");
1269
1270        let mut scaler = StandardScaler::new();
1271        let transformed = scaler
1272            .fit_transform(&data)
1273            .expect("fit_transform should succeed with valid data");
1274
1275        // Check each column has mean ≈ 0
1276        for j in 0..2 {
1277            let mean: f32 = (0..4).map(|i| transformed.get(i, j)).sum::<f32>() / 4.0;
1278            assert!(mean.abs() < 1e-5, "Column {j} mean should be ~0");
1279        }
1280    }
1281
1282    #[test]
1283    fn test_inverse_transform() {
1284        let data = Matrix::from_vec(3, 2, vec![1.0, 10.0, 2.0, 20.0, 3.0, 30.0])
1285            .expect("valid matrix dimensions");
1286
1287        let mut scaler = StandardScaler::new();
1288        let transformed = scaler
1289            .fit_transform(&data)
1290            .expect("fit_transform should succeed");
1291        let recovered = scaler
1292            .inverse_transform(&transformed)
1293            .expect("inverse_transform should succeed");
1294
1295        // Should recover original data
1296        for i in 0..3 {
1297            for j in 0..2 {
1298                assert!(
1299                    (data.get(i, j) - recovered.get(i, j)).abs() < 1e-5,
1300                    "Mismatch at ({i}, {j})"
1301                );
1302            }
1303        }
1304    }
1305
1306    #[test]
1307    fn test_transform_new_data() {
1308        let train = Matrix::from_vec(3, 1, vec![1.0, 2.0, 3.0]).expect("valid matrix dimensions");
1309        let test = Matrix::from_vec(2, 1, vec![4.0, 5.0]).expect("valid matrix dimensions");
1310
1311        let mut scaler = StandardScaler::new();
1312        scaler
1313            .fit(&train)
1314            .expect("fit should succeed with valid data");
1315
1316        let transformed = scaler
1317            .transform(&test)
1318            .expect("transform should succeed with new data");
1319
1320        // Test data should be transformed using train stats
1321        // mean=2, std=sqrt(2/3)
1322        let mean = 2.0;
1323        let std = (2.0_f32 / 3.0).sqrt();
1324
1325        let expected_0 = (4.0 - mean) / std;
1326        let expected_1 = (5.0 - mean) / std;
1327
1328        assert!((transformed.get(0, 0) - expected_0).abs() < 1e-5);
1329        assert!((transformed.get(1, 0) - expected_1).abs() < 1e-5);
1330    }
1331
1332    #[test]
1333    fn test_without_mean() {
1334        let data = Matrix::from_vec(3, 1, vec![1.0, 2.0, 3.0]).expect("valid matrix dimensions");
1335
1336        let mut scaler = StandardScaler::new().with_mean(false);
1337        let transformed = scaler
1338            .fit_transform(&data)
1339            .expect("fit_transform should succeed");
1340
1341        // Should only scale, not center
1342        // Original values divided by std
1343        let std = (2.0_f32 / 3.0).sqrt();
1344        assert!((transformed.get(0, 0) - 1.0 / std).abs() < 1e-5);
1345        assert!((transformed.get(1, 0) - 2.0 / std).abs() < 1e-5);
1346        assert!((transformed.get(2, 0) - 3.0 / std).abs() < 1e-5);
1347    }
1348
1349    #[test]
1350    fn test_without_std() {
1351        let data = Matrix::from_vec(3, 1, vec![1.0, 2.0, 3.0]).expect("valid matrix dimensions");
1352
1353        let mut scaler = StandardScaler::new().with_std(false);
1354        let transformed = scaler
1355            .fit_transform(&data)
1356            .expect("fit_transform should succeed");
1357
1358        // Should only center, not scale
1359        // mean = 2.0
1360        assert!((transformed.get(0, 0) - (-1.0)).abs() < 1e-5);
1361        assert!((transformed.get(1, 0) - 0.0).abs() < 1e-5);
1362        assert!((transformed.get(2, 0) - 1.0).abs() < 1e-5);
1363    }
1364
1365    #[test]
1366    fn test_constant_feature() {
1367        // Feature with zero variance
1368        let data = Matrix::from_vec(3, 2, vec![1.0, 5.0, 2.0, 5.0, 3.0, 5.0])
1369            .expect("valid matrix dimensions");
1370
1371        let mut scaler = StandardScaler::new();
1372        let transformed = scaler
1373            .fit_transform(&data)
1374            .expect("fit_transform should succeed");
1375
1376        // Second column has zero std, should remain centered but not scaled
1377        assert!((transformed.get(0, 1) - 0.0).abs() < 1e-5);
1378        assert!((transformed.get(1, 1) - 0.0).abs() < 1e-5);
1379        assert!((transformed.get(2, 1) - 0.0).abs() < 1e-5);
1380    }
1381
1382    #[test]
1383    fn test_empty_data_error() {
1384        let data = Matrix::from_vec(0, 2, vec![]).expect("empty matrix should be valid");
1385        let mut scaler = StandardScaler::new();
1386        assert!(scaler.fit(&data).is_err());
1387    }
1388
1389    #[test]
1390    fn test_transform_not_fitted_error() {
1391        let data = Matrix::from_vec(3, 1, vec![1.0, 2.0, 3.0]).expect("valid matrix dimensions");
1392        let scaler = StandardScaler::new();
1393        assert!(scaler.transform(&data).is_err());
1394    }
1395
1396    #[test]
1397    fn test_dimension_mismatch_error() {
1398        let train = Matrix::from_vec(3, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
1399            .expect("valid matrix dimensions");
1400        let test = Matrix::from_vec(3, 3, vec![1.0; 9]).expect("valid matrix dimensions");
1401
1402        let mut scaler = StandardScaler::new();
1403        scaler.fit(&train).expect("fit should succeed");
1404
1405        assert!(scaler.transform(&test).is_err());
1406    }
1407
1408    #[test]
1409    fn test_single_sample() {
1410        let data = Matrix::from_vec(1, 2, vec![5.0, 10.0]).expect("valid matrix dimensions");
1411
1412        let mut scaler = StandardScaler::new();
1413        scaler
1414            .fit(&data)
1415            .expect("fit should succeed with single sample");
1416
1417        // With single sample, std is 0
1418        let std = scaler.std();
1419        assert!((std[0]).abs() < 1e-6);
1420        assert!((std[1]).abs() < 1e-6);
1421
1422        // Transform should center only (std is 0, no scaling)
1423        let transformed = scaler.transform(&data).expect("transform should succeed");
1424        assert!((transformed.get(0, 0)).abs() < 1e-5);
1425        assert!((transformed.get(0, 1)).abs() < 1e-5);
1426    }
1427
1428    #[test]
1429    fn test_builder_chain() {
1430        let scaler = StandardScaler::new().with_mean(false).with_std(true);
1431
1432        let data = Matrix::from_vec(2, 1, vec![2.0, 4.0]).expect("valid matrix dimensions");
1433        let mut scaler = scaler;
1434        let transformed = scaler
1435            .fit_transform(&data)
1436            .expect("fit_transform should succeed");
1437
1438        // Only scaling, no centering
1439        // Values: 2, 4; mean=3; std=1
1440        // Without centering: 2/1=2, 4/1=4
1441        assert!(transformed.get(0, 0) > 0.0, "Should not be centered");
1442    }
1443
1444    // MinMaxScaler tests
1445    #[test]
1446    fn test_minmax_new() {
1447        let scaler = MinMaxScaler::new();
1448        assert!(!scaler.is_fitted());
1449    }
1450
1451    #[test]
1452    fn test_minmax_default() {
1453        let scaler = MinMaxScaler::default();
1454        assert!(!scaler.is_fitted());
1455    }
1456
1457    #[test]
1458    fn test_minmax_fit_basic() {
1459        let data = Matrix::from_vec(3, 2, vec![1.0, 10.0, 2.0, 20.0, 3.0, 30.0])
1460            .expect("valid matrix dimensions");
1461
1462        let mut scaler = MinMaxScaler::new();
1463        scaler
1464            .fit(&data)
1465            .expect("fit should succeed with valid data");
1466
1467        assert!(scaler.is_fitted());
1468
1469        // Min should be [1.0, 10.0], max should be [3.0, 30.0]
1470        let data_min = scaler.data_min();
1471        let data_max = scaler.data_max();
1472        assert!((data_min[0] - 1.0).abs() < 1e-6);
1473        assert!((data_min[1] - 10.0).abs() < 1e-6);
1474        assert!((data_max[0] - 3.0).abs() < 1e-6);
1475        assert!((data_max[1] - 30.0).abs() < 1e-6);
1476    }
1477
1478    #[test]
1479    fn test_minmax_transform_basic() {
1480        let data = Matrix::from_vec(3, 1, vec![0.0, 5.0, 10.0]).expect("valid matrix dimensions");
1481
1482        let mut scaler = MinMaxScaler::new();
1483        scaler
1484            .fit(&data)
1485            .expect("fit should succeed with valid data");
1486
1487        let transformed = scaler
1488            .transform(&data)
1489            .expect("transform should succeed after fit");
1490
1491        // Should scale to [0, 1]
1492        assert!((transformed.get(0, 0) - 0.0).abs() < 1e-6);
1493        assert!((transformed.get(1, 0) - 0.5).abs() < 1e-6);
1494        assert!((transformed.get(2, 0) - 1.0).abs() < 1e-6);
1495    }
1496
1497    #[test]
1498    fn test_minmax_fit_transform() {
1499        let data = Matrix::from_vec(4, 2, vec![0.0, 0.0, 10.0, 100.0, 20.0, 200.0, 30.0, 300.0])
1500            .expect("valid matrix dimensions");
1501
1502        let mut scaler = MinMaxScaler::new();
1503        let transformed = scaler
1504            .fit_transform(&data)
1505            .expect("fit_transform should succeed with valid data");
1506
1507        // Check min is 0 and max is 1 for each column
1508        for j in 0..2 {
1509            let mut min_val = f32::INFINITY;
1510            let mut max_val = f32::NEG_INFINITY;
1511            for i in 0..4 {
1512                let val = transformed.get(i, j);
1513                if val < min_val {
1514                    min_val = val;
1515                }
1516                if val > max_val {
1517                    max_val = val;
1518                }
1519            }
1520            assert!(min_val.abs() < 1e-5, "Column {j} min should be ~0");
1521            assert!((max_val - 1.0).abs() < 1e-5, "Column {j} max should be ~1");
1522        }
1523    }
1524
1525    #[test]
1526    fn test_minmax_inverse_transform() {
1527        let data = Matrix::from_vec(3, 2, vec![1.0, 10.0, 2.0, 20.0, 3.0, 30.0])
1528            .expect("valid matrix dimensions");
1529
1530        let mut scaler = MinMaxScaler::new();
1531        let transformed = scaler
1532            .fit_transform(&data)
1533            .expect("fit_transform should succeed");
1534        let recovered = scaler
1535            .inverse_transform(&transformed)
1536            .expect("inverse_transform should succeed");
1537
1538        // Should recover original data
1539        for i in 0..3 {
1540            for j in 0..2 {
1541                assert!(
1542                    (data.get(i, j) - recovered.get(i, j)).abs() < 1e-5,
1543                    "Mismatch at ({i}, {j})"
1544                );
1545            }
1546        }
1547    }
1548
1549    #[test]
1550    fn test_minmax_transform_new_data() {
1551        let train = Matrix::from_vec(3, 1, vec![0.0, 5.0, 10.0]).expect("valid matrix dimensions");
1552        let test = Matrix::from_vec(2, 1, vec![15.0, -5.0]).expect("valid matrix dimensions");
1553
1554        let mut scaler = MinMaxScaler::new();
1555        scaler
1556            .fit(&train)
1557            .expect("fit should succeed with valid data");
1558
1559        let transformed = scaler
1560            .transform(&test)
1561            .expect("transform should succeed with new data");
1562
1563        // 15 should map to 1.5 (beyond training range)
1564        // -5 should map to -0.5 (below training range)
1565        assert!((transformed.get(0, 0) - 1.5).abs() < 1e-5);
1566        assert!((transformed.get(1, 0) - (-0.5)).abs() < 1e-5);
1567    }
1568
1569    #[test]
1570    fn test_minmax_custom_range() {
1571        let data = Matrix::from_vec(3, 1, vec![0.0, 5.0, 10.0]).expect("valid matrix dimensions");
1572
1573        let mut scaler = MinMaxScaler::new().with_range(-1.0, 1.0);
1574        let transformed = scaler
1575            .fit_transform(&data)
1576            .expect("fit_transform should succeed");
1577
1578        // Should scale to [-1, 1]
1579        assert!((transformed.get(0, 0) - (-1.0)).abs() < 1e-6);
1580        assert!((transformed.get(1, 0) - 0.0).abs() < 1e-6);
1581        assert!((transformed.get(2, 0) - 1.0).abs() < 1e-6);
1582    }
1583
1584    #[test]
1585    fn test_minmax_constant_feature() {
1586        // Feature with same min and max
1587        let data = Matrix::from_vec(3, 2, vec![1.0, 5.0, 2.0, 5.0, 3.0, 5.0])
1588            .expect("valid matrix dimensions");
1589
1590        let mut scaler = MinMaxScaler::new();
1591        let transformed = scaler
1592            .fit_transform(&data)
1593            .expect("fit_transform should succeed");
1594
1595        // Second column is constant, should become feature_min (0)
1596        assert!((transformed.get(0, 1) - 0.0).abs() < 1e-5);
1597        assert!((transformed.get(1, 1) - 0.0).abs() < 1e-5);
1598        assert!((transformed.get(2, 1) - 0.0).abs() < 1e-5);
1599    }
1600
1601    #[test]
1602    fn test_minmax_empty_data_error() {
1603        let data = Matrix::from_vec(0, 2, vec![]).expect("empty matrix should be valid");
1604        let mut scaler = MinMaxScaler::new();
1605        assert!(scaler.fit(&data).is_err());
1606    }
1607
1608    #[test]
1609    fn test_minmax_transform_not_fitted_error() {
1610        let data = Matrix::from_vec(3, 1, vec![1.0, 2.0, 3.0]).expect("valid matrix dimensions");
1611        let scaler = MinMaxScaler::new();
1612        assert!(scaler.transform(&data).is_err());
1613    }
1614
1615    #[test]
1616    fn test_minmax_dimension_mismatch_error() {
1617        let train = Matrix::from_vec(3, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
1618            .expect("valid matrix dimensions");
1619        let test = Matrix::from_vec(3, 3, vec![1.0; 9]).expect("valid matrix dimensions");
1620
1621        let mut scaler = MinMaxScaler::new();
1622        scaler.fit(&train).expect("fit should succeed");
1623
1624        assert!(scaler.transform(&test).is_err());
1625    }
1626
1627    #[test]
1628    fn test_minmax_single_sample() {
1629        let data = Matrix::from_vec(1, 2, vec![5.0, 10.0]).expect("valid matrix dimensions");
1630
1631        let mut scaler = MinMaxScaler::new();
1632        scaler
1633            .fit(&data)
1634            .expect("fit should succeed with single sample");
1635
1636        // With single sample, min = max = value
1637        let data_min = scaler.data_min();
1638        let data_max = scaler.data_max();
1639        assert!((data_min[0] - 5.0).abs() < 1e-6);
1640        assert!((data_max[0] - 5.0).abs() < 1e-6);
1641
1642        // Transform should give feature_min (0) since range is 0
1643        let transformed = scaler.transform(&data).expect("transform should succeed");
1644        assert!((transformed.get(0, 0)).abs() < 1e-5);
1645        assert!((transformed.get(0, 1)).abs() < 1e-5);
1646    }
1647
1648    #[test]
1649    fn test_minmax_inverse_with_custom_range() {
1650        let data = Matrix::from_vec(3, 1, vec![0.0, 5.0, 10.0]).expect("valid matrix dimensions");
1651
1652        let mut scaler = MinMaxScaler::new().with_range(-1.0, 1.0);
1653        let transformed = scaler
1654            .fit_transform(&data)
1655            .expect("fit_transform should succeed");
1656        let recovered = scaler
1657            .inverse_transform(&transformed)
1658            .expect("inverse_transform should succeed");
1659
1660        for i in 0..3 {
1661            assert!(
1662                (data.get(i, 0) - recovered.get(i, 0)).abs() < 1e-5,
1663                "Mismatch at row {i}"
1664            );
1665        }
1666    }
1667
1668    // PCA tests
1669    #[test]
1670    fn test_pca_basic_fit_transform() {
1671        // Simple 2D data that should reduce to 1D along diagonal
1672        let data = Matrix::from_vec(4, 2, vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0])
1673            .expect("valid matrix dimensions");
1674
1675        let mut pca = PCA::new(1);
1676        let transformed = pca
1677            .fit_transform(&data)
1678            .expect("fit_transform should succeed");
1679
1680        // Should reduce to (n_samples, n_components)
1681        assert_eq!(transformed.shape(), (4, 1));
1682
1683        // Mean should be centered (approximately)
1684        let mut sum = 0.0;
1685        for i in 0..4 {
1686            sum += transformed.get(i, 0);
1687        }
1688        let mean = sum / 4.0;
1689        assert!(mean.abs() < 1e-5, "Mean should be ~0, got {mean}");
1690    }
1691
1692    #[test]
1693    fn test_pca_explained_variance() {
1694        // Data with known variance structure
1695        let data = Matrix::from_vec(5, 2, vec![1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0, 0.0])
1696            .expect("valid matrix dimensions");
1697
1698        let mut pca = PCA::new(2);
1699        pca.fit(&data).expect("fit should succeed with valid data");
1700
1701        let explained_var = pca
1702            .explained_variance()
1703            .expect("explained variance should exist after fit");
1704        let explained_ratio = pca
1705            .explained_variance_ratio()
1706            .expect("explained variance ratio should exist after fit");
1707
1708        // First component should capture all variance (second column is constant)
1709        assert_eq!(explained_var.len(), 2);
1710        assert_eq!(explained_ratio.len(), 2);
1711
1712        // Ratios should sum to approximately 1.0
1713        let total_ratio: f32 = explained_ratio.iter().sum();
1714        assert!(
1715            (total_ratio - 1.0).abs() < 1e-5,
1716            "Variance ratios should sum to 1.0, got {total_ratio}"
1717        );
1718
1719        // First component should explain most variance
1720        assert!(
1721            explained_ratio[0] > 0.99,
1722            "First component should explain >99% variance"
1723        );
1724    }
1725
1726    #[test]
1727    fn test_pca_inverse_transform() {
1728        let data = Matrix::from_vec(
1729            4,
1730            3,
1731            vec![
1732                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
1733            ],
1734        )
1735        .expect("valid matrix dimensions");
1736
1737        let mut pca = PCA::new(2);
1738        let transformed = pca
1739            .fit_transform(&data)
1740            .expect("fit_transform should succeed");
1741        let reconstructed = pca
1742            .inverse_transform(&transformed)
1743            .expect("inverse_transform should succeed");
1744
1745        // Reconstruction should be close to original (with some loss since n_components < n_features)
1746        assert_eq!(reconstructed.shape(), data.shape());
1747
1748        // Check reconstruction error is reasonable
1749        let mut total_error = 0.0;
1750        for i in 0..4 {
1751            for j in 0..3 {
1752                let error = (data.get(i, j) - reconstructed.get(i, j)).abs();
1753                total_error += error * error;
1754            }
1755        }
1756        let mse = total_error / 12.0;
1757        // With dimensionality reduction, some error is expected
1758        assert!(mse < 10.0, "Reconstruction MSE too large: {mse}");
1759    }
1760
1761    #[test]
1762    fn test_pca_perfect_reconstruction() {
1763        // When n_components == n_features, reconstruction should be perfect
1764        let data = Matrix::from_vec(3, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
1765            .expect("valid matrix dimensions");
1766
1767        let mut pca = PCA::new(2);
1768        let transformed = pca
1769            .fit_transform(&data)
1770            .expect("fit_transform should succeed");
1771        let reconstructed = pca
1772            .inverse_transform(&transformed)
1773            .expect("inverse_transform should succeed");
1774
1775        // Perfect reconstruction
1776        for i in 0..3 {
1777            for j in 0..2 {
1778                assert!(
1779                    (data.get(i, j) - reconstructed.get(i, j)).abs() < 1e-4,
1780                    "Perfect reconstruction failed at ({}, {}): {} vs {}",
1781                    i,
1782                    j,
1783                    data.get(i, j),
1784                    reconstructed.get(i, j)
1785                );
1786            }
1787        }
1788    }
1789
1790    #[test]
1791    fn test_pca_n_components_exceeds_features() {
1792        let data = Matrix::from_vec(3, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
1793            .expect("valid matrix dimensions");
1794
1795        let mut pca = PCA::new(3); // More components than features
1796        let result = pca.fit(&data);
1797
1798        assert!(
1799            result.is_err(),
1800            "Should fail when n_components > n_features"
1801        );
1802        assert_eq!(
1803            result.expect_err("Should fail when n_components exceeds features"),
1804            "n_components cannot exceed number of features"
1805        );
1806    }
1807
1808    #[test]
1809    fn test_pca_not_fitted_error() {
1810        let data = Matrix::from_vec(3, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
1811            .expect("valid matrix dimensions");
1812
1813        let pca = PCA::new(1);
1814        let result = pca.transform(&data);
1815
1816        assert!(result.is_err(), "Should fail when transforming before fit");
1817        assert_eq!(
1818            result.expect_err("Should fail when PCA not fitted"),
1819            "PCA not fitted"
1820        );
1821    }
1822
1823    #[test]
1824    fn test_pca_dimension_mismatch() {
1825        let train = Matrix::from_vec(3, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
1826            .expect("valid matrix dimensions");
1827        let test = Matrix::from_vec(3, 3, vec![1.0; 9]).expect("valid matrix dimensions");
1828
1829        let mut pca = PCA::new(1);
1830        pca.fit(&train).expect("fit should succeed");
1831
1832        let result = pca.transform(&test);
1833        assert!(result.is_err(), "Should fail on dimension mismatch");
1834        assert_eq!(
1835            result.expect_err("Should fail with dimension mismatch"),
1836            "Input has wrong number of features"
1837        );
1838    }
1839
1840    #[test]
1841    fn test_pca_inverse_dimension_mismatch() {
1842        let train = Matrix::from_vec(3, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
1843            .expect("valid matrix dimensions");
1844        let wrong_transformed =
1845            Matrix::from_vec(3, 2, vec![1.0; 6]).expect("valid matrix dimensions");
1846
1847        let mut pca = PCA::new(1);
1848        pca.fit(&train).expect("fit should succeed");
1849
1850        let result = pca.inverse_transform(&wrong_transformed);
1851        assert!(
1852            result.is_err(),
1853            "Should fail on inverse transform dimension mismatch"
1854        );
1855        assert_eq!(
1856            result.expect_err("Should fail with wrong component count"),
1857            "Input has wrong number of components"
1858        );
1859    }
1860
1861    #[test]
1862    fn test_pca_components_shape() {
1863        let data = Matrix::from_vec(5, 4, vec![1.0; 20]).expect("valid matrix dimensions");
1864
1865        let mut pca = PCA::new(2);
1866        pca.fit(&data).expect("fit should succeed with valid data");
1867
1868        let components = pca.components().expect("components should exist after fit");
1869        // Components should be (n_components, n_features)
1870        assert_eq!(components.shape(), (2, 4));
1871    }
1872
1873    #[test]
1874    fn test_pca_variance_preservation() {
1875        // Property test: total variance should be preserved
1876        let data = Matrix::from_vec(
1877            6,
1878            3,
1879            vec![
1880                1.0, 4.0, 7.0, 2.0, 5.0, 8.0, 3.0, 6.0, 9.0, 4.0, 7.0, 10.0, 5.0, 8.0, 11.0, 6.0,
1881                9.0, 12.0,
1882            ],
1883        )
1884        .expect("valid matrix dimensions");
1885
1886        let mut pca = PCA::new(3);
1887        pca.fit(&data).expect("fit should succeed with valid data");
1888
1889        let explained_var = pca
1890            .explained_variance()
1891            .expect("explained variance should exist after fit");
1892
1893        // Sum of explained variance should be close to total variance
1894        let total_explained: f32 = explained_var.iter().sum();
1895
1896        // Calculate actual variance of centered data
1897        let (n_samples, n_features) = data.shape();
1898        let mut means = vec![0.0; n_features];
1899        for (j, mean) in means.iter_mut().enumerate() {
1900            for i in 0..n_samples {
1901                *mean += data.get(i, j);
1902            }
1903            *mean /= n_samples as f32;
1904        }
1905
1906        let mut total_var = 0.0;
1907        for (j, &mean_j) in means.iter().enumerate() {
1908            for i in 0..n_samples {
1909                let diff = data.get(i, j) - mean_j;
1910                total_var += diff * diff;
1911            }
1912        }
1913        total_var /= (n_samples - 1) as f32;
1914
1915        // Explained variance should match total variance (with full components)
1916        assert!(
1917            (total_explained - total_var).abs() < 1e-3,
1918            "Total explained variance {total_explained} should match total variance {total_var}"
1919        );
1920    }
1921
1922    #[test]
1923    fn test_pca_component_orthogonality() {
1924        // Property test: principal components should be orthogonal
1925        let data = Matrix::from_vec(
1926            10,
1927            4,
1928            vec![
1929                1.0, 2.0, 3.0, 4.0, 2.0, 4.0, 6.0, 8.0, 3.0, 6.0, 9.0, 12.0, 4.0, 8.0, 12.0, 16.0,
1930                5.0, 10.0, 15.0, 20.0, 6.0, 12.0, 18.0, 24.0, 7.0, 14.0, 21.0, 28.0, 8.0, 16.0,
1931                24.0, 32.0, 9.0, 18.0, 27.0, 36.0, 10.0, 20.0, 30.0, 40.0,
1932            ],
1933        )
1934        .expect("valid matrix dimensions");
1935
1936        let mut pca = PCA::new(3);
1937        pca.fit(&data).expect("fit should succeed with valid data");
1938
1939        let components = pca.components().expect("components should exist after fit");
1940        let (_n_components, n_features) = components.shape();
1941
1942        // Check that all pairs of components are orthogonal (dot product ≈ 0)
1943        for i in 0..3 {
1944            for j in (i + 1)..3 {
1945                let mut dot_product = 0.0;
1946                for k in 0..n_features {
1947                    dot_product += components.get(i, k) * components.get(j, k);
1948                }
1949                assert!(
1950                    dot_product.abs() < 1e-4,
1951                    "Components {i} and {j} should be orthogonal, got dot product {dot_product}"
1952                );
1953            }
1954        }
1955
1956        // Check that each component is normalized (length ≈ 1)
1957        for i in 0..3 {
1958            let mut norm_sq = 0.0;
1959            for k in 0..n_features {
1960                let val = components.get(i, k);
1961                norm_sq += val * val;
1962            }
1963            let norm = norm_sq.sqrt();
1964            assert!(
1965                (norm - 1.0).abs() < 1e-4,
1966                "Component {i} should be unit length, got {norm}"
1967            );
1968        }
1969    }
1970
1971    // ========================================================================
1972    // t-SNE Tests
1973    // ========================================================================
1974
1975    #[test]
1976    fn test_tsne_new() {
1977        let tsne = TSNE::new(2);
1978        assert!(!tsne.is_fitted());
1979        assert_eq!(tsne.n_components(), 2);
1980    }
1981
1982    #[test]
1983    fn test_tsne_fit_basic() {
1984        // Simple 2D data, reduce to 2D (should work)
1985        let data = Matrix::from_vec(
1986            6,
1987            3,
1988            vec![
1989                1.0, 2.0, 3.0, 1.1, 2.1, 3.1, 5.0, 6.0, 7.0, 5.1, 6.1, 7.1, 10.0, 11.0, 12.0, 10.1,
1990                11.1, 12.1,
1991            ],
1992        )
1993        .expect("valid matrix dimensions");
1994
1995        let mut tsne = TSNE::new(2);
1996        tsne.fit(&data).expect("fit should succeed with valid data");
1997        assert!(tsne.is_fitted());
1998    }
1999
2000    #[test]
2001    fn test_tsne_transform() {
2002        let data = Matrix::from_vec(
2003            4,
2004            3,
2005            vec![
2006                1.0, 2.0, 3.0, 2.0, 3.0, 4.0, 10.0, 11.0, 12.0, 11.0, 12.0, 13.0,
2007            ],
2008        )
2009        .expect("valid matrix dimensions");
2010
2011        let mut tsne = TSNE::new(2);
2012        tsne.fit(&data).expect("fit should succeed with valid data");
2013
2014        let transformed = tsne
2015            .transform(&data)
2016            .expect("transform should succeed after fit");
2017        assert_eq!(transformed.shape(), (4, 2));
2018    }
2019
2020    #[test]
2021    fn test_tsne_fit_transform() {
2022        let data = Matrix::from_vec(
2023            4,
2024            3,
2025            vec![
2026                1.0, 2.0, 3.0, 2.0, 3.0, 4.0, 10.0, 11.0, 12.0, 11.0, 12.0, 13.0,
2027            ],
2028        )
2029        .expect("valid matrix dimensions");
2030
2031        let mut tsne = TSNE::new(2);
2032        let transformed = tsne
2033            .fit_transform(&data)
2034            .expect("fit_transform should succeed");
2035        assert_eq!(transformed.shape(), (4, 2));
2036        assert!(tsne.is_fitted());
2037    }
2038
2039    #[test]
2040    fn test_tsne_perplexity() {
2041        let data = Matrix::from_vec(
2042            10,
2043            3,
2044            vec![
2045                1.0, 2.0, 3.0, 1.1, 2.1, 3.1, 1.2, 2.2, 3.2, 5.0, 6.0, 7.0, 5.1, 6.1, 7.1, 5.2,
2046                6.2, 7.2, 10.0, 11.0, 12.0, 10.1, 11.1, 12.1, 10.2, 11.2, 12.2, 10.3, 11.3, 12.3,
2047            ],
2048        )
2049        .expect("valid matrix dimensions");
2050
2051        // Low perplexity (more local)
2052        let mut tsne_low = TSNE::new(2).with_perplexity(2.0);
2053        let result_low = tsne_low
2054            .fit_transform(&data)
2055            .expect("fit_transform should succeed with low perplexity");
2056        assert_eq!(result_low.shape(), (10, 2));
2057
2058        // High perplexity (more global)
2059        let mut tsne_high = TSNE::new(2).with_perplexity(5.0);
2060        let result_high = tsne_high
2061            .fit_transform(&data)
2062            .expect("fit_transform should succeed with high perplexity");
2063        assert_eq!(result_high.shape(), (10, 2));
2064    }
2065
2066    #[test]
2067    fn test_tsne_learning_rate() {
2068        let data = Matrix::from_vec(
2069            6,
2070            3,
2071            vec![
2072                1.0, 2.0, 3.0, 2.0, 3.0, 4.0, 3.0, 4.0, 5.0, 10.0, 11.0, 12.0, 11.0, 12.0, 13.0,
2073                12.0, 13.0, 14.0,
2074            ],
2075        )
2076        .expect("valid matrix dimensions");
2077
2078        let mut tsne = TSNE::new(2).with_learning_rate(100.0).with_n_iter(100);
2079        let transformed = tsne
2080            .fit_transform(&data)
2081            .expect("fit_transform should succeed with custom learning rate");
2082        assert_eq!(transformed.shape(), (6, 2));
2083    }
2084
2085    #[test]
2086    fn test_tsne_n_components() {
2087        let data = Matrix::from_vec(
2088            4,
2089            5,
2090            vec![
2091                1.0, 2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 4.0, 5.0, 6.0, 10.0, 11.0, 12.0, 13.0, 14.0,
2092                11.0, 12.0, 13.0, 14.0, 15.0,
2093            ],
2094        )
2095        .expect("valid matrix dimensions");
2096
2097        // 2D embedding
2098        let mut tsne_2d = TSNE::new(2);
2099        let result_2d = tsne_2d
2100            .fit_transform(&data)
2101            .expect("fit_transform should succeed for 2D");
2102        assert_eq!(result_2d.shape(), (4, 2));
2103
2104        // 3D embedding
2105        let mut tsne_3d = TSNE::new(3);
2106        let result_3d = tsne_3d
2107            .fit_transform(&data)
2108            .expect("fit_transform should succeed for 3D");
2109        assert_eq!(result_3d.shape(), (4, 3));
2110    }
2111
2112    #[test]
2113    fn test_tsne_reproducibility() {
2114        let data = Matrix::from_vec(
2115            6,
2116            3,
2117            vec![
2118                1.0, 2.0, 3.0, 2.0, 3.0, 4.0, 3.0, 4.0, 5.0, 10.0, 11.0, 12.0, 11.0, 12.0, 13.0,
2119                12.0, 13.0, 14.0,
2120            ],
2121        )
2122        .expect("valid matrix dimensions");
2123
2124        let mut tsne1 = TSNE::new(2).with_random_state(42);
2125        let result1 = tsne1
2126            .fit_transform(&data)
2127            .expect("fit_transform should succeed");
2128
2129        let mut tsne2 = TSNE::new(2).with_random_state(42);
2130        let result2 = tsne2
2131            .fit_transform(&data)
2132            .expect("fit_transform should succeed");
2133
2134        // Results should be identical with same random state
2135        for i in 0..6 {
2136            for j in 0..2 {
2137                assert!(
2138                    (result1.get(i, j) - result2.get(i, j)).abs() < 1e-5,
2139                    "Results should be reproducible with same random state"
2140                );
2141            }
2142        }
2143    }
2144
2145    #[test]
2146    #[should_panic(expected = "Model not fitted")]
2147    fn test_tsne_transform_before_fit() {
2148        let data =
2149            Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]).expect("valid matrix dimensions");
2150        let tsne = TSNE::new(2);
2151        let _ = tsne.transform(&data);
2152    }
2153
2154    #[test]
2155    fn test_tsne_preserves_local_structure() {
2156        // Create data with clear local structure
2157        let data = Matrix::from_vec(
2158            8,
2159            3,
2160            vec![
2161                // Cluster 1: tight cluster around (0, 0, 0)
2162                0.0, 0.0, 0.0, 0.1, 0.1, 0.1, // Cluster 2: tight cluster around (5, 5, 5)
2163                5.0, 5.0, 5.0, 5.1, 5.1, 5.1,
2164                // Cluster 3: tight cluster around (10, 10, 10)
2165                10.0, 10.0, 10.0, 10.1, 10.1, 10.1,
2166                // Cluster 4: tight cluster around (15, 15, 15)
2167                15.0, 15.0, 15.0, 15.1, 15.1, 15.1,
2168            ],
2169        )
2170        .expect("valid matrix dimensions");
2171
2172        let mut tsne = TSNE::new(2)
2173            .with_random_state(42)
2174            .with_n_iter(500)
2175            .with_perplexity(3.0);
2176        let embedding = tsne
2177            .fit_transform(&data)
2178            .expect("fit_transform should succeed");
2179
2180        // Points within same cluster should be close in embedding
2181        // Cluster 1: points 0, 1
2182        let dist_01 = ((embedding.get(0, 0) - embedding.get(1, 0)).powi(2)
2183            + (embedding.get(0, 1) - embedding.get(1, 1)).powi(2))
2184        .sqrt();
2185
2186        // Distance to far cluster should be larger
2187        let dist_03 = ((embedding.get(0, 0) - embedding.get(3, 0)).powi(2)
2188            + (embedding.get(0, 1) - embedding.get(3, 1)).powi(2))
2189        .sqrt();
2190
2191        // Allow some tolerance - t-SNE is stochastic
2192        // Just verify local structure is somewhat preserved
2193        assert!(
2194            dist_01 < dist_03 * 1.5,
2195            "Local structure should be roughly preserved: dist_01={:.3} should be < dist_03*1.5={:.3}",
2196            dist_01,
2197            dist_03 * 1.5
2198        );
2199    }
2200
2201    #[test]
2202    fn test_tsne_min_samples() {
2203        // t-SNE should work with minimum number of samples (> perplexity * 3)
2204        let data = Matrix::from_vec(
2205            10,
2206            3,
2207            vec![
2208                1.0, 2.0, 3.0, 2.0, 3.0, 4.0, 3.0, 4.0, 5.0, 4.0, 5.0, 6.0, 5.0, 6.0, 7.0, 6.0,
2209                7.0, 8.0, 7.0, 8.0, 9.0, 8.0, 9.0, 10.0, 9.0, 10.0, 11.0, 10.0, 11.0, 12.0,
2210            ],
2211        )
2212        .expect("valid matrix dimensions");
2213
2214        let mut tsne = TSNE::new(2).with_perplexity(3.0);
2215        let result = tsne
2216            .fit_transform(&data)
2217            .expect("fit_transform should succeed with minimum samples");
2218        assert_eq!(result.shape(), (10, 2));
2219    }
2220
2221    #[test]
2222    fn test_tsne_embedding_finite() {
2223        let data = Matrix::from_vec(
2224            6,
2225            3,
2226            vec![
2227                1.0, 2.0, 3.0, 2.0, 3.0, 4.0, 3.0, 4.0, 5.0, 10.0, 11.0, 12.0, 11.0, 12.0, 13.0,
2228                12.0, 13.0, 14.0,
2229            ],
2230        )
2231        .expect("valid matrix dimensions");
2232
2233        let mut tsne = TSNE::new(2).with_n_iter(100);
2234        let embedding = tsne
2235            .fit_transform(&data)
2236            .expect("fit_transform should succeed");
2237
2238        // All embedding values should be finite
2239        for i in 0..6 {
2240            for j in 0..2 {
2241                assert!(
2242                    embedding.get(i, j).is_finite(),
2243                    "Embedding should contain only finite values"
2244                );
2245            }
2246        }
2247    }
2248}