use crate::{
coloring::*,
labeled_prf as lprf,
msgs::*,
pqkem::*,
prftree::{SecretPrfTree, SecretPrfTreeBranch},
sodium::*,
util::*,
};
use anyhow::{bail, ensure, Context, Result};
use std::collections::hash_map::{
Entry::{Occupied, Vacant},
HashMap,
};
pub const RTX_BUFFER_SIZE: usize = max_usize(
<Envelope<(), InitHello<()>> as LenseView>::LEN,
<Envelope<(), InitConf<()>> as LenseView>::LEN,
);
pub type Timing = f64;
pub const BCE: Timing = -3600.0 * 24.0 * 356.0 * 10_000.0;
pub const UNENDING: Timing = 3600.0 * 8.0;
pub const REKEY_AFTER_TIME_RESPONDER: Timing = 120.0;
pub const REKEY_AFTER_TIME_INITIATOR: Timing = 130.0;
pub const REJECT_AFTER_TIME: Timing = 180.0;
pub const BISCUIT_EPOCH: Timing = 300.0;
pub const RETRANSMIT_ABORT: Timing = 120.0;
pub const RETRANSMIT_DELAY_GROWTH: Timing = 2.0;
pub const RETRANSMIT_DELAY_BEGIN: Timing = 0.5;
pub const RETRANSMIT_DELAY_END: Timing = 10.0;
pub const RETRANSMIT_DELAY_JITTER: Timing = 0.5;
pub const EVENT_GRACE: Timing = 0.0025;
pub fn has_happened(ev: Timing, now: Timing) -> bool {
(ev - now) < EVENT_GRACE
}
pub type SPk = Secret<{ StaticKEM::PK_LEN }>; pub type SSk = Secret<{ StaticKEM::SK_LEN }>;
pub type EPk = Public<{ EphemeralKEM::PK_LEN }>;
pub type ESk = Secret<{ EphemeralKEM::SK_LEN }>;
pub type SymKey = Secret<KEY_SIZE>;
pub type SymHash = Public<KEY_SIZE>;
pub type PeerId = Public<KEY_SIZE>;
pub type SessionId = Public<SESSION_ID_LEN>;
pub type BiscuitId = Public<BISCUIT_ID_LEN>;
pub type XAEADNonce = Public<XAEAD_NONCE_LEN>;
pub type MsgBuf = Public<MAX_MESSAGE_LEN>;
pub type PeerNo = usize;
#[derive(Debug)]
pub struct CryptoServer {
pub timebase: Timebase,
pub sskm: SSk,
pub spkm: SPk,
pub biscuit_ctr: BiscuitId,
pub biscuit_keys: [BiscuitKey; 2],
pub peers: Vec<Peer>,
pub index: HashMap<IndexKey, PeerNo>,
pub peer_poll_off: usize,
}
#[derive(Debug)]
pub struct BiscuitKey {
pub created_at: Timing,
pub key: SymKey,
}
#[derive(Hash, PartialEq, Eq, PartialOrd, Ord, Debug)]
pub enum IndexKey {
Peer(PeerId),
Sid(SessionId),
}
#[derive(Debug)]
pub struct Peer {
pub psk: SymKey,
pub spkt: SPk,
pub biscuit_used: BiscuitId,
pub session: Option<Session>,
pub handshake: Option<InitiatorHandshake>,
pub initiation_requested: bool,
}
impl Peer {
pub fn zero() -> Self {
Self {
psk: SymKey::zero(),
spkt: SPk::zero(),
biscuit_used: BiscuitId::zero(),
session: None,
initiation_requested: false,
handshake: None,
}
}
}
#[derive(Debug, Clone)]
pub struct HandshakeState {
pub sidi: SessionId,
pub sidr: SessionId,
pub ck: SecretPrfTreeBranch,
}
#[derive(Hash, PartialEq, Eq, PartialOrd, Ord, Debug, Copy, Clone)]
pub enum HandshakeRole {
Initiator,
Responder,
}
impl HandshakeRole {
pub fn is_initiator(&self) -> bool {
match *self {
HandshakeRole::Initiator => true,
HandshakeRole::Responder => false,
}
}
}
#[derive(Copy, Clone, Default, Hash, PartialEq, Eq, PartialOrd, Ord, Debug)]
pub enum HandshakeStateMachine {
#[default]
RespHello,
RespConf,
}
#[derive(Debug)]
pub struct InitiatorHandshake {
pub created_at: Timing,
pub next: HandshakeStateMachine,
pub core: HandshakeState,
pub eski: ESk,
pub epki: EPk,
pub tx_at: Timing,
pub tx_retry_at: Timing,
pub tx_count: usize,
pub tx_len: usize,
pub tx_buf: MsgBuf,
}
#[derive(Debug)]
pub struct Session {
pub created_at: Timing,
pub sidm: SessionId,
pub sidt: SessionId,
pub handshake_role: HandshakeRole,
pub ck: SecretPrfTreeBranch,
pub txkm: SymKey,
pub txkt: SymKey,
pub txnm: u64,
pub txnt: u64,
}
#[derive(Hash, PartialEq, Eq, PartialOrd, Ord, Debug)]
enum Lifecycle {
Void = 0,
Dead,
Retired,
Young,
}
trait Mortal {
fn created_at(&self, srv: &CryptoServer) -> Option<Timing>;
fn retire_at(&self, srv: &CryptoServer) -> Option<Timing>;
fn die_at(&self, srv: &CryptoServer) -> Option<Timing>;
}
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug)]
pub struct PeerPtr(pub usize);
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug)]
pub struct IniHsPtr(pub usize);
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug)]
pub struct SessionPtr(pub usize);
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug)]
pub struct BiscuitKeyPtr(pub usize);
impl PeerPtr {
pub fn get<'a>(&self, srv: &'a CryptoServer) -> &'a Peer {
&srv.peers[self.0]
}
pub fn get_mut<'a>(&self, srv: &'a mut CryptoServer) -> &'a mut Peer {
&mut srv.peers[self.0]
}
pub fn session(&self) -> SessionPtr {
SessionPtr(self.0)
}
pub fn hs(&self) -> IniHsPtr {
IniHsPtr(self.0)
}
}
impl IniHsPtr {
pub fn get<'a>(&self, srv: &'a CryptoServer) -> &'a Option<InitiatorHandshake> {
&srv.peers[self.0].handshake
}
pub fn get_mut<'a>(&self, srv: &'a mut CryptoServer) -> &'a mut Option<InitiatorHandshake> {
&mut srv.peers[self.0].handshake
}
pub fn peer(&self) -> PeerPtr {
PeerPtr(self.0)
}
pub fn insert<'a>(
&self,
srv: &'a mut CryptoServer,
hs: InitiatorHandshake,
) -> Result<&'a mut InitiatorHandshake> {
srv.register_session(hs.core.sidi, self.peer())?;
self.take(srv);
self.peer().get_mut(srv).initiation_requested = false;
Ok(self.peer().get_mut(srv).handshake.insert(hs))
}
pub fn take(&self, srv: &mut CryptoServer) -> Option<InitiatorHandshake> {
let r = self.peer().get_mut(srv).handshake.take();
if let Some(ref stale) = r {
srv.unregister_session_if_vacant(stale.core.sidi, self.peer());
}
r
}
}
impl SessionPtr {
pub fn get<'a>(&self, srv: &'a CryptoServer) -> &'a Option<Session> {
&srv.peers[self.0].session
}
pub fn get_mut<'a>(&self, srv: &'a mut CryptoServer) -> &'a mut Option<Session> {
&mut srv.peers[self.0].session
}
pub fn peer(&self) -> PeerPtr {
PeerPtr(self.0)
}
pub fn insert<'a>(&self, srv: &'a mut CryptoServer, ses: Session) -> Result<&'a mut Session> {
self.take(srv);
srv.register_session(ses.sidm, self.peer())?;
Ok(self.peer().get_mut(srv).session.insert(ses))
}
pub fn take(&self, srv: &mut CryptoServer) -> Option<Session> {
let r = self.peer().get_mut(srv).session.take();
if let Some(ref stale) = r {
srv.unregister_session_if_vacant(stale.sidm, self.peer());
}
r
}
}
impl BiscuitKeyPtr {
pub fn get<'a>(&self, srv: &'a CryptoServer) -> &'a BiscuitKey {
&srv.biscuit_keys[self.0]
}
pub fn get_mut<'a>(&self, srv: &'a mut CryptoServer) -> &'a mut BiscuitKey {
&mut srv.biscuit_keys[self.0]
}
}
impl CryptoServer {
pub fn new(sk: SSk, pk: SPk) -> CryptoServer {
let tb = Timebase::default();
CryptoServer {
sskm: sk,
spkm: pk,
timebase: tb,
biscuit_ctr: BiscuitId::new([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), biscuit_keys: [BiscuitKey::new(), BiscuitKey::new()],
peers: Vec::new(),
index: HashMap::new(),
peer_poll_off: 0,
}
}
pub fn biscuit_key_ptrs(&self) -> impl Iterator<Item = BiscuitKeyPtr> {
(0..self.biscuit_keys.len()).map(BiscuitKeyPtr)
}
#[rustfmt::skip]
pub fn pidm(&self) -> Result<PeerId> {
Ok(Public::new(
lprf::peerid()?
.mix(self.spkm.secret())?
.into_value()))
}
pub fn peer_ptrs_off(&self, n: usize) -> impl Iterator<Item = PeerPtr> {
let l = self.peers.len();
(0..l).map(move |i| PeerPtr((i + n) % l))
}
pub fn add_peer(&mut self, psk: Option<SymKey>, pk: SPk) -> Result<PeerPtr> {
let peer = Peer {
psk: psk.unwrap_or_else(SymKey::zero),
spkt: pk,
biscuit_used: BiscuitId::zero(),
session: None,
handshake: None,
initiation_requested: false,
};
let peerid = peer.pidt()?;
let peerno = self.peers.len();
match self.index.entry(IndexKey::Peer(peerid)) {
Occupied(_) => bail!(
"Cannot insert peer with id {:?}; peer with this id already registered.",
peerid
),
Vacant(e) => e.insert(peerno),
};
self.peers.push(peer);
Ok(PeerPtr(peerno))
}
pub fn register_session(&mut self, id: SessionId, peer: PeerPtr) -> Result<()> {
match self.index.entry(IndexKey::Sid(id)) {
Occupied(p) if PeerPtr(*p.get()) == peer => {} Occupied(_) => bail!("Cannot insert session with id {:?}; id is in use.", id),
Vacant(e) => {
e.insert(peer.0);
}
};
Ok(())
}
pub fn unregister_session(&mut self, id: SessionId) {
self.index.remove(&IndexKey::Sid(id));
}
pub fn unregister_session_if_vacant(&mut self, id: SessionId, peer: PeerPtr) {
match (peer.session().get(self), peer.hs().get(self)) {
(Some(ses), _) if ses.sidm == id => {}
(_, Some(hs)) if hs.core.sidi == id => {}
_ => self.unregister_session(id),
};
}
pub fn find_peer(&self, id: PeerId) -> Option<PeerPtr> {
self.index.get(&IndexKey::Peer(id)).map(|no| PeerPtr(*no))
}
pub fn lookup_handshake(&self, id: SessionId) -> Option<IniHsPtr> {
self.index
.get(&IndexKey::Sid(id)) .map(|no| IniHsPtr(*no)) .filter(|hsptr| {
hsptr
.get(self) .as_ref()
.map(|hs| hs.core.sidi == id) .unwrap_or(false) })
}
pub fn lookup_session(&self, id: SessionId) -> Option<SessionPtr> {
self.index
.get(&IndexKey::Sid(id))
.map(|no| SessionPtr(*no))
.filter(|sptr| {
sptr.get(self)
.as_ref()
.map(|ses| ses.sidm == id)
.unwrap_or(false)
})
}
pub fn active_biscuit_key(&mut self) -> BiscuitKeyPtr {
let (a, b) = (BiscuitKeyPtr(0), BiscuitKeyPtr(1));
let (t, u) = (a.get(self).created_at, b.get(self).created_at);
let r = if t >= u { a } else { b };
if r.lifecycle(self) == Lifecycle::Young {
return r;
}
let r = if t < u { a } else { b };
let tb = self.timebase.clone();
r.get_mut(self).randomize(&tb);
r
}
}
impl Peer {
pub fn new(psk: SymKey, pk: SPk) -> Peer {
Peer {
psk,
spkt: pk,
biscuit_used: BiscuitId::zero(),
session: None,
handshake: None,
initiation_requested: false,
}
}
#[rustfmt::skip]
pub fn pidt(&self) -> Result<PeerId> {
Ok(Public::new(
lprf::peerid()?
.mix(self.spkt.secret())?
.into_value()))
}
}
impl Session {
pub fn zero() -> Self {
Self {
created_at: 0.0,
sidm: SessionId::zero(),
sidt: SessionId::zero(),
handshake_role: HandshakeRole::Initiator,
ck: SecretPrfTree::zero().dup(),
txkm: SymKey::zero(),
txkt: SymKey::zero(),
txnm: 0,
txnt: 0,
}
}
}
impl BiscuitKey {
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
Self {
created_at: BCE,
key: SymKey::random(),
}
}
pub fn erase(&mut self) {
self.key.randomize();
self.created_at = BCE;
}
pub fn randomize(&mut self, tb: &Timebase) {
self.key.randomize();
self.created_at = tb.now();
}
}
impl Mortal for IniHsPtr {
fn created_at(&self, srv: &CryptoServer) -> Option<Timing> {
self.get(srv).as_ref().map(|hs| hs.created_at)
}
fn retire_at(&self, srv: &CryptoServer) -> Option<Timing> {
self.die_at(srv)
}
fn die_at(&self, srv: &CryptoServer) -> Option<Timing> {
self.created_at(srv).map(|t| t + REJECT_AFTER_TIME)
}
}
impl Mortal for SessionPtr {
fn created_at(&self, srv: &CryptoServer) -> Option<Timing> {
self.get(srv).as_ref().map(|p| p.created_at)
}
fn retire_at(&self, srv: &CryptoServer) -> Option<Timing> {
use HandshakeRole::*;
self.get(srv).as_ref().map(|p| {
let wait = match p.handshake_role {
Initiator => REKEY_AFTER_TIME_INITIATOR,
Responder => REKEY_AFTER_TIME_RESPONDER,
};
p.created_at + wait
})
}
fn die_at(&self, srv: &CryptoServer) -> Option<Timing> {
self.created_at(srv).map(|t| t + REJECT_AFTER_TIME)
}
}
impl Mortal for BiscuitKeyPtr {
fn created_at(&self, srv: &CryptoServer) -> Option<Timing> {
let t = self.get(srv).created_at;
if t < 0.0 {
None
} else {
Some(t)
}
}
fn retire_at(&self, srv: &CryptoServer) -> Option<Timing> {
self.created_at(srv).map(|t| t + BISCUIT_EPOCH)
}
fn die_at(&self, srv: &CryptoServer) -> Option<Timing> {
self.retire_at(srv).map(|t| t + BISCUIT_EPOCH)
}
}
trait MortalExt: Mortal {
fn life_left(&self, srv: &CryptoServer) -> Option<Timing>;
fn youth_left(&self, srv: &CryptoServer) -> Option<Timing>;
fn lifecycle(&self, srv: &CryptoServer) -> Lifecycle;
}
impl<T: Mortal> MortalExt for T {
fn life_left(&self, srv: &CryptoServer) -> Option<Timing> {
self.die_at(srv).map(|t| t - srv.timebase.now())
}
fn youth_left(&self, srv: &CryptoServer) -> Option<Timing> {
self.retire_at(srv).map(|t| t - srv.timebase.now())
}
fn lifecycle(&self, srv: &CryptoServer) -> Lifecycle {
match (self.youth_left(srv), self.life_left(srv)) {
(_, Some(t)) if has_happened(t, 0.0) => Lifecycle::Dead,
(Some(t), _) if has_happened(t, 0.0) => Lifecycle::Retired,
(Some(_), Some(_)) => Lifecycle::Young,
_ => Lifecycle::Void,
}
}
}
impl CryptoServer {
pub fn initiate_handshake(&mut self, peer: PeerPtr, tx_buf: &mut [u8]) -> Result<usize> {
let mut msg = tx_buf.envelope_truncating::<InitHello<()>>()?; self.handle_initiation(peer, msg.payload_mut().init_hello()?)?;
let len = self.seal_and_commit_msg(peer, MsgType::InitHello, msg)?;
peer.hs()
.store_msg_for_retransmission(self, &tx_buf[..len])?;
Ok(len)
}
}
#[derive(Debug)]
pub struct HandleMsgResult {
pub exchanged_with: Option<PeerPtr>,
pub resp: Option<usize>,
}
impl CryptoServer {
pub fn handle_msg(&mut self, rx_buf: &[u8], tx_buf: &mut [u8]) -> Result<HandleMsgResult> {
let seal_broken = "Message seal broken!";
let mut len = 0;
let mut exchanged = false;
ensure!(!rx_buf.is_empty(), "received empty message, ignoring it");
let peer = match rx_buf[0].try_into() {
Ok(MsgType::InitHello) => {
let msg_in = rx_buf.envelope::<InitHello<&[u8]>>()?;
ensure!(msg_in.check_seal(self)?, seal_broken);
let mut msg_out = tx_buf.envelope_truncating::<RespHello<&mut [u8]>>()?;
let peer = self.handle_init_hello(
msg_in.payload().init_hello()?,
msg_out.payload_mut().resp_hello()?,
)?;
len = self.seal_and_commit_msg(peer, MsgType::RespHello, msg_out)?;
peer
}
Ok(MsgType::RespHello) => {
let msg_in = rx_buf.envelope::<RespHello<&[u8]>>()?;
ensure!(msg_in.check_seal(self)?, seal_broken);
let mut msg_out = tx_buf.envelope_truncating::<InitConf<&mut [u8]>>()?;
let peer = self.handle_resp_hello(
msg_in.payload().resp_hello()?,
msg_out.payload_mut().init_conf()?,
)?;
len = self.seal_and_commit_msg(peer, MsgType::InitConf, msg_out)?;
peer.hs()
.store_msg_for_retransmission(self, &tx_buf[..len])?;
exchanged = true;
peer
}
Ok(MsgType::InitConf) => {
let msg_in = rx_buf.envelope::<InitConf<&[u8]>>()?;
ensure!(msg_in.check_seal(self)?, seal_broken);
let mut msg_out = tx_buf.envelope_truncating::<EmptyData<&mut [u8]>>()?;
let (peer, if_exchanged) = self.handle_init_conf(
msg_in.payload().init_conf()?,
msg_out.payload_mut().empty_data()?,
)?;
len = self.seal_and_commit_msg(peer, MsgType::EmptyData, msg_out)?;
exchanged = if_exchanged;
peer
}
Ok(MsgType::EmptyData) => {
let msg_in = rx_buf.envelope::<EmptyData<&[u8]>>()?;
ensure!(msg_in.check_seal(self)?, seal_broken);
self.handle_resp_conf(msg_in.payload().empty_data()?)?
}
Ok(MsgType::DataMsg) => bail!("DataMsg handling not implemented!"),
Ok(MsgType::CookieReply) => bail!("CookieReply handling not implemented!"),
Err(_) => {
bail!("CookieReply handling not implemented!")
}
};
Ok(HandleMsgResult {
exchanged_with: exchanged.then_some(peer),
resp: if len == 0 { None } else { Some(len) },
})
}
pub fn seal_and_commit_msg<M: LenseView>(
&mut self,
peer: PeerPtr,
msg_type: MsgType,
mut msg: Envelope<&mut [u8], M>,
) -> Result<usize> {
msg.msg_type_mut()[0] = msg_type as u8;
msg.seal(peer, self)?;
Ok(<Envelope<(), M> as LenseView>::LEN)
}
}
#[derive(Debug, Copy, Clone)]
pub struct Wait(Timing);
impl Wait {
fn immediate() -> Self {
Wait(0.0)
}
fn hibernate() -> Self {
Wait(UNENDING)
}
fn immediate_unless(cond: bool) -> Self {
if cond {
Self::hibernate()
} else {
Self::immediate()
}
}
fn or_hibernate(t: Option<Timing>) -> Self {
match t {
Some(u) => Wait(u),
None => Wait::hibernate(),
}
}
fn or_immediate(t: Option<Timing>) -> Self {
match t {
Some(u) => Wait(u),
None => Wait::immediate(),
}
}
fn and<T: Into<Wait>>(&self, o: T) -> Self {
let (a, b) = (self.0, o.into().0);
Wait(if a > b { a } else { b })
}
}
impl From<Timing> for Wait {
fn from(t: Timing) -> Wait {
Wait(t)
}
}
impl From<Option<Timing>> for Wait {
fn from(t: Option<Timing>) -> Wait {
Wait::or_hibernate(t)
}
}
#[derive(Debug, Copy, Clone)]
pub enum PollResult {
Sleep(Timing),
DeleteKey(PeerPtr),
SendInitiation(PeerPtr),
SendRetransmission(PeerPtr),
}
impl Default for PollResult {
fn default() -> Self {
Self::hibernate()
}
}
impl PollResult {
pub fn hibernate() -> Self {
Self::Sleep(UNENDING) }
pub fn peer(&self) -> Option<PeerPtr> {
use PollResult::*;
match *self {
DeleteKey(p) | SendInitiation(p) | SendRetransmission(p) => Some(p),
_ => None,
}
}
pub fn fold(&self, otr: PollResult) -> PollResult {
use PollResult::*;
match (*self, otr) {
(Sleep(a), Sleep(b)) => Sleep(f64::min(a, b)),
(a, Sleep(_b)) if a.saturated() => a,
(Sleep(_a), b) if b.saturated() => b,
_ => panic!(
"Do not fold two saturated poll results! It doesn't make sense; \
we would have to discard one of the events. \
As soon as some result that requires an action (i.e. something other than sleep \
is reached you should just return and have the API consumer poll again."
),
}
}
pub fn try_fold_with<F: FnOnce() -> Result<PollResult>>(&self, f: F) -> Result<PollResult> {
if self.saturated() {
Ok(*self)
} else {
Ok(self.fold(f()?))
}
}
pub fn poll_child<P: Pollable>(&self, srv: &mut CryptoServer, p: &P) -> Result<PollResult> {
self.try_fold_with(|| p.poll(srv))
}
pub fn poll_children<P, I>(&self, srv: &mut CryptoServer, iter: I) -> Result<PollResult>
where
P: Pollable,
I: Iterator<Item = P>,
{
let mut acc = *self;
for e in iter {
if acc.saturated() {
break;
}
acc = acc.fold(e.poll(srv)?);
}
Ok(acc)
}
pub fn sched<W: Into<Wait>, F: FnOnce() -> PollResult>(&self, wait: W, f: F) -> PollResult {
let wait = wait.into().0;
if self.saturated() {
*self
} else if has_happened(wait, 0.0) {
self.fold(f())
} else {
self.fold(Self::Sleep(wait))
}
}
pub fn try_sched<W: Into<Wait>, F: FnOnce() -> Result<PollResult>>(
&self,
wait: W,
f: F,
) -> Result<PollResult> {
let wait = wait.into().0;
if self.saturated() {
Ok(*self)
} else if has_happened(wait, 0.0) {
Ok(self.fold(f()?))
} else {
Ok(self.fold(Self::Sleep(wait)))
}
}
pub fn ok(&self) -> Result<PollResult> {
Ok(*self)
}
pub fn saturated(&self) -> bool {
use PollResult::*;
!matches!(self, Sleep(_))
}
}
pub fn begin_poll() -> PollResult {
PollResult::default()
}
pub fn void_poll<T, F: FnOnce() -> T>(f: F) -> impl FnOnce() -> PollResult {
|| {
f();
PollResult::default()
}
}
pub trait Pollable {
fn poll(&self, srv: &mut CryptoServer) -> Result<PollResult>;
}
impl CryptoServer {
pub fn poll(&mut self) -> Result<PollResult> {
let r = begin_poll() .poll_children(self, self.biscuit_key_ptrs())?
.poll_children(self, self.peer_ptrs_off(self.peer_poll_off))?;
self.peer_poll_off = match r.peer() {
Some(p) => p.0 + 1, None => 0, };
r.ok()
}
}
impl Pollable for BiscuitKeyPtr {
fn poll(&self, srv: &mut CryptoServer) -> Result<PollResult> {
begin_poll()
.sched(self.life_left(srv), void_poll(|| self.get_mut(srv).erase())) .ok()
}
}
impl Pollable for PeerPtr {
fn poll(&self, srv: &mut CryptoServer) -> Result<PollResult> {
let (ses, hs) = (self.session(), self.hs());
begin_poll()
.sched(hs.life_left(srv), void_poll(|| hs.take(srv))) .sched(ses.life_left(srv), || {
ses.take(srv);
PollResult::DeleteKey(*self)
})
.sched(
Wait::immediate_unless(self.get(srv).initiation_requested)
.and(Wait::or_immediate(ses.youth_left(srv)))
.and(Wait::or_immediate(hs.youth_left(srv))),
|| {
self.get_mut(srv).initiation_requested = true;
PollResult::SendInitiation(*self)
},
)
.poll_child(srv, &hs) }
}
impl Pollable for IniHsPtr {
fn poll(&self, srv: &mut CryptoServer) -> Result<PollResult> {
begin_poll().try_sched(self.retransmission_in(srv), || {
self.register_retransmission(srv)?;
Ok(PollResult::SendRetransmission(self.peer()))
})
}
}
impl CryptoServer {
pub fn retransmit_handshake(&mut self, peer: PeerPtr, tx_buf: &mut [u8]) -> Result<usize> {
peer.hs().apply_retransmission(self, tx_buf)
}
}
impl IniHsPtr {
pub fn store_msg_for_retransmission(&self, srv: &mut CryptoServer, msg: &[u8]) -> Result<()> {
let ih = self
.get_mut(srv)
.as_mut()
.with_context(|| format!("No current handshake for peer {:?}", self.peer()))?;
cpy_min(msg, &mut *ih.tx_buf);
ih.tx_count = 0;
ih.tx_len = msg.len();
self.register_retransmission(srv)?;
Ok(())
}
pub fn apply_retransmission(&self, srv: &mut CryptoServer, tx_buf: &mut [u8]) -> Result<usize> {
let ih = self
.get_mut(srv)
.as_mut()
.with_context(|| format!("No current handshake for peer {:?}", self.peer()))?;
cpy_min(&ih.tx_buf[..ih.tx_len], tx_buf);
Ok(ih.tx_len)
}
pub fn register_retransmission(&self, srv: &mut CryptoServer) -> Result<()> {
let tb = srv.timebase.clone();
let ih = self
.get_mut(srv)
.as_mut()
.with_context(|| format!("No current handshake for peer {:?}", self.peer()))?;
ih.tx_retry_at = tb.now()
+ RETRANSMIT_DELAY_BEGIN
* RETRANSMIT_DELAY_GROWTH.powf(
(RETRANSMIT_DELAY_END / RETRANSMIT_DELAY_BEGIN)
.log(RETRANSMIT_DELAY_GROWTH)
.min(ih.tx_count as f64),
)
* RETRANSMIT_DELAY_JITTER
* (rand_f64() + 1.0);
ih.tx_count += 1;
Ok(())
}
pub fn retransmission_in(&self, srv: &mut CryptoServer) -> Option<Timing> {
self.get(srv)
.as_ref()
.map(|hs| hs.tx_retry_at - srv.timebase.now())
}
}
impl<M> Envelope<&mut [u8], M>
where
M: LenseView,
{
pub fn seal(&mut self, peer: PeerPtr, srv: &CryptoServer) -> Result<()> {
let mac = lprf::mac()?
.mix(peer.get(srv).spkt.secret())?
.mix(self.until_mac())?;
self.mac_mut()
.copy_from_slice(mac.into_value()[..16].as_ref());
Ok(())
}
}
impl<M> Envelope<&[u8], M>
where
M: LenseView,
{
pub fn check_seal(&self, srv: &CryptoServer) -> Result<bool> {
let expected = lprf::mac()?.mix(srv.spkm.secret())?.mix(self.until_mac())?;
Ok(sodium_memcmp(self.mac(), &expected.into_value()[..16]))
}
}
impl InitiatorHandshake {
pub fn zero_with_timestamp(srv: &CryptoServer) -> Self {
InitiatorHandshake {
created_at: srv.timebase.now(),
next: HandshakeStateMachine::RespHello,
core: HandshakeState::zero(),
eski: ESk::zero(),
epki: EPk::zero(),
tx_at: 0.0,
tx_retry_at: 0.0,
tx_count: 0,
tx_len: 0,
tx_buf: MsgBuf::zero(),
}
}
}
impl HandshakeState {
pub fn zero() -> Self {
Self {
sidi: SessionId::zero(),
sidr: SessionId::zero(),
ck: SecretPrfTree::zero().dup(),
}
}
pub fn erase(&mut self) {
self.ck = SecretPrfTree::zero().dup();
}
pub fn init(&mut self, spkr: &[u8]) -> Result<&mut Self> {
self.ck = lprf::ckinit()?.mix(spkr)?.into_secret_prf_tree().dup();
Ok(self)
}
pub fn mix(&mut self, a: &[u8]) -> Result<&mut Self> {
self.ck = self.ck.mix(&lprf::mix()?)?.mix(a)?.dup();
Ok(self)
}
pub fn encrypt_and_mix(&mut self, ct: &mut [u8], pt: &[u8]) -> Result<&mut Self> {
let k = self.ck.mix(&lprf::hs_enc()?)?.into_secret();
aead_enc_into(ct, k.secret(), &NONCE0, &NOTHING, pt)?;
self.mix(ct)
}
pub fn decrypt_and_mix(&mut self, pt: &mut [u8], ct: &[u8]) -> Result<&mut Self> {
let k = self.ck.mix(&lprf::hs_enc()?)?.into_secret();
aead_dec_into(pt, k.secret(), &NONCE0, &NOTHING, ct)?;
self.mix(ct)
}
pub fn encaps_and_mix<T: KEM, const SHK_LEN: usize>(
&mut self,
ct: &mut [u8],
pk: &[u8],
) -> Result<&mut Self> {
let mut shk = Secret::<SHK_LEN>::zero();
T::encaps(shk.secret_mut(), ct, pk)?;
self.mix(pk)?.mix(shk.secret())?.mix(ct)
}
pub fn decaps_and_mix<T: KEM, const SHK_LEN: usize>(
&mut self,
sk: &[u8],
pk: &[u8],
ct: &[u8],
) -> Result<&mut Self> {
let mut shk = Secret::<SHK_LEN>::zero();
T::decaps(shk.secret_mut(), sk, ct)?;
self.mix(pk)?.mix(shk.secret())?.mix(ct)
}
pub fn store_biscuit(
&mut self,
srv: &mut CryptoServer,
peer: PeerPtr,
biscuit_ct: &mut [u8],
) -> Result<&mut Self> {
let mut biscuit = Secret::<BISCUIT_PT_LEN>::zero(); let mut biscuit = (&mut biscuit.secret_mut()[..]).biscuit()?;
biscuit
.pidi_mut()
.copy_from_slice(peer.get(srv).pidt()?.as_slice());
biscuit.biscuit_no_mut().copy_from_slice(&*srv.biscuit_ctr);
biscuit
.ck_mut()
.copy_from_slice(self.ck.clone().danger_into_secret().secret());
let ad = lprf::biscuit_ad()?
.mix(srv.spkm.secret())?
.mix(self.sidi.as_slice())?
.mix(self.sidr.as_slice())?
.into_value();
sodium_bigint_inc(&mut *srv.biscuit_ctr);
let bk = srv.active_biscuit_key();
let mut n = XAEADNonce::random();
n[0] &= 0b0111_1111;
n[0] |= (bk.0 as u8 & 0x1) << 7;
let k = bk.get(srv).key.secret();
let pt = biscuit.all_bytes();
xaead_enc_into(biscuit_ct, k, &*n, &ad, pt)?;
self.mix(biscuit_ct)
}
pub fn load_biscuit(
srv: &CryptoServer,
biscuit_ct: &[u8],
sidi: SessionId,
sidr: SessionId,
) -> Result<(PeerPtr, BiscuitId, HandshakeState)> {
let bk = BiscuitKeyPtr(((biscuit_ct[0] & 0b1000_0000) >> 7) as usize);
let ad = lprf::biscuit_ad()?
.mix(srv.spkm.secret())?
.mix(sidi.as_slice())?
.mix(sidr.as_slice())?
.into_value();
let mut biscuit = Secret::<BISCUIT_PT_LEN>::zero(); let mut biscuit = (&mut biscuit.secret_mut()[..]).biscuit()?; xaead_dec_into(
biscuit.all_bytes_mut(),
bk.get(srv).key.secret(),
&ad,
biscuit_ct,
)?;
let no = BiscuitId::from_slice(biscuit.biscuit_no());
let ck = SecretPrfTree::danger_from_secret(Secret::from_slice(biscuit.ck())).dup();
let pid = PeerId::from_slice(biscuit.pidi());
let mut hs = Self { sidi, sidr, ck };
hs.mix(biscuit_ct)?;
let peer = srv
.find_peer(pid) .with_context(|| format!("Could not decode biscuit for peer {pid:?}: No such peer."))?;
ensure!(
sodium_bigint_cmp(biscuit.biscuit_no(), &*peer.get(srv).biscuit_used) >= 0,
"Rejecting biscuit: Outdated biscuit number"
);
Ok((peer, no, hs))
}
pub fn enter_live(self, srv: &CryptoServer, role: HandshakeRole) -> Result<Session> {
let HandshakeState { ck, sidi, sidr } = self;
let tki = ck.mix(&lprf::ini_enc()?)?.into_secret();
let tkr = ck.mix(&lprf::res_enc()?)?.into_secret();
let created_at = srv.timebase.now();
let (ntx, nrx) = (0, 0);
let (mysid, peersid, ktx, krx) = match role {
HandshakeRole::Initiator => (sidi, sidr, tki, tkr),
HandshakeRole::Responder => (sidr, sidi, tkr, tki),
};
Ok(Session {
created_at,
sidm: mysid,
sidt: peersid,
handshake_role: role,
ck,
txkm: ktx,
txkt: krx,
txnm: ntx,
txnt: nrx,
})
}
}
impl CryptoServer {
pub fn osk(&self, peer: PeerPtr) -> Result<SymKey> {
let session = peer
.session()
.get(self)
.as_ref()
.with_context(|| format!("No current session for peer {:?}", peer))?;
Ok(session.ck.mix(&lprf::osk()?)?.into_secret())
}
}
impl CryptoServer {
pub fn handle_initiation(
&mut self,
peer: PeerPtr,
mut ih: InitHello<&mut [u8]>,
) -> Result<PeerPtr> {
let mut hs = InitiatorHandshake::zero_with_timestamp(self);
hs.core.init(peer.get(self).spkt.secret())?;
hs.core.sidi.randomize();
ih.sidi_mut().copy_from_slice(&hs.core.sidi.value);
EphemeralKEM::keygen(hs.eski.secret_mut(), &mut *hs.epki)?;
ih.epki_mut().copy_from_slice(&hs.epki.value);
hs.core.mix(ih.sidi())?.mix(ih.epki())?;
hs.core
.encaps_and_mix::<StaticKEM, { StaticKEM::SHK_LEN }>(
ih.sctr_mut(),
peer.get(self).spkt.secret(),
)?;
hs.core
.encrypt_and_mix(ih.pidic_mut(), self.pidm()?.as_ref())?;
hs.core
.mix(self.spkm.secret())?
.mix(peer.get(self).psk.secret())?;
hs.core.encrypt_and_mix(ih.auth_mut(), &NOTHING)?;
peer.hs().insert(self, hs)?;
Ok(peer)
}
pub fn handle_init_hello(
&mut self,
ih: InitHello<&[u8]>,
mut rh: RespHello<&mut [u8]>,
) -> Result<PeerPtr> {
let mut core = HandshakeState::zero();
core.sidi = SessionId::from_slice(ih.sidi());
core.init(self.spkm.secret())?;
core.mix(ih.sidi())?.mix(ih.epki())?;
core.decaps_and_mix::<StaticKEM, { StaticKEM::SHK_LEN }>(
self.sskm.secret(),
self.spkm.secret(),
ih.sctr(),
)?;
let peer = {
let mut peerid = PeerId::zero();
core.decrypt_and_mix(&mut *peerid, ih.pidic())?;
self.find_peer(peerid)
.with_context(|| format!("No such peer {peerid:?}."))?
};
core.mix(peer.get(self).spkt.secret())?
.mix(peer.get(self).psk.secret())?;
core.decrypt_and_mix(&mut [0u8; 0], ih.auth())?;
core.sidr.randomize();
rh.sidi_mut().copy_from_slice(core.sidi.as_ref());
rh.sidr_mut().copy_from_slice(core.sidr.as_ref());
core.mix(rh.sidr())?.mix(rh.sidi())?;
core.encaps_and_mix::<EphemeralKEM, { EphemeralKEM::SHK_LEN }>(rh.ecti_mut(), ih.epki())?;
core.encaps_and_mix::<StaticKEM, { StaticKEM::SHK_LEN }>(
rh.scti_mut(),
peer.get(self).spkt.secret(),
)?;
core.store_biscuit(self, peer, rh.biscuit_mut())?;
core.encrypt_and_mix(rh.auth_mut(), &NOTHING)?;
Ok(peer)
}
pub fn handle_resp_hello(
&mut self,
rh: RespHello<&[u8]>,
mut ic: InitConf<&mut [u8]>,
) -> Result<PeerPtr> {
let peer = self
.lookup_handshake(SessionId::from_slice(rh.sidi()))
.with_context(|| {
format!(
"Got RespHello packet for non-existent session {:?}",
rh.sidi()
)
})?
.peer();
macro_rules! hs {
() => {
peer.hs().get(self).as_ref().unwrap()
};
}
macro_rules! hs_mut {
() => {
peer.hs().get_mut(self).as_mut().unwrap()
};
}
let exp = hs!().next;
let got = HandshakeStateMachine::RespHello;
ensure!(
exp == got,
"Unexpected package in session {:?}. Expected {:?}, got {:?}.",
SessionId::from_slice(rh.sidi()),
exp,
got
);
let mut core = hs!().core.clone();
core.sidr.copy_from_slice(rh.sidr());
core.mix(rh.sidr())?.mix(rh.sidi())?;
core.decaps_and_mix::<EphemeralKEM, { EphemeralKEM::SHK_LEN }>(
hs!().eski.secret(),
&*hs!().epki,
rh.ecti(),
)?;
core.decaps_and_mix::<StaticKEM, { StaticKEM::SHK_LEN }>(
self.sskm.secret(),
self.spkm.secret(),
rh.scti(),
)?;
core.mix(rh.biscuit())?;
core.decrypt_and_mix(&mut [0u8; 0], rh.auth())?;
ic.sidi_mut().copy_from_slice(rh.sidi());
ic.sidr_mut().copy_from_slice(rh.sidr());
core.mix(ic.sidi())?.mix(ic.sidr())?;
ic.biscuit_mut().copy_from_slice(rh.biscuit());
core.encrypt_and_mix(ic.auth_mut(), &NOTHING)?;
peer.session()
.insert(self, core.enter_live(self, HandshakeRole::Initiator)?)?;
hs_mut!().core.erase();
hs_mut!().next = HandshakeStateMachine::RespConf;
Ok(peer)
}
pub fn handle_init_conf(
&mut self,
ic: InitConf<&[u8]>,
mut rc: EmptyData<&mut [u8]>,
) -> Result<(PeerPtr, bool)> {
let mut exchanged = false;
let (peer, biscuit_no, mut core) = HandshakeState::load_biscuit(
self,
ic.biscuit(),
SessionId::from_slice(ic.sidi()),
SessionId::from_slice(ic.sidr()),
)?;
core.encrypt_and_mix(&mut [0u8; AEAD_TAG_LEN], &NOTHING)?;
core.mix(ic.sidi())?.mix(ic.sidr())?;
core.decrypt_and_mix(&mut [0u8; 0], ic.auth())?;
if sodium_bigint_cmp(&*biscuit_no, &*peer.get(self).biscuit_used) > 0 {
peer.get_mut(self).biscuit_used = biscuit_no;
peer.session()
.insert(self, core.enter_live(self, HandshakeRole::Responder)?)?;
peer.hs().take(self);
exchanged = true;
}
let ses = peer
.session()
.get_mut(self)
.as_mut()
.context("Cannot send acknowledgement. No session.")?;
rc.sid_mut().copy_from_slice(&ses.sidt.value);
rc.ctr_mut().copy_from_slice(&ses.txnm.to_le_bytes());
ses.txnm += 1;
let n = cat!(AEAD_NONCE_LEN; rc.ctr(), &[0u8; 4]);
let k = ses.txkm.secret();
aead_enc_into(rc.auth_mut(), k, &n, &NOTHING, &NOTHING)?;
Ok((peer, exchanged))
}
pub fn handle_resp_conf(&mut self, rc: EmptyData<&[u8]>) -> Result<PeerPtr> {
let sid = SessionId::from_slice(rc.sid());
let hs = self
.lookup_handshake(sid)
.with_context(|| format!("Got RespConf packet for non-existent session {sid:?}"))?;
let ses = hs.peer().session();
let exp = hs.get(self).as_ref().map(|h| h.next);
let got = Some(HandshakeStateMachine::RespConf);
ensure!(
exp == got,
"Unexpected package in session {:?}. Expected {:?}, got {:?}.",
sid,
exp,
got
);
{
let s = ses.get_mut(self).as_mut().with_context(|| {
format!("Cannot validate EmptyData message. Missing encryption session for {sid:?}")
})?;
let n = u64::from_le_bytes(rc.ctr().try_into().unwrap());
ensure!(n >= s.txnt, "Stale nonce");
s.txnt = n;
aead_dec_into(
&mut [0u8; 0],
s.txkt.secret(),
&cat!(AEAD_NONCE_LEN; rc.ctr(), &[0u8; 4]),
&NOTHING,
rc.auth(),
)?;
}
hs.take(self);
Ok(hs.peer())
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn handles_incorrect_size_messages() {
crate::sodium::sodium_init().unwrap();
stacker::grow(8 * 1024 * 1024, || {
const OVERSIZED_MESSAGE: usize = ((MAX_MESSAGE_LEN as f32) * 1.2) as usize;
type MsgBufPlus = Public<OVERSIZED_MESSAGE>;
const PEER0: PeerPtr = PeerPtr(0);
let (mut me, mut they) = make_server_pair().unwrap();
let (mut msgbuf, mut resbuf) = (MsgBufPlus::zero(), MsgBufPlus::zero());
let mut msglen = Some(me.initiate_handshake(PEER0, &mut *resbuf).unwrap());
loop {
if let Some(l) = msglen {
std::mem::swap(&mut me, &mut they);
std::mem::swap(&mut msgbuf, &mut resbuf);
msglen = test_incorrect_sizes_for_msg(&mut me, &*msgbuf, l, &mut *resbuf);
} else {
break;
}
}
assert_eq!(
me.osk(PEER0).unwrap().secret(),
they.osk(PEER0).unwrap().secret()
);
});
}
fn test_incorrect_sizes_for_msg(
srv: &mut CryptoServer,
msgbuf: &[u8],
msglen: usize,
resbuf: &mut [u8],
) -> Option<usize> {
resbuf.fill(0);
for l in 0..(((msglen as f32) * 1.2) as usize) {
if l == msglen {
continue;
}
let res = srv.handle_msg(&msgbuf[..l], resbuf);
assert!(matches!(res, Err(_))); assert!(!resbuf.iter().find(|x| **x != 0).is_some()); }
srv.handle_msg(&msgbuf[..msglen], resbuf).unwrap().resp
}
fn keygen() -> Result<(SSk, SPk)> {
let (mut sk, mut pk) = (SSk::zero(), SPk::zero());
StaticKEM::keygen(sk.secret_mut(), pk.secret_mut())?;
Ok((sk, pk))
}
fn make_server_pair() -> Result<(CryptoServer, CryptoServer)> {
let psk = SymKey::random();
let ((ska, pka), (skb, pkb)) = (keygen()?, keygen()?);
let (mut a, mut b) = (
CryptoServer::new(ska, pka.clone()),
CryptoServer::new(skb, pkb.clone()),
);
a.add_peer(Some(psk.clone()), pkb)?;
b.add_peer(Some(psk), pka)?;
Ok((a, b))
}
}