selen 0.15.5

Constraint Satisfaction Problem (CSP) solver
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
//! LU decomposition with partial pivoting
//!
//! Provides LU factorization needed for solving linear systems in the Simplex method.
//! Uses partial pivoting for numerical stability.

use super::matrix::Matrix;
use super::types::LpError;

/// LU decomposition: PA = LU where P is a permutation matrix
#[derive(Debug, Clone)]
pub struct LuDecomposition {
    /// Combined L and U matrices (L below diagonal, U on and above diagonal)
    pub lu: Matrix,
    
    /// Row permutation from partial pivoting
    pub permutation: Vec<usize>,
    
    /// Sign of permutation (+1 or -1), used for determinant calculation
    pub sign: i32,
}

impl LuDecomposition {
    /// Compute LU decomposition with partial pivoting
    ///
    /// Returns LpError::SingularBasis if matrix is singular (within tolerance)
    pub fn decompose(matrix: &Matrix, tolerance: f64) -> Result<Self, LpError> {
        if matrix.rows != matrix.cols {
            return Err(LpError::NumericalInstability);
        }
        
        let n = matrix.rows;
        let mut lu = matrix.clone();
        let mut permutation: Vec<usize> = (0..n).collect();
        let mut sign = 1;
        
        // Gaussian elimination with partial pivoting
        for k in 0..n {
            // Find pivot (largest absolute value in column k, from row k onwards)
            let mut pivot_row = k;
            let mut pivot_value = lu.get(k, k).abs();
            
            for i in (k + 1)..n {
                let val = lu.get(i, k).abs();
                if val > pivot_value {
                    pivot_row = i;
                    pivot_value = val;
                }
            }
            
            // Check for singularity
            if pivot_value < tolerance {
                return Err(LpError::SingularBasis);
            }
            
            // Swap rows if needed
            if pivot_row != k {
                lu.swap_rows(k, pivot_row);
                permutation.swap(k, pivot_row);
                sign = -sign;
            }
            
            // Eliminate below diagonal
            let pivot = lu.get(k, k);
            for i in (k + 1)..n {
                let factor = lu.get(i, k) / pivot;
                lu.set(i, k, factor); // Store multiplier in L part
                
                // Update row i
                for j in (k + 1)..n {
                    let val = lu.get(i, j) - factor * lu.get(k, j);
                    lu.set(i, j, val);
                }
            }
        }
        
        Ok(Self {
            lu,
            permutation,
            sign,
        })
    }
    
    /// Solve linear system Ax = b using the LU decomposition
    ///
    /// Uses forward substitution (Ly = Pb) then backward substitution (Ux = y)
    pub fn solve(&self, b: &[f64]) -> Result<Vec<f64>, LpError> {
        let n = self.lu.rows;
        
        if b.len() != n {
            return Err(LpError::NumericalInstability);
        }
        
        // Apply permutation: y = Pb
        let mut y = vec![0.0; n];
        for i in 0..n {
            y[i] = b[self.permutation[i]];
        }
        
        // Forward substitution: Ly = Pb (L has 1s on diagonal)
        for i in 0..n {
            for j in 0..i {
                y[i] -= self.lu.get(i, j) * y[j];
            }
        }
        
        // Backward substitution: Ux = y
        let mut x = vec![0.0; n];
        for i in (0..n).rev() {
            let mut sum = y[i];
            for j in (i + 1)..n {
                sum -= self.lu.get(i, j) * x[j];
            }
            let diag = self.lu.get(i, i);
            if diag.abs() < 1e-12 {
                return Err(LpError::SingularBasis);
            }
            x[i] = sum / diag;
        }
        
        Ok(x)
    }

    /// Solve transpose system: A^T x = b
    ///
    /// Uses the existing LU decomposition: (LU)^T x = b => U^T L^T x = b
    pub fn solve_transpose(&self, b: &[f64]) -> Result<Vec<f64>, LpError> {
        let n = self.lu.rows;
        
        if b.len() != n {
            return Err(LpError::NumericalInstability);
        }
        
        // Solve U^T y = b (forward substitution since U^T is lower triangular)
        let mut y = vec![0.0; n];
        for i in 0..n {
            let mut sum = b[i];
            for j in 0..i {
                sum -= self.lu.get(j, i) * y[j];  // Note: transposed access
            }
            let diag = self.lu.get(i, i);
            if diag.abs() < 1e-12 {
                return Err(LpError::SingularBasis);
            }
            y[i] = sum / diag;
        }
        
        // Solve L^T x = y (backward substitution since L^T is upper triangular)
        // Note: L has 1s on diagonal
        let mut x = vec![0.0; n];
        for i in (0..n).rev() {
            let mut sum = y[i];
            for j in (i + 1)..n {
                sum -= self.lu.get(j, i) * x[j];  // Note: transposed access
            }
            x[i] = sum;  // Diagonal is 1
        }
        
        // Apply inverse permutation: result = P^T x
        let mut result = vec![0.0; n];
        for i in 0..n {
            result[self.permutation[i]] = x[i];
        }
        
        Ok(result)
    }
    
