aranya_daemon_api/
service.rs

1#![allow(clippy::disallowed_macros)] // tarpc uses unreachable
2
3use core::{
4    borrow::Borrow,
5    error, fmt,
6    hash::{Hash, Hasher},
7    net::SocketAddr,
8    ops::Deref,
9    time::Duration,
10};
11use std::collections::hash_map::{self, HashMap};
12
13use anyhow::bail;
14pub use aranya_crypto::aqc::CipherSuiteId;
15use aranya_crypto::{
16    aqc::{BidiPskId, UniPskId},
17    custom_id,
18    default::DefaultEngine,
19    id::IdError,
20    subtle::{Choice, ConstantTimeEq},
21    zeroize::{Zeroize, ZeroizeOnDrop},
22    EncryptionPublicKey, Engine, Id,
23};
24pub use aranya_policy_text::{text, Text};
25use aranya_util::Addr;
26use buggy::Bug;
27pub use semver::Version;
28use serde::{Deserialize, Serialize};
29use tracing::error;
30
31pub mod quic_sync;
32pub use quic_sync::*;
33
34/// CE = Crypto Engine
35pub type CE = DefaultEngine;
36/// CS = Cipher Suite
37pub type CS = <DefaultEngine as Engine>::CS;
38
39/// An error returned by the API.
40// TODO: enum?
41#[derive(Serialize, Deserialize, Debug)]
42pub struct Error(String);
43
44impl Error {
45    pub fn from_msg(err: &str) -> Self {
46        error!(?err);
47        Self(err.into())
48    }
49
50    pub fn from_err<E: error::Error>(err: E) -> Self {
51        error!(?err);
52        Self(format!("{err:?}"))
53    }
54}
55
56impl From<Bug> for Error {
57    fn from(err: Bug) -> Self {
58        error!(?err);
59        Self(format!("{err:?}"))
60    }
61}
62
63impl From<anyhow::Error> for Error {
64    fn from(err: anyhow::Error) -> Self {
65        error!(?err);
66        Self(format!("{err:?}"))
67    }
68}
69
70impl From<semver::Error> for Error {
71    fn from(err: semver::Error) -> Self {
72        error!(?err);
73        Self(format!("{err:?}"))
74    }
75}
76
77impl From<IdError> for Error {
78    fn from(err: IdError) -> Self {
79        error!(%err);
80        Self(err.to_string())
81    }
82}
83
84impl fmt::Display for Error {
85    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
86        self.0.fmt(f)
87    }
88}
89
90impl error::Error for Error {}
91
92pub type Result<T, E = Error> = core::result::Result<T, E>;
93
94custom_id! {
95    /// The Device ID.
96    pub struct DeviceId;
97}
98
99custom_id! {
100    /// The Team ID (a.k.a Graph ID).
101    pub struct TeamId;
102}
103
104custom_id! {
105    /// An AQC label ID.
106    pub struct LabelId;
107}
108
109custom_id! {
110    /// An AQC bidi channel ID.
111    pub struct AqcBidiChannelId;
112}
113
114custom_id! {
115    /// An AQC uni channel ID.
116    pub struct AqcUniChannelId;
117}
118
119/// A device's public key bundle.
120#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
121pub struct KeyBundle {
122    pub identity: Vec<u8>,
123    pub signing: Vec<u8>,
124    pub encoding: Vec<u8>,
125}
126
127/// A device's role on the team.
128#[derive(Copy, Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
129pub enum Role {
130    Owner,
131    Admin,
132    Operator,
133    Member,
134}
135
136// Note: any fields added to this type should be public
137/// A configuration for adding a team in the daemon.
138#[derive(Debug, Serialize, Deserialize)]
139pub struct AddTeamConfig {
140    pub team_id: TeamId,
141    pub quic_sync: Option<AddTeamQuicSyncConfig>,
142}
143
144// Note: any fields added to this type should be public
145/// A configuration for creating a team in the daemon.
146#[derive(Debug, Serialize, Deserialize)]
147pub struct CreateTeamConfig {
148    pub quic_sync: Option<CreateTeamQuicSyncConfig>,
149}
150
151/// A device's network identifier.
152#[derive(Clone, Debug, Serialize, Deserialize, Eq, Ord, PartialEq, PartialOrd)]
153pub struct NetIdentifier(pub Text);
154
155impl Borrow<str> for NetIdentifier {
156    #[inline]
157    fn borrow(&self) -> &str {
158        &self.0
159    }
160}
161
162impl<T> AsRef<T> for NetIdentifier
163where
164    T: ?Sized,
165    <Self as Deref>::Target: AsRef<T>,
166{
167    #[inline]
168    fn as_ref(&self) -> &T {
169        self.deref().as_ref()
170    }
171}
172
173impl Deref for NetIdentifier {
174    type Target = str;
175
176    #[inline]
177    fn deref(&self) -> &Self::Target {
178        &self.0
179    }
180}
181
182impl fmt::Display for NetIdentifier {
183    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
184        self.0.fmt(f)
185    }
186}
187
188/// A serialized command for AQC.
189pub type AqcCtrl = Vec<Box<[u8]>>;
190
191/// A PSK IKM.
192#[derive(Clone, Serialize, Deserialize)]
193pub struct Ikm([u8; SEED_IKM_SIZE]);
194
195impl Ikm {
196    /// Provides access to the raw IKM bytes.
197    #[inline]
198    pub fn raw_ikm_bytes(&self) -> &[u8; SEED_IKM_SIZE] {
199        &self.0
200    }
201}
202
203impl From<[u8; SEED_IKM_SIZE]> for Ikm {
204    fn from(value: [u8; SEED_IKM_SIZE]) -> Self {
205        Self(value)
206    }
207}
208
209impl ConstantTimeEq for Ikm {
210    fn ct_eq(&self, other: &Self) -> Choice {
211        self.0.ct_eq(&other.0)
212    }
213}
214
215impl ZeroizeOnDrop for Ikm {}
216impl Drop for Ikm {
217    fn drop(&mut self) {
218        self.0.zeroize()
219    }
220}
221
222impl fmt::Debug for Ikm {
223    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
224        f.debug_struct("Ikm").finish_non_exhaustive()
225    }
226}
227
228/// A secret.
229#[derive(Clone, Serialize, Deserialize)]
230pub struct Secret(Box<[u8]>);
231
232impl Secret {
233    /// Provides access to the raw secret bytes.
234    #[inline]
235    pub fn raw_secret_bytes(&self) -> &[u8] {
236        &self.0
237    }
238}
239
240impl<T> From<T> for Secret
241where
242    T: Into<Box<[u8]>>,
243{
244    fn from(value: T) -> Self {
245        Self(value.into())
246    }
247}
248
249impl ConstantTimeEq for Secret {
250    fn ct_eq(&self, other: &Self) -> Choice {
251        self.0.ct_eq(&other.0)
252    }
253}
254
255impl ZeroizeOnDrop for Secret {}
256impl Drop for Secret {
257    fn drop(&mut self) {
258        self.0.zeroize()
259    }
260}
261
262impl fmt::Debug for Secret {
263    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
264        f.debug_struct("Secret").finish_non_exhaustive()
265    }
266}
267
268macro_rules! psk_map {
269    (
270        $(#[$meta:meta])*
271        $vis:vis struct $name:ident(PskMap<$psk:ty>);
272    ) => {
273        $(#[$meta])*
274        #[derive(Clone, Debug, Serialize, Deserialize)]
275        #[cfg_attr(test, derive(PartialEq))]
276        $vis struct $name {
277            id: Id,
278            psks: HashMap<CsId, $psk>
279        }
280
281        impl $name {
282            /// Returns the number of PSKs.
283            pub fn len(&self) -> usize {
284                self.psks.len()
285            }
286
287            /// Reports whether `self` is empty.
288            pub fn is_empty(&self) -> bool {
289                self.psks.is_empty()
290            }
291
292            /// Returns the channel ID.
293            pub fn channel_id(&self) -> &Id {
294                &self.id
295            }
296
297            /// Returns the PSK for the cipher suite.
298            pub fn get(&self, suite: CipherSuiteId) -> Option<&$psk> {
299                self.psks.get(&CsId(suite))
300            }
301
302            /// Creates a PSK map from a function that generates
303            /// a PSK for a cipher suite.
304            pub fn try_from_fn<I, E, F>(id: I, mut f: F) -> anyhow::Result<Self>
305            where
306                I: Into<Id>,
307                anyhow::Error: From<E>,
308                F: FnMut(CipherSuiteId) -> Result<$psk, E>,
309            {
310                let id = id.into();
311                let mut psks = HashMap::new();
312                for &suite in CipherSuiteId::all() {
313                    let psk = f(suite)?;
314                    if !bool::from(psk.identity().channel_id().into_id().ct_eq(&id)) {
315                        bail!("PSK identity does not match channel ID");
316                    }
317                    psks.insert(CsId(suite), psk);
318                }
319                Ok(Self { id, psks })
320            }
321        }
322
323        impl IntoIterator for $name {
324            type Item = (CipherSuiteId, $psk);
325            type IntoIter = IntoPsks<$psk>;
326
327            fn into_iter(self) -> Self::IntoIter {
328                IntoPsks {
329                    iter: self.psks.into_iter(),
330                }
331            }
332        }
333
334        #[cfg(test)]
335        impl tests::PskMap for $name {
336            type Psk = $psk;
337
338            fn new() -> Self {
339                Self {
340                    // TODO
341                    id: Id::default(),
342                    psks: HashMap::new(),
343                }
344            }
345
346            fn len(&self) -> usize {
347                self.psks.len()
348            }
349
350            fn insert(&mut self, psk: Self::Psk) {
351                let suite = psk.cipher_suite();
352                let opt = self.psks.insert(CsId(suite), psk);
353                assert!(opt.is_none());
354            }
355        }
356    };
357}
358psk_map! {
359    /// An injective mapping of PSKs to cipher suites for
360    /// a single bidirectional channel.
361    pub struct AqcBidiPsks(PskMap<AqcBidiPsk>);
362}
363
364psk_map! {
365    /// An injective mapping of PSKs to cipher suites for
366    /// a single unidirectional channel.
367    pub struct AqcUniPsks(PskMap<AqcUniPsk>);
368}
369
370/// An injective mapping of PSKs to cipher suites for a single
371/// bidirectional or unidirectional channel.
372#[derive(Clone, Debug, Serialize, Deserialize)]
373pub enum AqcPsks {
374    Bidi(AqcBidiPsks),
375    Uni(AqcUniPsks),
376}
377
378impl IntoIterator for AqcPsks {
379    type IntoIter = AqcPsksIntoIter;
380    type Item = <Self::IntoIter as Iterator>::Item;
381
382    fn into_iter(self) -> Self::IntoIter {
383        match self {
384            AqcPsks::Bidi(psks) => AqcPsksIntoIter::Bidi(psks.into_iter()),
385            AqcPsks::Uni(psks) => AqcPsksIntoIter::Uni(psks.into_iter()),
386        }
387    }
388}
389
390/// An iterator over an AQC channel's PSKs.
391#[derive(Debug)]
392pub enum AqcPsksIntoIter {
393    Bidi(IntoPsks<AqcBidiPsk>),
394    Uni(IntoPsks<AqcUniPsk>),
395}
396
397impl Iterator for AqcPsksIntoIter {
398    type Item = (CipherSuiteId, AqcPsk);
399    fn next(&mut self) -> Option<Self::Item> {
400        match self {
401            AqcPsksIntoIter::Bidi(it) => it.next().map(|(s, k)| (s, AqcPsk::Bidi(k))),
402            AqcPsksIntoIter::Uni(it) => it.next().map(|(s, k)| (s, AqcPsk::Uni(k))),
403        }
404    }
405}
406
407/// An iterator over an AQC channel's PSKs.
408#[derive(Debug)]
409pub struct IntoPsks<V> {
410    iter: hash_map::IntoIter<CsId, V>,
411}
412
413impl<V> Iterator for IntoPsks<V> {
414    type Item = (CipherSuiteId, V);
415
416    fn next(&mut self) -> Option<Self::Item> {
417        self.iter.next().map(|(k, v)| (k.0, v))
418    }
419}
420
421// TODO(eric): Get rid of this once `CipherSuiteId` implements
422// `Hash`.
423#[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
424#[serde(transparent)]
425struct CsId(CipherSuiteId);
426
427impl Hash for CsId {
428    fn hash<H: Hasher>(&self, state: &mut H) {
429        self.0.to_bytes().hash(state);
430    }
431}
432
433/// An AQC PSK.
434#[derive(Clone, Debug, Serialize, Deserialize)]
435pub enum AqcPsk {
436    /// Bidirectional.
437    Bidi(AqcBidiPsk),
438    /// Unidirectional.
439    Uni(AqcUniPsk),
440}
441
442impl AqcPsk {
443    /// Returns the PSK identity.
444    #[inline]
445    pub fn identity(&self) -> AqcPskId {
446        match self {
447            Self::Bidi(psk) => AqcPskId::Bidi(psk.identity),
448            Self::Uni(psk) => AqcPskId::Uni(psk.identity),
449        }
450    }
451
452    /// Returns the PSK cipher suite.
453    #[inline]
454    pub fn cipher_suite(&self) -> CipherSuiteId {
455        self.identity().cipher_suite()
456    }
457
458    /// Returns the PSK secret.
459    #[inline]
460    pub fn secret(&self) -> &[u8] {
461        match self {
462            Self::Bidi(psk) => psk.secret.raw_secret_bytes(),
463            Self::Uni(psk) => match &psk.secret {
464                Directed::Send(secret) | Directed::Recv(secret) => secret.raw_secret_bytes(),
465            },
466        }
467    }
468}
469
470impl From<AqcBidiPsk> for AqcPsk {
471    fn from(psk: AqcBidiPsk) -> Self {
472        Self::Bidi(psk)
473    }
474}
475
476impl From<AqcUniPsk> for AqcPsk {
477    fn from(psk: AqcUniPsk) -> Self {
478        Self::Uni(psk)
479    }
480}
481
482impl ConstantTimeEq for AqcPsk {
483    fn ct_eq(&self, other: &Self) -> Choice {
484        // It's fine that matching discriminants isn't constant
485        // time since it isn't secret data.
486        match (self, other) {
487            (Self::Bidi(lhs), Self::Bidi(rhs)) => lhs.ct_eq(rhs),
488            (Self::Uni(lhs), Self::Uni(rhs)) => lhs.ct_eq(rhs),
489            _ => Choice::from(0u8),
490        }
491    }
492}
493
494/// An AQC bidirectional channel PSK.
495#[derive(Clone, Debug, Serialize, Deserialize)]
496pub struct AqcBidiPsk {
497    /// The PSK identity.
498    pub identity: BidiPskId,
499    /// The PSK's secret.
500    pub secret: Secret,
501}
502
503impl AqcBidiPsk {
504    fn identity(&self) -> &BidiPskId {
505        &self.identity
506    }
507
508    #[cfg(test)]
509    fn cipher_suite(&self) -> CipherSuiteId {
510        self.identity.cipher_suite()
511    }
512}
513
514impl ConstantTimeEq for AqcBidiPsk {
515    fn ct_eq(&self, other: &Self) -> Choice {
516        let id = self.identity.ct_eq(&other.identity);
517        let secret = self.secret.ct_eq(&other.secret);
518        id & secret
519    }
520}
521
522impl ZeroizeOnDrop for AqcBidiPsk {}
523
524/// An AQC unidirectional PSK.
525#[derive(Clone, Debug, Serialize, Deserialize)]
526pub struct AqcUniPsk {
527    /// The PSK identity.
528    pub identity: UniPskId,
529    /// The PSK's secret.
530    pub secret: Directed<Secret>,
531}
532
533impl AqcUniPsk {
534    fn identity(&self) -> &UniPskId {
535        &self.identity
536    }
537
538    #[cfg(test)]
539    fn cipher_suite(&self) -> CipherSuiteId {
540        self.identity.cipher_suite()
541    }
542}
543
544impl ConstantTimeEq for AqcUniPsk {
545    fn ct_eq(&self, other: &Self) -> Choice {
546        let id = self.identity.ct_eq(&other.identity);
547        let secret = self.secret.ct_eq(&other.secret);
548        id & secret
549    }
550}
551
552impl ZeroizeOnDrop for AqcUniPsk {}
553
554/// Either send only or receive only.
555#[derive(Clone, Debug, Serialize, Deserialize)]
556pub enum Directed<T> {
557    /// Send only.
558    Send(T),
559    /// Receive only.
560    Recv(T),
561}
562
563impl<T: ConstantTimeEq> ConstantTimeEq for Directed<T> {
564    fn ct_eq(&self, other: &Self) -> Choice {
565        // It's fine that matching discriminants isn't constant
566        // time since the direction isn't secret data.
567        match (self, other) {
568            (Self::Send(lhs), Self::Send(rhs)) => lhs.ct_eq(rhs),
569            (Self::Recv(lhs), Self::Recv(rhs)) => lhs.ct_eq(rhs),
570            _ => Choice::from(0u8),
571        }
572    }
573}
574
575/// An AQC PSK identity.
576#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
577pub enum AqcPskId {
578    /// A bidirectional PSK.
579    Bidi(BidiPskId),
580    /// A unidirectional PSK.
581    Uni(UniPskId),
582}
583
584impl AqcPskId {
585    /// Returns the unique channel ID.
586    pub fn channel_id(&self) -> Id {
587        match self {
588            Self::Bidi(v) => (*v.channel_id()).into(),
589            Self::Uni(v) => (*v.channel_id()).into(),
590        }
591    }
592
593    /// Returns the cipher suite.
594    pub fn cipher_suite(&self) -> CipherSuiteId {
595        match self {
596            Self::Bidi(v) => v.cipher_suite(),
597            Self::Uni(v) => v.cipher_suite(),
598        }
599    }
600
601    /// Converts the ID to its byte encoding.
602    pub fn as_bytes(&self) -> &[u8; 34] {
603        match self {
604            Self::Bidi(v) => v.as_bytes(),
605            Self::Uni(v) => v.as_bytes(),
606        }
607    }
608}
609
610/// Configuration values for syncing with a peer
611#[derive(Clone, Debug, Serialize, Deserialize)]
612pub struct SyncPeerConfig {
613    /// The interval at which syncing occurs
614    pub interval: Duration,
615    /// Determines if a peer should be synced with immediately after they're added
616    pub sync_now: bool,
617}
618
619/// Valid channel operations for a label assignment.
620#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
621pub enum ChanOp {
622    /// The device can only receive data in channels with this
623    /// label.
624    RecvOnly,
625    /// The device can only send data in channels with this
626    /// label.
627    SendOnly,
628    /// The device can send and receive data in channels with this
629    /// label.
630    SendRecv,
631}
632
633/// A label.
634#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
635pub struct Label {
636    pub id: LabelId,
637    pub name: Text,
638}
639
640#[tarpc::service]
641pub trait DaemonApi {
642    /// Returns the daemon's version.
643    async fn version() -> Result<Version>;
644
645    /// Gets local address the Aranya sync server is bound to.
646    async fn aranya_local_addr() -> Result<SocketAddr>;
647
648    /// Gets the public key bundle for this device
649    async fn get_key_bundle() -> Result<KeyBundle>;
650
651    /// Gets the public device id.
652    async fn get_device_id() -> Result<DeviceId>;
653
654    /// Adds the peer for automatic periodic syncing.
655    async fn add_sync_peer(addr: Addr, team: TeamId, config: SyncPeerConfig) -> Result<()>;
656
657    /// Sync with peer immediately.
658    async fn sync_now(addr: Addr, team: TeamId, cfg: Option<SyncPeerConfig>) -> Result<()>;
659
660    /// Removes the peer from automatic syncing.
661    async fn remove_sync_peer(addr: Addr, team: TeamId) -> Result<()>;
662
663    /// add a team to the local device store that was created by someone else. Not an aranya action/command.
664    async fn add_team(cfg: AddTeamConfig) -> Result<()>;
665
666    /// Remove a team from local device storage.
667    async fn remove_team(team: TeamId) -> Result<()>;
668
669    /// Create a new graph/team with the current device as the owner.
670    async fn create_team(cfg: CreateTeamConfig) -> Result<TeamId>;
671    /// Close the team.
672    async fn close_team(team: TeamId) -> Result<()>;
673
674    async fn encrypt_psk_seed_for_peer(
675        team: TeamId,
676        peer_enc_pk: EncryptionPublicKey<CS>,
677    ) -> Result<WrappedSeed>;
678
679    /// Add device to the team.
680    async fn add_device_to_team(team: TeamId, keys: KeyBundle) -> Result<()>;
681    /// Remove device from the team.
682    async fn remove_device_from_team(team: TeamId, device: DeviceId) -> Result<()>;
683
684    /// Assign a role to a device.
685    async fn assign_role(team: TeamId, device: DeviceId, role: Role) -> Result<()>;
686    /// Revoke a role from a device.
687    async fn revoke_role(team: TeamId, device: DeviceId, role: Role) -> Result<()>;
688
689    /// Assign a QUIC channels network identifier to a device.
690    async fn assign_aqc_net_identifier(
691        team: TeamId,
692        device: DeviceId,
693        name: NetIdentifier,
694    ) -> Result<()>;
695    /// Remove a QUIC channels network identifier from a device.
696    async fn remove_aqc_net_identifier(
697        team: TeamId,
698        device: DeviceId,
699        name: NetIdentifier,
700    ) -> Result<()>;
701
702    // Create a label.
703    async fn create_label(team: TeamId, name: Text) -> Result<LabelId>;
704    // Delete a label.
705    async fn delete_label(team: TeamId, label_id: LabelId) -> Result<()>;
706    // Assign a label to a device.
707    async fn assign_label(
708        team: TeamId,
709        device: DeviceId,
710        label_id: LabelId,
711        op: ChanOp,
712    ) -> Result<()>;
713    // Revoke a label from a device.
714    async fn revoke_label(team: TeamId, device: DeviceId, label_id: LabelId) -> Result<()>;
715
716    /// Create a bidirectional QUIC channel.
717    async fn create_aqc_bidi_channel(
718        team: TeamId,
719        peer: NetIdentifier,
720        label_id: LabelId,
721    ) -> Result<(AqcCtrl, AqcBidiPsks)>;
722    /// Create a unidirectional QUIC channel.
723    async fn create_aqc_uni_channel(
724        team: TeamId,
725        peer: NetIdentifier,
726        label_id: LabelId,
727    ) -> Result<(AqcCtrl, AqcUniPsks)>;
728    /// Delete a QUIC bidi channel.
729    async fn delete_aqc_bidi_channel(chan: AqcBidiChannelId) -> Result<AqcCtrl>;
730    /// Delete a QUIC uni channel.
731    async fn delete_aqc_uni_channel(chan: AqcUniChannelId) -> Result<AqcCtrl>;
732    /// Receive AQC ctrl message.
733    async fn receive_aqc_ctrl(team: TeamId, ctrl: AqcCtrl) -> Result<(LabelId, AqcPsks)>;
734
735    /// Query devices on team.
736    async fn query_devices_on_team(team: TeamId) -> Result<Vec<DeviceId>>;
737    /// Query device role.
738    async fn query_device_role(team: TeamId, device: DeviceId) -> Result<Role>;
739    /// Query device keybundle.
740    async fn query_device_keybundle(team: TeamId, device: DeviceId) -> Result<KeyBundle>;
741    /// Query device label assignments.
742    async fn query_device_label_assignments(team: TeamId, device: DeviceId) -> Result<Vec<Label>>;
743    /// Query AQC network ID.
744    async fn query_aqc_net_identifier(
745        team: TeamId,
746        device: DeviceId,
747    ) -> Result<Option<NetIdentifier>>;
748    // Query labels on team.
749    async fn query_labels(team: TeamId) -> Result<Vec<Label>>;
750    /// Query whether a label exists.
751    async fn query_label_exists(team: TeamId, label: LabelId) -> Result<bool>;
752}
753
754#[cfg(test)]
755mod tests {
756    use aranya_crypto::Rng;
757    use serde::de::DeserializeOwned;
758
759    use super::*;
760
761    fn secret(secret: &[u8]) -> Secret {
762        Secret(Box::from(secret))
763    }
764
765    pub(super) trait PskMap:
766        fmt::Debug + PartialEq + Serialize + DeserializeOwned + Sized
767    {
768        type Psk;
769        fn new() -> Self;
770        /// Returns the number of PSKs in the map.
771        fn len(&self) -> usize;
772        /// Adds `psk` to the map.
773        ///
774        /// # Panics
775        ///
776        /// Panics if `psk` already exists.
777        fn insert(&mut self, psk: Self::Psk);
778    }
779
780    impl PartialEq for AqcBidiPsk {
781        fn eq(&self, other: &Self) -> bool {
782            bool::from(self.ct_eq(other))
783        }
784    }
785    impl PartialEq for AqcUniPsk {
786        fn eq(&self, other: &Self) -> bool {
787            bool::from(self.ct_eq(other))
788        }
789    }
790    impl PartialEq for AqcPsk {
791        fn eq(&self, other: &Self) -> bool {
792            bool::from(self.ct_eq(other))
793        }
794    }
795
796    #[track_caller]
797    fn psk_map_test<M, F>(name: &'static str, mut f: F)
798    where
799        M: PskMap,
800        F: FnMut(Secret, Id, CipherSuiteId) -> M::Psk,
801    {
802        let mut psks = M::new();
803        for (i, &suite) in CipherSuiteId::all().iter().enumerate() {
804            let id = Id::random(&mut Rng);
805            let secret = secret(&i.to_le_bytes());
806            psks.insert(f(secret, id, suite));
807        }
808        assert_eq!(psks.len(), CipherSuiteId::all().len(), "{name}");
809
810        let bytes = postcard::to_allocvec(&psks).unwrap();
811        let got = postcard::from_bytes::<M>(&bytes).unwrap();
812        assert_eq!(got, psks, "{name}")
813    }
814
815    /// Test that we can correctly serialize and deserialize
816    /// [`AqcBidiPsk`].
817    #[test]
818    fn test_aqc_bidi_psks_serde() {
819        psk_map_test::<AqcBidiPsks, _>("AqcBidiPsk", |secret, id, suite| AqcBidiPsk {
820            identity: BidiPskId::from((id.into(), suite)),
821            secret,
822        });
823    }
824
825    /// Test that we can correctly serialize and deserialize
826    /// [`AqcUniPsk`].
827    #[test]
828    fn test_aqc_uni_psks_serde() {
829        psk_map_test::<AqcUniPsks, _>("AqcUniPsk (send)", |secret, id, suite| AqcUniPsk {
830            identity: UniPskId::from((id.into(), suite)),
831            secret: Directed::Send(secret),
832        });
833        psk_map_test::<AqcUniPsks, _>("AqcUniPsk (recv)", |secret, id, suite| AqcUniPsk {
834            identity: UniPskId::from((id.into(), suite)),
835            secret: Directed::Recv(secret),
836        });
837    }
838}