1use std::ops::{Add, AddAssign, Div, Index, IndexMut, Mul, MulAssign, Sub, SubAssign};
2use std::slice;
3
4use faer::{unzip, zip, Col, ColMut, ColRef};
5
6use crate::{scalar::Scale, FaerContext, IndexType, Scalar, Vector};
7
8use crate::{FaerMat, VectorCommon, VectorHost, VectorIndex, VectorView, VectorViewMut};
9
10use super::utils::*;
11use super::DefaultDenseMatrix;
12
13#[derive(Debug, Clone, PartialEq)]
14pub struct FaerVec<T: Scalar> {
15 pub(crate) data: Col<T>,
16 pub(crate) context: FaerContext,
17}
18
19#[derive(Debug, Clone, PartialEq)]
20pub struct FaerVecIndex {
21 pub(crate) data: Vec<IndexType>,
22 pub(crate) context: FaerContext,
23}
24
25#[derive(Debug, Clone, PartialEq)]
26pub struct FaerVecRef<'a, T: Scalar> {
27 pub(crate) data: ColRef<'a, T>,
28 pub(crate) context: FaerContext,
29}
30
31#[derive(Debug, PartialEq)]
32pub struct FaerVecMut<'a, T: Scalar> {
33 pub(crate) data: ColMut<'a, T>,
34 pub(crate) context: FaerContext,
35}
36
37impl<T: Scalar> From<Col<T>> for FaerVec<T> {
38 fn from(data: Col<T>) -> Self {
39 Self {
40 data,
41 context: FaerContext::default(),
42 }
43 }
44}
45
46impl<T: Scalar> DefaultDenseMatrix for FaerVec<T> {
47 type M = FaerMat<T>;
48}
49
50impl_vector_common!(FaerVec<T>, FaerContext, Col<T>);
51impl_vector_common_ref!(FaerVecRef<'a, T>, FaerContext, ColRef<'a, T>);
52impl_vector_common_ref!(FaerVecMut<'a, T>, FaerContext, ColMut<'a, T>);
53
54macro_rules! impl_mul_scalar {
55 ($lhs:ty, $out:ty, $scalar:ty) => {
56 impl<T: Scalar> Mul<Scale<T>> for $lhs {
57 type Output = $out;
58 fn mul(self, rhs: Scale<T>) -> Self::Output {
59 let scale: $scalar = rhs.into();
60 Self::Output {
61 data: &self.data * scale,
62 context: self.context.clone(),
63 }
64 }
65 }
66 };
67}
68
69macro_rules! impl_div_scalar {
70 ($lhs:ty, $out:ty, $scalar:expr) => {
71 impl<'a, T: Scalar> Div<Scale<T>> for $lhs {
72 type Output = $out;
73 fn div(self, rhs: Scale<T>) -> Self::Output {
74 let inv_rhs: T = T::one() / rhs.value();
75 let scale = faer::Scale(inv_rhs);
76 Self::Output {
77 data: &self.data * scale,
78 context: self.context.clone(),
79 }
80 }
81 }
82 };
83}
84
85macro_rules! impl_mul_assign_scalar {
86 ($col_type:ty, $scalar:ty) => {
87 impl<'a, T: Scalar> MulAssign<Scale<T>> for $col_type {
88 fn mul_assign(&mut self, rhs: Scale<T>) {
89 let scale = faer::Scale(rhs.value());
90 self.data *= scale;
91 }
92 }
93 };
94}
95
96impl_mul_scalar!(FaerVec<T>, FaerVec<T>, faer::Scale<T>);
97impl_mul_scalar!(&FaerVec<T>, FaerVec<T>, faer::Scale<T>);
98impl_mul_scalar!(FaerVecRef<'_, T>, FaerVec<T>, faer::Scale<T>);
99impl_mul_scalar!(FaerVecMut<'_, T>, FaerVec<T>, faer::Scale<T>);
100impl_div_scalar!(FaerVec<T>, FaerVec<T>, faer::Scale::<T>);
101impl_mul_assign_scalar!(FaerVecMut<'a, T>, faer::Scale<T>);
102impl_mul_assign_scalar!(FaerVec<T>, faer::Scale<T>);
103
104impl_sub_assign!(FaerVec<T>, FaerVec<T>);
105impl_sub_assign!(FaerVec<T>, &FaerVec<T>);
106impl_sub_assign!(FaerVec<T>, FaerVecRef<'_, T>);
107impl_sub_assign!(FaerVec<T>, &FaerVecRef<'_, T>);
108
109impl_sub_assign!(FaerVecMut<'_, T>, FaerVec<T>);
110impl_sub_assign!(FaerVecMut<'_, T>, &FaerVec<T>);
111impl_sub_assign!(FaerVecMut<'_, T>, FaerVecRef<'_, T>);
112impl_sub_assign!(FaerVecMut<'_, T>, &FaerVecRef<'_, T>);
113
114impl_add_assign!(FaerVec<T>, FaerVec<T>);
115impl_add_assign!(FaerVec<T>, &FaerVec<T>);
116impl_add_assign!(FaerVec<T>, FaerVecRef<'_, T>);
117impl_add_assign!(FaerVec<T>, &FaerVecRef<'_, T>);
118
119impl_add_assign!(FaerVecMut<'_, T>, FaerVec<T>);
120impl_add_assign!(FaerVecMut<'_, T>, &FaerVec<T>);
121impl_add_assign!(FaerVecMut<'_, T>, FaerVecRef<'_, T>);
122impl_add_assign!(FaerVecMut<'_, T>, &FaerVecRef<'_, T>);
123
124impl_sub_both_ref!(&FaerVec<T>, &FaerVec<T>, FaerVec<T>);
125impl_sub_rhs!(&FaerVec<T>, FaerVec<T>, FaerVec<T>);
126impl_sub_both_ref!(&FaerVec<T>, FaerVecRef<'_, T>, FaerVec<T>);
127impl_sub_both_ref!(&FaerVec<T>, &FaerVecRef<'_, T>, FaerVec<T>);
128
129impl_sub_lhs!(FaerVec<T>, FaerVec<T>, FaerVec<T>);
130impl_sub_lhs!(FaerVec<T>, &FaerVec<T>, FaerVec<T>);
131impl_sub_lhs!(FaerVec<T>, FaerVecRef<'_, T>, FaerVec<T>);
132impl_sub_lhs!(FaerVec<T>, &FaerVecRef<'_, T>, FaerVec<T>);
133
134impl_sub_rhs!(FaerVecRef<'_, T>, FaerVec<T>, FaerVec<T>);
135impl_sub_both_ref!(FaerVecRef<'_, T>, &FaerVec<T>, FaerVec<T>);
136impl_sub_both_ref!(FaerVecRef<'_, T>, FaerVecRef<'_, T>, FaerVec<T>);
137impl_sub_both_ref!(FaerVecRef<'_, T>, &FaerVecRef<'_, T>, FaerVec<T>);
138
139impl_add_both_ref!(&FaerVec<T>, &FaerVec<T>, FaerVec<T>);
140impl_add_rhs!(&FaerVec<T>, FaerVec<T>, FaerVec<T>);
141impl_add_both_ref!(&FaerVec<T>, FaerVecRef<'_, T>, FaerVec<T>);
142impl_add_both_ref!(&FaerVec<T>, &FaerVecRef<'_, T>, FaerVec<T>);
143
144impl_add_lhs!(FaerVec<T>, FaerVec<T>, FaerVec<T>);
145impl_add_lhs!(FaerVec<T>, &FaerVec<T>, FaerVec<T>);
146impl_add_lhs!(FaerVec<T>, FaerVecRef<'_, T>, FaerVec<T>);
147impl_add_lhs!(FaerVec<T>, &FaerVecRef<'_, T>, FaerVec<T>);
148
149impl_add_rhs!(FaerVecRef<'_, T>, FaerVec<T>, FaerVec<T>);
150impl_add_both_ref!(FaerVecRef<'_, T>, &FaerVec<T>, FaerVec<T>);
151impl_add_both_ref!(FaerVecRef<'_, T>, FaerVecRef<'_, T>, FaerVec<T>);
152impl_add_both_ref!(FaerVecRef<'_, T>, &FaerVecRef<'_, T>, FaerVec<T>);
153
154impl_index!(FaerVec<T>);
155impl_index_mut!(FaerVec<T>);
156impl_index!(FaerVecRef<'_, T>);
157
158impl<T: Scalar> VectorHost for FaerVec<T> {
159 fn as_mut_slice(&mut self) -> &mut [Self::T] {
160 unsafe { slice::from_raw_parts_mut(self.data.as_ptr_mut(), self.len()) }
161 }
162 fn as_slice(&self) -> &[Self::T] {
163 unsafe { slice::from_raw_parts(self.data.as_ptr(), self.len()) }
164 }
165}
166
167impl<T: Scalar> Vector for FaerVec<T> {
168 type View<'a> = FaerVecRef<'a, T>;
169 type ViewMut<'a> = FaerVecMut<'a, T>;
170 type Index = FaerVecIndex;
171 fn context(&self) -> &Self::C {
172 &self.context
173 }
174 fn inner_mut(&mut self) -> &mut Self::Inner {
175 &mut self.data
176 }
177 fn len(&self) -> IndexType {
178 self.data.nrows()
179 }
180 fn get_index(&self, index: IndexType) -> Self::T {
181 self.data[index]
182 }
183 fn set_index(&mut self, index: IndexType, value: Self::T) {
184 self.data[index] = value;
185 }
186 fn norm(&self, k: i32) -> T {
187 match k {
188 1 => self.data.norm_l1(),
189 2 => self.data.norm_l2(),
190 _ => self
191 .data
192 .iter()
193 .fold(T::zero(), |acc, x| acc + x.pow(k))
194 .pow(T::one() / T::from(k as f64)),
195 }
196 }
197
198 fn squared_norm(&self, y: &Self, atol: &Self, rtol: Self::T) -> Self::T {
199 let mut acc = T::zero();
200 if y.len() != self.len() || y.len() != atol.len() {
201 panic!("Vector lengths do not match");
202 }
203 for i in 0..self.len() {
204 let yi = unsafe { y.data.get_unchecked(i) };
205 let ai = unsafe { atol.data.get_unchecked(i) };
206 let xi = unsafe { self.data.get_unchecked(i) };
207 acc += (*xi / (yi.abs() * rtol + *ai)).powi(2);
208 }
209 acc / Self::T::from(self.len() as f64)
210 }
211 fn as_view(&self) -> Self::View<'_> {
212 FaerVecRef {
213 data: self.data.as_ref(),
214 context: self.context.clone(),
215 }
216 }
217 fn as_view_mut(&mut self) -> Self::ViewMut<'_> {
218 FaerVecMut {
219 data: self.data.as_mut(),
220 context: self.context.clone(),
221 }
222 }
223 fn copy_from(&mut self, other: &Self) {
224 self.data.copy_from(&other.data)
225 }
226 fn copy_from_view(&mut self, other: &Self::View<'_>) {
227 self.data.copy_from(&other.data)
228 }
229 fn fill(&mut self, value: Self::T) {
230 self.data.iter_mut().for_each(|s| *s = value);
231 }
232 fn from_element(nstates: usize, value: Self::T, ctx: Self::C) -> Self {
233 let data = Col::from_fn(nstates, |_| value);
234 FaerVec { data, context: ctx }
235 }
236 fn from_vec(vec: Vec<Self::T>, ctx: Self::C) -> Self {
237 let data = Col::from_fn(vec.len(), |i| vec[i]);
238 FaerVec { data, context: ctx }
239 }
240 fn from_slice(slice: &[Self::T], ctx: Self::C) -> Self {
241 let data = Col::from_fn(slice.len(), |i| slice[i]);
242 FaerVec { data, context: ctx }
243 }
244 fn clone_as_vec(&self) -> Vec<Self::T> {
245 self.data.iter().cloned().collect()
246 }
247 fn zeros(nstates: usize, ctx: Self::C) -> Self {
248 Self::from_element(nstates, T::zero(), ctx)
249 }
250 fn axpy(&mut self, alpha: Self::T, x: &Self, beta: Self::T) {
251 zip!(self.data.as_mut(), x.data.as_ref())
252 .for_each(|unzip!(si, xi)| *si = *si * beta + *xi * alpha);
253 }
254 fn axpy_v(&mut self, alpha: Self::T, x: &Self::View<'_>, beta: Self::T) {
255 zip!(self.data.as_mut(), x.data).for_each(|unzip!(si, xi)| *si = *si * beta + *xi * alpha);
256 }
257 fn component_mul_assign(&mut self, other: &Self) {
258 zip!(self.data.as_mut(), other.data.as_ref()).for_each(|unzip!(s, o)| *s *= *o);
259 }
260 fn component_div_assign(&mut self, other: &Self) {
261 zip!(self.data.as_mut(), other.data.as_ref()).for_each(|unzip!(s, o)| *s /= *o);
262 }
263
264 fn root_finding(&self, g1: &Self) -> (bool, Self::T, i32) {
265 let mut max_frac = T::zero();
266 let mut max_frac_index = -1;
267 let mut found_root = false;
268 assert_eq!(self.len(), g1.len(), "Vector lengths do not match");
269 for i in 0..self.len() {
270 let g0 = unsafe { *self.data.get_unchecked(i) };
271 let g1 = unsafe { *g1.data.get_unchecked(i) };
272 if g1 == T::zero() {
273 found_root = true;
274 }
275 if g0 * g1 < T::zero() {
276 let frac = (g1 / (g1 - g0)).abs();
277 if frac > max_frac {
278 max_frac = frac;
279 max_frac_index = i as i32;
280 }
281 }
282 }
283 (found_root, max_frac, max_frac_index)
284 }
285
286 fn assign_at_indices(&mut self, indices: &Self::Index, value: Self::T) {
287 for i in indices.data.iter() {
288 self[*i] = value;
289 }
290 }
291
292 fn copy_from_indices(&mut self, other: &Self, indices: &Self::Index) {
293 for i in indices.data.iter() {
294 self[*i] = other[*i];
295 }
296 }
297
298 fn gather(&mut self, other: &Self, indices: &Self::Index) {
299 assert_eq!(self.len(), indices.len(), "Vector lengths do not match");
300 for (s, o) in self.data.iter_mut().zip(indices.data.iter()) {
301 *s = other[*o];
302 }
303 }
304
305 fn scatter(&self, indices: &Self::Index, other: &mut Self) {
306 assert_eq!(self.len(), indices.len(), "Vector lengths do not match");
307 for (s, o) in self.data.iter().zip(indices.data.iter()) {
308 other[*o] = *s;
309 }
310 }
311}
312
313impl VectorIndex for FaerVecIndex {
314 type C = FaerContext;
315 fn zeros(len: IndexType, ctx: Self::C) -> Self {
316 Self {
317 data: vec![0; len],
318 context: ctx,
319 }
320 }
321 fn len(&self) -> IndexType {
322 self.data.len() as IndexType
323 }
324 fn from_vec(v: Vec<IndexType>, ctx: Self::C) -> Self {
325 Self {
326 data: v,
327 context: ctx,
328 }
329 }
330 fn clone_as_vec(&self) -> Vec<IndexType> {
331 self.data.clone()
332 }
333 fn context(&self) -> &Self::C {
334 &self.context
335 }
336}
337
338impl<'a, T: Scalar> VectorView<'a> for FaerVecRef<'a, T> {
339 type Owned = FaerVec<T>;
340 fn into_owned(self) -> FaerVec<T> {
341 FaerVec {
342 data: self.data.to_owned(),
343 context: self.context,
344 }
345 }
346 fn squared_norm(&self, y: &Self::Owned, atol: &Self::Owned, rtol: Self::T) -> Self::T {
347 let mut acc = T::zero();
348 if y.len() != self.data.nrows() || y.data.nrows() != atol.data.nrows() {
349 panic!("Vector lengths do not match");
350 }
351 for i in 0..self.data.nrows() {
352 let yi = unsafe { y.data.get_unchecked(i) };
353 let ai = unsafe { atol.data.get_unchecked(i) };
354 let xi = unsafe { self.data.get_unchecked(i) };
355 acc += (*xi / (yi.abs() * rtol + *ai)).powi(2);
356 }
357 acc / Self::T::from(self.data.nrows() as f64)
358 }
359}
360
361impl<'a, T: Scalar> VectorViewMut<'a> for FaerVecMut<'a, T> {
362 type Owned = FaerVec<T>;
363 type View = FaerVecRef<'a, T>;
364 type Index = FaerVecIndex;
365 fn copy_from(&mut self, other: &Self::Owned) {
366 self.data.copy_from(&other.data);
367 }
368 fn copy_from_view(&mut self, other: &Self::View) {
369 self.data.copy_from(&other.data);
370 }
371 fn axpy(&mut self, alpha: Self::T, x: &Self::Owned, beta: Self::T) {
372 zip!(self.data.as_mut(), x.data.as_ref())
373 .for_each(|unzip!(si, xi)| *si = *si * beta + *xi * alpha);
374 }
375}
376
377#[cfg(test)]
379mod tests {
380 use super::*;
381 use crate::scalar::scale;
382
383 #[test]
384 fn test_mult() {
385 let v = FaerVec::from_vec(vec![1.0, -2.0, 3.0], Default::default());
386 let s = scale(2.0);
387 let r = FaerVec::from_vec(vec![2.0, -4.0, 6.0], Default::default());
388 assert_eq!(v * s, r);
389 }
390
391 #[test]
392 fn test_mul_assign() {
393 let mut v = FaerVec::from_vec(vec![1.0, -2.0, 3.0], Default::default());
394 let s = scale(2.0);
395 let r = FaerVec::from_vec(vec![2.0, -4.0, 6.0], Default::default());
396 v.mul_assign(s);
397 assert_eq!(v, r);
398 }
399
400 #[test]
401 fn test_error_norm() {
402 let v = FaerVec::from_vec(vec![1.0, -2.0, 3.0], Default::default());
403 let y = FaerVec::from_vec(vec![1.0, 2.0, 3.0], Default::default());
404 let atol = FaerVec::from_vec(vec![0.1, 0.2, 0.3], Default::default());
405 let rtol = 0.1;
406 let mut tmp = y.clone() * scale(rtol);
407 tmp += &atol;
408 let mut r = v.clone();
409 r.component_div_assign(&tmp);
410 let errorn_check = r.data.squared_norm_l2() / 3.0;
411 assert!(
412 (v.squared_norm(&y, &atol, rtol) - errorn_check).abs() < 1e-10,
413 "{} vs {}",
414 v.squared_norm(&y, &atol, rtol),
415 errorn_check
416 );
417 assert!(
418 (v.squared_norm(&y, &atol, rtol) - errorn_check).abs() < 1e-10,
419 "{} vs {}",
420 v.squared_norm(&y, &atol, rtol),
421 errorn_check
422 );
423 }
424
425 #[test]
426 fn test_root_finding() {
427 super::super::tests::test_root_finding::<FaerVec<f64>>();
428 }
429
430 #[test]
431 fn test_from_slice() {
432 let slice = [1.0, 2.0, 3.0];
433 let v = FaerVec::from_slice(&slice, Default::default());
434 assert_eq!(v.clone_as_vec(), slice);
435 }
436
437 #[test]
438 fn test_into() {
439 let col: Col<f64> = Col::from_fn(3, |i| (i + 1) as f64);
440 let v: FaerVec<f64> = col.into();
441 assert_eq!(v.clone_as_vec(), vec![1.0, 2.0, 3.0]);
442 }
443}