iterative_solvers/
utils.rs

1//! Utility functions for creating matrices and vectors.
2
3use nalgebra::{DMatrix, DVector};
4
5use crate::{IterSolverError, IterSolverResult};
6
7/// Creates a diagonal matrix with the given data placed on a specified diagonal.
8///
9/// This function constructs a matrix where the provided data is placed on a diagonal
10/// that can be offset from the main diagonal. The resulting matrix size is determined
11/// by the length of the data and the absolute value of the offset.
12///
13/// # Arguments
14///
15/// * `data` - A slice of f64 values to be placed on the diagonal
16/// * `offset` - The diagonal offset:
17///   - `0`: Main diagonal
18///   - Positive: Above the main diagonal (super-diagonal)
19///   - Negative: Below the main diagonal (sub-diagonal)
20///
21/// # Returns
22///
23/// A `DMatrix<f64>` containing the diagonal matrix. If `data` is empty, returns
24/// a 0×0 matrix.
25///
26/// # Examples
27///
28/// ```rust
29/// use nalgebra::DMatrix;
30/// use iterative_solvers::utils::diagm;
31///
32/// // Main diagonal
33/// let data = vec![1.0, 2.0, 3.0];
34/// let mat = diagm(&data, 0);
35/// // Creates:
36/// // [1.0, 0.0, 0.0]
37/// // [0.0, 2.0, 0.0]
38/// // [0.0, 0.0, 3.0]
39///
40/// // Super-diagonal (offset = 1)
41/// let mat = diagm(&data, 1);
42/// // Creates:
43/// // [0.0, 1.0, 0.0, 0.0]
44/// // [0.0, 0.0, 2.0, 0.0]
45/// // [0.0, 0.0, 0.0, 3.0]
46/// // [0.0, 0.0, 0.0, 0.0]
47/// ```
48pub fn diagm(data: &[f64], offset: i32) -> DMatrix<f64> {
49    if data.is_empty() {
50        return DMatrix::zeros(0, 0);
51    }
52    match offset {
53        0 => DMatrix::from_diagonal(&DVector::from_column_slice(data)),
54        offset => {
55            let offset_usize = offset.unsigned_abs() as usize;
56            let n = data.len() + offset_usize;
57            let mut mat = DMatrix::zeros(n, n);
58
59            unsafe {
60                if offset > 0 {
61                    for (idx, &val) in data.iter().enumerate() {
62                        *mat.get_unchecked_mut((idx, idx + offset_usize)) = val;
63                    }
64                } else {
65                    for (idx, &val) in data.iter().enumerate() {
66                        *mat.get_unchecked_mut((idx + offset_usize, idx)) = val;
67                    }
68                }
69            }
70            mat
71        }
72    }
73}
74
75/// Creates a tridiagonal matrix from diagonal, lower diagonal, and upper diagonal vectors.
76///
77/// This function constructs a tridiagonal matrix where:
78/// - The main diagonal contains elements from the `diagonal` vector
79/// - The sub-diagonal (below main) contains elements from the `lower` vector
80/// - The super-diagonal (above main) contains elements from the `upper` vector
81///
82/// # Arguments
83///
84/// * `diagonal` - A slice containing the main diagonal elements
85/// * `lower` - A slice containing the lower diagonal elements (sub-diagonal)
86/// * `upper` - A slice containing the upper diagonal elements (super-diagonal)
87///
88/// # Returns
89///
90/// * `Ok(DMatrix<f64>)` - The resulting tridiagonal matrix
91/// * `Err(IterSolverError::DimensionError)` - If the vector dimensions don't match the required pattern
92///
93/// # Dimension Requirements
94///
95/// For a valid tridiagonal matrix:
96/// - `diagonal.len()` must equal `lower.len() + 1`
97/// - `lower.len()` must equal `upper.len()`
98///
99/// This is because an n×n tridiagonal matrix has:
100/// - n diagonal elements
101/// - (n-1) sub-diagonal elements
102/// - (n-1) super-diagonal elements
103///
104/// # Examples
105///
106/// ```rust
107/// use iterative_solvers::utils::tridiagonal;
108///
109/// let diagonal = vec![2.0, 3.0, 4.0];
110/// let lower = vec![1.0, 1.0];
111/// let upper = vec![1.0, 1.0];
112///
113/// let result = tridiagonal(&diagonal, &lower, &upper).unwrap();
114/// // Creates:
115/// // [2.0, 1.0, 0.0]
116/// // [1.0, 3.0, 1.0]
117/// // [0.0, 1.0, 4.0]
118/// ```
119///
120/// # Errors
121///
122/// Returns `IterSolverError::DimensionError` if the input vectors have incompatible dimensions.
123pub fn tridiagonal(
124    diagonal: &[f64],
125    lower: &[f64],
126    upper: &[f64],
127) -> IterSolverResult<DMatrix<f64>> {
128    if diagonal.len() != lower.len() + 1 || lower.len() != upper.len() {
129        return Err(IterSolverError::DimensionError(format!(
130            "For tridiagonal matrix, the length of `diagonal` {}, the length of `lower` {} and `upper` {} do not match",
131            diagonal.len(),
132            lower.len(),
133            upper.len()
134        )));
135    }
136    Ok(diagm(diagonal, 0) + diagm(lower, -1) + diagm(upper, 1))
137}
138
139/// Creates a symmetric tridiagonal matrix from diagonal and sub-diagonal vectors.
140///
141/// This function constructs a symmetric tridiagonal matrix where the sub-diagonal
142/// and super-diagonal elements are identical. This is a common structure in numerical
143/// methods, particularly for solving differential equations and eigenvalue problems.
144///
145/// # Arguments
146///
147/// * `diagonal` - A slice containing the main diagonal elements
148/// * `sub_diagonal` - A slice containing the sub-diagonal elements, which will be
149///   mirrored to create the super-diagonal
150///
151/// # Returns
152///
153/// * `Ok(DMatrix<f64>)` - The resulting symmetric tridiagonal matrix
154/// * `Err(IterSolverError::DimensionError)` - If the vector dimensions don't match the required pattern
155///
156/// # Dimension Requirements
157///
158/// For a valid symmetric tridiagonal matrix:
159/// - `diagonal.len()` must equal `sub_diagonal.len() + 1`
160///
161/// # Examples
162///
163/// ```rust
164/// use iterative_solvers::utils::symmetric_tridiagonal;
165///
166/// let diagonal = vec![2.0, 3.0, 4.0];
167/// let sub_diagonal = vec![1.0, 1.5];
168///
169/// let result = symmetric_tridiagonal(&diagonal, &sub_diagonal).unwrap();
170/// // Creates:
171/// // [2.0, 1.0, 0.0]
172/// // [1.0, 3.0, 1.5]
173/// // [0.0, 1.5, 4.0]
174/// ```
175///
176/// # Errors
177///
178/// Returns `IterSolverError::DimensionError` if the input vectors have incompatible dimensions.
179///
180/// # Note
181///
182/// This function internally calls `tridiagonal(diagonal, sub_diagonal, sub_diagonal)`,
183/// ensuring that the lower and upper diagonals are identical.
184pub fn symmetric_tridiagonal(
185    diagonal: &[f64],
186    sub_diagonal: &[f64],
187) -> IterSolverResult<DMatrix<f64>> {
188    tridiagonal(diagonal, sub_diagonal, sub_diagonal)
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194
195    #[test]
196    fn test_diag() {
197        let data = vec![1.0, 2.0, 3.0];
198        let mat = diagm(&data, 0);
199        println!("{}", mat);
200        let mat = diagm(&data, 1);
201        println!("{}", mat);
202        let mat = diagm(&data, -1);
203        println!("{}", mat);
204        let mat = diagm(&data, 2);
205        println!("{}", mat);
206        let mat = diagm(&data, -2);
207        println!("{}", mat);
208    }
209}