aranya-daemon 6.0.0

Daemon process for syncing with Aranya peers and maintaining the DAG
Documentation
//! PSK setup for rustls for use with QUIC connections

use std::{collections::HashMap, sync::Arc};

use anyhow::{bail, Context as _, Result};
use aranya_crypto::{policy::GroupId, tls::PskSeedId, Csprng, KeyStoreExt as _, PolicyId};
use aranya_daemon_api::{CipherSuiteId, Ikm, TeamId};
use s2n_quic::provider::tls::rustls::rustls::{
    client,
    crypto::{hash::HashAlgorithm, PresharedKey},
    pki_types::ServerName,
    server,
};

use crate::{keystore::LocalStore, CE, CS, KS};

type TeamIdPSKPair = (TeamId, Arc<PresharedKey>);

const QUIC_SYNC_PSK_CONTEXT: &[u8] = b"AranyaQuicSync-v1";

#[derive(Clone)]
pub(crate) struct PskSeed(pub(crate) aranya_crypto::tls::PskSeed<CS>);

impl PskSeed {
    pub(crate) fn new(rng: impl Csprng, team: TeamId) -> Self {
        let group = GroupId::transmute(team);
        Self(aranya_crypto::tls::PskSeed::new(rng, &group))
    }

    pub(crate) fn import_from_ikm(ikm: &Ikm, team: TeamId) -> Self {
        let group = GroupId::transmute(team);
        Self(aranya_crypto::tls::PskSeed::import_from_ikm(
            ikm.raw_ikm_bytes(),
            &group,
        ))
    }

    pub(crate) fn load(eng: &CE, store: &LocalStore<KS>, id: PskSeedId) -> Result<Option<Self>> {
        store
            .get_key(eng, id)
            .map(|r| r.map(Self))
            .map_err(Into::into)
    }

    pub(crate) fn into_inner(self) -> aranya_crypto::tls::PskSeed<CS> {
        self.0
    }

    pub(crate) fn generate_psks(self, team: TeamId) -> impl Iterator<Item = Result<PresharedKey>> {
        let group = GroupId::transmute(team);
        let policy = PolicyId::default();
        self.0
            .generate_psks(
                QUIC_SYNC_PSK_CONTEXT,
                group,
                policy,
                CipherSuiteId::all().iter().copied(),
            )
            .map(|r| {
                let psk = r?;
                psk_to_rustls(psk)
            })
    }
}

fn psk_to_rustls(psk: aranya_crypto::tls::Psk<CS>) -> Result<PresharedKey> {
    let identity = psk.identity().as_bytes();
    let secret = psk.raw_secret_bytes();
    let alg = match psk.identity().cipher_suite() {
        CipherSuiteId::TlsAes128GcmSha256 => HashAlgorithm::SHA256,
        CipherSuiteId::TlsAes256GcmSha384 => HashAlgorithm::SHA384,
        CipherSuiteId::TlsChaCha20Poly1305Sha256 => HashAlgorithm::SHA256,
        CipherSuiteId::TlsAes128CcmSha256 => HashAlgorithm::SHA256,
        CipherSuiteId::TlsAes128Ccm8Sha256 => HashAlgorithm::SHA256,
        cs => bail!("unknown ciphersuite {cs}"),
    };
    let psk = PresharedKey::external(identity, secret)
        .context("unable to create PSK")?
        .with_hash_alg(alg)
        .context("Invalid hash algorithm")?;
    Ok(psk)
}

/// Shared PSK store that's shared between the [`QuicListener`] and [`QuicConnector`].
///
/// [`QuicListener`]: super::QuicListener
/// [`QuicConnector`]: super::QuicConnector
#[derive(Debug, Default)]
pub(crate) struct PskStore {
    inner: std::sync::Mutex<PskStoreInner>,
    /// Lock for active team to avoid doubly setting it.
    team_lock: tokio::sync::Mutex<()>,
}

