use std::ffi::CStr;
use zeroize::Zeroizing;
use crate::errors::{self, OlmSasError};
use crate::getrandom;
use crate::ByteBuf;
pub struct OlmSas {
sas_ptr: *mut olm_sys::OlmSAS,
_sas_buf: ByteBuf,
public_key_set: bool,
}
impl Drop for OlmSas {
fn drop(&mut self) {
unsafe {
olm_sys::olm_clear_sas(self.sas_ptr);
}
}
}
impl Default for OlmSas {
fn default() -> Self {
Self::new()
}
}
impl OlmSas {
pub fn new() -> Self {
let mut sas_buf = ByteBuf::new(unsafe { olm_sys::olm_sas_size() });
let ptr = unsafe { olm_sys::olm_sas(sas_buf.as_mut_void_ptr()) };
let random_len = unsafe { olm_sys::olm_create_sas_random_length(ptr) };
let mut random_buf: Zeroizing<Vec<u8>> = Zeroizing::new(vec![0; random_len]);
getrandom(&mut random_buf);
let ret =
unsafe { olm_sys::olm_create_sas(ptr, random_buf.as_mut_ptr() as *mut _, random_len) };
if ret == errors::olm_error() {
errors::handle_fatal_error(Self::last_error(ptr));
}
Self {
sas_ptr: ptr,
_sas_buf: sas_buf,
public_key_set: false,
}
}
pub fn public_key(&self) -> String {
let pubkey_length = unsafe { olm_sys::olm_sas_pubkey_length(self.sas_ptr) };
let mut buffer: Vec<u8> = vec![0; pubkey_length];
let ret = unsafe {
olm_sys::olm_sas_get_pubkey(self.sas_ptr, buffer.as_mut_ptr() as *mut _, pubkey_length)
};
if ret == errors::olm_error() {
errors::handle_fatal_error(Self::last_error(self.sas_ptr));
}
unsafe { String::from_utf8_unchecked(buffer) }
}
fn last_error(sas_ptr: *mut olm_sys::OlmSAS) -> OlmSasError {
let error = unsafe {
let error_raw = olm_sys::olm_sas_last_error(sas_ptr);
CStr::from_ptr(error_raw).to_str().unwrap()
};
match error {
"NOT_ENOUGH_RANDOM" => OlmSasError::NotEnoughRandom,
"OUTPUT_BUFFER_TOO_SMALL" => OlmSasError::OutputBufferTooSmall,
"INPUT_BUFFER_TOO_SMALL" => OlmSasError::OutputBufferTooSmall,
_ => OlmSasError::Unknown,
}
}
pub fn set_their_public_key(&mut self, public_key: String) -> Result<(), OlmSasError> {
let ret = unsafe {
olm_sys::olm_sas_set_their_key(
self.sas_ptr,
public_key.as_ptr() as *mut _,
public_key.len(),
)
};
if ret == errors::olm_error() {
Err(Self::last_error(self.sas_ptr))
} else {
self.public_key_set = true;
Ok(())
}
}
pub fn generate_bytes(&self, extra_info: &str, length: usize) -> Result<Vec<u8>, OlmSasError> {
if !self.public_key_set {
return Err(OlmSasError::OtherPublicKeyUnset);
} else if length < 1 {
return Err(OlmSasError::InvalidLength);
}
let mut out_buffer = vec![0; length];
let ret = unsafe {
olm_sys::olm_sas_generate_bytes(
self.sas_ptr,
extra_info.as_ptr() as *mut _,
extra_info.len(),
out_buffer.as_mut_ptr() as *mut _,
length,
)
};
if ret == errors::olm_error() {
Err(Self::last_error(self.sas_ptr))
} else {
Ok(out_buffer)
}
}
pub fn calculate_mac(&self, message: &str, extra_info: &str) -> Result<String, OlmSasError> {
if !self.public_key_set {
return Err(OlmSasError::OtherPublicKeyUnset);
}
let mac_length = unsafe { olm_sys::olm_sas_mac_length(self.sas_ptr) };
let mut mac_buffer = vec![0; mac_length];
let ret = unsafe {
olm_sys::olm_sas_calculate_mac(
self.sas_ptr,
message.as_ptr() as *mut _,
message.len(),
extra_info.as_ptr() as *mut _,
extra_info.len(),
mac_buffer.as_mut_ptr() as *mut _,
mac_length,
)
};
if ret == errors::olm_error() {
Err(Self::last_error(self.sas_ptr))
} else {
Ok(unsafe { String::from_utf8_unchecked(mac_buffer) })
}
}
}
#[cfg(test)]
mod test {
use crate::sas::OlmSas;
#[test]
fn test_creation() {
let alice = OlmSas::new();
assert!(!alice.public_key().is_empty());
}
#[test]
fn test_set_pubkey() {
let mut alice = OlmSas::new();
assert!(alice.set_their_public_key(alice.public_key()).is_ok());
assert!(alice.set_their_public_key("".to_string()).is_err());
}
#[test]
fn test_generate_bytes() {
let mut alice = OlmSas::new();
let mut bob = OlmSas::new();
assert!(alice.generate_bytes("", 5).is_err());
assert!(alice.set_their_public_key(bob.public_key()).is_ok());
assert!(bob.set_their_public_key(alice.public_key()).is_ok());
assert_eq!(
alice.generate_bytes("", 5).unwrap(),
bob.generate_bytes("", 5).unwrap()
);
assert_ne!(
alice.generate_bytes("fake", 5).unwrap(),
bob.generate_bytes("", 5).unwrap()
);
}
#[test]
fn test_calculate_mac() {
let mut alice = OlmSas::new();
let mut bob = OlmSas::new();
let message = "It's a secret to everyone".to_string();
assert!(alice.calculate_mac(&message, "").is_err());
assert!(alice.set_their_public_key(bob.public_key()).is_ok());
assert!(bob.set_their_public_key(alice.public_key()).is_ok());
assert_eq!(
alice.calculate_mac(&message, "").unwrap(),
bob.calculate_mac(&message, "").unwrap()
);
assert_ne!(
alice.calculate_mac("fake", "").unwrap(),
bob.calculate_mac(&message, "").unwrap()
);
}
}