use vox_types::{
ConnectionSettings, HandshakeMessage, HandshakeResult, LinkRx, LinkTx, ResumeKeyBytes, Schema,
SessionResumeKey, SessionRole,
};
#[derive(Debug)]
pub enum HandshakeError {
Io(std::io::Error),
Encode(String),
Decode(String),
PeerClosed,
Protocol(String),
Sorry(String),
NotResumable,
}
impl std::fmt::Display for HandshakeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Io(e) => write!(f, "handshake io error: {e}"),
Self::Encode(e) => write!(f, "handshake encode error: {e}"),
Self::Decode(e) => write!(f, "handshake decode error: {e}"),
Self::PeerClosed => write!(f, "peer closed during handshake"),
Self::Protocol(msg) => write!(f, "handshake protocol error: {msg}"),
Self::Sorry(reason) => write!(f, "handshake rejected: {reason}"),
Self::NotResumable => write!(f, "session is not resumable"),
}
}
}
impl std::error::Error for HandshakeError {}
fn message_schema() -> Vec<Schema> {
vox_types::extract_schemas(<vox_types::Message<'static> as facet::Facet<'static>>::SHAPE)
.expect("schema extraction")
.schemas
.clone()
}
async fn send_handshake<Tx: LinkTx>(tx: &Tx, msg: &HandshakeMessage) -> Result<(), HandshakeError> {
let bytes = facet_cbor::to_vec(msg).map_err(|e| HandshakeError::Encode(e.to_string()))?;
vox_types::dlog!(
"[handshake] send {:?} ({} bytes)",
handshake_tag(msg),
bytes.len()
);
tx.send(bytes).await.map_err(HandshakeError::Io)
}
async fn recv_handshake<Rx: LinkRx>(rx: &mut Rx) -> Result<HandshakeMessage, HandshakeError> {
let backing = rx
.recv()
.await
.map_err(|error| HandshakeError::Io(std::io::Error::other(error.to_string())))?
.ok_or(HandshakeError::PeerClosed)?;
vox_types::dlog!(
"[handshake] recv raw frame ({} bytes)",
backing.as_bytes().len()
);
let msg = facet_cbor::from_slice(backing.as_bytes())
.map_err(|e| HandshakeError::Decode(e.to_string()))?;
vox_types::dlog!("[handshake] recv {:?}", handshake_tag(&msg));
Ok(msg)
}
fn handshake_tag(msg: &HandshakeMessage) -> &'static str {
match msg {
HandshakeMessage::Hello(_) => "Hello",
HandshakeMessage::HelloYourself(_) => "HelloYourself",
HandshakeMessage::LetsGo(_) => "LetsGo",
HandshakeMessage::Sorry(_) => "Sorry",
}
}
pub async fn handshake_as_initiator<Tx: LinkTx, Rx: LinkRx>(
tx: &Tx,
rx: &mut Rx,
settings: ConnectionSettings,
supports_retry: bool,
resume_key: Option<&SessionResumeKey>,
metadata: vox_types::Metadata<'static>,
) -> Result<HandshakeResult, HandshakeError> {
let our_schema = message_schema();
let hello = vox_types::Hello {
parity: settings.parity,
connection_settings: settings.clone(),
message_payload_schema: our_schema.clone(),
supports_retry,
resume_key: resume_key.map(ResumeKeyBytes::from_key),
metadata,
};
send_handshake(tx, &HandshakeMessage::Hello(hello)).await?;
let response = recv_handshake(rx).await?;
let hy = match response {
HandshakeMessage::HelloYourself(hy) => hy,
HandshakeMessage::Sorry(sorry) => return Err(HandshakeError::Sorry(sorry.reason)),
_ => {
return Err(HandshakeError::Protocol(
"expected HelloYourself or Sorry".into(),
));
}
};
send_handshake(tx, &HandshakeMessage::LetsGo(vox_types::LetsGo {})).await?;
let session_resume_key = hy.resume_key.as_ref().and_then(|k| k.to_key());
Ok(HandshakeResult {
role: SessionRole::Initiator,
our_settings: settings,
peer_settings: hy.connection_settings,
peer_supports_retry: hy.supports_retry,
session_resume_key,
peer_resume_key: None, our_schema,
peer_schema: hy.message_payload_schema,
peer_metadata: hy.metadata,
})
}
pub async fn handshake_as_acceptor<Tx: LinkTx, Rx: LinkRx>(
tx: &Tx,
rx: &mut Rx,
settings: ConnectionSettings,
supports_retry: bool,
resumable: bool,
expected_resume_key: Option<&SessionResumeKey>,
metadata: vox_types::Metadata<'static>,
) -> Result<HandshakeResult, HandshakeError> {
let hello = match recv_handshake(rx).await? {
HandshakeMessage::Hello(h) => h,
_ => return Err(HandshakeError::Protocol("expected Hello".into())),
};
if let Some(expected) = expected_resume_key {
let actual = hello.resume_key.as_ref().and_then(|k| k.to_key());
match actual {
Some(actual) if actual == *expected => {} _ => {
let reason = "session resume key mismatch".to_string();
send_handshake(
tx,
&HandshakeMessage::Sorry(vox_types::Sorry {
reason: reason.clone(),
}),
)
.await?;
return Err(HandshakeError::Protocol(reason));
}
}
}
let our_settings = ConnectionSettings {
parity: hello.parity.other(),
..settings
};
let our_resume_key = if resumable {
Some(fresh_resume_key()?)
} else {
None
};
let our_schema = message_schema();
let hy = vox_types::HelloYourself {
connection_settings: our_settings.clone(),
message_payload_schema: our_schema.clone(),
supports_retry,
resume_key: our_resume_key.as_ref().map(ResumeKeyBytes::from_key),
metadata,
};
send_handshake(tx, &HandshakeMessage::HelloYourself(hy)).await?;
let response = recv_handshake(rx).await?;
match response {
HandshakeMessage::LetsGo(_) => {}
HandshakeMessage::Sorry(sorry) => return Err(HandshakeError::Sorry(sorry.reason)),
_ => return Err(HandshakeError::Protocol("expected LetsGo or Sorry".into())),
}
let peer_resume_key = hello.resume_key.as_ref().and_then(|k| k.to_key());
Ok(HandshakeResult {
role: SessionRole::Acceptor,
our_settings,
peer_settings: hello.connection_settings,
peer_supports_retry: hello.supports_retry,
session_resume_key: our_resume_key,
peer_resume_key,
our_schema,
peer_schema: hello.message_payload_schema,
peer_metadata: hello.metadata,
})
}
fn fresh_resume_key() -> Result<SessionResumeKey, HandshakeError> {
let mut bytes = [0u8; 16];
getrandom::fill(&mut bytes).map_err(|error| {
HandshakeError::Protocol(format!("failed to generate session key: {error}"))
})?;
Ok(SessionResumeKey(bytes))
}