diffsol/vector/
faer_serial.rs

1use std::ops::{Add, AddAssign, Div, Index, IndexMut, Mul, MulAssign, Sub, SubAssign};
2use std::slice;
3
4use faer::{unzip, zip, Col, ColMut, ColRef};
5
6use crate::{scalar::Scale, FaerContext, IndexType, Scalar, Vector};
7
8use crate::{FaerMat, VectorCommon, VectorHost, VectorIndex, VectorView, VectorViewMut};
9
10use super::utils::*;
11use super::DefaultDenseMatrix;
12
13#[derive(Debug, Clone, PartialEq)]
14pub struct FaerVec<T: Scalar> {
15    pub(crate) data: Col<T>,
16    pub(crate) context: FaerContext,
17}
18
19#[derive(Debug, Clone, PartialEq)]
20pub struct FaerVecIndex {
21    pub(crate) data: Vec<IndexType>,
22    pub(crate) context: FaerContext,
23}
24
25#[derive(Debug, Clone, PartialEq)]
26pub struct FaerVecRef<'a, T: Scalar> {
27    pub(crate) data: ColRef<'a, T>,
28    pub(crate) context: FaerContext,
29}
30
31#[derive(Debug, PartialEq)]
32pub struct FaerVecMut<'a, T: Scalar> {
33    pub(crate) data: ColMut<'a, T>,
34    pub(crate) context: FaerContext,
35}
36
37impl<T: Scalar> From<Col<T>> for FaerVec<T> {
38    fn from(data: Col<T>) -> Self {
39        Self {
40            data,
41            context: FaerContext::default(),
42        }
43    }
44}
45
46impl<T: Scalar> DefaultDenseMatrix for FaerVec<T> {
47    type M = FaerMat<T>;
48}
49
50impl_vector_common!(FaerVec<T>, FaerContext, Col<T>);
51impl_vector_common_ref!(FaerVecRef<'a, T>, FaerContext, ColRef<'a, T>);
52impl_vector_common_ref!(FaerVecMut<'a, T>, FaerContext, ColMut<'a, T>);
53
54macro_rules! impl_mul_scalar {
55    ($lhs:ty, $out:ty, $scalar:ty) => {
56        impl<T: Scalar> Mul<Scale<T>> for $lhs {
57            type Output = $out;
58            fn mul(self, rhs: Scale<T>) -> Self::Output {
59                let scale: $scalar = rhs.into();
60                Self::Output {
61                    data: &self.data * scale,
62                    context: self.context.clone(),
63                }
64            }
65        }
66    };
67}
68
69macro_rules! impl_div_scalar {
70    ($lhs:ty, $out:ty, $scalar:expr) => {
71        impl<'a, T: Scalar> Div<Scale<T>> for $lhs {
72            type Output = $out;
73            fn div(self, rhs: Scale<T>) -> Self::Output {
74                let inv_rhs: T = T::one() / rhs.value();
75                let scale = faer::Scale(inv_rhs);
76                Self::Output {
77                    data: &self.data * scale,
78                    context: self.context.clone(),
79                }
80            }
81        }
82    };
83}
84
85macro_rules! impl_mul_assign_scalar {
86    ($col_type:ty, $scalar:ty) => {
87        impl<'a, T: Scalar> MulAssign<Scale<T>> for $col_type {
88            fn mul_assign(&mut self, rhs: Scale<T>) {
89                let scale = faer::Scale(rhs.value());
90                self.data *= scale;
91            }
92        }
93    };
94}
95
96impl_mul_scalar!(FaerVec<T>, FaerVec<T>, faer::Scale<T>);
97impl_mul_scalar!(&FaerVec<T>, FaerVec<T>, faer::Scale<T>);
98impl_mul_scalar!(FaerVecRef<'_, T>, FaerVec<T>, faer::Scale<T>);
99impl_mul_scalar!(FaerVecMut<'_, T>, FaerVec<T>, faer::Scale<T>);
100impl_div_scalar!(FaerVec<T>, FaerVec<T>, faer::Scale::<T>);
101impl_mul_assign_scalar!(FaerVecMut<'a, T>, faer::Scale<T>);
102impl_mul_assign_scalar!(FaerVec<T>, faer::Scale<T>);
103
104impl_sub_assign!(FaerVec<T>, FaerVec<T>);
105impl_sub_assign!(FaerVec<T>, &FaerVec<T>);
106impl_sub_assign!(FaerVec<T>, FaerVecRef<'_, T>);
107impl_sub_assign!(FaerVec<T>, &FaerVecRef<'_, T>);
108
109impl_sub_assign!(FaerVecMut<'_, T>, FaerVec<T>);
110impl_sub_assign!(FaerVecMut<'_, T>, &FaerVec<T>);
111impl_sub_assign!(FaerVecMut<'_, T>, FaerVecRef<'_, T>);
112impl_sub_assign!(FaerVecMut<'_, T>, &FaerVecRef<'_, T>);
113
114impl_add_assign!(FaerVec<T>, FaerVec<T>);
115impl_add_assign!(FaerVec<T>, &FaerVec<T>);
116impl_add_assign!(FaerVec<T>, FaerVecRef<'_, T>);
117impl_add_assign!(FaerVec<T>, &FaerVecRef<'_, T>);
118
119impl_add_assign!(FaerVecMut<'_, T>, FaerVec<T>);
120impl_add_assign!(FaerVecMut<'_, T>, &FaerVec<T>);
121impl_add_assign!(FaerVecMut<'_, T>, FaerVecRef<'_, T>);
122impl_add_assign!(FaerVecMut<'_, T>, &FaerVecRef<'_, T>);
123
124impl_sub_both_ref!(&FaerVec<T>, &FaerVec<T>, FaerVec<T>);
125impl_sub_rhs!(&FaerVec<T>, FaerVec<T>, FaerVec<T>);
126impl_sub_both_ref!(&FaerVec<T>, FaerVecRef<'_, T>, FaerVec<T>);
127impl_sub_both_ref!(&FaerVec<T>, &FaerVecRef<'_, T>, FaerVec<T>);
128
129impl_sub_lhs!(FaerVec<T>, FaerVec<T>, FaerVec<T>);
130impl_sub_lhs!(FaerVec<T>, &FaerVec<T>, FaerVec<T>);
131impl_sub_lhs!(FaerVec<T>, FaerVecRef<'_, T>, FaerVec<T>);
132impl_sub_lhs!(FaerVec<T>, &FaerVecRef<'_, T>, FaerVec<T>);
133
134impl_sub_rhs!(FaerVecRef<'_, T>, FaerVec<T>, FaerVec<T>);
135impl_sub_both_ref!(FaerVecRef<'_, T>, &FaerVec<T>, FaerVec<T>);
136impl_sub_both_ref!(FaerVecRef<'_, T>, FaerVecRef<'_, T>, FaerVec<T>);
137impl_sub_both_ref!(FaerVecRef<'_, T>, &FaerVecRef<'_, T>, FaerVec<T>);
138
139impl_add_both_ref!(&FaerVec<T>, &FaerVec<T>, FaerVec<T>);
140impl_add_rhs!(&FaerVec<T>, FaerVec<T>, FaerVec<T>);
141impl_add_both_ref!(&FaerVec<T>, FaerVecRef<'_, T>, FaerVec<T>);
142impl_add_both_ref!(&FaerVec<T>, &FaerVecRef<'_, T>, FaerVec<T>);
143
144impl_add_lhs!(FaerVec<T>, FaerVec<T>, FaerVec<T>);
145impl_add_lhs!(FaerVec<T>, &FaerVec<T>, FaerVec<T>);
146impl_add_lhs!(FaerVec<T>, FaerVecRef<'_, T>, FaerVec<T>);
147impl_add_lhs!(FaerVec<T>, &FaerVecRef<'_, T>, FaerVec<T>);
148
149impl_add_rhs!(FaerVecRef<'_, T>, FaerVec<T>, FaerVec<T>);
150impl_add_both_ref!(FaerVecRef<'_, T>, &FaerVec<T>, FaerVec<T>);
151impl_add_both_ref!(FaerVecRef<'_, T>, FaerVecRef<'_, T>, FaerVec<T>);
152impl_add_both_ref!(FaerVecRef<'_, T>, &FaerVecRef<'_, T>, FaerVec<T>);
153
154impl_index!(FaerVec<T>);
155impl_index_mut!(FaerVec<T>);
156impl_index!(FaerVecRef<'_, T>);
157
158impl<T: Scalar> VectorHost for FaerVec<T> {
159    fn as_mut_slice(&mut self) -> &mut [Self::T] {
160        unsafe { slice::from_raw_parts_mut(self.data.as_ptr_mut(), self.len()) }
161    }
162    fn as_slice(&self) -> &[Self::T] {
163        unsafe { slice::from_raw_parts(self.data.as_ptr(), self.len()) }
164    }
165}
166
167impl<T: Scalar> Vector for FaerVec<T> {
168    type View<'a> = FaerVecRef<'a, T>;
169    type ViewMut<'a> = FaerVecMut<'a, T>;
170    type Index = FaerVecIndex;
171    fn context(&self) -> &Self::C {
172        &self.context
173    }
174    fn inner_mut(&mut self) -> &mut Self::Inner {
175        &mut self.data
176    }
177    fn len(&self) -> IndexType {
178        self.data.nrows()
179    }
180    fn get_index(&self, index: IndexType) -> Self::T {
181        self.data[index]
182    }
183    fn set_index(&mut self, index: IndexType, value: Self::T) {
184        self.data[index] = value;
185    }
186    fn norm(&self, k: i32) -> T {
187        match k {
188            1 => self.data.norm_l1(),
189            2 => self.data.norm_l2(),
190            _ => self
191                .data
192                .iter()
193                .fold(T::zero(), |acc, x| acc + x.pow(k))
194                .pow(T::one() / T::from(k as f64)),
195        }
196    }
197
198    fn squared_norm(&self, y: &Self, atol: &Self, rtol: Self::T) -> Self::T {
199        let mut acc = T::zero();
200        if y.len() != self.len() || y.len() != atol.len() {
201            panic!("Vector lengths do not match");
202        }
203        for i in 0..self.len() {
204            let yi = unsafe { y.data.get_unchecked(i) };
205            let ai = unsafe { atol.data.get_unchecked(i) };
206            let xi = unsafe { self.data.get_unchecked(i) };
207            acc += (*xi / (yi.abs() * rtol + *ai)).powi(2);
208        }
209        acc / Self::T::from(self.len() as f64)
210    }
211    fn as_view(&self) -> Self::View<'_> {
212        FaerVecRef {
213            data: self.data.as_ref(),
214            context: self.context.clone(),
215        }
216    }
217    fn as_view_mut(&mut self) -> Self::ViewMut<'_> {
218        FaerVecMut {
219            data: self.data.as_mut(),
220            context: self.context.clone(),
221        }
222    }
223    fn copy_from(&mut self, other: &Self) {
224        self.data.copy_from(&other.data)
225    }
226    fn copy_from_view(&mut self, other: &Self::View<'_>) {
227        self.data.copy_from(&other.data)
228    }
229    fn fill(&mut self, value: Self::T) {
230        self.data.iter_mut().for_each(|s| *s = value);
231    }
232    fn from_element(nstates: usize, value: Self::T, ctx: Self::C) -> Self {
233        let data = Col::from_fn(nstates, |_| value);
234        FaerVec { data, context: ctx }
235    }
236    fn from_vec(vec: Vec<Self::T>, ctx: Self::C) -> Self {
237        let data = Col::from_fn(vec.len(), |i| vec[i]);
238        FaerVec { data, context: ctx }
239    }
240    fn from_slice(slice: &[Self::T], ctx: Self::C) -> Self {
241        let data = Col::from_fn(slice.len(), |i| slice[i]);
242        FaerVec { data, context: ctx }
243    }
244    fn clone_as_vec(&self) -> Vec<Self::T> {
245        self.data.iter().cloned().collect()
246    }
247    fn zeros(nstates: usize, ctx: Self::C) -> Self {
248        Self::from_element(nstates, T::zero(), ctx)
249    }
250    fn axpy(&mut self, alpha: Self::T, x: &Self, beta: Self::T) {
251        zip!(self.data.as_mut(), x.data.as_ref())
252            .for_each(|unzip!(si, xi)| *si = *si * beta + *xi * alpha);
253    }
254    fn axpy_v(&mut self, alpha: Self::T, x: &Self::View<'_>, beta: Self::T) {
255        zip!(self.data.as_mut(), x.data).for_each(|unzip!(si, xi)| *si = *si * beta + *xi * alpha);
256    }
257    fn component_mul_assign(&mut self, other: &Self) {
258        zip!(self.data.as_mut(), other.data.as_ref()).for_each(|unzip!(s, o)| *s *= *o);
259    }
260    fn component_div_assign(&mut self, other: &Self) {
261        zip!(self.data.as_mut(), other.data.as_ref()).for_each(|unzip!(s, o)| *s /= *o);
262    }
263
264    fn root_finding(&self, g1: &Self) -> (bool, Self::T, i32) {
265        let mut max_frac = T::zero();
266        let mut max_frac_index = -1;
267        let mut found_root = false;
268        assert_eq!(self.len(), g1.len(), "Vector lengths do not match");
269        for i in 0..self.len() {
270            let g0 = unsafe { *self.data.get_unchecked(i) };
271            let g1 = unsafe { *g1.data.get_unchecked(i) };
272            if g1 == T::zero() {
273                found_root = true;
274            }
275            if g0 * g1 < T::zero() {
276                let frac = (g1 / (g1 - g0)).abs();
277                if frac > max_frac {
278                    max_frac = frac;
279                    max_frac_index = i as i32;
280                }
281            }
282        }
283        (found_root, max_frac, max_frac_index)
284    }
285
286    fn assign_at_indices(&mut self, indices: &Self::Index, value: Self::T) {
287        for i in indices.data.iter() {
288            self[*i] = value;
289        }
290    }
291
292    fn copy_from_indices(&mut self, other: &Self, indices: &Self::Index) {
293        for i in indices.data.iter() {
294            self[*i] = other[*i];
295        }
296    }
297
298    fn gather(&mut self, other: &Self, indices: &Self::Index) {
299        assert_eq!(self.len(), indices.len(), "Vector lengths do not match");
300        for (s, o) in self.data.iter_mut().zip(indices.data.iter()) {
301            *s = other[*o];
302        }
303    }
304
305    fn scatter(&self, indices: &Self::Index, other: &mut Self) {
306        assert_eq!(self.len(), indices.len(), "Vector lengths do not match");
307        for (s, o) in self.data.iter().zip(indices.data.iter()) {
308            other[*o] = *s;
309        }
310    }
311}
312
313impl VectorIndex for FaerVecIndex {
314    type C = FaerContext;
315    fn zeros(len: IndexType, ctx: Self::C) -> Self {
316        Self {
317            data: vec![0; len],
318            context: ctx,
319        }
320    }
321    fn len(&self) -> IndexType {
322        self.data.len() as IndexType
323    }
324    fn from_vec(v: Vec<IndexType>, ctx: Self::C) -> Self {
325        Self {
326            data: v,
327            context: ctx,
328        }
329    }
330    fn clone_as_vec(&self) -> Vec<IndexType> {
331        self.data.clone()
332    }
333    fn context(&self) -> &Self::C {
334        &self.context
335    }
336}
337
338impl<'a, T: Scalar> VectorView<'a> for FaerVecRef<'a, T> {
339    type Owned = FaerVec<T>;
340    fn into_owned(self) -> FaerVec<T> {
341        FaerVec {
342            data: self.data.to_owned(),
343            context: self.context,
344        }
345    }
346    fn squared_norm(&self, y: &Self::Owned, atol: &Self::Owned, rtol: Self::T) -> Self::T {
347        let mut acc = T::zero();
348        if y.len() != self.data.nrows() || y.data.nrows() != atol.data.nrows() {
349            panic!("Vector lengths do not match");
350        }
351        for i in 0..self.data.nrows() {
352            let yi = unsafe { y.data.get_unchecked(i) };
353            let ai = unsafe { atol.data.get_unchecked(i) };
354            let xi = unsafe { self.data.get_unchecked(i) };
355            acc += (*xi / (yi.abs() * rtol + *ai)).powi(2);
356        }
357        acc / Self::T::from(self.data.nrows() as f64)
358    }
359}
360
361impl<'a, T: Scalar> VectorViewMut<'a> for FaerVecMut<'a, T> {
362    type Owned = FaerVec<T>;
363    type View = FaerVecRef<'a, T>;
364    type Index = FaerVecIndex;
365    fn copy_from(&mut self, other: &Self::Owned) {
366        self.data.copy_from(&other.data);
367    }
368    fn copy_from_view(&mut self, other: &Self::View) {
369        self.data.copy_from(&other.data);
370    }
371    fn axpy(&mut self, alpha: Self::T, x: &Self::Owned, beta: Self::T) {
372        zip!(self.data.as_mut(), x.data.as_ref())
373            .for_each(|unzip!(si, xi)| *si = *si * beta + *xi * alpha);
374    }
375}
376
377// tests
378#[cfg(test)]
379mod tests {
380    use super::*;
381    use crate::scalar::scale;
382
383    #[test]
384    fn test_mult() {
385        let v = FaerVec::from_vec(vec![1.0, -2.0, 3.0], Default::default());
386        let s = scale(2.0);
387        let r = FaerVec::from_vec(vec![2.0, -4.0, 6.0], Default::default());
388        assert_eq!(v * s, r);
389    }
390
391    #[test]
392    fn test_mul_assign() {
393        let mut v = FaerVec::from_vec(vec![1.0, -2.0, 3.0], Default::default());
394        let s = scale(2.0);
395        let r = FaerVec::from_vec(vec![2.0, -4.0, 6.0], Default::default());
396        v.mul_assign(s);
397        assert_eq!(v, r);
398    }
399
400    #[test]
401    fn test_error_norm() {
402        let v = FaerVec::from_vec(vec![1.0, -2.0, 3.0], Default::default());
403        let y = FaerVec::from_vec(vec![1.0, 2.0, 3.0], Default::default());
404        let atol = FaerVec::from_vec(vec![0.1, 0.2, 0.3], Default::default());
405        let rtol = 0.1;
406        let mut tmp = y.clone() * scale(rtol);
407        tmp += &atol;
408        let mut r = v.clone();
409        r.component_div_assign(&tmp);
410        let errorn_check = r.data.squared_norm_l2() / 3.0;
411        assert!(
412            (v.squared_norm(&y, &atol, rtol) - errorn_check).abs() < 1e-10,
413            "{} vs {}",
414            v.squared_norm(&y, &atol, rtol),
415            errorn_check
416        );
417        assert!(
418            (v.squared_norm(&y, &atol, rtol) - errorn_check).abs() < 1e-10,
419            "{} vs {}",
420            v.squared_norm(&y, &atol, rtol),
421            errorn_check
422        );
423    }
424
425    #[test]
426    fn test_root_finding() {
427        super::super::tests::test_root_finding::<FaerVec<f64>>();
428    }
429
430    #[test]
431    fn test_from_slice() {
432        let slice = [1.0, 2.0, 3.0];
433        let v = FaerVec::from_slice(&slice, Default::default());
434        assert_eq!(v.clone_as_vec(), slice);
435    }
436
437    #[test]
438    fn test_into() {
439        let col: Col<f64> = Col::from_fn(3, |i| (i + 1) as f64);
440        let v: FaerVec<f64> = col.into();
441        assert_eq!(v.clone_as_vec(), vec![1.0, 2.0, 3.0]);
442    }
443}