Skip to main content

numra_core/
vector.rs

1//! Vector trait for numerical computation.
2//!
3//! This module defines the [`Vector`] trait which abstracts over different
4//! vector implementations, allowing algorithms to work with:
5//! - `Vec<S>` (standard library)
6//! - faer vectors (when using numra-linalg)
7//! - Fixed-size arrays `[S; N]`
8//!
9//! # Design Philosophy
10//!
11//! The trait is designed around BLAS-like operations that are fundamental
12//! to numerical algorithms: axpy, dot product, norms, etc.
13//!
14//! Author: Moussa Leblouba
15//! Date: 4 February 2026
16//! Modified: 2 May 2026
17
18#[cfg(not(feature = "std"))]
19use alloc::vec::Vec;
20
21use crate::Scalar;
22
23/// A vector type for numerical computation.
24///
25/// This trait provides the essential operations needed by ODE/SDE solvers
26/// and other numerical algorithms.
27///
28/// # Example
29///
30/// ```rust
31/// use numra_core::Vector;
32///
33/// fn compute_weighted_sum<V: Vector<f64>>(a: f64, x: &V, b: f64, y: &V) -> V {
34///     let mut result = x.clone();
35///     result.scale(a);
36///     result.axpy(b, y);
37///     result
38/// }
39/// ```
40pub trait Vector<S: Scalar>: Clone + Sized {
41    /// Create a zero vector of given length.
42    fn zeros(len: usize) -> Self;
43
44    /// Create a vector filled with a constant value.
45    fn fill(len: usize, value: S) -> Self;
46
47    /// Create from a slice.
48    fn from_slice(data: &[S]) -> Self;
49
50    /// Length of the vector.
51    fn len(&self) -> usize;
52
53    /// Check if empty.
54    #[inline]
55    fn is_empty(&self) -> bool {
56        self.len() == 0
57    }
58
59    /// Get element at index (panics if out of bounds).
60    fn get(&self, i: usize) -> S;
61
62    /// Set element at index.
63    fn set(&mut self, i: usize, value: S);
64
65    /// Get mutable reference to element.
66    fn get_mut(&mut self, i: usize) -> &mut S;
67
68    /// Return as a slice.
69    fn as_slice(&self) -> &[S];
70
71    /// Return as mutable slice.
72    fn as_mut_slice(&mut self) -> &mut [S];
73
74    /// Copy elements from another vector.
75    fn copy_from(&mut self, other: &Self);
76
77    // ===== BLAS-like Operations =====
78
79    /// AXPY: y = a*x + y (fundamental BLAS operation)
80    ///
81    /// This is the most important operation for numerical algorithms.
82    fn axpy(&mut self, a: S, x: &Self);
83
84    /// Scaled add: y = a*x + b*y
85    fn axpby(&mut self, a: S, x: &Self, b: S);
86
87    /// Dot product: x·y
88    fn dot(&self, other: &Self) -> S;
89
90    /// Scale in place: x *= a
91    fn scale(&mut self, a: S);
92
93    // ===== Norms =====
94
95    /// Euclidean norm: ||x||_2 = sqrt(x·x)
96    #[inline]
97    fn norm2(&self) -> S {
98        self.dot(self).sqrt()
99    }
100
101    /// Infinity norm: ||x||_∞ = max|x_i|
102    fn norm_inf(&self) -> S;
103
104    /// 1-norm: ||x||_1 = Σ|x_i|
105    fn norm1(&self) -> S;
106
107    /// Weighted RMS norm: sqrt(mean((x/w)²))
108    /// Used for error control in ODE solvers.
109    fn weighted_rms_norm(&self, weights: &Self) -> S {
110        let n = S::from_usize(self.len());
111        let mut sum = S::ZERO;
112        for i in 0..self.len() {
113            let xi = self.get(i) / weights.get(i);
114            sum += xi * xi;
115        }
116        (sum / n).sqrt()
117    }
118
119    // ===== Element-wise Operations =====
120
121    /// Element-wise absolute value in place.
122    fn abs_inplace(&mut self);
123
124    /// Element-wise maximum: `self[i] = max(self[i], other[i])`
125    fn max_elementwise(&mut self, other: &Self);
126
127    /// Element-wise minimum: `self[i] = min(self[i], other[i])`
128    fn min_elementwise(&mut self, other: &Self);
129
130    // ===== Reductions =====
131
132    /// Sum of all elements.
133    fn sum(&self) -> S;
134
135    /// Maximum element.
136    fn max_element(&self) -> S;
137
138    /// Minimum element.
139    fn min_element(&self) -> S;
140
141    // ===== Utility =====
142
143    /// Apply a function to each element.
144    fn map_inplace<F: Fn(S) -> S>(&mut self, f: F);
145}
146
147// ============================================================================
148// Implementation for Vec<S>
149// ============================================================================
150
151impl<S: Scalar> Vector<S> for Vec<S> {
152    #[inline]
153    fn zeros(len: usize) -> Self {
154        vec![S::ZERO; len]
155    }
156
157    #[inline]
158    fn fill(len: usize, value: S) -> Self {
159        vec![value; len]
160    }
161
162    #[inline]
163    fn from_slice(data: &[S]) -> Self {
164        data.to_vec()
165    }
166
167    #[inline]
168    fn len(&self) -> usize {
169        Vec::len(self)
170    }
171
172    #[inline]
173    fn get(&self, i: usize) -> S {
174        self[i]
175    }
176
177    #[inline]
178    fn set(&mut self, i: usize, value: S) {
179        self[i] = value;
180    }
181
182    #[inline]
183    fn get_mut(&mut self, i: usize) -> &mut S {
184        &mut self[i]
185    }
186
187    #[inline]
188    fn as_slice(&self) -> &[S] {
189        self
190    }
191
192    #[inline]
193    fn as_mut_slice(&mut self) -> &mut [S] {
194        self
195    }
196
197    #[inline]
198    fn copy_from(&mut self, other: &Self) {
199        self.copy_from_slice(other);
200    }
201
202    fn axpy(&mut self, a: S, x: &Self) {
203        debug_assert_eq!(self.len(), x.len());
204        for (yi, xi) in self.iter_mut().zip(x.iter()) {
205            *yi += a * *xi;
206        }
207    }
208
209    fn axpby(&mut self, a: S, x: &Self, b: S) {
210        debug_assert_eq!(self.len(), x.len());
211        for (yi, xi) in self.iter_mut().zip(x.iter()) {
212            *yi = a * *xi + b * *yi;
213        }
214    }
215
216    fn dot(&self, other: &Self) -> S {
217        debug_assert_eq!(self.len(), other.len());
218        self.iter()
219            .zip(other.iter())
220            .fold(S::ZERO, |acc, (a, b)| acc + *a * *b)
221    }
222
223    fn scale(&mut self, a: S) {
224        for x in self.iter_mut() {
225            *x *= a;
226        }
227    }
228
229    fn norm_inf(&self) -> S {
230        self.iter().fold(S::ZERO, |acc, x| acc.max(x.abs()))
231    }
232
233    fn norm1(&self) -> S {
234        self.iter().fold(S::ZERO, |acc, x| acc + x.abs())
235    }
236
237    fn abs_inplace(&mut self) {
238        for x in self.iter_mut() {
239            *x = x.abs();
240        }
241    }
242
243    fn max_elementwise(&mut self, other: &Self) {
244        debug_assert_eq!(self.len(), other.len());
245        for (yi, xi) in self.iter_mut().zip(other.iter()) {
246            *yi = yi.max(*xi);
247        }
248    }
249
250    fn min_elementwise(&mut self, other: &Self) {
251        debug_assert_eq!(self.len(), other.len());
252        for (yi, xi) in self.iter_mut().zip(other.iter()) {
253            *yi = yi.min(*xi);
254        }
255    }
256
257    fn sum(&self) -> S {
258        self.iter().fold(S::ZERO, |acc, x| acc + *x)
259    }
260
261    fn max_element(&self) -> S {
262        self.iter().fold(S::NEG_INFINITY, |acc, x| acc.max(*x))
263    }
264
265    fn min_element(&self) -> S {
266        self.iter().fold(S::INFINITY, |acc, x| acc.min(*x))
267    }
268
269    fn map_inplace<F: Fn(S) -> S>(&mut self, f: F) {
270        for x in self.iter_mut() {
271            *x = f(*x);
272        }
273    }
274}
275
276// ============================================================================
277// Tests
278// ============================================================================
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283
284    #[test]
285    fn test_zeros() {
286        let v: Vec<f64> = Vector::zeros(5);
287        assert_eq!(v.len(), 5);
288        for x in &v {
289            assert_eq!(*x, 0.0);
290        }
291    }
292
293    #[test]
294    fn test_fill() {
295        let v: Vec<f64> = Vector::fill(3, 2.5);
296        assert_eq!(v, vec![2.5, 2.5, 2.5]);
297    }
298
299    #[test]
300    fn test_axpy() {
301        let x: Vec<f64> = vec![1.0, 2.0, 3.0];
302        let mut y: Vec<f64> = vec![4.0, 5.0, 6.0];
303        y.axpy(2.0, &x);
304        assert_eq!(y, vec![6.0, 9.0, 12.0]);
305    }
306
307    #[test]
308    fn test_axpby() {
309        let x: Vec<f64> = vec![1.0, 2.0, 3.0];
310        let mut y: Vec<f64> = vec![4.0, 5.0, 6.0];
311        // y = 2*x + 0.5*y = [2,4,6] + [2,2.5,3] = [4, 6.5, 9]
312        y.axpby(2.0, &x, 0.5);
313        assert!((y[0] - 4.0).abs() < 1e-10);
314        assert!((y[1] - 6.5).abs() < 1e-10);
315        assert!((y[2] - 9.0).abs() < 1e-10);
316    }
317
318    #[test]
319    fn test_dot() {
320        let x: Vec<f64> = vec![1.0, 2.0, 3.0];
321        let y: Vec<f64> = vec![4.0, 5.0, 6.0];
322        // 1*4 + 2*5 + 3*6 = 4 + 10 + 18 = 32
323        assert!((x.dot(&y) - 32.0).abs() < 1e-10);
324    }
325
326    #[test]
327    fn test_norm2() {
328        let v: Vec<f64> = vec![3.0, 4.0];
329        assert!((v.norm2() - 5.0).abs() < 1e-10);
330    }
331
332    #[test]
333    fn test_norm_inf() {
334        let v: Vec<f64> = vec![-5.0, 3.0, -1.0];
335        assert!((v.norm_inf() - 5.0).abs() < 1e-10);
336    }
337
338    #[test]
339    fn test_norm1() {
340        let v: Vec<f64> = vec![-1.0, 2.0, -3.0];
341        assert!((v.norm1() - 6.0).abs() < 1e-10);
342    }
343
344    #[test]
345    fn test_scale() {
346        let mut v: Vec<f64> = vec![1.0, 2.0, 3.0];
347        v.scale(2.0);
348        assert_eq!(v, vec![2.0, 4.0, 6.0]);
349    }
350
351    #[test]
352    fn test_sum() {
353        let v: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0];
354        assert!((v.sum() - 10.0).abs() < 1e-10);
355    }
356
357    #[test]
358    fn test_max_min_element() {
359        let v: Vec<f64> = vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0];
360        assert!((v.max_element() - 9.0).abs() < 1e-10);
361        assert!((v.min_element() - 1.0).abs() < 1e-10);
362    }
363
364    #[test]
365    fn test_weighted_rms_norm() {
366        // y = [2, 4], w = [1, 2]
367        // (y/w)^2 = [4, 4], mean = 4, sqrt = 2
368        let y: Vec<f64> = vec![2.0, 4.0];
369        let w: Vec<f64> = vec![1.0, 2.0];
370        assert!((y.weighted_rms_norm(&w) - 2.0).abs() < 1e-10);
371    }
372
373    #[test]
374    fn test_map_inplace() {
375        let mut v: Vec<f64> = vec![1.0, 4.0, 9.0];
376        v.map_inplace(|x| x.sqrt());
377        assert!((v[0] - 1.0).abs() < 1e-10);
378        assert!((v[1] - 2.0).abs() < 1e-10);
379        assert!((v[2] - 3.0).abs() < 1e-10);
380    }
381}