ring_math/
vector.rs

1use scalarff::BigUint;
2use scalarff::FieldElement;
3
4#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
5#[derive(Clone, PartialEq)]
6pub struct Vector<T: FieldElement>(pub Vec<T>);
7
8impl<T: FieldElement> Default for Vector<T> {
9    fn default() -> Self {
10        Self::new()
11    }
12}
13
14impl<T: FieldElement> Vector<T> {
15    pub fn new() -> Self {
16        Vector(Vec::new())
17    }
18
19    pub fn zero(len: usize) -> Self {
20        Self(vec![T::zero(); len])
21    }
22
23    /// Compute the inner product (dot product) of two vectors.
24    /// Vectors are multiplied element-wise and then summed.
25    pub fn dot_product(&self, other: Vector<T>) -> T {
26        let mut out = T::zero();
27        for (a, b) in std::iter::zip(self.iter(), other.iter()) {
28            out += a.clone() * b.clone();
29        }
30        out
31    }
32
33    /// Calculate the l1 norm for this vector. That is
34    /// the summation of all coefficients
35    pub fn norm_l1(&self) -> BigUint {
36        self.0
37            .iter()
38            .fold(BigUint::from(0u32), |acc, x| acc + x.to_biguint())
39    }
40
41    /// Calculate the l2 norm for this vector. That is
42    /// the square root of the summation of each coefficient squared
43    ///
44    /// Specifically, we're calculating the square root in the integer
45    /// field, not the prime field
46    pub fn norm_l2(&self) -> BigUint {
47        self.0
48            .iter()
49            .fold(BigUint::from(0u32), |acc, x| {
50                acc + (x.to_biguint() * x.to_biguint())
51            })
52            .sqrt()
53    }
54
55    /// Calculate the l-infinity norm for this vector. That is
56    /// the largest coefficient
57    pub fn norm_max(&self) -> BigUint {
58        let mut max = T::zero().to_biguint();
59        for i in &self.0 {
60            if i.to_biguint() > max {
61                max = i.to_biguint();
62            }
63        }
64        max
65    }
66
67    /// Sum the elements of the array and return the result.
68    pub fn sum(&self) -> T {
69        self.0.iter().fold(T::zero(), |acc, x| acc + x.clone())
70    }
71
72    /// Sample a uniform random vector of the specified dimension
73    /// from the underlying field.
74    #[cfg(feature = "rand")]
75    pub fn sample_uniform<R: rand::Rng>(len: usize, rng: &mut R) -> Self {
76        Self((0..len).map(|_| T::sample_uniform(rng)).collect())
77    }
78
79    pub fn from_vec(v: Vec<T>) -> Self {
80        Vector(v)
81    }
82
83    pub fn to_vec(&self) -> Vec<T> {
84        self.0.clone()
85    }
86
87    pub fn to_vec_ref(&self) -> &Vec<T> {
88        &self.0
89    }
90
91    pub fn len(&self) -> usize {
92        self.0.len()
93    }
94
95    pub fn is_empty(&self) -> bool {
96        self.0.is_empty()
97    }
98
99    pub fn push(&mut self, v: T) {
100        self.0.push(v);
101    }
102
103    pub fn iter(&self) -> std::slice::Iter<T> {
104        self.0.iter()
105    }
106
107    pub fn iter_mut(&mut self) -> std::slice::IterMut<T> {
108        self.0.iter_mut()
109    }
110}
111
112impl<T: FieldElement> std::fmt::Display for Vector<T> {
113    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
114        for v in &self.0 {
115            write!(f, "{}, ", v)?;
116        }
117        Ok(())
118    }
119}
120
121impl<T: FieldElement> std::ops::Index<std::ops::Range<usize>> for Vector<T> {
122    type Output = [T];
123
124    fn index(&self, index: std::ops::Range<usize>) -> &[T] {
125        &self.0[index]
126    }
127}
128
129impl<T: FieldElement> std::ops::Index<usize> for Vector<T> {
130    type Output = T;
131
132    fn index(&self, index: usize) -> &T {
133        &self.0[index]
134    }
135}
136
137impl<T: FieldElement> std::ops::Mul<Vector<T>> for Vector<T> {
138    type Output = Vector<T>;
139
140    fn mul(self, other: Vector<T>) -> Vector<T> {
141        assert_eq!(self.0.len(), other.len(), "vector mul length mismatch");
142        let mut out = Vec::new();
143        for i in 0..self.len() {
144            out.push(self.to_vec_ref()[i].clone() * other.to_vec_ref()[i].clone());
145        }
146        Vector::from_vec(out)
147    }
148}
149
150impl<T: FieldElement> std::ops::Add<Vector<T>> for Vector<T> {
151    type Output = Vector<T>;
152
153    fn add(self, other: Vector<T>) -> Vector<T> {
154        assert_eq!(self.0.len(), other.len(), "vector mul length mismatch");
155        let mut out = Vec::new();
156        for i in 0..self.len() {
157            out.push(self.to_vec_ref()[i].clone() + other.to_vec_ref()[i].clone());
158        }
159        Vector::from_vec(out)
160    }
161}
162
163impl<T: FieldElement> std::ops::Mul<T> for Vector<T> {
164    type Output = Vector<T>;
165
166    fn mul(self, other: T) -> Vector<T> {
167        Vector::from_vec(self.iter().map(|v| v.clone() * other.clone()).collect())
168    }
169}