differential-equations 0.5.3

A Rust library for solving differential equations.
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
//! LU decomposition algorithms.
//!
//! This module provides LU decomposition with partial pivoting for real and complex matrices.
//!
//! # References
//! - Hairer, E., & Wanner, G. (1996). Solving Ordinary Differential Equations II:
//!   Stiff and Differential-Algebraic Problems. Springer.

use crate::{
    linalg::{Matrix, error::LinalgError},
    traits::Real,
};

/// LU decomposition with partial pivoting
///
/// This function performs LU decomposition with partial pivoting on a square matrix,
/// factorizing a matrix A into the product PA = LU where:
/// - P is a permutation matrix (represented by pivot indices)
/// - L is unit lower triangular (with implicit unit diagonal)
/// - U is upper triangular
///
/// # Arguments
/// * `a` - Square matrix to decompose (modified in-place to store L and U)
/// * `ip` - Pivot index slice (must have length equal to matrix size)
///
/// # Returns
/// * `Ok(())` - Decomposition successful
/// * `Err(k)` - Matrix is singular, detected at step k (1-indexed)
///
/// # Algorithm
/// The decomposition proceeds in n-1 stages. At each stage k:
/// 1. **Pivoting**: Find the largest element in column k below the diagonal
/// 2. **Row exchange**: Swap rows to bring the pivot to the diagonal
/// 3. **Elimination**: Use the pivot to eliminate elements below it
/// 4. **Update**: Apply the elimination to the remaining submatrix
///
/// The pivot information is stored in `ip` where `ip[k]` contains the row index
/// that was swapped with row k during stage k.
///
/// # Mathematical Background
/// LU decomposition with partial pivoting factors PA = LU where:
/// - The permutation P ensures numerical stability by choosing the largest pivot
/// - L has unit diagonal (1's) and the multipliers below the diagonal
/// - U is upper triangular with the pivots on the diagonal
/// - The factorization satisfies: A = P⁻¹LU
///
/// # Storage
/// After decomposition, the matrix `a` contains:
/// - Upper triangle and diagonal: the U factor
/// - Strict lower triangle: the L factor (without the unit diagonal)
///
/// # Examples
/// ```rust,ignore
/// use differential_equations::linalg::{Matrix, lu::lu_decomp};
///
/// let mut a = Matrix::from_vec(2, 2, vec![2.0, 1.0, 1.0, 1.0]);
/// let mut ip = [0; 2];
///
/// match lu_decomp(&mut a, &mut ip) {
///     Ok(()) => println!("Decomposition successful"),
///     Err(err) => println!("Decomposition failed: {}", err),
/// }
/// ```
///
/// # Errors
/// Returns [`LinalgError`] if the matrix is not square, pivot slice has wrong size, or matrix is singular.
pub fn lu_decomp<T: Real>(a: &mut Matrix<T>, ip: &mut [usize]) -> Result<(), LinalgError> {
    let n = a.nrows();
    if n != a.ncols() {
        return Err(LinalgError::BadInput {
            message: format!("Matrix is not square: {}x{}", n, a.ncols()),
        });
    }

    if ip.len() != n {
        return Err(LinalgError::PivotSizeMismatch {
            expected: n,
            actual: ip.len(),
        });
    }

    if n == 1 {
        if a[(0, 0)] == T::zero() {
            return Err(LinalgError::Singular { step: 1 });
        }
        ip[0] = 0;
        return Ok(());
    }

    let nm1 = n - 1;
    for k in 0..nm1 {
        let kp1 = k + 1;

        // Find pivot - search for largest magnitude element in column k
        let mut m = k;
        let mut max_val = a[(k, k)].abs();
        for i in kp1..n {
            let val = a[(i, k)].abs();
            if val > max_val {
                max_val = val;
                m = i;
            }
        }

        ip[k] = m;
        // store pivot value (original A(m,k)) before any swapping of row entries
        let pivot = a[(m, k)];

        // Check for singularity
        if pivot == T::zero() {
            return Err(LinalgError::Singular { step: k + 1 });
        }

        // If m != k, swap only the k-th column entries between rows m and k now
        if m != k {
            let tmp = a[(m, k)];
            a[(m, k)] = a[(k, k)];
            a[(k, k)] = tmp;
        }

        // Scale column - store negative multipliers (uses original A(i,k))
        let t = T::one() / pivot;
        for i in kp1..n {
            a[(i, k)] = -a[(i, k)] * t;
        }

        // Update remaining submatrix using original A(m,j) as multiplier (Fortran uses T=A(M,J))
        for j in kp1..n {
            // take T = original A(m,j)
            let tj = a[(m, j)];

            // swap the rest of the row entries between m and k (as lu_decomp does)
            if m != k {
                let temp = a[(m, j)];
                a[(m, j)] = a[(k, j)];
                a[(k, j)] = temp;
            }

            // Apply elimination using the original A(m,j)
            if tj != T::zero() {
                for i in kp1..n {
                    a[(i, j)] = a[(i, j)] + a[(i, k)] * tj;
                }
            }
        }
    }

    // Check if the final diagonal element is non-zero
    if a[(n - 1, n - 1)] == T::zero() {
        return Err(LinalgError::Singular { step: n });
    }

    Ok(())
}

