mod chain_key;
mod double_ratchet;
pub mod message_key;
pub mod ratchet;
mod receiver_chain;
mod root_key;
use std::fmt::Debug;
use aes::cipher::block_padding::UnpadError;
use arrayvec::ArrayVec;
use chain_key::RemoteChainKey;
use double_ratchet::DoubleRatchet;
use hmac::digest::MacError;
use ratchet::RemoteRatchetKey;
use receiver_chain::ReceiverChain;
use root_key::RemoteRootKey;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use zeroize::Zeroize;
use super::{
session_config::Version,
session_keys::SessionKeys,
shared_secret::{RemoteShared3DHSecret, Shared3DHSecret},
SessionConfig,
};
#[cfg(feature = "low-level-api")]
use crate::hazmat::olm::MessageKey;
use crate::{
olm::messages::{Message, OlmMessage, PreKeyMessage},
utilities::{pickle, unpickle},
Curve25519PublicKey, PickleError,
};
const MAX_RECEIVING_CHAINS: usize = 5;
#[derive(Error, Debug)]
pub enum DecryptionError {
#[error("Failed decrypting Olm message, invalid MAC: {0}")]
InvalidMAC(#[from] MacError),
#[error("Failed decrypting Olm message, invalid MAC length: expected {0}, got {1}")]
InvalidMACLength(usize, usize),
#[error("Failed decrypting Olm message, invalid padding")]
InvalidPadding(#[from] UnpadError),
#[error("The message key with the given key can't be created, message index: {0}")]
MissingMessageKey(u64),
#[error("The message gap was too big, got {0}, max allowed {1}")]
TooBigMessageGap(u64, u64),
}
#[derive(Serialize, Deserialize, Clone)]
struct ChainStore {
inner: ArrayVec<ReceiverChain, MAX_RECEIVING_CHAINS>,
}
impl ChainStore {
fn new() -> Self {
Self { inner: ArrayVec::new() }
}
fn push(&mut self, ratchet: ReceiverChain) {
if self.inner.is_full() {
self.inner.pop_at(0);
}
self.inner.push(ratchet)
}
fn is_empty(&self) -> bool {
self.inner.is_empty()
}
#[cfg(test)]
pub fn len(&self) -> usize {
self.inner.len()
}
#[cfg(feature = "libolm-compat")]
pub fn get(&self, index: usize) -> Option<&ReceiverChain> {
self.inner.get(index)
}
fn find_ratchet(&mut self, ratchet_key: &RemoteRatchetKey) -> Option<&mut ReceiverChain> {
self.inner.iter_mut().find(|r| r.belongs_to(ratchet_key))
}
}
impl Default for ChainStore {
fn default() -> Self {
Self::new()
}
}
pub struct Session {
session_keys: SessionKeys,
sending_ratchet: DoubleRatchet,
receiving_chains: ChainStore,
config: SessionConfig,
}
impl Debug for Session {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let Self { session_keys: _, sending_ratchet, receiving_chains, config } = self;
f.debug_struct("Session")
.field("session_id", &self.session_id())
.field("sending_chain_index", &sending_ratchet.chain_index())
.field("receiving_chains", &receiving_chains.inner)
.field("config", config)
.finish_non_exhaustive()
}
}
impl Session {
pub(super) fn new(
config: SessionConfig,
shared_secret: Shared3DHSecret,
session_keys: SessionKeys,
) -> Self {
let local_ratchet = DoubleRatchet::active(shared_secret);
Self {
session_keys,
sending_ratchet: local_ratchet,
receiving_chains: Default::default(),
config,
}
}
pub(super) fn new_remote(
config: SessionConfig,
shared_secret: RemoteShared3DHSecret,
remote_ratchet_key: Curve25519PublicKey,
session_keys: SessionKeys,
) -> Self {
let (root_key, remote_chain_key) = shared_secret.expand();
let remote_ratchet_key = RemoteRatchetKey::from(remote_ratchet_key);
let root_key = RemoteRootKey::new(root_key);
let remote_chain_key = RemoteChainKey::new(remote_chain_key);
let local_ratchet = DoubleRatchet::inactive(root_key, remote_ratchet_key);
let remote_ratchet = ReceiverChain::new(remote_ratchet_key, remote_chain_key);
let mut ratchet_store = ChainStore::new();
ratchet_store.push(remote_ratchet);
Self {
session_keys,
sending_ratchet: local_ratchet,
receiving_chains: ratchet_store,
config,
}
}
pub fn session_id(&self) -> String {
self.session_keys.session_id()
}
pub fn has_received_message(&self) -> bool {
!self.receiving_chains.is_empty()
}
pub fn encrypt(&mut self, plaintext: impl AsRef<[u8]>) -> OlmMessage {
let message = match self.config.version {
Version::V1 => self.sending_ratchet.encrypt_truncated_mac(plaintext.as_ref()),
Version::V2 => self.sending_ratchet.encrypt(plaintext.as_ref()),
};
if self.has_received_message() {
OlmMessage::Normal(message)
} else {
let message = PreKeyMessage::new(self.session_keys, message);
OlmMessage::PreKey(message)
}
}
pub fn session_keys(&self) -> SessionKeys {
self.session_keys
}
pub fn session_config(&self) -> SessionConfig {
self.config
}
#[cfg(feature = "low-level-api")]
pub fn next_message_key(&mut self) -> MessageKey {
self.sending_ratchet.next_message_key()
}
pub fn decrypt(&mut self, message: &OlmMessage) -> Result<Vec<u8>, DecryptionError> {
let decrypted = match message {
OlmMessage::Normal(m) => self.decrypt_decoded(m)?,
OlmMessage::PreKey(m) => self.decrypt_decoded(&m.message)?,
};
Ok(decrypted)
}
pub(super) fn decrypt_decoded(
&mut self,
message: &Message,
) -> Result<Vec<u8>, DecryptionError> {
let ratchet_key = RemoteRatchetKey::from(message.ratchet_key);
if let Some(ratchet) = self.receiving_chains.find_ratchet(&ratchet_key) {
ratchet.decrypt(message, &self.config)
} else {
let (sending_ratchet, mut remote_ratchet) = self.sending_ratchet.advance(ratchet_key);
let plaintext = remote_ratchet.decrypt(message, &self.config)?;
self.sending_ratchet = sending_ratchet;
self.receiving_chains.push(remote_ratchet);
Ok(plaintext)
}
}
pub fn pickle(&self) -> SessionPickle {
SessionPickle {
session_keys: self.session_keys,
sending_ratchet: self.sending_ratchet.clone(),
receiving_chains: self.receiving_chains.clone(),
config: self.config,
}
}
pub fn from_pickle(pickle: SessionPickle) -> Self {
pickle.into()
}
#[cfg(feature = "libolm-compat")]
pub fn from_libolm_pickle(
pickle: &str,
pickle_key: &[u8],
) -> Result<Self, crate::LibolmPickleError> {
use chain_key::ChainKey;
use matrix_pickle::Decode;
use message_key::RemoteMessageKey;
use ratchet::{Ratchet, RatchetKey};
use root_key::RootKey;
use crate::{types::Curve25519SecretKey, utilities::unpickle_libolm};
#[derive(Debug, Decode, Zeroize)]
#[zeroize(drop)]
struct SenderChain {
public_ratchet_key: [u8; 32],
#[secret]
secret_ratchet_key: Box<[u8; 32]>,
chain_key: Box<[u8; 32]>,
chain_key_index: u32,
}
#[derive(Debug, Decode, Zeroize)]
#[zeroize(drop)]
struct ReceivingChain {
public_ratchet_key: [u8; 32],
#[secret]
chain_key: Box<[u8; 32]>,
chain_key_index: u32,
}
impl From<&ReceivingChain> for ReceiverChain {
fn from(chain: &ReceivingChain) -> Self {
let ratchet_key = RemoteRatchetKey::from(chain.public_ratchet_key);
let chain_key = RemoteChainKey::from_bytes_and_index(
chain.chain_key.clone(),
chain.chain_key_index,
);
ReceiverChain::new(ratchet_key, chain_key)
}
}
#[derive(Debug, Decode, Zeroize)]
#[zeroize(drop)]
struct MessageKey {
ratchet_key: [u8; 32],
#[secret]
message_key: Box<[u8; 32]>,
index: u32,
}
impl From<&MessageKey> for RemoteMessageKey {
fn from(key: &MessageKey) -> Self {
RemoteMessageKey { key: key.message_key.clone(), index: key.index.into() }
}
}
#[derive(Decode)]
struct Pickle {
#[allow(dead_code)]
version: u32,
#[allow(dead_code)]
received_message: bool,
session_keys: SessionKeys,
#[secret]
root_key: Box<[u8; 32]>,
sender_chains: Vec<SenderChain>,
receiver_chains: Vec<ReceivingChain>,
message_keys: Vec<MessageKey>,
}
impl Drop for Pickle {
fn drop(&mut self) {
self.root_key.zeroize();
self.sender_chains.zeroize();
self.receiver_chains.zeroize();
self.message_keys.zeroize();
}
}
impl TryFrom<Pickle> for Session {
type Error = crate::LibolmPickleError;
fn try_from(pickle: Pickle) -> Result<Self, Self::Error> {
let mut receiving_chains = ChainStore::new();
for chain in &pickle.receiver_chains {
receiving_chains.push(chain.into())
}
for key in &pickle.message_keys {
let ratchet_key =
RemoteRatchetKey::from(Curve25519PublicKey::from(key.ratchet_key));
if let Some(receiving_chain) = receiving_chains.find_ratchet(&ratchet_key) {
receiving_chain.insert_message_key(key.into())
}
}
if let Some(chain) = pickle.sender_chains.get(0) {
let ratchet_key = RatchetKey::from(Curve25519SecretKey::from_slice(
chain.secret_ratchet_key.as_ref(),
));
let chain_key = ChainKey::from_bytes_and_index(
chain.chain_key.clone(),
chain.chain_key_index,
);
let root_key = RootKey::new(pickle.root_key.clone());
let ratchet = Ratchet::new_with_ratchet_key(root_key, ratchet_key);
let sending_ratchet =
DoubleRatchet::from_ratchet_and_chain_key(ratchet, chain_key);
Ok(Self {
session_keys: pickle.session_keys,
sending_ratchet,
receiving_chains,
config: SessionConfig::version_1(),
})
} else if let Some(chain) = receiving_chains.get(0) {
let sending_ratchet = DoubleRatchet::inactive(
RemoteRootKey::new(pickle.root_key.clone()),
chain.ratchet_key(),
);
Ok(Self {
session_keys: pickle.session_keys,
sending_ratchet,
receiving_chains,
config: SessionConfig::version_1(),
})
} else {
Err(crate::LibolmPickleError::InvalidSession)
}
}
}
const PICKLE_VERSION: u32 = 1;
unpickle_libolm::<Pickle, _>(pickle, pickle_key, PICKLE_VERSION)
}
}
#[derive(Deserialize, Serialize)]
pub struct SessionPickle {
session_keys: SessionKeys,
sending_ratchet: DoubleRatchet,
receiving_chains: ChainStore,
#[serde(default = "default_config")]
config: SessionConfig,
}
fn default_config() -> SessionConfig {
SessionConfig::version_1()
}
impl SessionPickle {
pub fn encrypt(self, pickle_key: &[u8; 32]) -> String {
pickle(&self, pickle_key)
}
pub fn from_encrypted(ciphertext: &str, pickle_key: &[u8; 32]) -> Result<Self, PickleError> {
unpickle(ciphertext, pickle_key)
}
}
impl From<SessionPickle> for Session {
fn from(pickle: SessionPickle) -> Self {
Self {
session_keys: pickle.session_keys,
sending_ratchet: pickle.sending_ratchet,
receiving_chains: pickle.receiving_chains,
config: pickle.config,
}
}
}
#[cfg(test)]
mod test {
use anyhow::{bail, Result};
use olm_rs::{
account::OlmAccount,
session::{OlmMessage, OlmSession},
};
use super::Session;
use crate::{
olm::{Account, SessionConfig, SessionPickle},
Curve25519PublicKey,
};
const PICKLE_KEY: [u8; 32] = [0u8; 32];
fn sessions() -> Result<(Account, OlmAccount, Session, OlmSession)> {
let alice = Account::new();
let bob = OlmAccount::new();
bob.generate_one_time_keys(1);
let one_time_key = bob
.parsed_one_time_keys()
.curve25519()
.values()
.next()
.cloned()
.expect("Couldn't find a one-time key");
let identity_keys = bob.parsed_identity_keys();
let curve25519_key = Curve25519PublicKey::from_base64(identity_keys.curve25519())?;
let one_time_key = Curve25519PublicKey::from_base64(&one_time_key)?;
let mut alice_session =
alice.create_outbound_session(SessionConfig::version_1(), curve25519_key, one_time_key);
let message = "It's a secret to everybody";
let olm_message = alice_session.encrypt(message);
bob.mark_keys_as_published();
if let OlmMessage::PreKey(m) = olm_message.into() {
let session =
bob.create_inbound_session_from(&alice.curve25519_key().to_base64(), m)?;
Ok((alice, bob, alice_session, session))
} else {
bail!("Invalid message type");
}
}
#[test]
fn out_of_order_decryption() -> Result<()> {
let (_, _, mut alice_session, bob_session) = sessions()?;
let message_1 = bob_session.encrypt("Message 1").into();
let message_2 = bob_session.encrypt("Message 2").into();
let message_3 = bob_session.encrypt("Message 3").into();
assert_eq!("Message 3".as_bytes(), alice_session.decrypt(&message_3)?);
assert_eq!("Message 2".as_bytes(), alice_session.decrypt(&message_2)?);
assert_eq!("Message 1".as_bytes(), alice_session.decrypt(&message_1)?);
Ok(())
}
#[test]
fn more_out_of_order_decryption() -> Result<()> {
let (_, _, mut alice_session, bob_session) = sessions()?;
let message_1 = bob_session.encrypt("Message 1").into();
let message_2 = bob_session.encrypt("Message 2").into();
let message_3 = bob_session.encrypt("Message 3").into();
assert_eq!("Message 1".as_bytes(), alice_session.decrypt(&message_1)?);
assert_eq!(alice_session.receiving_chains.len(), 1);
let message_4 = alice_session.encrypt("Message 4").into();
assert_eq!("Message 4", bob_session.decrypt(message_4)?);
let message_5 = bob_session.encrypt("Message 5").into();
assert_eq!("Message 5".as_bytes(), alice_session.decrypt(&message_5)?);
assert_eq!("Message 3".as_bytes(), alice_session.decrypt(&message_3)?);
assert_eq!("Message 2".as_bytes(), alice_session.decrypt(&message_2)?);
assert_eq!(alice_session.receiving_chains.len(), 2);
Ok(())
}
#[test]
#[cfg(feature = "libolm-compat")]
fn libolm_unpickling() -> Result<()> {
let (_, _, mut session, olm) = sessions()?;
let plaintext = "It's a secret to everybody";
let old_message = session.encrypt(plaintext);
for _ in 0..9 {
session.encrypt("Hello");
}
let message = session.encrypt("Hello");
olm.decrypt(message.into())?;
let key = b"DEFAULT_PICKLE_KEY";
let pickle = olm.pickle(olm_rs::PicklingMode::Encrypted { key: key.to_vec() });
let mut unpickled = Session::from_libolm_pickle(&pickle, key)?;
assert_eq!(olm.session_id(), unpickled.session_id());
assert_eq!(unpickled.decrypt(&old_message)?, plaintext.as_bytes());
let message = unpickled.encrypt(plaintext);
assert_eq!(session.decrypt(&message)?, plaintext.as_bytes());
Ok(())
}
#[test]
fn session_pickling_roundtrip_is_identity() -> Result<()> {
let (_, _, session, _) = sessions()?;
let pickle = session.pickle().encrypt(&PICKLE_KEY);
let decrypted_pickle = SessionPickle::from_encrypted(&pickle, &PICKLE_KEY)?;
let unpickled_group_session = Session::from_pickle(decrypted_pickle);
let repickle = unpickled_group_session.pickle();
assert_eq!(session.session_id(), unpickled_group_session.session_id());
let decrypted_pickle = SessionPickle::from_encrypted(&pickle, &PICKLE_KEY)?;
let pickle = serde_json::to_value(decrypted_pickle)?;
let repickle = serde_json::to_value(repickle)?;
assert_eq!(pickle, repickle);
Ok(())
}
}