use std::ffi::CStr;
use zeroize::Zeroizing;
use crate::errors;
use crate::errors::{OlmPkDecryptionError, OlmPkEncryptionError, OlmPkSigningError};
use crate::{getrandom, PicklingMode};
pub struct PkMessage {
pub ciphertext: String,
pub mac: String,
pub ephemeral_key: String,
}
impl PkMessage {
pub fn new(ephemeral_key: String, mac: String, ciphertext: String) -> Self {
PkMessage {
ephemeral_key,
mac,
ciphertext,
}
}
}
pub struct OlmPkEncryption {
ptr: *mut olm_sys::OlmPkEncryption,
_buf: Vec<u8>,
}
impl Drop for OlmPkEncryption {
fn drop(&mut self) {
unsafe {
olm_sys::olm_clear_pk_encryption(self.ptr);
}
}
}
impl Default for OlmPkDecryption {
fn default() -> Self {
Self::new()
}
}
impl OlmPkEncryption {
pub fn new(recipient_key: String) -> Self {
let recipient_key = Zeroizing::from(recipient_key);
let size = unsafe { olm_sys::olm_pk_encryption_size() };
let mut buf = vec![0; size];
let ptr = unsafe { olm_sys::olm_pk_encryption(buf.as_mut_ptr() as *mut _) };
unsafe {
olm_sys::olm_pk_encryption_set_recipient_key(
ptr,
recipient_key.as_ptr() as *mut _,
recipient_key.len(),
);
}
Self { ptr, _buf: buf }
}
fn last_error(ptr: *mut olm_sys::OlmPkEncryption) -> OlmPkEncryptionError {
let error = unsafe {
let error_raw = olm_sys::olm_pk_encryption_last_error(ptr);
CStr::from_ptr(error_raw).to_str().unwrap()
};
error.into()
}
pub fn encrypt(&self, plaintext: &str) -> PkMessage {
let random_length = unsafe { olm_sys::olm_pk_encrypt_random_length(self.ptr) };
let mut random_buf = Zeroizing::new(vec![0; random_length]);
getrandom(&mut random_buf);
let ciphertext_length =
unsafe { olm_sys::olm_pk_ciphertext_length(self.ptr, plaintext.len()) };
let mac_length = unsafe { olm_sys::olm_pk_mac_length(self.ptr) };
let ephemeral_key_size = unsafe { olm_sys::olm_pk_key_length() };
let mut ciphertext = vec![0; ciphertext_length];
let mut mac = vec![0; mac_length];
let mut ephemeral_key = vec![0; ephemeral_key_size];
let ret = unsafe {
olm_sys::olm_pk_encrypt(
self.ptr,
plaintext.as_ptr() as *const _,
plaintext.len(),
ciphertext.as_mut_ptr() as *mut _,
ciphertext.len(),
mac.as_mut_ptr() as *mut _,
mac.len(),
ephemeral_key.as_mut_ptr() as *mut _,
ephemeral_key.len(),
random_buf.as_ptr() as *mut _,
random_buf.len(),
)
};
if ret == errors::olm_error() {
errors::handle_fatal_error(OlmPkEncryption::last_error(self.ptr));
}
let ciphertext = unsafe { String::from_utf8_unchecked(ciphertext) };
let mac = unsafe { String::from_utf8_unchecked(mac) };
let ephemeral_key = unsafe { String::from_utf8_unchecked(ephemeral_key) };
PkMessage {
ciphertext,
mac,
ephemeral_key,
}
}
}
pub struct OlmPkDecryption {
ptr: *mut olm_sys::OlmPkDecryption,
_buf: Vec<u8>,
public_key: String,
}
impl Drop for OlmPkDecryption {
fn drop(&mut self) {
unsafe {
olm_sys::olm_clear_pk_decryption(self.ptr);
}
}
}
impl OlmPkDecryption {
pub fn new() -> Self {
let (ptr, buf) = OlmPkDecryption::init();
let random_len = unsafe { olm_sys::olm_pk_private_key_length() };
let mut random_buf = Zeroizing::new(vec![0; random_len]);
getrandom(&mut random_buf);
let key_length = unsafe { olm_sys::olm_pk_key_length() };
let mut key_buffer = vec![0; key_length];
let ret = unsafe {
olm_sys::olm_pk_key_from_private(
ptr,
key_buffer.as_mut_ptr() as *mut _,
key_buffer.len(),
random_buf.as_mut_ptr() as *mut _,
random_buf.len(),
)
};
if ret == errors::olm_error() {
errors::handle_fatal_error(Self::last_error(ptr));
}
let public_key =
String::from_utf8(key_buffer).expect("Can't convert the public key buffer to a string");
Self {
ptr,
_buf: buf,
public_key,
}
}
fn init() -> (*mut olm_sys::OlmPkDecryption, Vec<u8>) {
let size = unsafe { olm_sys::olm_pk_decryption_size() };
let mut buf = vec![0; size];
let ptr = unsafe { olm_sys::olm_pk_decryption(buf.as_mut_ptr() as *mut _) };
(ptr, buf)
}
fn last_error(ptr: *mut olm_sys::OlmPkDecryption) -> OlmPkDecryptionError {
let error = unsafe {
let error_raw = olm_sys::olm_pk_decryption_last_error(ptr);
CStr::from_ptr(error_raw).to_str().unwrap()
};
error.into()
}
pub fn pickle(&self, mode: PicklingMode) -> String {
let mut pickled_buf: Vec<u8> =
vec![0; unsafe { olm_sys::olm_pickle_pk_decryption_length(self.ptr) }];
let pickle_error = {
let key = Zeroizing::new(crate::convert_pickling_mode_to_key(mode));
unsafe {
olm_sys::olm_pickle_pk_decryption(
self.ptr,
key.as_ptr() as *const _,
key.len(),
pickled_buf.as_mut_ptr() as *mut _,
pickled_buf.len(),
)
}
};
let pickled_result =
String::from_utf8(pickled_buf).expect("Pickle string is not valid utf-8");
if pickle_error == errors::olm_error() {
errors::handle_fatal_error(Self::last_error(self.ptr));
}
pickled_result
}
pub fn unpickle(mut pickle: String, mode: PicklingMode) -> Result<Self, OlmPkDecryptionError> {
let (ptr, buf) = OlmPkDecryption::init();
let pubkey_length = unsafe { olm_sys::olm_pk_signing_public_key_length() };
let mut pubkey_buffer = vec![0; pubkey_length];
let unpickle_error = {
let key = Zeroizing::new(crate::convert_pickling_mode_to_key(mode));
unsafe {
olm_sys::olm_unpickle_pk_decryption(
ptr,
key.as_ptr() as *const _,
key.len(),
pickle.as_mut_ptr() as *mut _,
pickle.len(),
pubkey_buffer.as_mut_ptr() as *mut _,
pubkey_buffer.len(),
)
}
};
let public_key = String::from_utf8(pubkey_buffer)
.expect("Can't conver the public key buffer to a string");
if unpickle_error == errors::olm_error() {
Err(Self::last_error(ptr))
} else {
Ok(Self {
ptr,
_buf: buf,
public_key,
})
}
}
pub fn decrypt(&self, mut message: PkMessage) -> Result<String, OlmPkDecryptionError> {
let max_plaintext =
unsafe { olm_sys::olm_pk_max_plaintext_length(self.ptr, message.ciphertext.len()) };
let mut plaintext = vec![0; max_plaintext];
let plaintext_len = unsafe {
olm_sys::olm_pk_decrypt(
self.ptr,
message.ephemeral_key.as_ptr() as *const _,
message.ephemeral_key.len(),
message.mac.as_ptr() as *const _,
message.mac.len(),
message.ciphertext.as_mut_ptr() as *mut _,
message.ciphertext.len(),
plaintext.as_mut_ptr() as *mut _,
max_plaintext,
)
};
if plaintext_len == errors::olm_error() {
Err(Self::last_error(self.ptr))
} else {
plaintext.truncate(plaintext_len);
Ok(String::from_utf8_lossy(&plaintext).to_string())
}
}
pub fn public_key(&self) -> &str {
&self.public_key
}
}
pub struct OlmPkSigning {
ptr: *mut olm_sys::OlmPkSigning,
_buf: Vec<u8>,
public_key: String,
}
impl Drop for OlmPkSigning {
fn drop(&mut self) {
unsafe { olm_sys::olm_clear_pk_signing(self.ptr) };
}
}
impl OlmPkSigning {
pub fn new(mut seed: Vec<u8>) -> Result<Self, OlmPkSigningError> {
if seed.len() != OlmPkSigning::seed_length() {
return Err(OlmPkSigningError::InvalidSeed);
}
let length = unsafe { olm_sys::olm_pk_signing_size() };
let mut buffer = vec![0; length];
let ptr = unsafe { olm_sys::olm_pk_signing(buffer.as_mut_ptr() as *mut _) };
let pubkey_length = unsafe { olm_sys::olm_pk_signing_public_key_length() };
let mut pubkey_buffer = vec![0; pubkey_length];
let ret = unsafe {
olm_sys::olm_pk_signing_key_from_seed(
ptr,
pubkey_buffer.as_mut_ptr() as *mut _,
pubkey_length,
seed.as_mut_ptr() as *mut _,
seed.len(),
)
};
if ret == errors::olm_error() {
Err(OlmPkSigning::last_error(ptr))
} else {
Ok(Self {
ptr,
_buf: buffer,
public_key: String::from_utf8(pubkey_buffer)
.expect("Can't conver the public key buffer to a string"),
})
}
}
fn last_error(ptr: *mut olm_sys::OlmPkSigning) -> OlmPkSigningError {
let error = unsafe {
let error_raw = olm_sys::olm_pk_signing_last_error(ptr);
CStr::from_ptr(error_raw).to_str().unwrap()
};
error.into()
}
pub fn seed_length() -> usize {
unsafe { olm_sys::olm_pk_signing_seed_length() }
}
pub fn generate_seed() -> Vec<u8> {
let length = OlmPkSigning::seed_length();
let mut buffer = Zeroizing::new(vec![0; length]);
getrandom(&mut buffer);
buffer.to_vec()
}
pub fn public_key(&self) -> &str {
&self.public_key
}
pub fn sign(&self, message: &str) -> String {
let signature_len = unsafe { olm_sys::olm_pk_signature_length() };
let mut signature = vec![0; signature_len];
let ret = unsafe {
olm_sys::olm_pk_sign(
self.ptr,
message.as_ptr() as *mut _,
message.len(),
signature.as_mut_ptr() as *mut _,
signature_len,
)
};
if ret == errors::olm_error() {
errors::handle_fatal_error(Self::last_error(self.ptr));
}
String::from_utf8(signature).expect("Can't conver the signature to a string")
}
}
#[cfg(test)]
mod test {
use crate::pk::{OlmPkDecryption, OlmPkEncryption, OlmPkSigning};
use crate::utility::OlmUtility;
use crate::PicklingMode;
#[test]
fn create_pk_sign() {
assert!(OlmPkSigning::new(OlmPkSigning::generate_seed()).is_ok());
}
#[test]
fn invalid_seed() {
assert!(OlmPkSigning::new(vec![]).is_err());
let lo_seed_len = OlmPkSigning::seed_length() - 1;
let hi_seed_len = OlmPkSigning::seed_length() + 1;
assert!(OlmPkSigning::new(vec![0; lo_seed_len]).is_err());
assert!(OlmPkSigning::new(vec![0; hi_seed_len]).is_err());
}
#[test]
fn seed_random() {
let seed_a = OlmPkSigning::generate_seed();
let seed_b = OlmPkSigning::generate_seed();
assert_ne!(&seed_a[..], &seed_b[..]);
}
#[test]
fn sign_a_message() {
let message = "It's a secret to everyone";
let sign = OlmPkSigning::new(OlmPkSigning::generate_seed()).unwrap();
let utility = OlmUtility::new();
let signature = sign.sign(message);
assert!(utility
.ed25519_verify(sign.public_key(), message, &signature)
.is_ok());
assert!(utility
.ed25519_verify(sign.public_key(), "Hello world", &signature)
.is_err());
}
#[test]
fn encrypt_a_message() {
let message = "It's a secret to everyone";
let decryption = OlmPkDecryption::new();
let encryption = OlmPkEncryption::new(decryption.public_key().to_owned());
let encrypted_message = encryption.encrypt(message);
let plaintext = decryption.decrypt(encrypted_message).unwrap();
assert_eq!(message, plaintext);
}
#[test]
fn pickle() {
let message = "It's a secret to everyone";
let decryption = OlmPkDecryption::new();
let encryption = OlmPkEncryption::new(decryption.public_key().to_owned());
let encrypted_message = encryption.encrypt(message);
let pickle = decryption.pickle(PicklingMode::Unencrypted);
let decryption = OlmPkDecryption::unpickle(pickle, PicklingMode::Unencrypted).unwrap();
let plaintext = decryption.decrypt(encrypted_message).unwrap();
assert_eq!(message, plaintext);
}
#[test]
fn invalid_unpickle() {
let decryption = OlmPkDecryption::new();
let pickle = decryption.pickle(PicklingMode::Encrypted {
key: Vec::from("wordpass"),
});
assert!(OlmPkDecryption::unpickle(pickle, PicklingMode::Unencrypted).is_err());
}
#[test]
fn invalid_decrypt() {
let message = "It's a secret to everyone";
let alice = OlmPkDecryption::new();
let malory = OlmPkEncryption::new(OlmPkDecryption::new().public_key().to_owned());
let encrypted_message = malory.encrypt(message);
assert!(alice.decrypt(encrypted_message).is_err());
}
}