dcrypt_algorithms/ec/k256/
field.rs

1//! secp256k1 field arithmetic implementation.
2//! Field prime p = 2^256 - 2^32 - 977.
3
4use crate::ec::k256::constants::K256_FIELD_ELEMENT_SIZE;
5use crate::error::{Error, Result};
6use subtle::{Choice, ConditionallySelectable};
7
8/// secp256k1 field element representing values in F_p
9#[derive(Clone, Debug, PartialEq, Eq)]
10pub struct FieldElement(pub(crate) [u32; 8]);
11
12impl FieldElement {
13    /// The secp256k1 prime modulus: p = 2^256 - 2^32 - 977
14    pub(crate) const MOD_LIMBS: [u32; 8] = [
15        0xFFFF_FC2F,
16        0xFFFF_FFFE,
17        0xFFFF_FFFF,
18        0xFFFF_FFFF,
19        0xFFFF_FFFF,
20        0xFFFF_FFFF,
21        0xFFFF_FFFF,
22        0xFFFF_FFFF,
23    ];
24
25    /// The additive identity element: 0
26    pub fn zero() -> Self {
27        FieldElement([0; 8])
28    }
29
30    /// The multiplicative identity element: 1
31    pub fn one() -> Self {
32        let mut limbs = [0; 8];
33        limbs[0] = 1;
34        FieldElement(limbs)
35    }
36
37    /// Create a field element from its canonical byte representation.
38    ///
39    /// Returns an error if the value is greater than or equal to the field modulus.
40    pub fn from_bytes(bytes: &[u8; K256_FIELD_ELEMENT_SIZE]) -> Result<Self> {
41        let mut limbs = [0u32; 8];
42        for (i, limb) in limbs.iter_mut().enumerate() {
43            let offset = (7 - i) * 4;
44            *limb = u32::from_be_bytes([
45                bytes[offset],
46                bytes[offset + 1],
47                bytes[offset + 2],
48                bytes[offset + 3],
49            ]);
50        }
51        let fe = FieldElement(limbs);
52        if !fe.is_valid() {
53            return Err(Error::param(
54                "FieldElement K256",
55                "Value must be less than the field modulus",
56            ));
57        }
58        Ok(fe)
59    }
60
61    /// Convert this field element to its canonical byte representation.
62    pub fn to_bytes(&self) -> [u8; K256_FIELD_ELEMENT_SIZE] {
63        let mut bytes = [0u8; K256_FIELD_ELEMENT_SIZE];
64        for i in 0..8 {
65            let limb_bytes = self.0[i].to_be_bytes();
66            let offset = (7 - i) * 4;
67            bytes[offset..offset + 4].copy_from_slice(&limb_bytes);
68        }
69        bytes
70    }
71
72    /// Check if this field element is less than the field modulus.
73    #[inline(always)]
74    pub fn is_valid(&self) -> bool {
75        let (_, borrow) = Self::sbb8(self.0, Self::MOD_LIMBS);
76        borrow == 1
77    }
78
79    /// Check if this field element is zero.
80    pub fn is_zero(&self) -> bool {
81        self.0.iter().all(|&l| l == 0)
82    }
83
84    /// Check if this field element is odd (least significant bit is 1).
85    pub fn is_odd(&self) -> bool {
86        // limbs[0] contains the least significant 32 bits
87        (self.0[0] & 1) == 1
88    }
89
90    /// Add two field elements modulo p.
91    #[inline(always)]
92    pub fn add(&self, other: &Self) -> Self {
93        let (sum, carry) = Self::adc8(self.0, other.0);
94        let (sum_minus_p, borrow) = Self::sbb8(sum, Self::MOD_LIMBS);
95        let needs_reduce = (carry | (borrow ^ 1)) & 1;
96        Self::conditional_select(&sum, &sum_minus_p, Choice::from(needs_reduce as u8))
97    }
98
99    /// Subtract two field elements modulo p.
100    pub fn sub(&self, other: &Self) -> Self {
101        let (diff, borrow) = Self::sbb8(self.0, other.0);
102        let (candidate, _) = Self::adc8(diff, Self::MOD_LIMBS);
103        Self::conditional_select(&diff, &candidate, Choice::from(borrow as u8))
104    }
105
106    /// Negate a field element modulo p.
107    pub fn negate(&self) -> Self {
108        if self.is_zero() {
109            return self.clone();
110        }
111        FieldElement(Self::MOD_LIMBS).sub(self)
112    }
113
114    /// Multiply two field elements modulo p.
115    pub fn mul(&self, other: &Self) -> Self {
116        let mut t = [0u128; 16];
117        for i in 0..8 {
118            for j in 0..8 {
119                t[i + j] += (self.0[i] as u128) * (other.0[j] as u128);
120            }
121        }
122        let mut prod = [0u32; 16];
123        let mut carry: u128 = 0;
124        for i in 0..16 {
125            let v = t[i] + carry;
126            prod[i] = (v & 0xffff_ffff) as u32;
127            carry = v >> 32;
128        }
129        Self::reduce_wide(prod)
130    }
131
132    /// Square a field element modulo p.
133    #[inline(always)]
134    pub fn square(&self) -> Self {
135        self.mul(self)
136    }
137
138    /// Double a field element (multiply by 2) modulo p.
139    pub fn double(&self) -> Self {
140        self.add(self)
141    }
142
143    /// Compute the multiplicative inverse of a field element.
144    ///
145    /// Returns an error if the element is zero.
146    pub fn invert(&self) -> Result<Self> {
147        if self.is_zero() {
148            return Err(Error::param(
149                "FieldElement K256",
150                "Inversion of zero is undefined",
151            ));
152        }
153        const P_MINUS_2: [u8; 32] = [
154            0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
155            0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFE,
156            0xFF, 0xFF, 0xFC, 0x2D,
157        ];
158        self.pow(&P_MINUS_2)
159    }
160
161    /// Compute the square root of a field element.
162    ///
163    /// Returns None if the element is not a quadratic residue.
164    pub fn sqrt(&self) -> Option<Self> {
165        if self.is_zero() {
166            return Some(Self::zero());
167        }
168        // p mod 4 = 3, so sqrt(a) = a^((p+1)/4)
169        const P_PLUS_1_DIV_4: [u8; 32] = [
170            0x3F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
171            0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
172            0xBF, 0xFF, 0xFF, 0x0C,
173        ];
174        let root = self.pow(&P_PLUS_1_DIV_4).ok()?;
175        if root.square() == *self {
176            Some(root)
177        } else {
178            None
179        }
180    }
181
182    fn pow(&self, exp_be: &[u8]) -> Result<Self> {
183        let mut result = Self::one();
184        let base = self.clone();
185        for &byte in exp_be.iter() {
186            for i in (0..8).rev() {
187                result = result.square();
188                if (byte >> i) & 1 == 1 {
189                    result = result.mul(&base);
190                }
191            }
192        }
193        Ok(result)
194    }
195
196    fn conditional_select(a: &[u32; 8], b: &[u32; 8], flag: Choice) -> Self {
197        let mut out = [0u32; 8];
198        for i in 0..8 {
199            out[i] = u32::conditional_select(&a[i], &b[i], flag);
200        }
201        FieldElement(out)
202    }
203
204    fn adc8(a: [u32; 8], b: [u32; 8]) -> ([u32; 8], u32) {
205        let mut r = [0u32; 8];
206        let mut carry: u64 = 0;
207        for i in 0..8 {
208            let tmp = (a[i] as u64) + (b[i] as u64) + carry;
209            r[i] = tmp as u32;
210            carry = tmp >> 32;
211        }
212        (r, carry as u32)
213    }
214
215    fn sbb8(a: [u32; 8], b: [u32; 8]) -> ([u32; 8], u32) {
216        let mut r = [0u32; 8];
217        let mut borrow: i64 = 0;
218        for i in 0..8 {
219            let tmp = (a[i] as i64) - (b[i] as i64) - borrow;
220            r[i] = tmp as u32;
221            borrow = (tmp >> 63) & 1;
222        }
223        (r, borrow as u32)
224    }
225
226    /// Reduce a 512-bit number modulo p = 2^256 - 2^32 - 977
227    /// Uses the special form of secp256k1's prime for efficient reduction
228    fn reduce_wide(t: [u32; 16]) -> Self {
229        // For p = 2^256 - 2^32 - 977, we can use the fact that
230        // 2^256 ≡ 2^32 + 977 (mod p)
231        // This allows us to reduce the high 256 bits efficiently
232
233        // Split t into low 256 bits (t_low) and high 256 bits (t_high)
234        let mut t_low = [0u32; 8];
235        let mut t_high = [0u32; 8];
236        t_low.copy_from_slice(&t[..8]);
237        t_high.copy_from_slice(&t[8..]);
238
239        // We need to compute: t_low + t_high * 2^256
240        // Since 2^256 ≡ 2^32 + 977 (mod p), we compute:
241        // t_low + t_high * (2^32 + 977)
242        // = t_low + (t_high << 32) + t_high * 977
243
244        // First, compute t_high * 977
245        let mut t_high_977 = [0u64; 9];
246        for i in 0..8 {
247            t_high_977[i] += (t_high[i] as u64) * 977u64;
248        }
249        // Propagate carries
250        for i in 0..8 {
251            t_high_977[i + 1] += t_high_977[i] >> 32;
252            t_high_977[i] &= 0xFFFF_FFFF;
253        }
254
255        // Now add: t_low + (t_high << 32) + t_high_977
256        let mut result = [0u64; 9];
257
258        // Add t_low
259        for i in 0..8 {
260            result[i] += t_low[i] as u64;
261        }
262
263        // Add t_high << 32 (which means t_high[i] goes to position i+1)
264        for i in 0..8 {
265            result[i + 1] += t_high[i] as u64;
266        }
267
268        // Add t_high_977
269        for i in 0..9 {
270            result[i] += t_high_977[i];
271        }
272
273        // Propagate all carries
274        for i in 0..8 {
275            result[i + 1] += result[i] >> 32;
276            result[i] &= 0xFFFF_FFFF;
277        }
278
279        // If result[8] is non-zero, we need another reduction step
280        if result[8] > 0 {
281            // result[8] * 2^256 ≡ result[8] * (2^32 + 977) (mod p)
282            let overflow = result[8];
283            result[8] = 0;
284
285            // Add overflow * 977 to result[0]
286            result[0] += overflow * 977;
287            // Add overflow to result[1] (for the 2^32 part)
288            result[1] += overflow;
289
290            // Propagate carries again
291            for i in 0..8 {
292                if i < 7 {
293                    result[i + 1] += result[i] >> 32;
294                }
295                result[i] &= 0xFFFF_FFFF;
296            }
297        }
298
299        // Convert back to u32 array
300        let mut r = [0u32; 8];
301        for i in 0..8 {
302            r[i] = result[i] as u32;
303        }
304
305        // Final reduction if r >= p
306        let fe = FieldElement(r);
307        if !fe.is_valid() {
308            let (reduced, _) = Self::sbb8(r, Self::MOD_LIMBS);
309            FieldElement(reduced)
310        } else {
311            fe
312        }
313    }
314}
315
316#[cfg(test)]
317mod field_constants_tests {
318    use super::*;
319
320    #[test]
321    fn test_modulus_is_correct() {
322        // The correct secp256k1 prime in hex:
323        // p = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F
324
325        // Convert MOD_LIMBS to bytes for comparison
326        let mut mod_bytes = [0u8; 32];
327        for (i, &limb) in FieldElement::MOD_LIMBS.iter().enumerate() {
328            let limb_bytes = limb.to_be_bytes();
329            let offset = (7 - i) * 4;
330            mod_bytes[offset..offset + 4].copy_from_slice(&limb_bytes);
331        }
332
333        // Expected prime as bytes
334        let expected_bytes: [u8; 32] = [
335            0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
336            0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFE,
337            0xFF, 0xFF, 0xFC, 0x2F,
338        ];
339
340        assert_eq!(
341            mod_bytes, expected_bytes,
342            "MOD_LIMBS does not encode the correct secp256k1 prime"
343        );
344    }
345}