use std::net::{IpAddr, SocketAddr};
use crate::types::TenantId;
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum FingerprintMode {
Strict,
Subnet,
Disabled,
}
impl Default for FingerprintMode {
fn default() -> Self {
Self::Subnet
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ClientFingerprint {
pub tenant_id: TenantId,
pub ip: IpAddr,
}
impl ClientFingerprint {
pub fn from_peer(tenant_id: TenantId, peer: &SocketAddr) -> Self {
Self {
tenant_id,
ip: peer.ip(),
}
}
pub fn new(tenant_id: TenantId, ip: IpAddr) -> Self {
Self { tenant_id, ip }
}
pub fn matches(&self, caller: &ClientFingerprint, mode: FingerprintMode) -> bool {
if self.tenant_id != caller.tenant_id {
return false;
}
match mode {
FingerprintMode::Disabled => true,
FingerprintMode::Strict => self.ip == caller.ip,
FingerprintMode::Subnet => same_subnet(self.ip, caller.ip),
}
}
}
fn same_subnet(a: IpAddr, b: IpAddr) -> bool {
match (a, b) {
(IpAddr::V4(x), IpAddr::V4(y)) => x.octets()[..3] == y.octets()[..3],
(IpAddr::V6(x), IpAddr::V6(y)) => x.octets()[..8] == y.octets()[..8],
_ => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{Ipv4Addr, Ipv6Addr};
fn v4(a: u8, b: u8, c: u8, d: u8) -> IpAddr {
IpAddr::V4(Ipv4Addr::new(a, b, c, d))
}
fn v6(s: &str) -> IpAddr {
IpAddr::V6(s.parse::<Ipv6Addr>().unwrap())
}
#[test]
fn tenant_divergence_rejects_in_all_modes() {
let captured = ClientFingerprint::new(TenantId::new(1), v4(10, 0, 0, 5));
let caller = ClientFingerprint::new(TenantId::new(2), v4(10, 0, 0, 5));
for mode in [
FingerprintMode::Strict,
FingerprintMode::Subnet,
FingerprintMode::Disabled,
] {
assert!(!captured.matches(&caller, mode), "mode {mode:?}");
}
}
#[test]
fn strict_mode_rejects_any_ip_difference() {
let captured = ClientFingerprint::new(TenantId::new(1), v4(10, 0, 0, 5));
let caller = ClientFingerprint::new(TenantId::new(1), v4(10, 0, 0, 6));
assert!(!captured.matches(&caller, FingerprintMode::Strict));
}
#[test]
fn subnet_mode_tolerates_ipv4_24_host_bits() {
let captured = ClientFingerprint::new(TenantId::new(1), v4(10, 0, 0, 5));
let caller = ClientFingerprint::new(TenantId::new(1), v4(10, 0, 0, 99));
assert!(captured.matches(&caller, FingerprintMode::Subnet));
}
#[test]
fn subnet_mode_rejects_different_ipv4_24() {
let captured = ClientFingerprint::new(TenantId::new(1), v4(10, 0, 0, 5));
let caller = ClientFingerprint::new(TenantId::new(1), v4(10, 0, 1, 5));
assert!(!captured.matches(&caller, FingerprintMode::Subnet));
}
#[test]
fn subnet_mode_tolerates_ipv6_64() {
let captured = ClientFingerprint::new(TenantId::new(1), v6("2001:db8::1"));
let caller = ClientFingerprint::new(TenantId::new(1), v6("2001:db8::ffff"));
assert!(captured.matches(&caller, FingerprintMode::Subnet));
}
#[test]
fn subnet_mode_rejects_different_ipv6_64() {
let captured = ClientFingerprint::new(TenantId::new(1), v6("2001:db8::1"));
let caller = ClientFingerprint::new(TenantId::new(1), v6("2001:db8:0:1::1"));
assert!(!captured.matches(&caller, FingerprintMode::Subnet));
}
#[test]
fn disabled_mode_ignores_ip() {
let captured = ClientFingerprint::new(TenantId::new(1), v4(10, 0, 0, 5));
let caller = ClientFingerprint::new(TenantId::new(1), v4(192, 168, 1, 1));
assert!(captured.matches(&caller, FingerprintMode::Disabled));
}
#[test]
fn mixed_address_families_never_match_under_subnet() {
let captured = ClientFingerprint::new(TenantId::new(1), v4(10, 0, 0, 5));
let caller = ClientFingerprint::new(TenantId::new(1), v6("2001:db8::1"));
assert!(!captured.matches(&caller, FingerprintMode::Subnet));
assert!(!captured.matches(&caller, FingerprintMode::Strict));
}
}