1mod curve25519;
7mod utils;
8
9use std::cmp::Ordering;
10use std::fmt;
11
12use curve25519_dalek::{MontgomeryPoint, scalar};
13use rand::{CryptoRng, Rng};
14use subtle::ConstantTimeEq;
15
16#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
17pub enum KeyType {
18 Djb,
19}
20
21impl fmt::Display for KeyType {
22 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
23 fmt::Debug::fmt(self, f)
24 }
25}
26
27impl KeyType {
28 fn value(&self) -> u8 {
29 match &self {
30 KeyType::Djb => 0x05u8,
31 }
32 }
33}
34
35#[derive(Debug, displaydoc::Display)]
36pub enum CurveError {
37 NoKeyTypeIdentifier,
39 BadKeyType(u8),
41 BadKeyLength(KeyType, usize),
43}
44
45impl std::error::Error for CurveError {}
46
47impl TryFrom<u8> for KeyType {
48 type Error = CurveError;
49
50 fn try_from(x: u8) -> Result<Self, CurveError> {
51 match x {
52 0x05u8 => Ok(KeyType::Djb),
53 t => Err(CurveError::BadKeyType(t)),
54 }
55 }
56}
57
58#[derive(Debug, Clone, Copy, Eq, PartialEq)]
59enum PublicKeyData {
60 DjbPublicKey([u8; curve25519::PUBLIC_KEY_LENGTH]),
61}
62
63#[derive(Clone, Copy, Eq, derive_more::From)]
64pub struct PublicKey {
65 key: PublicKeyData,
66}
67
68impl PublicKey {
69 fn new(key: PublicKeyData) -> Self {
70 Self { key }
71 }
72
73 pub fn deserialize(value: &[u8]) -> Result<Self, CurveError> {
74 let (key_type, value) = value.split_first().ok_or(CurveError::NoKeyTypeIdentifier)?;
75 let key_type = KeyType::try_from(*key_type)?;
76 match key_type {
77 KeyType::Djb => {
78 let (key, tail): (&[u8; curve25519::PUBLIC_KEY_LENGTH], _) = value
79 .split_first_chunk()
80 .ok_or(CurveError::BadKeyLength(KeyType::Djb, value.len() + 1))?;
81 if !tail.is_empty() {
84 log::warn!(
85 "ECPublicKey deserialized with {} trailing bytes",
86 tail.len()
87 );
88 }
89 Ok(PublicKey {
90 key: PublicKeyData::DjbPublicKey(*key),
91 })
92 }
93 }
94 }
95
96 pub fn public_key_bytes(&self) -> &[u8] {
97 match &self.key {
98 PublicKeyData::DjbPublicKey(v) => v,
99 }
100 }
101
102 pub fn from_djb_public_key_bytes(bytes: &[u8]) -> Result<Self, CurveError> {
103 match <[u8; curve25519::PUBLIC_KEY_LENGTH]>::try_from(bytes) {
104 Err(_) => Err(CurveError::BadKeyLength(KeyType::Djb, bytes.len())),
105 Ok(key) => Ok(PublicKey {
106 key: PublicKeyData::DjbPublicKey(key),
107 }),
108 }
109 }
110
111 pub fn serialize(&self) -> Box<[u8]> {
112 let value_len = match &self.key {
113 PublicKeyData::DjbPublicKey(v) => v.len(),
114 };
115 let mut result = Vec::with_capacity(1 + value_len);
116 result.push(self.key_type().value());
117 match &self.key {
118 PublicKeyData::DjbPublicKey(v) => result.extend_from_slice(v),
119 }
120 result.into_boxed_slice()
121 }
122
123 pub fn verify_signature(&self, message: &[u8], signature: &[u8]) -> bool {
124 self.verify_signature_for_multipart_message(&[message], signature)
125 }
126
127 pub fn verify_signature_for_multipart_message(
128 &self,
129 message: &[&[u8]],
130 signature: &[u8],
131 ) -> bool {
132 match &self.key {
133 PublicKeyData::DjbPublicKey(pub_key) => {
134 let Ok(signature) = signature.try_into() else {
135 return false;
136 };
137 curve25519::PrivateKey::verify_signature(pub_key, message, signature)
138 }
139 }
140 }
141
142 fn key_data(&self) -> &[u8] {
143 match &self.key {
144 PublicKeyData::DjbPublicKey(k) => k.as_ref(),
145 }
146 }
147
148 pub fn key_type(&self) -> KeyType {
149 match &self.key {
150 PublicKeyData::DjbPublicKey(_) => KeyType::Djb,
151 }
152 }
153
154 fn is_torsion_free(&self) -> bool {
155 match &self.key {
156 PublicKeyData::DjbPublicKey(k) => {
157 let mont_point = MontgomeryPoint(*k);
158 mont_point
159 .to_edwards(0)
160 .is_some_and(|ed| ed.is_torsion_free())
161 }
162 }
163 }
164
165 fn scalar_is_in_range(&self) -> bool {
166 match &self.key {
167 PublicKeyData::DjbPublicKey(k) => {
168 !(k[31] & 0b1000_0000_u8 != 0
172 || (k[0] >= 0u8.wrapping_sub(19) && k[1..31] == [0xFFu8; 30] && k[31] == 0x7F))
173 }
174 }
175 }
176
177 pub fn is_canonical(&self) -> bool {
178 self.is_torsion_free() && self.scalar_is_in_range()
179 }
180}
181
182impl TryFrom<&[u8]> for PublicKey {
183 type Error = CurveError;
184
185 fn try_from(value: &[u8]) -> Result<Self, CurveError> {
186 Self::deserialize(value)
187 }
188}
189
190impl subtle::ConstantTimeEq for PublicKey {
191 fn ct_eq(&self, other: &PublicKey) -> subtle::Choice {
196 if self.key_type() != other.key_type() {
197 return 0.ct_eq(&1);
198 }
199 self.key_data().ct_eq(other.key_data())
200 }
201}
202
203impl PartialEq for PublicKey {
204 fn eq(&self, other: &PublicKey) -> bool {
205 bool::from(self.ct_eq(other))
206 }
207}
208
209impl Ord for PublicKey {
210 fn cmp(&self, other: &Self) -> Ordering {
211 if self.key_type() != other.key_type() {
212 return self.key_type().cmp(&other.key_type());
213 }
214
215 utils::constant_time_cmp(self.key_data(), other.key_data())
216 }
217}
218
219impl PartialOrd for PublicKey {
220 fn partial_cmp(&self, other: &PublicKey) -> Option<Ordering> {
221 Some(self.cmp(other))
222 }
223}
224
225impl fmt::Debug for PublicKey {
226 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
227 write!(
228 f,
229 "PublicKey {{ key_type={}, serialize={:?} }}",
230 self.key_type(),
231 self.serialize()
232 )
233 }
234}
235
236#[derive(Debug, Clone, Copy, Eq, PartialEq)]
237enum PrivateKeyData {
238 DjbPrivateKey([u8; curve25519::PRIVATE_KEY_LENGTH]),
239}
240
241#[derive(Clone, Copy, Eq, PartialEq, derive_more::From)]
242pub struct PrivateKey {
243 key: PrivateKeyData,
244}
245
246impl PrivateKey {
247 pub fn deserialize(value: &[u8]) -> Result<Self, CurveError> {
248 let mut key: [u8; curve25519::PRIVATE_KEY_LENGTH] = value
249 .try_into()
250 .map_err(|_| CurveError::BadKeyLength(KeyType::Djb, value.len()))?;
251 key = scalar::clamp_integer(key);
253 Ok(Self {
254 key: PrivateKeyData::DjbPrivateKey(key),
255 })
256 }
257
258 pub fn serialize(&self) -> Vec<u8> {
259 match &self.key {
260 PrivateKeyData::DjbPrivateKey(v) => v.to_vec(),
261 }
262 }
263
264 pub fn public_key(&self) -> Result<PublicKey, CurveError> {
265 match &self.key {
266 PrivateKeyData::DjbPrivateKey(private_key) => {
267 let public_key =
268 curve25519::PrivateKey::from(*private_key).derive_public_key_bytes();
269 Ok(PublicKey::new(PublicKeyData::DjbPublicKey(public_key)))
270 }
271 }
272 }
273
274 pub fn key_type(&self) -> KeyType {
275 match &self.key {
276 PrivateKeyData::DjbPrivateKey(_) => KeyType::Djb,
277 }
278 }
279
280 pub fn calculate_signature<R: CryptoRng + Rng>(
281 &self,
282 message: &[u8],
283 csprng: &mut R,
284 ) -> Result<Box<[u8]>, CurveError> {
285 self.calculate_signature_for_multipart_message(&[message], csprng)
286 }
287
288 pub fn calculate_signature_for_multipart_message<R: CryptoRng + Rng>(
289 &self,
290 message: &[&[u8]],
291 csprng: &mut R,
292 ) -> Result<Box<[u8]>, CurveError> {
293 match self.key {
294 PrivateKeyData::DjbPrivateKey(k) => {
295 let private_key = curve25519::PrivateKey::from(k);
296 Ok(Box::new(private_key.calculate_signature(csprng, message)))
297 }
298 }
299 }
300
301 pub fn calculate_agreement(&self, their_key: &PublicKey) -> Result<Box<[u8]>, CurveError> {
302 match (self.key, their_key.key) {
303 (PrivateKeyData::DjbPrivateKey(priv_key), PublicKeyData::DjbPublicKey(pub_key)) => {
304 let private_key = curve25519::PrivateKey::from(priv_key);
305 Ok(Box::new(private_key.calculate_agreement(&pub_key)))
306 }
307 }
308 }
309}
310
311impl TryFrom<&[u8]> for PrivateKey {
312 type Error = CurveError;
313
314 fn try_from(value: &[u8]) -> Result<Self, CurveError> {
315 Self::deserialize(value)
316 }
317}
318
319#[derive(Copy, Clone)]
320pub struct KeyPair {
321 pub public_key: PublicKey,
322 pub private_key: PrivateKey,
323}
324
325impl KeyPair {
326 pub fn generate<R: Rng + CryptoRng>(csprng: &mut R) -> Self {
327 let private_key = curve25519::PrivateKey::new(csprng);
328
329 let public_key = PublicKey::from(PublicKeyData::DjbPublicKey(
330 private_key.derive_public_key_bytes(),
331 ));
332 let private_key = PrivateKey::from(PrivateKeyData::DjbPrivateKey(
333 private_key.private_key_bytes(),
334 ));
335
336 Self {
337 public_key,
338 private_key,
339 }
340 }
341
342 pub fn new(public_key: PublicKey, private_key: PrivateKey) -> Self {
343 Self {
344 public_key,
345 private_key,
346 }
347 }
348
349 pub fn from_public_and_private(
350 public_key: &[u8],
351 private_key: &[u8],
352 ) -> Result<Self, CurveError> {
353 let public_key = PublicKey::try_from(public_key)?;
354 let private_key = PrivateKey::try_from(private_key)?;
355 Ok(Self {
356 public_key,
357 private_key,
358 })
359 }
360
361 pub fn calculate_signature<R: CryptoRng + Rng>(
362 &self,
363 message: &[u8],
364 csprng: &mut R,
365 ) -> Result<Box<[u8]>, CurveError> {
366 self.private_key.calculate_signature(message, csprng)
367 }
368
369 pub fn calculate_agreement(&self, their_key: &PublicKey) -> Result<Box<[u8]>, CurveError> {
370 self.private_key.calculate_agreement(their_key)
371 }
372}
373
374impl TryFrom<PrivateKey> for KeyPair {
375 type Error = CurveError;
376
377 fn try_from(value: PrivateKey) -> Result<Self, CurveError> {
378 let public_key = value.public_key()?;
379 Ok(Self::new(public_key, value))
380 }
381}
382
383#[cfg(test)]
384mod tests {
385 use assert_matches::assert_matches;
386 use const_str::hex;
387 use curve25519_dalek::constants::EIGHT_TORSION;
388 use rand::TryRngCore as _;
389 use rand::rngs::OsRng;
390
391 use super::*;
392
393 #[test]
394 fn test_large_signatures() -> Result<(), CurveError> {
395 let mut csprng = OsRng.unwrap_err();
396 let key_pair = KeyPair::generate(&mut csprng);
397 let mut message = [0u8; 1024 * 1024];
398 let signature = key_pair
399 .private_key
400 .calculate_signature(&message, &mut csprng)?;
401
402 assert!(key_pair.public_key.verify_signature(&message, &signature));
403 message[0] ^= 0x01u8;
404 assert!(!key_pair.public_key.verify_signature(&message, &signature));
405 message[0] ^= 0x01u8;
406 let public_key = key_pair.private_key.public_key()?;
407 assert!(public_key.verify_signature(&message, &signature));
408
409 assert!(
410 public_key.verify_signature_for_multipart_message(
411 &[&message[..7], &message[7..]],
412 &signature
413 )
414 );
415
416 let signature = key_pair
417 .private_key
418 .calculate_signature_for_multipart_message(
419 &[&message[..20], &message[20..]],
420 &mut csprng,
421 )?;
422 assert!(public_key.verify_signature(&message, &signature));
423
424 Ok(())
425 }
426
427 #[test]
428 fn test_decode_size() -> Result<(), CurveError> {
429 let mut csprng = OsRng.unwrap_err();
430 let key_pair = KeyPair::generate(&mut csprng);
431 let serialized_public = key_pair.public_key.serialize();
432
433 assert_eq!(
434 serialized_public,
435 key_pair.private_key.public_key()?.serialize()
436 );
437 let empty: [u8; 0] = [];
438
439 let just_right = PublicKey::try_from(&serialized_public[..])?;
440
441 assert!(PublicKey::try_from(&serialized_public[1..]).is_err());
442 assert!(PublicKey::try_from(&empty[..]).is_err());
443
444 let mut bad_key_type = [0u8; 33];
445 bad_key_type[..].copy_from_slice(&serialized_public[..]);
446 bad_key_type[0] = 0x01u8;
447 assert!(PublicKey::try_from(&bad_key_type[..]).is_err());
448
449 let mut extra_space = [0u8; 34];
450 extra_space[..33].copy_from_slice(&serialized_public[..]);
451 let extra_space_decode = PublicKey::try_from(&extra_space[..]);
452 assert!(extra_space_decode.is_ok());
453
454 assert_eq!(&serialized_public[..], &just_right.serialize()[..]);
455 assert_eq!(&serialized_public[..], &extra_space_decode?.serialize()[..]);
456 Ok(())
457 }
458
459 #[test]
460 fn curve_error_impls_std_error() {
461 let error = CurveError::BadKeyType(u8::MAX);
462 let error = Box::new(error) as Box<dyn std::error::Error>;
463 assert_matches!(error.downcast_ref(), Some(CurveError::BadKeyType(_)));
464 }
465
466 #[test]
467 fn honest_keys_are_torsion_free() {
468 let mut csprng = OsRng.unwrap_err();
469 let key_pair = KeyPair::generate(&mut csprng);
470 assert!(key_pair.public_key.is_torsion_free());
471 }
472
473 #[test]
474 fn tweaked_keys_are_not_torsion_free() {
475 let mut csprng = OsRng.unwrap_err();
476 let key_pair = KeyPair::generate(&mut csprng);
477 let pk_bytes: [u8; 32] = key_pair.public_key.public_key_bytes().try_into().unwrap();
478 let mont_pt = MontgomeryPoint(pk_bytes);
479 let ed_pt = mont_pt.to_edwards(0).unwrap();
480 for t in EIGHT_TORSION.iter().skip(1) {
481 let tweaked = ed_pt + *t; let tweaked_mont = tweaked.to_montgomery();
483 let tweaked_pk_bytes: [u8; 32] = tweaked_mont.to_bytes();
484 let tweaked_pk = PublicKey::from_djb_public_key_bytes(&tweaked_pk_bytes).unwrap();
485 assert!(!tweaked_pk.is_torsion_free());
486 }
487 }
488
489 #[test]
490 fn keys_with_the_high_bit_set_are_out_of_range() {
491 assert!(
492 PublicKey::from_djb_public_key_bytes(&[0; 32])
493 .expect("structurally valid")
494 .scalar_is_in_range(),
495 "0 should be in range"
496 );
497 assert!(
498 !PublicKey::from_djb_public_key_bytes(&hex!(
499 "0000000000000000000000000000000000000000000000000000000000000080"
500 ))
501 .expect("structurally valid")
502 .scalar_is_in_range(),
503 "2^255 should be out of range"
504 );
505 assert!(
506 !PublicKey::from_djb_public_key_bytes(&[0xFF; 32])
507 .expect("structurally valid")
508 .scalar_is_in_range(),
509 "2^256 - 1 should be out of range"
510 );
511 {
512 let mut csprng = OsRng.unwrap_err();
513 let key_pair = KeyPair::generate(&mut csprng);
514 assert!(key_pair.public_key.scalar_is_in_range());
515 let mut pk_bytes: [u8; 32] = key_pair.public_key.public_key_bytes().try_into().unwrap();
516 assert!(pk_bytes[31] & 0x80 == 0);
517 pk_bytes[31] |= 0x80;
518 assert!(
519 !PublicKey::from_djb_public_key_bytes(&pk_bytes)
520 .expect("structurally valid")
521 .scalar_is_in_range(),
522 ">2^255 should be out of range"
523 );
524 }
525 }
526
527 #[test]
528 fn keys_above_the_prime_modulus_are_out_of_range() {
529 let two_to_the_255_minus_one =
531 hex!("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff7f");
532
533 for i in 1..=19 {
534 let mut pk_bytes = two_to_the_255_minus_one;
535 pk_bytes[0] -= i;
536 pk_bytes[0] += 1; assert!(
538 !PublicKey::from_djb_public_key_bytes(&pk_bytes)
539 .expect("structurally valid")
540 .scalar_is_in_range(),
541 "2^255 - {i} should be out of range",
542 );
543
544 let mut canonical_representative = [0; 32];
545 canonical_representative[0] = 19 - i;
546
547 assert_eq!(
548 MontgomeryPoint(pk_bytes),
549 MontgomeryPoint(canonical_representative)
550 );
551 }
552
553 let mut pk_bytes = two_to_the_255_minus_one;
554 pk_bytes[0] -= 19; assert!(
556 PublicKey::from_djb_public_key_bytes(&pk_bytes)
557 .expect("structurally valid")
558 .scalar_is_in_range()
559 );
560 }
561}