1use ndarray::{Array1, Array2};
11use serde::{Deserialize, Serialize};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct ErrorFixPair {
16 pub error_features: Vec<f32>,
18 pub fix_features: Vec<f32>,
20 pub correlation_score: f32,
22}
23
24impl ErrorFixPair {
25 #[must_use]
27 pub fn new(error_features: Vec<f32>, fix_features: Vec<f32>, correlation_score: f32) -> Self {
28 Self { error_features, fix_features, correlation_score: correlation_score.clamp(0.0, 1.0) }
29 }
30}
31
32#[derive(Debug, Clone)]
53pub struct CitlTrainer {
54 weights: Array2<f32>,
56 error_dim: usize,
58 fix_dim: usize,
60}
61
62impl CitlTrainer {
63 pub fn train(pairs: &[ErrorFixPair]) -> Result<Self, crate::Error> {
72 if pairs.is_empty() {
73 return Err(crate::Error::InvalidParameter(
74 "CITL training requires at least one error-fix pair".into(),
75 ));
76 }
77
78 let error_dim = pairs[0].error_features.len();
79 let fix_dim = pairs[0].fix_features.len();
80
81 if error_dim == 0 || fix_dim == 0 {
82 return Err(crate::Error::InvalidParameter(
83 "Feature dimensions must be positive".into(),
84 ));
85 }
86
87 validate_pair_dimensions(pairs, error_dim, fix_dim)?;
89
90 let n = pairs.len();
91
92 let mut x_data = Vec::with_capacity(n * error_dim);
94 let mut y_data = Vec::with_capacity(n * fix_dim);
95 let mut sample_weights = Vec::with_capacity(n);
96
97 for pair in pairs {
98 x_data.extend_from_slice(&pair.error_features);
99 y_data.extend_from_slice(&pair.fix_features);
100 sample_weights.push(pair.correlation_score.max(1e-6)); }
102
103 let x = Array2::from_shape_vec((n, error_dim), x_data)
104 .map_err(|e| crate::Error::InvalidParameter(format!("X matrix build error: {e}")))?;
105 let y = Array2::from_shape_vec((n, fix_dim), y_data)
106 .map_err(|e| crate::Error::InvalidParameter(format!("Y matrix build error: {e}")))?;
107
108 let sqrt_w: Array1<f32> =
110 Array1::from_vec(sample_weights.iter().map(|w| w.sqrt()).collect());
111
112 let mut x_w = x.clone();
114 let mut y_w = y.clone();
115 for i in 0..n {
116 let sw = sqrt_w[i];
117 for j in 0..error_dim {
118 x_w[[i, j]] *= sw;
119 }
120 for j in 0..fix_dim {
121 y_w[[i, j]] *= sw;
122 }
123 }
124
125 let a = x_w.t().dot(&x_w);
128
129 let b = x_w.t().dot(&y_w);
131
132 let lambda = 1e-4_f32;
135 let mut a_reg = a;
136 for i in 0..error_dim {
137 a_reg[[i, i]] += lambda;
138 }
139
140 let a_inv = invert_matrix(&a_reg).map_err(|_e| {
142 crate::Error::InvalidParameter(
143 "Normal equation matrix is singular; cannot solve for weights".into(),
144 )
145 })?;
146
147 let w_t = a_inv.dot(&b);
149
150 let weights = w_t.t().to_owned();
152
153 Ok(Self { weights, error_dim, fix_dim })
154 }
155
156 #[must_use]
160 pub fn predict_fix(&self, error_features: &[f32]) -> Vec<f32> {
161 if error_features.len() != self.error_dim {
162 return vec![0.0; self.fix_dim];
163 }
164
165 let x = Array1::from_vec(error_features.to_vec());
166 let y = self.weights.dot(&x);
167 y.to_vec()
168 }
169
170 #[must_use]
172 pub fn error_dim(&self) -> usize {
173 self.error_dim
174 }
175
176 #[must_use]
178 pub fn fix_dim(&self) -> usize {
179 self.fix_dim
180 }
181
182 #[must_use]
184 pub fn weights(&self) -> &Array2<f32> {
185 &self.weights
186 }
187}
188
189fn validate_pair_dimensions(
191 pairs: &[ErrorFixPair],
192 error_dim: usize,
193 fix_dim: usize,
194) -> Result<(), crate::Error> {
195 for (i, pair) in pairs.iter().enumerate() {
196 if pair.error_features.len() != error_dim {
197 return Err(crate::Error::ShapeMismatch {
198 expected: vec![error_dim],
199 got: vec![pair.error_features.len()],
200 });
201 }
202 if pair.fix_features.len() != fix_dim {
203 return Err(crate::Error::ShapeMismatch {
204 expected: vec![fix_dim],
205 got: vec![pair.fix_features.len()],
206 });
207 }
208 if i > 0 && pair.error_features.len() != error_dim {
209 return Err(crate::Error::InvalidParameter(format!(
210 "Inconsistent error feature dimension at pair {i}"
211 )));
212 }
213 }
214 Ok(())
215}
216
217fn invert_matrix(m: &Array2<f32>) -> std::result::Result<Array2<f32>, ()> {
221 let n = m.nrows();
222 assert_eq!(n, m.ncols(), "Matrix must be square");
223
224 let mut aug = build_augmented(m, n);
225
226 for col in 0..n {
227 pivot_column(&mut aug, col, n)?;
228 eliminate_column(&mut aug, col, n);
229 }
230
231 Ok(extract_inverse(&aug, n))
232}
233
234fn build_augmented(m: &Array2<f32>, n: usize) -> Array2<f32> {
236 let mut aug = Array2::<f32>::zeros((n, 2 * n));
237 for i in 0..n {
238 for j in 0..n {
239 aug[[i, j]] = m[[i, j]];
240 }
241 aug[[i, n + i]] = 1.0;
242 }
243 aug
244}
245
246fn pivot_column(aug: &mut Array2<f32>, col: usize, n: usize) -> std::result::Result<(), ()> {
248 let mut max_val = aug[[col, col]].abs();
249 let mut max_row = col;
250 for row in (col + 1)..n {
251 let val = aug[[row, col]].abs();
252 if val > max_val {
253 max_val = val;
254 max_row = row;
255 }
256 }
257
258 if max_val < 1e-12 {
259 return Err(());
260 }
261
262 if max_row != col {
263 for j in 0..(2 * n) {
264 let tmp = aug[[col, j]];
265 aug[[col, j]] = aug[[max_row, j]];
266 aug[[max_row, j]] = tmp;
267 }
268 }
269
270 let pivot = aug[[col, col]];
271 for j in 0..(2 * n) {
272 aug[[col, j]] /= pivot;
273 }
274 Ok(())
275}
276
277fn eliminate_column(aug: &mut Array2<f32>, col: usize, n: usize) {
279 for row in 0..n {
280 if row == col {
281 continue;
282 }
283 let factor = aug[[row, col]];
284 for j in 0..(2 * n) {
285 aug[[row, j]] -= factor * aug[[col, j]];
286 }
287 }
288}
289
290fn extract_inverse(aug: &Array2<f32>, n: usize) -> Array2<f32> {
292 let mut inv = Array2::<f32>::zeros((n, n));
293 for i in 0..n {
294 for j in 0..n {
295 inv[[i, j]] = aug[[i, n + j]];
296 }
297 }
298 inv
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304
305 fn simple_pairs() -> Vec<ErrorFixPair> {
306 vec![
307 ErrorFixPair::new(vec![1.0, 0.0], vec![0.0, 1.0], 0.9),
308 ErrorFixPair::new(vec![0.0, 1.0], vec![1.0, 0.0], 0.8),
309 ErrorFixPair::new(vec![1.0, 1.0], vec![1.0, 1.0], 0.7),
310 ]
311 }
312
313 #[test]
314 fn test_train_produces_correct_dims() {
315 let trainer = CitlTrainer::train(&simple_pairs()).expect("operation should succeed");
316 assert_eq!(trainer.error_dim(), 2);
317 assert_eq!(trainer.fix_dim(), 2);
318 assert_eq!(trainer.weights().shape(), &[2, 2]);
319 }
320
321 #[test]
322 fn test_predict_suggestion_output_length() {
323 let trainer = CitlTrainer::train(&simple_pairs()).expect("operation should succeed");
324 let pred = trainer.predict_fix(&[1.0, 0.0]);
325 assert_eq!(pred.len(), 2);
326 }
327
328 #[test]
329 fn test_predict_fix_wrong_dim_returns_zeros() {
330 let trainer = CitlTrainer::train(&simple_pairs()).expect("operation should succeed");
331 let pred = trainer.predict_fix(&[1.0, 0.0, 0.0]);
332 assert_eq!(pred, vec![0.0, 0.0]);
333 }
334
335 #[test]
336 fn test_train_empty_pairs() {
337 let result = CitlTrainer::train(&[]);
338 assert!(result.is_err());
339 }
340
341 #[test]
342 fn test_train_zero_dim_features() {
343 let pairs = vec![ErrorFixPair::new(vec![], vec![1.0], 1.0)];
344 let result = CitlTrainer::train(&pairs);
345 assert!(result.is_err());
346 }
347
348 #[test]
349 fn test_train_inconsistent_error_dims() {
350 let pairs = vec![
351 ErrorFixPair::new(vec![1.0, 0.0], vec![1.0], 0.9),
352 ErrorFixPair::new(vec![1.0], vec![1.0], 0.8), ];
354 let result = CitlTrainer::train(&pairs);
355 assert!(result.is_err());
356 }
357
358 #[test]
359 fn test_train_inconsistent_fix_dims() {
360 let pairs = vec![
361 ErrorFixPair::new(vec![1.0], vec![1.0, 0.0], 0.9),
362 ErrorFixPair::new(vec![0.0], vec![1.0], 0.8), ];
364 let result = CitlTrainer::train(&pairs);
365 assert!(result.is_err());
366 }
367
368 #[test]
369 fn test_identity_mapping() {
370 let pairs: Vec<ErrorFixPair> = (0..10)
372 .map(|i| {
373 let mut e = vec![0.0; 3];
374 e[i % 3] = 1.0;
375 ErrorFixPair::new(e.clone(), e, 1.0)
376 })
377 .collect();
378
379 let trainer = CitlTrainer::train(&pairs).expect("operation should succeed");
380 let pred = trainer.predict_fix(&[1.0, 0.0, 0.0]);
381 assert!((pred[0] - 1.0).abs() < 0.2, "pred[0]={}", pred[0]);
383 assert!(pred[1].abs() < 0.2, "pred[1]={}", pred[1]);
384 assert!(pred[2].abs() < 0.2, "pred[2]={}", pred[2]);
385 }
386
387 #[test]
388 fn test_correlation_score_clamped() {
389 let pair = ErrorFixPair::new(vec![1.0], vec![1.0], 2.0);
390 assert_eq!(pair.correlation_score, 1.0);
391
392 let pair2 = ErrorFixPair::new(vec![1.0], vec![1.0], -1.0);
393 assert_eq!(pair2.correlation_score, 0.0);
394 }
395
396 #[test]
397 fn test_single_pair_training() {
398 let pairs = vec![ErrorFixPair::new(vec![2.0, 0.0], vec![0.0, 4.0], 1.0)];
399 let trainer = CitlTrainer::train(&pairs).expect("operation should succeed");
400 let pred = trainer.predict_fix(&[2.0, 0.0]);
401 assert!(pred.len() == 2);
403 assert!(pred[1] > pred[0], "pred={pred:?}");
405 }
406
407 #[test]
408 fn test_invert_identity() {
409 let eye = Array2::eye(3);
410 let inv = invert_matrix(&eye).expect("operation should succeed");
411 for i in 0..3 {
412 for j in 0..3 {
413 let expected = if i == j { 1.0 } else { 0.0 };
414 assert!((inv[[i, j]] - expected).abs() < 1e-6, "inv[{i},{j}]={}", inv[[i, j]]);
415 }
416 }
417 }
418
419 #[test]
420 fn test_invert_2x2() {
421 let m = Array2::from_shape_vec((2, 2), vec![2.0, 1.0, 1.0, 1.0])
423 .expect("operation should succeed");
424 let inv = invert_matrix(&m).expect("operation should succeed");
425 assert!((inv[[0, 0]] - 1.0).abs() < 1e-5);
426 assert!((inv[[0, 1]] - (-1.0)).abs() < 1e-5);
427 assert!((inv[[1, 0]] - (-1.0)).abs() < 1e-5);
428 assert!((inv[[1, 1]] - 2.0).abs() < 1e-5);
429 }
430
431 #[test]
432 fn test_weighted_training() {
433 let pairs = vec![
435 ErrorFixPair::new(vec![1.0, 0.0], vec![10.0, 0.0], 1.0), ErrorFixPair::new(vec![1.0, 0.0], vec![0.0, 10.0], 0.01), ];
438 let trainer = CitlTrainer::train(&pairs).expect("operation should succeed");
439 let pred = trainer.predict_fix(&[1.0, 0.0]);
440 assert!(pred[0] > pred[1], "High-weight sample should dominate: pred={pred:?}");
442 }
443}