use crate::errors::*;
use ring::agreement;
use ring::rand::{SystemRandom, SecureRandom};
use ring::{hmac, hkdf, aead};
use untrusted;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
pub type IdentityKey = [u8; 32];
pub type IdentityKeyPrivate = [u8; 32];
pub type SignedPreKeyId = u32;
pub type SignedPreKeyPublic = Vec<u8>;
pub type SignedPreKeyPrivate = Vec<u8>;
pub type SignedPreKeySignature = Vec<u8>;
pub type OneTimeKeyId = u32;
pub type OneTimeKeyPublic = Vec<u8>;
pub type OneTimeKeyPrivate = Vec<u8>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IdentityKeyPair {
pub public_key: IdentityKey,
pub private_key: IdentityKeyPrivate,
}
impl IdentityKeyPair {
pub fn generate() -> Result<Self> {
let mut public_key = [0u8; 32];
let mut private_key = [0u8; 32];
SystemRandom::new().fill(&mut public_key)
.map_err(|_| "Failed to generate public key")?;
SystemRandom::new().fill(&mut private_key)
.map_err(|_| "Failed to generate private key")?;
Ok(IdentityKeyPair { public_key, private_key })
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SignedPreKeyPair {
pub key_id: SignedPreKeyId,
pub public_key: SignedPreKeyPublic,
pub private_key: SignedPreKeyPrivate,
pub signature: SignedPreKeySignature,
pub timestamp: u64,
}
impl SignedPreKeyPair {
pub fn generate(key_id: SignedPreKeyId) -> Result<Self> {
let (private_key, public_key) = crate::crypto::generate_keypair()?;
let mut private_bytes = vec![0u8; private_key.public_key_len()];
private_key.compute_public_key(&mut private_bytes)?;
let signature = vec![0u8; 64];
Ok(SignedPreKeyPair {
key_id,
public_key,
private_key: private_bytes,
signature,
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs(),
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OneTimeKeyPair {
pub key_id: OneTimeKeyId,
pub public_key: OneTimeKeyPublic,
pub private_key: OneTimeKeyPrivate,
}
impl OneTimeKeyPair {
pub fn generate(key_id: OneTimeKeyId) -> Result<Self> {
let (private_key, public_key) = crate::crypto::generate_keypair()?;
let mut private_bytes = vec![0u8; private_key.public_key_len()];
private_key.compute_public_key(&mut private_bytes)?;
Ok(OneTimeKeyPair {
key_id,
public_key,
private_key: private_bytes,
})
}
}
pub type RegistrationId = u32;
pub fn generate_registration_id() -> RegistrationId {
let mut rng = SystemRandom::new();
let mut bytes = [0u8; 4];
rng.fill(&mut bytes).unwrap();
u32::from_le_bytes(bytes) & 0x3FFF }
pub struct SessionCipher {
registration_id: RegistrationId,
identity_key_pair: IdentityKeyPair,
signed_pre_key_pair: SignedPreKeyPair,
one_time_keys: std::collections::HashMap<OneTimeKeyId, OneTimeKeyPair>,
}
impl SessionCipher {
pub fn new(
registration_id: RegistrationId,
identity_key_pair: IdentityKeyPair,
signed_pre_key_pair: SignedPreKeyPair,
) -> Self {
SessionCipher {
registration_id,
identity_key_pair,
signed_pre_key_pair,
one_time_keys: std::collections::HashMap::new(),
}
}
pub fn add_one_time_key(&mut self, key: OneTimeKeyPair) {
self.one_time_keys.insert(key.key_id, key);
}
pub fn calculate_ratchet_keys(
&self,
their_identity_key: &[u8],
our_base_key: &[u8],
their_base_key: &[u8],
) -> Result<[u8; 64]> {
let shared_secret = crate::crypto::agree_on_shared_secret(
agreement::EphemeralPrivateKey::generate(&agreement::X25519, &SystemRandom::new())?,
their_identity_key
)?;
let mut key_material = [0u8; 64];
hkdf::extract_and_expand(
&hkdf::Salt::new(hkdf::HKDF_SHA256, b"WhisperText"),
&shared_secret,
&[], &mut key_material
);
Ok(key_material)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_identity_key_pair_generation() {
let identity_key_pair = IdentityKeyPair::generate().unwrap();
assert_eq!(identity_key_pair.public_key.len(), 32);
assert_eq!(identity_key_pair.private_key.len(), 32);
}
#[test]
fn test_signed_pre_key_generation() {
let signed_pre_key = SignedPreKeyPair::generate(1).unwrap();
assert_eq!(signed_pre_key.key_id, 1);
}
#[test]
fn test_one_time_key_generation() {
let one_time_key = OneTimeKeyPair::generate(1).unwrap();
assert_eq!(one_time_key.key_id, 1);
}
#[test]
fn test_registration_id_generation() {
let reg_id = generate_registration_id();
assert!(reg_id <= 0x3FFF); }
}