#![cfg(any(test, feature = "testing"))]
#![cfg_attr(docsrs, doc(cfg(feature = "testing")))]
#![allow(clippy::arithmetic_side_effects)]
#![allow(clippy::panic)]
#![allow(clippy::unwrap_used)]
extern crate alloc;
use alloc::{collections::BTreeSet, sync::Arc, vec::Vec};
use core::{
    cell::UnsafeCell,
    mem::{self, MaybeUninit},
    ops::Deref,
    result::Result,
};
use aranya_crypto::{
    self,
    aqc::{BidiChannelId, UniChannelId},
    csprng::Random,
    engine::WrappedKey,
    keystore::{memstore, Entry, Occupied, Vacant},
    CipherSuite, DeviceId, EncryptionKey, EncryptionKeyId, EncryptionPublicKey, Engine, Id,
    IdentityKey, KeyStore, Rng,
};
use aranya_policy_vm::{ActionContext, CommandContext};
use spin::Mutex;
use crate::{
    ffi::{AqcBidiChannel, AqcUniChannel, Ffi},
    handler::{
        BidiChannelCreated, BidiChannelReceived, Handler, UniChannelCreated, UniChannelReceived,
        UniPsk,
    },
    shared::LabelId,
};
fn encode_enc_pk<CS: CipherSuite>(pk: &EncryptionPublicKey<CS>) -> Vec<u8> {
    postcard::to_allocvec(pk).expect("should be able to encode an `EncryptionPublicKey`")
}
fn shuffle<T>(data: &mut [T]) {
    shuffle_by(data.len(), |i, j| {
        data.swap(i, j);
    })
}
fn shuffle_by<F>(n: usize, mut swap: F)
where
    F: FnMut(usize, usize),
{
    for i in (0..n).rev() {
        let j = rand_intn(i + 1);
        swap(i, j);
    }
}
#[track_caller]
fn assert_unique<T>(iter: impl IntoIterator<Item = T>)
where
    T: Ord,
{
    let mut uniq = BTreeSet::new();
    for v in iter {
        assert!(uniq.insert(v));
    }
}
fn rand_intn(max: usize) -> usize {
    debug_assert!(max < usize::MAX);
                            loop {
        let range = max;
        let rand = usize::random(&mut Rng);
        let (hi, lo) = widening_mul(rand, range);
        let thresh = 0usize.wrapping_sub(range) % range;
        if lo >= thresh {
                                    debug_assert!(hi < max);
            break hi;
        }
    }
}
#[inline(always)]
const fn widening_mul(x: usize, y: usize) -> (usize, usize) {
    const SHIFT: u32 = usize::BITS / 2;     const MASK: usize = (1 << SHIFT) - 1; 
    let x1 = x >> SHIFT;
    let x0 = x & MASK;
    let y1 = y >> SHIFT;
    let y0 = y & MASK;
                let w0 = x0 * y0;
                let t = (x1 * y0).wrapping_add(w0 >> SHIFT);
                let w1 = (x0 * y1).wrapping_add(t & MASK);
    let w2 = t >> SHIFT;
                let hi = (x1 * y1).wrapping_add(w2).wrapping_add(w1 >> SHIFT);
        let lo = x.wrapping_mul(y);
    (hi, lo)
}
#[derive(Clone, Default)]
pub struct MemStore(Arc<MemStoreInner>);
impl MemStore {
        pub fn new() -> Self {
        Self(Default::default())
    }
}
impl KeyStore for MemStore {
    type Error = memstore::Error;
    type Vacant<'a, T: WrappedKey> = VacantEntry<'a, T>;
    type Occupied<'a, T: WrappedKey> = OccupiedEntry<'a, T>;
    fn entry<T: WrappedKey>(&mut self, id: Id) -> Result<Entry<'_, Self, T>, Self::Error> {
        let entry = match self.0.entry(id)? {
            GuardedEntry::Vacant(v) => Entry::Vacant(VacantEntry(v)),
            GuardedEntry::Occupied(v) => Entry::Occupied(OccupiedEntry(v)),
        };
        Ok(entry)
    }
    fn get<T: WrappedKey>(&self, id: Id) -> Result<Option<T>, Self::Error> {
        match self.0.entry(id)? {
            GuardedEntry::Vacant(_) => Ok(None),
            GuardedEntry::Occupied(v) => Ok(Some(v.get()?)),
        }
    }
}
pub struct VacantEntry<'a, T>(Guard<'a, memstore::VacantEntry<'a, T>>);
impl<T: WrappedKey> Vacant<T> for VacantEntry<'_, T> {
    type Error = memstore::Error;
    fn insert(self, key: T) -> Result<(), Self::Error> {
        self.0.with_data(|entry| entry.insert(key))
    }
}
pub struct OccupiedEntry<'a, T>(Guard<'a, memstore::OccupiedEntry<'a, T>>);
impl<T: WrappedKey> Occupied<T> for OccupiedEntry<'_, T> {
    type Error = memstore::Error;
    fn get(&self) -> Result<T, Self::Error> {
        self.0.get()
    }
    fn remove(self) -> Result<T, Self::Error> {
        self.0.with_data(memstore::OccupiedEntry::remove)
    }
}
#[derive(Default)]
struct MemStoreInner {
    mutex: Mutex<()>,
    store: UnsafeCell<memstore::MemStore>,
}
impl MemStoreInner {
    fn entry<T: WrappedKey>(&self, id: Id) -> Result<GuardedEntry<'_, T>, memstore::Error> {
        mem::forget(self.mutex.lock());
                        let store = unsafe { &mut *self.store.get() };
        let entry = match store.entry(id)? {
            Entry::Vacant(entry) => {
                let entry = Guard::new(&self.mutex, entry);
                GuardedEntry::Vacant(entry)
            }
            Entry::Occupied(entry) => {
                let entry = Guard::new(&self.mutex, entry);
                GuardedEntry::Occupied(entry)
            }
        };
        Ok(entry)
    }
}
enum GuardedEntry<'a, T> {
    Vacant(Guard<'a, memstore::VacantEntry<'a, T>>),
    Occupied(Guard<'a, memstore::OccupiedEntry<'a, T>>),
}
#[clippy::has_significant_drop]
struct Guard<'a, T> {
        mutex: &'a Mutex<()>,
    data: MaybeUninit<T>,
}
impl<'a, T> Guard<'a, T> {
    const fn new(mutex: &'a Mutex<()>, data: T) -> Self {
        Self {
            mutex,
            data: MaybeUninit::new(data),
        }
    }
    fn with_data<F, R>(mut self, f: F) -> R
    where
        F: FnOnce(T) -> R,
    {
        let data = mem::replace(&mut self.data, MaybeUninit::uninit());
                                                        f(unsafe { data.assume_init() })
    }
}
impl<T> Drop for Guard<'_, T> {
    fn drop(&mut self) {
                unsafe { self.mutex.force_unlock() }
    }
}
impl<T> Deref for Guard<'_, T> {
    type Target = T;
    fn deref(&self) -> &Self::Target {
                        unsafe { self.data.assume_init_ref() }
    }
}
pub trait TestImpl: Sized {
        type Engine: Engine;
        type Store: KeyStore + Clone;
        fn new() -> Device<Self>;
}
pub struct Device<T: TestImpl> {
    eng: T::Engine,
        device_id: DeviceId,
        enc_key_id: EncryptionKeyId,
        enc_pk: Vec<u8>,
        ffi: Ffi<T::Store>,
        handler: Handler<T::Store>,
}
impl<T: TestImpl> Device<T> {
        pub fn new(mut eng: T::Engine, mut store: T::Store) -> Self {
        let device_id = IdentityKey::<<T::Engine as Engine>::CS>::new(&mut eng)
            .id()
            .expect("device ID should be valid");
        let enc_sk = EncryptionKey::new(&mut eng);
        let enc_key_id = enc_sk.id().expect("encryption key ID should be valid");
        let enc_pk = encode_enc_pk(
            &enc_sk
                .public()
                .expect("encryption public key should be valid"),
        );
        let wrapped = eng
            .wrap(enc_sk)
            .expect("should be able to wrap `EncryptionKey`");
        store
            .try_insert(enc_key_id.into(), wrapped)
            .expect("should be able to insert wrapped `EncryptionKey`");
        Self {
            eng,
            device_id,
            enc_key_id,
            enc_pk,
            ffi: Ffi::new(store.clone()),
            handler: Handler::new(device_id, store),
        }
    }
}
#[macro_export]
macro_rules! test_all {
    ($name:ident, $impl:ty) => {
        mod $name {
            #[allow(unused_imports)]
            use super::*;
            macro_rules! test {
                ($test:ident) => {
                    #[test]
                    fn $test() {
                        $crate::testing::$test::<$impl>();
                    }
                };
            }
            test!(test_create_bidi_channel);
            test!(test_create_multi_bidi_channels_same_label);
            test!(test_create_multi_bidi_channels_same_parent_cmd_id);
            test!(test_create_multi_bidi_channels_same_label_multi_peers);
            test!(test_create_send_only_uni_channel);
            test!(test_create_recv_only_uni_channel);
        }
    };
}
pub use test_all;
pub fn test_create_bidi_channel<T: TestImpl>() {
    let mut author = T::new();
    let mut peer = T::new();
    let label_id = LabelId::random(&mut Rng);
    let parent_cmd_id = Id::random(&mut Rng);
    let ctx = CommandContext::Action(ActionContext {
        name: "CreateBidiChannel",
        head_id: parent_cmd_id,
    });
        let AqcBidiChannel {
        peer_encap,
        channel_id,
        author_secrets_id,
        psk_length_in_bytes,
    } = author
        .ffi
        .create_bidi_channel(
            &ctx,
            &mut author.eng,
            parent_cmd_id,
            author.enc_key_id,
            author.device_id,
            peer.enc_pk.clone(),
            peer.device_id,
            label_id,
        )
        .expect("author should be able to create a bidi channel");
            let author_psk = author
        .handler
        .bidi_channel_created(
            &mut author.eng,
            &BidiChannelCreated {
                channel_id: channel_id.into(),
                parent_cmd_id,
                author_id: author.device_id,
                author_enc_key_id: author.enc_key_id,
                peer_id: peer.device_id,
                peer_enc_pk: &peer.enc_pk,
                label_id,
                author_secrets_id: author_secrets_id.into(),
                psk_length_in_bytes: psk_length_in_bytes.try_into().unwrap(),
            },
        )
        .expect("author should be able to load bidi PSK");
            let peer_psk = peer
        .handler
        .bidi_channel_received(
            &mut peer.eng,
            &BidiChannelReceived {
                channel_id: channel_id.into(),
                parent_cmd_id,
                author_id: author.device_id,
                author_enc_pk: &author.enc_pk,
                peer_id: peer.device_id,
                peer_enc_key_id: peer.enc_key_id,
                label_id,
                encap: &peer_encap,
                psk_length_in_bytes: psk_length_in_bytes.try_into().unwrap(),
            },
        )
        .expect("peer should be able to load bidi keys");
    assert_eq!(BidiChannelId::from(channel_id), author_psk.identity());
    assert_eq!(author_psk.identity(), peer_psk.identity());
    assert_eq!(author_psk.raw_secret_bytes(), peer_psk.raw_secret_bytes());
}
pub fn test_create_send_only_uni_channel<T: TestImpl>() {
    let mut author = T::new();
    let mut peer = T::new();
    let label_id = LabelId::random(&mut Rng);
    let parent_cmd_id = Id::random(&mut Rng);
    let ctx = CommandContext::Action(ActionContext {
        name: "CreateUniSendOnlyChannel",
        head_id: parent_cmd_id,
    });
        let AqcUniChannel {
        peer_encap,
        channel_id,
        author_secrets_id,
        psk_length_in_bytes,
    } = author
        .ffi
        .create_uni_channel(
            &ctx,
            &mut author.eng,
            parent_cmd_id,
            author.enc_key_id,
            peer.enc_pk.clone(),
            author.device_id,
            peer.device_id,
            label_id,
        )
        .expect("author should be able to create a uni channel");
            let author_psk = author
        .handler
        .uni_channel_created(
            &mut author.eng,
            &UniChannelCreated {
                channel_id: channel_id.into(),
                parent_cmd_id,
                author_id: author.device_id,
                send_id: author.device_id,
                recv_id: peer.device_id,
                author_enc_key_id: author.enc_key_id,
                peer_enc_pk: &peer.enc_pk,
                label_id,
                author_secrets_id: author_secrets_id.into(),
                psk_length_in_bytes: psk_length_in_bytes.try_into().unwrap(),
            },
        )
        .expect("author should be able to load encryption key");
    assert!(matches!(author_psk, UniPsk::SendOnly(_)));
            let peer_psk = peer
        .handler
        .uni_channel_received(
            &mut peer.eng,
            &UniChannelReceived {
                channel_id: channel_id.into(),
                parent_cmd_id,
                author_id: author.device_id,
                send_id: author.device_id,
                recv_id: peer.device_id,
                author_enc_pk: &author.enc_pk,
                peer_enc_key_id: peer.enc_key_id,
                label_id,
                encap: &peer_encap,
                psk_length_in_bytes: psk_length_in_bytes.try_into().unwrap(),
            },
        )
        .expect("peer should be able to load decryption key");
    assert!(matches!(peer_psk, UniPsk::RecvOnly(_)));
    assert_eq!(UniChannelId::from(channel_id), author_psk.identity());
    assert_eq!(author_psk.identity(), peer_psk.identity());
    assert_eq!(author_psk.raw_secret_bytes(), peer_psk.raw_secret_bytes());
}
pub fn test_create_recv_only_uni_channel<T: TestImpl>() {
    let mut author = T::new();     let mut peer = T::new(); 
    let label_id = LabelId::random(&mut Rng);
    let parent_cmd_id = Id::random(&mut Rng);
    let ctx = CommandContext::Action(ActionContext {
        name: "CreateUniRecvOnlyChannel",
        head_id: parent_cmd_id,
    });
        let AqcUniChannel {
        peer_encap,
        channel_id,
        author_secrets_id,
        psk_length_in_bytes,
    } = author
        .ffi
        .create_uni_channel(
            &ctx,
            &mut author.eng,
            parent_cmd_id,
            author.enc_key_id,
            peer.enc_pk.clone(),
            author.device_id,
            peer.device_id,
            label_id,
        )
        .expect("author should be able to create a uni channel");
            let author_psk = author
        .handler
        .uni_channel_created(
            &mut author.eng,
            &UniChannelCreated {
                channel_id: channel_id.into(),
                parent_cmd_id,
                author_id: author.device_id,
                send_id: peer.device_id,
                recv_id: author.device_id,
                author_enc_key_id: author.enc_key_id,
                peer_enc_pk: &peer.enc_pk,
                label_id,
                author_secrets_id: author_secrets_id.into(),
                psk_length_in_bytes: psk_length_in_bytes.try_into().unwrap(),
            },
        )
        .expect("author should be able to load decryption key");
    assert!(matches!(author_psk, UniPsk::RecvOnly(_)));
            let peer_psk = peer
        .handler
        .uni_channel_received(
            &mut peer.eng,
            &UniChannelReceived {
                channel_id: channel_id.into(),
                parent_cmd_id,
                author_id: author.device_id,
                send_id: peer.device_id,
                recv_id: author.device_id,
                author_enc_pk: &author.enc_pk,
                peer_enc_key_id: peer.enc_key_id,
                label_id,
                encap: &peer_encap,
                psk_length_in_bytes: psk_length_in_bytes.try_into().unwrap(),
            },
        )
        .expect("peer should be able to load encryption key");
    assert!(matches!(peer_psk, UniPsk::SendOnly(_)));
    assert_eq!(UniChannelId::from(channel_id), author_psk.identity());
    assert_eq!(author_psk.identity(), peer_psk.identity());
    assert_eq!(author_psk.raw_secret_bytes(), peer_psk.raw_secret_bytes());
}
pub fn test_create_multi_bidi_channels_same_label<T: TestImpl>() {
    let mut author = T::new();
    let mut peer = T::new();
    let label_id = LabelId::random(&mut Rng);
    let (mut expect, peer_encaps): (Vec<_>, Vec<_>) = (0..50)
        .map(|_| {
            let parent_cmd_id = Id::random(&mut Rng);
            let ctx = CommandContext::Action(ActionContext {
                name: "CreateBidiChannel",
                head_id: parent_cmd_id,
            });
                        let AqcBidiChannel {
                peer_encap,
                channel_id,
                author_secrets_id,
                psk_length_in_bytes,
            } = author
                .ffi
                .create_bidi_channel(
                    &ctx,
                    &mut author.eng,
                    parent_cmd_id,
                    author.enc_key_id,
                    author.device_id,
                    peer.enc_pk.clone(),
                    peer.device_id,
                    label_id,
                )
                .expect("author should be able to create a bidi channel");
            let created = BidiChannelCreated {
                channel_id: channel_id.into(),
                parent_cmd_id,
                author_id: author.device_id,
                author_enc_key_id: author.enc_key_id,
                peer_id: peer.device_id,
                peer_enc_pk: &peer.enc_pk,
                label_id,
                author_secrets_id: author_secrets_id.into(),
                psk_length_in_bytes: psk_length_in_bytes.try_into().unwrap(),
            };
            let received = BidiChannelReceived {
                channel_id: channel_id.into(),
                parent_cmd_id,
                author_id: author.device_id,
                author_enc_pk: &author.enc_pk,
                peer_id: peer.device_id,
                peer_enc_key_id: peer.enc_key_id,
                label_id,
                encap: &[],
                psk_length_in_bytes: psk_length_in_bytes.try_into().unwrap(),
            };
            ((created, received), peer_encap)
        })
        .unzip();
            for ((_, received), encap) in expect.iter_mut().zip(&peer_encaps) {
        received.encap = encap;
    }
    shuffle(&mut expect);
        assert_unique(expect.iter().map(|(created, _)| created.channel_id));
    for (created, received) in &expect {
                        let author_psk = author
            .handler
            .bidi_channel_created(&mut author.eng, created)
            .expect("author should be able to load bidi PSK");
                        let peer_psk = peer
            .handler
            .bidi_channel_received(&mut peer.eng, received)
            .expect("peer should be able to load bidi keys");
        assert_eq!(author_psk.identity(), peer_psk.identity());
        assert_eq!(author_psk.raw_secret_bytes(), peer_psk.raw_secret_bytes());
    }
}
pub fn test_create_multi_bidi_channels_same_parent_cmd_id<T: TestImpl>() {
    let mut author = T::new();
    let mut peer = T::new();
    let parent_cmd_id = Id::random(&mut Rng);
    let ctx = CommandContext::Action(ActionContext {
        name: "CreateBidiChannel",
        head_id: parent_cmd_id,
    });
    let (mut expect, peer_encaps): (Vec<_>, Vec<_>) = (0..50)
        .map(|_| {
            let label_id = LabelId::random(&mut Rng);
                        let AqcBidiChannel {
                peer_encap,
                channel_id,
                author_secrets_id,
                psk_length_in_bytes,
            } = author
                .ffi
                .create_bidi_channel(
                    &ctx,
                    &mut author.eng,
                    parent_cmd_id,
                    author.enc_key_id,
                    author.device_id,
                    peer.enc_pk.clone(),
                    peer.device_id,
                    label_id,
                )
                .expect("author should be able to create a bidi channel");
            let created = BidiChannelCreated {
                channel_id: channel_id.into(),
                parent_cmd_id,
                author_id: author.device_id,
                author_enc_key_id: author.enc_key_id,
                peer_id: peer.device_id,
                peer_enc_pk: &peer.enc_pk,
                label_id,
                author_secrets_id: author_secrets_id.into(),
                psk_length_in_bytes: psk_length_in_bytes.try_into().unwrap(),
            };
            let received = BidiChannelReceived {
                channel_id: channel_id.into(),
                parent_cmd_id,
                author_id: author.device_id,
                author_enc_pk: &author.enc_pk,
                peer_id: peer.device_id,
                peer_enc_key_id: peer.enc_key_id,
                label_id,
                encap: &[],
                psk_length_in_bytes: psk_length_in_bytes.try_into().unwrap(),
            };
            ((created, received), peer_encap)
        })
        .unzip();
            for ((_, received), encap) in expect.iter_mut().zip(&peer_encaps) {
        received.encap = encap;
    }
    shuffle(&mut expect);
        assert_unique(expect.iter().map(|(created, _)| created.channel_id));
    for (created, received) in &expect {
                        let author_psk = author
            .handler
            .bidi_channel_created(&mut author.eng, created)
            .expect("author should be able to load bidi PSK");
                        let peer_psk = peer
            .handler
            .bidi_channel_received(&mut peer.eng, received)
            .expect("peer should be able to load bidi keys");
        assert_eq!(created.channel_id, author_psk.identity());
        assert_eq!(author_psk.identity(), peer_psk.identity());
        assert_eq!(author_psk.raw_secret_bytes(), peer_psk.raw_secret_bytes());
    }
}
pub fn test_create_multi_bidi_channels_same_label_multi_peers<T: TestImpl>() {
    let mut author = T::new();
    let mut peers = (0..50).map(|_| T::new()).collect::<Vec<_>>();
        let peer_enc_pks = peers
        .iter()
        .map(|peer| peer.enc_pk.clone())
        .collect::<Vec<_>>();
    let label_id = LabelId::random(&mut Rng);
    let (mut expect, peer_encaps): (Vec<_>, Vec<_>) = peers
        .iter()
        .enumerate()
        .map(|(i, peer)| {
            let parent_cmd_id = Id::random(&mut Rng);
            let ctx = CommandContext::Action(ActionContext {
                name: "CreateBidiChannel",
                head_id: parent_cmd_id,
            });
                        let AqcBidiChannel {
                peer_encap,
                channel_id,
                author_secrets_id,
                psk_length_in_bytes,
            } = author
                .ffi
                .create_bidi_channel(
                    &ctx,
                    &mut author.eng,
                    parent_cmd_id,
                    author.enc_key_id,
                    author.device_id,
                    peer.enc_pk.clone(),
                    peer.device_id,
                    label_id,
                )
                .expect("author should be able to create a bidi channel");
            let created = BidiChannelCreated {
                channel_id: channel_id.into(),
                parent_cmd_id,
                author_id: author.device_id,
                author_enc_key_id: author.enc_key_id,
                peer_id: peer.device_id,
                peer_enc_pk: &peer_enc_pks[i],
                label_id,
                author_secrets_id: author_secrets_id.into(),
                psk_length_in_bytes: psk_length_in_bytes.try_into().unwrap(),
            };
            let received = BidiChannelReceived {
                channel_id: channel_id.into(),
                parent_cmd_id,
                author_id: author.device_id,
                author_enc_pk: &author.enc_pk,
                peer_id: peer.device_id,
                peer_enc_key_id: peer.enc_key_id,
                label_id,
                encap: &[],
                psk_length_in_bytes: psk_length_in_bytes.try_into().unwrap(),
            };
            ((created, received), peer_encap)
        })
        .unzip();
            for ((_, received), encap) in expect.iter_mut().zip(&peer_encaps) {
        received.encap = encap;
    }
    shuffle_by(expect.len(), |i, j| {
        expect.swap(i, j);
        peers.swap(i, j);
    });
        assert_unique(expect.iter().map(|(created, _)| created.channel_id));
    for ((created, received), mut peer) in expect.iter().zip(peers) {
                        let author_psk = author
            .handler
            .bidi_channel_created(&mut author.eng, created)
            .expect("author should be able to load bidi PSK");
                        let peer_psk = peer
            .handler
            .bidi_channel_received(&mut peer.eng, received)
            .expect("peer should be able to load bidi keys");
        assert_eq!(author_psk.identity(), peer_psk.identity());
        assert_eq!(author_psk.raw_secret_bytes(), peer_psk.raw_secret_bytes());
    }
}
#[cfg(test)]
mod tests {
    use alloc::collections::BTreeSet;
    use super::*;
    #[test]
    fn test_shuffle() {
        let r = 1;
        let g = 2;
        let b = 3;
        let mut perms = BTreeSet::from_iter([
            [r, g, b],
            [r, b, g],
            [g, r, b],
            [g, b, r],
            [b, r, g],
            [b, g, r],
        ]);
        let mut n = 0;
        let mut data = [r, g, b];
        while !perms.is_empty() {
            shuffle(&mut data);
            perms.remove(&data);
            n += 1;
                                                                                                                                                                                                if n > 1000 {
                panic!("too many iters");
            }
        }
    }
}