generic_ec/
scalar.rs

1use core::hash::{self, Hash};
2use core::{fmt, iter};
3
4use rand_core::RngCore;
5use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
6use zeroize::Zeroize;
7
8use crate::NonZero;
9use crate::{
10    as_raw::{AsRaw, FromRaw},
11    core::*,
12    encoded::EncodedScalar,
13    errors::InvalidScalar,
14};
15
16/// Scalar modulo curve `E` group order
17///
18/// Scalar is an integer modulo curve `E` group order.
19#[derive(Copy, Clone, PartialEq, Eq, Default)]
20pub struct Scalar<E: Curve>(E::Scalar);
21
22impl<E: Curve> Scalar<E> {
23    /// Returns scalar $S = 0$
24    ///
25    /// ```rust
26    /// use generic_ec::{Scalar, curves::Secp256k1};
27    /// use rand::rngs::OsRng;
28    ///
29    /// let s = Scalar::<Secp256k1>::random(&mut OsRng);
30    /// assert_eq!(s * Scalar::zero(), Scalar::zero());
31    /// assert_eq!(s + Scalar::zero(), s);
32    /// ```
33    pub fn zero() -> Self {
34        Self::from_raw(E::Scalar::zero())
35    }
36
37    /// Checks whether the scalar is zero
38    pub fn is_zero(&self) -> bool {
39        Zero::is_zero(self.as_raw()).into()
40    }
41
42    /// Returns scalar $S = 1$
43    ///
44    /// ```rust
45    /// use generic_ec::{Scalar, curves::Secp256k1};
46    /// use rand::rngs::OsRng;
47    ///
48    /// let s = Scalar::<Secp256k1>::random(&mut OsRng);
49    /// assert_eq!(s * Scalar::one(), s);
50    /// ```
51    pub fn one() -> Self {
52        Self::from_raw(E::Scalar::one())
53    }
54
55    /// Returns scalar inverse $S^{-1}$
56    ///
57    /// Inverse of scalar $S$ is a scalar $S^{-1}$ such as $S \cdot S^{-1} = 1$. Inverse doesn't
58    /// exist only for scalar $S = 0$, so function returns `None` if scalar is zero.
59    ///
60    /// ```rust
61    /// # fn func() -> Option<()> {
62    /// use generic_ec::{Scalar, curves::Secp256k1};
63    /// use rand::rngs::OsRng;
64    ///
65    /// let s = Scalar::<Secp256k1>::random(&mut OsRng);
66    /// let s_inv = s.invert()?;
67    /// assert_eq!(s * s_inv, Scalar::one());
68    /// # Some(()) }
69    /// # func();
70    /// ```
71    pub fn invert(&self) -> Option<Self> {
72        self.ct_invert().into()
73    }
74
75    /// Returns scalar inverse $S^{-1}$ (constant time)
76    ///
77    /// Same as [`Scalar::invert`] but performs constant-time check on whether it's zero
78    /// scalar
79    pub fn ct_invert(&self) -> CtOption<Self> {
80        let inv = Invertible::invert(self.as_raw());
81        inv.map(Self::from_raw)
82    }
83
84    /// Encodes scalar as bytes in big-endian order
85    ///
86    /// ```rust
87    /// use generic_ec::{Scalar, curves::Secp256k1};
88    /// use rand::rngs::OsRng;
89    ///
90    /// let s = Scalar::<Secp256k1>::random(&mut OsRng);
91    /// let bytes = s.to_be_bytes();
92    /// println!("Scalar hex representation: {}", hex::encode(bytes));
93    /// ```
94    pub fn to_be_bytes(&self) -> EncodedScalar<E> {
95        let bytes = self.as_raw().to_be_bytes();
96        EncodedScalar::new(bytes)
97    }
98
99    /// Encodes scalar as bytes in little-endian order
100    pub fn to_le_bytes(&self) -> EncodedScalar<E> {
101        let bytes = self.as_raw().to_le_bytes();
102        EncodedScalar::new(bytes)
103    }
104
105    /// Decodes scalar from its representation as bytes in big-endian order
106    ///
107    /// Returns error if encoded integer is larger than group order.
108    ///
109    /// ```rust
110    /// use generic_ec::{Scalar, curves::Secp256k1};
111    /// use rand::rngs::OsRng;
112    ///
113    /// let s = Scalar::<Secp256k1>::random(&mut OsRng);
114    /// let s_bytes = s.to_be_bytes();
115    /// let s_decoded = Scalar::from_be_bytes(&s_bytes)?;
116    /// assert_eq!(s, s_decoded);
117    /// # Ok::<(), Box<dyn std::error::Error>>(())
118    /// ```
119    pub fn from_be_bytes(bytes: impl AsRef<[u8]>) -> Result<Self, InvalidScalar> {
120        let bytes = bytes.as_ref();
121        let mut bytes_array = E::ScalarArray::zeroes();
122        let bytes_array_len = bytes_array.as_ref().len();
123        if bytes_array_len < bytes.len() {
124            return Err(InvalidScalar);
125        }
126        bytes_array.as_mut()[bytes_array_len - bytes.len()..].copy_from_slice(bytes);
127
128        let scalar = E::Scalar::from_be_bytes_exact(&bytes_array).ok_or(InvalidScalar)?;
129        Ok(Scalar::from_raw(scalar))
130    }
131
132    /// Decodes scalar from its representation as bytes in little-endian order
133    ///
134    /// Returns error if encoded integer is larger than group order.
135    pub fn from_le_bytes(bytes: impl AsRef<[u8]>) -> Result<Self, InvalidScalar> {
136        let bytes = bytes.as_ref();
137        let mut bytes_array = E::ScalarArray::zeroes();
138        let bytes_array_len = bytes_array.as_ref().len();
139        if bytes_array_len < bytes.len() {
140            return Err(InvalidScalar);
141        }
142        bytes_array.as_mut()[..bytes.len()].copy_from_slice(bytes);
143
144        let scalar = E::Scalar::from_le_bytes_exact(&bytes_array).ok_or(InvalidScalar)?;
145        Ok(Scalar::from_raw(scalar))
146    }
147
148    /// Interprets provided bytes as integer $i$ in big-endian order, returns scalar $s = i \mod q$
149    pub fn from_be_bytes_mod_order(bytes: impl AsRef<[u8]>) -> Self {
150        let scalar = E::Scalar::from_be_bytes_mod_order(bytes.as_ref());
151        Self::from_raw(scalar)
152    }
153
154    /// Interprets provided bytes as integer $i$ in little-endian order, returns scalar $s = i \mod q$
155    pub fn from_le_bytes_mod_order(bytes: impl AsRef<[u8]>) -> Self {
156        let scalar = E::Scalar::from_le_bytes_mod_order(bytes.as_ref());
157        Self::from_raw(scalar)
158    }
159
160    /// Generates random non-zero scalar
161    ///
162    /// Algorithm is based on rejection sampling: we sample a scalar, if it's zero try again.
163    /// It may be considered constant-time as zero scalar appears with $2^{-256}$ probability
164    /// which is considered to be negligible.
165    ///
166    /// ## Panics
167    /// Panics if randomness source returned 100 zero scalars in a row. It happens with
168    /// $2^{-25600}$ probability, which practically means that randomness source is broken.
169    pub fn random<R: RngCore>(rng: &mut R) -> Self {
170        NonZero::<Scalar<E>>::random(rng).into()
171    }
172
173    #[doc = include_str!("../docs/hash_to_scalar.md")]
174    ///
175    /// ## Example
176    /// ```rust
177    /// use generic_ec::{Scalar, curves::Secp256k1};
178    /// use sha2::Sha256;
179    ///
180    /// #[derive(udigest::Digestable)]
181    /// struct Data<'a> {
182    ///     nonce: &'a [u8],
183    ///     param_a: &'a str,
184    ///     param_b: u128,
185    ///     // ...
186    /// }
187    ///
188    /// let scalar = Scalar::<Secp256k1>::from_hash::<Sha256>(&Data {
189    ///     nonce: b"some data",
190    ///     param_a: "some other data",
191    ///     param_b: 12345,
192    ///     // ...
193    /// });
194    /// ```
195    #[cfg(feature = "hash-to-scalar")]
196    pub fn from_hash<D: digest::Digest>(data: &impl udigest::Digestable) -> Self {
197        let mut rng = rand_hash::HashRng::<D, _>::from_seed(data);
198        Self::random(&mut rng)
199    }
200
201    /// Returns size of bytes buffer that can fit serialized scalar
202    pub fn serialized_len() -> usize {
203        E::ScalarArray::zeroes().as_ref().len()
204    }
205
206    /// Returns scalar big-endian representation in radix $2^4 = 16$
207    ///
208    /// Radix 16 representation is defined as sum:
209    ///
210    /// $$s = s_0 + s_1 16^1 + s_2 16^2 + \dots + s_{\log_{16}(s) - 1} 16^{\log_{16}(s) - 1}$$
211    ///
212    /// (note: we typically work with 256 bit scalars, so $\log_{16}(s) = \log_{16}(2^{256}) = 64$)
213    ///
214    /// Returns iterator over coefficients from most to least significant:
215    /// $s_{\log_{16}(s) - 1}, \dots, s_1, s_0$
216    pub fn as_radix16_be(&self) -> Radix16Iter<E> {
217        Radix16Iter::new(self.to_be_bytes(), true)
218    }
219
220    /// Returns scalar little-endian representation in radix $2^4 = 16$
221    ///
222    /// Radix 16 representation is defined as sum:
223    ///
224    /// $$s = s_0 + s_1 16^1 + s_2 16^2 + \dots + s_{\log_{16}(s) - 1} 16^{\log_{16}(s) - 1}$$
225    ///
226    /// (note: we typically work with 256 bit scalars, so $\log_{16}(s) = \log_{16}(2^{256}) = 64$)
227    ///
228    /// Returns iterator over coefficients from least to most significant:
229    /// $s_0, s_1, \dots, s_{\log_{16}(s) - 1}$
230    pub fn as_radix16_le(&self) -> Radix16Iter<E> {
231        Radix16Iter::new(self.to_le_bytes(), false)
232    }
233
234    /// Performs multiscalar multiplication
235    ///
236    /// Takes iterator of pairs `(scalar, point)`. Returns sum of `scalar * point`. Uses
237    /// [`Default`](crate::multiscalar::Default) algorithm.
238    ///
239    /// See [multiscalar module](crate::multiscalar) docs for more info.
240    pub fn multiscalar_mul<S, P>(
241        scalar_points: impl ExactSizeIterator<Item = (S, P)>,
242    ) -> crate::Point<E>
243    where
244        S: AsRef<Scalar<E>>,
245        P: AsRef<crate::Point<E>>,
246    {
247        use crate::multiscalar::MultiscalarMul;
248        crate::multiscalar::Default::multiscalar_mul(scalar_points)
249    }
250}
251
252impl<E: Curve> AsRaw for Scalar<E> {
253    type Raw = E::Scalar;
254
255    #[inline]
256    fn as_raw(&self) -> &E::Scalar {
257        &self.0
258    }
259}
260
261impl<E: Curve> Zeroize for Scalar<E> {
262    #[inline]
263    fn zeroize(&mut self) {
264        self.0.zeroize()
265    }
266}
267
268impl<E: Curve> FromRaw for Scalar<E> {
269    fn from_raw(scalar: E::Scalar) -> Self {
270        Self(scalar)
271    }
272}
273
274impl<E: Curve> ConditionallySelectable for Scalar<E> {
275    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
276        Scalar::from_raw(<E::Scalar as ConditionallySelectable>::conditional_select(
277            a.as_raw(),
278            b.as_raw(),
279            choice,
280        ))
281    }
282}
283
284impl<E: Curve> ConstantTimeEq for Scalar<E> {
285    fn ct_eq(&self, other: &Self) -> Choice {
286        self.as_raw().ct_eq(other.as_raw())
287    }
288}
289
290impl<E: Curve> AsRef<Scalar<E>> for Scalar<E> {
291    fn as_ref(&self) -> &Scalar<E> {
292        self
293    }
294}
295
296impl<E: Curve> crate::traits::IsZero for Scalar<E> {
297    fn is_zero(&self) -> bool {
298        self.is_zero()
299    }
300}
301
302impl<E: Curve> crate::traits::Zero for Scalar<E> {
303    fn zero() -> Self {
304        Scalar::zero()
305    }
306
307    fn is_zero(x: &Self) -> Choice {
308        Zero::is_zero(x.as_raw())
309    }
310}
311
312impl<E: Curve> crate::traits::One for Scalar<E> {
313    fn one() -> Self {
314        Self::one()
315    }
316
317    fn is_one(x: &Self) -> Choice {
318        One::is_one(x.as_raw())
319    }
320}
321
322impl<E: Curve> crate::traits::Samplable for Scalar<E> {
323    fn random<R: RngCore>(rng: &mut R) -> Self {
324        Self::random(rng)
325    }
326}
327
328impl<E: Curve> iter::Sum for Scalar<E> {
329    fn sum<I: Iterator<Item = Self>>(mut iter: I) -> Self {
330        let Some(first_scalar) = iter.next() else {
331            return Scalar::zero();
332        };
333        iter.fold(first_scalar, |acc, x| acc + x)
334    }
335}
336
337impl<'a, E: Curve> iter::Sum<&'a Scalar<E>> for Scalar<E> {
338    fn sum<I: Iterator<Item = &'a Self>>(mut iter: I) -> Self {
339        let Some(first_scalar) = iter.next() else {
340            return Scalar::zero();
341        };
342        iter.fold(*first_scalar, |acc, x| acc + x)
343    }
344}
345
346impl<E: Curve> iter::Product for Scalar<E> {
347    fn product<I: Iterator<Item = Self>>(mut iter: I) -> Self {
348        let Some(first_scalar) = iter.next() else {
349            return Scalar::one();
350        };
351        iter.fold(first_scalar, |acc, x| acc * x)
352    }
353}
354
355impl<'a, E: Curve> iter::Product<&'a Scalar<E>> for Scalar<E> {
356    fn product<I: Iterator<Item = &'a Self>>(mut iter: I) -> Self {
357        let Some(first_scalar) = iter.next() else {
358            return Scalar::one();
359        };
360        iter.fold(*first_scalar, |acc, x| acc * x)
361    }
362}
363
364macro_rules! impl_from_primitive_integer {
365    ($($int:ident),+) => {$(
366        impl<E: Curve> From<$int> for Scalar<E> {
367            fn from(i: $int) -> Self {
368                Scalar::from_le_bytes(&i.to_le_bytes())
369                    .expect("scalar should be large enough to fit a primitive integer")
370            }
371        }
372    )+};
373}
374
375macro_rules! impl_from_signed_integer {
376    ($($iint:ident),+) => {$(
377        impl<E: Curve> From<$iint> for Scalar<E> {
378            fn from(i: $iint) -> Self {
379                use subtle::{ConditionallyNegatable, Choice};
380                // TODO: what's a better way to do that check in constant time?
381                let is_neg = Choice::from(u8::from(i.is_negative()));
382                let i = i.unsigned_abs();
383                let mut i = Scalar::from(i);
384                i.conditional_negate(is_neg);
385                i
386            }
387        }
388    )+};
389}
390
391impl_from_primitive_integer! {
392    u8, u16, u32, u64, u128, usize
393}
394impl_from_signed_integer! {
395    i8, i16, i32, i64, i128
396}
397
398impl<E: Curve> fmt::Debug for Scalar<E> {
399    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
400        let mut s = f.debug_struct("Scalar");
401        s.field("curve", &E::CURVE_NAME);
402        #[cfg(feature = "std")]
403        {
404            let scalar_hex = hex::encode(self.to_be_bytes());
405            s.field("value", &scalar_hex);
406        }
407        #[cfg(not(feature = "std"))]
408        {
409            s.field("value", &"...");
410        }
411        s.finish()
412    }
413}
414
415#[allow(clippy::derived_hash_with_manual_eq)]
416impl<E: Curve> Hash for Scalar<E> {
417    fn hash<H: hash::Hasher>(&self, state: &mut H) {
418        state.write(self.to_be_bytes().as_bytes())
419    }
420}
421
422impl<E: Curve> PartialOrd for Scalar<E> {
423    fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
424        Some(self.cmp(other))
425    }
426}
427
428impl<E: Curve> Ord for Scalar<E> {
429    fn cmp(&self, other: &Self) -> core::cmp::Ordering {
430        self.to_be_bytes()
431            .as_bytes()
432            .cmp(other.to_be_bytes().as_bytes())
433    }
434}
435
436#[cfg(feature = "udigest")]
437impl<E: Curve> udigest::Digestable for Scalar<E> {
438    fn unambiguously_encode<B>(&self, encoder: udigest::encoding::EncodeValue<B>)
439    where
440        B: udigest::Buffer,
441    {
442        let mut s = encoder.encode_struct();
443        s.add_field("curve").encode_leaf_value(E::CURVE_NAME);
444        s.add_field("scalar").encode_leaf_value(self.to_be_bytes());
445        s.finish();
446    }
447}
448
449/// Iterator over scalar coefficients in radix 16 representation
450///
451/// See [`Scalar::as_radix16_be`] and [`Scalar::as_radix16_le`]
452pub struct Radix16Iter<E: Curve> {
453    /// radix 256 representation of the scalar
454    encoded_scalar: EncodedScalar<E>,
455    next_radix16: Option<u8>,
456    next_index: usize,
457
458    /// Indicates that output is in big-endian. If it's false,
459    /// output is in little-endian
460    is_be: bool,
461}
462
463impl<E: Curve> Radix16Iter<E> {
464    fn new(encoded_scalar: EncodedScalar<E>, is_be: bool) -> Self {
465        Self {
466            encoded_scalar,
467            is_be,
468            next_radix16: None,
469            next_index: 0,
470        }
471    }
472}
473
474impl<E: Curve> Iterator for Radix16Iter<E> {
475    type Item = u8;
476
477    fn next(&mut self) -> Option<Self::Item> {
478        if let Some(next_radix16) = self.next_radix16.take() {
479            return Some(next_radix16);
480        }
481
482        let next_radix256 = self.encoded_scalar.get(self.next_index)?;
483        self.next_index += 1;
484
485        let high_radix16 = next_radix256 >> 4;
486        let low_radix16 = next_radix256 & 0xF;
487        debug_assert_eq!((high_radix16 << 4) | low_radix16, *next_radix256);
488        debug_assert_eq!(high_radix16 & (!0xF), 0);
489        debug_assert_eq!(low_radix16 & (!0xF), 0);
490
491        if self.is_be {
492            self.next_radix16 = Some(low_radix16);
493            Some(high_radix16)
494        } else {
495            self.next_radix16 = Some(high_radix16);
496            Some(low_radix16)
497        }
498    }
499
500    fn size_hint(&self) -> (usize, Option<usize>) {
501        let len = self.len();
502        (len, Some(len))
503    }
504}
505
506impl<E: Curve> ExactSizeIterator for Radix16Iter<E> {
507    fn len(&self) -> usize {
508        self.encoded_scalar[self.next_index..].len() * 2
509            + if self.next_radix16.is_some() { 1 } else { 0 }
510    }
511}
512
513impl<E: Curve, const N: usize> crate::traits::Reduce<N> for Scalar<E>
514where
515    E::Scalar: crate::traits::Reduce<N>,
516{
517    fn from_be_array_mod_order(bytes: &[u8; N]) -> Self {
518        Self::from_raw(Reduce::from_be_array_mod_order(bytes))
519    }
520    fn from_le_array_mod_order(bytes: &[u8; N]) -> Self {
521        Self::from_raw(Reduce::from_le_array_mod_order(bytes))
522    }
523}