use std::{fs, path::PathBuf};
use log::warn;
use openssl::{
pkey::{Private, Public},
rsa::{Padding, Rsa},
};
use crate::error::AetherError;
use home::home_dir;
pub const RSA_SIZE: u32 = 1024;
#[derive(Debug, Clone)]
pub struct Id {
rsa: Rsa<Private>,
}
#[derive(Debug, Clone)]
pub struct PublicId {
rsa: Rsa<Public>,
}
impl Id {
pub fn new() -> Result<Id, AetherError> {
Ok(Id {
rsa: Rsa::generate(RSA_SIZE)?,
})
}
pub fn get_private_key_path() -> PathBuf {
let mut config = Self::get_config_dir();
config.push("private_key.pem");
config
}
pub fn get_public_key_path() -> PathBuf {
let mut config = Self::get_config_dir();
config.push("public_key.pem");
config
}
fn get_config_dir() -> PathBuf {
match home_dir() {
Some(mut home) => {
home.push(".config/aether/");
match fs::create_dir_all(home.clone()) {
Ok(()) => home,
Err(_) => PathBuf::from("./"),
}
}
None => PathBuf::from("./"),
}
}
pub fn save(&self) -> Result<(), AetherError> {
let rsa_public = self.rsa.public_key_to_pem()?;
let rsa_private = self.rsa.private_key_to_pem()?;
if let Err(err) = fs::write(Self::get_private_key_path(), rsa_private) {
Err(AetherError::FileWrite(err))
} else if let Err(err) = fs::write(Self::get_public_key_path(), rsa_public) {
Err(AetherError::FileWrite(err))
} else {
Ok(())
}
}
pub fn load() -> Result<Id, AetherError> {
let private_pem = match fs::read(Self::get_private_key_path()) {
Ok(data) => data,
Err(err) => return Err(AetherError::FileRead(err)),
};
let rsa = Rsa::private_key_from_pem(&private_pem)?;
Ok(Id { rsa })
}
pub fn load_or_generate() -> Result<Id, AetherError> {
match Self::load() {
Ok(id) => Ok(id),
Err(AetherError::FileRead(err)) => {
warn!("Unable to read key: {}", err);
let new_id = Self::new()?;
match new_id.save() {
Ok(()) => Ok(new_id),
Err(err) => Err(err),
}
}
Err(err) => Err(err),
}
}
pub fn public_key_to_base64(&self) -> Result<String, AetherError> {
let public_key_der = self.rsa.public_key_to_der()?;
Ok(base64::encode(public_key_der))
}
pub fn private_key_to_base64(&self) -> Result<String, AetherError> {
let private_key_der = self.rsa.private_key_to_der()?;
Ok(base64::encode(private_key_der))
}
pub fn public_encrypt(&self, from: &[u8]) -> Result<Vec<u8>, AetherError> {
let mut buf: Vec<u8> = vec![0; self.rsa.size() as usize];
self.rsa.public_encrypt(from, &mut buf, Padding::PKCS1)?;
Ok(buf.to_vec())
}
pub fn private_encrypt(&self, from: &[u8]) -> Result<Vec<u8>, AetherError> {
let mut buf: Vec<u8> = vec![0; self.rsa.size() as usize];
self.rsa.private_encrypt(from, &mut buf, Padding::PKCS1)?;
Ok(buf.to_vec())
}
pub fn public_decrypt(&self, from: &[u8]) -> Result<Vec<u8>, AetherError> {
let mut buf: Vec<u8> = vec![0; self.rsa.size() as usize];
let size = self.rsa.public_decrypt(from, &mut buf, Padding::PKCS1)?;
Ok(buf[..size].to_vec())
}
pub fn private_decrypt(&self, from: &[u8]) -> Result<Vec<u8>, AetherError> {
let mut buf: Vec<u8> = vec![0; self.rsa.size() as usize];
let size = self.rsa.private_decrypt(from, &mut buf, Padding::PKCS1)?;
Ok(buf[..size].to_vec())
}
}
impl PublicId {
pub fn from_base64(key: &str) -> Result<PublicId, AetherError> {
let bytes = base64::decode(key)?;
let rsa = Rsa::public_key_from_der(&bytes)?;
Ok(Self { rsa })
}
pub fn public_key_to_base64(&self) -> Result<String, AetherError> {
let public_key_der = self.rsa.public_key_to_der()?;
Ok(base64::encode(public_key_der))
}
pub fn public_encrypt(&self, from: &[u8]) -> Result<Vec<u8>, AetherError> {
let mut buf: Vec<u8> = vec![0; self.rsa.size() as usize];
self.rsa.public_encrypt(from, &mut buf, Padding::PKCS1)?;
Ok(buf.to_vec())
}
pub fn public_decrypt(&self, from: &[u8]) -> Result<Vec<u8>, AetherError> {
let mut buf: Vec<u8> = vec![0; self.rsa.size() as usize];
let size = self.rsa.public_decrypt(from, &mut buf, Padding::PKCS1)?;
Ok(buf[..size].to_vec())
}
}
#[cfg(test)]
mod tests {
use crate::util::gen_nonce;
use super::{Id, PublicId};
#[test]
fn save_test() {
let id = Id::new().unwrap();
id.save().unwrap();
let id_new = Id::load().unwrap();
assert_eq!(
id.public_key_to_base64().unwrap(),
id_new.public_key_to_base64().unwrap()
);
assert_eq!(
id.private_key_to_base64().unwrap(),
id_new.private_key_to_base64().unwrap()
);
}
#[test]
fn encrypt_test() {
let message = String::from("This is a small message");
let message_bytes = message.as_bytes();
let id = Id::new().unwrap();
let message_encrypted = id.public_encrypt(message_bytes).unwrap();
let message_decrypted = id.private_decrypt(&message_encrypted).unwrap();
let message_out = String::from_utf8(message_decrypted).unwrap();
assert_eq!(message, message_out);
}
#[test]
fn signature_test() {
let alice_id = Id::new().unwrap();
let alice_public =
PublicId::from_base64(&alice_id.public_key_to_base64().unwrap()).unwrap();
let alice_message = "A message to be signed";
let alice_message_signed = alice_id.private_encrypt(alice_message.as_bytes()).unwrap();
let bob_decrypted_bytes = alice_public.public_decrypt(&alice_message_signed).unwrap();
let bob_message = String::from_utf8(bob_decrypted_bytes).unwrap();
assert_eq!(alice_message, bob_message);
}
#[test]
fn authentication_test() {
let alice_id = Id::new().unwrap();
let alice_public =
PublicId::from_base64(&alice_id.public_key_to_base64().unwrap()).unwrap();
let bob_nonce = gen_nonce(32);
let bob_challenge = alice_public.public_encrypt(&bob_nonce).unwrap();
let alice_response = alice_id.private_decrypt(&bob_challenge).unwrap();
println!(
"{} == {}",
base64::encode(bob_nonce.clone()),
base64::encode(alice_response.clone())
);
assert_eq!(bob_nonce, alice_response);
}
}