crypto_async_rs/
ecdh_x25519.rs

1//! # X25519 Elliptic Curve Diffie-Hellman Key Exchange
2//!
3//! This module provides a pure Rust implementation of the X25519 elliptic curve
4//! Diffie-Hellman key exchange algorithm as specified in RFC 7748.
5//!
6//! ## Features
7//!
8//! - Constant-time implementation resistant to timing attacks
9//! - Secure memory clearing using custom Drop implementation
10//! - Input validation and proper error handling
11//! - High-performance implementation with inline optimizations
12//!
13//! ## Security Considerations
14//!
15//! - All operations are implemented in constant time to prevent timing attacks
16//! - Sensitive data is automatically zeroed when dropped
17//! - Input validation prevents invalid key material from being processed
18//!
19//! ## Example
20//!
21//! ```rust
22//! use crypto_async_rs::ecdh_x25519::{x25519, U_COORDINATE};
23//!
24//! // Example private keys (in practice, these should be randomly generated)
25//! let alice_private = [
26//!     0x77, 0x07, 0x6d, 0x0a, 0x73, 0x18, 0xa5, 0x7d, 0x3c, 0x16, 0xc1, 0x72, 0x51, 0xb2, 0x66, 0x45,
27//!     0xdf, 0x4c, 0x2f, 0x87, 0xeb, 0xc0, 0x99, 0x2a, 0xb1, 0x77, 0xfb, 0xa5, 0x1d, 0xb9, 0x2c, 0x2a,
28//! ];
29//! let bob_private = [
30//!     0x5d, 0xab, 0x08, 0x7e, 0x62, 0x4a, 0x8a, 0x4b, 0x79, 0xe1, 0x7f, 0x8b, 0x83, 0x80, 0x0e, 0xe6,
31//!     0x6f, 0x3b, 0xb1, 0x29, 0x26, 0x18, 0xb6, 0xfd, 0x1c, 0x2f, 0x8b, 0x27, 0xff, 0x88, 0xe0, 0xeb,
32//! ];
33//!
34//! // Compute public keys
35//! let alice_public = x25519(alice_private, U_COORDINATE)?;
36//! let bob_public = x25519(bob_private, U_COORDINATE)?;
37//!
38//! // Perform key exchange
39//! let alice_shared = x25519(alice_private, bob_public)?;
40//! let bob_shared = x25519(bob_private, alice_public)?;
41//!
42//! // Both parties now have the same shared secret
43//! assert_eq!(alice_shared, bob_shared);
44//! # Ok::<(), crypto_async_rs::ecdh_x25519::X25519Error>(())
45//! ```
46
47
48/// Error types for X25519 operations
49#[derive(Debug, Clone, PartialEq, Eq)]
50pub enum X25519Error {
51    /// Invalid input: scalar or u-coordinate is all zeros
52    InvalidInput,
53    /// Input validation failed
54    ValidationError(String),
55}
56
57impl std::fmt::Display for X25519Error {
58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59        match self {
60            X25519Error::InvalidInput => write!(f, "Invalid input: scalar or u-coordinate is all zeros"),
61            X25519Error::ValidationError(msg) => write!(f, "Validation error: {}", msg),
62        }
63    }
64}
65
66impl std::error::Error for X25519Error {}
67
68/// Result type for X25519 operations
69pub type X25519Result<T> = Result<T, X25519Error>;
70
71const CURVE25519_BIT_LEN: usize = 255;
72pub const CURVE25519_BYTE_LEN: usize = 32;
73const CURVE25519_WORD_LEN: usize = 8;
74const CURVE25519_A24: u32 = 121666;
75
76/// The base point U-coordinate for X25519 (9 in little-endian format)
77pub const U_COORDINATE: [u8; CURVE25519_BYTE_LEN] = [
78    9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
79];
80
81
82#[repr(align(32))]
83struct X25519State {
84    k: [u32; CURVE25519_WORD_LEN],
85    u: [u32; CURVE25519_WORD_LEN],
86    x1: [u32; CURVE25519_WORD_LEN],
87    z1: [u32; CURVE25519_WORD_LEN],
88    x2: [u32; CURVE25519_WORD_LEN],
89    z2: [u32; CURVE25519_WORD_LEN],
90    t1: [u32; CURVE25519_WORD_LEN],
91    t2: [u32; CURVE25519_WORD_LEN],
92}
93
94impl Drop for X25519State {
95    fn drop(&mut self) {
96        let raw = self as *mut X25519State as *mut [u8; 256];
97        unsafe { *raw = [0; 256] };
98    }
99}
100
101/// Validates that the input is not all zeros
102#[inline]
103fn validate_input(input: &[u8; CURVE25519_BYTE_LEN]) -> X25519Result<()> {
104    let is_zero = input.iter().all(|&b| b == 0);
105    if is_zero {
106        return Err(X25519Error::InvalidInput);
107    }
108    Ok(())
109}
110
111/// Safely converts a byte array to a u32 array using little-endian interpretation
112#[inline]
113fn bytes_to_u32_array(bytes: [u8; CURVE25519_BYTE_LEN]) -> [u32; CURVE25519_WORD_LEN] {
114    let mut result = [0u32; CURVE25519_WORD_LEN];
115    for (i, chunk) in bytes.chunks_exact(4).enumerate() {
116        result[i] = u32::from_le_bytes(chunk.try_into().unwrap());
117    }
118    result
119}
120
121/// Safely converts a u32 array to a byte array using little-endian interpretation
122#[inline]
123fn u32_array_to_bytes(array: [u32; CURVE25519_WORD_LEN]) -> [u8; CURVE25519_BYTE_LEN] {
124    let mut result = [0u8; CURVE25519_BYTE_LEN];
125    for (i, &word) in array.iter().enumerate() {
126        let bytes = word.to_le_bytes();
127        result[i * 4..(i + 1) * 4].copy_from_slice(&bytes);
128    }
129    result
130}
131
132/// Performs X25519 scalar multiplication
133///
134/// # Arguments
135/// * `k` - The scalar (private key) as a 32-byte array
136/// * `u` - The u-coordinate (public key) as a 32-byte array
137///
138/// # Returns
139/// * `Ok([u8; 32])` - The resulting shared secret
140/// * `Err(X25519Error)` - If input validation fails
141///
142/// # Errors
143/// * `X25519Error::InvalidInput` - If either input is all zeros
144///
145/// # Security
146/// This function implements constant-time operations to prevent timing attacks.
147pub fn x25519(k: [u8; CURVE25519_BYTE_LEN], u: [u8; CURVE25519_BYTE_LEN]) -> X25519Result<[u8; CURVE25519_BYTE_LEN]> {
148    // Validate inputs
149    validate_input(&k)?;
150    validate_input(&u)?;
151
152    let mut swap: u32 = 0;
153    let mut b: u32;
154    let mut state = X25519State {
155        k: [0; CURVE25519_WORD_LEN],
156        u: [0; CURVE25519_WORD_LEN],
157        x1: [0; CURVE25519_WORD_LEN],
158        z1: [0; CURVE25519_WORD_LEN],
159        x2: [0; CURVE25519_WORD_LEN],
160        z2: [0; CURVE25519_WORD_LEN],
161        t1: [0; CURVE25519_WORD_LEN],
162        t2: [0; CURVE25519_WORD_LEN],
163    };
164
165    // Copy scalar using safe conversion
166    state.k = bytes_to_u32_array(k);
167    
168    // Set the three least significant bits of the first byte and the most
169    // significant bit of the last to zero, set the second most significant
170    // bit of the last byte to 1
171    state.k[0] &= 0xFFFFFFF8;
172    state.k[7] &= 0x7FFFFFFF;
173    state.k[7] |= 0x40000000;
174
175    // Copy input u-coordinate using safe conversion
176    state.u = bytes_to_u32_array(u);
177
178    // Implementations must mask the most significant bit in the final byte
179    state.u[7] &= 0x7FFFFFFF;
180
181    // Implementations must accept non-canonical values and process them as
182    // if they had been reduced modulo the field prime (refer to RFC 7748,
183    // section 5)
184    state.u = curve25519_red(state.u);
185
186    // Set Z1 = 0
187    // Set X1 = 1
188    state.x1[0] = 1;
189    // Set X2 = U
190    state.x2 = state.u;
191    // Set Z2 = 1
192    state.z2[0] = 1;
193
194    // Montgomery ladder
195    for i in (0usize..CURVE25519_BIT_LEN).rev() {
196        // The scalar is processed in a left-to-right fashion
197        b = (state.k[i / 32] >> (i % 32)) & 1;
198
199        // Conditional swap
200        curve25519_swap(&mut state.x1, &mut state.x2, swap ^ b);
201        curve25519_swap(&mut state.z1, &mut state.z2, swap ^ b);
202
203        // Save current bit value
204        swap = b;
205
206        // Compute T1 = X2 + Z2
207        state.t1 = curve25519_add(state.x2, state.z2);
208        // Compute X2 = X2 - Z2
209        state.x2 = curve25519_sub(state.x2, state.z2);
210        // Compute Z2 = X1 + Z1
211        state.z2 = curve25519_add(state.x1, state.z1);
212        // Compute X1 = X1 - Z1
213        state.x1 = curve25519_sub(state.x1, state.z1);
214        // Compute T1 = T1 * X1
215        state.t1 = curve25519_mul(state.t1, state.x1);
216        // Compute X2 = X2 * Z2
217        state.x2 = curve25519_mul(state.x2, state.z2);
218        // Compute Z2 = Z2 * Z2
219        state.z2 = curve25519_sqr(state.z2);
220        // Compute X1 = X1 * X1
221        state.x1 = curve25519_sqr(state.x1);
222        // Compute T2 = Z2 - X1
223        state.t2 = curve25519_sub(state.z2, state.x1);
224        // Compute Z1 = T2 * a24
225        state.z1 = curve25519_mul_int(state.t2, CURVE25519_A24);
226        // Compute Z1 = Z1 + X1
227        state.z1 = curve25519_add(state.z1, state.x1);
228        // Compute Z1 = Z1 * T2
229        state.z1 = curve25519_mul(state.z1, state.t2);
230        // Compute X1 = X1 * Z2
231        state.x1 = curve25519_mul(state.x1, state.z2);
232        // Compute Z2 = T1 - X2
233        state.z2 = curve25519_sub(state.t1, state.x2);
234        // Compute Z2 = Z2 * Z2
235        state.z2 = curve25519_sqr(state.z2);
236        // Compute Z2 = Z2 * U
237        state.z2 = curve25519_mul(state.z2, state.u);
238        // Compute X2 = X2 + T1
239        state.x2 = curve25519_add(state.x2, state.t1);
240        // Compute X2 = X2 * X2
241        state.x2 = curve25519_sqr(state.x2);
242    }
243
244    // Conditional swap
245    curve25519_swap(&mut state.x1, &mut state.x2, swap);
246    curve25519_swap(&mut state.z1, &mut state.z2, swap);
247
248    // Retrieve affine representation
249    state.u = curve25519_inv(state.z1);
250    state.u = curve25519_mul(state.u, state.x1);
251
252    Ok(u32_array_to_bytes(state.u))
253}
254
255/// Modular reduction
256/// 
257/// Performs modular reduction modulo the Curve25519 prime p = 2^255 - 19
258#[inline]
259fn curve25519_red(a: [u32; CURVE25519_WORD_LEN]) -> [u32; CURVE25519_WORD_LEN] {
260    let mut temp: u64 = 19;
261    let mut b: [u32; CURVE25519_WORD_LEN] = Default::default();
262
263    // Compute B = A + 19
264    for i in 0..CURVE25519_WORD_LEN {
265        temp += a[i] as u64;
266        b[i] = temp as u32;
267        temp >>= 32;
268    }
269
270    // Compute B = A - (2^255 - 19)
271    b[7] = b[7].wrapping_sub(0x80000000);
272    // If B < (2^255 - 19) then R = B, else R = A
273    curve25519_select(&b, &a, (b[7] & 0x80000000) >> 31)
274}
275
276/// Select an integer based on a condition
277/// 
278/// Performs constant-time selection between two integers
279#[inline]
280fn curve25519_select(a: &[u32; CURVE25519_WORD_LEN], b: &[u32; CURVE25519_WORD_LEN], c: u32) -> [u32; CURVE25519_WORD_LEN] {
281    // The mask is the all-1 or all-0 word
282    let mask = c.wrapping_sub(1);
283    let mut r: [u32; CURVE25519_WORD_LEN] = Default::default();
284    // Select between A and B
285    for i in 0..CURVE25519_WORD_LEN {
286        // Constant time implementation
287        r[i] = (a[i] & mask) | (b[i] & !mask);
288    }
289
290    r
291}
292
293/// Conditional swap
294/// 
295/// Performs constant-time conditional swap of two integers
296#[inline]
297fn curve25519_swap(a: &mut [u32; CURVE25519_WORD_LEN], b: &mut [u32; CURVE25519_WORD_LEN], c: u32) {
298    let mut dummy: u32;
299    // The mask is the all-1 or all-0 word
300    let mask = (!c).wrapping_add(1);
301
302    for i in 0..CURVE25519_WORD_LEN {
303        dummy = mask & (a[i] ^ b[i]);
304        a[i] ^= dummy;
305        b[i] ^= dummy;
306    }
307}
308
309/// Modular addition
310/// 
311/// Performs modular addition: R = (A + B) mod p
312#[inline]
313fn curve25519_add(a: [u32; CURVE25519_WORD_LEN], b: [u32; CURVE25519_WORD_LEN]) -> [u32; CURVE25519_WORD_LEN] {
314    let mut temp: u64 = 0;
315    let mut r: [u32; CURVE25519_WORD_LEN] = Default::default();
316
317    // Compute R = A + B
318    for i in 0..CURVE25519_WORD_LEN {
319        temp += a[i] as u64;
320        temp += b[i] as u64;
321        r[i] = temp as u32;
322        temp >>= 32;
323    }
324
325    // Perform modular reduction
326    curve25519_red(r)
327}
328
329/// Modular subtraction
330/// 
331/// Performs modular subtraction: R = (A - B) mod p
332#[inline]
333fn curve25519_sub(a: [u32; CURVE25519_WORD_LEN], b: [u32; CURVE25519_WORD_LEN]) -> [u32; CURVE25519_WORD_LEN] {
334    let mut temp: i64 = -19;
335    let mut result: [u32; CURVE25519_WORD_LEN] = Default::default();
336
337    // Compute R = A - 19 - B
338    for i in 0..CURVE25519_WORD_LEN {
339        temp += a[i] as i64;
340        temp -= b[i] as i64;
341        result[i] = temp as u32;
342        temp >>= 32;
343    }
344
345    // Compute R = A + (2^255 - 19) - B
346    result[7] = result[7].wrapping_add(0x80000000);
347
348    // Perform modular reduction
349    curve25519_red(result)
350}
351
352/// Modular multiplication
353/// 
354/// Performs modular multiplication: R = (A * B) mod p
355#[inline]
356fn curve25519_mul(a: [u32; CURVE25519_WORD_LEN], b: [u32; CURVE25519_WORD_LEN]) -> [u32; CURVE25519_WORD_LEN] {
357    let mut c: u64 = 0;
358    let mut temp: u64 = 0;
359    let mut u: [u32; 16] = Default::default();
360
361    // Comba's method is used to perform multiplication
362    for i in 0..16 {
363        // The algorithm computes the products, column by column
364        if i < CURVE25519_WORD_LEN {
365            // Inner loop
366            for j in 0..=i {
367                temp += a[j] as u64 * b[i - j] as u64;
368                c += temp >> 32;
369                temp &= 0xFFFFFFFF;
370            }
371        } else {
372            // Inner loop
373            for j in i - 7..CURVE25519_WORD_LEN {
374                temp += a[j] as u64 * b[i - j] as u64;
375                c += temp >> 32;
376                temp &= 0xFFFFFFFF;
377            }
378        }
379
380        // At the bottom of each column, the final result is written to memory
381        u[i] = temp as u32;
382
383        // Propagate the carry upwards
384        temp = c & 0xFFFFFFFF;
385        c >>= 32;
386    }
387
388    // Reduce bit 255 (2^255 = 19 mod p)
389    temp = (u[7] >> 31) as u64 * 19;
390    // Mask the most significant bit
391    u[7] &= 0x7FFFFFFF;
392
393    // Perform fast modular reduction (first pass)
394    for i in 0..CURVE25519_WORD_LEN {
395        temp += u[i] as u64;
396        temp += u[i + CURVE25519_WORD_LEN] as u64 * 38;
397        u[i] = temp as u32;
398        temp >>= 32;
399    }
400
401    // Reduce bit 256 (2^256 = 38 mod p)
402    temp *= 38;
403    // Reduce bit 255 (2^255 = 19 mod p)
404    temp += (u[7] >> 31) as u64 * 19;
405    // Mask the most significant bit
406    u[7] &= 0x7FFFFFFF;
407
408    // Perform fast modular reduction (second pass)
409    for i in 0..CURVE25519_WORD_LEN {
410        temp += u[i] as u64;
411        u[i] = temp as u32;
412        temp >>= 32;
413    }
414
415    // Reduce non-canonical values
416    let mut temp: [u32; CURVE25519_WORD_LEN] = Default::default();
417    temp.copy_from_slice(&u[..CURVE25519_WORD_LEN]);
418    curve25519_red(temp)
419}
420
421/// Modular squaring
422/// 
423/// Performs modular squaring: R = (A^2) mod p
424#[inline]
425fn curve25519_sqr(a: [u32; CURVE25519_WORD_LEN]) -> [u32; CURVE25519_WORD_LEN] {
426    // Compute R = (A^2) mod p
427    curve25519_mul(a, a)
428}
429
430/// Modular multiplication by integer
431/// 
432/// Performs modular multiplication: R = (A * B) mod p where B is a 32-bit integer
433#[inline]
434fn curve25519_mul_int(a: [u32; CURVE25519_WORD_LEN], b: u32) -> [u32; CURVE25519_WORD_LEN] {
435    let mut temp: u64 = 0;
436    let mut u: [u32; CURVE25519_WORD_LEN] = Default::default();
437
438    // Compute R = A * B
439    for i in 0..CURVE25519_WORD_LEN {
440        temp += a[i] as u64 * b as u64;
441        u[i] = temp as u32;
442        temp >>= 32;
443    }
444
445    // Reduce bit 256 (2^256 = 38 mod p)
446    temp *= 38;
447    // Reduce bit 255 (2^255 = 19 mod p)
448    temp += (u[7] >> 31) as u64 * 19;
449    // Mask the most significant bit
450    u[7] &= 0x7FFFFFFF;
451
452    // Perform fast modular reduction
453    for i in 0..CURVE25519_WORD_LEN {
454        temp += u[i] as u64;
455        u[i] = temp as u32;
456        temp >>= 32;
457    }
458
459    // Reduce non-canonical values
460    curve25519_red(u)
461}
462
463
464
465/// Modular multiplicative inverse
466/// 
467/// Performs modular multiplicative inverse: R = A^-1 mod p
468#[inline]
469fn curve25519_inv(a: [u32; CURVE25519_WORD_LEN]) -> [u32; CURVE25519_WORD_LEN] {
470    let mut u: [u32; CURVE25519_WORD_LEN];
471    let mut v: [u32; CURVE25519_WORD_LEN];
472
473    // Since GF(p) is a prime field, the Fermat's little theorem can be
474    // used to find the multiplicative inverse of A modulo p
475    u = curve25519_sqr(a);
476    u = curve25519_mul(u, a); // A^(2^2 - 1)
477    u = curve25519_sqr(u);
478    v = curve25519_mul(u, a); // A^(2^3 - 1)
479
480    u = curve25519_pwr2(v, 3);
481    u = curve25519_mul(u, v); // A^(2^6 - 1)
482    u = curve25519_sqr(u);
483    v = curve25519_mul(u, a); // A^(2^7 - 1)
484
485    u = curve25519_pwr2(v, 7);
486    u = curve25519_mul(u, v); // A^(2^14 - 1)
487    u = curve25519_sqr(u);
488    v = curve25519_mul(u, a); // A^(2^15 - 1)
489
490    u = curve25519_pwr2(v, 15);
491    u = curve25519_mul(u, v); // A^(2^30 - 1)
492    u = curve25519_sqr(u);
493    v = curve25519_mul(u, a); // A^(2^31 - 1)
494
495    u = curve25519_pwr2(v, 31);
496    v = curve25519_mul(u, v); // A^(2^62 - 1)
497
498    u = curve25519_pwr2(v, 62);
499    u = curve25519_mul(u, v); // A^(2^124 - 1)
500    u = curve25519_sqr(u);
501    v = curve25519_mul(u, a); // A^(2^125 - 1)
502
503    u = curve25519_pwr2(v, 125);
504    u = curve25519_mul(u, v); // A^(2^250 - 1)
505    u = curve25519_sqr(u);
506    u = curve25519_sqr(u);
507    u = curve25519_mul(u, a);
508    u = curve25519_sqr(u);
509    u = curve25519_sqr(u);
510    u = curve25519_mul(u, a);
511    u = curve25519_sqr(u);
512    curve25519_mul(u, a) // A^(2^255 - 21)
513}
514
515/// Raise an integer to power 2^n
516/// 
517/// Performs modular exponentiation: R = (A^(2^n)) mod p
518#[inline]
519fn curve25519_pwr2(a: [u32; CURVE25519_WORD_LEN], n: usize) -> [u32; CURVE25519_WORD_LEN] {
520    // Pre-compute (A^2) mod p
521    let mut result = curve25519_sqr(a);
522
523    // Compute R = (A^(2^n)) mod p
524    for _ in 1..n {
525        result = curve25519_sqr(result);
526    }
527
528    result
529}
530
531#[cfg(test)]
532/// Generates a random private key for X25519
533///
534/// # Arguments
535/// * `rng` - A cryptographically secure random number generator
536///
537/// # Returns
538/// * `Ok([u8; 32])` - A valid private key
539/// * `Err(X25519Error)` - If key generation fails
540///
541/// # Example
542/// ```rust,ignore
543/// use sha::ecdh_x25519::generate_private_key;
544/// use rand::rngs::OsRng;
545///
546/// let private_key = generate_private_key(&mut OsRng)?;
547/// # Ok::<(), sha::ecdh_x25519::X25519Error>(())
548/// ```
549pub fn generate_private_key<R: rand::RngCore + rand::CryptoRng>(rng: &mut R) -> X25519Result<[u8; CURVE25519_BYTE_LEN]> {
550    let mut key = [0u8; CURVE25519_BYTE_LEN];
551    rng.fill_bytes(&mut key);
552    
553    // Ensure the key is not all zeros
554    if key.iter().all(|&b| b == 0) {
555        return Err(X25519Error::ValidationError("Generated key is all zeros".to_string()));
556    }
557    
558    // Apply X25519 key clamping
559    key[0] &= 0xf8;  // Clear the 3 least significant bits
560    key[31] &= 0x7f; // Clear the most significant bit
561    key[31] |= 0x40; // Set the second most significant bit
562    
563    Ok(key)
564}
565
566#[cfg(test)]
567mod tests {
568    use super::*;
569    use rand::rngs::OsRng;
570
571
572    #[test]
573    fn test_x25519() {
574        let scalar = [
575            0xa5u8, 0x46, 0xe3, 0x6b, 0xf0, 0x52, 0x7c, 0x9d, 0x3b, 0x16, 0x15, 0x4b, 0x82, 0x46,
576            0x5e, 0xdd, 0x62, 0x14, 0x4c, 0x0a, 0xc1, 0xfc, 0x5a, 0x18, 0x50, 0x6a, 0x22, 0x44,
577            0xba, 0x44, 0x9a, 0xc4,
578        ];
579        let u_coordinate = [
580            0xe6u8, 0xdb, 0x68, 0x67, 0x58, 0x30, 0x30, 0xdb, 0x35, 0x94, 0xc1, 0xa4, 0x24, 0xb1,
581            0x5f, 0x7c, 0x72, 0x66, 0x24, 0xec, 0x26, 0xb3, 0x35, 0x3b, 0x10, 0xa9, 0x03, 0xa6,
582            0xd0, 0xab, 0x1c, 0x4c,
583        ];
584        let result = x25519(scalar, u_coordinate).unwrap();
585
586        assert_eq!(
587            result,
588            [
589                0xc3, 0xda, 0x55, 0x37, 0x9d, 0xe9, 0xc6, 0x90, 0x8e, 0x94, 0xea, 0x4d, 0xf2, 0x8d,
590                0x08, 0x4f, 0x32, 0xec, 0xcf, 0x03, 0x49, 0x1c, 0x71, 0xf7, 0x54, 0xb4, 0x07, 0x55,
591                0x77, 0xa2, 0x85, 0x52
592            ]
593        );
594
595        let scalar = [
596            0x4b, 0x66, 0xe9, 0xd4, 0xd1, 0xb4, 0x67, 0x3c, 0x5a, 0xd2, 0x26, 0x91, 0x95, 0x7d,
597            0x6a, 0xf5, 0xc1, 0x1b, 0x64, 0x21, 0xe0, 0xea, 0x01, 0xd4, 0x2c, 0xa4, 0x16, 0x9e,
598            0x79, 0x18, 0xba, 0x0d,
599        ];
600        let u_coordinate = [
601            0xe5, 0x21, 0x0f, 0x12, 0x78, 0x68, 0x11, 0xd3, 0xf4, 0xb7, 0x95, 0x9d, 0x05, 0x38,
602            0xae, 0x2c, 0x31, 0xdb, 0xe7, 0x10, 0x6f, 0xc0, 0x3c, 0x3e, 0xfc, 0x4c, 0xd5, 0x49,
603            0xc7, 0x15, 0xa4, 0x93,
604        ];
605        let result = x25519(scalar, u_coordinate).unwrap();
606
607        assert_eq!(
608            result,
609            [
610                0x95, 0xcb, 0xde, 0x94, 0x76, 0xe8, 0x90, 0x7d, 0x7a, 0xad, 0xe4, 0x5c, 0xb4, 0xb8,
611                0x73, 0xf8, 0x8b, 0x59, 0x5a, 0x68, 0x79, 0x9f, 0xa1, 0x52, 0xe6, 0xf8, 0xf7, 0x64,
612                0x7a, 0xac, 0x79, 0x57
613            ]
614        );
615    }
616
617    #[test]
618    fn test_x25519_series() {
619        let scalar = U_COORDINATE;
620        let u_coordinate = U_COORDINATE;
621        let result = x25519(scalar, u_coordinate).unwrap();
622
623        assert_eq!(
624            result,
625            [
626                0x42, 0x2c, 0x8e, 0x7a, 0x62, 0x27, 0xd7, 0xbc, 0xa1, 0x35, 0x0b, 0x3e, 0x2b, 0xb7,
627                0x27, 0x9f, 0x78, 0x97, 0xb8, 0x7b, 0xb6, 0x85, 0x4b, 0x78, 0x3c, 0x60, 0xe8, 0x03,
628                0x11, 0xae, 0x30, 0x79
629            ]
630        );
631    }
632
633    #[test]
634    fn test_generate_private_key() {
635        let mut rng = OsRng;
636        let private_key = generate_private_key(&mut rng).unwrap();
637        
638        // Check that the key is not all zeros
639        assert!(!private_key.iter().all(|&b| b == 0));
640        
641        // Check that the key has been properly clamped
642        assert_eq!(private_key[0] & 0x07, 0); // Bottom 3 bits should be 0
643        assert_eq!(private_key[31] & 0x80, 0); // Top bit should be 0
644        assert_eq!(private_key[31] & 0x40, 0x40); // Second top bit should be 1
645    }
646
647    #[test]
648    fn test_invalid_input() {
649        let zero_input = [0u8; CURVE25519_BYTE_LEN];
650        let valid_input = U_COORDINATE;
651        
652        // Test with zero scalar
653        assert!(matches!(x25519(zero_input, valid_input), Err(X25519Error::InvalidInput)));
654        
655        // Test with zero u-coordinate
656        assert!(matches!(x25519(valid_input, zero_input), Err(X25519Error::InvalidInput)));
657        
658        // Test with both zero
659        assert!(matches!(x25519(zero_input, zero_input), Err(X25519Error::InvalidInput)));
660    }
661
662    #[test]
663    fn test_key_exchange() {
664        let mut rng = OsRng;
665        
666        // Generate private keys for Alice and Bob
667        let alice_private = generate_private_key(&mut rng).unwrap();
668        let bob_private = generate_private_key(&mut rng).unwrap();
669        
670        // Compute public keys
671        let alice_public = x25519(alice_private, U_COORDINATE).unwrap();
672        let bob_public = x25519(bob_private, U_COORDINATE).unwrap();
673        
674        // Perform key exchange
675        let alice_shared = x25519(alice_private, bob_public).unwrap();
676        let bob_shared = x25519(bob_private, alice_public).unwrap();
677        
678        // Both parties should have the same shared secret
679        assert_eq!(alice_shared, bob_shared);
680    }
681}