#![cfg(any(test, feature = "testing"))]
#![cfg_attr(docsrs, doc(cfg(feature = "testing")))]
#![allow(clippy::arithmetic_side_effects)]
#![allow(clippy::panic)]
#![allow(clippy::unwrap_used)]
extern crate alloc;
use alloc::{collections::BTreeSet, sync::Arc, vec::Vec};
use core::{
cell::UnsafeCell,
mem::{self, MaybeUninit},
ops::Deref,
result::Result,
};
use aranya_crypto::{
self,
aqc::{BidiChannelId, UniChannelId},
csprng::Random,
engine::WrappedKey,
keystore::{memstore, Entry, Occupied, Vacant},
CipherSuite, DeviceId, EncryptionKey, EncryptionKeyId, EncryptionPublicKey, Engine, Id,
IdentityKey, KeyStore, Rng,
};
use aranya_policy_vm::{ActionContext, CommandContext};
use spin::Mutex;
use crate::{
ffi::{AqcBidiChannel, AqcUniChannel, Ffi},
handler::{
BidiChannelCreated, BidiChannelReceived, Handler, UniChannelCreated, UniChannelReceived,
UniPsk,
},
shared::LabelId,
};
fn encode_enc_pk<CS: CipherSuite>(pk: &EncryptionPublicKey<CS>) -> Vec<u8> {
postcard::to_allocvec(pk).expect("should be able to encode an `EncryptionPublicKey`")
}
fn shuffle<T>(data: &mut [T]) {
shuffle_by(data.len(), |i, j| {
data.swap(i, j);
})
}
fn shuffle_by<F>(n: usize, mut swap: F)
where
F: FnMut(usize, usize),
{
for i in (0..n).rev() {
let j = rand_intn(i + 1);
swap(i, j);
}
}
#[track_caller]
fn assert_unique<T>(iter: impl IntoIterator<Item = T>)
where
T: Ord,
{
let mut uniq = BTreeSet::new();
for v in iter {
assert!(uniq.insert(v));
}
}
fn rand_intn(max: usize) -> usize {
debug_assert!(max < usize::MAX);
loop {
let range = max;
let rand = usize::random(&mut Rng);
let (hi, lo) = widening_mul(rand, range);
let thresh = 0usize.wrapping_sub(range) % range;
if lo >= thresh {
debug_assert!(hi < max);
break hi;
}
}
}
#[inline(always)]
const fn widening_mul(x: usize, y: usize) -> (usize, usize) {
const SHIFT: u32 = usize::BITS / 2; const MASK: usize = (1 << SHIFT) - 1;
let x1 = x >> SHIFT;
let x0 = x & MASK;
let y1 = y >> SHIFT;
let y0 = y & MASK;
let w0 = x0 * y0;
let t = (x1 * y0).wrapping_add(w0 >> SHIFT);
let w1 = (x0 * y1).wrapping_add(t & MASK);
let w2 = t >> SHIFT;
let hi = (x1 * y1).wrapping_add(w2).wrapping_add(w1 >> SHIFT);
let lo = x.wrapping_mul(y);
(hi, lo)
}
#[derive(Clone, Default)]
pub struct MemStore(Arc<MemStoreInner>);
impl MemStore {
pub fn new() -> Self {
Self(Default::default())
}
}
impl KeyStore for MemStore {
type Error = memstore::Error;
type Vacant<'a, T: WrappedKey> = VacantEntry<'a, T>;
type Occupied<'a, T: WrappedKey> = OccupiedEntry<'a, T>;
fn entry<T: WrappedKey>(&mut self, id: Id) -> Result<Entry<'_, Self, T>, Self::Error> {
let entry = match self.0.entry(id)? {
GuardedEntry::Vacant(v) => Entry::Vacant(VacantEntry(v)),
GuardedEntry::Occupied(v) => Entry::Occupied(OccupiedEntry(v)),
};
Ok(entry)
}
fn get<T: WrappedKey>(&self, id: Id) -> Result<Option<T>, Self::Error> {
match self.0.entry(id)? {
GuardedEntry::Vacant(_) => Ok(None),
GuardedEntry::Occupied(v) => Ok(Some(v.get()?)),
}
}
}
pub struct VacantEntry<'a, T>(Guard<'a, memstore::VacantEntry<'a, T>>);
impl<T: WrappedKey> Vacant<T> for VacantEntry<'_, T> {
type Error = memstore::Error;
fn insert(self, key: T) -> Result<(), Self::Error> {
self.0.with_data(|entry| entry.insert(key))
}
}
pub struct OccupiedEntry<'a, T>(Guard<'a, memstore::OccupiedEntry<'a, T>>);
impl<T: WrappedKey> Occupied<T> for OccupiedEntry<'_, T> {
type Error = memstore::Error;
fn get(&self) -> Result<T, Self::Error> {
self.0.get()
}
fn remove(self) -> Result<T, Self::Error> {
self.0.with_data(memstore::OccupiedEntry::remove)
}
}
#[derive(Default)]
struct MemStoreInner {
mutex: Mutex<()>,
store: UnsafeCell<memstore::MemStore>,
}
impl MemStoreInner {
fn entry<T: WrappedKey>(&self, id: Id) -> Result<GuardedEntry<'_, T>, memstore::Error> {
mem::forget(self.mutex.lock());
let store = unsafe { &mut *self.store.get() };
let entry = match store.entry(id)? {
Entry::Vacant(entry) => {
let entry = Guard::new(&self.mutex, entry);
GuardedEntry::Vacant(entry)
}
Entry::Occupied(entry) => {
let entry = Guard::new(&self.mutex, entry);
GuardedEntry::Occupied(entry)
}
};
Ok(entry)
}
}
enum GuardedEntry<'a, T> {
Vacant(Guard<'a, memstore::VacantEntry<'a, T>>),
Occupied(Guard<'a, memstore::OccupiedEntry<'a, T>>),
}
#[clippy::has_significant_drop]
struct Guard<'a, T> {
mutex: &'a Mutex<()>,
data: MaybeUninit<T>,
}
impl<'a, T> Guard<'a, T> {
const fn new(mutex: &'a Mutex<()>, data: T) -> Self {
Self {
mutex,
data: MaybeUninit::new(data),
}
}
fn with_data<F, R>(mut self, f: F) -> R
where
F: FnOnce(T) -> R,
{
let data = mem::replace(&mut self.data, MaybeUninit::uninit());
f(unsafe { data.assume_init() })
}
}
impl<T> Drop for Guard<'_, T> {
fn drop(&mut self) {
unsafe { self.mutex.force_unlock() }
}
}
impl<T> Deref for Guard<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
unsafe { self.data.assume_init_ref() }
}
}
pub trait TestImpl: Sized {
type Engine: Engine;
type Store: KeyStore + Clone;
fn new() -> Device<Self>;
}
pub struct Device<T: TestImpl> {
eng: T::Engine,
device_id: DeviceId,
enc_key_id: EncryptionKeyId,
enc_pk: Vec<u8>,
ffi: Ffi<T::Store>,
handler: Handler<T::Store>,
}
impl<T: TestImpl> Device<T> {
pub fn new(mut eng: T::Engine, mut store: T::Store) -> Self {
let device_id = IdentityKey::<<T::Engine as Engine>::CS>::new(&mut eng)
.id()
.expect("device ID should be valid");
let enc_sk = EncryptionKey::new(&mut eng);
let enc_key_id = enc_sk.id().expect("encryption key ID should be valid");
let enc_pk = encode_enc_pk(
&enc_sk
.public()
.expect("encryption public key should be valid"),
);
let wrapped = eng
.wrap(enc_sk)
.expect("should be able to wrap `EncryptionKey`");
store
.try_insert(enc_key_id.into(), wrapped)
.expect("should be able to insert wrapped `EncryptionKey`");
Self {
eng,
device_id,
enc_key_id,
enc_pk,
ffi: Ffi::new(store.clone()),
handler: Handler::new(device_id, store),
}
}
}
#[macro_export]
macro_rules! test_all {
($name:ident, $impl:ty) => {
mod $name {
#[allow(unused_imports)]
use super::*;
macro_rules! test {
($test:ident) => {
#[test]
fn $test() {
$crate::testing::$test::<$impl>();
}
};
}
test!(test_create_bidi_channel);
test!(test_create_multi_bidi_channels_same_label);
test!(test_create_multi_bidi_channels_same_parent_cmd_id);
test!(test_create_multi_bidi_channels_same_label_multi_peers);
test!(test_create_send_only_uni_channel);
test!(test_create_recv_only_uni_channel);
}
};
}
pub use test_all;
pub fn test_create_bidi_channel<T: TestImpl>() {
let mut author = T::new();
let mut peer = T::new();
let label_id = LabelId::random(&mut Rng);
let parent_cmd_id = Id::random(&mut Rng);
let ctx = CommandContext::Action(ActionContext {
name: "CreateBidiChannel",
head_id: parent_cmd_id,
});
let AqcBidiChannel {
peer_encap,
channel_id,
author_secrets_id,
psk_length_in_bytes,
} = author
.ffi
.create_bidi_channel(
&ctx,
&mut author.eng,
parent_cmd_id,
author.enc_key_id,
author.device_id,
peer.enc_pk.clone(),
peer.device_id,
label_id,
)
.expect("author should be able to create a bidi channel");
let author_psk = author
.handler
.bidi_channel_created(
&mut author.eng,
&BidiChannelCreated {
channel_id: channel_id.into(),
parent_cmd_id,
author_id: author.device_id,
author_enc_key_id: author.enc_key_id,
peer_id: peer.device_id,
peer_enc_pk: &peer.enc_pk,
label_id,
author_secrets_id: author_secrets_id.into(),
psk_length_in_bytes: psk_length_in_bytes.try_into().unwrap(),
},
)
.expect("author should be able to load bidi PSK");
let peer_psk = peer
.handler
.bidi_channel_received(
&mut peer.eng,
&BidiChannelReceived {
channel_id: channel_id.into(),
parent_cmd_id,
author_id: author.device_id,
author_enc_pk: &author.enc_pk,
peer_id: peer.device_id,
peer_enc_key_id: peer.enc_key_id,
label_id,
encap: &peer_encap,
psk_length_in_bytes: psk_length_in_bytes.try_into().unwrap(),
},
)
.expect("peer should be able to load bidi keys");
assert_eq!(BidiChannelId::from(channel_id), author_psk.identity());
assert_eq!(author_psk.identity(), peer_psk.identity());
assert_eq!(author_psk.raw_secret_bytes(), peer_psk.raw_secret_bytes());
}
pub fn test_create_send_only_uni_channel<T: TestImpl>() {
let mut author = T::new();
let mut peer = T::new();
let label_id = LabelId::random(&mut Rng);
let parent_cmd_id = Id::random(&mut Rng);
let ctx = CommandContext::Action(ActionContext {
name: "CreateUniSendOnlyChannel",
head_id: parent_cmd_id,
});
let AqcUniChannel {
peer_encap,
channel_id,
author_secrets_id,
psk_length_in_bytes,
} = author
.ffi
.create_uni_channel(
&ctx,
&mut author.eng,
parent_cmd_id,
author.enc_key_id,
peer.enc_pk.clone(),
author.device_id,
peer.device_id,
label_id,
)
.expect("author should be able to create a uni channel");
let author_psk = author
.handler
.uni_channel_created(
&mut author.eng,
&UniChannelCreated {
channel_id: channel_id.into(),
parent_cmd_id,
author_id: author.device_id,
send_id: author.device_id,
recv_id: peer.device_id,
author_enc_key_id: author.enc_key_id,
peer_enc_pk: &peer.enc_pk,
label_id,
author_secrets_id: author_secrets_id.into(),
psk_length_in_bytes: psk_length_in_bytes.try_into().unwrap(),
},
)
.expect("author should be able to load encryption key");
assert!(matches!(author_psk, UniPsk::SendOnly(_)));
let peer_psk = peer
.handler
.uni_channel_received(
&mut peer.eng,
&UniChannelReceived {
channel_id: channel_id.into(),
parent_cmd_id,
author_id: author.device_id,
send_id: author.device_id,
recv_id: peer.device_id,
author_enc_pk: &author.enc_pk,
peer_enc_key_id: peer.enc_key_id,
label_id,
encap: &peer_encap,
psk_length_in_bytes: psk_length_in_bytes.try_into().unwrap(),
},
)
.expect("peer should be able to load decryption key");
assert!(matches!(peer_psk, UniPsk::RecvOnly(_)));
assert_eq!(UniChannelId::from(channel_id), author_psk.identity());
assert_eq!(author_psk.identity(), peer_psk.identity());
assert_eq!(author_psk.raw_secret_bytes(), peer_psk.raw_secret_bytes());
}
pub fn test_create_recv_only_uni_channel<T: TestImpl>() {
let mut author = T::new(); let mut peer = T::new();
let label_id = LabelId::random(&mut Rng);
let parent_cmd_id = Id::random(&mut Rng);
let ctx = CommandContext::Action(ActionContext {
name: "CreateUniRecvOnlyChannel",
head_id: parent_cmd_id,
});
let AqcUniChannel {
peer_encap,
channel_id,
author_secrets_id,
psk_length_in_bytes,
} = author
.ffi
.create_uni_channel(
&ctx,
&mut author.eng,
parent_cmd_id,
author.enc_key_id,
peer.enc_pk.clone(),
author.device_id,
peer.device_id,
label_id,
)
.expect("author should be able to create a uni channel");
let author_psk = author
.handler
.uni_channel_created(
&mut author.eng,
&UniChannelCreated {
channel_id: channel_id.into(),
parent_cmd_id,
author_id: author.device_id,
send_id: peer.device_id,
recv_id: author.device_id,
author_enc_key_id: author.enc_key_id,
peer_enc_pk: &peer.enc_pk,
label_id,
author_secrets_id: author_secrets_id.into(),
psk_length_in_bytes: psk_length_in_bytes.try_into().unwrap(),
},
)
.expect("author should be able to load decryption key");
assert!(matches!(author_psk, UniPsk::RecvOnly(_)));
let peer_psk = peer
.handler
.uni_channel_received(
&mut peer.eng,
&UniChannelReceived {
channel_id: channel_id.into(),
parent_cmd_id,
author_id: author.device_id,
send_id: peer.device_id,
recv_id: author.device_id,
author_enc_pk: &author.enc_pk,
peer_enc_key_id: peer.enc_key_id,
label_id,
encap: &peer_encap,
psk_length_in_bytes: psk_length_in_bytes.try_into().unwrap(),
},
)
.expect("peer should be able to load encryption key");
assert!(matches!(peer_psk, UniPsk::SendOnly(_)));
assert_eq!(UniChannelId::from(channel_id), author_psk.identity());
assert_eq!(author_psk.identity(), peer_psk.identity());
assert_eq!(author_psk.raw_secret_bytes(), peer_psk.raw_secret_bytes());
}
pub fn test_create_multi_bidi_channels_same_label<T: TestImpl>() {
let mut author = T::new();
let mut peer = T::new();
let label_id = LabelId::random(&mut Rng);
let (mut expect, peer_encaps): (Vec<_>, Vec<_>) = (0..50)
.map(|_| {
let parent_cmd_id = Id::random(&mut Rng);
let ctx = CommandContext::Action(ActionContext {
name: "CreateBidiChannel",
head_id: parent_cmd_id,
});
let AqcBidiChannel {
peer_encap,
channel_id,
author_secrets_id,
psk_length_in_bytes,
} = author
.ffi
.create_bidi_channel(
&ctx,
&mut author.eng,
parent_cmd_id,
author.enc_key_id,
author.device_id,
peer.enc_pk.clone(),
peer.device_id,
label_id,
)
.expect("author should be able to create a bidi channel");
let created = BidiChannelCreated {
channel_id: channel_id.into(),
parent_cmd_id,
author_id: author.device_id,
author_enc_key_id: author.enc_key_id,
peer_id: peer.device_id,
peer_enc_pk: &peer.enc_pk,
label_id,
author_secrets_id: author_secrets_id.into(),
psk_length_in_bytes: psk_length_in_bytes.try_into().unwrap(),
};
let received = BidiChannelReceived {
channel_id: channel_id.into(),
parent_cmd_id,
author_id: author.device_id,
author_enc_pk: &author.enc_pk,
peer_id: peer.device_id,
peer_enc_key_id: peer.enc_key_id,
label_id,
encap: &[],
psk_length_in_bytes: psk_length_in_bytes.try_into().unwrap(),
};
((created, received), peer_encap)
})
.unzip();
for ((_, received), encap) in expect.iter_mut().zip(&peer_encaps) {
received.encap = encap;
}
shuffle(&mut expect);
assert_unique(expect.iter().map(|(created, _)| created.channel_id));
for (created, received) in &expect {
let author_psk = author
.handler
.bidi_channel_created(&mut author.eng, created)
.expect("author should be able to load bidi PSK");
let peer_psk = peer
.handler
.bidi_channel_received(&mut peer.eng, received)
.expect("peer should be able to load bidi keys");
assert_eq!(author_psk.identity(), peer_psk.identity());
assert_eq!(author_psk.raw_secret_bytes(), peer_psk.raw_secret_bytes());
}
}
pub fn test_create_multi_bidi_channels_same_parent_cmd_id<T: TestImpl>() {
let mut author = T::new();
let mut peer = T::new();
let parent_cmd_id = Id::random(&mut Rng);
let ctx = CommandContext::Action(ActionContext {
name: "CreateBidiChannel",
head_id: parent_cmd_id,
});
let (mut expect, peer_encaps): (Vec<_>, Vec<_>) = (0..50)
.map(|_| {
let label_id = LabelId::random(&mut Rng);
let AqcBidiChannel {
peer_encap,
channel_id,
author_secrets_id,
psk_length_in_bytes,
} = author
.ffi
.create_bidi_channel(
&ctx,
&mut author.eng,
parent_cmd_id,
author.enc_key_id,
author.device_id,
peer.enc_pk.clone(),
peer.device_id,
label_id,
)
.expect("author should be able to create a bidi channel");
let created = BidiChannelCreated {
channel_id: channel_id.into(),
parent_cmd_id,
author_id: author.device_id,
author_enc_key_id: author.enc_key_id,
peer_id: peer.device_id,
peer_enc_pk: &peer.enc_pk,
label_id,
author_secrets_id: author_secrets_id.into(),
psk_length_in_bytes: psk_length_in_bytes.try_into().unwrap(),
};
let received = BidiChannelReceived {
channel_id: channel_id.into(),
parent_cmd_id,
author_id: author.device_id,
author_enc_pk: &author.enc_pk,
peer_id: peer.device_id,
peer_enc_key_id: peer.enc_key_id,
label_id,
encap: &[],
psk_length_in_bytes: psk_length_in_bytes.try_into().unwrap(),
};
((created, received), peer_encap)
})
.unzip();
for ((_, received), encap) in expect.iter_mut().zip(&peer_encaps) {
received.encap = encap;
}
shuffle(&mut expect);
assert_unique(expect.iter().map(|(created, _)| created.channel_id));
for (created, received) in &expect {
let author_psk = author
.handler
.bidi_channel_created(&mut author.eng, created)
.expect("author should be able to load bidi PSK");
let peer_psk = peer
.handler
.bidi_channel_received(&mut peer.eng, received)
.expect("peer should be able to load bidi keys");
assert_eq!(created.channel_id, author_psk.identity());
assert_eq!(author_psk.identity(), peer_psk.identity());
assert_eq!(author_psk.raw_secret_bytes(), peer_psk.raw_secret_bytes());
}
}
pub fn test_create_multi_bidi_channels_same_label_multi_peers<T: TestImpl>() {
let mut author = T::new();
let mut peers = (0..50).map(|_| T::new()).collect::<Vec<_>>();
let peer_enc_pks = peers
.iter()
.map(|peer| peer.enc_pk.clone())
.collect::<Vec<_>>();
let label_id = LabelId::random(&mut Rng);
let (mut expect, peer_encaps): (Vec<_>, Vec<_>) = peers
.iter()
.enumerate()
.map(|(i, peer)| {
let parent_cmd_id = Id::random(&mut Rng);
let ctx = CommandContext::Action(ActionContext {
name: "CreateBidiChannel",
head_id: parent_cmd_id,
});
let AqcBidiChannel {
peer_encap,
channel_id,
author_secrets_id,
psk_length_in_bytes,
} = author
.ffi
.create_bidi_channel(
&ctx,
&mut author.eng,
parent_cmd_id,
author.enc_key_id,
author.device_id,
peer.enc_pk.clone(),
peer.device_id,
label_id,
)
.expect("author should be able to create a bidi channel");
let created = BidiChannelCreated {
channel_id: channel_id.into(),
parent_cmd_id,
author_id: author.device_id,
author_enc_key_id: author.enc_key_id,
peer_id: peer.device_id,
peer_enc_pk: &peer_enc_pks[i],
label_id,
author_secrets_id: author_secrets_id.into(),
psk_length_in_bytes: psk_length_in_bytes.try_into().unwrap(),
};
let received = BidiChannelReceived {
channel_id: channel_id.into(),
parent_cmd_id,
author_id: author.device_id,
author_enc_pk: &author.enc_pk,
peer_id: peer.device_id,
peer_enc_key_id: peer.enc_key_id,
label_id,
encap: &[],
psk_length_in_bytes: psk_length_in_bytes.try_into().unwrap(),
};
((created, received), peer_encap)
})
.unzip();
for ((_, received), encap) in expect.iter_mut().zip(&peer_encaps) {
received.encap = encap;
}
shuffle_by(expect.len(), |i, j| {
expect.swap(i, j);
peers.swap(i, j);
});
assert_unique(expect.iter().map(|(created, _)| created.channel_id));
for ((created, received), mut peer) in expect.iter().zip(peers) {
let author_psk = author
.handler
.bidi_channel_created(&mut author.eng, created)
.expect("author should be able to load bidi PSK");
let peer_psk = peer
.handler
.bidi_channel_received(&mut peer.eng, received)
.expect("peer should be able to load bidi keys");
assert_eq!(author_psk.identity(), peer_psk.identity());
assert_eq!(author_psk.raw_secret_bytes(), peer_psk.raw_secret_bytes());
}
}
#[cfg(test)]
mod tests {
use alloc::collections::BTreeSet;
use super::*;
#[test]
fn test_shuffle() {
let r = 1;
let g = 2;
let b = 3;
let mut perms = BTreeSet::from_iter([
[r, g, b],
[r, b, g],
[g, r, b],
[g, b, r],
[b, r, g],
[b, g, r],
]);
let mut n = 0;
let mut data = [r, g, b];
while !perms.is_empty() {
shuffle(&mut data);
perms.remove(&data);
n += 1;
if n > 1000 {
panic!("too many iters");
}
}
}
}