1use scalarff::BigUint;
2use scalarff::FieldElement;
3
4#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
5#[derive(Clone, PartialEq)]
6pub struct Vector<T: FieldElement>(pub Vec<T>);
7
8impl<T: FieldElement> Default for Vector<T> {
9 fn default() -> Self {
10 Self::new()
11 }
12}
13
14impl<T: FieldElement> Vector<T> {
15 pub fn new() -> Self {
16 Vector(Vec::new())
17 }
18
19 pub fn zero(len: usize) -> Self {
20 Self(vec![T::zero(); len])
21 }
22
23 pub fn dot_product(&self, other: Vector<T>) -> T {
26 let mut out = T::zero();
27 for (a, b) in std::iter::zip(self.iter(), other.iter()) {
28 out += a.clone() * b.clone();
29 }
30 out
31 }
32
33 pub fn norm_l1(&self) -> BigUint {
36 self.0
37 .iter()
38 .fold(BigUint::from(0u32), |acc, x| acc + x.to_biguint())
39 }
40
41 pub fn norm_l2(&self) -> BigUint {
47 self.0
48 .iter()
49 .fold(BigUint::from(0u32), |acc, x| {
50 acc + (x.to_biguint() * x.to_biguint())
51 })
52 .sqrt()
53 }
54
55 pub fn norm_max(&self) -> BigUint {
58 let mut max = T::zero().to_biguint();
59 for i in &self.0 {
60 if i.to_biguint() > max {
61 max = i.to_biguint();
62 }
63 }
64 max
65 }
66
67 pub fn sum(&self) -> T {
69 self.0.iter().fold(T::zero(), |acc, x| acc + x.clone())
70 }
71
72 #[cfg(feature = "rand")]
75 pub fn sample_uniform<R: rand::Rng>(len: usize, rng: &mut R) -> Self {
76 Self((0..len).map(|_| T::sample_uniform(rng)).collect())
77 }
78
79 pub fn from_vec(v: Vec<T>) -> Self {
80 Vector(v)
81 }
82
83 pub fn to_vec(&self) -> Vec<T> {
84 self.0.clone()
85 }
86
87 pub fn to_vec_ref(&self) -> &Vec<T> {
88 &self.0
89 }
90
91 pub fn len(&self) -> usize {
92 self.0.len()
93 }
94
95 pub fn is_empty(&self) -> bool {
96 self.0.is_empty()
97 }
98
99 pub fn push(&mut self, v: T) {
100 self.0.push(v);
101 }
102
103 pub fn iter(&self) -> std::slice::Iter<T> {
104 self.0.iter()
105 }
106
107 pub fn iter_mut(&mut self) -> std::slice::IterMut<T> {
108 self.0.iter_mut()
109 }
110}
111
112impl<T: FieldElement> std::fmt::Display for Vector<T> {
113 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
114 for v in &self.0 {
115 write!(f, "{}, ", v)?;
116 }
117 Ok(())
118 }
119}
120
121impl<T: FieldElement> std::ops::Index<std::ops::Range<usize>> for Vector<T> {
122 type Output = [T];
123
124 fn index(&self, index: std::ops::Range<usize>) -> &[T] {
125 &self.0[index]
126 }
127}
128
129impl<T: FieldElement> std::ops::Index<usize> for Vector<T> {
130 type Output = T;
131
132 fn index(&self, index: usize) -> &T {
133 &self.0[index]
134 }
135}
136
137impl<T: FieldElement> std::ops::Mul<Vector<T>> for Vector<T> {
138 type Output = Vector<T>;
139
140 fn mul(self, other: Vector<T>) -> Vector<T> {
141 assert_eq!(self.0.len(), other.len(), "vector mul length mismatch");
142 let mut out = Vec::new();
143 for i in 0..self.len() {
144 out.push(self.to_vec_ref()[i].clone() * other.to_vec_ref()[i].clone());
145 }
146 Vector::from_vec(out)
147 }
148}
149
150impl<T: FieldElement> std::ops::Add<Vector<T>> for Vector<T> {
151 type Output = Vector<T>;
152
153 fn add(self, other: Vector<T>) -> Vector<T> {
154 assert_eq!(self.0.len(), other.len(), "vector mul length mismatch");
155 let mut out = Vec::new();
156 for i in 0..self.len() {
157 out.push(self.to_vec_ref()[i].clone() + other.to_vec_ref()[i].clone());
158 }
159 Vector::from_vec(out)
160 }
161}
162
163impl<T: FieldElement> std::ops::Mul<T> for Vector<T> {
164 type Output = Vector<T>;
165
166 fn mul(self, other: T) -> Vector<T> {
167 Vector::from_vec(self.iter().map(|v| v.clone() * other.clone()).collect())
168 }
169}