    /// Solve multiple right-hand sides: AX = B
    pub fn solve_multiple(&self, b_matrix: &Matrix) -> Result<Matrix, LpError> {
        if b_matrix.rows != self.lu.rows {
            return Err(LpError::NumericalInstability);
        }
        
        let mut result = Matrix::zeros(b_matrix.rows, b_matrix.cols);
        
        for col in 0..b_matrix.cols {
            let b_col = b_matrix.col(col);
            let x = self.solve(&b_col)?;
            
            for row in 0..result.rows {
                result.set(row, col, x[row]);
            }
        }
        
        Ok(result)
    }
    
    /// Compute the determinant from LU decomposition
    ///
    /// det(A) = sign * product of diagonal elements of U
    pub fn determinant(&self) -> f64 {
        let mut det = self.sign as f64;
        for i in 0..self.lu.rows {
            det *= self.lu.get(i, i);
        }
        det
    }
    
    /// Extract L matrix (lower triangular with 1s on diagonal)
    pub fn extract_l(&self) -> Matrix {
        let n = self.lu.rows;
        let mut l = Matrix::identity(n);
        
        for i in 0..n {
            for j in 0..i {
                l.set(i, j, self.lu.get(i, j));
            }
        }
        l
    }
    
    /// Extract U matrix (upper triangular)
    pub fn extract_u(&self) -> Matrix {
        let n = self.lu.rows;
        let mut u = Matrix::zeros(n, n);
        
        for i in 0..n {
            for j in i..n {
                u.set(i, j, self.lu.get(i, j));
            }
        }
        u
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_lu_decompose_simple() {
        // Simple 2x2 matrix
        let a = Matrix::from_rows(vec![
            vec![2.0, 1.0],
            vec![4.0, 3.0],
        ]);
        
        let lu = LuDecomposition::decompose(&a, 1e-10).unwrap();
        
        // Verify we can solve a system
        let b = vec![5.0, 13.0]; 
        let x = lu.solve(&b).unwrap();
        
        // Check solution by verifying Ax = b
        let ax = a.mul_vec(&x);
        assert!((ax[0] - b[0]).abs() < 1e-10);
        assert!((ax[1] - b[1]).abs() < 1e-10);
    }
    
    #[test]
    fn test_lu_decompose_3x3() {
        let a = Matrix::from_rows(vec![
            vec![2.0, -1.0, 0.0],
            vec![-1.0, 2.0, -1.0],
            vec![0.0, -1.0, 2.0],
        ]);
        
        let lu = LuDecomposition::decompose(&a, 1e-10).unwrap();
        
        // Solve Ax = b where b = [1, 0, 1]
        let b = vec![1.0, 0.0, 1.0];
        let x = lu.solve(&b).unwrap();
        
        // Verify: Ax should equal b
        let ax = a.mul_vec(&x);
        for i in 0..3 {
            assert!((ax[i] - b[i]).abs() < 1e-10);
        }
    }
    
    #[test]
    fn test_lu_identity() {
        let identity = Matrix::identity(3);
        let lu = LuDecomposition::decompose(&identity, 1e-10).unwrap();
        
        let b = vec![1.0, 2.0, 3.0];
        let x = lu.solve(&b).unwrap();
        
        // Identity matrix: x should equal b
        assert_eq!(x, b);
    }
    
    #[test]
    fn test_lu_singular_matrix() {
        // Singular matrix (second row is 2x first row)
        let a = Matrix::from_rows(vec![
            vec![1.0, 2.0],
            vec![2.0, 4.0],
        ]);
        
        let result = LuDecomposition::decompose(&a, 1e-10);
        assert!(result.is_err());
        assert_eq!(result.unwrap_err(), LpError::SingularBasis);
    }
    
    #[test]
    fn test_lu_ill_conditioned_identical_rows() {
        // Ill-conditioned matrix with two identical rows
        // This represents a linearly dependent system (rank < n)
        let a = Matrix::from_rows(vec![
            vec![1.0, 2.0, 3.0],
            vec![1.0, 2.0, 3.0],  // Identical to first row
            vec![4.0, 5.0, 6.0],
        ]);
        
        // Should fail as singular (rank-deficient)
        let result = LuDecomposition::decompose(&a, 1e-10);
        assert!(result.is_err());
        assert_eq!(result.unwrap_err(), LpError::SingularBasis);
    }
    
    #[test]
    fn test_lu_with_pivoting() {
        // Matrix that requires pivoting for stability
        let a = Matrix::from_rows(vec![
            vec![0.0001, 1.0],
            vec![1.0, 1.0],
        ]);
        
        let lu = LuDecomposition::decompose(&a, 1e-10).unwrap();
        
        let b = vec![1.0, 2.0];
        let x = lu.solve(&b).unwrap();
        
        // Verify solution
        let ax = a.mul_vec(&x);
        for i in 0..2 {
            assert!((ax[i] - b[i]).abs() < 1e-8);
        }
    }
    
    #[test]
    fn test_lu_determinant() {
        let a = Matrix::from_rows(vec![
            vec![2.0, 1.0],
            vec![4.0, 3.0],
        ]);
        
        let lu = LuDecomposition::decompose(&a, 1e-10).unwrap();
        let det = lu.determinant();
        
        // det([[2,1],[4,3]]) = 2*3 - 1*4 = 2
        assert!((det - 2.0).abs() < 1e-10);
    }
    
    #[test]
    fn test_lu_extract_l_u() {
        let a = Matrix::from_rows(vec![
            vec![4.0, 3.0],
            vec![6.0, 3.0],
        ]);
        
        let lu_decomp = LuDecomposition::decompose(&a, 1e-10).unwrap();
        let l = lu_decomp.extract_l();
        let u = lu_decomp.extract_u();
        
        // L should be lower triangular with 1s on diagonal
        assert_eq!(l.get(0, 0), 1.0);
        assert_eq!(l.get(1, 1), 1.0);
        assert_eq!(l.get(0, 1), 0.0); // Upper triangle is zero
        
        // U should be upper triangular
        assert_eq!(u.get(1, 0), 0.0); // Lower triangle is zero
        
        // PA = LU (where P is permutation)
        let lu_product = l.mul_matrix(&u);
        
        // Apply permutation to original matrix
        let mut pa = Matrix::zeros(a.rows, a.cols);
        for i in 0..a.rows {
            for j in 0..a.cols {
                pa.set(i, j, a.get(lu_decomp.permutation[i], j));
            }
        }
        
        // Check PA ≈ LU
        for i in 0..a.rows {
            for j in 0..a.cols {
                assert!((pa.get(i, j) - lu_product.get(i, j)).abs() < 1e-10);
            }
        }
    }
    
    #[test]
    fn test_lu_solve_multiple() {
        let a = Matrix::from_rows(vec![
            vec![2.0, 1.0],
            vec![4.0, 3.0],
        ]);
        
        let lu = LuDecomposition::decompose(&a, 1e-10).unwrap();
        
        // Solve for multiple right-hand sides
        let b = Matrix::from_rows(vec![
            vec![5.0, 3.0],
            vec![13.0, 7.0],
        ]);
        
        let x = lu.solve_multiple(&b).unwrap();
        
        // Verify AX = B
        let ax = a.mul_matrix(&x);
        for i in 0..2 {
            for j in 0..2 {
                assert!((ax.get(i, j) - b.get(i, j)).abs() < 1e-10);
            }
        }
    }

    #[test]
    fn test_lu_solve_transpose() {
        // Test solving A^T x = b
        let a = Matrix::from_rows(vec![
            vec![2.0, 1.0, 0.0],
            vec![1.0, 3.0, 1.0],
            vec![0.0, 1.0, 2.0],
        ]);
        
        let lu = LuDecomposition::decompose(&a, 1e-10).unwrap();
        
        // Solve A^T x = b where b = [1, 2, 3]
        let b = vec![1.0, 2.0, 3.0];
        let x = lu.solve_transpose(&b).unwrap();
        
        // Verify A^T x = b by computing the product
        let a_transpose = a.transpose();
        let result = a_transpose.mul_vec(&x);
        
        for i in 0..3 {
            assert!(
                (result[i] - b[i]).abs() < 1e-10,
                "A^T x = b verification failed at index {}: {} != {}",
                i,
                result[i],
                b[i]
            );
        }
    }

    #[test]
    fn test_lu_solve_transpose_simple() {
        // Simple 2x2 test
        // A = [[2, 1], [4, 3]]
        // A^T = [[2, 4], [1, 3]]
        let a = Matrix::from_rows(vec![
            vec![2.0, 1.0],
            vec![4.0, 3.0],
        ]);
        
        let lu = LuDecomposition::decompose(&a, 1e-10).unwrap();
        
        // Solve A^T x = [10, 7] where A^T = [[2, 4], [1, 3]]
        // Expected solution: x = [1, 2] because [[2,4],[1,3]] * [1,2] = [10, 7]
        let b2 = vec![10.0, 7.0];
        let x2 = lu.solve_transpose(&b2).unwrap();
        
        assert!((x2[0] - 1.0).abs() < 1e-10, "x[0] should be 1.0, got {}", x2[0]);
        assert!((x2[1] - 2.0).abs() < 1e-10, "x[1] should be 2.0, got {}", x2[1]);
    }
}