use super::error::WireVersionError;
use super::negotiation::{VersionHandshake, VersionHandshakeAck, VersionRange, negotiate};
use super::types::WireVersion;
use crate::error::{ClusterError, Result};
use crate::wire::WIRE_VERSION;
const MAX_HANDSHAKE_BYTES: u32 = 4 * 1024;
pub fn local_version_range() -> VersionRange {
VersionRange::new(WireVersion(1), WireVersion(WIRE_VERSION))
}
async fn write_framed<T: serde::Serialize + zerompk::ToMessagePack>(
send: &mut quinn::SendStream,
msg: &T,
) -> Result<()> {
let payload = zerompk::to_msgpack_vec(msg).map_err(|e| ClusterError::Codec {
detail: format!("handshake serialize: {e}"),
})?;
let len: u32 = payload.len().try_into().map_err(|_| ClusterError::Codec {
detail: format!("handshake message too large: {} bytes", payload.len()),
})?;
send.write_all(&len.to_be_bytes())
.await
.map_err(|e| ClusterError::Transport {
detail: format!("handshake write length: {e}"),
})?;
send.write_all(&payload)
.await
.map_err(|e| ClusterError::Transport {
detail: format!("handshake write payload: {e}"),
})?;
Ok(())
}
async fn read_framed<T: serde::de::DeserializeOwned + for<'a> zerompk::FromMessagePack<'a>>(
recv: &mut quinn::RecvStream,
) -> Result<T> {
let mut len_buf = [0u8; 4];
recv.read_exact(&mut len_buf)
.await
.map_err(|e| ClusterError::Transport {
detail: format!("handshake read length: {e}"),
})?;
let len = u32::from_be_bytes(len_buf);
if len > MAX_HANDSHAKE_BYTES {
return Err(ClusterError::Codec {
detail: format!("handshake frame too large: {len} bytes (max {MAX_HANDSHAKE_BYTES})"),
});
}
let mut buf = vec![0u8; len as usize];
recv.read_exact(&mut buf)
.await
.map_err(|e| ClusterError::Transport {
detail: format!("handshake read payload: {e}"),
})?;
zerompk::from_msgpack(&buf).map_err(|e| ClusterError::Codec {
detail: format!("handshake deserialize: {e}"),
})
}
pub async fn perform_version_handshake_server(
conn: &quinn::Connection,
send: &mut quinn::SendStream,
recv: &mut quinn::RecvStream,
) -> Result<WireVersion> {
let client_hs: VersionHandshake = read_framed(recv).await?;
let remote_range = client_hs.to_range();
let local = local_version_range();
let agreed = negotiate(local, remote_range).map_err(|e| {
let reason = e.to_string();
let reason_bytes = reason.as_bytes();
conn.close(
quinn::VarInt::from_u32(0x01),
&reason_bytes[..reason_bytes.len().min(100)],
);
ClusterError::Transport {
detail: format!("wire version handshake failed (server): {e}"),
}
})?;
let ack = VersionHandshakeAck::new(agreed);
write_framed(send, &ack).await?;
Ok(agreed)
}
pub async fn perform_version_handshake_client(
send: &mut quinn::SendStream,
recv: &mut quinn::RecvStream,
) -> Result<WireVersion> {
let local = local_version_range();
let hs = VersionHandshake::from_range(local);
write_framed(send, &hs).await?;
let ack: VersionHandshakeAck = read_framed(recv).await?;
let agreed = ack.agreed_version();
if !local.contains(agreed) {
return Err(ClusterError::Transport {
detail: format!(
"server returned agreed version {} outside our supported range {}..={}",
agreed, local.min, local.max
),
});
}
Ok(agreed)
}
pub fn close_on_version_error(conn: &quinn::Connection, e: WireVersionError) -> ClusterError {
let reason = e.to_string();
let reason_bytes = reason.as_bytes();
conn.close(
quinn::VarInt::from_u32(0x01),
&reason_bytes[..reason_bytes.len().min(100)],
);
ClusterError::Transport {
detail: format!("wire version mismatch: {e}"),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::wire_version::negotiation::VersionRange;
use crate::wire_version::types::WireVersion;
fn v(n: u16) -> WireVersion {
WireVersion(n)
}
fn range(min: u16, max: u16) -> VersionRange {
VersionRange::new(v(min), v(max))
}
#[test]
fn local_range_is_valid() {
let r = local_version_range();
assert!(
r.min <= r.max,
"local range min ({}) must be <= max ({})",
r.min,
r.max
);
assert_eq!(r.min, v(1));
assert_eq!(r.max, v(WIRE_VERSION));
}
#[test]
fn negotiate_in_range_succeeds() {
let local = local_version_range();
let result = negotiate(local, local);
assert!(
result.is_ok(),
"identical ranges must negotiate: {result:?}"
);
assert_eq!(result.unwrap(), local.max);
}
#[test]
fn negotiate_disjoint_range_fails() {
let local = range(1, 2);
let remote = range(100, 200);
let err = negotiate(local, remote).unwrap_err();
assert!(
matches!(err, WireVersionError::NegotiationFailed { .. }),
"expected NegotiationFailed, got: {err}"
);
}
#[test]
fn handshake_capabilities_roundtrip() {
let caps = 0xABCD_1234_5678_EF01_u64;
let hs = VersionHandshake {
range: (1, 3),
capabilities: caps,
};
let bytes = zerompk::to_msgpack_vec(&hs).unwrap();
let decoded: VersionHandshake = zerompk::from_msgpack(&bytes).unwrap();
assert_eq!(decoded.capabilities, caps);
assert_eq!(decoded.range, (1, 3));
}
#[test]
fn ack_capabilities_roundtrip() {
let caps = 0xFEDC_BA98_7654_3210_u64;
let ack = VersionHandshakeAck {
agreed: 2,
capabilities: caps,
};
let bytes = zerompk::to_msgpack_vec(&ack).unwrap();
let decoded: VersionHandshakeAck = zerompk::from_msgpack(&bytes).unwrap();
assert_eq!(decoded.agreed_version(), v(2));
assert_eq!(decoded.capabilities, caps);
}
}