use num_bigint::BigInt;
use rsa::pkcs8::EncodePublicKey;
use rsa::rand_core::OsRng;
use rsa::{Pkcs1v15Encrypt, RsaPrivateKey, RsaPublicKey};
use sha1::Digest;
use sha1::Sha1;
pub struct ServerKeyPair {
private_key: RsaPrivateKey,
public_key_der: Vec<u8>,
}
impl ServerKeyPair {
pub fn generate() -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
let private_key = RsaPrivateKey::new(&mut OsRng, 1024)?;
let public_key = RsaPublicKey::from(&private_key);
let public_key_der = public_key.to_public_key_der()?.to_vec();
Ok(Self {
private_key,
public_key_der,
})
}
#[must_use]
pub fn public_key_der(&self) -> &[u8] {
&self.public_key_der
}
pub fn decrypt(&self, ciphertext: &[u8]) -> Result<Vec<u8>, rsa::Error> {
self.private_key.decrypt(Pkcs1v15Encrypt, ciphertext)
}
}
#[must_use]
pub fn generate_server_id(shared_secret: &[u8], public_key_der: &[u8]) -> String {
let mut hasher = Sha1::new();
hasher.update(shared_secret);
hasher.update(public_key_der);
let hash = hasher.finalize();
let bigint = BigInt::from_signed_bytes_be(&hash);
bigint.to_str_radix(16)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generate_key_pair() {
let kp = ServerKeyPair::generate().unwrap();
assert!(!kp.public_key_der().is_empty());
assert!(kp.public_key_der().len() > 100);
}
#[test]
fn test_encrypt_decrypt_roundtrip() {
let kp = ServerKeyPair::generate().unwrap();
let plaintext = b"Hello, Minecraft!";
let public_key = RsaPublicKey::from(&kp.private_key);
let ciphertext = public_key
.encrypt(&mut OsRng, Pkcs1v15Encrypt, plaintext)
.unwrap();
let decrypted = kp.decrypt(&ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_server_id_hash_twos_complement() {
let hash = generate_server_id(&[0xFF; 16], &[0xFF; 16]);
assert!(!hash.is_empty());
assert!(hash.chars().all(|c| c.is_ascii_hexdigit() || c == '-'));
}
}