libsignal_core_syft/
curve.rs

1//
2// Copyright 2020-2021 Signal Messenger, LLC.
3// SPDX-License-Identifier: AGPL-3.0-only
4//
5
6mod curve25519;
7mod utils;
8
9use std::cmp::Ordering;
10use std::fmt;
11
12use curve25519_dalek::{MontgomeryPoint, scalar};
13use rand::{CryptoRng, Rng};
14use subtle::ConstantTimeEq;
15
16#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
17pub enum KeyType {
18    Djb,
19}
20
21impl fmt::Display for KeyType {
22    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
23        fmt::Debug::fmt(self, f)
24    }
25}
26
27impl KeyType {
28    fn value(&self) -> u8 {
29        match &self {
30            KeyType::Djb => 0x05u8,
31        }
32    }
33}
34
35#[derive(Debug, displaydoc::Display)]
36pub enum CurveError {
37    /// no key type identifier
38    NoKeyTypeIdentifier,
39    /// bad key type <{0:#04x}>
40    BadKeyType(u8),
41    /// bad key length <{1}> for key with type <{0}>
42    BadKeyLength(KeyType, usize),
43}
44
45impl std::error::Error for CurveError {}
46
47impl TryFrom<u8> for KeyType {
48    type Error = CurveError;
49
50    fn try_from(x: u8) -> Result<Self, CurveError> {
51        match x {
52            0x05u8 => Ok(KeyType::Djb),
53            t => Err(CurveError::BadKeyType(t)),
54        }
55    }
56}
57
58#[derive(Debug, Clone, Copy, Eq, PartialEq)]
59enum PublicKeyData {
60    DjbPublicKey([u8; curve25519::PUBLIC_KEY_LENGTH]),
61}
62
63#[derive(Clone, Copy, Eq, derive_more::From)]
64pub struct PublicKey {
65    key: PublicKeyData,
66}
67
68impl PublicKey {
69    fn new(key: PublicKeyData) -> Self {
70        Self { key }
71    }
72
73    pub fn deserialize(value: &[u8]) -> Result<Self, CurveError> {
74        let (key_type, value) = value.split_first().ok_or(CurveError::NoKeyTypeIdentifier)?;
75        let key_type = KeyType::try_from(*key_type)?;
76        match key_type {
77            KeyType::Djb => {
78                let (key, tail): (&[u8; curve25519::PUBLIC_KEY_LENGTH], _) = value
79                    .split_first_chunk()
80                    .ok_or(CurveError::BadKeyLength(KeyType::Djb, value.len() + 1))?;
81                // We currently allow trailing data after the public key.
82                // TODO: once this is known to not be seen in practice, make this a hard error.
83                if !tail.is_empty() {
84                    log::warn!(
85                        "ECPublicKey deserialized with {} trailing bytes",
86                        tail.len()
87                    );
88                }
89                Ok(PublicKey {
90                    key: PublicKeyData::DjbPublicKey(*key),
91                })
92            }
93        }
94    }
95
96    pub fn public_key_bytes(&self) -> &[u8] {
97        match &self.key {
98            PublicKeyData::DjbPublicKey(v) => v,
99        }
100    }
101
102    pub fn from_djb_public_key_bytes(bytes: &[u8]) -> Result<Self, CurveError> {
103        match <[u8; curve25519::PUBLIC_KEY_LENGTH]>::try_from(bytes) {
104            Err(_) => Err(CurveError::BadKeyLength(KeyType::Djb, bytes.len())),
105            Ok(key) => Ok(PublicKey {
106                key: PublicKeyData::DjbPublicKey(key),
107            }),
108        }
109    }
110
111    pub fn serialize(&self) -> Box<[u8]> {
112        let value_len = match &self.key {
113            PublicKeyData::DjbPublicKey(v) => v.len(),
114        };
115        let mut result = Vec::with_capacity(1 + value_len);
116        result.push(self.key_type().value());
117        match &self.key {
118            PublicKeyData::DjbPublicKey(v) => result.extend_from_slice(v),
119        }
120        result.into_boxed_slice()
121    }
122
123    pub fn verify_signature(&self, message: &[u8], signature: &[u8]) -> bool {
124        self.verify_signature_for_multipart_message(&[message], signature)
125    }
126
127    pub fn verify_signature_for_multipart_message(
128        &self,
129        message: &[&[u8]],
130        signature: &[u8],
131    ) -> bool {
132        match &self.key {
133            PublicKeyData::DjbPublicKey(pub_key) => {
134                let Ok(signature) = signature.try_into() else {
135                    return false;
136                };
137                curve25519::PrivateKey::verify_signature(pub_key, message, signature)
138            }
139        }
140    }
141
142    fn key_data(&self) -> &[u8] {
143        match &self.key {
144            PublicKeyData::DjbPublicKey(k) => k.as_ref(),
145        }
146    }
147
148    pub fn key_type(&self) -> KeyType {
149        match &self.key {
150            PublicKeyData::DjbPublicKey(_) => KeyType::Djb,
151        }
152    }
153
154    fn is_torsion_free(&self) -> bool {
155        match &self.key {
156            PublicKeyData::DjbPublicKey(k) => {
157                let mont_point = MontgomeryPoint(*k);
158                mont_point
159                    .to_edwards(0)
160                    .is_some_and(|ed| ed.is_torsion_free())
161            }
162        }
163    }
164
165    fn scalar_is_in_range(&self) -> bool {
166        match &self.key {
167            PublicKeyData::DjbPublicKey(k) => {
168                // it is not true that the scalar is greater than 2^255 - 19
169                // specifically, it is not true that either the high bit is set
170                // or that the high 247 bits are all 1 and the bottom byte is >(2^8 - 19)
171                !(k[31] & 0b1000_0000_u8 != 0
172                    || (k[0] >= 0u8.wrapping_sub(19) && k[1..31] == [0xFFu8; 30] && k[31] == 0x7F))
173            }
174        }
175    }
176
177    pub fn is_canonical(&self) -> bool {
178        self.is_torsion_free() && self.scalar_is_in_range()
179    }
180}
181
182impl TryFrom<&[u8]> for PublicKey {
183    type Error = CurveError;
184
185    fn try_from(value: &[u8]) -> Result<Self, CurveError> {
186        Self::deserialize(value)
187    }
188}
189
190impl subtle::ConstantTimeEq for PublicKey {
191    /// A constant-time comparison as long as the two keys have a matching type.
192    ///
193    /// If the two keys have different types, the comparison short-circuits,
194    /// much like comparing two slices of different lengths.
195    fn ct_eq(&self, other: &PublicKey) -> subtle::Choice {
196        if self.key_type() != other.key_type() {
197            return 0.ct_eq(&1);
198        }
199        self.key_data().ct_eq(other.key_data())
200    }
201}
202
203impl PartialEq for PublicKey {
204    fn eq(&self, other: &PublicKey) -> bool {
205        bool::from(self.ct_eq(other))
206    }
207}
208
209impl Ord for PublicKey {
210    fn cmp(&self, other: &Self) -> Ordering {
211        if self.key_type() != other.key_type() {
212            return self.key_type().cmp(&other.key_type());
213        }
214
215        utils::constant_time_cmp(self.key_data(), other.key_data())
216    }
217}
218
219impl PartialOrd for PublicKey {
220    fn partial_cmp(&self, other: &PublicKey) -> Option<Ordering> {
221        Some(self.cmp(other))
222    }
223}
224
225impl fmt::Debug for PublicKey {
226    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
227        write!(
228            f,
229            "PublicKey {{ key_type={}, serialize={:?} }}",
230            self.key_type(),
231            self.serialize()
232        )
233    }
234}
235
236#[derive(Debug, Clone, Copy, Eq, PartialEq)]
237enum PrivateKeyData {
238    DjbPrivateKey([u8; curve25519::PRIVATE_KEY_LENGTH]),
239}
240
241#[derive(Clone, Copy, Eq, PartialEq, derive_more::From)]
242pub struct PrivateKey {
243    key: PrivateKeyData,
244}
245
246impl PrivateKey {
247    pub fn deserialize(value: &[u8]) -> Result<Self, CurveError> {
248        let mut key: [u8; curve25519::PRIVATE_KEY_LENGTH] = value
249            .try_into()
250            .map_err(|_| CurveError::BadKeyLength(KeyType::Djb, value.len()))?;
251        // Clamping is not necessary but is kept for backward compatibility
252        key = scalar::clamp_integer(key);
253        Ok(Self {
254            key: PrivateKeyData::DjbPrivateKey(key),
255        })
256    }
257
258    pub fn serialize(&self) -> Vec<u8> {
259        match &self.key {
260            PrivateKeyData::DjbPrivateKey(v) => v.to_vec(),
261        }
262    }
263
264    pub fn public_key(&self) -> Result<PublicKey, CurveError> {
265        match &self.key {
266            PrivateKeyData::DjbPrivateKey(private_key) => {
267                let public_key =
268                    curve25519::PrivateKey::from(*private_key).derive_public_key_bytes();
269                Ok(PublicKey::new(PublicKeyData::DjbPublicKey(public_key)))
270            }
271        }
272    }
273
274    pub fn key_type(&self) -> KeyType {
275        match &self.key {
276            PrivateKeyData::DjbPrivateKey(_) => KeyType::Djb,
277        }
278    }
279
280    pub fn calculate_signature<R: CryptoRng + Rng>(
281        &self,
282        message: &[u8],
283        csprng: &mut R,
284    ) -> Result<Box<[u8]>, CurveError> {
285        self.calculate_signature_for_multipart_message(&[message], csprng)
286    }
287
288    pub fn calculate_signature_for_multipart_message<R: CryptoRng + Rng>(
289        &self,
290        message: &[&[u8]],
291        csprng: &mut R,
292    ) -> Result<Box<[u8]>, CurveError> {
293        match self.key {
294            PrivateKeyData::DjbPrivateKey(k) => {
295                let private_key = curve25519::PrivateKey::from(k);
296                Ok(Box::new(private_key.calculate_signature(csprng, message)))
297            }
298        }
299    }
300
301    pub fn calculate_agreement(&self, their_key: &PublicKey) -> Result<Box<[u8]>, CurveError> {
302        match (self.key, their_key.key) {
303            (PrivateKeyData::DjbPrivateKey(priv_key), PublicKeyData::DjbPublicKey(pub_key)) => {
304                let private_key = curve25519::PrivateKey::from(priv_key);
305                Ok(Box::new(private_key.calculate_agreement(&pub_key)))
306            }
307        }
308    }
309}
310
311impl TryFrom<&[u8]> for PrivateKey {
312    type Error = CurveError;
313
314    fn try_from(value: &[u8]) -> Result<Self, CurveError> {
315        Self::deserialize(value)
316    }
317}
318
319#[derive(Copy, Clone)]
320pub struct KeyPair {
321    pub public_key: PublicKey,
322    pub private_key: PrivateKey,
323}
324
325impl KeyPair {
326    pub fn generate<R: Rng + CryptoRng>(csprng: &mut R) -> Self {
327        let private_key = curve25519::PrivateKey::new(csprng);
328
329        let public_key = PublicKey::from(PublicKeyData::DjbPublicKey(
330            private_key.derive_public_key_bytes(),
331        ));
332        let private_key = PrivateKey::from(PrivateKeyData::DjbPrivateKey(
333            private_key.private_key_bytes(),
334        ));
335
336        Self {
337            public_key,
338            private_key,
339        }
340    }
341
342    pub fn new(public_key: PublicKey, private_key: PrivateKey) -> Self {
343        Self {
344            public_key,
345            private_key,
346        }
347    }
348
349    pub fn from_public_and_private(
350        public_key: &[u8],
351        private_key: &[u8],
352    ) -> Result<Self, CurveError> {
353        let public_key = PublicKey::try_from(public_key)?;
354        let private_key = PrivateKey::try_from(private_key)?;
355        Ok(Self {
356            public_key,
357            private_key,
358        })
359    }
360
361    pub fn calculate_signature<R: CryptoRng + Rng>(
362        &self,
363        message: &[u8],
364        csprng: &mut R,
365    ) -> Result<Box<[u8]>, CurveError> {
366        self.private_key.calculate_signature(message, csprng)
367    }
368
369    pub fn calculate_agreement(&self, their_key: &PublicKey) -> Result<Box<[u8]>, CurveError> {
370        self.private_key.calculate_agreement(their_key)
371    }
372}
373
374impl TryFrom<PrivateKey> for KeyPair {
375    type Error = CurveError;
376
377    fn try_from(value: PrivateKey) -> Result<Self, CurveError> {
378        let public_key = value.public_key()?;
379        Ok(Self::new(public_key, value))
380    }
381}
382
383#[cfg(test)]
384mod tests {
385    use assert_matches::assert_matches;
386    use const_str::hex;
387    use curve25519_dalek::constants::EIGHT_TORSION;
388    use rand::TryRngCore as _;
389    use rand::rngs::OsRng;
390
391    use super::*;
392
393    #[test]
394    fn test_large_signatures() -> Result<(), CurveError> {
395        let mut csprng = OsRng.unwrap_err();
396        let key_pair = KeyPair::generate(&mut csprng);
397        let mut message = [0u8; 1024 * 1024];
398        let signature = key_pair
399            .private_key
400            .calculate_signature(&message, &mut csprng)?;
401
402        assert!(key_pair.public_key.verify_signature(&message, &signature));
403        message[0] ^= 0x01u8;
404        assert!(!key_pair.public_key.verify_signature(&message, &signature));
405        message[0] ^= 0x01u8;
406        let public_key = key_pair.private_key.public_key()?;
407        assert!(public_key.verify_signature(&message, &signature));
408
409        assert!(
410            public_key.verify_signature_for_multipart_message(
411                &[&message[..7], &message[7..]],
412                &signature
413            )
414        );
415
416        let signature = key_pair
417            .private_key
418            .calculate_signature_for_multipart_message(
419                &[&message[..20], &message[20..]],
420                &mut csprng,
421            )?;
422        assert!(public_key.verify_signature(&message, &signature));
423
424        Ok(())
425    }
426
427    #[test]
428    fn test_decode_size() -> Result<(), CurveError> {
429        let mut csprng = OsRng.unwrap_err();
430        let key_pair = KeyPair::generate(&mut csprng);
431        let serialized_public = key_pair.public_key.serialize();
432
433        assert_eq!(
434            serialized_public,
435            key_pair.private_key.public_key()?.serialize()
436        );
437        let empty: [u8; 0] = [];
438
439        let just_right = PublicKey::try_from(&serialized_public[..])?;
440
441        assert!(PublicKey::try_from(&serialized_public[1..]).is_err());
442        assert!(PublicKey::try_from(&empty[..]).is_err());
443
444        let mut bad_key_type = [0u8; 33];
445        bad_key_type[..].copy_from_slice(&serialized_public[..]);
446        bad_key_type[0] = 0x01u8;
447        assert!(PublicKey::try_from(&bad_key_type[..]).is_err());
448
449        let mut extra_space = [0u8; 34];
450        extra_space[..33].copy_from_slice(&serialized_public[..]);
451        let extra_space_decode = PublicKey::try_from(&extra_space[..]);
452        assert!(extra_space_decode.is_ok());
453
454        assert_eq!(&serialized_public[..], &just_right.serialize()[..]);
455        assert_eq!(&serialized_public[..], &extra_space_decode?.serialize()[..]);
456        Ok(())
457    }
458
459    #[test]
460    fn curve_error_impls_std_error() {
461        let error = CurveError::BadKeyType(u8::MAX);
462        let error = Box::new(error) as Box<dyn std::error::Error>;
463        assert_matches!(error.downcast_ref(), Some(CurveError::BadKeyType(_)));
464    }
465
466    #[test]
467    fn honest_keys_are_torsion_free() {
468        let mut csprng = OsRng.unwrap_err();
469        let key_pair = KeyPair::generate(&mut csprng);
470        assert!(key_pair.public_key.is_torsion_free());
471    }
472
473    #[test]
474    fn tweaked_keys_are_not_torsion_free() {
475        let mut csprng = OsRng.unwrap_err();
476        let key_pair = KeyPair::generate(&mut csprng);
477        let pk_bytes: [u8; 32] = key_pair.public_key.public_key_bytes().try_into().unwrap();
478        let mont_pt = MontgomeryPoint(pk_bytes);
479        let ed_pt = mont_pt.to_edwards(0).unwrap();
480        for t in EIGHT_TORSION.iter().skip(1) {
481            let tweaked = ed_pt + *t; // add a torsion point
482            let tweaked_mont = tweaked.to_montgomery();
483            let tweaked_pk_bytes: [u8; 32] = tweaked_mont.to_bytes();
484            let tweaked_pk = PublicKey::from_djb_public_key_bytes(&tweaked_pk_bytes).unwrap();
485            assert!(!tweaked_pk.is_torsion_free());
486        }
487    }
488
489    #[test]
490    fn keys_with_the_high_bit_set_are_out_of_range() {
491        assert!(
492            PublicKey::from_djb_public_key_bytes(&[0; 32])
493                .expect("structurally valid")
494                .scalar_is_in_range(),
495            "0 should be in range"
496        );
497        assert!(
498            !PublicKey::from_djb_public_key_bytes(&hex!(
499                "0000000000000000000000000000000000000000000000000000000000000080"
500            ))
501            .expect("structurally valid")
502            .scalar_is_in_range(),
503            "2^255 should be out of range"
504        );
505        assert!(
506            !PublicKey::from_djb_public_key_bytes(&[0xFF; 32])
507                .expect("structurally valid")
508                .scalar_is_in_range(),
509            "2^256 - 1 should be out of range"
510        );
511        {
512            let mut csprng = OsRng.unwrap_err();
513            let key_pair = KeyPair::generate(&mut csprng);
514            assert!(key_pair.public_key.scalar_is_in_range());
515            let mut pk_bytes: [u8; 32] = key_pair.public_key.public_key_bytes().try_into().unwrap();
516            assert!(pk_bytes[31] & 0x80 == 0);
517            pk_bytes[31] |= 0x80;
518            assert!(
519                !PublicKey::from_djb_public_key_bytes(&pk_bytes)
520                    .expect("structurally valid")
521                    .scalar_is_in_range(),
522                ">2^255 should be out of range"
523            );
524        }
525    }
526
527    #[test]
528    fn keys_above_the_prime_modulus_are_out_of_range() {
529        // Curve25519 scalars use a little-endian representation.
530        let two_to_the_255_minus_one =
531            hex!("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff7f");
532
533        for i in 1..=19 {
534            let mut pk_bytes = two_to_the_255_minus_one;
535            pk_bytes[0] -= i;
536            pk_bytes[0] += 1; // because our original literal was 2^255 - 1
537            assert!(
538                !PublicKey::from_djb_public_key_bytes(&pk_bytes)
539                    .expect("structurally valid")
540                    .scalar_is_in_range(),
541                "2^255 - {i} should be out of range",
542            );
543
544            let mut canonical_representative = [0; 32];
545            canonical_representative[0] = 19 - i;
546
547            assert_eq!(
548                MontgomeryPoint(pk_bytes),
549                MontgomeryPoint(canonical_representative)
550            );
551        }
552
553        let mut pk_bytes = two_to_the_255_minus_one;
554        pk_bytes[0] -= 19; // resulting in the value 2^255 - 20
555        assert!(
556            PublicKey::from_djb_public_key_bytes(&pk_bytes)
557                .expect("structurally valid")
558                .scalar_is_in_range()
559        );
560    }
561}