Skip to main content

entrenar/decision/
citl.rs

1//! Correlation-Informed Transfer Learning (CITL) trainer.
2//!
3//! Trains a simple linear model that maps error feature vectors to fix
4//! feature vectors using least-squares regression via the normal equation.
5//!
6//! Given a set of `ErrorFixPair` samples, the trainer computes a weight
7//! matrix W such that `fix_features ≈ W * error_features`. Prediction
8//! for new errors is a simple matrix-vector multiply.
9
10use ndarray::{Array1, Array2};
11use serde::{Deserialize, Serialize};
12
13/// An error-fix training pair for CITL.
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct ErrorFixPair {
16    /// Feature vector describing the error.
17    pub error_features: Vec<f32>,
18    /// Feature vector describing the fix.
19    pub fix_features: Vec<f32>,
20    /// Strength of the error-fix correlation in [0.0, 1.0].
21    pub correlation_score: f32,
22}
23
24impl ErrorFixPair {
25    /// Create a new error-fix pair.
26    #[must_use]
27    pub fn new(error_features: Vec<f32>, fix_features: Vec<f32>, correlation_score: f32) -> Self {
28        Self { error_features, fix_features, correlation_score: correlation_score.clamp(0.0, 1.0) }
29    }
30}
31
32/// CITL trainer that learns a linear mapping from error features to fix features.
33///
34/// Uses weighted least-squares regression:
35///   `W = (X^T S X)^{-1} X^T S Y`
36/// where S is a diagonal matrix of correlation scores used as sample weights.
37///
38/// # Example
39///
40/// ```
41/// use entrenar::decision::{CitlTrainer, ErrorFixPair};
42///
43/// let pairs = vec![
44///     ErrorFixPair::new(vec![1.0, 0.0], vec![0.0, 1.0], 0.9),
45///     ErrorFixPair::new(vec![0.0, 1.0], vec![1.0, 0.0], 0.8),
46/// ];
47///
48/// let trainer = CitlTrainer::train(&pairs).expect("training must succeed");
49/// let prediction = trainer.predict_fix(&[1.0, 0.0]);
50/// assert_eq!(prediction.len(), 2);
51/// ```
52#[derive(Debug, Clone)]
53pub struct CitlTrainer {
54    /// Weight matrix of shape (fix_dim, error_dim).
55    weights: Array2<f32>,
56    /// Dimensionality of error features (input).
57    error_dim: usize,
58    /// Dimensionality of fix features (output).
59    fix_dim: usize,
60}
61
62impl CitlTrainer {
63    /// Train a linear correlation model from error-fix pairs.
64    ///
65    /// # Errors
66    ///
67    /// Returns an error if:
68    /// - `pairs` is empty
69    /// - Feature dimensions are inconsistent across pairs
70    /// - The normal equation matrix is singular (not invertible)
71    pub fn train(pairs: &[ErrorFixPair]) -> Result<Self, crate::Error> {
72        if pairs.is_empty() {
73            return Err(crate::Error::InvalidParameter(
74                "CITL training requires at least one error-fix pair".into(),
75            ));
76        }
77
78        let error_dim = pairs[0].error_features.len();
79        let fix_dim = pairs[0].fix_features.len();
80
81        if error_dim == 0 || fix_dim == 0 {
82            return Err(crate::Error::InvalidParameter(
83                "Feature dimensions must be positive".into(),
84            ));
85        }
86
87        // Validate consistent dimensions
88        validate_pair_dimensions(pairs, error_dim, fix_dim)?;
89
90        let n = pairs.len();
91
92        // Build X (n x error_dim) and Y (n x fix_dim) matrices
93        let mut x_data = Vec::with_capacity(n * error_dim);
94        let mut y_data = Vec::with_capacity(n * fix_dim);
95        let mut sample_weights = Vec::with_capacity(n);
96
97        for pair in pairs {
98            x_data.extend_from_slice(&pair.error_features);
99            y_data.extend_from_slice(&pair.fix_features);
100            sample_weights.push(pair.correlation_score.max(1e-6)); // avoid zero weight
101        }
102
103        let x = Array2::from_shape_vec((n, error_dim), x_data)
104            .map_err(|e| crate::Error::InvalidParameter(format!("X matrix build error: {e}")))?;
105        let y = Array2::from_shape_vec((n, fix_dim), y_data)
106            .map_err(|e| crate::Error::InvalidParameter(format!("Y matrix build error: {e}")))?;
107
108        // Build diagonal weight vector sqrt(S) for weighted least squares
109        let sqrt_w: Array1<f32> =
110            Array1::from_vec(sample_weights.iter().map(|w| w.sqrt()).collect());
111
112        // Apply weights: X_w = diag(sqrt_w) * X,  Y_w = diag(sqrt_w) * Y
113        let mut x_w = x.clone();
114        let mut y_w = y.clone();
115        for i in 0..n {
116            let sw = sqrt_w[i];
117            for j in 0..error_dim {
118                x_w[[i, j]] *= sw;
119            }
120            for j in 0..fix_dim {
121                y_w[[i, j]] *= sw;
122            }
123        }
124
125        // Normal equation: W = (X_w^T X_w)^{-1} X_w^T Y_w
126        // A = X_w^T X_w  (error_dim x error_dim)
127        let a = x_w.t().dot(&x_w);
128
129        // B = X_w^T Y_w  (error_dim x fix_dim)
130        let b = x_w.t().dot(&y_w);
131
132        // Solve A * W^T = B  via Tikhonov regularization (ridge)
133        // A_reg = A + lambda * I  to ensure invertibility
134        let lambda = 1e-4_f32;
135        let mut a_reg = a;
136        for i in 0..error_dim {
137            a_reg[[i, i]] += lambda;
138        }
139
140        // Invert A_reg using Gauss-Jordan elimination
141        let a_inv = invert_matrix(&a_reg).map_err(|_e| {
142            crate::Error::InvalidParameter(
143                "Normal equation matrix is singular; cannot solve for weights".into(),
144            )
145        })?;
146
147        // W^T = A_inv * B  (error_dim x fix_dim)
148        let w_t = a_inv.dot(&b);
149
150        // W = (W^T)^T  (fix_dim x error_dim)
151        let weights = w_t.t().to_owned();
152
153        Ok(Self { weights, error_dim, fix_dim })
154    }
155
156    /// Predict fix features from error features.
157    ///
158    /// Returns a zero vector if the input dimension does not match `error_dim`.
159    #[must_use]
160    pub fn predict_fix(&self, error_features: &[f32]) -> Vec<f32> {
161        if error_features.len() != self.error_dim {
162            return vec![0.0; self.fix_dim];
163        }
164
165        let x = Array1::from_vec(error_features.to_vec());
166        let y = self.weights.dot(&x);
167        y.to_vec()
168    }
169
170    /// Return the error feature dimensionality.
171    #[must_use]
172    pub fn error_dim(&self) -> usize {
173        self.error_dim
174    }
175
176    /// Return the fix feature dimensionality.
177    #[must_use]
178    pub fn fix_dim(&self) -> usize {
179        self.fix_dim
180    }
181
182    /// Return a reference to the weight matrix.
183    #[must_use]
184    pub fn weights(&self) -> &Array2<f32> {
185        &self.weights
186    }
187}
188
189/// Validate that all pairs have consistent error/fix feature dimensions.
190fn validate_pair_dimensions(
191    pairs: &[ErrorFixPair],
192    error_dim: usize,
193    fix_dim: usize,
194) -> Result<(), crate::Error> {
195    for (i, pair) in pairs.iter().enumerate() {
196        if pair.error_features.len() != error_dim {
197            return Err(crate::Error::ShapeMismatch {
198                expected: vec![error_dim],
199                got: vec![pair.error_features.len()],
200            });
201        }
202        if pair.fix_features.len() != fix_dim {
203            return Err(crate::Error::ShapeMismatch {
204                expected: vec![fix_dim],
205                got: vec![pair.fix_features.len()],
206            });
207        }
208        if i > 0 && pair.error_features.len() != error_dim {
209            return Err(crate::Error::InvalidParameter(format!(
210                "Inconsistent error feature dimension at pair {i}"
211            )));
212        }
213    }
214    Ok(())
215}
216
217/// Invert a square matrix using Gauss-Jordan elimination.
218///
219/// Returns `Err(())` if the matrix is singular.
220fn invert_matrix(m: &Array2<f32>) -> std::result::Result<Array2<f32>, ()> {
221    let n = m.nrows();
222    assert_eq!(n, m.ncols(), "Matrix must be square");
223
224    let mut aug = build_augmented(m, n);
225
226    for col in 0..n {
227        pivot_column(&mut aug, col, n)?;
228        eliminate_column(&mut aug, col, n);
229    }
230
231    Ok(extract_inverse(&aug, n))
232}
233
234/// Build augmented matrix [M | I].
235fn build_augmented(m: &Array2<f32>, n: usize) -> Array2<f32> {
236    let mut aug = Array2::<f32>::zeros((n, 2 * n));
237    for i in 0..n {
238        for j in 0..n {
239            aug[[i, j]] = m[[i, j]];
240        }
241        aug[[i, n + i]] = 1.0;
242    }
243    aug
244}
245
246/// Partial pivoting: find largest pivot, swap rows, scale pivot row.
247fn pivot_column(aug: &mut Array2<f32>, col: usize, n: usize) -> std::result::Result<(), ()> {
248    let mut max_val = aug[[col, col]].abs();
249    let mut max_row = col;
250    for row in (col + 1)..n {
251        let val = aug[[row, col]].abs();
252        if val > max_val {
253            max_val = val;
254            max_row = row;
255        }
256    }
257
258    if max_val < 1e-12 {
259        return Err(());
260    }
261
262    if max_row != col {
263        for j in 0..(2 * n) {
264            let tmp = aug[[col, j]];
265            aug[[col, j]] = aug[[max_row, j]];
266            aug[[max_row, j]] = tmp;
267        }
268    }
269
270    let pivot = aug[[col, col]];
271    for j in 0..(2 * n) {
272        aug[[col, j]] /= pivot;
273    }
274    Ok(())
275}
276
277/// Eliminate all rows except pivot row for a given column.
278fn eliminate_column(aug: &mut Array2<f32>, col: usize, n: usize) {
279    for row in 0..n {
280        if row == col {
281            continue;
282        }
283        let factor = aug[[row, col]];
284        for j in 0..(2 * n) {
285            aug[[row, j]] -= factor * aug[[col, j]];
286        }
287    }
288}
289
290/// Extract the inverse matrix from the right half of the augmented matrix.
291fn extract_inverse(aug: &Array2<f32>, n: usize) -> Array2<f32> {
292    let mut inv = Array2::<f32>::zeros((n, n));
293    for i in 0..n {
294        for j in 0..n {
295            inv[[i, j]] = aug[[i, n + j]];
296        }
297    }
298    inv
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304
305    fn simple_pairs() -> Vec<ErrorFixPair> {
306        vec![
307            ErrorFixPair::new(vec![1.0, 0.0], vec![0.0, 1.0], 0.9),
308            ErrorFixPair::new(vec![0.0, 1.0], vec![1.0, 0.0], 0.8),
309            ErrorFixPair::new(vec![1.0, 1.0], vec![1.0, 1.0], 0.7),
310        ]
311    }
312
313    #[test]
314    fn test_train_produces_correct_dims() {
315        let trainer = CitlTrainer::train(&simple_pairs()).expect("operation should succeed");
316        assert_eq!(trainer.error_dim(), 2);
317        assert_eq!(trainer.fix_dim(), 2);
318        assert_eq!(trainer.weights().shape(), &[2, 2]);
319    }
320
321    #[test]
322    fn test_predict_suggestion_output_length() {
323        let trainer = CitlTrainer::train(&simple_pairs()).expect("operation should succeed");
324        let pred = trainer.predict_fix(&[1.0, 0.0]);
325        assert_eq!(pred.len(), 2);
326    }
327
328    #[test]
329    fn test_predict_fix_wrong_dim_returns_zeros() {
330        let trainer = CitlTrainer::train(&simple_pairs()).expect("operation should succeed");
331        let pred = trainer.predict_fix(&[1.0, 0.0, 0.0]);
332        assert_eq!(pred, vec![0.0, 0.0]);
333    }
334
335    #[test]
336    fn test_train_empty_pairs() {
337        let result = CitlTrainer::train(&[]);
338        assert!(result.is_err());
339    }
340
341    #[test]
342    fn test_train_zero_dim_features() {
343        let pairs = vec![ErrorFixPair::new(vec![], vec![1.0], 1.0)];
344        let result = CitlTrainer::train(&pairs);
345        assert!(result.is_err());
346    }
347
348    #[test]
349    fn test_train_inconsistent_error_dims() {
350        let pairs = vec![
351            ErrorFixPair::new(vec![1.0, 0.0], vec![1.0], 0.9),
352            ErrorFixPair::new(vec![1.0], vec![1.0], 0.8), // wrong error dim
353        ];
354        let result = CitlTrainer::train(&pairs);
355        assert!(result.is_err());
356    }
357
358    #[test]
359    fn test_train_inconsistent_fix_dims() {
360        let pairs = vec![
361            ErrorFixPair::new(vec![1.0], vec![1.0, 0.0], 0.9),
362            ErrorFixPair::new(vec![0.0], vec![1.0], 0.8), // wrong fix dim
363        ];
364        let result = CitlTrainer::train(&pairs);
365        assert!(result.is_err());
366    }
367
368    #[test]
369    fn test_identity_mapping() {
370        // Train on identity-like mapping: error = fix
371        let pairs: Vec<ErrorFixPair> = (0..10)
372            .map(|i| {
373                let mut e = vec![0.0; 3];
374                e[i % 3] = 1.0;
375                ErrorFixPair::new(e.clone(), e, 1.0)
376            })
377            .collect();
378
379        let trainer = CitlTrainer::train(&pairs).expect("operation should succeed");
380        let pred = trainer.predict_fix(&[1.0, 0.0, 0.0]);
381        // Should approximately recover the identity mapping
382        assert!((pred[0] - 1.0).abs() < 0.2, "pred[0]={}", pred[0]);
383        assert!(pred[1].abs() < 0.2, "pred[1]={}", pred[1]);
384        assert!(pred[2].abs() < 0.2, "pred[2]={}", pred[2]);
385    }
386
387    #[test]
388    fn test_correlation_score_clamped() {
389        let pair = ErrorFixPair::new(vec![1.0], vec![1.0], 2.0);
390        assert_eq!(pair.correlation_score, 1.0);
391
392        let pair2 = ErrorFixPair::new(vec![1.0], vec![1.0], -1.0);
393        assert_eq!(pair2.correlation_score, 0.0);
394    }
395
396    #[test]
397    fn test_single_pair_training() {
398        let pairs = vec![ErrorFixPair::new(vec![2.0, 0.0], vec![0.0, 4.0], 1.0)];
399        let trainer = CitlTrainer::train(&pairs).expect("operation should succeed");
400        let pred = trainer.predict_fix(&[2.0, 0.0]);
401        // With only one sample + ridge regularization, should approximate [0.0, 4.0]
402        assert!(pred.len() == 2);
403        // Direction should be roughly correct
404        assert!(pred[1] > pred[0], "pred={pred:?}");
405    }
406
407    #[test]
408    fn test_invert_identity() {
409        let eye = Array2::eye(3);
410        let inv = invert_matrix(&eye).expect("operation should succeed");
411        for i in 0..3 {
412            for j in 0..3 {
413                let expected = if i == j { 1.0 } else { 0.0 };
414                assert!((inv[[i, j]] - expected).abs() < 1e-6, "inv[{i},{j}]={}", inv[[i, j]]);
415            }
416        }
417    }
418
419    #[test]
420    fn test_invert_2x2() {
421        // [[2, 1], [1, 1]] -> inverse [[1, -1], [-1, 2]]
422        let m = Array2::from_shape_vec((2, 2), vec![2.0, 1.0, 1.0, 1.0])
423            .expect("operation should succeed");
424        let inv = invert_matrix(&m).expect("operation should succeed");
425        assert!((inv[[0, 0]] - 1.0).abs() < 1e-5);
426        assert!((inv[[0, 1]] - (-1.0)).abs() < 1e-5);
427        assert!((inv[[1, 0]] - (-1.0)).abs() < 1e-5);
428        assert!((inv[[1, 1]] - 2.0).abs() < 1e-5);
429    }
430
431    #[test]
432    fn test_weighted_training() {
433        // Two conflicting samples: high-weight sample should dominate
434        let pairs = vec![
435            ErrorFixPair::new(vec![1.0, 0.0], vec![10.0, 0.0], 1.0), // high weight
436            ErrorFixPair::new(vec![1.0, 0.0], vec![0.0, 10.0], 0.01), // low weight
437        ];
438        let trainer = CitlTrainer::train(&pairs).expect("operation should succeed");
439        let pred = trainer.predict_fix(&[1.0, 0.0]);
440        // High-weight sample's fix direction should dominate
441        assert!(pred[0] > pred[1], "High-weight sample should dominate: pred={pred:?}");
442    }
443}