1use std::ops::{Add, AddAssign, Div, Index, IndexMut, Mul, MulAssign, Sub, SubAssign};
2
3use super::utils::*;
4use nalgebra::{DVector, DVectorView, DVectorViewMut, LpNorm};
5
6use crate::{IndexType, NalgebraContext, NalgebraMat, NalgebraScalar, Scalar, Scale, VectorHost};
7
8use super::{DefaultDenseMatrix, Vector, VectorCommon, VectorIndex, VectorView, VectorViewMut};
9
10#[derive(Debug, Clone, PartialEq)]
11pub struct NalgebraIndex {
12 pub(crate) data: DVector<IndexType>,
13 pub(crate) context: NalgebraContext,
14}
15
16#[derive(Debug, Clone, PartialEq)]
17pub struct NalgebraVec<T: NalgebraScalar> {
18 pub(crate) data: DVector<T>,
19 pub(crate) context: NalgebraContext,
20}
21
22#[derive(Debug, Clone, PartialEq)]
23pub struct NalgebraVecRef<'a, T: NalgebraScalar> {
24 pub(crate) data: DVectorView<'a, T>,
25 pub(crate) context: NalgebraContext,
26}
27
28#[derive(Debug, PartialEq)]
29pub struct NalgebraVecMut<'a, T: NalgebraScalar> {
30 pub(crate) data: DVectorViewMut<'a, T>,
31 pub(crate) context: NalgebraContext,
32}
33
34impl<T: NalgebraScalar> From<DVector<T>> for NalgebraVec<T> {
35 fn from(data: DVector<T>) -> Self {
36 Self {
37 data,
38 context: NalgebraContext,
39 }
40 }
41}
42
43impl<T: NalgebraScalar> DefaultDenseMatrix for NalgebraVec<T> {
44 type M = NalgebraMat<T>;
45}
46
47impl_vector_common!(NalgebraVec<T>, NalgebraContext, DVector<T>, NalgebraScalar);
48impl_vector_common_ref!(
49 NalgebraVecRef<'a, T>,
50 NalgebraContext,
51 DVectorView<'a, T>,
52 NalgebraScalar
53);
54impl_vector_common_ref!(
55 NalgebraVecMut<'a, T>,
56 NalgebraContext,
57 DVectorViewMut<'a, T>,
58 NalgebraScalar
59);
60
61macro_rules! impl_mul_scalar {
62 ($lhs:ty, $out:ty, $scalar:ty) => {
63 impl<T: NalgebraScalar> Mul<Scale<T>> for $lhs {
64 type Output = $out;
65 #[inline]
66 fn mul(self, rhs: Scale<T>) -> Self::Output {
67 let scale: $scalar = rhs.value();
68 Self::Output {
69 data: &self.data * scale,
70 context: self.context,
71 }
72 }
73 }
74 };
75}
76
77macro_rules! impl_div_scalar {
78 ($lhs:ty, $out:ty, $scalar:expr) => {
79 impl<'a, T: NalgebraScalar> Div<Scale<T>> for $lhs {
80 type Output = $out;
81 #[inline]
82 fn div(self, rhs: Scale<T>) -> Self::Output {
83 let inv_rhs: T = T::one() / rhs.value();
84 Self::Output {
85 data: self.data * inv_rhs,
86 context: self.context,
87 }
88 }
89 }
90 };
91}
92
93macro_rules! impl_mul_assign_scalar {
94 ($col_type:ty, $scalar:ty) => {
95 impl<'a, T: NalgebraScalar> MulAssign<Scale<T>> for $col_type {
96 #[inline]
97 fn mul_assign(&mut self, rhs: Scale<T>) {
98 let scale = rhs.value();
99 self.data *= scale;
100 }
101 }
102 };
103}
104
105impl_mul_scalar!(NalgebraVec<T>, NalgebraVec<T>, T);
106impl_mul_scalar!(&NalgebraVec<T>, NalgebraVec<T>, T);
107impl_mul_scalar!(NalgebraVecRef<'_, T>, NalgebraVec<T>, T);
108impl_mul_scalar!(NalgebraVecMut<'_, T>, NalgebraVec<T>, T);
109impl_div_scalar!(NalgebraVec<T>, NalgebraVec<T>, T);
110impl_mul_assign_scalar!(NalgebraVecMut<'a, T>, T);
111impl_mul_assign_scalar!(NalgebraVec<T>, T);
112
113impl_sub_assign!(NalgebraVec<T>, NalgebraVec<T>, NalgebraScalar);
114impl_sub_assign!(NalgebraVec<T>, &NalgebraVec<T>, NalgebraScalar);
115impl_sub_assign!(NalgebraVec<T>, NalgebraVecRef<'_, T>, NalgebraScalar);
116impl_sub_assign!(NalgebraVec<T>, &NalgebraVecRef<'_, T>, NalgebraScalar);
117
118impl_sub_assign!(NalgebraVecMut<'_, T>, NalgebraVec<T>, NalgebraScalar);
119impl_sub_assign!(NalgebraVecMut<'_, T>, &NalgebraVec<T>, NalgebraScalar);
120impl_sub_assign!(NalgebraVecMut<'_, T>, NalgebraVecRef<'_, T>, NalgebraScalar);
121impl_sub_assign!(
122 NalgebraVecMut<'_, T>,
123 &NalgebraVecRef<'_, T>,
124 NalgebraScalar
125);
126
127impl_add_assign!(NalgebraVec<T>, NalgebraVec<T>, NalgebraScalar);
128impl_add_assign!(NalgebraVec<T>, &NalgebraVec<T>, NalgebraScalar);
129impl_add_assign!(NalgebraVec<T>, NalgebraVecRef<'_, T>, NalgebraScalar);
130impl_add_assign!(NalgebraVec<T>, &NalgebraVecRef<'_, T>, NalgebraScalar);
131
132impl_add_assign!(NalgebraVecMut<'_, T>, NalgebraVec<T>, NalgebraScalar);
133impl_add_assign!(NalgebraVecMut<'_, T>, &NalgebraVec<T>, NalgebraScalar);
134impl_add_assign!(NalgebraVecMut<'_, T>, NalgebraVecRef<'_, T>, NalgebraScalar);
135impl_add_assign!(
136 NalgebraVecMut<'_, T>,
137 &NalgebraVecRef<'_, T>,
138 NalgebraScalar
139);
140
141impl_sub_both_ref!(
142 &NalgebraVec<T>,
143 &NalgebraVec<T>,
144 NalgebraVec<T>,
145 NalgebraScalar
146);
147impl_sub_rhs!(
148 &NalgebraVec<T>,
149 NalgebraVec<T>,
150 NalgebraVec<T>,
151 NalgebraScalar
152);
153impl_sub_both_ref!(
154 &NalgebraVec<T>,
155 NalgebraVecRef<'_, T>,
156 NalgebraVec<T>,
157 NalgebraScalar
158);
159impl_sub_both_ref!(
160 &NalgebraVec<T>,
161 &NalgebraVecRef<'_, T>,
162 NalgebraVec<T>,
163 NalgebraScalar
164);
165
166impl_sub_lhs!(
167 NalgebraVec<T>,
168 NalgebraVec<T>,
169 NalgebraVec<T>,
170 NalgebraScalar
171);
172impl_sub_lhs!(
173 NalgebraVec<T>,
174 &NalgebraVec<T>,
175 NalgebraVec<T>,
176 NalgebraScalar
177);
178impl_sub_lhs!(
179 NalgebraVec<T>,
180 NalgebraVecRef<'_, T>,
181 NalgebraVec<T>,
182 NalgebraScalar
183);
184impl_sub_lhs!(
185 NalgebraVec<T>,
186 &NalgebraVecRef<'_, T>,
187 NalgebraVec<T>,
188 NalgebraScalar
189);
190
191impl_sub_rhs!(
192 NalgebraVecRef<'_, T>,
193 NalgebraVec<T>,
194 NalgebraVec<T>,
195 NalgebraScalar
196);
197impl_sub_both_ref!(
198 NalgebraVecRef<'_, T>,
199 &NalgebraVec<T>,
200 NalgebraVec<T>,
201 NalgebraScalar
202);
203impl_sub_both_ref!(
204 NalgebraVecRef<'_, T>,
205 NalgebraVecRef<'_, T>,
206 NalgebraVec<T>,
207 NalgebraScalar
208);
209impl_sub_both_ref!(
210 NalgebraVecRef<'_, T>,
211 &NalgebraVecRef<'_, T>,
212 NalgebraVec<T>,
213 NalgebraScalar
214);
215
216impl_add_both_ref!(
217 &NalgebraVec<T>,
218 &NalgebraVec<T>,
219 NalgebraVec<T>,
220 NalgebraScalar
221);
222impl_add_rhs!(
223 &NalgebraVec<T>,
224 NalgebraVec<T>,
225 NalgebraVec<T>,
226 NalgebraScalar
227);
228impl_add_both_ref!(
229 &NalgebraVec<T>,
230 NalgebraVecRef<'_, T>,
231 NalgebraVec<T>,
232 NalgebraScalar
233);
234impl_add_both_ref!(
235 &NalgebraVec<T>,
236 &NalgebraVecRef<'_, T>,
237 NalgebraVec<T>,
238 NalgebraScalar
239);
240
241impl_add_lhs!(
242 NalgebraVec<T>,
243 NalgebraVec<T>,
244 NalgebraVec<T>,
245 NalgebraScalar
246);
247impl_add_lhs!(
248 NalgebraVec<T>,
249 &NalgebraVec<T>,
250 NalgebraVec<T>,
251 NalgebraScalar
252);
253impl_add_lhs!(
254 NalgebraVec<T>,
255 NalgebraVecRef<'_, T>,
256 NalgebraVec<T>,
257 NalgebraScalar
258);
259impl_add_lhs!(
260 NalgebraVec<T>,
261 &NalgebraVecRef<'_, T>,
262 NalgebraVec<T>,
263 NalgebraScalar
264);
265
266impl_add_rhs!(
267 NalgebraVecRef<'_, T>,
268 NalgebraVec<T>,
269 NalgebraVec<T>,
270 NalgebraScalar
271);
272impl_add_both_ref!(
273 NalgebraVecRef<'_, T>,
274 &NalgebraVec<T>,
275 NalgebraVec<T>,
276 NalgebraScalar
277);
278impl_add_both_ref!(
279 NalgebraVecRef<'_, T>,
280 NalgebraVecRef<'_, T>,
281 NalgebraVec<T>,
282 NalgebraScalar
283);
284impl_add_both_ref!(
285 NalgebraVecRef<'_, T>,
286 &NalgebraVecRef<'_, T>,
287 NalgebraVec<T>,
288 NalgebraScalar
289);
290
291impl_index!(NalgebraVec<T>, NalgebraScalar);
292impl_index_mut!(NalgebraVec<T>, NalgebraScalar);
293
294impl_index!(NalgebraVecRef<'_, T>, NalgebraScalar);
295
296impl VectorIndex for NalgebraIndex {
297 type C = NalgebraContext;
298 fn zeros(len: IndexType, ctx: Self::C) -> Self {
299 let data = DVector::from_element(len, 0);
300 Self { data, context: ctx }
301 }
302 fn len(&self) -> crate::IndexType {
303 self.data.len()
304 }
305 fn from_vec(v: Vec<IndexType>, ctx: Self::C) -> Self {
306 let data = DVector::from_vec(v);
307 Self { data, context: ctx }
308 }
309 fn clone_as_vec(&self) -> Vec<IndexType> {
310 self.data.iter().copied().collect()
311 }
312 fn context(&self) -> &Self::C {
313 &self.context
314 }
315}
316
317impl<'a, T: NalgebraScalar> VectorView<'a> for NalgebraVecRef<'a, T> {
318 type Owned = NalgebraVec<T>;
319
320 fn into_owned(self) -> Self::Owned {
321 Self::Owned {
322 data: self.data.into_owned(),
323 context: self.context,
324 }
325 }
326 fn squared_norm(&self, y: &Self::Owned, atol: &Self::Owned, rtol: Self::T) -> Self::T {
327 let mut acc = T::zero();
328 if y.len() != self.data.len() || y.len() != atol.len() {
329 panic!("Vector lengths do not match");
330 }
331 for i in 0..self.data.len() {
332 let yi = unsafe { y.data.get_unchecked(i) };
333 let ai = unsafe { atol.data.get_unchecked(i) };
334 let xi = unsafe { self.data.get_unchecked(i) };
335 let term = *xi / (yi.abs() * rtol + *ai);
336 acc += term * term;
337 }
338 acc / Self::T::from_f64(self.data.len() as f64).unwrap()
339 }
340}
341
342impl<'a, T: NalgebraScalar> VectorViewMut<'a> for NalgebraVecMut<'a, T> {
343 type Owned = NalgebraVec<T>;
344 type View = NalgebraVecRef<'a, T>;
345 type Index = NalgebraIndex;
346 fn copy_from(&mut self, other: &Self::Owned) {
347 self.data.copy_from(&other.data);
348 }
349 fn copy_from_view(&mut self, other: &Self::View) {
350 self.data.copy_from(&other.data);
351 }
352 fn axpy(&mut self, alpha: Self::T, x: &Self::Owned, beta: Self::T) {
353 self.data.axpy(alpha, &x.data, beta);
354 }
355}
356
357impl<T: NalgebraScalar> VectorHost for NalgebraVec<T> {
358 fn as_slice(&self) -> &[Self::T] {
359 self.data.as_slice()
360 }
361 fn as_mut_slice(&mut self) -> &mut [Self::T] {
362 self.data.as_mut_slice()
363 }
364}
365
366impl<T: NalgebraScalar> Vector for NalgebraVec<T> {
367 type View<'a> = NalgebraVecRef<'a, T>;
368 type ViewMut<'a> = NalgebraVecMut<'a, T>;
369 type Index = NalgebraIndex;
370 fn len(&self) -> IndexType {
371 self.data.len()
372 }
373 fn inner_mut(&mut self) -> &mut Self::Inner {
374 &mut self.data
375 }
376 fn context(&self) -> &Self::C {
377 &self.context
378 }
379 fn norm(&self, k: i32) -> Self::T {
380 self.data.apply_norm(&LpNorm(k))
381 }
382 fn get_index(&self, index: IndexType) -> Self::T {
383 self.data[index]
384 }
385 fn set_index(&mut self, index: IndexType, value: Self::T) {
386 self.data[index] = value;
387 }
388 fn squared_norm(&self, y: &Self, atol: &Self, rtol: Self::T) -> Self::T {
389 let mut acc = T::zero();
390 if y.len() != self.len() || y.len() != atol.len() {
391 panic!("Vector lengths do not match");
392 }
393 for i in 0..self.len() {
394 let yi = unsafe { y.data.get_unchecked(i) };
395 let ai = unsafe { atol.data.get_unchecked(i) };
396 let xi = unsafe { self.data.get_unchecked(i) };
397 let term = *xi / (yi.abs() * rtol + *ai);
398 acc += term * term;
399 }
400 acc / Self::T::from_f64(self.len() as f64).unwrap()
401 }
402 fn as_view(&self) -> Self::View<'_> {
403 Self::View {
404 data: self.data.as_view(),
405 context: self.context,
406 }
407 }
408 fn as_view_mut(&mut self) -> Self::ViewMut<'_> {
409 Self::ViewMut {
410 data: self.data.as_view_mut(),
411 context: self.context,
412 }
413 }
414 fn copy_from(&mut self, other: &Self) {
415 self.data.copy_from(&other.data);
416 }
417 fn fill(&mut self, value: Self::T) {
418 self.data.iter_mut().for_each(|x: &mut _| *x = value);
419 }
420 fn copy_from_view(&mut self, other: &Self::View<'_>) {
421 self.data.copy_from(&other.data);
422 }
423 fn from_element(nstates: usize, value: T, ctx: Self::C) -> Self {
424 let data = DVector::from_element(nstates, value);
425 Self { data, context: ctx }
426 }
427 fn from_vec(vec: Vec<T>, ctx: Self::C) -> Self {
428 let data = DVector::from_vec(vec);
429 Self { data, context: ctx }
430 }
431 fn from_slice(slice: &[T], ctx: Self::C) -> Self {
432 let data = DVector::from_column_slice(slice);
433 Self { data, context: ctx }
434 }
435 fn clone_as_vec(&self) -> Vec<Self::T> {
436 self.data.iter().copied().collect()
437 }
438 fn zeros(nstates: usize, ctx: Self::C) -> Self {
439 let data = DVector::zeros(nstates);
440 Self { data, context: ctx }
441 }
442 fn axpy(&mut self, alpha: T, x: &Self, beta: T) {
443 self.data.axpy(alpha, &x.data, beta);
444 }
445 fn axpy_v(&mut self, alpha: Self::T, x: &Self::View<'_>, beta: Self::T) {
446 self.data.axpy(alpha, &x.data, beta);
447 }
448 fn component_div_assign(&mut self, other: &Self) {
449 self.data.component_div_assign(&other.data);
450 }
451 fn component_mul_assign(&mut self, other: &Self) {
452 self.data.component_mul_assign(&other.data);
453 }
454
455 fn root_finding(&self, g1: &Self) -> (bool, Self::T, i32) {
456 let mut max_frac = T::zero();
457 let mut max_frac_index = -1;
458 let mut found_root = false;
459 assert_eq!(self.len(), g1.len(), "Vector lengths do not match");
460 for i in 0..self.len() {
461 let g0 = unsafe { *self.data.get_unchecked(i) };
462 let g1 = unsafe { *g1.data.get_unchecked(i) };
463 if g1 == T::zero() {
464 found_root = true;
465 }
466 if g0 * g1 < T::zero() {
467 let frac = (g1 / (g1 - g0)).abs();
468 if frac > max_frac {
469 max_frac = frac;
470 max_frac_index = i as i32;
471 }
472 }
473 }
474 (found_root, max_frac, max_frac_index)
475 }
476
477 fn assign_at_indices(&mut self, indices: &Self::Index, value: Self::T) {
478 for i in indices.data.iter() {
479 self[*i] = value;
480 }
481 }
482
483 fn copy_from_indices(&mut self, other: &Self, indices: &Self::Index) {
484 for i in indices.data.iter() {
485 self[*i] = other[*i];
486 }
487 }
488
489 fn gather(&mut self, other: &Self, indices: &Self::Index) {
490 assert_eq!(self.len(), indices.len(), "Vector lengths do not match");
491 for (s, o) in self.data.iter_mut().zip(indices.data.iter()) {
492 *s = other[*o];
493 }
494 }
495
496 fn scatter(&self, indices: &Self::Index, other: &mut Self) {
497 assert_eq!(self.len(), indices.len(), "Vector lengths do not match");
498 for (s, o) in self.data.iter().zip(indices.data.iter()) {
499 other[*o] = *s;
500 }
501 }
502}
503
504#[cfg(test)]
506mod tests {
507 use super::*;
508
509 #[test]
510 fn test_error_norm() {
511 let v = NalgebraVec::from_vec(vec![1.0, -2.0, 3.0], Default::default());
512 let y = NalgebraVec::from_vec(vec![1.0, 2.0, 3.0], Default::default());
513 let atol = NalgebraVec::from_vec(vec![0.1, 0.2, 0.3], Default::default());
514 let rtol = 0.1;
515 let mut tmp = y.clone() * Scale(rtol);
516 tmp += &atol;
517 let mut r = v.clone();
518 r.component_div_assign(&tmp);
519 let errorn_check = r.data.norm_squared() / 3.0;
520 assert_eq!(v.squared_norm(&y, &atol, rtol), errorn_check);
521 let vview = v.as_view();
522 assert_eq!(
523 VectorView::squared_norm(&vview, &y, &atol, rtol),
524 errorn_check
525 );
526 }
527
528 #[test]
529 fn test_root_finding() {
530 super::super::tests::test_root_finding::<NalgebraVec<f64>>();
531 }
532
533 #[test]
534 fn test_from_slice() {
535 let slice = [1.0, 2.0, 3.0];
536 let v = NalgebraVec::from_slice(&slice, Default::default());
537 assert_eq!(v.clone_as_vec(), slice);
538 }
539
540 #[test]
541 fn test_into() {
542 let vec = DVector::from_vec(vec![1.0, 2.0, 3.0]);
543 let v: NalgebraVec<f64> = vec.into();
544 assert_eq!(v.clone_as_vec(), vec![1.0, 2.0, 3.0]);
545 }
546}