umbral_pre/
keys.rs

1use alloc::boxed::Box;
2use alloc::format;
3use alloc::string::{String, ToString};
4use core::cmp::Ordering;
5use core::fmt;
6
7use generic_array::{
8    typenum::{Unsigned, U32, U64},
9    GenericArray,
10};
11use k256::{
12    ecdsa::{
13        signature::{DigestVerifier, RandomizedDigestSigner},
14        RecoveryId, Signature as BackendSignature, SigningKey, VerifyingKey,
15    },
16    elliptic_curve::{FieldBytes, PublicKey as BackendPublicKey, SecretKey as BackendSecretKey},
17};
18use rand_core::{CryptoRng, RngCore};
19use sha2::digest::{Digest, FixedOutput};
20use zeroize::ZeroizeOnDrop;
21
22#[cfg(feature = "default-rng")]
23use rand_core::OsRng;
24
25#[cfg(feature = "serde")]
26use serde::{Deserialize, Deserializer, Serialize, Serializer};
27
28use crate::curve::{CompressedPointSize, CurvePoint, CurveType, NonZeroCurveScalar, ScalarSize};
29use crate::dem::kdf;
30use crate::hashing::{BackendDigest, Hash, ScalarDigest};
31use crate::secret_box::SecretBox;
32use crate::traits::{fmt_public, fmt_secret, SizeMismatchError};
33
34#[cfg(feature = "serde")]
35use crate::serde_bytes::{
36    deserialize_with_encoding, serialize_with_encoding, Encoding, TryFromBytes,
37};
38
39/// ECDSA signature object.
40#[derive(Clone, Debug, PartialEq, Eq)]
41pub struct Signature(BackendSignature);
42
43impl Signature {
44    /// Returns the signature serialized as concatenated `r` and `s`
45    /// in big endian order (32+32 bytes).
46    pub fn to_be_bytes(&self) -> Box<[u8]> {
47        AsRef::<[u8]>::as_ref(&self.0.to_bytes()).into()
48    }
49
50    /// Restores the signature from concatenated `r` and `s`
51    /// in big endian order (32+32 bytes).
52    pub fn try_from_be_bytes(bytes: &[u8]) -> Result<Self, String> {
53        if bytes.len() != 64 {
54            return Err(format!(
55                "Invalid length for a serialized signature: {}, expected 64",
56                bytes.len()
57            ));
58        }
59
60        let r = FieldBytes::<CurveType>::from_slice(&bytes[0..32]);
61        let s = FieldBytes::<CurveType>::from_slice(&bytes[32..64]);
62        BackendSignature::from_scalars(*r, *s)
63            .map(Self)
64            .map_err(|err| format!("Internal backend error: {err}"))
65    }
66
67    /// Returns the signature serialized in ASN.1 DER format.
68    pub fn to_der_bytes(&self) -> Box<[u8]> {
69        self.0.to_der().as_bytes().into()
70    }
71
72    /// Restores the signature from a bytestring in ASN.1 DER format.
73    pub fn try_from_der_bytes(bytes: &[u8]) -> Result<Self, String> {
74        // Note that it will not normalize `s` automatically,
75        // and if it is not normalized, verification will fail.
76        BackendSignature::from_der(bytes)
77            .map(Self)
78            .map_err(|err| format!("Internal backend error: {err}"))
79    }
80
81    /// Verifies that the given message was signed with the secret counterpart of the given key.
82    /// The message is hashed internally.
83    pub fn verify(&self, verifying_pk: &PublicKey, message: &[u8]) -> bool {
84        verifying_pk.verify_digest(digest_for_signing(message), self)
85    }
86
87    pub(crate) fn get_recovery_id(
88        &self,
89        verifying_pk: &PublicKey,
90        message: &[u8],
91    ) -> Option<RecoveryId> {
92        let digest = digest_for_signing(message);
93        RecoveryId::trial_recovery_from_digest(&verifying_pk.0.into(), digest, &self.0).ok()
94    }
95}
96
97#[cfg(feature = "serde")]
98impl Serialize for Signature {
99    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
100    where
101        S: Serializer,
102    {
103        serialize_with_encoding(&self.to_be_bytes(), serializer, Encoding::Hex)
104    }
105}
106
107#[cfg(feature = "serde")]
108impl<'de> Deserialize<'de> for Signature {
109    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
110    where
111        D: Deserializer<'de>,
112    {
113        deserialize_with_encoding(deserializer, Encoding::Hex)
114    }
115}
116
117#[cfg(feature = "serde")]
118impl TryFromBytes for Signature {
119    type Error = String;
120
121    fn try_from_bytes(bytes: &[u8]) -> Result<Self, Self::Error> {
122        Self::try_from_be_bytes(bytes)
123    }
124}
125
126impl fmt::Display for Signature {
127    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
128        fmt_public("Signature", &self.to_be_bytes(), f)
129    }
130}
131
132/// A signature with the recovery byte attached.
133#[derive(Clone, Debug, PartialEq, Eq)]
134pub struct RecoverableSignature {
135    signature: Signature,
136    recovery_id: RecoveryId,
137}
138
139impl RecoverableSignature {
140    /// Create the recoverable signature from the signature proper and the recovery flag.
141    #[cfg(test)]
142    pub(crate) fn from_normalized(signature: Signature, is_y_odd: bool) -> Self {
143        Self {
144            signature,
145            // TODO: currently `RecoveryId.is_x_reduced` is ignored during recovery
146            // (see `VerifyingKey::recover_from_prehash()`), so we set it to `false`.
147            // Make sure this parameter is handled properly if this method is made public.
148            recovery_id: RecoveryId::new(is_y_odd, false),
149        }
150    }
151
152    /// Returns the signature serialized as concatenated `r`, `s`, and `v`
153    /// in big endian order (32+32+1 bytes).
154    pub fn to_be_bytes(&self) -> Box<[u8]> {
155        [
156            AsRef::<[u8]>::as_ref(&self.signature.to_be_bytes()),
157            &[self.recovery_id.to_byte()],
158        ]
159        .concat()
160        .into()
161    }
162
163    /// Restores the signature from concatenated `r`, `s`, and `v`
164    /// in big endian order (32+32+1 bytes).
165    pub fn try_from_be_bytes(bytes: &[u8]) -> Result<Self, String> {
166        // The signature size is exported from `ecdsa` crate (as `SignatureSize` type),
167        // and I don't really want to add it to dependencies just because of this constant.
168        const SIGNATURE_SIZE: usize = 64;
169
170        if bytes.len() != SIGNATURE_SIZE + 1 {
171            return Err("Invalid size of a recoverable signature".into());
172        }
173
174        let signature = Signature::try_from_be_bytes(&bytes[0..SIGNATURE_SIZE])?;
175        let recovery_id = RecoveryId::from_byte(bytes[SIGNATURE_SIZE])
176            .ok_or_else(|| "Invalid recovery byte".to_string())?;
177
178        Ok(RecoverableSignature {
179            signature,
180            recovery_id,
181        })
182    }
183}
184
185#[cfg(feature = "serde")]
186impl Serialize for RecoverableSignature {
187    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
188    where
189        S: Serializer,
190    {
191        serialize_with_encoding(&self.to_be_bytes(), serializer, Encoding::Hex)
192    }
193}
194
195#[cfg(feature = "serde")]
196impl<'de> Deserialize<'de> for RecoverableSignature {
197    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
198    where
199        D: Deserializer<'de>,
200    {
201        deserialize_with_encoding(deserializer, Encoding::Hex)
202    }
203}
204
205#[cfg(feature = "serde")]
206impl TryFromBytes for RecoverableSignature {
207    type Error = String;
208
209    fn try_from_bytes(bytes: &[u8]) -> Result<Self, Self::Error> {
210        Self::try_from_be_bytes(bytes)
211    }
212}
213
214impl fmt::Display for RecoverableSignature {
215    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
216        fmt_public("RecoverableSignature", &self.to_be_bytes(), f)
217    }
218}
219
220/// A secret key.
221#[derive(Clone, ZeroizeOnDrop, PartialEq, Eq)]
222pub struct SecretKey(BackendSecretKey<CurveType>);
223
224impl SecretKey {
225    fn new(sk: BackendSecretKey<CurveType>) -> Self {
226        Self(sk)
227    }
228
229    /// Creates a secret key using the given RNG.
230    pub fn random_with_rng(rng: &mut (impl CryptoRng + RngCore)) -> Self {
231        Self::new(BackendSecretKey::<CurveType>::random(rng))
232    }
233
234    /// Creates a secret key using the default RNG.
235    #[cfg(feature = "default-rng")]
236    pub fn random() -> Self {
237        Self::random_with_rng(&mut OsRng)
238    }
239
240    /// Returns a public key corresponding to this secret key.
241    pub fn public_key(&self) -> PublicKey {
242        PublicKey(self.0.public_key())
243    }
244
245    fn from_nonzero_scalar(scalar: SecretBox<NonZeroCurveScalar>) -> Self {
246        let backend_scalar_ref = scalar.as_secret().as_backend_scalar();
247        Self::new(BackendSecretKey::<CurveType>::from(backend_scalar_ref))
248    }
249
250    /// Returns a reference to the underlying scalar of the secret key.
251    pub(crate) fn to_secret_scalar(&self) -> SecretBox<NonZeroCurveScalar> {
252        let backend_scalar = SecretBox::new(self.0.to_nonzero_scalar());
253        SecretBox::new(NonZeroCurveScalar::from_backend_scalar(
254            *backend_scalar.as_secret(),
255        ))
256    }
257
258    /// Serializes the secret key as a scalar in the big-endian representation.
259    pub fn to_be_bytes(&self) -> SecretBox<GenericArray<u8, ScalarSize>> {
260        SecretBox::new(self.0.to_bytes())
261    }
262
263    /// Deserializes the secret key from a scalar in the big-endian representation.
264    pub fn try_from_be_bytes(
265        bytes: &SecretBox<GenericArray<u8, ScalarSize>>,
266    ) -> Result<Self, String> {
267        BackendSecretKey::<CurveType>::from_bytes(bytes.as_secret())
268            .map(Self::new)
269            .map_err(|err| format!("{err}"))
270    }
271}
272
273impl fmt::Display for SecretKey {
274    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
275        fmt_secret("SecretKey", f)
276    }
277}
278
279pub(crate) fn digest_for_signing(message: &[u8]) -> BackendDigest {
280    Hash::new().chain_bytes(message).digest()
281}
282
283/// An object used to sign messages.
284/// For security reasons cannot be serialized.
285#[derive(Clone, ZeroizeOnDrop)]
286pub struct Signer(SigningKey);
287
288impl Signer {
289    /// Creates a new signer out of a secret key.
290    pub fn new(sk: SecretKey) -> Self {
291        Self(SigningKey::from(sk.0.clone()))
292    }
293
294    /// Signs the given message using the given RNG.
295    pub fn sign_with_rng(&self, rng: &mut (impl CryptoRng + RngCore), message: &[u8]) -> Signature {
296        let digest = digest_for_signing(message);
297        Signature(self.0.sign_digest_with_rng(rng, digest))
298    }
299
300    /// Signs the given message using the default RNG.
301    #[cfg(feature = "default-rng")]
302    pub fn sign(&self, message: &[u8]) -> Signature {
303        self.sign_with_rng(&mut OsRng, message)
304    }
305
306    /// Returns the public key that can be used to verify the signatures produced by this signer.
307    pub fn verifying_key(&self) -> PublicKey {
308        PublicKey(self.0.verifying_key().into())
309    }
310}
311
312impl fmt::Display for Signer {
313    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
314        fmt_secret("Signer", f)
315    }
316}
317
318/// A public key.
319///
320/// Create using [`SecretKey::public_key`].
321#[derive(Clone, Copy, Debug, PartialEq, Eq)]
322pub struct PublicKey(BackendPublicKey<CurveType>);
323
324impl PublicKey {
325    /// Returns the underlying curve point of the public key.
326    pub(crate) fn to_point(self) -> CurvePoint {
327        CurvePoint::from_backend_point(&self.0.to_projective())
328    }
329
330    /// Verifies the signature.
331    pub(crate) fn verify_digest(
332        &self,
333        digest: impl Digest<OutputSize = U32> + FixedOutput,
334        signature: &Signature,
335    ) -> bool {
336        let verifier = VerifyingKey::from(&self.0);
337        verifier.verify_digest(digest, &signature.0).is_ok()
338    }
339
340    /// Recovers the public key from a prehashed message
341    /// and the corresponding recoverable signature.
342    pub fn recover_from_prehash(
343        prehash: &[u8],
344        signature: &RecoverableSignature,
345    ) -> Result<Self, String> {
346        let vkey = VerifyingKey::recover_from_prehash(
347            prehash,
348            &signature.signature.0,
349            signature.recovery_id,
350        )
351        .map_err(|err| format!("Internal backend error: {err}"))?;
352        Ok(Self(vkey.into()))
353    }
354
355    /// Restores the public key from a compressed curve point.
356    pub fn try_from_compressed_bytes(bytes: &[u8]) -> Result<Self, String> {
357        let cp = CurvePoint::try_from_compressed_bytes(bytes)?;
358        BackendPublicKey::<CurveType>::try_from(cp.as_backend_point())
359            .map(Self)
360            .map_err(|_| "Cannot instantiate a public key from the given curve point".into())
361    }
362
363    /// Retunrs the serialized pubic key as the compressed curve point.
364    pub fn to_compressed_bytes(self) -> Box<[u8]> {
365        let arr: GenericArray<u8, CompressedPointSize> = self.to_point().to_compressed_array();
366        let slice: &[u8] = arr.as_ref();
367        slice.into()
368    }
369
370    /// Retunrs the serialized pubic key as the uncompressed curve point.
371    pub fn to_uncompressed_bytes(self) -> Box<[u8]> {
372        self.to_point().to_uncompressed_bytes()
373    }
374}
375
376#[cfg(feature = "serde")]
377impl Serialize for PublicKey {
378    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
379    where
380        S: Serializer,
381    {
382        serialize_with_encoding(&self.to_compressed_bytes(), serializer, Encoding::Hex)
383    }
384}
385
386#[cfg(feature = "serde")]
387impl<'de> Deserialize<'de> for PublicKey {
388    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
389    where
390        D: Deserializer<'de>,
391    {
392        deserialize_with_encoding(deserializer, Encoding::Hex)
393    }
394}
395
396#[cfg(feature = "serde")]
397impl TryFromBytes for PublicKey {
398    type Error = String;
399
400    fn try_from_bytes(bytes: &[u8]) -> Result<Self, Self::Error> {
401        Self::try_from_compressed_bytes(bytes)
402    }
403}
404
405impl fmt::Display for PublicKey {
406    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
407        fmt_public("PublicKey", &self.to_compressed_bytes(), f)
408    }
409}
410
411type SecretKeyFactorySeedSize = U32; // the size of the seed material for key derivation
412type SecretKeyFactoryDerivedSize = U64; // the size of the derived key (before hashing to scalar)
413type SecretKeyFactorySeed = GenericArray<u8, SecretKeyFactorySeedSize>;
414
415/// This class handles keyring material for Umbral, by allowing deterministic
416/// derivation of `SecretKey` objects based on labels.
417#[derive(Clone, ZeroizeOnDrop, PartialEq)]
418pub struct SecretKeyFactory(SecretBox<SecretKeyFactorySeed>);
419
420impl SecretKeyFactory {
421    /// Creates a secret key factory using the given RNG.
422    pub fn random_with_rng(rng: &mut (impl CryptoRng + RngCore)) -> Self {
423        let mut bytes = SecretBox::new(SecretKeyFactorySeed::default());
424        rng.fill_bytes(bytes.as_mut_secret());
425        Self(bytes)
426    }
427
428    /// Creates a secret key factory using the default RNG.
429    #[cfg(feature = "default-rng")]
430    pub fn random() -> Self {
431        Self::random_with_rng(&mut OsRng)
432    }
433
434    /// Returns the seed size required by
435    /// [`from_secure_randomness`](`SecretKeyFactory::from_secure_randomness`).
436    pub fn seed_size() -> usize {
437        SecretKeyFactorySeedSize::to_usize()
438    }
439
440    /// Creates a secret key factory using the given random bytes.
441    ///
442    /// **Warning:** make sure the given seed has been obtained
443    /// from a cryptographically secure source of randomness!
444    pub fn from_secure_randomness(seed: &[u8]) -> Result<Self, SizeMismatchError> {
445        let received_size = seed.len();
446        let expected_size = Self::seed_size();
447        match received_size.cmp(&expected_size) {
448            Ordering::Greater | Ordering::Less => {
449                Err(SizeMismatchError::new(received_size, expected_size))
450            }
451            Ordering::Equal => Ok(Self(SecretBox::new(*SecretKeyFactorySeed::from_slice(
452                seed,
453            )))),
454        }
455    }
456
457    /// Creates an untyped bytestring deterministically from the given label.
458    /// This can be used externally to seed some kind of a secret key.
459    pub fn make_secret(
460        &self,
461        label: &[u8],
462    ) -> SecretBox<GenericArray<u8, SecretKeyFactoryDerivedSize>> {
463        let prefix = b"SECRET_DERIVATION/";
464        let info = [prefix, label].concat();
465        kdf::<SecretKeyFactoryDerivedSize>(self.0.as_secret(), None, Some(&info))
466    }
467
468    /// Creates a `SecretKey` deterministically from the given label.
469    pub fn make_key(&self, label: &[u8]) -> SecretKey {
470        let prefix = b"KEY_DERIVATION/";
471        let info = [prefix, label].concat();
472        let key = kdf::<SecretKeyFactoryDerivedSize>(self.0.as_secret(), None, Some(&info));
473        let nz_scalar = SecretBox::new(
474            ScalarDigest::new_with_dst(&info)
475                .chain_secret_bytes(&key)
476                .finalize(),
477        );
478        SecretKey::from_nonzero_scalar(nz_scalar)
479    }
480
481    /// Creates a `SecretKeyFactory` deterministically from the given label.
482    pub fn make_factory(&self, label: &[u8]) -> Self {
483        let prefix = b"FACTORY_DERIVATION/";
484        let info = [prefix, label].concat();
485        let derived_seed = kdf::<SecretKeyFactorySeedSize>(self.0.as_secret(), None, Some(&info));
486        Self(derived_seed)
487    }
488}
489
490impl fmt::Display for SecretKeyFactory {
491    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
492        fmt_secret("SecretKeyFactory", f)
493    }
494}
495
496#[cfg(test)]
497mod tests {
498
499    use sha2::digest::Digest;
500
501    use super::{
502        digest_for_signing, PublicKey, RecoverableSignature, SecretKey, SecretKeyFactory, Signer,
503    };
504
505    #[cfg(feature = "serde")]
506    use crate::serde_bytes::tests::check_serialization_roundtrip;
507
508    #[test]
509    fn test_secret_key_factory() {
510        let skf = SecretKeyFactory::random();
511        let sk1 = skf.make_key(b"foo");
512        let sk2 = skf.make_key(b"foo");
513        let sk3 = skf.make_key(b"bar");
514
515        assert!(sk1 == sk2);
516        assert!(sk1 != sk3);
517    }
518
519    #[test]
520    fn test_sign_and_verify() {
521        let sk = SecretKey::random();
522        let message = b"asdafdahsfdasdfasd";
523        let signer = Signer::new(sk.clone());
524        let signature = signer.sign(message);
525
526        let pk = sk.public_key();
527        let vk = signer.verifying_key();
528
529        assert_eq!(pk, vk);
530        assert!(signature.verify(&vk, message));
531    }
532
533    #[test]
534    fn test_sign_and_recover() {
535        let sk = SecretKey::random();
536        let pk = sk.public_key();
537        let message: &[u8] = b"asdafdahsfdasdfasd";
538        let message_prehashed = digest_for_signing(message).finalize();
539
540        let signer = Signer::new(sk);
541        let signature = signer.sign(message);
542
543        for flag in [false, true] {
544            let rsig = RecoverableSignature::from_normalized(signature.clone(), flag);
545            let pk_rec = PublicKey::recover_from_prehash(&message_prehashed, &rsig).unwrap();
546
547            if pk == pk_rec {
548                // Test serialization
549                let sig_bytes = signature.to_be_bytes();
550                let rsig_bytes = [AsRef::<[u8]>::as_ref(&sig_bytes), &[flag as u8]].concat();
551                assert_eq!(&rsig_bytes.clone().into_boxed_slice(), &rsig.to_be_bytes());
552
553                let rsig_deserialized =
554                    RecoverableSignature::try_from_be_bytes(&rsig_bytes).unwrap();
555
556                let pk_rec =
557                    PublicKey::recover_from_prehash(&message_prehashed, &rsig_deserialized)
558                        .unwrap();
559                assert_eq!(&pk, &pk_rec);
560                assert_eq!(rsig, rsig_deserialized);
561
562                return;
563            }
564        }
565
566        panic!("Could not find a flag that would recover the original public key");
567    }
568
569    #[cfg(feature = "serde")]
570    #[test]
571    fn test_serialize_signature() {
572        let message = b"asdafdahsfdasdfasd";
573        let signer = Signer::new(SecretKey::random());
574        let signature = signer.sign(message);
575
576        check_serialization_roundtrip(&signature);
577    }
578
579    #[cfg(feature = "serde")]
580    #[test]
581    fn test_serialize_recoverable_signature() {
582        let message = b"asdafdahsfdasdfasd";
583        let signer = Signer::new(SecretKey::random());
584        let signature = signer.sign(message);
585        let rsig = RecoverableSignature::from_normalized(signature, true);
586        check_serialization_roundtrip(&rsig);
587    }
588
589    #[cfg(feature = "serde")]
590    #[test]
591    fn test_serialize_public_key() {
592        let signer = Signer::new(SecretKey::random());
593        let pk = signer.verifying_key();
594
595        check_serialization_roundtrip(&pk);
596    }
597}