use hkdf::Hkdf;
use hmac::{Hmac, Mac as _, digest::MacError};
use rand::thread_rng;
use sha2::Sha256;
use thiserror::Error;
use x25519_dalek::{EphemeralSecret, SharedSecret};
use crate::{
Curve25519PublicKey, KeyError,
utilities::{base64_decode, base64_encode},
};
type HmacSha256Key = Box<[u8; 32]>;
pub struct Mac(Vec<u8>);
impl Mac {
pub fn to_base64(&self) -> String {
base64_encode(&self.0)
}
pub fn as_bytes(&self) -> &[u8] {
&self.0
}
pub fn from_slice(bytes: &[u8]) -> Self {
Self(bytes.to_vec())
}
pub fn from_base64(mac: &str) -> Result<Self, base64::DecodeError> {
let bytes = base64_decode(mac)?;
Ok(Self(bytes))
}
}
#[derive(Debug, Clone, Error)]
#[error("The given count of bytes was too large")]
pub struct InvalidCount;
#[derive(Debug, Error)]
pub enum SasError {
#[error("The SAS MAC validation didn't succeed: {0}")]
Mac(#[from] MacError),
}
pub struct Sas {
secret_key: EphemeralSecret,
public_key: Curve25519PublicKey,
}
pub struct EstablishedSas {
shared_secret: SharedSecret,
our_public_key: Curve25519PublicKey,
their_public_key: Curve25519PublicKey,
}
impl std::fmt::Debug for EstablishedSas {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EstablishedSas")
.field("our_public_key", &self.our_public_key.to_base64())
.field("their_public_key", &self.their_public_key.to_base64())
.finish_non_exhaustive()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SasBytes {
bytes: [u8; 6],
}
impl SasBytes {
pub fn emoji_indices(&self) -> [u8; 7] {
Self::bytes_to_emoji_index(&self.bytes)
}
pub fn decimals(&self) -> (u16, u16, u16) {
Self::bytes_to_decimal(&self.bytes)
}
pub const fn as_bytes(&self) -> &[u8; 6] {
&self.bytes
}
fn bytes_to_emoji_index(bytes: &[u8; 6]) -> [u8; 7] {
let bytes: Vec<u64> = bytes.iter().map(|b| *b as u64).collect();
let mut num: u64 = bytes[0] << 40;
num += bytes[1] << 32;
num += bytes[2] << 24;
num += bytes[3] << 16;
num += bytes[4] << 8;
num += bytes[5];
[
((num >> 42) & 63) as u8,
((num >> 36) & 63) as u8,
((num >> 30) & 63) as u8,
((num >> 24) & 63) as u8,
((num >> 18) & 63) as u8,
((num >> 12) & 63) as u8,
((num >> 6) & 63) as u8,
]
}
fn bytes_to_decimal(bytes: &[u8; 6]) -> (u16, u16, u16) {
let bytes: Vec<u16> = bytes.iter().map(|b| *b as u16).collect();
let first = (bytes[0] << 5) | (bytes[1] >> 3);
let second = ((bytes[1] & 0x7) << 10) | (bytes[2] << 2) | (bytes[3] >> 6);
let third = ((bytes[3] & 0x3F) << 7) | (bytes[4] >> 1);
(first + 1000, second + 1000, third + 1000)
}
}
impl Default for Sas {
fn default() -> Self {
Self::new()
}
}
impl Sas {
pub fn new() -> Self {
let rng = thread_rng();
let secret_key = EphemeralSecret::random_from_rng(rng);
let public_key = Curve25519PublicKey::from(&secret_key);
Self { secret_key, public_key }
}
pub const fn public_key(&self) -> Curve25519PublicKey {
self.public_key
}
pub fn diffie_hellman(
self,
their_public_key: Curve25519PublicKey,
) -> Result<EstablishedSas, KeyError> {
let shared_secret = self.secret_key.diffie_hellman(&their_public_key.inner);
if shared_secret.was_contributory() {
Ok(EstablishedSas { shared_secret, our_public_key: self.public_key, their_public_key })
} else {
Err(KeyError::NonContributoryKey)
}
}
pub fn diffie_hellman_with_raw(
self,
other_public_key: &str,
) -> Result<EstablishedSas, KeyError> {
let other_public_key = Curve25519PublicKey::from_base64(other_public_key)?;
self.diffie_hellman(other_public_key)
}
}
impl EstablishedSas {
pub fn bytes(&self, info: &str) -> SasBytes {
let mut bytes = [0u8; 6];
#[allow(clippy::expect_used)]
let byte_vec =
self.bytes_raw(info, 6).expect("HKDF should always be able to generate 6 bytes");
bytes.copy_from_slice(&byte_vec);
SasBytes { bytes }
}
pub fn bytes_raw(&self, info: &str, count: usize) -> Result<Vec<u8>, InvalidCount> {
let mut output = vec![0u8; count];
let hkdf = self.get_hkdf();
hkdf.expand(info.as_bytes(), &mut output[0..count]).map_err(|_| InvalidCount)?;
Ok(output)
}
pub fn calculate_mac(&self, input: &str, info: &str) -> Mac {
let mut mac = self.get_mac(info);
mac.update(input.as_ref());
Mac(mac.finalize().into_bytes().to_vec())
}
#[cfg(feature = "libolm-compat")]
pub fn calculate_mac_invalid_base64(&self, input: &str, info: &str) -> String {
let mac = self.calculate_mac(input, info);
let mut out = base64_encode(&mac.as_bytes()[0..3]);
let mut bytes_from_mac = 2;
for i in (6..10).step_by(3) {
let from_mac = &mac.as_bytes()[i - bytes_from_mac..i];
let from_out = &out.as_bytes()[out.len() - (3 - bytes_from_mac)..];
let bytes = [from_out, from_mac].concat();
let encoded = base64_encode(bytes);
bytes_from_mac -= 1;
out = out + &encoded;
}
for i in (9..30).step_by(3) {
let next = &out.as_bytes()[i..i + 3];
let next_four = base64_encode(next);
out = out + &next_four;
}
let next = &out.as_bytes()[30..32];
let next = base64_encode(next);
out + &next
}
pub fn verify_mac(&self, input: &str, info: &str, tag: &Mac) -> Result<(), SasError> {
let mut mac = self.get_mac(info);
mac.update(input.as_bytes());
Ok(mac.verify_slice(&tag.0)?)
}
pub const fn our_public_key(&self) -> Curve25519PublicKey {
self.our_public_key
}
pub const fn their_public_key(&self) -> Curve25519PublicKey {
self.their_public_key
}
fn get_hkdf(&self) -> Hkdf<Sha256> {
Hkdf::new(None, self.shared_secret.as_bytes())
}
fn get_mac_key(&self, info: &str) -> HmacSha256Key {
let mut mac_key = Box::new([0u8; 32]);
let hkdf = self.get_hkdf();
#[allow(clippy::expect_used)]
hkdf.expand(info.as_bytes(), mac_key.as_mut_slice())
.expect("We should be able to expand the shared SAS secret into a MAC key");
mac_key
}
fn get_mac(&self, info: &str) -> Hmac<Sha256> {
let mac_key = self.get_mac_key(info);
#[allow(clippy::expect_used)]
Hmac::<Sha256>::new_from_slice(mac_key.as_slice())
.expect("We should be able to create a HMAC object from a 32-byte slice")
}
}
#[cfg(test)]
mod test {
use insta::assert_debug_snapshot;
use olm_rs::sas::OlmSas;
use proptest::prelude::*;
use super::{Mac, Sas, SasBytes};
use crate::Curve25519PublicKey;
const ALICE_MXID: &str = "@alice:example.com";
const ALICE_DEVICE_ID: &str = "AAAAAAAAAA";
const BOB_MXID: &str = "@bob:example.com";
const BOB_DEVICE_ID: &str = "BBBBBBBBBB";
#[test]
fn as_bytes_is_identity() {
let bytes = [0u8, 1, 2, 3, 4, 5];
assert_eq!(SasBytes { bytes }.as_bytes(), &bytes);
}
#[test]
fn mac_from_slice_as_bytes_is_identity() {
let bytes = "ABCDEFGH".as_bytes();
assert_eq!(
Mac::from_slice(bytes).as_bytes(),
bytes,
"as_bytes() after from_slice() is not identity"
);
}
#[test]
fn snapshot_debug() {
let key = Curve25519PublicKey::from_bytes([0; 32]);
let alice = Sas::default();
let bob = Sas::default();
let mut established = alice
.diffie_hellman(bob.public_key())
.expect("We should be able to establish a SAS session");
established.our_public_key = key;
established.their_public_key = key;
assert_debug_snapshot!(established);
}
#[test]
fn libolm_and_vodozemac_generate_same_bytes() {
let mut olm = OlmSas::new();
let dalek = Sas::new();
olm.set_their_public_key(dalek.public_key().to_base64())
.expect("Couldn't set the public key for libolm");
let established = dalek
.diffie_hellman_with_raw(&olm.public_key())
.expect("Couldn't establish SAS secret");
assert_eq!(
olm.generate_bytes("TEST", 10).expect("libolm couldn't generate SAS bytes"),
established.bytes_raw("TEST", 10).expect("vodozemac couldn't generate SAS bytes")
);
}
#[test]
fn vodozemac_and_vodozemac_generate_same_bytes() {
let alice = Sas::default();
let bob = Sas::default();
let alice_public_key_encoded = alice.public_key().to_base64();
let alice_public_key = alice.public_key().to_owned();
let bob_public_key_encoded = bob.public_key().to_base64();
let bob_public_key = bob.public_key();
let alice_established = alice
.diffie_hellman_with_raw(&bob_public_key_encoded)
.expect("Couldn't establish SAS secret for Alice");
let bob_established = bob
.diffie_hellman_with_raw(&alice_public_key_encoded)
.expect("Couldn't establish SAS secret for Bob");
assert_eq!(alice_established.our_public_key(), alice_public_key);
assert_eq!(alice_established.their_public_key(), bob_public_key);
assert_eq!(bob_established.our_public_key(), bob_public_key);
assert_eq!(bob_established.their_public_key(), alice_public_key);
let alice_bytes = alice_established.bytes("TEST");
let bob_bytes = bob_established.bytes("TEST");
assert_eq!(alice_bytes, bob_bytes, "The two sides calculated different bytes.");
assert_eq!(
alice_bytes.emoji_indices(),
bob_bytes.emoji_indices(),
"The two sides calculated different emoji indices."
);
assert_eq!(
alice_bytes.decimals(),
bob_bytes.decimals(),
"The two sides calculated different decimals."
);
assert_eq!(alice_bytes.as_bytes(), bob_bytes.as_bytes());
}
#[test]
fn calculate_mac_vodozemac_vodozemac() {
let alice = Sas::new();
let bob = Sas::new();
let alice_public_key = alice.public_key().to_base64();
let bob_public_key = bob.public_key().to_base64();
let message = format!("ed25519:{BOB_DEVICE_ID}");
let extra_info = format!(
"MATRIX_KEY_VERIFICATION_MAC\
{BOB_MXID}{BOB_DEVICE_ID}\
{ALICE_MXID}{ALICE_DEVICE_ID}\
$1234567890\
KEY_IDS",
);
let alice_established = alice
.diffie_hellman_with_raw(&bob_public_key)
.expect("Couldn't establish SAS secret for Alice");
let bob_established = bob
.diffie_hellman_with_raw(&alice_public_key)
.expect("Couldn't establish SAS secret for Bob");
let alice_mac = alice_established.calculate_mac(&message, &extra_info);
let bob_mac = bob_established.calculate_mac(&message, &extra_info);
assert_eq!(
alice_mac.to_base64(),
bob_mac.to_base64(),
"Two vodozemac devices calculated different SAS MACs."
);
alice_established
.verify_mac(&message, &extra_info, &bob_mac)
.expect("Alice couldn't verify Bob's MAC");
bob_established
.verify_mac(&message, &extra_info, &alice_mac)
.expect("Bob couldn't verify Alice's MAC");
let invalid_mac = Mac::from_slice(&[
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
1, 0, 1,
]);
alice_established
.verify_mac(&message, &extra_info, &invalid_mac)
.expect_err("Alice verified an invalid MAC");
bob_established
.verify_mac(&message, &extra_info, &invalid_mac)
.expect_err("Bob verified an invalid MAC");
}
#[test]
fn calculate_mac_vodozemac_libolm() {
let alice_on_dalek = Sas::new();
let mut bob_on_libolm = OlmSas::new();
let alice_public_key = alice_on_dalek.public_key().to_base64();
let bob_public_key = bob_on_libolm.public_key();
let message = format!("ed25519:{BOB_DEVICE_ID}");
let extra_info = format!(
"MATRIX_KEY_VERIFICATION_MAC\
{BOB_MXID}{BOB_DEVICE_ID}\
{ALICE_MXID}{ALICE_DEVICE_ID}\
$1234567890\
KEY_IDS",
);
bob_on_libolm
.set_their_public_key(alice_public_key)
.expect("Couldn't set the public key for libolm");
let established = alice_on_dalek
.diffie_hellman_with_raw(&bob_public_key)
.expect("Couldn't establish SAS secret");
let olm_mac = bob_on_libolm
.calculate_mac_fixed_base64(&message, &extra_info)
.expect("libolm couldn't calculate SAS MAC.");
assert_eq!(olm_mac, established.calculate_mac(&message, &extra_info).to_base64());
let olm_mac =
Mac::from_base64(&olm_mac).expect("SAS MAC generated by libolm wasn't valid base64.");
established.verify_mac(&message, &extra_info, &olm_mac).expect("Couldn't verify MAC");
}
#[test]
#[cfg(feature = "libolm-compat")]
fn calculate_mac_invalid_base64() {
let mut olm = OlmSas::new();
let dalek = Sas::new();
olm.set_their_public_key(dalek.public_key().to_base64())
.expect("Couldn't set the public key for libolm");
let established = dalek
.diffie_hellman_with_raw(&olm.public_key())
.expect("Couldn't establish SAS secret");
let olm_mac = olm.calculate_mac("", "").expect("libolm couldn't calculate a MAC");
assert_eq!(olm_mac, established.calculate_mac_invalid_base64("", ""));
}
#[test]
fn emoji_generation() {
let bytes: [u8; 6] = [0, 0, 0, 0, 0, 0];
let index: [u8; 7] = [0, 0, 0, 0, 0, 0, 0];
assert_eq!(SasBytes::bytes_to_emoji_index(&bytes), index.as_ref());
assert_eq!(SasBytes { bytes }.emoji_indices(), index.as_ref());
let bytes: [u8; 6] = [0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF];
let index: [u8; 7] = [63, 63, 63, 63, 63, 63, 63];
assert_eq!(SasBytes::bytes_to_emoji_index(&bytes), index.as_ref());
assert_eq!(SasBytes { bytes }.emoji_indices(), index.as_ref());
}
#[test]
fn decimal_generation() {
let bytes: [u8; 6] = [0, 0, 0, 0, 0, 0];
let decimal: (u16, u16, u16) = (1000, 1000, 1000);
assert_eq!(SasBytes::bytes_to_decimal(&bytes), decimal);
assert_eq!(SasBytes { bytes }.decimals(), decimal);
let bytes: [u8; 6] = [0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF];
let decimal: (u16, u16, u16) = (9191, 9191, 9191);
assert_eq!(SasBytes::bytes_to_decimal(&bytes), decimal);
assert_eq!(SasBytes { bytes }.decimals(), decimal);
}
proptest! {
#[test]
fn proptest_emoji(bytes in prop::array::uniform6(0u8..)) {
let numbers = SasBytes::bytes_to_emoji_index(&bytes);
for number in numbers.iter() {
prop_assert!(*number < 64);
}
}
}
proptest! {
#[test]
fn proptest_decimals(bytes in prop::array::uniform6(0u8..)) {
let (first, second, third) = SasBytes::bytes_to_decimal(&bytes);
prop_assert!((1000..=9191).contains(&first));
prop_assert!((1000..=9191).contains(&second));
prop_assert!((1000..=9191).contains(&third));
}
}
}