1use {
7 core::ops::{Add, Div, Mul, Neg, Sub},
8 derive_more::From,
9 generic_array::{
10 typenum::{
11 operator_aliases::{Add1, Maximum},
12 Diff, Max, Sum, B1,
13 },
14 ArrayLength, GenericArray,
15 },
16};
17
18pub trait Field
19where
20 Self: Copy
21 + Add<Self, Output = Self>
22 + Mul<Self, Output = Self>
23 + Neg<Output = Self>
24 + Sub<Self, Output = Self>
25 + Div<Self, Output = Self>,
26{
27 const ONE: Self;
28 const ZERO: Self;
29 fn equals(&self, lhs: Self) -> bool;
30}
31
32macro_rules! impl_field {
33 ($type:ty) => {
34 impl Field for $type {
35 const ONE: $type = 1;
36 const ZERO: $type = 0;
37 fn equals(&self, lhs: $type) -> bool {
38 *self == lhs
39 }
40 }
41 };
42 ($($type:ty),+) => {
43 $(impl_field!($type);)+
44 }
45}
46
47impl_field!(isize, i8, i16, i32, i64);
48
49#[derive(Clone, Debug, Default, From, PartialEq)]
50pub struct Polynomial<F: Field, N: ArrayLength<F> + Add<B1>>(GenericArray<F, Add1<N>>)
51where
52 Add1<N>: ArrayLength<F>;
53
54impl<F: Field, N: ArrayLength<F> + Add<B1>> Polynomial<F, N>
55where
56 Add1<N>: ArrayLength<F>,
57{
58 pub fn deg(&self) -> usize {
59 self.0
60 .iter()
61 .enumerate()
62 .rev()
63 .find(|(_, a)| !a.equals(F::ZERO))
64 .map(|(idx, _)| idx)
65 .unwrap_or(0)
66 }
67
68 pub fn leading_coefficient(&self) -> F {
69 *self
70 .0
71 .iter()
72 .enumerate()
73 .find(|(index, _)| *index == self.deg())
74 .map(|(_, value)| value)
75 .unwrap_or(&F::ZERO)
76 }
77}
78
79impl<F: Field, N: ArrayLength<F> + Add<B1>> Neg for Polynomial<F, N>
80where
81 Add1<N>: ArrayLength<F>,
82{
83 type Output = Self;
84
85 fn neg(self) -> Self::Output {
86 GenericArray::from(self.0.iter().map(|a| -*a).collect()).into()
87 }
88}
89
90impl<F: Field, N: ArrayLength<F> + Add<B1> + Max<M>, M: ArrayLength<F> + Add<B1>>
91 Add<Polynomial<F, M>> for Polynomial<F, N>
92where
93 Add1<N>: ArrayLength<F>,
94 Add1<M>: ArrayLength<F>,
95 Maximum<N, M>: ArrayLength<F> + Add<B1>,
96 Add1<Maximum<N, M>>: ArrayLength<F>,
97{
98 type Output = Polynomial<F, Maximum<N, M>>;
99
100 fn add(self, rhs: Polynomial<F, M>) -> Self::Output {
101 GenericArray::from(
102 (0..=usize::max(N::USIZE, M::USIZE))
103 .map(|i| *self.0.get(i).unwrap_or(&F::ZERO) + *rhs.0.get(i).unwrap_or(&F::ZERO))
104 .collect(),
105 )
106 .into()
107 }
108}
109
110impl<F: Field, N: ArrayLength<F> + Add<B1> + Max<M>, M: ArrayLength<F> + Add<B1>>
111 Sub<Polynomial<F, M>> for Polynomial<F, N>
112where
113 Add1<N>: ArrayLength<F>,
114 Add1<M>: ArrayLength<F>,
115 Maximum<N, M>: ArrayLength<F> + Add<B1>,
116 Add1<Maximum<N, M>>: ArrayLength<F>,
117{
118 type Output = Polynomial<F, Maximum<N, M>>;
119
120 fn sub(self, rhs: Polynomial<F, M>) -> Self::Output {
121 GenericArray::from(
122 (0..=usize::max(N::USIZE, M::USIZE))
123 .map(|i| *self.0.get(i).unwrap_or(&F::ZERO) - *rhs.0.get(i).unwrap_or(&F::ZERO))
124 .collect(),
125 )
126 .into()
127 }
128}
129
130impl<F: Field, N: ArrayLength<F> + Add<M> + Add<B1>, M: ArrayLength<F> + Add<B1>>
131 Mul<Polynomial<F, M>> for Polynomial<F, N>
132where
133 Add1<N>: ArrayLength<F>,
134 Add1<M>: ArrayLength<F>,
135 Sum<N, M>: ArrayLength<F> + Add<B1>,
136 Add1<Sum<N, M>>: ArrayLength<F>,
137{
138 type Output = Polynomial<F, Sum<N, M>>;
139
140 fn mul(self, rhs: Polynomial<F, M>) -> Self::Output {
141 use std::iter::repeat;
142 self.0
143 .iter()
144 .enumerate()
145 .flat_map(|(i, p)| rhs.0.iter().enumerate().map(move |(j, q)| (i + j, *p * *q)))
146 .fold(
147 GenericArray::from(repeat(F::ZERO).take(N::USIZE + M::USIZE + 1).collect()),
148 |acc, (k, c)| {
149 GenericArray::from(
150 acc.iter()
151 .enumerate()
152 .map(|(i, &v)| if i == k { v + c } else { v })
153 .collect(),
154 )
155 },
156 )
157 .into()
158 }
159}
160
161impl<
162 F: Field + PartialEq,
163 N: ArrayLength<F> + Add<B1> + Sub<M> + Max<Sum<Diff<N, M>, M>, Output = N>,
164 M: ArrayLength<F> + Add<B1>,
165 > Div<Polynomial<F, M>> for Polynomial<F, N>
166where
167 Add1<N>: ArrayLength<F>,
168 Add1<M>: ArrayLength<F>,
169 Diff<N, M>: ArrayLength<F> + Add<B1> + Add<M> + Max<Output = Diff<N, M>> + Add<Diff<N, M>>,
170 Add1<Diff<N, M>>: ArrayLength<F>,
171 Sum<Diff<N, M>, M>: ArrayLength<F> + Add<B1>,
172 Add1<Sum<Diff<N, M>, M>>: ArrayLength<F>,
173 Maximum<N, Sum<Diff<N, M>, M>>: ArrayLength<F> + Add<B1>,
174 Add1<Maximum<N, Sum<Diff<N, M>, M>>>: ArrayLength<F>,
175{
176 type Output = Result<Polynomial<F, Diff<N, M>>, ()>;
177
178 fn div(self, rhs: Polynomial<F, M>) -> Self::Output {
179 let (self_degree, rhs_degree) = (self.deg(), rhs.deg());
180 match () {
181 _ if self_degree < rhs_degree => {
182 Ok(GenericArray::from((0..=N::USIZE - M::USIZE).map(|_| F::ZERO).collect()).into())
183 }
184 _ if self_degree == 0 && *rhs.0.first().unwrap_or(&F::ZERO) == F::ZERO => Err(()),
185 _ if self_degree == 0 => Ok(Polynomial::from(GenericArray::from(
186 (0..=N::USIZE - M::USIZE)
187 .map(|index| {
188 if index == 0 {
189 *self.0.first().unwrap_or(&F::ZERO) / *rhs.0.first().unwrap_or(&F::ZERO)
190 } else {
191 F::ZERO
192 }
193 })
194 .collect(),
195 ))),
196 _ => {
197 let first_quotient = Polynomial::<F, Diff<N, M>>::from(GenericArray::from(
198 (0..=N::USIZE - M::USIZE)
199 .map(|index| {
200 if index == self.deg() - rhs.deg() {
201 self.leading_coefficient() / rhs.leading_coefficient()
202 } else {
203 F::ZERO
204 }
205 })
206 .collect(),
207 ));
208 let remainder = self - first_quotient.clone() * rhs.clone();
209 (remainder / rhs).map(|second_quotient| first_quotient + second_quotient)
210 }
211 }
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use {
218 super::Polynomial,
219 generic_array::{
220 typenum::consts::{U0, U1, U2, U3, U4},
221 GenericArray,
222 },
223 };
224
225 #[test]
226 fn test_add() {
227 let poly1 = Polynomial::<isize, U2>::from(GenericArray::from([1, 2, 3]));
229 let poly2 = Polynomial::<isize, U2>::from(GenericArray::from([4, 5, 6]));
230 let expected_result = Polynomial::<isize, U2>::from(GenericArray::from([5, 7, 9]));
231 assert_eq!(poly1 + poly2, expected_result);
232
233 let poly1 = Polynomial::<isize, U0>::from(GenericArray::from([1]));
235 let poly2 = Polynomial::<isize, U0>::from(GenericArray::from([2]));
236 let expected_result = Polynomial::<isize, U0>::from(GenericArray::from([3]));
237 assert_eq!(poly1 + poly2, expected_result);
238
239 let poly1 = Polynomial::<isize, U3>::from(GenericArray::from([1, 2, 3, 4]));
241 let poly2 = Polynomial::<isize, U3>::from(GenericArray::from([5, 6, 7, 8]));
242 let expected_result = Polynomial::<isize, U3>::from(GenericArray::from([6, 8, 10, 12]));
243 assert_eq!(poly1 + poly2, expected_result);
244
245 let poly1 = Polynomial::<isize, U3>::from(GenericArray::from([1, 2, 3, 4]));
247 let poly2 = Polynomial::<isize, U4>::from(GenericArray::from([5, 6, 7, 8, 6]));
248 let expected_result = Polynomial::<isize, U4>::from(GenericArray::from([6, 8, 10, 12, 6]));
249 assert_eq!(poly1 + poly2, expected_result);
250 }
251
252 #[test]
253 fn test_sub() {
254 let poly1 = Polynomial::<isize, U2>::from(GenericArray::from([1, 2, 3]));
256 let poly2 = Polynomial::<isize, U2>::from(GenericArray::from([4, 5, 6]));
257 let expected_result = Polynomial::<isize, U2>::from(GenericArray::from([-3, -3, -3]));
258 assert_eq!(poly1 - poly2, expected_result);
259
260 let poly1 = Polynomial::<isize, U0>::from(GenericArray::from([1]));
262 let poly2 = Polynomial::<isize, U0>::from(GenericArray::from([2]));
263 let expected_result = Polynomial::<isize, U0>::from(GenericArray::from([-1]));
264 assert_eq!(poly1 - poly2, expected_result);
265
266 let poly1 = Polynomial::<isize, U3>::from(GenericArray::from([1, 2, 3, 4]));
268 let poly2 = Polynomial::<isize, U3>::from(GenericArray::from([5, 6, 7, 8]));
269 let expected_result = Polynomial::<isize, U3>::from(GenericArray::from([-4, -4, -4, -4]));
270 assert_eq!(poly1 - poly2, expected_result);
271
272 let poly1 = Polynomial::<isize, U3>::from(GenericArray::from([1, 2, 3, 4]));
274 let poly2 = Polynomial::<isize, U2>::from(GenericArray::from([5, 6, 7]));
275 let expected_result = Polynomial::<isize, U3>::from(GenericArray::from([-4, -4, -4, 4]));
276 assert_eq!(poly1 - poly2, expected_result);
277 }
278
279 #[test]
280 fn test_neg() {
281 let poly = Polynomial::<isize, U3>::from(GenericArray::from([1, 2, 3, 4]));
282 let expected_result = Polynomial::<isize, U3>::from(GenericArray::from([-1, -2, -3, -4]));
283 assert_eq!(-poly, expected_result);
284 }
285
286 #[test]
287 fn test_mul() {
288 let p1 = Polynomial::<isize, U3>::from(GenericArray::from([2, 3, 4, 5]));
289 let p2 = Polynomial::<isize, U2>::from(GenericArray::from([7, 6, 5]));
290 let expected = Polynomial::from(GenericArray::from([14, 33, 56, 74, 50, 25]));
291 assert_eq!(p1 * p2, expected);
292
293 let p1 = Polynomial::<isize, U2>::from(GenericArray::from([1, 2, 3]));
294 let p2 = Polynomial::<isize, U2>::from(GenericArray::from([4, 5, 6]));
295 let expected = Polynomial::from(GenericArray::from([4, 13, 28, 27, 18]));
296 assert_eq!(p1 * p2, expected);
297
298 let p1 = Polynomial::<isize, U0>::from(GenericArray::from([3]));
299 let p2 = Polynomial::<isize, U1>::from(GenericArray::from([3, 1]));
300 let expected = Polynomial::from(GenericArray::from([9, 3]));
301 assert_eq!(p1 * p2, expected);
302
303 let p1 = Polynomial::<isize, U2>::from(GenericArray::from([1, 2, 2]));
304 let p2 = Polynomial::<isize, U1>::from(GenericArray::from([-1, 1]));
305 let expected = Polynomial::from(GenericArray::from([-1, -1, 0, 2]));
306 assert_eq!(p1 * p2, expected);
307 }
308
309 #[test]
310 fn test_div() {
311 let p1 = Polynomial::<isize, U3>::from(GenericArray::from([-1, -1, 0, 2]));
312 let p2 = Polynomial::<isize, U1>::from(GenericArray::from([-1, 1]));
313 let expected = Polynomial::from(GenericArray::from([1, 2, 2]));
314 assert_eq!((p1 / p2).unwrap(), expected);
315
316 let p1 = Polynomial::<isize, U4>::from(GenericArray::from([4, 13, 28, 27, 18]));
317 let p2 = Polynomial::<isize, U2>::from(GenericArray::from([4, 5, 6]));
318 let expected = Polynomial::from(GenericArray::from([1, 2, 3]));
319 assert_eq!((p1 / p2).unwrap(), expected);
320 }
321}