1use std::ops::Mul;
16
17use na::{DMatrix, DMatrixView, DVector, DVectorView, DVectorViewMut};
18
19use crate::csv::CsVecRef;
20use crate::{
21 CscMatrixView, CsrMatrix, CsrMatrixView, CsrMatrixViewMethods, DiagonalBlockMatrixView, Real,
22};
23
24pub(crate) fn mul_csr_bd_to<T>(
28 a: CsrMatrixView<T>,
29 b: DiagonalBlockMatrixView<T>,
30 o: &mut CsrMatrix<T>,
31) where
32 T: Real,
33{
34 assert_eq!(a.ncols(), b.nrows());
35 assert_eq!(b.ncols(), o.ncols());
36 assert_eq!(o.nrows(), 0);
37
38 for i in 0..a.nrows() {
39 let mut or = o.new_row_builder(T::zero_threshold());
40 let ar = a.get_row(i);
41 let mut a_col_start = 0;
42 for bindex in 0..b.num_blocks() {
43 let range = b.get_block_row_range(bindex);
44 let block = b.view_block(bindex);
45 let mut a_n = 0;
46 for col in ar.indices().iter().skip(a_col_start) {
47 if *col >= range.end {
48 break;
49 }
50 a_n += 1;
51 }
52 for j in 0..block.ncols() {
53 let mut o_ij = T::zero();
54 let col = block.column(j);
55 for (k, a_ik) in ar.iter().skip(a_col_start).take(a_n) {
56 o_ij += a_ik * col[k - range.start];
57 }
58
59 if o_ij.abs() > T::zero_threshold() {
60 or.push(range.start + j, o_ij);
61 }
62 }
63 a_col_start += a_n;
64 }
65 }
66}
67
68pub(crate) fn mul_csr_csc_to<T: Real>(
72 a: CsrMatrixView<T>,
73 b: CscMatrixView<T>,
74 o: &mut CsrMatrix<T>,
75) {
76 assert_eq!(o.nrows(), 0);
77 assert_eq!(a.ncols(), b.nrows());
78
79 for i in 0..a.nrows() {
80 let mut or = o.new_row_builder(T::zero_threshold());
81 let ar = a.get_row(i);
82
83 for j in 0..b.ncols() {
84 let bc = b.get_col(j);
85
86 let o_ij = dot_csvec(ar, bc);
87 if o_ij.abs() > T::zero_threshold() {
88 or.push(j, o_ij);
89 }
90 }
91 }
92}
93
94#[inline]
96pub(crate) fn mul_csr_dvec_to_dvec<T: Real>(
97 a: CsrMatrixView<T>,
98 b: DVectorView<T>,
99 mut o: DVectorViewMut<T>,
100) {
101 assert_eq!(a.nrows(), o.len());
102 assert!(!b.is_empty());
103
104 for i in 0..a.nrows() {
105 let ar = a.get_row(i);
106 let mut o_i = T::zero();
107
108 for (j, a_ij) in ar.iter() {
109 o_i += a_ij * b[j];
110 }
111
112 o[i] = o_i;
113 }
114}
115
116pub(crate) fn mul_csc_dvec_to_dvec<T>(
118 a: CscMatrixView<T>,
119 b: DVectorView<T>,
120 mut o: DVectorViewMut<T>,
121) where
122 T: Real,
123{
124 assert_eq!(a.ncols(), b.len());
125 assert_eq!(o.len(), a.nrows());
126 assert!(!b.is_empty());
127 debug_assert!(o.iter().all(|v| v.abs() < T::zero_threshold()));
128
129 for j in 0..a.ncols() {
130 let aj = a.get_col(j);
131
132 for (i, a_ij) in aj.iter() {
133 o[i] += a_ij * b[j];
134 }
135 }
136}
137
138pub(crate) fn mul_csr_dmat_to_csr<T>(a: CsrMatrixView<T>, b: DMatrixView<T>, o: &mut CsrMatrix<T>)
139where
140 T: Real,
141{
142 assert_eq!(a.ncols(), b.nrows());
143 assert_eq!(o.ncols(), b.ncols());
144 assert_eq!(o.nrows(), 0);
145
146 for i in 0..a.nrows() {
147 let ar = a.get_row(i);
148 let mut or = o.new_row_builder(T::zero_threshold());
149 for j in 0..b.ncols() {
150 let bc = b.column(j);
151 let o_ij = dot_csv_dv(ar, bc);
152 if o_ij.abs() > T::zero_threshold() {
153 or.push(j, o_ij);
154 }
155 }
156 }
157}
158
159pub(crate) fn dot_csvec<T: Real>(a: CsVecRef<T>, b: CsVecRef<T>) -> T {
165 let mut res = T::zero();
166
167 let col_a = a.indices();
168 let col_b = b.indices();
169 let values_a = a.values();
170 let values_b = b.values();
171 let mut ia = 0;
172 let mut ib = 0;
173
174 unsafe {
175 while ia < col_a.len() && ib < col_b.len() {
176 let ca = *col_a.get_unchecked(ia);
177 let cb = *col_b.get_unchecked(ib);
178 match ca.cmp(&cb) {
179 std::cmp::Ordering::Less => {
180 ia += 1;
181 }
182 std::cmp::Ordering::Equal => {
183 res += *values_a.get_unchecked(ia) * *values_b.get_unchecked(ib);
184 ia += 1;
185 ib += 1;
186 }
187 std::cmp::Ordering::Greater => {
188 ib += 1;
189 }
190 }
191 }
192 }
193
194 res
195}
196
197pub(crate) fn dot_csv_dv<T>(a: CsVecRef<T>, b: DVectorView<T>) -> T
199where
200 T: Real,
201{
202 assert_eq!(a.len(), b.len());
203 let mut res = T::zero();
204 for (i, a_ij) in a.iter() {
205 res += a_ij * b[i];
206 }
207 res
208}
209
210pub(crate) fn add_csv_dv<T>(a: CsVecRef<T>, d: DVectorView<T>, mut o: DVectorViewMut<T>)
211where
212 T: Real,
213{
214 assert_eq!(a.len(), d.len());
215 assert_eq!(a.len(), o.len());
216 o.copy_from(&d);
217 for (i, a_ij) in a.iter() {
218 o[i] += a_ij * d[i];
219 }
220}
221
222pub(crate) fn mul_bd_vec<T>(
223 a: DiagonalBlockMatrixView<T>,
224 b: DVectorView<T>,
225 mut o: DVectorViewMut<T>,
226) where
227 T: Real,
228{
229 assert_eq!(a.ncols(), b.len());
230 assert_eq!(a.nrows(), o.len());
231 let mut element_offset = 0;
232 let mut row_offset = 0;
233 for block_index in 0..a.num_blocks() {
234 let block_size = a.get_block_size(block_index);
235 let block_size2 = block_size * block_size;
236 let block = DMatrixView::from_slice(
237 &a.values()[element_offset..element_offset + block_size2],
238 block_size,
239 block_size,
240 );
241 let row_range = row_offset..row_offset + block_size;
244 let mut o = o.rows_range_mut(row_range.clone());
245 let b = b.rows_range(row_range);
246 block.mul_to(&b, &mut o);
247
248 element_offset += block_size2;
249 row_offset += block_size;
250 }
251}
252
253pub fn mul_add_diag_to_csr<T: Real>(
254 o: &mut CsrMatrix<T>,
255 diag_scale: DVectorView<T>,
256 diag_add: DVectorView<T>,
257) {
258 assert_eq!(o.ncols(), o.nrows());
259 assert_eq!(o.nrows(), diag_scale.len());
260 assert_eq!(o.nrows(), diag_add.len());
261
262 for i in 0..o.nrows() {
263 let row = o.get_row_mut(i);
264 let k = row
265 .col_indices
266 .binary_search(&i)
267 .expect("Diagonal element must be present in CSR row");
268 row.values[k] = row.values[k] * diag_scale[i] + diag_add[i];
269 }
270}
271
272impl<'a, T: Real> Mul<DVectorView<'a, T>> for DiagonalBlockMatrixView<'a, T> {
273 type Output = DVector<T>;
274
275 #[inline]
276 fn mul(self, rhs: DVectorView<'a, T>) -> DVector<T> {
277 let mut o = DVector::zeros(self.nrows());
278 mul_bd_vec(self, rhs, o.as_view_mut());
279 o
280 }
281}
282
283impl<'a, T: Real> Mul<DVector<T>> for DiagonalBlockMatrixView<'a, T> {
284 type Output = DVector<T>;
285
286 #[inline]
287 fn mul(self, rhs: DVector<T>) -> DVector<T> {
288 let mut o = DVector::zeros(self.nrows());
289 mul_bd_vec(self, rhs.as_view(), o.as_view_mut());
290 o
291 }
292}
293
294impl<'a, T: Real> Mul<DiagonalBlockMatrixView<'a, T>> for CsrMatrixView<'a, T> {
295 type Output = CsrMatrix<T>;
296
297 fn mul(self, rhs: DiagonalBlockMatrixView<'a, T>) -> Self::Output {
298 let mut o = CsrMatrix::new(rhs.ncols());
299 mul_csr_bd_to(self, rhs, &mut o);
300 o
301 }
302}
303
304impl<'a, T: Real> Mul<CscMatrixView<'a, T>> for CsrMatrixView<'a, T> {
305 type Output = CsrMatrix<T>;
306
307 fn mul(self, rhs: CscMatrixView<'a, T>) -> Self::Output {
308 let mut o = CsrMatrix::new(rhs.ncols());
309 mul_csr_csc_to(self, rhs, &mut o);
310 o
311 }
312}
313
314impl<'a, T: Real> Mul<DMatrixView<'a, T>> for CsrMatrixView<'a, T> {
315 type Output = CsrMatrix<T>;
316
317 fn mul(self, rhs: DMatrixView<'a, T>) -> Self::Output {
318 let mut o = CsrMatrix::new(rhs.ncols());
319 mul_csr_dmat_to_csr(self, rhs, &mut o);
320 o
321 }
322}
323
324impl<'a, T: Real> Mul<DVectorView<'a, T>> for CsrMatrixView<'a, T> {
325 type Output = DVector<T>;
326
327 #[inline]
328 fn mul(self, rhs: DVectorView<'a, T>) -> Self::Output {
329 let mut o = DVector::zeros(self.nrows());
330 mul_csr_dvec_to_dvec(self, rhs, o.as_view_mut());
331 o
332 }
333}
334
335impl<'a, T: Real> Mul<DVector<T>> for CsrMatrixView<'a, T> {
336 type Output = DVector<T>;
337
338 #[inline]
339 fn mul(self, rhs: DVector<T>) -> Self::Output {
340 let rhs: DVectorView<T> = rhs.as_view();
341 self * rhs
342 }
343}
344
345impl<'a, T: Real> Mul<DVectorView<'a, T>> for CscMatrixView<'a, T> {
346 type Output = DVector<T>;
347
348 #[inline]
349 fn mul(self, rhs: DVectorView<'a, T>) -> Self::Output {
350 let mut o = DVector::zeros(self.nrows());
351 mul_csc_dvec_to_dvec(self, rhs, o.as_view_mut());
352 o
353 }
354}
355
356impl<'a, T: Real> Mul<DMatrixView<'a, T>> for DiagonalBlockMatrixView<'a, T> {
357 type Output = DMatrix<T>;
358
359 fn mul(self, rhs: DMatrixView<'a, T>) -> Self::Output {
360 let mut result = DMatrix::zeros(self.nrows(), rhs.ncols());
361
362 for bindex in 0..self.num_blocks() {
363 let block = self.view_block(bindex);
364 let range = self.get_block_row_range(bindex);
365 let mut output = result.rows_range_mut(range.clone());
366 let rhs = rhs.rows_range(range);
367 block.mul_to(&rhs, &mut output);
368 }
369
370 result
371 }
372}
373
374mod add {
375
376 use std::ops::Add;
377
378 use super::*;
379
380 impl<'a, T: Real> Add<DVectorView<'a, T>> for CsVecRef<'a, T> {
381 type Output = DVector<T>;
382
383 #[inline]
384 fn add(self, rhs: DVectorView<'a, T>) -> Self::Output {
385 let mut o = DVector::zeros(self.len());
386 add_csv_dv(self, rhs, o.as_view_mut());
387 o
388 }
389 }
390}