use crate::account::OlmAccount;
use crate::errors::{self, OlmSessionError};
use crate::getrandom;
use crate::{ByteBuf, PicklingMode};
use std::cmp::Ordering;
use std::convert::TryFrom;
use std::ffi::CStr;
use std::fmt;
use zeroize::Zeroizing;
#[derive(Debug)]
pub struct OlmSession {
pub(crate) olm_session_ptr: *mut olm_sys::OlmSession,
_olm_session_buf: ByteBuf,
}
#[derive(Debug, Clone)]
pub struct Message(String);
#[derive(Debug, Clone)]
pub struct PreKeyMessage(String);
impl PreKeyMessage {
fn new(message: String) -> Self {
PreKeyMessage(message)
}
}
impl Message {
fn new(ciphertext: String) -> Self {
Message(ciphertext)
}
}
#[derive(Debug, Clone)]
pub enum OlmMessage {
Message(Message),
PreKey(PreKeyMessage),
}
#[derive(Debug)]
pub struct UnknownOlmMessageType;
impl fmt::Display for UnknownOlmMessageType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Unknown message type")
}
}
impl std::error::Error for UnknownOlmMessageType {}
impl OlmMessage {
pub fn from_type_and_ciphertext(
message_type: usize,
ciphertext: String,
) -> Result<Self, UnknownOlmMessageType> {
match message_type {
olm_sys::OLM_MESSAGE_TYPE_PRE_KEY => {
Ok(OlmMessage::PreKey(PreKeyMessage::new(ciphertext)))
}
olm_sys::OLM_MESSAGE_TYPE_MESSAGE => Ok(OlmMessage::Message(Message::new(ciphertext))),
_ => Err(UnknownOlmMessageType),
}
}
#[allow(clippy::wrong_self_convention)]
pub fn to_tuple(self) -> (OlmMessageType, String) {
match self {
OlmMessage::Message(m) => (OlmMessageType::Message, m.0),
OlmMessage::PreKey(m) => (OlmMessageType::PreKey, m.0),
}
}
}
impl OlmSession {
pub(crate) fn create_inbound_session(
account: &OlmAccount,
mut message: PreKeyMessage,
) -> Result<Self, OlmSessionError> {
Self::create_session_with(|olm_session_ptr| unsafe {
let one_time_key_message_buf = message.0.as_bytes_mut();
olm_sys::olm_create_inbound_session(
olm_session_ptr,
account.olm_account_ptr,
one_time_key_message_buf.as_mut_ptr() as *mut _,
one_time_key_message_buf.len(),
)
})
}
pub(crate) fn create_inbound_session_from(
account: &OlmAccount,
their_identity_key: &str,
mut one_time_key_message: PreKeyMessage,
) -> Result<Self, OlmSessionError> {
Self::create_session_with(|olm_session_ptr| {
let their_identity_key_buf = their_identity_key.as_bytes();
unsafe {
let one_time_key_message_buf = one_time_key_message.0.as_bytes_mut();
olm_sys::olm_create_inbound_session_from(
olm_session_ptr,
account.olm_account_ptr,
their_identity_key_buf.as_ptr() as *const _,
their_identity_key_buf.len(),
one_time_key_message_buf.as_mut_ptr() as *mut _,
one_time_key_message_buf.len(),
)
}
})
}
pub(crate) fn create_outbound_session(
account: &OlmAccount,
their_identity_key: &str,
their_one_time_key: &str,
) -> Result<Self, OlmSessionError> {
Self::create_session_with(|olm_session_ptr| {
let their_identity_key_buf = their_identity_key.as_bytes();
let their_one_time_key_buf = their_one_time_key.as_bytes();
let random_len =
unsafe { olm_sys::olm_create_outbound_session_random_length(olm_session_ptr) };
let mut random_buf: Zeroizing<Vec<u8>> = Zeroizing::new(vec![0; random_len]);
getrandom(&mut random_buf);
unsafe {
olm_sys::olm_create_outbound_session(
olm_session_ptr,
account.olm_account_ptr,
their_identity_key_buf.as_ptr() as *const _,
their_identity_key_buf.len(),
their_one_time_key_buf.as_ptr() as *const _,
their_one_time_key_buf.len(),
random_buf.as_mut_ptr() as *mut _,
random_len,
)
}
})
}
fn create_session_with<F: FnMut(*mut olm_sys::OlmSession) -> usize>(
mut f: F,
) -> Result<OlmSession, OlmSessionError> {
let mut olm_session_buf = ByteBuf::new(unsafe { olm_sys::olm_session_size() });
let olm_session_ptr = unsafe { olm_sys::olm_session(olm_session_buf.as_mut_void_ptr()) };
let error = f(olm_session_ptr);
if error == errors::olm_error() {
let last_error = Self::last_error(olm_session_ptr);
if last_error == OlmSessionError::NotEnoughRandom {
errors::handle_fatal_error(OlmSessionError::NotEnoughRandom);
}
Err(last_error)
} else {
Ok(OlmSession {
olm_session_ptr,
_olm_session_buf: olm_session_buf,
})
}
}
fn last_error(session_ptr: *mut olm_sys::OlmSession) -> OlmSessionError {
let error_raw = unsafe { olm_sys::olm_session_last_error(session_ptr) };
let error = unsafe { CStr::from_ptr(error_raw).to_str().unwrap() };
match error {
"BAD_ACCOUNT_KEY" => OlmSessionError::BadAccountKey,
"BAD_MESSAGE_MAC" => OlmSessionError::BadMessageMac,
"BAD_MESSAGE_FORMAT" => OlmSessionError::BadMessageFormat,
"BAD_MESSAGE_KEY_ID" => OlmSessionError::BadMessageKeyId,
"BAD_MESSAGE_VERSION" => OlmSessionError::BadMessageVersion,
"INVALID_BASE64" => OlmSessionError::InvalidBase64,
"NOT_ENOUGH_RANDOM" => OlmSessionError::NotEnoughRandom,
"OUTPUT_BUFFER_TOO_SMALL" => OlmSessionError::OutputBufferTooSmall,
_ => OlmSessionError::Unknown,
}
}
pub fn session_id(&self) -> String {
let session_id_len = unsafe { olm_sys::olm_session_id_length(self.olm_session_ptr) };
let mut session_id_buf: Vec<u8> = vec![0; session_id_len];
let error = unsafe {
olm_sys::olm_session_id(
self.olm_session_ptr,
session_id_buf.as_mut_ptr() as *mut _,
session_id_len,
)
};
let session_id_result = String::from_utf8(session_id_buf).unwrap();
if error == errors::olm_error() {
errors::handle_fatal_error(Self::last_error(self.olm_session_ptr));
}
session_id_result
}
pub fn pickle(&self, mode: PicklingMode) -> String {
let pickled_len = unsafe { olm_sys::olm_pickle_session_length(self.olm_session_ptr) };
let mut pickled_buf = vec![0; pickled_len];
let pickle_error = {
let key = Zeroizing::new(crate::convert_pickling_mode_to_key(mode));
unsafe {
olm_sys::olm_pickle_session(
self.olm_session_ptr,
key.as_ptr() as *const _,
key.len(),
pickled_buf.as_mut_ptr() as *mut _,
pickled_len,
)
}
};
let pickled_result = String::from_utf8(pickled_buf).unwrap();
if pickle_error == errors::olm_error() {
errors::handle_fatal_error(Self::last_error(self.olm_session_ptr));
}
pickled_result
}
pub fn unpickle(mut pickled: String, mode: PicklingMode) -> Result<Self, OlmSessionError> {
let key = Zeroizing::new(crate::convert_pickling_mode_to_key(mode));
Self::create_session_with(|olm_session_ptr| {
let pickled_len = pickled.len();
unsafe {
let pickled_buf = pickled.as_bytes_mut();
olm_sys::olm_unpickle_session(
olm_session_ptr,
key.as_ptr() as *const _,
key.len(),
pickled_buf.as_mut_ptr() as *mut _,
pickled_len,
)
}
})
}
pub fn encrypt(&self, plaintext: &str) -> OlmMessage {
let plaintext_buf = plaintext.as_bytes();
let plaintext_len = plaintext_buf.len();
let message_len =
unsafe { olm_sys::olm_encrypt_message_length(self.olm_session_ptr, plaintext_len) };
let mut message_buf: Vec<u8> = vec![0; message_len];
let message_type = self.encrypt_message_type();
let encrypt_error = {
let random_len = unsafe { olm_sys::olm_encrypt_random_length(self.olm_session_ptr) };
let mut random_buf: Zeroizing<Vec<u8>> = Zeroizing::new(vec![0; random_len]);
getrandom(&mut random_buf);
unsafe {
olm_sys::olm_encrypt(
self.olm_session_ptr,
plaintext_buf.as_ptr() as *const _,
plaintext_len,
random_buf.as_mut_ptr() as *mut _,
random_len,
message_buf.as_mut_ptr() as *mut _,
message_len,
)
}
};
let message_result = String::from_utf8(message_buf).unwrap();
if encrypt_error == errors::olm_error() {
errors::handle_fatal_error(Self::last_error(self.olm_session_ptr));
}
match message_type {
OlmMessageType::Message => OlmMessage::Message(Message::new(message_result)),
OlmMessageType::PreKey => OlmMessage::PreKey(PreKeyMessage::new(message_result)),
}
}
pub fn decrypt(&self, message: OlmMessage) -> Result<String, OlmSessionError> {
let (message_type, mut ciphertext) = message.to_tuple();
let message_type_val = match message_type {
OlmMessageType::PreKey => olm_sys::OLM_MESSAGE_TYPE_PRE_KEY,
_ => olm_sys::OLM_MESSAGE_TYPE_MESSAGE,
};
let mut message_for_len = ciphertext.to_owned();
let message_buf = unsafe { message_for_len.as_bytes_mut() };
let message_len = message_buf.len();
let message_ptr = message_buf.as_mut_ptr() as *mut _;
let plaintext_max_len = unsafe {
olm_sys::olm_decrypt_max_plaintext_length(
self.olm_session_ptr,
message_type_val,
message_ptr,
message_len,
)
};
if plaintext_max_len == errors::olm_error() {
return Err(Self::last_error(self.olm_session_ptr));
}
let mut plaintext_buf = Zeroizing::new(vec![0; plaintext_max_len]);
let message_buf = unsafe { ciphertext.as_bytes_mut() };
let message_len = message_buf.len();
let message_ptr = message_buf.as_mut_ptr() as *mut _;
let plaintext_result_len = unsafe {
olm_sys::olm_decrypt(
self.olm_session_ptr,
message_type_val,
message_ptr,
message_len,
plaintext_buf.as_mut_ptr() as *mut _,
plaintext_max_len,
)
};
let decrypt_error = plaintext_result_len;
if decrypt_error == errors::olm_error() {
let last_error = Self::last_error(self.olm_session_ptr);
if last_error == OlmSessionError::OutputBufferTooSmall {
errors::handle_fatal_error(OlmSessionError::OutputBufferTooSmall);
}
return Err(last_error);
}
plaintext_buf.truncate(plaintext_result_len);
Ok(String::from_utf8_lossy(&plaintext_buf).to_string())
}
pub(crate) fn encrypt_message_type(&self) -> OlmMessageType {
let message_type_result =
unsafe { olm_sys::olm_encrypt_message_type(self.olm_session_ptr) };
let message_type_error = message_type_result;
if message_type_error == errors::olm_error() {
errors::handle_fatal_error(Self::last_error(self.olm_session_ptr));
}
match message_type_result {
olm_sys::OLM_MESSAGE_TYPE_PRE_KEY => OlmMessageType::PreKey,
_ => OlmMessageType::Message,
}
}
pub fn has_received_message(&self) -> bool {
0 != unsafe { olm_sys::olm_session_has_received_message(self.olm_session_ptr) }
}
pub fn matches_inbound_session(
&self,
mut message: PreKeyMessage,
) -> Result<bool, OlmSessionError> {
let matches_result = unsafe {
let one_time_key_message_buf = message.0.as_bytes_mut();
olm_sys::olm_matches_inbound_session(
self.olm_session_ptr,
one_time_key_message_buf.as_mut_ptr() as *mut _,
one_time_key_message_buf.len(),
)
};
let matches_error = matches_result;
if matches_error == errors::olm_error() {
Err(OlmSession::last_error(self.olm_session_ptr))
} else {
match matches_result {
0 => Ok(false),
1 => Ok(true),
_ => Err(OlmSessionError::Unknown),
}
}
}
pub fn matches_inbound_session_from(
&self,
their_identity_key: &str,
mut message: PreKeyMessage,
) -> Result<bool, OlmSessionError> {
let their_identity_key_buf = their_identity_key.as_bytes();
let their_identity_key_ptr = their_identity_key_buf.as_ptr() as *const _;
let matches_result = unsafe {
let one_time_key_message_buf = message.0.as_bytes_mut();
olm_sys::olm_matches_inbound_session_from(
self.olm_session_ptr,
their_identity_key_ptr,
their_identity_key_buf.len(),
one_time_key_message_buf.as_mut_ptr() as *mut _,
one_time_key_message_buf.len(),
)
};
let matches_error = matches_result;
if matches_error == errors::olm_error() {
Err(OlmSession::last_error(self.olm_session_ptr))
} else {
match matches_result {
0 => Ok(false),
1 => Ok(true),
_ => Err(OlmSessionError::Unknown),
}
}
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum OlmMessageType {
PreKey,
Message,
}
impl From<OlmMessageType> for usize {
fn from(message_type: OlmMessageType) -> Self {
match message_type {
OlmMessageType::PreKey => olm_sys::OLM_MESSAGE_TYPE_PRE_KEY,
OlmMessageType::Message => olm_sys::OLM_MESSAGE_TYPE_MESSAGE,
}
}
}
impl TryFrom<usize> for OlmMessageType {
type Error = ();
fn try_from(message_type: usize) -> Result<OlmMessageType, ()> {
match message_type {
olm_sys::OLM_MESSAGE_TYPE_PRE_KEY => Ok(OlmMessageType::PreKey),
olm_sys::OLM_MESSAGE_TYPE_MESSAGE => Ok(OlmMessageType::Message),
_ => Err(()),
}
}
}
impl Ord for OlmSession {
fn cmp(&self, other: &Self) -> Ordering {
self.session_id().cmp(&other.session_id())
}
}
impl PartialOrd for OlmSession {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl PartialEq for OlmSession {
fn eq(&self, other: &Self) -> bool {
self.session_id() == other.session_id()
}
}
impl Eq for OlmSession {}
impl Drop for OlmSession {
fn drop(&mut self) {
unsafe {
olm_sys::olm_clear_session(self.olm_session_ptr);
}
}
}
#[cfg(test)]
mod test {
use crate::account::OlmAccount;
use crate::session::OlmMessageType;
#[test]
fn message_type() {
let alice = OlmAccount::new();
let bob = OlmAccount::new();
alice.generate_one_time_keys(1);
let identity_key = alice.parsed_identity_keys().ed25519().to_owned();
let one_time_key = alice
.parsed_one_time_keys()
.curve25519()
.values()
.next()
.unwrap()
.to_owned();
let outbound_session = bob
.create_outbound_session(&identity_key, &one_time_key)
.unwrap();
assert_eq!(
OlmMessageType::PreKey,
outbound_session.encrypt_message_type()
);
assert!(!outbound_session.has_received_message());
}
}