/// Complex LU decomposition with partial pivoting
///
/// This function performs LU decomposition with partial pivoting on a complex matrix
/// represented by separate real and imaginary parts. It factorizes a complex matrix
/// (AR + i*AI) into the product P(AR + i*AI) = LU where:
/// - P is a permutation matrix (represented by pivot indices)
/// - L is unit lower triangular (with implicit unit diagonal)
/// - U is upper triangular
///
/// # Arguments
/// * `ar` - Real part of the square matrix to decompose (modified in-place)
/// * `ai` - Imaginary part of the square matrix to decompose (modified in-place)
/// * `ip` - Pivot index slice (must have length equal to matrix size)
///
/// # Returns
/// * `Ok(())` - Decomposition successful
/// * `Err(k)` - Matrix is singular, detected at step k (1-indexed)
///
/// # Algorithm
/// Similar to real LU decomposition, but with complex arithmetic:
/// 1. **Pivoting**: Find the largest magnitude complex element in column k
/// 2. **Row exchange**: Swap rows to bring the pivot to the diagonal
/// 3. **Elimination**: Use complex arithmetic to eliminate elements below the pivot
/// 4. **Update**: Apply complex elimination to the remaining submatrix
///
/// The magnitude of a complex number (a + bi) is computed as |a| + |b| for efficiency.
/// All complex operations are performed using separate real and imaginary components.
///
/// # Mathematical Background
/// Complex LU decomposition factors P(A + iB) = LU where the complex arithmetic
/// is handled explicitly:
/// - Complex multiplication: (a + bi)(c + di) = (ac - bd) + (ad + bc)i
/// - Complex division: (a + bi)/(c + di) = [(ac + bd) + (bc - ad)i]/(c² + d²)
///
/// # Examples
/// ```rust,ignore
/// use differential_equations::linalg::{Matrix, lu_decomp_complex};
///
/// let mut ar = Matrix::from_vec(2, 2, vec![1.0, 0.0, 0.0, 1.0]);
/// let mut ai = Matrix::from_vec(2, 2, vec![0.0, 1.0, 1.0, 0.0]);
/// let mut ip = [0; 2];
///
/// match lu_decomp_complex(&mut ar, &mut ai, &mut ip) {
///     Ok(()) => println!("Complex decomposition successful"),
///     Err(err) => println!("Complex decomposition failed: {}", err),
/// }
/// ```
///
/// # Errors
/// Returns [`LinalgError`] if matrices have inconsistent dimensions, pivot slice has wrong size, or matrix is singular.
pub fn lu_decomp_complex<T: Real>(
    ar: &mut Matrix<T>,
    ai: &mut Matrix<T>,
    ip: &mut [usize],
) -> Result<(), LinalgError> {
    let n = ar.nrows();
    if n != ar.ncols() || n != ai.nrows() || n != ai.ncols() {
        return Err(LinalgError::BadInput {
            message: format!(
                "Matrix dimensions inconsistent: {}x{}, {}x{}",
                ar.nrows(),
                ar.ncols(),
                ai.nrows(),
                ai.ncols()
            ),
        });
    }

    if ip.len() != n {
        return Err(LinalgError::PivotSizeMismatch {
            expected: n,
            actual: ip.len(),
        });
    }

    if n == 1 {
        if ar[(0, 0)].abs() + ai[(0, 0)].abs() == T::zero() {
            return Err(LinalgError::Singular { step: 1 });
        }
        ip[0] = 0;
        return Ok(());
    }

    let nm1 = n - 1;
    for k in 0..nm1 {
        let kp1 = k + 1;

        // Find pivot - largest magnitude complex number
        let mut m = k;
        let mut max_val = ar[(k, k)].abs() + ai[(k, k)].abs();
        for i in kp1..n {
            let val = ar[(i, k)].abs() + ai[(i, k)].abs();
            if val > max_val {
                max_val = val;
                m = i;
            }
        }

        ip[k] = m;
        // store original pivot (AR(M,K) + i*AI(M,K))
        let mut tr = ar[(m, k)];
        let mut ti = ai[(m, k)];

        // Check for singularity
        if tr.abs() + ti.abs() == T::zero() {
            return Err(LinalgError::Singular { step: k + 1 });
        }

        // If m != k, swap only the (m,k) and (k,k) entries now
        if m != k {
            let tmp_r = ar[(m, k)];
            let tmp_i = ai[(m, k)];
            ar[(m, k)] = ar[(k, k)];
            ai[(m, k)] = ai[(k, k)];
            ar[(k, k)] = tmp_r;
            ai[(k, k)] = tmp_i;
        }

        // Complex reciprocal 1/(tr + i*ti) stored as (tr, ti) = (tr/den, -ti/den)
        let den = tr * tr + ti * ti;
        tr /= den;
        ti = -ti / den;

        // Scale column - store negative multipliers
        for i in kp1..n {
            let prod_r = ar[(i, k)] * tr - ai[(i, k)] * ti;
            let prod_i = ai[(i, k)] * tr + ar[(i, k)] * ti;
            ar[(i, k)] = -prod_r;
            ai[(i, k)] = -prod_i;
        }

        // Update remaining matrix using original AR(M,J), AI(M,J) as multiplier
        for j in kp1..n {
            // take multiplier = original A(m,j)
            let mr = ar[(m, j)];
            let mi = ai[(m, j)];

            // swap the rest of the row entries between m and k
            if m != k {
                let temp_r = ar[(m, j)];
                let temp_i = ai[(m, j)];
                ar[(m, j)] = ar[(k, j)];
                ai[(m, j)] = ai[(k, j)];
                ar[(k, j)] = temp_r;
                ai[(k, j)] = temp_i;
            }

            if mr.abs() + mi.abs() != T::zero() {
                if mi == T::zero() {
                    // real multiplier
                    for i in kp1..n {
                        let prod_r = ar[(i, k)] * mr;
                        let prod_i = ai[(i, k)] * mr;
                        ar[(i, j)] += prod_r;
                        ai[(i, j)] += prod_i;
                    }
                } else if mr == T::zero() {
                    // imaginary-only multiplier
                    for i in kp1..n {
                        let prod_r = -ai[(i, k)] * mi;
                        let prod_i = ar[(i, k)] * mi;
                        ar[(i, j)] += prod_r;
                        ai[(i, j)] += prod_i;
                    }
                } else {
                    // general complex multiplier
                    for i in kp1..n {
                        let prod_r = ar[(i, k)] * mr - ai[(i, k)] * mi;
                        let prod_i = ai[(i, k)] * mr + ar[(i, k)] * mi;
                        ar[(i, j)] += prod_r;
                        ai[(i, j)] += prod_i;
                    }
                }
            }
        }
    }

    // Check final diagonal element
    if ar[(n - 1, n - 1)].abs() + ai[(n - 1, n - 1)].abs() == T::zero() {
        return Err(LinalgError::Singular { step: n });
    }

    Ok(())
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_dec_simple() {
        // Test LU decomposition of a simple 2x2 matrix
        let mut a = Matrix::from_vec(2, 2, vec![2.0_f64, 1.0, 4.0, 3.0]);
        let mut ip = [0; 2];

        let result = lu_decomp(&mut a, &mut ip);
        assert!(result.is_ok());

        // The matrix should be factorized in-place
        // We can verify that the diagonal elements are non-zero
        assert!(a[(0, 0)].abs() > 1e-10);
        assert!(a[(1, 1)].abs() > 1e-10);
    }

    #[test]
    fn test_dec_singular() {
        // Test with a singular matrix
        let mut a = Matrix::from_vec(2, 2, vec![1.0_f64, 0.0, 0.0, 0.0]);
        let mut ip = [0; 2];

        let result = lu_decomp(&mut a, &mut ip);
        assert!(result.is_err());
        assert_eq!(result.unwrap_err(), LinalgError::Singular { step: 2 });
    }

    #[test]
    fn test_dec_1x1() {
        // Test with a 1x1 matrix
        let mut a = Matrix::from_vec(1, 1, vec![5.0_f64]);
        let mut ip = [0; 1];

        let result = lu_decomp(&mut a, &mut ip);
        assert!(result.is_ok());
        assert_eq!(ip[0], 0);
    }

    #[test]
    fn test_dec_1x1_singular() {
        // Test with a singular 1x1 matrix
        let mut a = Matrix::from_vec(1, 1, vec![0.0_f64]);
        let mut ip = [0; 1];

        let result = lu_decomp(&mut a, &mut ip);
        assert!(result.is_err());
        assert_eq!(result.unwrap_err(), LinalgError::Singular { step: 1 });
    }

    #[test]
    fn test_decc_simple() {
        // Test complex LU decomposition of a simple 2x2 matrix
        let mut ar = Matrix::from_vec(2, 2, vec![1.0_f64, 0.0, 0.0, 1.0]);
        let mut ai = Matrix::from_vec(2, 2, vec![0.0, 1.0, 1.0, 0.0]);
        let mut ip = [0; 2];

        let result = lu_decomp_complex(&mut ar, &mut ai, &mut ip);
        assert!(result.is_ok());

        // Verify that the diagonal elements have non-zero magnitude
        let diag0_mag = ar[(0, 0)].abs() + ai[(0, 0)].abs();
        let diag1_mag = ar[(1, 1)].abs() + ai[(1, 1)].abs();
        assert!(diag0_mag > 1e-10);
        assert!(diag1_mag > 1e-10);
    }

    #[test]
    fn test_decc_singular() {
        // Test with a singular complex matrix
        let mut ar = Matrix::from_vec(2, 2, vec![1.0_f64, 1.0, 1.0, 1.0]);
        let mut ai = Matrix::from_vec(2, 2, vec![0.0_f64, 0.0, 0.0, 0.0]);
        let mut ip = [0; 2];

        let result = lu_decomp_complex(&mut ar, &mut ai, &mut ip);
        assert!(result.is_err());
    }

    #[test]
    fn test_decc_1x1() {
        // Test with a 1x1 complex matrix
        let mut ar = Matrix::from_vec(1, 1, vec![3.0_f64]);
        let mut ai = Matrix::from_vec(1, 1, vec![4.0_f64]); // 3 + 4i
        let mut ip = [0; 1];

        let result = lu_decomp_complex(&mut ar, &mut ai, &mut ip);
        assert!(result.is_ok());
        assert_eq!(ip[0], 0);
    }

    #[test]
    fn test_decc_1x1_singular() {
        // Test with a singular 1x1 complex matrix
        let mut ar = Matrix::from_vec(1, 1, vec![0.0_f64]);
        let mut ai = Matrix::from_vec(1, 1, vec![0.0_f64]);
        let mut ip = [0; 1];

        let result = lu_decomp_complex(&mut ar, &mut ai, &mut ip);
        assert!(result.is_err());
    }
}