use num_bigint::BigUint;
use super::{
ec::{pad_to_32_bytes, EcPoint, PKey, Private, Public, P256_FIELD_SIZE},
group_p256,
};
use num_traits::Zero;
#[derive(Debug)]
pub struct Deriver {
private_key: Vec<u8>,
peer_key: Option<Vec<u8>>,
}
impl Deriver {
pub fn new(private_key: &PKey<Private>) -> Result<Self, ece::Error> {
if private_key.key_data.len() != 96 {
println!("Invalid key length: {}", private_key.key_data.len());
return Err(ece::Error::CryptoError);
}
let priv_key_bytes = private_key.key_data[..32].to_vec();
let x = BigUint::from_bytes_be(&private_key.key_data[32..64]);
let y = BigUint::from_bytes_be(&private_key.key_data[64..]);
let group = group_p256();
if !group.check_point(&x, &y) {
println!("Point validation failed in Deriver::new");
return Err(ece::Error::CryptoError);
}
Ok(Deriver {
private_key: priv_key_bytes,
peer_key: None,
})
}
pub fn set_peer(&mut self, peer_key: &PKey<Public>) -> Result<(), ece::Error> {
if peer_key.key_data.len() != 64 {
println!("peer_key.key_data.len() != 64");
return Err(ece::Error::CryptoError);
}
let group = group_p256();
let x = BigUint::from_bytes_be(&peer_key.key_data[..P256_FIELD_SIZE]);
let y = BigUint::from_bytes_be(&peer_key.key_data[P256_FIELD_SIZE..]);
if !group.check_point(&x, &y) {
println!("Invalid point: curve equation failed");
return Err(ece::Error::CryptoError);
}
if x.is_zero() && y.is_zero() {
println!("Invalid point: infinity point");
return Err(ece::Error::CryptoError);
}
self.peer_key = Some(peer_key.key_data.clone());
Ok(())
}
pub fn derive_to_vec(&self) -> Result<Vec<u8>, ece::Error> {
let peer_key = self.peer_key.as_ref().ok_or(ece::Error::CryptoError)?;
let group = group_p256();
let p = group.prime();
let n = group.order();
let priv_key = BigUint::from_bytes_be(&self.private_key);
if priv_key >= *n {
println!("Private key out of range");
return Err(ece::Error::CryptoError);
}
let peer_x = BigUint::from_bytes_be(&peer_key[..P256_FIELD_SIZE]);
let peer_y = BigUint::from_bytes_be(&peer_key[P256_FIELD_SIZE..]);
if !group.check_point(&peer_x, &peer_y) {
println!("Invalid peer point");
return Err(ece::Error::CryptoError);
}
let peer_point =
EcPoint::new(group, peer_x, peer_y).map_err(|_| ece::Error::CryptoError)?;
let result = peer_point
.scalar_mul(&priv_key)
.map_err(|_| ece::Error::CryptoError)?;
if result.z.is_zero() {
println!("Result is point at infinity");
return Err(ece::Error::CryptoError);
}
let (shared_x, _) = result.get_affine().map_err(|_| ece::Error::CryptoError)?;
let shared_x = shared_x % p;
if shared_x >= *p {
println!("Shared secret out of range");
return Err(ece::Error::CryptoError);
}
let result = pad_to_32_bytes(&shared_x).map_err(|_| ece::Error::CryptoError)?;
if result.len() != 32 {
println!("Invalid result length");
return Err(ece::Error::CryptoError);
}
Ok(result)
}
}