1use std::fmt::Debug;
30use std::fmt::Display;
31use std::hash::Hash;
32use std::ops::Add;
33use std::ops::AddAssign;
34use std::ops::Div;
35use std::ops::Mul;
36use std::ops::MulAssign;
37use std::ops::Neg;
38use std::ops::Sub;
39use std::ops::SubAssign;
40use std::str::FromStr;
41
42use num_integer::Integer;
43
44#[macro_use]
45mod custom;
46
47#[cfg(feature = "alt_bn128")]
48pub mod alt_bn128;
49#[cfg(feature = "curve25519")]
50pub mod curve_25519;
51#[cfg(feature = "oxfoi")]
52pub mod oxfoi;
53
54pub mod matrix;
55pub mod timing;
56
57#[cfg(feature = "alt_bn128")]
58pub use alt_bn128::Bn128FieldElement;
59#[cfg(feature = "curve25519")]
60pub use curve_25519::Curve25519FieldElement;
61pub use num_bigint::BigUint;
62#[cfg(all(feature = "oxfoi"))]
63pub use oxfoi::OxfoiFieldElement;
64
65pub trait FieldElement:
71 Add<Output = Self>
72 + AddAssign
73 + Div<Output = Self>
74 + Mul<Output = Self>
75 + MulAssign
76 + Neg<Output = Self>
77 + Sub<Output = Self>
78 + SubAssign
79 + FromStr<Err = anyhow::Error>
80 + PartialEq
81 + Clone
82 + Hash
83 + Debug
84 + From<u64>
85 + Display
86{
87 fn zero() -> Self {
89 Self::from(0)
90 }
91
92 fn one() -> Self {
94 Self::from(1)
95 }
96
97 fn byte_len() -> usize;
100
101 #[cfg(feature = "random")]
104 fn sample_uniform<R: rand::Rng>(src: &mut R) -> Self {
105 let bytes = vec![0; Self::byte_len()]
106 .iter()
107 .map(|_| src.gen_range(0..=255))
108 .collect::<Vec<_>>();
109 Self::from_bytes_le(&bytes)
110 }
111
112 fn prime() -> BigUint {
115 (-Self::one()).to_biguint() + 1_u32
119 }
120
121 fn name_str() -> &'static str;
123
124 fn from_usize(value: usize) -> Self {
128 Self::from(u64::try_from(value).unwrap())
132 }
133
134 fn to_biguint(&self) -> num_bigint::BigUint {
137 BigUint::from_bytes_le(self.to_bytes_le().as_slice())
140 }
141
142 fn from_biguint(v: &BigUint) -> Self {
145 Self::from_bytes_le(&v.clone().to_bytes_le()[..])
146 }
147
148 fn from_bytes_le(bytes: &[u8]) -> Self;
152
153 fn to_bytes_le(&self) -> Vec<u8>;
157
158 fn lower60_string(&self) -> String {
164 const POW: u32 = 60;
165 let two_pow = BigUint::from(2_u64.pow(POW));
168 let plain_str = self.to_string();
169 let l60_str = format!("{}_L60", self.to_biguint() % two_pow);
170 if l60_str.len() + 3 < plain_str.len() {
173 l60_str
174 } else {
175 plain_str
176 }
177 }
178
179 fn log_floor(&self, b: Self) -> u32 {
183 if b.to_biguint() > self.to_biguint() {
184 return 0;
185 } else if b == *self {
186 return 1;
187 }
188 let e = self.to_biguint();
189 let b = b.to_biguint();
190 let mut x = b.clone();
191 let mut i = 1;
192 while x < e {
193 x *= b.clone();
194 if x >= e {
195 return i;
196 }
197 i += 1;
198 }
199 unreachable!();
200 }
201
202 fn legendre(&self) -> i32 {
206 if self == &Self::zero() {
207 return 0;
208 }
209 let neg_one = Self::prime() - 1_u32;
210 let one = BigUint::from(1_u32);
211 let e = (-Self::one()) / (Self::one() + Self::one());
212 let e_bigint = BigUint::from_str(&e.to_string()).unwrap();
213 let a = BigUint::from_str(&self.to_string()).unwrap();
214 let l = a.modpow(&e_bigint, &Self::prime());
215 if l == neg_one {
216 -1
217 } else if l == one {
218 return 1;
219 } else {
220 panic!("legendre symbol is not 1, -1, or 0");
221 }
222 }
223
224 fn sqrt(&self) -> Self {
227 if self == &Self::zero() {
228 return Self::zero();
229 }
230 if self.legendre() != 1 {
231 panic!("legendre symbol is not 1: root does not exist or input is 0");
232 }
233 let mut x = Self::one() + Self::one();
235 let non_residue;
236 loop {
237 if x.legendre() == -1 {
238 non_residue = x.clone();
239 break;
240 }
241 x += Self::one();
242 }
243 let b = BigUint::from_str(&non_residue.to_string()).unwrap();
244
245 let a = BigUint::from_str(&self.to_string()).unwrap();
246 let two = Self::one() + Self::one();
247 let m = (-Self::one()) / two.clone();
248 let mut apow = -Self::one();
249 let mut bpow = Self::zero();
250 while BigUint::from_str(&apow.to_string()).unwrap().is_even() {
251 apow = apow / two.clone();
252 bpow = bpow / two.clone();
253 let a_ = a.modpow(
254 &BigUint::from_str(&apow.to_string()).unwrap(),
255 &Self::prime(),
256 );
257 let b_ = b.modpow(
258 &BigUint::from_str(&bpow.to_string()).unwrap(),
259 &Self::prime(),
260 );
261 if (a_ * b_) % Self::prime() == Self::prime() - 1_u32 {
262 bpow += m.clone();
263 }
264 }
265 apow = (apow + Self::one()) / two.clone();
266 bpow = bpow / two;
267 let a_ = a.modpow(
268 &BigUint::from_str(&apow.to_string()).unwrap(),
269 &Self::prime(),
270 );
271 let b_ = b.modpow(
272 &BigUint::from_str(&bpow.to_string()).unwrap(),
273 &Self::prime(),
274 );
275 let root = (a_ * b_) % Self::prime();
276 let other_root = Self::prime() - root.clone();
277 if root > other_root {
278 Self::from_biguint(&other_root)
279 } else {
280 Self::from_biguint(&root)
281 }
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288
289 fn test_sqrt<T: FieldElement>() {
290 let mut x = T::one();
291 for _ in 0..1000 {
292 let square = x.clone() * x.clone();
293 let root = square.sqrt();
294 assert_eq!(square, root.clone() * root.clone());
295 x += T::one();
296 }
297 }
298
299 scalar_ring!(F13FieldElement, 13, "f13");
300
301 #[test]
302 fn sqrt_scalar_ring() {
303 test_sqrt::<F13FieldElement>();
304 }
305
306 #[test]
307 fn sqrt_oxfoi() {
308 test_sqrt::<oxfoi::OxfoiFieldElement>();
309 }
310
311 #[test]
312 fn sqrt_bn128() {
313 test_sqrt::<alt_bn128::Bn128FieldElement>();
314 }
315
316 #[test]
317 fn sqrt_curve25519() {
318 test_sqrt::<curve_25519::Curve25519FieldElement>();
319 }
320}