use std::net::SocketAddr;
use sha2::{Digest, Sha256};
use crate::rpc_codec::auth_envelope;
use crate::rpc_codec::{MacKey, PeerSeqSender, PeerSeqWindow};
use crate::swim::error::SwimError;
use super::{codec, message::SwimMessage};
pub fn addr_hash(addr: SocketAddr) -> u64 {
let mut h = Sha256::new();
h.update(addr.to_string().as_bytes());
let digest = h.finalize();
u64::from_le_bytes(digest[..8].try_into().expect("sha256 is 32 bytes"))
}
#[derive(Debug)]
pub struct SwimAuth {
mac_key: MacKey,
local_addr_hash: u64,
seq_out: PeerSeqSender,
seq_in: PeerSeqWindow,
}
impl SwimAuth {
pub fn new(mac_key: MacKey, local_addr: SocketAddr) -> Self {
Self {
mac_key,
local_addr_hash: addr_hash(local_addr),
seq_out: PeerSeqSender::new(),
seq_in: PeerSeqWindow::new(),
}
}
pub fn local_addr_hash(&self) -> u64 {
self.local_addr_hash
}
}
pub fn wrap(auth: &SwimAuth, _to: SocketAddr, msg: &SwimMessage) -> Result<Vec<u8>, SwimError> {
let inner = codec::encode(msg)?;
let seq = auth.seq_out.next();
let mut out = Vec::with_capacity(auth_envelope::ENVELOPE_OVERHEAD + inner.len());
auth_envelope::write_envelope(auth.local_addr_hash, seq, &inner, &auth.mac_key, &mut out)
.map_err(|e| SwimError::Encode {
detail: format!("swim envelope: {e}"),
})?;
Ok(out)
}
pub fn unwrap(auth: &SwimAuth, from: SocketAddr, bytes: &[u8]) -> Result<SwimMessage, SwimError> {
let (fields, inner_frame) =
auth_envelope::parse_envelope(bytes, &auth.mac_key).map_err(|e| SwimError::Decode {
detail: format!("swim envelope: {e}"),
})?;
let expected = addr_hash(from);
if fields.from_node_id != expected {
return Err(SwimError::Decode {
detail: format!(
"swim envelope from {from} claimed addr_hash {}, observed hash {}",
fields.from_node_id, expected
),
});
}
auth.seq_in
.accept(fields.from_node_id, fields.seq)
.map_err(|e| SwimError::Decode {
detail: format!("swim replay: {e}"),
})?;
codec::decode(inner_frame)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::swim::incarnation::Incarnation;
use crate::swim::wire::probe::{Ping, ProbeId};
use nodedb_types::NodeId;
use std::net::{IpAddr, Ipv4Addr};
fn addr(port: u16) -> SocketAddr {
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), port)
}
fn sample_msg() -> SwimMessage {
SwimMessage::Ping(Ping {
probe_id: ProbeId::new(1),
from: NodeId::new("a"),
incarnation: Incarnation::ZERO,
piggyback: vec![],
})
}
#[test]
fn addr_hash_is_deterministic() {
assert_eq!(addr_hash(addr(7001)), addr_hash(addr(7001)));
assert_ne!(addr_hash(addr(7001)), addr_hash(addr(7002)));
}
#[test]
fn roundtrip_across_independent_endpoints() {
let key = MacKey::from_bytes([0x11u8; 32]);
let sender = SwimAuth::new(key.clone(), addr(7001));
let receiver = SwimAuth::new(key, addr(7002));
let bytes = wrap(&sender, addr(7002), &sample_msg()).unwrap();
let msg = unwrap(&receiver, addr(7001), &bytes).unwrap();
assert_eq!(msg, sample_msg());
}
#[test]
fn rejects_spoofed_source_address() {
let key = MacKey::from_bytes([0x33u8; 32]);
let real_sender = SwimAuth::new(key.clone(), addr(7001));
let receiver = SwimAuth::new(key, addr(7002));
let bytes = wrap(&real_sender, addr(7002), &sample_msg()).unwrap();
let err = unwrap(&receiver, addr(9999), &bytes).unwrap_err();
assert!(err.to_string().contains("addr_hash"));
}
#[test]
fn rejects_tampered_mac() {
let key = MacKey::from_bytes([3u8; 32]);
let sender = SwimAuth::new(key.clone(), addr(7001));
let receiver = SwimAuth::new(key, addr(7002));
let mut bytes = wrap(&sender, addr(7002), &sample_msg()).unwrap();
let mac_start = bytes.len() - 32;
bytes[mac_start] ^= 0xFF;
let err = unwrap(&receiver, addr(7001), &bytes).unwrap_err();
assert!(err.to_string().contains("MAC verification failed"));
}
#[test]
fn rejects_replay() {
let key = MacKey::from_bytes([4u8; 32]);
let sender = SwimAuth::new(key.clone(), addr(7001));
let receiver = SwimAuth::new(key, addr(7002));
let bytes = wrap(&sender, addr(7002), &sample_msg()).unwrap();
unwrap(&receiver, addr(7001), &bytes).unwrap();
let err = unwrap(&receiver, addr(7001), &bytes).unwrap_err();
assert!(err.to_string().contains("replayed"));
}
#[test]
fn rejects_wrong_cluster_key() {
let k1 = MacKey::from_bytes([1u8; 32]);
let k2 = MacKey::from_bytes([2u8; 32]);
let sender = SwimAuth::new(k1, addr(7001));
let receiver = SwimAuth::new(k2, addr(7002));
let bytes = wrap(&sender, addr(7002), &sample_msg()).unwrap();
let err = unwrap(&receiver, addr(7001), &bytes).unwrap_err();
assert!(err.to_string().contains("MAC verification failed"));
}
}