algebra_sparse/
ops.rs

1// Copyright (C) 2020-2025 algebra-sparse authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::ops::Mul;
16
17use na::{DMatrix, DMatrixView, DVector, DVectorView, DVectorViewMut};
18
19use crate::csv::CsVecRef;
20use crate::{
21    CscMatrixView, CsrMatrix, CsrMatrixView, CsrMatrixViewMethods, DiagonalBlockMatrixView, Real,
22};
23
24/// Multiply a sparse matrix `a` with a block diagonal matrix `b` and store the result in `o`.
25///
26/// `C = A * B`
27pub(crate) fn mul_csr_bd_to<T>(
28    a: CsrMatrixView<T>,
29    b: DiagonalBlockMatrixView<T>,
30    o: &mut CsrMatrix<T>,
31) where
32    T: Real,
33{
34    assert_eq!(a.ncols(), b.nrows());
35    assert_eq!(b.ncols(), o.ncols());
36    assert_eq!(o.nrows(), 0);
37
38    for i in 0..a.nrows() {
39        let mut or = o.new_row_builder(T::zero_threshold());
40        let ar = a.get_row(i);
41        let mut a_col_start = 0;
42        for bindex in 0..b.num_blocks() {
43            let range = b.get_block_row_range(bindex);
44            let block = b.view_block(bindex);
45            let mut a_n = 0;
46            for col in ar.indices().iter().skip(a_col_start) {
47                if *col >= range.end {
48                    break;
49                }
50                a_n += 1;
51            }
52            for j in 0..block.ncols() {
53                let mut o_ij = T::zero();
54                let col = block.column(j);
55                for (k, a_ik) in ar.iter().skip(a_col_start).take(a_n) {
56                    o_ij += a_ik * col[k - range.start];
57                }
58
59                if o_ij.abs() > T::zero_threshold() {
60                    or.push(range.start + j, o_ij);
61                }
62            }
63            a_col_start += a_n;
64        }
65    }
66}
67
68/// Csr(O) = Csr(a) * Csc(b).
69///
70/// The o is assumed empty before pass to this function.
71pub(crate) fn mul_csr_csc_to<T: Real>(
72    a: CsrMatrixView<T>,
73    b: CscMatrixView<T>,
74    o: &mut CsrMatrix<T>,
75) {
76    assert_eq!(o.nrows(), 0);
77    assert_eq!(a.ncols(), b.nrows());
78
79    for i in 0..a.nrows() {
80        let mut or = o.new_row_builder(T::zero_threshold());
81        let ar = a.get_row(i);
82
83        for j in 0..b.ncols() {
84            let bc = b.get_col(j);
85
86            let o_ij = dot_csvec(ar, bc);
87            if o_ij.abs() > T::zero_threshold() {
88                or.push(j, o_ij);
89            }
90        }
91    }
92}
93
94/// `DenseVec(o) = CsrMat(a) x DenseVec(b)`
95#[inline]
96pub(crate) fn mul_csr_dvec_to_dvec<T: Real>(
97    a: CsrMatrixView<T>,
98    b: DVectorView<T>,
99    mut o: DVectorViewMut<T>,
100) {
101    assert_eq!(a.nrows(), o.len());
102    assert!(!b.is_empty());
103
104    for i in 0..a.nrows() {
105        let ar = a.get_row(i);
106        let mut o_i = T::zero();
107
108        for (j, a_ij) in ar.iter() {
109            o_i += a_ij * b[j];
110        }
111
112        o[i] = o_i;
113    }
114}
115
116/// `DenseVec(o) = CscMat(a) x DenseVec(b)`
117pub(crate) fn mul_csc_dvec_to_dvec<T>(
118    a: CscMatrixView<T>,
119    b: DVectorView<T>,
120    mut o: DVectorViewMut<T>,
121) where
122    T: Real,
123{
124    assert_eq!(a.ncols(), b.len());
125    assert_eq!(o.len(), a.nrows());
126    assert!(!b.is_empty());
127    debug_assert!(o.iter().all(|v| v.abs() < T::zero_threshold()));
128
129    for j in 0..a.ncols() {
130        let aj = a.get_col(j);
131
132        for (i, a_ij) in aj.iter() {
133            o[i] += a_ij * b[j];
134        }
135    }
136}
137
138pub(crate) fn mul_csr_dmat_to_csr<T>(a: CsrMatrixView<T>, b: DMatrixView<T>, o: &mut CsrMatrix<T>)
139where
140    T: Real,
141{
142    assert_eq!(a.ncols(), b.nrows());
143    assert_eq!(o.ncols(), b.ncols());
144    assert_eq!(o.nrows(), 0);
145
146    for i in 0..a.nrows() {
147        let ar = a.get_row(i);
148        let mut or = o.new_row_builder(T::zero_threshold());
149        for j in 0..b.ncols() {
150            let bc = b.column(j);
151            let o_ij = dot_csv_dv(ar, bc);
152            if o_ij.abs() > T::zero_threshold() {
153                or.push(j, o_ij);
154            }
155        }
156    }
157}
158
159/// Compute the dot product of two sparse vectors `a` and `b`.
160///
161/// # Note
162///
163/// This method need CsVec's element stored in ascending order.
164pub(crate) fn dot_csvec<T: Real>(a: CsVecRef<T>, b: CsVecRef<T>) -> T {
165    let mut res = T::zero();
166
167    let col_a = a.indices();
168    let col_b = b.indices();
169    let values_a = a.values();
170    let values_b = b.values();
171    let mut ia = 0;
172    let mut ib = 0;
173
174    unsafe {
175        while ia < col_a.len() && ib < col_b.len() {
176            let ca = *col_a.get_unchecked(ia);
177            let cb = *col_b.get_unchecked(ib);
178            match ca.cmp(&cb) {
179                std::cmp::Ordering::Less => {
180                    ia += 1;
181                }
182                std::cmp::Ordering::Equal => {
183                    res += *values_a.get_unchecked(ia) * *values_b.get_unchecked(ib);
184                    ia += 1;
185                    ib += 1;
186                }
187                std::cmp::Ordering::Greater => {
188                    ib += 1;
189                }
190            }
191        }
192    }
193
194    res
195}
196
197/// Dot product between sparse vec `a` and dense vector `b`
198pub(crate) fn dot_csv_dv<T>(a: CsVecRef<T>, b: DVectorView<T>) -> T
199where
200    T: Real,
201{
202    assert_eq!(a.len(), b.len());
203    let mut res = T::zero();
204    for (i, a_ij) in a.iter() {
205        res += a_ij * b[i];
206    }
207    res
208}
209
210pub(crate) fn add_csv_dv<T>(a: CsVecRef<T>, d: DVectorView<T>, mut o: DVectorViewMut<T>)
211where
212    T: Real,
213{
214    assert_eq!(a.len(), d.len());
215    assert_eq!(a.len(), o.len());
216    o.copy_from(&d);
217    for (i, a_ij) in a.iter() {
218        o[i] += a_ij * d[i];
219    }
220}
221
222pub(crate) fn mul_bd_vec<T>(
223    a: DiagonalBlockMatrixView<T>,
224    b: DVectorView<T>,
225    mut o: DVectorViewMut<T>,
226) where
227    T: Real,
228{
229    assert_eq!(a.ncols(), b.len());
230    assert_eq!(a.nrows(), o.len());
231    let mut element_offset = 0;
232    let mut row_offset = 0;
233    for block_index in 0..a.num_blocks() {
234        let block_size = a.get_block_size(block_index);
235        let block_size2 = block_size * block_size;
236        let block = DMatrixView::from_slice(
237            &a.values()[element_offset..element_offset + block_size2],
238            block_size,
239            block_size,
240        );
241        // let block = a.view_block(block_index);
242        // let row_range = a.get_block_row_range(block_index);
243        let row_range = row_offset..row_offset + block_size;
244        let mut o = o.rows_range_mut(row_range.clone());
245        let b = b.rows_range(row_range);
246        block.mul_to(&b, &mut o);
247
248        element_offset += block_size2;
249        row_offset += block_size;
250    }
251}
252
253pub fn mul_add_diag_to_csr<T: Real>(
254    o: &mut CsrMatrix<T>,
255    diag_scale: DVectorView<T>,
256    diag_add: DVectorView<T>,
257) {
258    assert_eq!(o.ncols(), o.nrows());
259    assert_eq!(o.nrows(), diag_scale.len());
260    assert_eq!(o.nrows(), diag_add.len());
261
262    for i in 0..o.nrows() {
263        let row = o.get_row_mut(i);
264        let k = row
265            .col_indices
266            .binary_search(&i)
267            .expect("Diagonal element must be present in CSR row");
268        row.values[k] = row.values[k] * diag_scale[i] + diag_add[i];
269    }
270}
271
272impl<'a, T: Real> Mul<DVectorView<'a, T>> for DiagonalBlockMatrixView<'a, T> {
273    type Output = DVector<T>;
274
275    #[inline]
276    fn mul(self, rhs: DVectorView<'a, T>) -> DVector<T> {
277        let mut o = DVector::zeros(self.nrows());
278        mul_bd_vec(self, rhs, o.as_view_mut());
279        o
280    }
281}
282
283impl<'a, T: Real> Mul<DVector<T>> for DiagonalBlockMatrixView<'a, T> {
284    type Output = DVector<T>;
285
286    #[inline]
287    fn mul(self, rhs: DVector<T>) -> DVector<T> {
288        let mut o = DVector::zeros(self.nrows());
289        mul_bd_vec(self, rhs.as_view(), o.as_view_mut());
290        o
291    }
292}
293
294impl<'a, T: Real> Mul<DiagonalBlockMatrixView<'a, T>> for CsrMatrixView<'a, T> {
295    type Output = CsrMatrix<T>;
296
297    fn mul(self, rhs: DiagonalBlockMatrixView<'a, T>) -> Self::Output {
298        let mut o = CsrMatrix::new(rhs.ncols());
299        mul_csr_bd_to(self, rhs, &mut o);
300        o
301    }
302}
303
304impl<'a, T: Real> Mul<CscMatrixView<'a, T>> for CsrMatrixView<'a, T> {
305    type Output = CsrMatrix<T>;
306
307    fn mul(self, rhs: CscMatrixView<'a, T>) -> Self::Output {
308        let mut o = CsrMatrix::new(rhs.ncols());
309        mul_csr_csc_to(self, rhs, &mut o);
310        o
311    }
312}
313
314impl<'a, T: Real> Mul<DMatrixView<'a, T>> for CsrMatrixView<'a, T> {
315    type Output = CsrMatrix<T>;
316
317    fn mul(self, rhs: DMatrixView<'a, T>) -> Self::Output {
318        let mut o = CsrMatrix::new(rhs.ncols());
319        mul_csr_dmat_to_csr(self, rhs, &mut o);
320        o
321    }
322}
323
324impl<'a, T: Real> Mul<DVectorView<'a, T>> for CsrMatrixView<'a, T> {
325    type Output = DVector<T>;
326
327    #[inline]
328    fn mul(self, rhs: DVectorView<'a, T>) -> Self::Output {
329        let mut o = DVector::zeros(self.nrows());
330        mul_csr_dvec_to_dvec(self, rhs, o.as_view_mut());
331        o
332    }
333}
334
335impl<'a, T: Real> Mul<DVector<T>> for CsrMatrixView<'a, T> {
336    type Output = DVector<T>;
337
338    #[inline]
339    fn mul(self, rhs: DVector<T>) -> Self::Output {
340        let rhs: DVectorView<T> = rhs.as_view();
341        self * rhs
342    }
343}
344
345impl<'a, T: Real> Mul<DVectorView<'a, T>> for CscMatrixView<'a, T> {
346    type Output = DVector<T>;
347
348    #[inline]
349    fn mul(self, rhs: DVectorView<'a, T>) -> Self::Output {
350        let mut o = DVector::zeros(self.nrows());
351        mul_csc_dvec_to_dvec(self, rhs, o.as_view_mut());
352        o
353    }
354}
355
356impl<'a, T: Real> Mul<DMatrixView<'a, T>> for DiagonalBlockMatrixView<'a, T> {
357    type Output = DMatrix<T>;
358
359    fn mul(self, rhs: DMatrixView<'a, T>) -> Self::Output {
360        let mut result = DMatrix::zeros(self.nrows(), rhs.ncols());
361
362        for bindex in 0..self.num_blocks() {
363            let block = self.view_block(bindex);
364            let range = self.get_block_row_range(bindex);
365            let mut output = result.rows_range_mut(range.clone());
366            let rhs = rhs.rows_range(range);
367            block.mul_to(&rhs, &mut output);
368        }
369
370        result
371    }
372}
373
374mod add {
375
376    use std::ops::Add;
377
378    use super::*;
379
380    impl<'a, T: Real> Add<DVectorView<'a, T>> for CsVecRef<'a, T> {
381        type Output = DVector<T>;
382
383        #[inline]
384        fn add(self, rhs: DVectorView<'a, T>) -> Self::Output {
385            let mut o = DVector::zeros(self.len());
386            add_csv_dv(self, rhs, o.as_view_mut());
387            o
388        }
389    }
390}