impl PskStore {
    pub(crate) fn new<I>(initial_keys: I) -> Self
    where
        I: IntoIterator<Item = TeamIdPSKPair>,
    {
        let mut team_identities = HashMap::<TeamId, Vec<Arc<PresharedKey>>>::new();
        let mut identity_psk = HashMap::new();

        for (team_id, psk) in initial_keys {
            team_identities
                .entry(team_id)
                .or_default()
                .push(Arc::clone(&psk));
            identity_psk.insert(PskIdAsKey(psk), team_id);
        }

        Self {
            inner: std::sync::Mutex::new(PskStoreInner {
                active_team: None,
                team_identities,
                identity_team: identity_psk,
            }),
            team_lock: tokio::sync::Mutex::default(),
        }
    }

    pub(crate) fn insert(&self, team_id: TeamId, psk: Arc<PresharedKey>) {
        #[allow(clippy::expect_used, reason = "poison")]
        let mut inner = self.inner.lock().expect("poisoned");
        inner
            .team_identities
            .entry(team_id)
            .or_default()
            .push(Arc::clone(&psk));
        inner.identity_team.insert(PskIdAsKey(psk), team_id);
    }

    pub(crate) fn remove(&self, team_id: TeamId) {
        #[allow(clippy::expect_used, reason = "poison")]
        let mut inner = self.inner.lock().expect("poisoned");
        inner.team_identities.remove(&team_id);
        inner.identity_team.retain(|_, other| *other != team_id);
    }

    /// Set the active team used while connecting.
    ///
    /// The guard must be held until after the connection is made.
    pub(crate) async fn set_team(&self, team_id: TeamId) -> tokio::sync::MutexGuard<'_, ()> {
        let guard = self.team_lock.lock().await;
        #[allow(clippy::expect_used, reason = "poison")]
        let mut inner = self.inner.lock().expect("poisoned");
        let _ = inner.active_team.replace(team_id);
        guard
    }

    pub(super) fn get_team_for_identity(&self, identity: &[u8]) -> Option<TeamId> {
        #[allow(clippy::expect_used, reason = "poison")]
        let inner = self.inner.lock().expect("poisoned");
        inner.identity_team.get(identity).copied()
    }
}

impl client::PresharedKeyStore for PskStore {
    #[allow(clippy::expect_used)]
    fn psks(&self, _server_name: &ServerName<'_>) -> Vec<Arc<PresharedKey>> {
        let inner = self.inner.lock().expect("poisoned mutex");

        let Some(active_team) = inner.active_team.as_ref() else {
            return Vec::new();
        };
        let Some(active_identities) = inner.team_identities.get(active_team) else {
            return Vec::new();
        };
        active_identities.clone()
    }
}

impl server::SelectsPresharedKeys for PskStore {
    #[allow(clippy::expect_used)]
    fn load_psk(&self, identity: &[u8]) -> Option<Arc<PresharedKey>> {
        let inner = self.inner.lock().expect("poison");

        let (k, _) = inner.identity_team.get_key_value(identity)?;
        Some(k.0.clone())
    }
}

#[derive(Debug, Default)]
struct PskStoreInner {
    team_identities: HashMap<TeamId, Vec<Arc<PresharedKey>>>,
    identity_team: HashMap<PskIdAsKey, TeamId>,
    /// Indicates the "active team".
    /// Used by [`PskStore`] to restrict the PSKs that are offered by the client.
    active_team: Option<TeamId>,
}

#[derive(Debug)]
struct PskIdAsKey(Arc<PresharedKey>);
impl core::hash::Hash for PskIdAsKey {
    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
        self.0.identity().hash(state);
    }
}
impl core::borrow::Borrow<[u8]> for PskIdAsKey {
    fn borrow(&self) -> &[u8] {
        self.0.identity()
    }
}
impl Eq for PskIdAsKey {}
impl PartialEq for PskIdAsKey {
    fn eq(&self, other: &Self) -> bool {
        self.0.identity() == other.0.identity()
    }
}