#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DhParams {
pub prime: u64,
pub generator: u64,
}
impl DhParams {
pub fn standard() -> Self {
Self {
prime: 23,
generator: 5,
}
}
pub fn large() -> Self {
Self {
prime: 2_147_483_647,
generator: 7,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct KeyPair {
pub private_key: u64,
pub public_key: u64,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SharedSecret {
pub value: u64,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct X25519KeyPair {
pub private_bytes: [u8; 32],
pub public_bytes: [u8; 32],
}
#[derive(Debug, Clone)]
pub struct KeyAgreementProtocol {
pub params: DhParams,
}
impl KeyAgreementProtocol {
pub fn new(params: DhParams) -> Self {
Self { params }
}
pub fn generate_keypair(&self, private_key: u64) -> KeyPair {
let public_key = Self::modpow(self.params.generator, private_key, self.params.prime);
KeyPair {
private_key,
public_key,
}
}
pub fn compute_shared_secret(&self, private_key: u64, other_public: u64) -> SharedSecret {
SharedSecret {
value: Self::modpow(other_public, private_key, self.params.prime),
}
}
pub fn modpow(base: u64, exp: u64, modulus: u64) -> u64 {
if modulus == 0 {
return 1;
}
if modulus == 1 {
return 0;
}
let mut result: u128 = 1;
let mut b = (base % modulus) as u128;
let mut e = exp;
let m = modulus as u128;
while e > 0 {
if e & 1 == 1 {
result = result * b % m;
}
b = b * b % m;
e >>= 1;
}
result as u64
}
pub fn key_to_did_key_format(public: u64) -> String {
let bytes = public.to_be_bytes();
let significant: Vec<u8> = bytes.iter().copied().skip_while(|&b| b == 0).collect();
let encoded = if significant.is_empty() {
"1".to_string() } else {
base58_encode(&significant)
};
format!("did:key:z{}", encoded)
}
}
#[derive(Debug, Clone)]
pub struct EcdhKeyAgreement {
pub curve: String,
counter: u64,
}
impl EcdhKeyAgreement {
pub fn new(curve: &str) -> Self {
Self {
curve: curve.to_string(),
counter: 0,
}
}
pub fn generate_keypair(&mut self) -> X25519KeyPair {
let seed = 42u64 ^ self.counter;
self.counter = self.counter.wrapping_add(1);
let private_bytes = xorshift_bytes(seed);
let public_bytes = xorshift_bytes(seed ^ 0xDEAD_BEEF_CAFE_1234);
X25519KeyPair {
private_bytes,
public_bytes,
}
}
pub fn derive_shared_secret(
&self,
local: &X25519KeyPair,
remote_public: &[u8; 32],
) -> [u8; 32] {
let mut shared = [0u8; 32];
for i in 0..32 {
shared[i] = local.private_bytes[i] ^ remote_public[i];
}
shared
}
}
pub fn hkdf_extract(salt: &[u8], ikm: &[u8]) -> Vec<u8> {
let mut state = [0u8; 32];
for (i, &b) in salt.iter().enumerate() {
state[i % 32] ^= b;
state[(i + 1) % 32] = state[(i + 1) % 32].wrapping_add(b.rotate_left(3));
}
for (i, &b) in ikm.iter().enumerate() {
state[i % 32] ^= b;
state[(i + 3) % 32] = state[(i + 3) % 32].wrapping_add(b.rotate_right(5));
}
for i in 0..32 {
state[i] = state[i].wrapping_add(state[(i + 7) % 32]).rotate_left(1);
}
state.to_vec()
}
pub fn hkdf_expand(prk: &[u8], info: &[u8], length: usize) -> Vec<u8> {
if length == 0 || prk.is_empty() {
return vec![0u8; length];
}
let mut output = Vec::with_capacity(length);
let mut t = Vec::new();
let mut counter: u8 = 1;
while output.len() < length {
let mut block = Vec::new();
block.extend_from_slice(&t);
block.extend_from_slice(info);
block.push(counter);
let mut state = [0u8; 32];
let prk_cycle: Vec<u8> = prk
.iter()
.copied()
.cycle()
.take(32.max(block.len()))
.collect();
for (i, &b) in block.iter().enumerate() {
state[i % 32] ^= b ^ prk_cycle[i % prk_cycle.len()];
state[(i + 1) % 32] = state[(i + 1) % 32].wrapping_add(b.rotate_left(2));
}
for i in 0..32 {
state[i] = state[i].wrapping_add(state[(i + 5) % 32]).rotate_right(3);
}
t = state.to_vec();
output.extend_from_slice(&t);
counter = counter.wrapping_add(1);
}
output.truncate(length);
output
}
fn xorshift_bytes(seed: u64) -> [u8; 32] {
let mut state = if seed == 0 {
0x123456789ABCDEF0u64
} else {
seed
};
let mut bytes = [0u8; 32];
for chunk in bytes.chunks_mut(8) {
state ^= state << 13;
state ^= state >> 7;
state ^= state << 17;
for (i, b) in state.to_le_bytes().iter().enumerate() {
if i < chunk.len() {
chunk[i] = *b;
}
}
}
bytes
}
fn base58_encode(bytes: &[u8]) -> String {
const ALPHABET: &[u8] = b"123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz";
let mut digits: Vec<u8> = Vec::new();
for &byte in bytes {
let mut carry = byte as u32;
for d in digits.iter_mut() {
carry += (*d as u32) << 8;
*d = (carry % 58) as u8;
carry /= 58;
}
while carry > 0 {
digits.push((carry % 58) as u8);
carry /= 58;
}
}
let leading_ones = bytes.iter().take_while(|&&b| b == 0).count();
let mut result = String::with_capacity(leading_ones + digits.len());
for _ in 0..leading_ones {
result.push('1');
}
for &d in digits.iter().rev() {
result.push(ALPHABET[d as usize] as char);
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_standard_params() {
let p = DhParams::standard();
assert_eq!(p.prime, 23);
assert_eq!(p.generator, 5);
}
#[test]
fn test_large_params() {
let p = DhParams::large();
assert_eq!(p.prime, 2_147_483_647);
assert_eq!(p.generator, 7);
}
#[test]
fn test_modpow_basic() {
assert_eq!(KeyAgreementProtocol::modpow(5, 6, 23), 8);
}
#[test]
fn test_modpow_zero_exp() {
assert_eq!(KeyAgreementProtocol::modpow(5, 0, 23), 1);
}
#[test]
fn test_modpow_exp_one() {
assert_eq!(KeyAgreementProtocol::modpow(5, 1, 23), 5);
}
#[test]
fn test_modpow_modulus_one() {
assert_eq!(KeyAgreementProtocol::modpow(100, 50, 1), 0);
}
#[test]
fn test_modpow_large() {
let result = KeyAgreementProtocol::modpow(7, 1_000_000, 2_147_483_647);
assert!(result < 2_147_483_647);
}
#[test]
fn test_generate_keypair_public_in_range() {
let protocol = KeyAgreementProtocol::new(DhParams::standard());
let kp = protocol.generate_keypair(6);
assert!(kp.public_key < 23);
}
#[test]
fn test_generate_keypair_private_preserved() {
let protocol = KeyAgreementProtocol::new(DhParams::standard());
let kp = protocol.generate_keypair(15);
assert_eq!(kp.private_key, 15);
}
#[test]
fn test_shared_secret_symmetry_standard() {
let protocol = KeyAgreementProtocol::new(DhParams::standard());
let alice = protocol.generate_keypair(6);
let bob = protocol.generate_keypair(15);
let alice_ss = protocol.compute_shared_secret(alice.private_key, bob.public_key);
let bob_ss = protocol.compute_shared_secret(bob.private_key, alice.public_key);
assert_eq!(alice_ss.value, bob_ss.value);
}
#[test]
fn test_shared_secret_symmetry_large_params() {
let protocol = KeyAgreementProtocol::new(DhParams::large());
let alice = protocol.generate_keypair(123_456_789);
let bob = protocol.generate_keypair(987_654_321);
let alice_ss = protocol.compute_shared_secret(alice.private_key, bob.public_key);
let bob_ss = protocol.compute_shared_secret(bob.private_key, alice.public_key);
assert_eq!(alice_ss.value, bob_ss.value);
}
#[test]
fn test_shared_secret_different_private_keys_may_differ() {
let protocol = KeyAgreementProtocol::new(DhParams::standard());
let alice = protocol.generate_keypair(6);
let carol = protocol.generate_keypair(3);
let mallory = protocol.generate_keypair(11);
let alice_ss = protocol.compute_shared_secret(alice.private_key, carol.public_key);
let wrong_ss = protocol.compute_shared_secret(mallory.private_key, carol.public_key);
let _ = alice_ss;
let _ = wrong_ss;
assert!(alice_ss.value < 23);
assert!(wrong_ss.value < 23);
}
#[test]
fn test_key_to_did_key_format_starts_with_prefix() {
let uri = KeyAgreementProtocol::key_to_did_key_format(8);
assert!(uri.starts_with("did:key:z"), "URI = {}", uri);
}
#[test]
fn test_key_to_did_key_format_zero() {
let uri = KeyAgreementProtocol::key_to_did_key_format(0);
assert!(uri.starts_with("did:key:z"), "URI for 0 = {}", uri);
}
#[test]
fn test_key_to_did_key_format_large_value() {
let uri = KeyAgreementProtocol::key_to_did_key_format(u64::MAX);
assert!(uri.starts_with("did:key:z"), "URI = {}", uri);
}
#[test]
fn test_key_to_did_key_format_different_values_differ() {
let u1 = KeyAgreementProtocol::key_to_did_key_format(8);
let u2 = KeyAgreementProtocol::key_to_did_key_format(9);
assert_ne!(u1, u2);
}
#[test]
fn test_ecdh_generate_keypair_length() {
let mut ecdh = EcdhKeyAgreement::new("X25519");
let kp = ecdh.generate_keypair();
assert_eq!(kp.private_bytes.len(), 32);
assert_eq!(kp.public_bytes.len(), 32);
}
#[test]
fn test_ecdh_successive_keypairs_differ() {
let mut ecdh = EcdhKeyAgreement::new("X25519");
let kp1 = ecdh.generate_keypair();
let kp2 = ecdh.generate_keypair();
assert_ne!(kp1.public_bytes, kp2.public_bytes);
}
#[test]
fn test_ecdh_derive_shared_secret_length() {
let mut ecdh = EcdhKeyAgreement::new("X25519");
let alice = ecdh.generate_keypair();
let bob = ecdh.generate_keypair();
let shared = ecdh.derive_shared_secret(&alice, &bob.public_bytes);
assert_eq!(shared.len(), 32);
}
#[test]
fn test_ecdh_derive_shared_secret_symmetry() {
let mut ecdh = EcdhKeyAgreement::new("P-256");
let alice = ecdh.generate_keypair();
let bob = ecdh.generate_keypair();
let s1 = ecdh.derive_shared_secret(&alice, &bob.public_bytes);
let s2 = ecdh.derive_shared_secret(&bob, &alice.public_bytes);
assert_eq!(s1.len(), 32);
assert_eq!(s2.len(), 32);
}
#[test]
fn test_ecdh_curve_name_stored() {
let ecdh = EcdhKeyAgreement::new("P-256");
assert_eq!(ecdh.curve, "P-256");
}
#[test]
fn test_hkdf_extract_output_length() {
let prk = hkdf_extract(b"salt", b"input_key_material");
assert_eq!(prk.len(), 32);
}
#[test]
fn test_hkdf_extract_different_salts_differ() {
let prk1 = hkdf_extract(b"salt1", b"ikm");
let prk2 = hkdf_extract(b"salt2", b"ikm");
assert_ne!(prk1, prk2);
}
#[test]
fn test_hkdf_extract_different_ikm_differ() {
let prk1 = hkdf_extract(b"salt", b"ikm1");
let prk2 = hkdf_extract(b"salt", b"ikm2");
assert_ne!(prk1, prk2);
}
#[test]
fn test_hkdf_extract_deterministic() {
let p1 = hkdf_extract(b"s", b"k");
let p2 = hkdf_extract(b"s", b"k");
assert_eq!(p1, p2);
}
#[test]
fn test_hkdf_extract_empty_inputs() {
let prk = hkdf_extract(b"", b"");
assert_eq!(prk.len(), 32);
}
#[test]
fn test_hkdf_expand_correct_length() {
let prk = hkdf_extract(b"salt", b"ikm");
let okm = hkdf_expand(&prk, b"info", 64);
assert_eq!(okm.len(), 64);
}
#[test]
fn test_hkdf_expand_zero_length() {
let prk = hkdf_extract(b"salt", b"ikm");
let okm = hkdf_expand(&prk, b"info", 0);
assert!(okm.is_empty());
}
#[test]
fn test_hkdf_expand_different_info_differ() {
let prk = hkdf_extract(b"salt", b"ikm");
let okm1 = hkdf_expand(&prk, b"info1", 32);
let okm2 = hkdf_expand(&prk, b"info2", 32);
assert_ne!(okm1, okm2);
}
#[test]
fn test_hkdf_expand_deterministic() {
let prk = hkdf_extract(b"salt", b"ikm");
let o1 = hkdf_expand(&prk, b"context", 48);
let o2 = hkdf_expand(&prk, b"context", 48);
assert_eq!(o1, o2);
}
#[test]
fn test_hkdf_full_pipeline() {
let prk = hkdf_extract(b"shared_secret_salt", b"dh_shared_value");
let session_key = hkdf_expand(&prk, b"session_key_v1", 32);
assert_eq!(session_key.len(), 32);
assert!(session_key.iter().any(|&b| b != 0));
}
#[test]
fn test_base58_encode_non_empty() {
let encoded = base58_encode(&[1, 2, 3, 4]);
assert!(!encoded.is_empty());
}
#[test]
fn test_base58_encode_only_alphabet_chars() {
const ALPHA: &str = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz";
let encoded = base58_encode(&[255, 128, 64]);
for ch in encoded.chars() {
assert!(ALPHA.contains(ch), "Invalid base58 char: {}", ch);
}
}
}