1#[derive(Debug, Clone, PartialEq, Eq)]
50pub enum X25519Error {
51 InvalidInput,
53 ValidationError(String),
55}
56
57impl std::fmt::Display for X25519Error {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 match self {
60 X25519Error::InvalidInput => write!(f, "Invalid input: scalar or u-coordinate is all zeros"),
61 X25519Error::ValidationError(msg) => write!(f, "Validation error: {}", msg),
62 }
63 }
64}
65
66impl std::error::Error for X25519Error {}
67
68pub type X25519Result<T> = Result<T, X25519Error>;
70
71const CURVE25519_BIT_LEN: usize = 255;
72pub const CURVE25519_BYTE_LEN: usize = 32;
73const CURVE25519_WORD_LEN: usize = 8;
74const CURVE25519_A24: u32 = 121666;
75
76pub const U_COORDINATE: [u8; CURVE25519_BYTE_LEN] = [
78 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
79];
80
81
82#[repr(align(32))]
83struct X25519State {
84 k: [u32; CURVE25519_WORD_LEN],
85 u: [u32; CURVE25519_WORD_LEN],
86 x1: [u32; CURVE25519_WORD_LEN],
87 z1: [u32; CURVE25519_WORD_LEN],
88 x2: [u32; CURVE25519_WORD_LEN],
89 z2: [u32; CURVE25519_WORD_LEN],
90 t1: [u32; CURVE25519_WORD_LEN],
91 t2: [u32; CURVE25519_WORD_LEN],
92}
93
94impl Drop for X25519State {
95 fn drop(&mut self) {
96 let raw = self as *mut X25519State as *mut [u8; 256];
97 unsafe { *raw = [0; 256] };
98 }
99}
100
101#[inline]
103fn validate_input(input: &[u8; CURVE25519_BYTE_LEN]) -> X25519Result<()> {
104 let is_zero = input.iter().all(|&b| b == 0);
105 if is_zero {
106 return Err(X25519Error::InvalidInput);
107 }
108 Ok(())
109}
110
111#[inline]
113fn bytes_to_u32_array(bytes: [u8; CURVE25519_BYTE_LEN]) -> [u32; CURVE25519_WORD_LEN] {
114 let mut result = [0u32; CURVE25519_WORD_LEN];
115 for (i, chunk) in bytes.chunks_exact(4).enumerate() {
116 result[i] = u32::from_le_bytes(chunk.try_into().unwrap());
117 }
118 result
119}
120
121#[inline]
123fn u32_array_to_bytes(array: [u32; CURVE25519_WORD_LEN]) -> [u8; CURVE25519_BYTE_LEN] {
124 let mut result = [0u8; CURVE25519_BYTE_LEN];
125 for (i, &word) in array.iter().enumerate() {
126 let bytes = word.to_le_bytes();
127 result[i * 4..(i + 1) * 4].copy_from_slice(&bytes);
128 }
129 result
130}
131
132pub fn x25519(k: [u8; CURVE25519_BYTE_LEN], u: [u8; CURVE25519_BYTE_LEN]) -> X25519Result<[u8; CURVE25519_BYTE_LEN]> {
148 validate_input(&k)?;
150 validate_input(&u)?;
151
152 let mut swap: u32 = 0;
153 let mut b: u32;
154 let mut state = X25519State {
155 k: [0; CURVE25519_WORD_LEN],
156 u: [0; CURVE25519_WORD_LEN],
157 x1: [0; CURVE25519_WORD_LEN],
158 z1: [0; CURVE25519_WORD_LEN],
159 x2: [0; CURVE25519_WORD_LEN],
160 z2: [0; CURVE25519_WORD_LEN],
161 t1: [0; CURVE25519_WORD_LEN],
162 t2: [0; CURVE25519_WORD_LEN],
163 };
164
165 state.k = bytes_to_u32_array(k);
167
168 state.k[0] &= 0xFFFFFFF8;
172 state.k[7] &= 0x7FFFFFFF;
173 state.k[7] |= 0x40000000;
174
175 state.u = bytes_to_u32_array(u);
177
178 state.u[7] &= 0x7FFFFFFF;
180
181 state.u = curve25519_red(state.u);
185
186 state.x1[0] = 1;
189 state.x2 = state.u;
191 state.z2[0] = 1;
193
194 for i in (0usize..CURVE25519_BIT_LEN).rev() {
196 b = (state.k[i / 32] >> (i % 32)) & 1;
198
199 curve25519_swap(&mut state.x1, &mut state.x2, swap ^ b);
201 curve25519_swap(&mut state.z1, &mut state.z2, swap ^ b);
202
203 swap = b;
205
206 state.t1 = curve25519_add(state.x2, state.z2);
208 state.x2 = curve25519_sub(state.x2, state.z2);
210 state.z2 = curve25519_add(state.x1, state.z1);
212 state.x1 = curve25519_sub(state.x1, state.z1);
214 state.t1 = curve25519_mul(state.t1, state.x1);
216 state.x2 = curve25519_mul(state.x2, state.z2);
218 state.z2 = curve25519_sqr(state.z2);
220 state.x1 = curve25519_sqr(state.x1);
222 state.t2 = curve25519_sub(state.z2, state.x1);
224 state.z1 = curve25519_mul_int(state.t2, CURVE25519_A24);
226 state.z1 = curve25519_add(state.z1, state.x1);
228 state.z1 = curve25519_mul(state.z1, state.t2);
230 state.x1 = curve25519_mul(state.x1, state.z2);
232 state.z2 = curve25519_sub(state.t1, state.x2);
234 state.z2 = curve25519_sqr(state.z2);
236 state.z2 = curve25519_mul(state.z2, state.u);
238 state.x2 = curve25519_add(state.x2, state.t1);
240 state.x2 = curve25519_sqr(state.x2);
242 }
243
244 curve25519_swap(&mut state.x1, &mut state.x2, swap);
246 curve25519_swap(&mut state.z1, &mut state.z2, swap);
247
248 state.u = curve25519_inv(state.z1);
250 state.u = curve25519_mul(state.u, state.x1);
251
252 Ok(u32_array_to_bytes(state.u))
253}
254
255#[inline]
259fn curve25519_red(a: [u32; CURVE25519_WORD_LEN]) -> [u32; CURVE25519_WORD_LEN] {
260 let mut temp: u64 = 19;
261 let mut b: [u32; CURVE25519_WORD_LEN] = Default::default();
262
263 for i in 0..CURVE25519_WORD_LEN {
265 temp += a[i] as u64;
266 b[i] = temp as u32;
267 temp >>= 32;
268 }
269
270 b[7] = b[7].wrapping_sub(0x80000000);
272 curve25519_select(&b, &a, (b[7] & 0x80000000) >> 31)
274}
275
276#[inline]
280fn curve25519_select(a: &[u32; CURVE25519_WORD_LEN], b: &[u32; CURVE25519_WORD_LEN], c: u32) -> [u32; CURVE25519_WORD_LEN] {
281 let mask = c.wrapping_sub(1);
283 let mut r: [u32; CURVE25519_WORD_LEN] = Default::default();
284 for i in 0..CURVE25519_WORD_LEN {
286 r[i] = (a[i] & mask) | (b[i] & !mask);
288 }
289
290 r
291}
292
293#[inline]
297fn curve25519_swap(a: &mut [u32; CURVE25519_WORD_LEN], b: &mut [u32; CURVE25519_WORD_LEN], c: u32) {
298 let mut dummy: u32;
299 let mask = (!c).wrapping_add(1);
301
302 for i in 0..CURVE25519_WORD_LEN {
303 dummy = mask & (a[i] ^ b[i]);
304 a[i] ^= dummy;
305 b[i] ^= dummy;
306 }
307}
308
309#[inline]
313fn curve25519_add(a: [u32; CURVE25519_WORD_LEN], b: [u32; CURVE25519_WORD_LEN]) -> [u32; CURVE25519_WORD_LEN] {
314 let mut temp: u64 = 0;
315 let mut r: [u32; CURVE25519_WORD_LEN] = Default::default();
316
317 for i in 0..CURVE25519_WORD_LEN {
319 temp += a[i] as u64;
320 temp += b[i] as u64;
321 r[i] = temp as u32;
322 temp >>= 32;
323 }
324
325 curve25519_red(r)
327}
328
329#[inline]
333fn curve25519_sub(a: [u32; CURVE25519_WORD_LEN], b: [u32; CURVE25519_WORD_LEN]) -> [u32; CURVE25519_WORD_LEN] {
334 let mut temp: i64 = -19;
335 let mut result: [u32; CURVE25519_WORD_LEN] = Default::default();
336
337 for i in 0..CURVE25519_WORD_LEN {
339 temp += a[i] as i64;
340 temp -= b[i] as i64;
341 result[i] = temp as u32;
342 temp >>= 32;
343 }
344
345 result[7] = result[7].wrapping_add(0x80000000);
347
348 curve25519_red(result)
350}
351
352#[inline]
356fn curve25519_mul(a: [u32; CURVE25519_WORD_LEN], b: [u32; CURVE25519_WORD_LEN]) -> [u32; CURVE25519_WORD_LEN] {
357 let mut c: u64 = 0;
358 let mut temp: u64 = 0;
359 let mut u: [u32; 16] = Default::default();
360
361 for i in 0..16 {
363 if i < CURVE25519_WORD_LEN {
365 for j in 0..=i {
367 temp += a[j] as u64 * b[i - j] as u64;
368 c += temp >> 32;
369 temp &= 0xFFFFFFFF;
370 }
371 } else {
372 for j in i - 7..CURVE25519_WORD_LEN {
374 temp += a[j] as u64 * b[i - j] as u64;
375 c += temp >> 32;
376 temp &= 0xFFFFFFFF;
377 }
378 }
379
380 u[i] = temp as u32;
382
383 temp = c & 0xFFFFFFFF;
385 c >>= 32;
386 }
387
388 temp = (u[7] >> 31) as u64 * 19;
390 u[7] &= 0x7FFFFFFF;
392
393 for i in 0..CURVE25519_WORD_LEN {
395 temp += u[i] as u64;
396 temp += u[i + CURVE25519_WORD_LEN] as u64 * 38;
397 u[i] = temp as u32;
398 temp >>= 32;
399 }
400
401 temp *= 38;
403 temp += (u[7] >> 31) as u64 * 19;
405 u[7] &= 0x7FFFFFFF;
407
408 for i in 0..CURVE25519_WORD_LEN {
410 temp += u[i] as u64;
411 u[i] = temp as u32;
412 temp >>= 32;
413 }
414
415 let mut temp: [u32; CURVE25519_WORD_LEN] = Default::default();
417 temp.copy_from_slice(&u[..CURVE25519_WORD_LEN]);
418 curve25519_red(temp)
419}
420
421#[inline]
425fn curve25519_sqr(a: [u32; CURVE25519_WORD_LEN]) -> [u32; CURVE25519_WORD_LEN] {
426 curve25519_mul(a, a)
428}
429
430#[inline]
434fn curve25519_mul_int(a: [u32; CURVE25519_WORD_LEN], b: u32) -> [u32; CURVE25519_WORD_LEN] {
435 let mut temp: u64 = 0;
436 let mut u: [u32; CURVE25519_WORD_LEN] = Default::default();
437
438 for i in 0..CURVE25519_WORD_LEN {
440 temp += a[i] as u64 * b as u64;
441 u[i] = temp as u32;
442 temp >>= 32;
443 }
444
445 temp *= 38;
447 temp += (u[7] >> 31) as u64 * 19;
449 u[7] &= 0x7FFFFFFF;
451
452 for i in 0..CURVE25519_WORD_LEN {
454 temp += u[i] as u64;
455 u[i] = temp as u32;
456 temp >>= 32;
457 }
458
459 curve25519_red(u)
461}
462
463
464
465#[inline]
469fn curve25519_inv(a: [u32; CURVE25519_WORD_LEN]) -> [u32; CURVE25519_WORD_LEN] {
470 let mut u: [u32; CURVE25519_WORD_LEN];
471 let mut v: [u32; CURVE25519_WORD_LEN];
472
473 u = curve25519_sqr(a);
476 u = curve25519_mul(u, a); u = curve25519_sqr(u);
478 v = curve25519_mul(u, a); u = curve25519_pwr2(v, 3);
481 u = curve25519_mul(u, v); u = curve25519_sqr(u);
483 v = curve25519_mul(u, a); u = curve25519_pwr2(v, 7);
486 u = curve25519_mul(u, v); u = curve25519_sqr(u);
488 v = curve25519_mul(u, a); u = curve25519_pwr2(v, 15);
491 u = curve25519_mul(u, v); u = curve25519_sqr(u);
493 v = curve25519_mul(u, a); u = curve25519_pwr2(v, 31);
496 v = curve25519_mul(u, v); u = curve25519_pwr2(v, 62);
499 u = curve25519_mul(u, v); u = curve25519_sqr(u);
501 v = curve25519_mul(u, a); u = curve25519_pwr2(v, 125);
504 u = curve25519_mul(u, v); u = curve25519_sqr(u);
506 u = curve25519_sqr(u);
507 u = curve25519_mul(u, a);
508 u = curve25519_sqr(u);
509 u = curve25519_sqr(u);
510 u = curve25519_mul(u, a);
511 u = curve25519_sqr(u);
512 curve25519_mul(u, a) }
514
515#[inline]
519fn curve25519_pwr2(a: [u32; CURVE25519_WORD_LEN], n: usize) -> [u32; CURVE25519_WORD_LEN] {
520 let mut result = curve25519_sqr(a);
522
523 for _ in 1..n {
525 result = curve25519_sqr(result);
526 }
527
528 result
529}
530
531#[cfg(test)]
532pub fn generate_private_key<R: rand::RngCore + rand::CryptoRng>(rng: &mut R) -> X25519Result<[u8; CURVE25519_BYTE_LEN]> {
550 let mut key = [0u8; CURVE25519_BYTE_LEN];
551 rng.fill_bytes(&mut key);
552
553 if key.iter().all(|&b| b == 0) {
555 return Err(X25519Error::ValidationError("Generated key is all zeros".to_string()));
556 }
557
558 key[0] &= 0xf8; key[31] &= 0x7f; key[31] |= 0x40; Ok(key)
564}
565
566#[cfg(test)]
567mod tests {
568 use super::*;
569 use rand::rngs::OsRng;
570
571
572 #[test]
573 fn test_x25519() {
574 let scalar = [
575 0xa5u8, 0x46, 0xe3, 0x6b, 0xf0, 0x52, 0x7c, 0x9d, 0x3b, 0x16, 0x15, 0x4b, 0x82, 0x46,
576 0x5e, 0xdd, 0x62, 0x14, 0x4c, 0x0a, 0xc1, 0xfc, 0x5a, 0x18, 0x50, 0x6a, 0x22, 0x44,
577 0xba, 0x44, 0x9a, 0xc4,
578 ];
579 let u_coordinate = [
580 0xe6u8, 0xdb, 0x68, 0x67, 0x58, 0x30, 0x30, 0xdb, 0x35, 0x94, 0xc1, 0xa4, 0x24, 0xb1,
581 0x5f, 0x7c, 0x72, 0x66, 0x24, 0xec, 0x26, 0xb3, 0x35, 0x3b, 0x10, 0xa9, 0x03, 0xa6,
582 0xd0, 0xab, 0x1c, 0x4c,
583 ];
584 let result = x25519(scalar, u_coordinate).unwrap();
585
586 assert_eq!(
587 result,
588 [
589 0xc3, 0xda, 0x55, 0x37, 0x9d, 0xe9, 0xc6, 0x90, 0x8e, 0x94, 0xea, 0x4d, 0xf2, 0x8d,
590 0x08, 0x4f, 0x32, 0xec, 0xcf, 0x03, 0x49, 0x1c, 0x71, 0xf7, 0x54, 0xb4, 0x07, 0x55,
591 0x77, 0xa2, 0x85, 0x52
592 ]
593 );
594
595 let scalar = [
596 0x4b, 0x66, 0xe9, 0xd4, 0xd1, 0xb4, 0x67, 0x3c, 0x5a, 0xd2, 0x26, 0x91, 0x95, 0x7d,
597 0x6a, 0xf5, 0xc1, 0x1b, 0x64, 0x21, 0xe0, 0xea, 0x01, 0xd4, 0x2c, 0xa4, 0x16, 0x9e,
598 0x79, 0x18, 0xba, 0x0d,
599 ];
600 let u_coordinate = [
601 0xe5, 0x21, 0x0f, 0x12, 0x78, 0x68, 0x11, 0xd3, 0xf4, 0xb7, 0x95, 0x9d, 0x05, 0x38,
602 0xae, 0x2c, 0x31, 0xdb, 0xe7, 0x10, 0x6f, 0xc0, 0x3c, 0x3e, 0xfc, 0x4c, 0xd5, 0x49,
603 0xc7, 0x15, 0xa4, 0x93,
604 ];
605 let result = x25519(scalar, u_coordinate).unwrap();
606
607 assert_eq!(
608 result,
609 [
610 0x95, 0xcb, 0xde, 0x94, 0x76, 0xe8, 0x90, 0x7d, 0x7a, 0xad, 0xe4, 0x5c, 0xb4, 0xb8,
611 0x73, 0xf8, 0x8b, 0x59, 0x5a, 0x68, 0x79, 0x9f, 0xa1, 0x52, 0xe6, 0xf8, 0xf7, 0x64,
612 0x7a, 0xac, 0x79, 0x57
613 ]
614 );
615 }
616
617 #[test]
618 fn test_x25519_series() {
619 let scalar = U_COORDINATE;
620 let u_coordinate = U_COORDINATE;
621 let result = x25519(scalar, u_coordinate).unwrap();
622
623 assert_eq!(
624 result,
625 [
626 0x42, 0x2c, 0x8e, 0x7a, 0x62, 0x27, 0xd7, 0xbc, 0xa1, 0x35, 0x0b, 0x3e, 0x2b, 0xb7,
627 0x27, 0x9f, 0x78, 0x97, 0xb8, 0x7b, 0xb6, 0x85, 0x4b, 0x78, 0x3c, 0x60, 0xe8, 0x03,
628 0x11, 0xae, 0x30, 0x79
629 ]
630 );
631 }
632
633 #[test]
634 fn test_generate_private_key() {
635 let mut rng = OsRng;
636 let private_key = generate_private_key(&mut rng).unwrap();
637
638 assert!(!private_key.iter().all(|&b| b == 0));
640
641 assert_eq!(private_key[0] & 0x07, 0); assert_eq!(private_key[31] & 0x80, 0); assert_eq!(private_key[31] & 0x40, 0x40); }
646
647 #[test]
648 fn test_invalid_input() {
649 let zero_input = [0u8; CURVE25519_BYTE_LEN];
650 let valid_input = U_COORDINATE;
651
652 assert!(matches!(x25519(zero_input, valid_input), Err(X25519Error::InvalidInput)));
654
655 assert!(matches!(x25519(valid_input, zero_input), Err(X25519Error::InvalidInput)));
657
658 assert!(matches!(x25519(zero_input, zero_input), Err(X25519Error::InvalidInput)));
660 }
661
662 #[test]
663 fn test_key_exchange() {
664 let mut rng = OsRng;
665
666 let alice_private = generate_private_key(&mut rng).unwrap();
668 let bob_private = generate_private_key(&mut rng).unwrap();
669
670 let alice_public = x25519(alice_private, U_COORDINATE).unwrap();
672 let bob_public = x25519(bob_private, U_COORDINATE).unwrap();
673
674 let alice_shared = x25519(alice_private, bob_public).unwrap();
676 let bob_shared = x25519(bob_private, alice_public).unwrap();
677
678 assert_eq!(alice_shared, bob_shared);
680 }
681}