sklears_gaussian_process/
regression.rs

1//! Gaussian Process Regression Models
2//!
3//! This module provides advanced Gaussian Process regression implementations:
4//! - `VariationalSparseGaussianProcessRegressor`: Scalable GP regression using variational inference
5//! - `MultiOutputGaussianProcessRegressor`: Multi-output GP regression with Linear Model of Coregionalization
6//!
7//! These models offer efficient solutions for large-scale regression problems with uncertainty quantification.
8
9use std::f64::consts::PI;
10
11// SciRS2 Policy - Use scirs2-autograd for ndarray types and operations
12use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis};
13// SciRS2 Policy - Use scirs2-core for random number generation
14// use scirs2_core::random::Rng;
15use sklears_core::{
16    error::{Result as SklResult, SklearsError},
17    traits::{Estimator, Fit, Predict, Untrained},
18};
19
20use crate::classification::GpcConfig;
21use crate::kernels::Kernel;
22use crate::sparse_gpr;
23use crate::utils;
24
25/// Optimization method for variational sparse Gaussian processes
26#[derive(Debug, Clone, PartialEq, Default)]
27pub enum VariationalOptimizer {
28    /// Adam optimizer with adaptive learning rates
29    #[default]
30    Adam,
31    /// Natural gradients optimizer using the Fisher information metric
32    NaturalGradients,
33    /// Doubly stochastic variational inference with mini-batches for both data and inducing points
34    DoublyStochastic,
35}
36
37///
38/// let X = array![[0.0], [1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0]];
39/// let y = array![0.0, 1.0, 4.0, 9.0, 16.0, 25.0, 36.0, 49.0];
40///
41/// let kernel = RBF::new(2.0);
42/// let vsgpr = VariationalSparseGaussianProcessRegressor::new()
43///     .kernel(Box::new(kernel))
44///     .n_inducing(3);
45/// let fitted = vsgpr.fit(&X.view(), &y.view()).unwrap();
46/// let predictions = fitted.predict(&X.view()).unwrap();
47/// ```
48#[derive(Debug, Clone)]
49pub struct VariationalSparseGaussianProcessRegressor<S = Untrained> {
50    state: S,
51    kernel: Option<Box<dyn Kernel>>,
52    n_inducing: usize,
53    inducing_init: sparse_gpr::InducingPointInit,
54    optimizer: VariationalOptimizer,
55    learning_rate: f64,
56    max_iter: usize,
57    batch_size: Option<usize>,
58    inducing_batch_size: Option<usize>, // Mini-batch size for inducing points in doubly stochastic
59    beta1: f64,
60    beta2: f64,
61    epsilon: f64,
62    natural_gradient_damping: f64, // Damping factor for natural gradients
63    sigma_n: f64,
64    tol: f64,
65    verbose: bool,
66    random_state: Option<u64>,
67    config: GpcConfig,
68}
69
70/// Trained state for Variational Sparse Gaussian Process Regressor
71#[derive(Debug, Clone)]
72pub struct VsgprTrained {
73    /// Z
74    pub Z: Array2<f64>, // Inducing points
75    /// m
76    pub m: Array1<f64>, // Variational mean
77    /// S
78    pub S: Array2<f64>, // Variational covariance
79    /// kernel
80    pub kernel: Box<dyn Kernel>, // Kernel function
81    /// sigma_n
82    pub sigma_n: f64, // Noise standard deviation
83    /// elbo_history
84    pub elbo_history: Vec<f64>, // ELBO history during training
85    /// final_elbo
86    pub final_elbo: f64, // Final ELBO value
87}
88
89impl VariationalSparseGaussianProcessRegressor<Untrained> {
90    /// Create a new VariationalSparseGaussianProcessRegressor instance
91    pub fn new() -> Self {
92        Self {
93            state: Untrained,
94            kernel: None,
95            n_inducing: 10,
96            inducing_init: sparse_gpr::InducingPointInit::Kmeans,
97            optimizer: VariationalOptimizer::default(),
98            learning_rate: 0.01,
99            max_iter: 1000,
100            batch_size: None,
101            inducing_batch_size: None,
102            beta1: 0.9,
103            beta2: 0.999,
104            epsilon: 1e-8,
105            natural_gradient_damping: 1e-4,
106            sigma_n: 0.1,
107            tol: 1e-6,
108            verbose: false,
109            random_state: None,
110            config: GpcConfig::default(),
111        }
112    }
113
114    /// Set the kernel function
115    pub fn kernel(mut self, kernel: Box<dyn Kernel>) -> Self {
116        self.kernel = Some(kernel);
117        self
118    }
119
120    /// Set the number of inducing points
121    pub fn n_inducing(mut self, n_inducing: usize) -> Self {
122        self.n_inducing = n_inducing;
123        self
124    }
125
126    /// Set the inducing point initialization method
127    pub fn inducing_init(mut self, inducing_init: sparse_gpr::InducingPointInit) -> Self {
128        self.inducing_init = inducing_init;
129        self
130    }
131
132    /// Set the learning rate for optimization
133    pub fn learning_rate(mut self, learning_rate: f64) -> Self {
134        self.learning_rate = learning_rate;
135        self
136    }
137
138    /// Set the maximum number of iterations
139    pub fn max_iter(mut self, max_iter: usize) -> Self {
140        self.max_iter = max_iter;
141        self
142    }
143
144    /// Set the batch size for mini-batch optimization
145    pub fn batch_size(mut self, batch_size: Option<usize>) -> Self {
146        self.batch_size = batch_size;
147        self
148    }
149
150    /// Set the inducing batch size for doubly stochastic optimization
151    pub fn inducing_batch_size(mut self, inducing_batch_size: Option<usize>) -> Self {
152        self.inducing_batch_size = inducing_batch_size;
153        self
154    }
155
156    /// Set the noise standard deviation
157    pub fn sigma_n(mut self, sigma_n: f64) -> Self {
158        self.sigma_n = sigma_n;
159        self
160    }
161
162    /// Set convergence tolerance
163    pub fn tol(mut self, tol: f64) -> Self {
164        self.tol = tol;
165        self
166    }
167
168    /// Set verbosity
169    pub fn verbose(mut self, verbose: bool) -> Self {
170        self.verbose = verbose;
171        self
172    }
173
174    /// Set the random state
175    pub fn random_state(mut self, random_state: Option<u64>) -> Self {
176        self.random_state = random_state;
177        self
178    }
179
180    /// Set the optimization method
181    pub fn optimizer(mut self, optimizer: VariationalOptimizer) -> Self {
182        self.optimizer = optimizer;
183        self
184    }
185
186    /// Set the damping factor for natural gradients (only used when optimizer is NaturalGradients)
187    pub fn natural_gradient_damping(mut self, damping: f64) -> Self {
188        self.natural_gradient_damping = damping;
189        self
190    }
191}
192
193impl Estimator for VariationalSparseGaussianProcessRegressor<Untrained> {
194    type Config = GpcConfig;
195    type Error = SklearsError;
196    type Float = f64;
197
198    fn config(&self) -> &Self::Config {
199        &self.config
200    }
201}
202
203impl Estimator for VariationalSparseGaussianProcessRegressor<VsgprTrained> {
204    type Config = GpcConfig;
205    type Error = SklearsError;
206    type Float = f64;
207
208    fn config(&self) -> &Self::Config {
209        &self.config
210    }
211}
212
213impl Fit<ArrayView2<'_, f64>, ArrayView1<'_, f64>>
214    for VariationalSparseGaussianProcessRegressor<Untrained>
215{
216    type Fitted = VariationalSparseGaussianProcessRegressor<VsgprTrained>;
217
218    #[allow(non_snake_case)]
219    fn fit(self, X: &ArrayView2<f64>, y: &ArrayView1<f64>) -> SklResult<Self::Fitted> {
220        if X.nrows() != y.len() {
221            return Err(SklearsError::InvalidInput(
222                "X and y must have the same number of samples".to_string(),
223            ));
224        }
225
226        let kernel = self
227            .kernel
228            .ok_or_else(|| SklearsError::InvalidInput("Kernel must be specified".to_string()))?;
229
230        // Initialize inducing points
231        let Z = match self.inducing_init {
232            sparse_gpr::InducingPointInit::Random => {
233                utils::random_inducing_points(X, self.n_inducing, self.random_state)?
234            }
235            sparse_gpr::InducingPointInit::Uniform => {
236                utils::uniform_inducing_points(X, self.n_inducing, self.random_state)?
237            }
238            sparse_gpr::InducingPointInit::Kmeans => {
239                utils::kmeans_inducing_points(X, self.n_inducing, self.random_state)?
240            }
241        };
242
243        // Initialize variational parameters
244        let mut m = Array1::<f64>::zeros(self.n_inducing);
245        let mut S = Array2::<f64>::eye(self.n_inducing);
246
247        // Initialize optimizer-specific parameters
248        let mut m_adam_m = Array1::<f64>::zeros(self.n_inducing);
249        let mut m_adam_v = Array1::<f64>::zeros(self.n_inducing);
250        let mut S_adam_m = Array2::<f64>::zeros((self.n_inducing, self.n_inducing));
251        let mut S_adam_v = Array2::<f64>::zeros((self.n_inducing, self.n_inducing));
252
253        let mut elbo_history = Vec::new();
254        let n_data = X.nrows();
255        let batch_size = self.batch_size.unwrap_or(n_data);
256
257        // Training loop
258        for iter in 0..self.max_iter {
259            let mut total_elbo = 0.0;
260            let mut n_batches = 0;
261
262            // Process data in batches
263            for batch_start in (0..n_data).step_by(batch_size) {
264                let batch_end = (batch_start + batch_size).min(n_data);
265                let X_batch = X.slice(s![batch_start..batch_end, ..]);
266                let y_batch = y.slice(s![batch_start..batch_end]);
267
268                // Compute ELBO and gradients
269                let (elbo, grad_m, grad_S) = compute_elbo_and_gradients(
270                    &X_batch,
271                    &y_batch,
272                    &Z,
273                    &m,
274                    &S,
275                    &kernel,
276                    self.sigma_n,
277                )?;
278
279                total_elbo += elbo * (n_data as f64 / batch_size as f64);
280
281                // Apply optimizer-specific updates
282                match self.optimizer {
283                    VariationalOptimizer::Adam => {
284                        let t = (iter * (n_data / batch_size) + n_batches + 1) as f64;
285
286                        // Update m with Adam
287                        m_adam_m = self.beta1 * &m_adam_m + (1.0 - self.beta1) * &grad_m;
288                        m_adam_v =
289                            self.beta2 * &m_adam_v + (1.0 - self.beta2) * grad_m.mapv(|x| x * x);
290                        let m_hat = &m_adam_m / (1.0 - self.beta1.powf(t));
291                        let v_hat = &m_adam_v / (1.0 - self.beta2.powf(t));
292                        m = &m
293                            + self.learning_rate * &m_hat
294                                / (v_hat.mapv(|x| x.sqrt()) + self.epsilon);
295
296                        // Update S with Adam (ensure positive definiteness)
297                        S_adam_m = self.beta1 * &S_adam_m + (1.0 - self.beta1) * &grad_S;
298                        S_adam_v =
299                            self.beta2 * &S_adam_v + (1.0 - self.beta2) * grad_S.mapv(|x| x * x);
300                        let S_m_hat = &S_adam_m / (1.0 - self.beta1.powf(t));
301                        let S_v_hat = &S_adam_v / (1.0 - self.beta2.powf(t));
302                        S = &S
303                            + self.learning_rate * &S_m_hat
304                                / (S_v_hat.mapv(|x| x.sqrt() + self.epsilon));
305                    }
306                    VariationalOptimizer::NaturalGradients => {
307                        // Natural gradients using the Fisher information metric
308                        // For mean parameter: m += η * grad_m
309                        m = &m + self.learning_rate * &grad_m;
310
311                        // For covariance parameter: S += η * (S * grad_S * S + damping * I)
312                        // This uses the natural gradient based on the Fisher information
313                        let natural_grad_S = S.dot(&grad_S).dot(&S)
314                            + self.natural_gradient_damping * Array2::<f64>::eye(self.n_inducing);
315                        S = &S + self.learning_rate * &natural_grad_S;
316                    }
317                    VariationalOptimizer::DoublyStochastic => {
318                        // Doubly stochastic variational inference
319                        // Use mini-batches for both data (already done) and inducing points
320                        let inducing_batch_size =
321                            self.inducing_batch_size.unwrap_or(self.n_inducing);
322
323                        if inducing_batch_size < self.n_inducing {
324                            // Sample random subset of inducing points for this update
325                            // SciRS2 Policy - Use scirs2-core for random number generation
326                            let mut rng = scirs2_core::random::Random::seed(42);
327                            let mut indices: Vec<usize> = (0..self.n_inducing).collect();
328                            // Simple shuffle using Fisher-Yates algorithm
329                            for i in (1..indices.len()).rev() {
330                                let j = rng.gen_range(0..i + 1);
331                                indices.swap(i, j);
332                            }
333                            indices.truncate(inducing_batch_size);
334
335                            // Apply gradients only to selected subset with scaling
336                            let scaling_factor =
337                                self.n_inducing as f64 / inducing_batch_size as f64;
338
339                            // Update mean parameters for selected indices
340                            for &idx in &indices {
341                                m[idx] += self.learning_rate * grad_m[idx] * scaling_factor;
342                            }
343
344                            // Update covariance parameters for selected indices
345                            // Only update the submatrix corresponding to selected indices
346                            for &idx in &indices {
347                                for &jdx in &indices {
348                                    S[[idx, jdx]] +=
349                                        self.learning_rate * grad_S[[idx, jdx]] * scaling_factor;
350                                }
351                            }
352                        } else {
353                            // Full batch update - use simple gradient ascent for doubly stochastic
354                            m = &m + self.learning_rate * &grad_m;
355                            S = &S + self.learning_rate * &grad_S;
356                        }
357                    }
358                }
359
360                // Ensure S remains positive definite for both optimizers
361                S = ensure_positive_definite(S)?;
362
363                n_batches += 1;
364            }
365
366            let avg_elbo = total_elbo / n_batches as f64;
367            elbo_history.push(avg_elbo);
368
369            if self.verbose && iter % 100 == 0 {
370                println!("Iteration {}: ELBO = {:.6}", iter, avg_elbo);
371            }
372
373            // Check convergence
374            if iter > 0 && (avg_elbo - elbo_history[iter - 1]).abs() < self.tol {
375                if self.verbose {
376                    println!("Converged at iteration {}", iter);
377                }
378                break;
379            }
380        }
381
382        let final_elbo = elbo_history.last().copied().unwrap_or(0.0);
383
384        Ok(VariationalSparseGaussianProcessRegressor {
385            state: VsgprTrained {
386                Z,
387                m,
388                S,
389                kernel,
390                sigma_n: self.sigma_n,
391                elbo_history,
392                final_elbo,
393            },
394            kernel: None,
395            n_inducing: self.n_inducing,
396            inducing_init: self.inducing_init,
397            optimizer: self.optimizer,
398            learning_rate: self.learning_rate,
399            max_iter: self.max_iter,
400            batch_size: self.batch_size,
401            inducing_batch_size: self.inducing_batch_size,
402            beta1: self.beta1,
403            beta2: self.beta2,
404            epsilon: self.epsilon,
405            natural_gradient_damping: self.natural_gradient_damping,
406            sigma_n: self.sigma_n,
407            tol: self.tol,
408            verbose: self.verbose,
409            random_state: self.random_state,
410            config: self.config.clone(),
411        })
412    }
413}
414
415impl Predict<ArrayView2<'_, f64>, Array1<f64>>
416    for VariationalSparseGaussianProcessRegressor<VsgprTrained>
417{
418    fn predict(&self, X: &ArrayView2<f64>) -> SklResult<Array1<f64>> {
419        let (mean, _) = self.predict_with_std(X)?;
420        Ok(mean)
421    }
422}
423
424impl VariationalSparseGaussianProcessRegressor<VsgprTrained> {
425    /// Predict with uncertainty estimates
426    #[allow(non_snake_case)]
427    pub fn predict_with_std(&self, X: &ArrayView2<f64>) -> SklResult<(Array1<f64>, Array1<f64>)> {
428        // Compute kernel matrices
429        let Kzz = self
430            .state
431            .kernel
432            .compute_kernel_matrix(&self.state.Z, None)?;
433        let X_owned = X.to_owned();
434        let Kxz = self
435            .state
436            .kernel
437            .compute_kernel_matrix(&X_owned, Some(&self.state.Z))?;
438        let Kxx_diag = X
439            .axis_iter(Axis(0))
440            .map(|x| self.state.kernel.kernel(&x, &x))
441            .collect::<Array1<f64>>();
442
443        // Cholesky decomposition of Kzz
444        let L_zz = utils::robust_cholesky(&Kzz)?;
445
446        // Solve Lzz^{-1} * Kxz^T -> A
447        let mut A = Array2::<f64>::zeros((self.state.Z.nrows(), X.nrows()));
448        for i in 0..X.nrows() {
449            let kxz_i = Kxz.row(i).to_owned();
450            let a_i = utils::triangular_solve(&L_zz, &kxz_i)?;
451            A.column_mut(i).assign(&a_i);
452        }
453
454        // Predictive mean: A^T * m
455        let mean = A.t().dot(&self.state.m);
456
457        // Predictive variance: Kxx + A^T * (S - I) * A
458        let I = Array2::<f64>::eye(self.state.S.nrows());
459        let S_diff = &self.state.S - &I;
460        let var_correction = A.t().dot(&S_diff.dot(&A));
461
462        let mut variance = Kxx_diag.clone();
463        for i in 0..X.nrows() {
464            variance[i] += var_correction[[i, i]] + self.state.sigma_n.powi(2);
465        }
466
467        let std = variance.mapv(|x| x.sqrt().max(0.0));
468
469        Ok((mean, std))
470    }
471
472    /// Get the evidence lower bound (ELBO)
473    pub fn elbo(&self) -> f64 {
474        self.state.final_elbo
475    }
476
477    /// Get the ELBO history during training
478    pub fn elbo_history(&self) -> &[f64] {
479        &self.state.elbo_history
480    }
481
482    /// Get the inducing points
483    pub fn inducing_points(&self) -> &Array2<f64> {
484        &self.state.Z
485    }
486
487    /// Get the variational mean
488    pub fn variational_mean(&self) -> &Array1<f64> {
489        &self.state.m
490    }
491
492    /// Get the variational covariance
493    pub fn variational_covariance(&self) -> &Array2<f64> {
494        &self.state.S
495    }
496
497    /// Online update with new data for streaming/scalable learning
498    ///
499    /// This method allows incrementally updating the variational parameters
500    /// with new data points without retraining from scratch, making it suitable
501    /// for streaming scenarios and very large datasets.
502    ///
503    /// # Arguments
504    /// * `X_new` - New input data points
505    /// * `y_new` - New target values
506    /// * `learning_rate` - Learning rate for the update (if None, uses model's learning rate)
507    /// * `n_iterations` - Number of update iterations (default: 10)
508    ///
509    /// # Returns
510    /// Updated model with modified variational parameters
511    pub fn update(
512        mut self,
513        X_new: &ArrayView2<f64>,
514        y_new: &ArrayView1<f64>,
515        learning_rate: Option<f64>,
516        n_iterations: Option<usize>,
517    ) -> SklResult<Self> {
518        if X_new.nrows() != y_new.len() {
519            return Err(SklearsError::InvalidInput(
520                "X_new and y_new must have the same number of samples".to_string(),
521            ));
522        }
523
524        let lr = learning_rate.unwrap_or(self.learning_rate);
525        let n_iter = n_iterations.unwrap_or(10);
526
527        // Perform incremental updates using the current optimizer
528        for _ in 0..n_iter {
529            // Compute gradients for new data
530            let (_, grad_m, grad_S) = compute_elbo_and_gradients(
531                X_new,
532                y_new,
533                &self.state.Z,
534                &self.state.m,
535                &self.state.S,
536                &self.state.kernel,
537                self.state.sigma_n,
538            )?;
539
540            // Apply optimizer-specific updates
541            match self.optimizer {
542                VariationalOptimizer::Adam => {
543                    // Simple gradient ascent for online updates
544                    self.state.m = &self.state.m + lr * &grad_m;
545                    self.state.S = &self.state.S + lr * &grad_S;
546                }
547                VariationalOptimizer::NaturalGradients => {
548                    // Natural gradients update
549                    self.state.m = &self.state.m + lr * &grad_m;
550                    let natural_grad_S = self.state.S.dot(&grad_S).dot(&self.state.S)
551                        + self.natural_gradient_damping * Array2::<f64>::eye(self.state.Z.nrows());
552                    self.state.S = &self.state.S + lr * &natural_grad_S;
553                }
554                VariationalOptimizer::DoublyStochastic => {
555                    // For streaming updates, use simple gradient ascent
556                    self.state.m = &self.state.m + lr * &grad_m;
557                    self.state.S = &self.state.S + lr * &grad_S;
558                }
559            }
560
561            // Ensure S remains positive definite
562            self.state.S = ensure_positive_definite(self.state.S)?;
563        }
564
565        // Update ELBO history with final value
566        let (final_elbo, _, _) = compute_elbo_and_gradients(
567            X_new,
568            y_new,
569            &self.state.Z,
570            &self.state.m,
571            &self.state.S,
572            &self.state.kernel,
573            self.state.sigma_n,
574        )?;
575
576        self.state.elbo_history.push(final_elbo);
577        self.state.final_elbo = final_elbo;
578
579        Ok(self)
580    }
581
582    /// Recursive Bayesian update for online GP learning
583    ///
584    /// This method implements proper recursive Bayesian updates for Gaussian processes,
585    /// maintaining the posterior mean and covariance through sequential updates without
586    /// requiring full recomputation of the ELBO.
587    ///
588    /// # Arguments
589    /// * `X_new` - New input data (n_new x n_features)
590    /// * `y_new` - New target values (n_new,)
591    /// * `forgetting_factor` - Exponential forgetting factor (0 < λ ≤ 1)
592    ///
593    /// # Returns
594    /// Updated model with recursively updated posterior
595    #[allow(non_snake_case)]
596    pub fn recursive_update(
597        mut self,
598        X_new: &ArrayView2<f64>,
599        y_new: &ArrayView1<f64>,
600        forgetting_factor: Option<f64>,
601    ) -> SklResult<Self> {
602        if X_new.nrows() != y_new.len() {
603            return Err(SklearsError::InvalidInput(
604                "X_new and y_new must have the same number of samples".to_string(),
605            ));
606        }
607
608        let lambda = forgetting_factor.unwrap_or(1.0);
609        if lambda <= 0.0 || lambda > 1.0 {
610            return Err(SklearsError::InvalidInput(
611                "Forgetting factor must be in range (0, 1]".to_string(),
612            ));
613        }
614
615        // Apply forgetting to prior covariance (increase uncertainty)
616        if lambda < 1.0 {
617            self.state.S /= lambda;
618        }
619
620        // Compute kernel matrices for new data
621        let Kzz = self
622            .state
623            .kernel
624            .compute_kernel_matrix(&self.state.Z, None)?;
625        let X_new_owned = X_new.to_owned();
626        let Kzx_new = self
627            .state
628            .kernel
629            .compute_kernel_matrix(&self.state.Z, Some(&X_new_owned))?;
630
631        // Robust Cholesky decomposition
632        let L_zz = utils::robust_cholesky(&Kzz)?;
633
634        // For each new data point, perform recursive Bayesian update
635        for (i, &y_i) in y_new.iter().enumerate() {
636            let k_zi = Kzx_new.column(i);
637
638            // Solve L_zz * alpha = k_zi
639            let alpha = utils::triangular_solve(&L_zz, &k_zi.to_owned())?;
640
641            // Predictive variance: σ²_i = σ²_n + k_ii - α^T α
642            let k_ii = self.state.kernel.kernel(&X_new.row(i), &X_new.row(i));
643            let pred_var = self.state.sigma_n.powi(2) + k_ii - alpha.dot(&alpha);
644
645            if pred_var <= 0.0 {
646                continue; // Skip if predictive variance is non-positive
647            }
648
649            // Predictive mean: μ_i = α^T m
650            let pred_mean = alpha.dot(&self.state.m);
651
652            // Innovation (prediction error)
653            let innovation = y_i - pred_mean;
654
655            // Kalman gain: K = S * α / σ²_i
656            let kalman_gain = self.state.S.dot(&alpha) / pred_var;
657
658            // Update posterior mean: m := m + K * innovation
659            let m_update = &kalman_gain * innovation;
660            self.state.m = &self.state.m + &m_update;
661
662            // Update posterior covariance: S := S - K * α^T * S
663            let s_update = kalman_gain
664                .view()
665                .into_shape((kalman_gain.len(), 1))
666                .map_err(|_| SklearsError::FitError("Shape error in recursive update".to_string()))?
667                .dot(&alpha.view().into_shape((1, alpha.len())).map_err(|_| {
668                    SklearsError::FitError("Shape error in recursive update".to_string())
669                })?)
670                .dot(&self.state.S);
671            self.state.S = &self.state.S - &s_update;
672
673            // Ensure positive definiteness
674            self.state.S = ensure_positive_definite(self.state.S)?;
675        }
676
677        // Update ELBO history with approximated value
678        let approx_elbo = self.compute_approximate_elbo(X_new, y_new)?;
679        self.state.elbo_history.push(approx_elbo);
680        self.state.final_elbo = approx_elbo;
681
682        Ok(self)
683    }
684
685    /// Sliding window update for streaming data
686    ///
687    /// Maintains a sliding window of recent data points and uses exponential
688    /// forgetting to down-weight older observations.
689    ///
690    /// # Arguments
691    /// * `X_new` - New input data
692    /// * `y_new` - New target values
693    /// * `window_size` - Maximum number of recent observations to maintain
694    /// * `decay_rate` - Exponential decay rate for older observations
695    pub fn sliding_window_update(
696        mut self,
697        X_new: &ArrayView2<f64>,
698        y_new: &ArrayView1<f64>,
699        window_size: usize,
700        decay_rate: f64,
701    ) -> SklResult<Self> {
702        if X_new.nrows() != y_new.len() {
703            return Err(SklearsError::InvalidInput(
704                "X_new and y_new must have the same number of samples".to_string(),
705            ));
706        }
707
708        if decay_rate <= 0.0 || decay_rate > 1.0 {
709            return Err(SklearsError::InvalidInput(
710                "Decay rate must be in range (0, 1]".to_string(),
711            ));
712        }
713
714        // Apply exponential forgetting based on window position
715        let n_new = X_new.nrows();
716        for i in 0..n_new {
717            let age_weight = decay_rate.powi((n_new - i - 1) as i32);
718            let forgetting_factor = (1.0 - age_weight).max(0.1); // Minimum forgetting factor
719
720            let x_i = X_new.row(i);
721            let y_i = Array1::from(vec![y_new[i]]);
722
723            self = self.recursive_update(
724                &x_i.view()
725                    .into_shape((1, x_i.len()))
726                    .map_err(|_| {
727                        SklearsError::FitError("Shape error in sliding window update".to_string())
728                    })?
729                    .view(),
730                &y_i.view(),
731                Some(forgetting_factor),
732            )?;
733        }
734
735        // Limit ELBO history to window size
736        if self.state.elbo_history.len() > window_size {
737            let start_idx = self.state.elbo_history.len() - window_size;
738            self.state.elbo_history = self.state.elbo_history[start_idx..].to_vec();
739        }
740
741        Ok(self)
742    }
743
744    /// Compute approximate ELBO for recursive updates
745    ///
746    /// This provides a computationally efficient approximation to the full ELBO
747    /// calculation for use in recursive updates.
748    fn compute_approximate_elbo(&self, X: &ArrayView2<f64>, y: &ArrayView1<f64>) -> SklResult<f64> {
749        let _n = X.nrows() as f64;
750
751        // Compute prediction errors
752        let (y_pred, y_var) = self.predict_with_std(X)?;
753
754        // Log likelihood approximation
755        let mut log_likelihood = 0.0;
756        for i in 0..X.nrows() {
757            let residual = y[i] - y_pred[i];
758            let total_var = y_var[i] + self.state.sigma_n.powi(2);
759            log_likelihood -=
760                0.5 * (residual.powi(2) / total_var + total_var.ln() + (2.0 * PI).ln());
761        }
762
763        // KL divergence approximation (simplified)
764        let kl_divergence = self.compute_approximate_kl_divergence()?;
765
766        // ELBO = log likelihood - KL divergence
767        Ok(log_likelihood - kl_divergence)
768    }
769
770    /// Compute approximate KL divergence for ELBO calculation
771    #[allow(non_snake_case)]
772    fn compute_approximate_kl_divergence(&self) -> SklResult<f64> {
773        let m = self.state.Z.nrows();
774
775        // Prior: N(0, K_zz)
776        let K_zz = self
777            .state
778            .kernel
779            .compute_kernel_matrix(&self.state.Z, None)?;
780        let L_zz = utils::robust_cholesky(&K_zz)?;
781
782        // Posterior: N(m, S)
783        let L_s = utils::robust_cholesky(&self.state.S)?;
784
785        // KL(q||p) = 0.5 * [tr(K_zz^{-1} S) + m^T K_zz^{-1} m - m - log|S| + log|K_zz|]
786
787        // log|K_zz| = 2 * sum(log(diag(L_zz)))
788        let log_det_k = 2.0 * L_zz.diag().iter().map(|x| x.ln()).sum::<f64>();
789
790        // log|S| = 2 * sum(log(diag(L_s)))
791        let log_det_s = 2.0 * L_s.diag().iter().map(|x| x.ln()).sum::<f64>();
792
793        // Solve K_zz^{-1} m and K_zz^{-1} S
794        let k_inv_m = utils::triangular_solve(&L_zz, &self.state.m)?;
795        let mut k_inv_s_trace = 0.0;
796        for i in 0..m {
797            let s_col = self.state.S.column(i).to_owned();
798            let k_inv_s_col = utils::triangular_solve(&L_zz, &s_col)?;
799            k_inv_s_trace += k_inv_s_col.dot(&s_col);
800        }
801
802        let kl = 0.5 * (k_inv_s_trace + k_inv_m.dot(&k_inv_m) - m as f64 - log_det_s + log_det_k);
803
804        Ok(kl)
805    }
806
807    /// Adaptive sparse GP with dynamic inducing point management
808    ///
809    /// This method adaptively adds/removes inducing points based on the approximation
810    /// quality and computational budget, maintaining good approximation while controlling
811    /// computational cost.
812    ///
813    /// # Arguments
814    /// * `X_new` - New input data
815    /// * `y_new` - New target values
816    /// * `max_inducing` - Maximum number of inducing points allowed
817    /// * `quality_threshold` - Minimum approximation quality threshold (0.0-1.0)
818    /// * `removal_threshold` - Threshold for removing redundant inducing points
819    ///
820    /// # Returns
821    /// Updated model with adaptively adjusted inducing points
822    pub fn adaptive_sparse_update(
823        mut self,
824        X_new: &ArrayView2<f64>,
825        y_new: &ArrayView1<f64>,
826        max_inducing: usize,
827        quality_threshold: f64,
828        removal_threshold: f64,
829    ) -> SklResult<Self> {
830        if X_new.nrows() != y_new.len() {
831            return Err(SklearsError::InvalidInput(
832                "X_new and y_new must have the same number of samples".to_string(),
833            ));
834        }
835
836        if !(0.0..=1.0).contains(&quality_threshold) {
837            return Err(SklearsError::InvalidInput(
838                "Quality threshold must be in range [0, 1]".to_string(),
839            ));
840        }
841
842        // 1. First, perform standard recursive update
843        self = self.recursive_update(X_new, y_new, None)?;
844
845        // 2. Assess current approximation quality
846        let quality = self.assess_approximation_quality(X_new)?;
847
848        // 3. If quality is below threshold and we haven't reached max inducing points, add new ones
849        if quality < quality_threshold && self.state.Z.nrows() < max_inducing {
850            self = self.add_inducing_points(X_new, y_new, max_inducing)?;
851        }
852
853        // 4. Remove redundant inducing points if we have too many
854        if self.state.Z.nrows() > max_inducing {
855            self = self.remove_redundant_inducing_points(removal_threshold, max_inducing)?;
856        }
857
858        // 5. Optionally optimize inducing point locations
859        self = self.optimize_inducing_point_locations(X_new, y_new)?;
860
861        Ok(self)
862    }
863
864    /// Assess the approximation quality of current inducing points
865    ///
866    /// Returns a quality score between 0 and 1, where 1 indicates perfect approximation
867    fn assess_approximation_quality(&self, X: &ArrayView2<f64>) -> SklResult<f64> {
868        let n_test = (X.nrows() / 4).max(10).min(50); // Sample subset for efficiency
869        let indices: Vec<usize> = (0..X.nrows()).step_by(X.nrows() / n_test + 1).collect();
870
871        let mut total_variance_explained = 0.0;
872        let mut total_variance = 0.0;
873
874        for &i in indices.iter().take(n_test) {
875            let x_i = X.row(i);
876
877            // Compute true kernel value k(x_i, x_i)
878            let k_ii = self.state.kernel.kernel(&x_i, &x_i);
879
880            // Compute approximated kernel value through inducing points
881            let kxz = self
882                .state
883                .Z
884                .axis_iter(Axis(0))
885                .map(|z| self.state.kernel.kernel(&x_i, &z))
886                .collect::<Array1<f64>>();
887
888            let Kzz = self
889                .state
890                .kernel
891                .compute_kernel_matrix(&self.state.Z, None)?;
892            let L_zz = utils::robust_cholesky(&Kzz)?;
893            let alpha = utils::triangular_solve(&L_zz, &kxz)?;
894            let k_approx = alpha.dot(&alpha);
895
896            // Variance explained by inducing points
897            let variance_explained = k_approx / k_ii;
898            total_variance_explained += variance_explained;
899            total_variance += 1.0;
900        }
901
902        let quality = total_variance_explained / total_variance;
903        Ok(quality.min(1.0).max(0.0))
904    }
905
906    /// Add new inducing points based on data coverage and uncertainty
907    fn add_inducing_points(
908        mut self,
909        X: &ArrayView2<f64>,
910        _y: &ArrayView1<f64>,
911        max_inducing: usize,
912    ) -> SklResult<Self> {
913        let current_inducing = self.state.Z.nrows();
914        let points_to_add = (max_inducing - current_inducing).min(X.nrows());
915
916        if points_to_add == 0 {
917            return Ok(self);
918        }
919
920        // Strategy 1: Select points with highest prediction uncertainty
921        let (_, uncertainties) = self.predict_with_std(X)?;
922        let mut uncertainty_indices: Vec<(usize, f64)> = uncertainties
923            .iter()
924            .enumerate()
925            .map(|(i, &u)| (i, u))
926            .collect();
927
928        // Sort by uncertainty (descending)
929        uncertainty_indices.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
930
931        // Strategy 2: Ensure good spatial coverage by checking distances to existing inducing points
932        let mut selected_indices = Vec::new();
933        for (idx, _) in uncertainty_indices.iter().take(points_to_add * 2) {
934            let x_candidate = X.row(*idx);
935
936            // Check minimum distance to existing inducing points
937            let min_distance = self
938                .state
939                .Z
940                .axis_iter(Axis(0))
941                .map(|z| {
942                    let diff = &x_candidate - &z;
943                    diff.dot(&diff).sqrt()
944                })
945                .fold(f64::INFINITY, f64::min);
946
947            // Add point if it's sufficiently far from existing inducing points
948            if min_distance > 0.1 {
949                // Minimum distance threshold
950                selected_indices.push(*idx);
951                if selected_indices.len() >= points_to_add {
952                    break;
953                }
954            }
955        }
956
957        // Add selected points to inducing set
958        if !selected_indices.is_empty() {
959            let mut new_Z = Array2::zeros((current_inducing + selected_indices.len(), X.ncols()));
960            new_Z
961                .slice_mut(s![..current_inducing, ..])
962                .assign(&self.state.Z);
963
964            for (i, &idx) in selected_indices.iter().enumerate() {
965                new_Z.row_mut(current_inducing + i).assign(&X.row(idx));
966            }
967
968            self.state.Z = new_Z;
969
970            // Expand variational parameters
971            let new_m_size = self.state.Z.nrows();
972            let mut new_m = Array1::zeros(new_m_size);
973            new_m
974                .slice_mut(s![..current_inducing])
975                .assign(&self.state.m);
976            // Initialize new parameters with small random values
977            for i in current_inducing..new_m_size {
978                new_m[i] = 0.01 * (i as f64 - new_m_size as f64 / 2.0) / new_m_size as f64;
979            }
980            self.state.m = new_m;
981
982            // Expand covariance matrix
983            let mut new_S = Array2::eye(new_m_size) * 0.1; // Small diagonal initialization
984            new_S
985                .slice_mut(s![..current_inducing, ..current_inducing])
986                .assign(&self.state.S);
987            self.state.S = new_S;
988        }
989
990        Ok(self)
991    }
992
993    /// Remove redundant inducing points to maintain computational efficiency
994    fn remove_redundant_inducing_points(
995        mut self,
996        _removal_threshold: f64,
997        max_inducing: usize,
998    ) -> SklResult<Self> {
999        let current_inducing = self.state.Z.nrows();
1000        if current_inducing <= max_inducing {
1001            return Ok(self);
1002        }
1003
1004        let points_to_remove = current_inducing - max_inducing;
1005
1006        // Compute influence scores for each inducing point
1007        let mut influence_scores = Vec::new();
1008        for i in 0..current_inducing {
1009            // Score based on variational mean magnitude and diagonal covariance
1010            let influence = self.state.m[i].abs() + self.state.S[[i, i]];
1011            influence_scores.push((i, influence));
1012        }
1013
1014        // Sort by influence (ascending - remove least influential)
1015        influence_scores.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
1016
1017        // Select points to remove
1018        let indices_to_remove: Vec<usize> = influence_scores
1019            .iter()
1020            .take(points_to_remove)
1021            .map(|(i, _)| *i)
1022            .collect();
1023
1024        // Create mask of points to keep
1025        let mut keep_mask = vec![true; current_inducing];
1026        for &i in &indices_to_remove {
1027            keep_mask[i] = false;
1028        }
1029
1030        // Filter inducing points
1031        let kept_indices: Vec<usize> = (0..current_inducing).filter(|&i| keep_mask[i]).collect();
1032
1033        let new_size = kept_indices.len();
1034        let mut new_Z = Array2::zeros((new_size, self.state.Z.ncols()));
1035        let mut new_m = Array1::zeros(new_size);
1036        let mut new_S = Array2::zeros((new_size, new_size));
1037
1038        // Copy kept inducing points and parameters
1039        for (new_i, &old_i) in kept_indices.iter().enumerate() {
1040            new_Z.row_mut(new_i).assign(&self.state.Z.row(old_i));
1041            new_m[new_i] = self.state.m[old_i];
1042            for (new_j, &old_j) in kept_indices.iter().enumerate() {
1043                new_S[[new_i, new_j]] = self.state.S[[old_i, old_j]];
1044            }
1045        }
1046
1047        self.state.Z = new_Z;
1048        self.state.m = new_m;
1049        self.state.S = new_S;
1050
1051        Ok(self)
1052    }
1053
1054    /// Optimize inducing point locations to improve approximation
1055    fn optimize_inducing_point_locations(
1056        mut self,
1057        X: &ArrayView2<f64>,
1058        _y: &ArrayView1<f64>,
1059    ) -> SklResult<Self> {
1060        // Simple optimization: move inducing points towards data centroids in their neighborhoods
1061        let n_inducing = self.state.Z.nrows();
1062
1063        for i in 0..n_inducing {
1064            let z_i = self.state.Z.row(i);
1065
1066            // Find nearby data points
1067            let mut nearby_points = Vec::new();
1068            for j in 0..X.nrows() {
1069                let x_j = X.row(j);
1070                let distance = (&z_i - &x_j).mapv(|x| x.powi(2)).sum().sqrt();
1071                if distance < 1.0 {
1072                    // Distance threshold
1073                    nearby_points.push(j);
1074                }
1075            }
1076
1077            // Compute centroid of nearby points
1078            if !nearby_points.is_empty() {
1079                let mut centroid = Array1::zeros(X.ncols());
1080                for &j in &nearby_points {
1081                    centroid = centroid + X.row(j);
1082                }
1083                centroid /= nearby_points.len() as f64;
1084
1085                // Move inducing point towards centroid (with momentum)
1086                let momentum = 0.1;
1087                let new_z_i = (1.0 - momentum) * &z_i + momentum * &centroid;
1088                self.state.Z.row_mut(i).assign(&new_z_i);
1089            }
1090        }
1091
1092        Ok(self)
1093    }
1094}
1095
1096impl Default for VariationalSparseGaussianProcessRegressor<Untrained> {
1097    fn default() -> Self {
1098        Self::new()
1099    }
1100}
1101
1102/// Compute ELBO and gradients for variational sparse GP
1103#[allow(non_snake_case)]
1104fn compute_elbo_and_gradients(
1105    X: &ArrayView2<f64>,
1106    y: &ArrayView1<f64>,
1107    Z: &Array2<f64>,
1108    m: &Array1<f64>,
1109    S: &Array2<f64>,
1110    kernel: &Box<dyn Kernel>,
1111    sigma_n: f64,
1112) -> SklResult<(f64, Array1<f64>, Array2<f64>)> {
1113    let n = X.nrows();
1114    let m_ind = Z.nrows();
1115
1116    // Compute kernel matrices
1117    let Kzz = kernel.compute_kernel_matrix(Z, None)?;
1118    let X_owned = X.to_owned();
1119    let Kxz = kernel.compute_kernel_matrix(&X_owned, Some(Z))?;
1120    let Kxx_diag = X
1121        .axis_iter(Axis(0))
1122        .map(|x| kernel.kernel(&x, &x))
1123        .collect::<Array1<f64>>();
1124
1125    // Cholesky decomposition of Kzz
1126    let L_zz = utils::robust_cholesky(&Kzz)?;
1127
1128    // Solve Lzz^{-1} * Kxz^T -> A
1129    let mut A = Array2::<f64>::zeros((m_ind, n));
1130    for i in 0..n {
1131        let kxz_i = Kxz.row(i).to_owned();
1132        let a_i = utils::triangular_solve(&L_zz, &kxz_i)?;
1133        A.column_mut(i).assign(&a_i);
1134    }
1135
1136    // Compute predictive mean and variance
1137    let f_mean = A.t().dot(m);
1138    let A_S_At = A.t().dot(&S.dot(&A));
1139
1140    let mut f_var = Kxx_diag.clone();
1141    for i in 0..n {
1142        f_var[i] += A_S_At[[i, i]] - A.column(i).dot(&A.column(i));
1143    }
1144
1145    // Data fit term
1146    let sigma_n_sq = sigma_n * sigma_n;
1147    let mut data_fit = 0.0;
1148    for i in 0..n {
1149        let residual = y[i] - f_mean[i];
1150        let total_var = f_var[i] + sigma_n_sq;
1151        data_fit -= 0.5 * (residual * residual / total_var + total_var.ln() + (2.0 * PI).ln());
1152    }
1153
1154    // KL divergence term
1155    let I = Array2::<f64>::eye(m_ind);
1156    let Kzz_inv = utils::triangular_solve_matrix(&L_zz, &I)?;
1157    let S_Kzz_inv = S.dot(&Kzz_inv);
1158
1159    let trace_term = S_Kzz_inv.diag().sum();
1160    let quad_term = m.dot(&Kzz_inv.dot(m));
1161    let log_det_S = 2.0 * utils::robust_cholesky(S)?.diag().mapv(|x| x.ln()).sum();
1162    let log_det_Kzz = 2.0 * L_zz.diag().mapv(|x| x.ln()).sum();
1163
1164    let kl_div = 0.5 * (trace_term + quad_term - m_ind as f64 + log_det_Kzz - log_det_S);
1165
1166    let elbo = data_fit - kl_div;
1167
1168    // Compute gradients
1169    let mut grad_m = Array1::<f64>::zeros(m_ind);
1170    let mut grad_S = Array2::<f64>::zeros((m_ind, m_ind));
1171
1172    // Gradient w.r.t. m
1173    for i in 0..n {
1174        let residual = y[i] - f_mean[i];
1175        let total_var = f_var[i] + sigma_n_sq;
1176        let a_i = A.column(i);
1177        grad_m = &grad_m + residual / total_var * &a_i.to_owned();
1178    }
1179    grad_m = &grad_m - &Kzz_inv.dot(m);
1180
1181    // Gradient w.r.t. S (simplified)
1182    for i in 0..n {
1183        let residual = y[i] - f_mean[i];
1184        let total_var = f_var[i] + sigma_n_sq;
1185        let a_i = A.column(i).to_owned();
1186        let outer_a = Array2::from_shape_fn((m_ind, m_ind), |(j, k)| a_i[j] * a_i[k]);
1187
1188        grad_S = &grad_S
1189            + 0.5 * (residual * residual / (total_var * total_var) - 1.0 / total_var) * &outer_a;
1190    }
1191    grad_S = &grad_S - 0.5 * &Kzz_inv;
1192
1193    Ok((elbo, grad_m, grad_S))
1194}
1195
1196/// Ensure a matrix remains positive definite
1197fn ensure_positive_definite(mut S: Array2<f64>) -> SklResult<Array2<f64>> {
1198    // Simple approach: add small diagonal jitter if needed
1199    let min_eigenval = 1e-6;
1200
1201    // Check if matrix is symmetric
1202    let is_symmetric = S
1203        .iter()
1204        .zip(S.t().iter())
1205        .all(|(a, b)| (a - b).abs() < 1e-12);
1206    if !is_symmetric {
1207        // Force symmetry
1208        S = 0.5 * (&S + &S.t());
1209    }
1210
1211    // Add jitter to diagonal
1212    for i in 0..S.nrows() {
1213        S[[i, i]] += min_eigenval;
1214    }
1215
1216    Ok(S)
1217}
1218
1219/// Multi-output Gaussian Process Regressor with Linear Model of Coregionalization (LMC)
1220///
1221/// The Linear Model of Coregionalization (LMC) is a framework for modeling multiple
1222/// correlated outputs by expressing each output as a linear combination of independent
1223/// latent Gaussian processes. This approach captures cross-correlations between outputs
1224/// while maintaining computational efficiency.
1225///
1226/// For Q outputs and R latent GPs, the model is:
1227/// f_q(x) = Σ_r A_{q,r} * u_r(x)
1228///
1229/// where:
1230/// - f_q(x) is the q-th output function
1231/// - u_r(x) are independent latent GPs with kernel k_r(x, x')
1232/// - A is the Q×R mixing matrix that captures output correlations
1233///
1234/// # Examples
1235///
1236/// ```
1237/// use sklears_gaussian_process::{MultiOutputGaussianProcessRegressor, RBF};
1238/// use sklears_core::traits::{Fit, Predict};
1239/// // SciRS2 Policy - Use scirs2-autograd for ndarray types and operations
1240/// use scirs2_core::ndarray::array;
1241///
1242/// let kernel = RBF::new(1.0);
1243/// let mogpr = MultiOutputGaussianProcessRegressor::new()
1244///     .n_outputs(2)
1245///     .n_latent(1)
1246///     .kernel(Box::new(kernel))
1247///     .alpha(1e-10);
1248///
1249/// let X = array![[1.0], [2.0], [3.0]];
1250/// let Y = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]; // 3 samples, 2 outputs
1251///
1252/// let fitted = mogpr.fit(&X.view(), &Y.view()).unwrap();
1253/// let predictions = fitted.predict(&X.view()).unwrap();
1254/// ```
1255#[derive(Debug, Clone)]
1256pub struct MultiOutputGaussianProcessRegressor<S = Untrained> {
1257    kernel: Option<Box<dyn Kernel>>,
1258    alpha: f64,
1259    n_outputs: usize,
1260    n_latent: usize,
1261    mixing_matrix: Option<Array2<f64>>, // Q × R mixing matrix A
1262    _state: S,
1263}
1264
1265/// Trained state for MultiOutputGaussianProcessRegressor
1266#[derive(Debug, Clone)]
1267pub struct MogprTrained {
1268    X_train: Array2<f64>,
1269    Y_train: Array2<f64>,
1270    kernel: Box<dyn Kernel>,
1271    alpha: f64,
1272    n_outputs: usize,
1273    n_latent: usize,
1274    mixing_matrix: Array2<f64>,
1275    covariance_inv: Vec<Array2<f64>>, // Inverse covariances for each latent GP
1276    y_latent: Vec<Array1<f64>>,       // Latent targets for each GP
1277}
1278
1279impl MultiOutputGaussianProcessRegressor<Untrained> {
1280    /// Create a new MultiOutputGaussianProcessRegressor instance
1281    pub fn new() -> Self {
1282        Self {
1283            kernel: None,
1284            alpha: 1e-10,
1285            n_outputs: 1,
1286            n_latent: 1,
1287            mixing_matrix: None,
1288            _state: Untrained,
1289        }
1290    }
1291
1292    /// Set the kernel function
1293    pub fn kernel(mut self, kernel: Box<dyn Kernel>) -> Self {
1294        self.kernel = Some(kernel);
1295        self
1296    }
1297
1298    /// Set the regularization parameter
1299    pub fn alpha(mut self, alpha: f64) -> Self {
1300        self.alpha = alpha;
1301        self
1302    }
1303
1304    /// Set the number of outputs
1305    pub fn n_outputs(mut self, n_outputs: usize) -> Self {
1306        self.n_outputs = n_outputs;
1307        self
1308    }
1309
1310    /// Set the number of latent GPs
1311    pub fn n_latent(mut self, n_latent: usize) -> Self {
1312        self.n_latent = n_latent;
1313        self
1314    }
1315
1316    /// Set a custom mixing matrix A (Q × R)
1317    pub fn mixing_matrix(mut self, mixing_matrix: Array2<f64>) -> Self {
1318        self.mixing_matrix = Some(mixing_matrix);
1319        self
1320    }
1321
1322    /// Initialize the mixing matrix randomly or with provided values
1323    fn initialize_mixing_matrix(&self) -> Array2<f64> {
1324        if let Some(ref matrix) = self.mixing_matrix {
1325            matrix.clone()
1326        } else {
1327            // Initialize with random values from normal distribution
1328            let mut matrix = Array2::<f64>::zeros((self.n_outputs, self.n_latent));
1329            let mut rng_state = 42u64; // Simple seed
1330
1331            for i in 0..self.n_outputs {
1332                for j in 0..self.n_latent {
1333                    // Simple Box-Muller for normal samples
1334                    let u1 = self.uniform(&mut rng_state);
1335                    let u2 = self.uniform(&mut rng_state);
1336                    let normal = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
1337                    matrix[[i, j]] = normal * 0.5; // Scale down initial values
1338                }
1339            }
1340            matrix
1341        }
1342    }
1343
1344    /// Simple uniform random number generator
1345    fn uniform(&self, state: &mut u64) -> f64 {
1346        *state = state.wrapping_mul(1103515245).wrapping_add(12345);
1347        (*state as f64) / (u64::MAX as f64)
1348    }
1349
1350    /// Optimize the mixing matrix using alternating optimization
1351    fn optimize_mixing_matrix(
1352        &self,
1353        Y: &ArrayView2<f64>,
1354        _K_inv: &[Array2<f64>],
1355        n_iter: usize,
1356    ) -> (Array2<f64>, Vec<Array1<f64>>) {
1357        let (n_samples, n_outputs) = Y.dim();
1358        let mut A = self.initialize_mixing_matrix();
1359        let mut y_latent = vec![Array1::<f64>::zeros(n_samples); self.n_latent];
1360
1361        for _iter in 0..n_iter {
1362            // Update latent targets given current mixing matrix
1363            for r in 0..self.n_latent {
1364                let mut target = Array1::<f64>::zeros(n_samples);
1365
1366                // Compute pseudo-targets for latent GP r
1367                for i in 0..n_samples {
1368                    let mut weighted_sum = 0.0;
1369                    let mut weight_sum = 0.0;
1370
1371                    for q in 0..n_outputs {
1372                        let weight = A[[q, r]].powi(2);
1373                        weighted_sum += weight * Y[[i, q]];
1374                        weight_sum += weight;
1375                    }
1376
1377                    if weight_sum > 1e-10 {
1378                        target[i] = weighted_sum / weight_sum;
1379                    }
1380                }
1381
1382                y_latent[r] = target;
1383            }
1384
1385            // Update mixing matrix given current latent targets
1386            for q in 0..n_outputs {
1387                for r in 0..self.n_latent {
1388                    let mut numerator = 0.0;
1389                    let mut denominator = 0.0;
1390
1391                    for i in 0..n_samples {
1392                        // Compute the optimal A[q,r] using least squares
1393                        let residual = Y[[i, q]];
1394                        let latent_contrib = y_latent[r][i];
1395
1396                        numerator += residual * latent_contrib;
1397                        denominator += latent_contrib * latent_contrib;
1398                    }
1399
1400                    if denominator > 1e-10 {
1401                        A[[q, r]] = numerator / denominator;
1402                    }
1403                }
1404            }
1405        }
1406
1407        (A, y_latent)
1408    }
1409}
1410
1411impl Default for MultiOutputGaussianProcessRegressor<Untrained> {
1412    fn default() -> Self {
1413        Self::new()
1414    }
1415}
1416
1417impl Estimator for MultiOutputGaussianProcessRegressor<Untrained> {
1418    type Config = ();
1419    type Error = SklearsError;
1420    type Float = f64;
1421
1422    fn config(&self) -> &Self::Config {
1423        &()
1424    }
1425}
1426
1427impl Fit<ArrayView2<'_, f64>, ArrayView2<'_, f64>, SklearsError>
1428    for MultiOutputGaussianProcessRegressor<Untrained>
1429{
1430    type Fitted = MultiOutputGaussianProcessRegressor<MogprTrained>;
1431
1432    #[allow(non_snake_case)]
1433    fn fit(self, X: &ArrayView2<f64>, Y: &ArrayView2<f64>) -> Result<Self::Fitted, SklearsError> {
1434        let kernel = self
1435            .kernel
1436            .as_ref()
1437            .ok_or_else(|| SklearsError::InvalidInput("No kernel provided".to_string()))?
1438            .clone();
1439
1440        let (n_samples, _n_features) = X.dim();
1441        let (n_samples_y, n_outputs) = Y.dim();
1442
1443        if n_samples != n_samples_y {
1444            return Err(SklearsError::InvalidInput(
1445                "Number of samples in X and Y must match".to_string(),
1446            ));
1447        }
1448
1449        if n_outputs != self.n_outputs {
1450            return Err(SklearsError::InvalidInput(format!(
1451                "Expected {} outputs, got {}",
1452                self.n_outputs, n_outputs
1453            )));
1454        }
1455
1456        // Compute kernel matrix
1457        let X_owned = X.to_owned();
1458        let K = kernel.compute_kernel_matrix(&X_owned, None)?;
1459        let mut K_reg = K.clone();
1460
1461        // Add regularization
1462        for i in 0..n_samples {
1463            K_reg[[i, i]] += self.alpha;
1464        }
1465
1466        // For multi-output, we need to solve for each latent GP
1467        let mut covariance_inv = Vec::new();
1468
1469        // Each latent GP uses the same kernel structure
1470        for _r in 0..self.n_latent {
1471            let chol_decomp = utils::robust_cholesky(&K_reg)?;
1472            let identity = Array2::eye(n_samples);
1473            let inv = solve_triangular_matrix(&chol_decomp, &identity)?;
1474            covariance_inv.push(inv);
1475        }
1476
1477        // Optimize mixing matrix and latent targets
1478        let (mixing_matrix, y_latent) = self.optimize_mixing_matrix(Y, &covariance_inv, 10);
1479
1480        Ok(MultiOutputGaussianProcessRegressor {
1481            kernel: None,
1482            alpha: 0.0,
1483            n_outputs: 0,
1484            n_latent: 0,
1485            mixing_matrix: None,
1486            _state: MogprTrained {
1487                X_train: X.to_owned(),
1488                Y_train: Y.to_owned(),
1489                kernel,
1490                alpha: self.alpha,
1491                n_outputs: self.n_outputs,
1492                n_latent: self.n_latent,
1493                mixing_matrix,
1494                covariance_inv,
1495                y_latent,
1496            },
1497        })
1498    }
1499}
1500
1501impl MultiOutputGaussianProcessRegressor<MogprTrained> {
1502    /// Access the trained state
1503    pub fn trained_state(&self) -> &MogprTrained {
1504        &self._state
1505    }
1506
1507    /// Get the learned mixing matrix
1508    pub fn mixing_matrix(&self) -> &Array2<f64> {
1509        &self._state.mixing_matrix
1510    }
1511
1512    /// Get the log marginal likelihood for model selection
1513    #[allow(non_snake_case)]
1514    pub fn log_marginal_likelihood(&self) -> SklResult<f64> {
1515        let mut total_ll = 0.0;
1516        let n_samples = self._state.X_train.nrows();
1517
1518        for r in 0..self._state.n_latent {
1519            // Compute kernel matrix
1520            let K = self
1521                ._state
1522                .kernel
1523                .compute_kernel_matrix(&self._state.X_train, None)?;
1524            let mut K_reg = K.clone();
1525
1526            // Add regularization
1527            for i in 0..n_samples {
1528                K_reg[[i, i]] += self._state.alpha;
1529            }
1530
1531            let chol_decomp = utils::robust_cholesky(&K_reg)?;
1532            let y = &self._state.y_latent[r];
1533
1534            // Compute log determinant
1535            let log_det = chol_decomp.diag().iter().map(|x| x.ln()).sum::<f64>() * 2.0;
1536
1537            // Solve for alpha = K^{-1} * y
1538            let alpha = utils::triangular_solve(&chol_decomp, y)?;
1539            let data_fit = y.dot(&alpha);
1540
1541            // Log marginal likelihood: -0.5 * (y^T K^{-1} y + log|K| + n*log(2π))
1542            let ll =
1543                -0.5 * (data_fit + log_det + n_samples as f64 * (2.0 * std::f64::consts::PI).ln());
1544            total_ll += ll;
1545        }
1546
1547        Ok(total_ll)
1548    }
1549}
1550
1551impl Predict<ArrayView2<'_, f64>, Array2<f64>>
1552    for MultiOutputGaussianProcessRegressor<MogprTrained>
1553{
1554    #[allow(non_snake_case)]
1555    fn predict(&self, X: &ArrayView2<f64>) -> Result<Array2<f64>, SklearsError> {
1556        let (n_test, _) = X.dim();
1557        let mut predictions = Array2::<f64>::zeros((n_test, self._state.n_outputs));
1558
1559        // Predict for each latent GP and combine using mixing matrix
1560        for r in 0..self._state.n_latent {
1561            // Compute cross-covariance between test and training points
1562            let X_test_owned = X.to_owned();
1563            let K_star = self
1564                ._state
1565                .kernel
1566                .compute_kernel_matrix(&self._state.X_train, Some(&X_test_owned))?;
1567
1568            // Compute predictions for latent GP r
1569            let y_latent = &self._state.y_latent[r];
1570            let alpha = utils::triangular_solve(
1571                &utils::robust_cholesky(&{
1572                    let K = self
1573                        ._state
1574                        .kernel
1575                        .compute_kernel_matrix(&self._state.X_train, None)?;
1576                    let mut K_reg = K;
1577                    for i in 0..self._state.X_train.nrows() {
1578                        K_reg[[i, i]] += self._state.alpha;
1579                    }
1580                    K_reg
1581                })?,
1582                y_latent,
1583            )?;
1584
1585            let latent_pred = K_star.t().dot(&alpha);
1586
1587            // Combine predictions using mixing matrix
1588            for q in 0..self._state.n_outputs {
1589                let weight = self._state.mixing_matrix[[q, r]];
1590                for i in 0..n_test {
1591                    predictions[[i, q]] += weight * latent_pred[i];
1592                }
1593            }
1594        }
1595
1596        Ok(predictions)
1597    }
1598}
1599
1600/// Solve triangular system for multiple right-hand sides
1601fn solve_triangular_matrix(L: &Array2<f64>, B: &Array2<f64>) -> SklResult<Array2<f64>> {
1602    let n = L.nrows();
1603    let m = B.ncols();
1604    let mut X = Array2::<f64>::zeros((n, m));
1605
1606    for j in 0..m {
1607        let b = B.column(j);
1608        let x = utils::triangular_solve(L, &b.to_owned())?;
1609        X.column_mut(j).assign(&x);
1610    }
1611
1612    Ok(X)
1613}