Skip to main content

ark_ff/fields/models/small_fp/
field.rs

1use crate::fields::models::small_fp::small_fp_backend::{SmallFp, SmallFpConfig};
2use crate::{Field, LegendreSymbol, One, PrimeField, SqrtPrecomputation, Zero};
3use ark_serialize::{buffer_byte_size, EmptyFlags, Flags};
4use core::iter;
5
6impl<P: SmallFpConfig> Field for SmallFp<P> {
7    type BasePrimeField = Self;
8
9    const SQRT_PRECOMP: Option<SqrtPrecomputation<Self>> = P::SQRT_PRECOMP;
10    const ONE: Self = P::ONE;
11    const NEG_ONE: Self = P::NEG_ONE;
12
13    fn extension_degree() -> u64 {
14        1
15    }
16
17    fn from_base_prime_field(elem: Self::BasePrimeField) -> Self {
18        elem
19    }
20
21    fn to_base_prime_field_elements(&self) -> impl Iterator<Item = Self> {
22        iter::once(*self)
23    }
24
25    fn from_base_prime_field_elems(
26        elems: impl IntoIterator<Item = Self::BasePrimeField>,
27    ) -> Option<Self> {
28        let mut iter = elems.into_iter();
29        let first = iter.next()?;
30        if iter.next().is_some() {
31            None
32        } else {
33            Some(first)
34        }
35    }
36
37    #[inline]
38    fn characteristic() -> &'static [u64] {
39        Self::MODULUS.as_ref()
40    }
41
42    #[inline]
43    fn sum_of_products<const T: usize>(a: &[Self; T], b: &[Self; T]) -> Self {
44        P::sum_of_products(a, b)
45    }
46
47    #[inline]
48    fn from_random_bytes_with_flags<F: Flags>(bytes: &[u8]) -> Option<(Self, F)> {
49        if F::BIT_SIZE > 8 {
50            None
51        } else {
52            let shave_bits = Self::num_bits_to_shave();
53            let mut result_bytes: crate::const_helpers::SerBuffer<1> =
54                crate::const_helpers::SerBuffer::zeroed();
55            // Copy the input into a temporary buffer.
56            result_bytes.copy_from_u8_slice(bytes);
57            // This mask retains everything in the last limb
58            // that is below `P::MODULUS_BIT_SIZE`.
59            let last_limb_mask =
60                (u64::MAX.checked_shr(shave_bits as u32).unwrap_or(0)).to_le_bytes();
61            let mut last_bytes_mask = [0u8; 9];
62            last_bytes_mask[..8].copy_from_slice(&last_limb_mask);
63
64            // Length of the buffer containing the field element and the flag.
65            let output_byte_size = buffer_byte_size(Self::MODULUS_BIT_SIZE as usize + F::BIT_SIZE);
66            // Location of the flag is the last byte of the serialized
67            // form of the field element.
68            let flag_location = output_byte_size - 1;
69
70            // At which byte is the flag located in the last limb?
71            let flag_location_in_last_limb =
72                flag_location.saturating_sub(8 * (P::NUM_BIG_INT_LIMBS - 1));
73
74            // Take all but the last 9 bytes.
75            let last_bytes = result_bytes.last_n_plus_1_bytes_mut();
76
77            // The mask only has the last `F::BIT_SIZE` bits set
78            let flags_mask = u8::MAX.checked_shl(8 - (F::BIT_SIZE as u32)).unwrap_or(0);
79
80            // Mask away the remaining bytes, and try to reconstruct the
81            // flag
82            let mut flags: u8 = 0;
83            for (i, (b, m)) in last_bytes.zip(&last_bytes_mask).enumerate() {
84                if i == flag_location_in_last_limb {
85                    flags = *b & flags_mask
86                }
87                *b &= m;
88            }
89            // Use from_bigint (not deserialize_compressed) since these are plaintext bytes, not Montgomery-encoded.
90            let bigint = result_bytes.to_bigint();
91            Self::from_bigint(bigint).and_then(|f| F::from_u8(flags).map(|flag| (f, flag)))
92        }
93    }
94
95    #[inline]
96    fn square(&self) -> Self {
97        let mut temp = *self;
98        temp.square_in_place();
99        temp
100    }
101
102    fn square_in_place(&mut self) -> &mut Self {
103        P::square_in_place(self);
104        self
105    }
106
107    #[inline]
108    fn inverse(&self) -> Option<Self> {
109        P::inverse(self)
110    }
111
112    fn inverse_in_place(&mut self) -> Option<&mut Self> {
113        self.inverse().map(|inverse| {
114            *self = inverse;
115            self
116        })
117    }
118
119    /// The Frobenius map has no effect in a prime field.
120    #[inline]
121    fn frobenius_map_in_place(&mut self, _: usize) {}
122
123    #[inline]
124    fn legendre(&self) -> LegendreSymbol {
125        // s = self^((MODULUS - 1) // 2)
126        let s = self.pow(Self::MODULUS_MINUS_ONE_DIV_TWO);
127        if s.is_zero() {
128            LegendreSymbol::Zero
129        } else if s.is_one() {
130            LegendreSymbol::QuadraticResidue
131        } else {
132            LegendreSymbol::QuadraticNonResidue
133        }
134    }
135
136    fn mul_by_base_prime_field(&self, elem: &Self::BasePrimeField) -> Self {
137        *self * elem
138    }
139
140    fn from_random_bytes(bytes: &[u8]) -> Option<Self> {
141        Self::from_random_bytes_with_flags::<EmptyFlags>(bytes).map(|f| f.0)
142    }
143
144    fn sqrt(&self) -> Option<Self> {
145        match Self::SQRT_PRECOMP {
146            Some(tv) => tv.sqrt(self),
147            None => ark_std::unimplemented!(),
148        }
149    }
150
151    fn sqrt_in_place(&mut self) -> Option<&mut Self> {
152        (*self).sqrt().map(|sqrt| {
153            *self = sqrt;
154            self
155        })
156    }
157
158    fn frobenius_map(&self, power: usize) -> Self {
159        let mut this = *self;
160        this.frobenius_map_in_place(power);
161        this
162    }
163
164    fn pow<S: AsRef<[u64]>>(&self, exp: S) -> Self {
165        let mut res = Self::one();
166
167        for i in crate::BitIteratorBE::without_leading_zeros(exp) {
168            res.square_in_place();
169
170            if i {
171                res *= self;
172            }
173        }
174        res
175    }
176
177    fn pow_with_table<S: AsRef<[u64]>>(powers_of_2: &[Self], exp: S) -> Option<Self> {
178        let mut res = Self::one();
179        for (pow, bit) in crate::BitIteratorLE::without_trailing_zeros(exp).enumerate() {
180            if bit {
181                res *= powers_of_2.get(pow)?;
182            }
183        }
184        Some(res)
185    }
186}