Skip to main content

diffsol/vector/
nalgebra_serial.rs

1use std::ops::{Add, AddAssign, Div, Index, IndexMut, Mul, MulAssign, Sub, SubAssign};
2
3use super::utils::*;
4use nalgebra::{DVector, DVectorView, DVectorViewMut, LpNorm};
5
6use crate::{IndexType, NalgebraContext, NalgebraMat, NalgebraScalar, Scalar, Scale, VectorHost};
7
8use super::{DefaultDenseMatrix, Vector, VectorCommon, VectorIndex, VectorView, VectorViewMut};
9
10#[derive(Debug, Clone, PartialEq)]
11pub struct NalgebraIndex {
12    pub(crate) data: DVector<IndexType>,
13    pub(crate) context: NalgebraContext,
14}
15
16#[derive(Debug, Clone, PartialEq)]
17pub struct NalgebraVec<T: NalgebraScalar> {
18    pub(crate) data: DVector<T>,
19    pub(crate) context: NalgebraContext,
20}
21
22#[derive(Debug, Clone, PartialEq)]
23pub struct NalgebraVecRef<'a, T: NalgebraScalar> {
24    pub(crate) data: DVectorView<'a, T>,
25    pub(crate) context: NalgebraContext,
26}
27
28#[derive(Debug, PartialEq)]
29pub struct NalgebraVecMut<'a, T: NalgebraScalar> {
30    pub(crate) data: DVectorViewMut<'a, T>,
31    pub(crate) context: NalgebraContext,
32}
33
34impl<T: NalgebraScalar> From<DVector<T>> for NalgebraVec<T> {
35    fn from(data: DVector<T>) -> Self {
36        Self {
37            data,
38            context: NalgebraContext,
39        }
40    }
41}
42
43impl<T: NalgebraScalar> DefaultDenseMatrix for NalgebraVec<T> {
44    type M = NalgebraMat<T>;
45}
46
47impl_vector_common!(NalgebraVec<T>, NalgebraContext, DVector<T>, NalgebraScalar);
48impl_vector_common_ref!(
49    NalgebraVecRef<'a, T>,
50    NalgebraContext,
51    DVectorView<'a, T>,
52    NalgebraScalar
53);
54impl_vector_common_ref!(
55    NalgebraVecMut<'a, T>,
56    NalgebraContext,
57    DVectorViewMut<'a, T>,
58    NalgebraScalar
59);
60
61macro_rules! impl_mul_scalar {
62    ($lhs:ty, $out:ty, $scalar:ty) => {
63        impl<T: NalgebraScalar> Mul<Scale<T>> for $lhs {
64            type Output = $out;
65            #[inline]
66            fn mul(self, rhs: Scale<T>) -> Self::Output {
67                let scale: $scalar = rhs.value();
68                Self::Output {
69                    data: &self.data * scale,
70                    context: self.context,
71                }
72            }
73        }
74    };
75}
76
77macro_rules! impl_div_scalar {
78    ($lhs:ty, $out:ty, $scalar:expr) => {
79        impl<'a, T: NalgebraScalar> Div<Scale<T>> for $lhs {
80            type Output = $out;
81            #[inline]
82            fn div(self, rhs: Scale<T>) -> Self::Output {
83                let inv_rhs: T = T::one() / rhs.value();
84                Self::Output {
85                    data: self.data * inv_rhs,
86                    context: self.context,
87                }
88            }
89        }
90    };
91}
92
93macro_rules! impl_mul_assign_scalar {
94    ($col_type:ty, $scalar:ty) => {
95        impl<'a, T: NalgebraScalar> MulAssign<Scale<T>> for $col_type {
96            #[inline]
97            fn mul_assign(&mut self, rhs: Scale<T>) {
98                let scale = rhs.value();
99                self.data *= scale;
100            }
101        }
102    };
103}
104
105impl_mul_scalar!(NalgebraVec<T>, NalgebraVec<T>, T);
106impl_mul_scalar!(&NalgebraVec<T>, NalgebraVec<T>, T);
107impl_mul_scalar!(NalgebraVecRef<'_, T>, NalgebraVec<T>, T);
108impl_mul_scalar!(NalgebraVecMut<'_, T>, NalgebraVec<T>, T);
109impl_div_scalar!(NalgebraVec<T>, NalgebraVec<T>, T);
110impl_mul_assign_scalar!(NalgebraVecMut<'a, T>, T);
111impl_mul_assign_scalar!(NalgebraVec<T>, T);
112
113impl_sub_assign!(NalgebraVec<T>, NalgebraVec<T>, NalgebraScalar);
114impl_sub_assign!(NalgebraVec<T>, &NalgebraVec<T>, NalgebraScalar);
115impl_sub_assign!(NalgebraVec<T>, NalgebraVecRef<'_, T>, NalgebraScalar);
116impl_sub_assign!(NalgebraVec<T>, &NalgebraVecRef<'_, T>, NalgebraScalar);
117
118impl_sub_assign!(NalgebraVecMut<'_, T>, NalgebraVec<T>, NalgebraScalar);
119impl_sub_assign!(NalgebraVecMut<'_, T>, &NalgebraVec<T>, NalgebraScalar);
120impl_sub_assign!(NalgebraVecMut<'_, T>, NalgebraVecRef<'_, T>, NalgebraScalar);
121impl_sub_assign!(
122    NalgebraVecMut<'_, T>,
123    &NalgebraVecRef<'_, T>,
124    NalgebraScalar
125);
126
127impl_add_assign!(NalgebraVec<T>, NalgebraVec<T>, NalgebraScalar);
128impl_add_assign!(NalgebraVec<T>, &NalgebraVec<T>, NalgebraScalar);
129impl_add_assign!(NalgebraVec<T>, NalgebraVecRef<'_, T>, NalgebraScalar);
130impl_add_assign!(NalgebraVec<T>, &NalgebraVecRef<'_, T>, NalgebraScalar);
131
132impl_add_assign!(NalgebraVecMut<'_, T>, NalgebraVec<T>, NalgebraScalar);
133impl_add_assign!(NalgebraVecMut<'_, T>, &NalgebraVec<T>, NalgebraScalar);
134impl_add_assign!(NalgebraVecMut<'_, T>, NalgebraVecRef<'_, T>, NalgebraScalar);
135impl_add_assign!(
136    NalgebraVecMut<'_, T>,
137    &NalgebraVecRef<'_, T>,
138    NalgebraScalar
139);
140
141impl_sub_both_ref!(
142    &NalgebraVec<T>,
143    &NalgebraVec<T>,
144    NalgebraVec<T>,
145    NalgebraScalar
146);
147impl_sub_rhs!(
148    &NalgebraVec<T>,
149    NalgebraVec<T>,
150    NalgebraVec<T>,
151    NalgebraScalar
152);
153impl_sub_both_ref!(
154    &NalgebraVec<T>,
155    NalgebraVecRef<'_, T>,
156    NalgebraVec<T>,
157    NalgebraScalar
158);
159impl_sub_both_ref!(
160    &NalgebraVec<T>,
161    &NalgebraVecRef<'_, T>,
162    NalgebraVec<T>,
163    NalgebraScalar
164);
165
166impl_sub_lhs!(
167    NalgebraVec<T>,
168    NalgebraVec<T>,
169    NalgebraVec<T>,
170    NalgebraScalar
171);
172impl_sub_lhs!(
173    NalgebraVec<T>,
174    &NalgebraVec<T>,
175    NalgebraVec<T>,
176    NalgebraScalar
177);
178impl_sub_lhs!(
179    NalgebraVec<T>,
180    NalgebraVecRef<'_, T>,
181    NalgebraVec<T>,
182    NalgebraScalar
183);
184impl_sub_lhs!(
185    NalgebraVec<T>,
186    &NalgebraVecRef<'_, T>,
187    NalgebraVec<T>,
188    NalgebraScalar
189);
190
191impl_sub_rhs!(
192    NalgebraVecRef<'_, T>,
193    NalgebraVec<T>,
194    NalgebraVec<T>,
195    NalgebraScalar
196);
197impl_sub_both_ref!(
198    NalgebraVecRef<'_, T>,
199    &NalgebraVec<T>,
200    NalgebraVec<T>,
201    NalgebraScalar
202);
203impl_sub_both_ref!(
204    NalgebraVecRef<'_, T>,
205    NalgebraVecRef<'_, T>,
206    NalgebraVec<T>,
207    NalgebraScalar
208);
209impl_sub_both_ref!(
210    NalgebraVecRef<'_, T>,
211    &NalgebraVecRef<'_, T>,
212    NalgebraVec<T>,
213    NalgebraScalar
214);
215
216impl_add_both_ref!(
217    &NalgebraVec<T>,
218    &NalgebraVec<T>,
219    NalgebraVec<T>,
220    NalgebraScalar
221);
222impl_add_rhs!(
223    &NalgebraVec<T>,
224    NalgebraVec<T>,
225    NalgebraVec<T>,
226    NalgebraScalar
227);
228impl_add_both_ref!(
229    &NalgebraVec<T>,
230    NalgebraVecRef<'_, T>,
231    NalgebraVec<T>,
232    NalgebraScalar
233);
234impl_add_both_ref!(
235    &NalgebraVec<T>,
236    &NalgebraVecRef<'_, T>,
237    NalgebraVec<T>,
238    NalgebraScalar
239);
240
241impl_add_lhs!(
242    NalgebraVec<T>,
243    NalgebraVec<T>,
244    NalgebraVec<T>,
245    NalgebraScalar
246);
247impl_add_lhs!(
248    NalgebraVec<T>,
249    &NalgebraVec<T>,
250    NalgebraVec<T>,
251    NalgebraScalar
252);
253impl_add_lhs!(
254    NalgebraVec<T>,
255    NalgebraVecRef<'_, T>,
256    NalgebraVec<T>,
257    NalgebraScalar
258);
259impl_add_lhs!(
260    NalgebraVec<T>,
261    &NalgebraVecRef<'_, T>,
262    NalgebraVec<T>,
263    NalgebraScalar
264);
265
266impl_add_rhs!(
267    NalgebraVecRef<'_, T>,
268    NalgebraVec<T>,
269    NalgebraVec<T>,
270    NalgebraScalar
271);
272impl_add_both_ref!(
273    NalgebraVecRef<'_, T>,
274    &NalgebraVec<T>,
275    NalgebraVec<T>,
276    NalgebraScalar
277);
278impl_add_both_ref!(
279    NalgebraVecRef<'_, T>,
280    NalgebraVecRef<'_, T>,
281    NalgebraVec<T>,
282    NalgebraScalar
283);
284impl_add_both_ref!(
285    NalgebraVecRef<'_, T>,
286    &NalgebraVecRef<'_, T>,
287    NalgebraVec<T>,
288    NalgebraScalar
289);
290
291impl_index!(NalgebraVec<T>, NalgebraScalar);
292impl_index_mut!(NalgebraVec<T>, NalgebraScalar);
293
294impl_index!(NalgebraVecRef<'_, T>, NalgebraScalar);
295
296impl VectorIndex for NalgebraIndex {
297    type C = NalgebraContext;
298    fn zeros(len: IndexType, ctx: Self::C) -> Self {
299        let data = DVector::from_element(len, 0);
300        Self { data, context: ctx }
301    }
302    fn len(&self) -> crate::IndexType {
303        self.data.len()
304    }
305    fn from_vec(v: Vec<IndexType>, ctx: Self::C) -> Self {
306        let data = DVector::from_vec(v);
307        Self { data, context: ctx }
308    }
309    fn clone_as_vec(&self) -> Vec<IndexType> {
310        self.data.iter().copied().collect()
311    }
312    fn context(&self) -> &Self::C {
313        &self.context
314    }
315}
316
317impl<'a, T: NalgebraScalar> VectorView<'a> for NalgebraVecRef<'a, T> {
318    type Owned = NalgebraVec<T>;
319
320    fn into_owned(self) -> Self::Owned {
321        Self::Owned {
322            data: self.data.into_owned(),
323            context: self.context,
324        }
325    }
326    fn squared_norm(&self, y: &Self::Owned, atol: &Self::Owned, rtol: Self::T) -> Self::T {
327        let mut acc = T::zero();
328        if y.len() != self.data.len() || y.len() != atol.len() {
329            panic!("Vector lengths do not match");
330        }
331        for i in 0..self.data.len() {
332            let yi = unsafe { y.data.get_unchecked(i) };
333            let ai = unsafe { atol.data.get_unchecked(i) };
334            let xi = unsafe { self.data.get_unchecked(i) };
335            let term = *xi / (yi.abs() * rtol + *ai);
336            acc += term * term;
337        }
338        acc / Self::T::from_f64(self.data.len() as f64).unwrap()
339    }
340}
341
342impl<'a, T: NalgebraScalar> VectorViewMut<'a> for NalgebraVecMut<'a, T> {
343    type Owned = NalgebraVec<T>;
344    type View = NalgebraVecRef<'a, T>;
345    type Index = NalgebraIndex;
346    fn copy_from(&mut self, other: &Self::Owned) {
347        self.data.copy_from(&other.data);
348    }
349    fn copy_from_view(&mut self, other: &Self::View) {
350        self.data.copy_from(&other.data);
351    }
352    fn axpy(&mut self, alpha: Self::T, x: &Self::Owned, beta: Self::T) {
353        self.data.axpy(alpha, &x.data, beta);
354    }
355}
356
357impl<T: NalgebraScalar> VectorHost for NalgebraVec<T> {
358    fn as_slice(&self) -> &[Self::T] {
359        self.data.as_slice()
360    }
361    fn as_mut_slice(&mut self) -> &mut [Self::T] {
362        self.data.as_mut_slice()
363    }
364}
365
366impl<T: NalgebraScalar> Vector for NalgebraVec<T> {
367    type View<'a> = NalgebraVecRef<'a, T>;
368    type ViewMut<'a> = NalgebraVecMut<'a, T>;
369    type Index = NalgebraIndex;
370    fn len(&self) -> IndexType {
371        self.data.len()
372    }
373    fn inner_mut(&mut self) -> &mut Self::Inner {
374        &mut self.data
375    }
376    fn context(&self) -> &Self::C {
377        &self.context
378    }
379    fn norm(&self, k: i32) -> Self::T {
380        self.data.apply_norm(&LpNorm(k))
381    }
382    fn get_index(&self, index: IndexType) -> Self::T {
383        self.data[index]
384    }
385    fn set_index(&mut self, index: IndexType, value: Self::T) {
386        self.data[index] = value;
387    }
388    fn squared_norm(&self, y: &Self, atol: &Self, rtol: Self::T) -> Self::T {
389        let mut acc = T::zero();
390        if y.len() != self.len() || y.len() != atol.len() {
391            panic!("Vector lengths do not match");
392        }
393        for i in 0..self.len() {
394            let yi = unsafe { y.data.get_unchecked(i) };
395            let ai = unsafe { atol.data.get_unchecked(i) };
396            let xi = unsafe { self.data.get_unchecked(i) };
397            let term = *xi / (yi.abs() * rtol + *ai);
398            acc += term * term;
399        }
400        acc / Self::T::from_f64(self.len() as f64).unwrap()
401    }
402    fn as_view(&self) -> Self::View<'_> {
403        Self::View {
404            data: self.data.as_view(),
405            context: self.context,
406        }
407    }
408    fn as_view_mut(&mut self) -> Self::ViewMut<'_> {
409        Self::ViewMut {
410            data: self.data.as_view_mut(),
411            context: self.context,
412        }
413    }
414    fn copy_from(&mut self, other: &Self) {
415        self.data.copy_from(&other.data);
416    }
417    fn fill(&mut self, value: Self::T) {
418        self.data.iter_mut().for_each(|x: &mut _| *x = value);
419    }
420    fn copy_from_view(&mut self, other: &Self::View<'_>) {
421        self.data.copy_from(&other.data);
422    }
423    fn from_element(nstates: usize, value: T, ctx: Self::C) -> Self {
424        let data = DVector::from_element(nstates, value);
425        Self { data, context: ctx }
426    }
427    fn from_vec(vec: Vec<T>, ctx: Self::C) -> Self {
428        let data = DVector::from_vec(vec);
429        Self { data, context: ctx }
430    }
431    fn from_slice(slice: &[T], ctx: Self::C) -> Self {
432        let data = DVector::from_column_slice(slice);
433        Self { data, context: ctx }
434    }
435    fn clone_as_vec(&self) -> Vec<Self::T> {
436        self.data.iter().copied().collect()
437    }
438    fn zeros(nstates: usize, ctx: Self::C) -> Self {
439        let data = DVector::zeros(nstates);
440        Self { data, context: ctx }
441    }
442    fn axpy(&mut self, alpha: T, x: &Self, beta: T) {
443        self.data.axpy(alpha, &x.data, beta);
444    }
445    fn axpy_v(&mut self, alpha: Self::T, x: &Self::View<'_>, beta: Self::T) {
446        self.data.axpy(alpha, &x.data, beta);
447    }
448    fn component_div_assign(&mut self, other: &Self) {
449        self.data.component_div_assign(&other.data);
450    }
451    fn component_mul_assign(&mut self, other: &Self) {
452        self.data.component_mul_assign(&other.data);
453    }
454
455    fn root_finding(&self, g1: &Self) -> (bool, Self::T, i32) {
456        let mut max_frac = T::zero();
457        let mut max_frac_index = -1;
458        let mut found_root = false;
459        assert_eq!(self.len(), g1.len(), "Vector lengths do not match");
460        for i in 0..self.len() {
461            let g0 = unsafe { *self.data.get_unchecked(i) };
462            let g1 = unsafe { *g1.data.get_unchecked(i) };
463            if g1 == T::zero() {
464                found_root = true;
465            }
466            if g0 * g1 < T::zero() {
467                let frac = (g1 / (g1 - g0)).abs();
468                if frac > max_frac {
469                    max_frac = frac;
470                    max_frac_index = i as i32;
471                }
472            }
473        }
474        (found_root, max_frac, max_frac_index)
475    }
476
477    fn assign_at_indices(&mut self, indices: &Self::Index, value: Self::T) {
478        for i in indices.data.iter() {
479            self[*i] = value;
480        }
481    }
482
483    fn copy_from_indices(&mut self, other: &Self, indices: &Self::Index) {
484        for i in indices.data.iter() {
485            self[*i] = other[*i];
486        }
487    }
488
489    fn gather(&mut self, other: &Self, indices: &Self::Index) {
490        assert_eq!(self.len(), indices.len(), "Vector lengths do not match");
491        for (s, o) in self.data.iter_mut().zip(indices.data.iter()) {
492            *s = other[*o];
493        }
494    }
495
496    fn scatter(&self, indices: &Self::Index, other: &mut Self) {
497        assert_eq!(self.len(), indices.len(), "Vector lengths do not match");
498        for (s, o) in self.data.iter().zip(indices.data.iter()) {
499            other[*o] = *s;
500        }
501    }
502}
503
504// tests
505#[cfg(test)]
506mod tests {
507    use super::*;
508
509    #[test]
510    fn test_error_norm() {
511        let v = NalgebraVec::from_vec(vec![1.0, -2.0, 3.0], Default::default());
512        let y = NalgebraVec::from_vec(vec![1.0, 2.0, 3.0], Default::default());
513        let atol = NalgebraVec::from_vec(vec![0.1, 0.2, 0.3], Default::default());
514        let rtol = 0.1;
515        let mut tmp = y.clone() * Scale(rtol);
516        tmp += &atol;
517        let mut r = v.clone();
518        r.component_div_assign(&tmp);
519        let errorn_check = r.data.norm_squared() / 3.0;
520        assert_eq!(v.squared_norm(&y, &atol, rtol), errorn_check);
521        let vview = v.as_view();
522        assert_eq!(
523            VectorView::squared_norm(&vview, &y, &atol, rtol),
524            errorn_check
525        );
526    }
527
528    #[test]
529    fn test_root_finding() {
530        super::super::tests::test_root_finding::<NalgebraVec<f64>>();
531    }
532
533    #[test]
534    fn test_from_slice() {
535        let slice = [1.0, 2.0, 3.0];
536        let v = NalgebraVec::from_slice(&slice, Default::default());
537        assert_eq!(v.clone_as_vec(), slice);
538    }
539
540    #[test]
541    fn test_into() {
542        let vec = DVector::from_vec(vec![1.0, 2.0, 3.0]);
543        let v: NalgebraVec<f64> = vec.into();
544        assert_eq!(v.clone_as_vec(), vec![1.0, 2.0, 3.0]);
545    }
546}