Skip to main content

generic_ec/
scalar.rs

1use core::hash::{self, Hash};
2use core::{fmt, iter};
3
4use rand_core::RngCore;
5use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
6use zeroize::Zeroize;
7
8use crate::NonZero;
9use crate::{
10    as_raw::{AsRaw, FromRaw},
11    core::*,
12    encoded::EncodedScalar,
13    errors::InvalidScalar,
14};
15
16/// Scalar modulo curve `E` group order
17///
18/// Scalar is an integer modulo curve `E` group order.
19#[derive(Copy, Clone, PartialEq, Eq, Default)]
20pub struct Scalar<E: Curve>(E::Scalar);
21
22impl<E: Curve> Scalar<E> {
23    /// Returns scalar $S = 0$
24    ///
25    /// ```rust
26    /// use generic_ec::{Scalar, curves::Secp256k1};
27    /// use rand::rngs::OsRng;
28    ///
29    /// let s = Scalar::<Secp256k1>::random(&mut OsRng);
30    /// assert_eq!(s * Scalar::zero(), Scalar::zero());
31    /// assert_eq!(s + Scalar::zero(), s);
32    /// ```
33    pub fn zero() -> Self {
34        Self::from_raw(E::Scalar::zero())
35    }
36
37    /// Checks whether the scalar is zero
38    pub fn is_zero(&self) -> bool {
39        Zero::is_zero(self.as_raw()).into()
40    }
41
42    /// Returns scalar $S = 1$
43    ///
44    /// ```rust
45    /// use generic_ec::{Scalar, curves::Secp256k1};
46    /// use rand::rngs::OsRng;
47    ///
48    /// let s = Scalar::<Secp256k1>::random(&mut OsRng);
49    /// assert_eq!(s * Scalar::one(), s);
50    /// ```
51    pub fn one() -> Self {
52        Self::from_raw(E::Scalar::one())
53    }
54
55    /// Returns scalar inverse $S^{-1}$
56    ///
57    /// Inverse of scalar $S$ is a scalar $S^{-1}$ such as $S \cdot S^{-1} = 1$. Inverse doesn't
58    /// exist only for scalar $S = 0$, so function returns `None` if scalar is zero.
59    ///
60    /// ```rust
61    /// # fn func() -> Option<()> {
62    /// use generic_ec::{Scalar, curves::Secp256k1};
63    /// use rand::rngs::OsRng;
64    ///
65    /// let s = Scalar::<Secp256k1>::random(&mut OsRng);
66    /// let s_inv = s.invert()?;
67    /// assert_eq!(s * s_inv, Scalar::one());
68    /// # Some(()) }
69    /// # func();
70    /// ```
71    pub fn invert(&self) -> Option<Self> {
72        self.ct_invert().into()
73    }
74
75    /// Returns scalar inverse $S^{-1}$ (constant time)
76    ///
77    /// Same as [`Scalar::invert`] but performs constant-time check on whether it's zero
78    /// scalar
79    pub fn ct_invert(&self) -> CtOption<Self> {
80        let inv = Invertible::invert(self.as_raw());
81        inv.map(Self::from_raw)
82    }
83
84    /// Encodes scalar as bytes in big-endian order
85    ///
86    /// ```rust
87    /// use generic_ec::{Scalar, curves::Secp256k1};
88    /// use rand::rngs::OsRng;
89    ///
90    /// let s = Scalar::<Secp256k1>::random(&mut OsRng);
91    /// let bytes = s.to_be_bytes();
92    /// println!("Scalar hex representation: {}", hex::encode(bytes));
93    /// ```
94    pub fn to_be_bytes(&self) -> EncodedScalar<E> {
95        let bytes = self.as_raw().to_be_bytes();
96        EncodedScalar::new(bytes)
97    }
98
99    /// Encodes scalar as bytes in little-endian order
100    pub fn to_le_bytes(&self) -> EncodedScalar<E> {
101        let bytes = self.as_raw().to_le_bytes();
102        EncodedScalar::new(bytes)
103    }
104
105    /// Decodes scalar from its representation as bytes in big-endian order
106    ///
107    /// Returns error if encoded integer is larger than group order.
108    ///
109    /// ```rust
110    /// use generic_ec::{Scalar, curves::Secp256k1};
111    /// use rand::rngs::OsRng;
112    ///
113    /// let s = Scalar::<Secp256k1>::random(&mut OsRng);
114    /// let s_bytes = s.to_be_bytes();
115    /// let s_decoded = Scalar::from_be_bytes(&s_bytes)?;
116    /// assert_eq!(s, s_decoded);
117    /// # Ok::<(), Box<dyn std::error::Error>>(())
118    /// ```
119    pub fn from_be_bytes(bytes: impl AsRef<[u8]>) -> Result<Self, InvalidScalar> {
120        let bytes = bytes.as_ref();
121        let mut bytes_array = E::ScalarArray::zeroes();
122        let bytes_array_len = bytes_array.as_ref().len();
123        if bytes_array_len < bytes.len() {
124            return Err(InvalidScalar);
125        }
126        bytes_array.as_mut()[bytes_array_len - bytes.len()..].copy_from_slice(bytes);
127
128        let scalar = E::Scalar::from_be_bytes_exact(&bytes_array).ok_or(InvalidScalar)?;
129        Ok(Scalar::from_raw(scalar))
130    }
131
132    /// Decodes scalar from its representation as bytes in little-endian order
133    ///
134    /// Returns error if encoded integer is larger than group order.
135    pub fn from_le_bytes(bytes: impl AsRef<[u8]>) -> Result<Self, InvalidScalar> {
136        let bytes = bytes.as_ref();
137        let mut bytes_array = E::ScalarArray::zeroes();
138        let bytes_array_len = bytes_array.as_ref().len();
139        if bytes_array_len < bytes.len() {
140            return Err(InvalidScalar);
141        }
142        bytes_array.as_mut()[..bytes.len()].copy_from_slice(bytes);
143
144        let scalar = E::Scalar::from_le_bytes_exact(&bytes_array).ok_or(InvalidScalar)?;
145        Ok(Scalar::from_raw(scalar))
146    }
147
148    /// Interprets provided bytes as integer $i$ in big-endian order, returns scalar $s = i \mod q$
149    pub fn from_be_bytes_mod_order(bytes: impl AsRef<[u8]>) -> Self {
150        let scalar = E::Scalar::from_be_bytes_mod_order(bytes.as_ref());
151        Self::from_raw(scalar)
152    }
153
154    /// Interprets provided bytes as integer $i$ in little-endian order, returns scalar $s = i \mod q$
155    pub fn from_le_bytes_mod_order(bytes: impl AsRef<[u8]>) -> Self {
156        let scalar = E::Scalar::from_le_bytes_mod_order(bytes.as_ref());
157        Self::from_raw(scalar)
158    }
159
160    /// Samples a uniform scalar from source of randomness using constant-time algorithm
161    ///
162    /// Under the hood, it uses [`NonZero::<Scalar<E>>::random()`] method, therefore it shares
163    /// its guarantees and performance. Refer to its documentation to learn more.
164    pub fn random<R: RngCore>(rng: &mut R) -> Self {
165        NonZero::<Scalar<E>>::random(rng).into()
166    }
167
168    /// Samples a uniform scalar from source of randomness using vartime algorithm
169    ///
170    /// Under the hood, it uses [`NonZero::<Scalar<E>>::random_vartime()`] method, therefore it shares
171    /// its guarantees and performance. Refer to its documentation to learn more.
172    pub fn random_vartime<R: RngCore>(rng: &mut R) -> Self {
173        NonZero::<Scalar<E>>::random_vartime(rng).into()
174    }
175
176    #[doc = include_str!("../docs/hash_to_scalar.md")]
177    ///
178    /// ## Example
179    /// ```rust
180    /// use generic_ec::{Scalar, curves::Secp256k1};
181    /// use sha2::Sha256;
182    ///
183    /// #[derive(udigest::Digestable)]
184    /// struct Data<'a> {
185    ///     nonce: &'a [u8],
186    ///     param_a: &'a str,
187    ///     param_b: u128,
188    ///     // ...
189    /// }
190    ///
191    /// let scalar = Scalar::<Secp256k1>::from_hash::<Sha256>(&Data {
192    ///     nonce: b"some data",
193    ///     param_a: "some other data",
194    ///     param_b: 12345,
195    ///     // ...
196    /// });
197    /// ```
198    #[cfg(feature = "hash-to-scalar")]
199    pub fn from_hash<D: digest::Digest>(data: &impl udigest::Digestable) -> Self {
200        let mut rng = rand_hash::HashRng::<D, _>::from_seed(data);
201        Self::random(&mut rng)
202    }
203
204    /// Returns size of bytes buffer that can fit serialized scalar
205    pub fn serialized_len() -> usize {
206        E::ScalarArray::zeroes().as_ref().len()
207    }
208
209    /// Returns scalar big-endian representation in radix $2^4 = 16$
210    ///
211    /// Radix 16 representation is defined as sum:
212    ///
213    /// $$s = s_0 + s_1 16^1 + s_2 16^2 + \dots + s_{\log_{16}(s) - 1} 16^{\log_{16}(s) - 1}$$
214    ///
215    /// (note: we typically work with 256 bit scalars, so $\log_{16}(s) = \log_{16}(2^{256}) = 64$)
216    ///
217    /// Returns iterator over coefficients from most to least significant:
218    /// $s_{\log_{16}(s) - 1}, \dots, s_1, s_0$
219    pub fn as_radix16_be(&self) -> Radix16Iter<E> {
220        Radix16Iter::new(self.to_be_bytes(), true)
221    }
222
223    /// Returns scalar little-endian representation in radix $2^4 = 16$
224    ///
225    /// Radix 16 representation is defined as sum:
226    ///
227    /// $$s = s_0 + s_1 16^1 + s_2 16^2 + \dots + s_{\log_{16}(s) - 1} 16^{\log_{16}(s) - 1}$$
228    ///
229    /// (note: we typically work with 256 bit scalars, so $\log_{16}(s) = \log_{16}(2^{256}) = 64$)
230    ///
231    /// Returns iterator over coefficients from least to most significant:
232    /// $s_0, s_1, \dots, s_{\log_{16}(s) - 1}$
233    pub fn as_radix16_le(&self) -> Radix16Iter<E> {
234        Radix16Iter::new(self.to_le_bytes(), false)
235    }
236
237    /// Performs multiscalar multiplication
238    ///
239    /// Takes iterator of pairs `(scalar, point)`. Returns sum of `scalar * point`. Uses
240    /// [`Default`](crate::multiscalar::Default) algorithm.
241    ///
242    /// See [multiscalar module](crate::multiscalar) docs for more info.
243    pub fn multiscalar_mul<S, P>(
244        scalar_points: impl ExactSizeIterator<Item = (S, P)>,
245    ) -> crate::Point<E>
246    where
247        S: AsRef<Scalar<E>>,
248        P: AsRef<crate::Point<E>>,
249    {
250        use crate::multiscalar::MultiscalarMul;
251        crate::multiscalar::Default::multiscalar_mul(scalar_points)
252    }
253}
254
255impl<E: Curve> AsRaw for Scalar<E> {
256    type Raw = E::Scalar;
257
258    #[inline]
259    fn as_raw(&self) -> &E::Scalar {
260        &self.0
261    }
262}
263
264impl<E: Curve> Zeroize for Scalar<E> {
265    #[inline]
266    fn zeroize(&mut self) {
267        self.0.zeroize()
268    }
269}
270
271impl<E: Curve> FromRaw for Scalar<E> {
272    fn from_raw(scalar: E::Scalar) -> Self {
273        Self(scalar)
274    }
275}
276
277impl<E: Curve> ConditionallySelectable for Scalar<E> {
278    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
279        Scalar::from_raw(<E::Scalar as ConditionallySelectable>::conditional_select(
280            a.as_raw(),
281            b.as_raw(),
282            choice,
283        ))
284    }
285}
286
287impl<E: Curve> ConstantTimeEq for Scalar<E> {
288    fn ct_eq(&self, other: &Self) -> Choice {
289        self.as_raw().ct_eq(other.as_raw())
290    }
291}
292
293impl<E: Curve> AsRef<Scalar<E>> for Scalar<E> {
294    fn as_ref(&self) -> &Scalar<E> {
295        self
296    }
297}
298
299impl<E: Curve> crate::traits::IsZero for Scalar<E> {
300    fn is_zero(&self) -> bool {
301        self.is_zero()
302    }
303}
304
305impl<E: Curve> crate::traits::Zero for Scalar<E> {
306    fn zero() -> Self {
307        Scalar::zero()
308    }
309
310    fn is_zero(x: &Self) -> Choice {
311        Zero::is_zero(x.as_raw())
312    }
313}
314
315impl<E: Curve> crate::traits::One for Scalar<E> {
316    fn one() -> Self {
317        Self::one()
318    }
319
320    fn is_one(x: &Self) -> Choice {
321        One::is_one(x.as_raw())
322    }
323}
324
325impl<E: Curve> crate::traits::Samplable for Scalar<E> {
326    fn random<R: RngCore>(rng: &mut R) -> Self {
327        Self::random(rng)
328    }
329
330    fn random_vartime<R: rand_core::RngCore>(rng: &mut R) -> Self {
331        Self::random_vartime(rng)
332    }
333}
334
335impl<E: Curve> iter::Sum for Scalar<E> {
336    fn sum<I: Iterator<Item = Self>>(mut iter: I) -> Self {
337        let Some(first_scalar) = iter.next() else {
338            return Scalar::zero();
339        };
340        iter.fold(first_scalar, |acc, x| acc + x)
341    }
342}
343
344impl<'a, E: Curve> iter::Sum<&'a Scalar<E>> for Scalar<E> {
345    fn sum<I: Iterator<Item = &'a Self>>(mut iter: I) -> Self {
346        let Some(first_scalar) = iter.next() else {
347            return Scalar::zero();
348        };
349        iter.fold(*first_scalar, |acc, x| acc + x)
350    }
351}
352
353impl<E: Curve> iter::Product for Scalar<E> {
354    fn product<I: Iterator<Item = Self>>(mut iter: I) -> Self {
355        let Some(first_scalar) = iter.next() else {
356            return Scalar::one();
357        };
358        iter.fold(first_scalar, |acc, x| acc * x)
359    }
360}
361
362impl<'a, E: Curve> iter::Product<&'a Scalar<E>> for Scalar<E> {
363    fn product<I: Iterator<Item = &'a Self>>(mut iter: I) -> Self {
364        let Some(first_scalar) = iter.next() else {
365            return Scalar::one();
366        };
367        iter.fold(*first_scalar, |acc, x| acc * x)
368    }
369}
370
371macro_rules! impl_from_primitive_integer {
372    ($($int:ident),+) => {$(
373        impl<E: Curve> From<$int> for Scalar<E> {
374            fn from(i: $int) -> Self {
375                Scalar::from_le_bytes(&i.to_le_bytes())
376                    .expect("scalar should be large enough to fit a primitive integer")
377            }
378        }
379    )+};
380}
381
382macro_rules! impl_from_signed_integer {
383    ($($iint:ident),+) => {$(
384        impl<E: Curve> From<$iint> for Scalar<E> {
385            fn from(i: $iint) -> Self {
386                use subtle::{ConditionallyNegatable, Choice};
387                // TODO: what's a better way to do that check in constant time?
388                let is_neg = Choice::from(u8::from(i.is_negative()));
389                let i = i.unsigned_abs();
390                let mut i = Scalar::from(i);
391                i.conditional_negate(is_neg);
392                i
393            }
394        }
395    )+};
396}
397
398impl_from_primitive_integer! {
399    u8, u16, u32, u64, u128, usize
400}
401impl_from_signed_integer! {
402    i8, i16, i32, i64, i128
403}
404
405impl<E: Curve> fmt::Debug for Scalar<E> {
406    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
407        let mut s = f.debug_struct("Scalar");
408        s.field("curve", &E::CURVE_NAME);
409        #[cfg(feature = "std")]
410        {
411            let scalar_hex = hex::encode(self.to_be_bytes());
412            s.field("value", &scalar_hex);
413        }
414        #[cfg(not(feature = "std"))]
415        {
416            s.field("value", &"...");
417        }
418        s.finish()
419    }
420}
421
422#[allow(clippy::derived_hash_with_manual_eq)]
423impl<E: Curve> Hash for Scalar<E> {
424    fn hash<H: hash::Hasher>(&self, state: &mut H) {
425        state.write(self.to_be_bytes().as_bytes())
426    }
427}
428
429impl<E: Curve> PartialOrd for Scalar<E> {
430    fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
431        Some(self.cmp(other))
432    }
433}
434
435impl<E: Curve> Ord for Scalar<E> {
436    fn cmp(&self, other: &Self) -> core::cmp::Ordering {
437        self.to_be_bytes()
438            .as_bytes()
439            .cmp(other.to_be_bytes().as_bytes())
440    }
441}
442
443#[cfg(feature = "udigest")]
444impl<E: Curve> udigest::Digestable for Scalar<E> {
445    fn unambiguously_encode<B>(&self, encoder: udigest::encoding::EncodeValue<B>)
446    where
447        B: udigest::Buffer,
448    {
449        let mut s = encoder.encode_struct();
450        s.add_field("curve").encode_leaf_value(E::CURVE_NAME);
451        s.add_field("scalar").encode_leaf_value(self.to_be_bytes());
452        s.finish();
453    }
454}
455
456/// Iterator over scalar coefficients in radix 16 representation
457///
458/// See [`Scalar::as_radix16_be`] and [`Scalar::as_radix16_le`]
459pub struct Radix16Iter<E: Curve> {
460    /// radix 256 representation of the scalar
461    encoded_scalar: EncodedScalar<E>,
462    next_radix16: Option<u8>,
463    next_index: usize,
464
465    /// Indicates that output is in big-endian. If it's false,
466    /// output is in little-endian
467    is_be: bool,
468}
469
470impl<E: Curve> Radix16Iter<E> {
471    fn new(encoded_scalar: EncodedScalar<E>, is_be: bool) -> Self {
472        Self {
473            encoded_scalar,
474            is_be,
475            next_radix16: None,
476            next_index: 0,
477        }
478    }
479}
480
481impl<E: Curve> Iterator for Radix16Iter<E> {
482    type Item = u8;
483
484    fn next(&mut self) -> Option<Self::Item> {
485        if let Some(next_radix16) = self.next_radix16.take() {
486            return Some(next_radix16);
487        }
488
489        let next_radix256 = self.encoded_scalar.get(self.next_index)?;
490        self.next_index += 1;
491
492        let high_radix16 = next_radix256 >> 4;
493        let low_radix16 = next_radix256 & 0xF;
494        debug_assert_eq!((high_radix16 << 4) | low_radix16, *next_radix256);
495        debug_assert_eq!(high_radix16 & (!0xF), 0);
496        debug_assert_eq!(low_radix16 & (!0xF), 0);
497
498        if self.is_be {
499            self.next_radix16 = Some(low_radix16);
500            Some(high_radix16)
501        } else {
502            self.next_radix16 = Some(high_radix16);
503            Some(low_radix16)
504        }
505    }
506
507    fn size_hint(&self) -> (usize, Option<usize>) {
508        let len = self.len();
509        (len, Some(len))
510    }
511}
512
513impl<E: Curve> ExactSizeIterator for Radix16Iter<E> {
514    fn len(&self) -> usize {
515        self.encoded_scalar[self.next_index..].len() * 2
516            + if self.next_radix16.is_some() { 1 } else { 0 }
517    }
518}
519
520impl<E: Curve, const N: usize> crate::traits::Reduce<N> for Scalar<E>
521where
522    E::Scalar: crate::traits::Reduce<N>,
523{
524    fn from_be_array_mod_order(bytes: &[u8; N]) -> Self {
525        Self::from_raw(Reduce::from_be_array_mod_order(bytes))
526    }
527    fn from_le_array_mod_order(bytes: &[u8; N]) -> Self {
528        Self::from_raw(Reduce::from_le_array_mod_order(bytes))
529    }
530}