use std::net::{IpAddr, SocketAddr};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct FlowId {
pub src_ip: IpAddr,
pub dst_ip: IpAddr,
pub src_port: u16,
pub dst_port: u16,
pub protocol: u8,
}
impl FlowId {
pub fn new(
src_ip: IpAddr,
dst_ip: IpAddr,
src_port: u16,
dst_port: u16,
protocol: u8,
) -> Self {
Self {
src_ip,
dst_ip,
src_port,
dst_port,
protocol,
}
}
pub fn from_udp(src: SocketAddr, dst: SocketAddr) -> Self {
Self::new(src.ip(), dst.ip(), src.port(), dst.port(), 17)
}
pub fn from_tcp(src: SocketAddr, dst: SocketAddr) -> Self {
Self::new(src.ip(), dst.ip(), src.port(), dst.port(), 6)
}
pub fn flow_hash(&self) -> u16 {
let hash = self.compute_hash();
if hash == 0 { 0xffff } else { hash }
}
fn compute_hash(&self) -> u16 {
let mut hash: u32 = 0;
hash = hash.wrapping_add(u32::from(self.protocol));
match self.src_ip {
IpAddr::V4(addr) => {
let octets = addr.octets();
hash = hash.wrapping_add(u32::from(u16::from_be_bytes([octets[0], octets[1]])));
hash = hash.wrapping_add(u32::from(u16::from_be_bytes([octets[2], octets[3]])));
}
IpAddr::V6(addr) => {
let octets = addr.octets();
for chunk in octets.chunks(2) {
hash = hash.wrapping_add(u32::from(u16::from_be_bytes([chunk[0], chunk[1]])));
}
}
}
match self.dst_ip {
IpAddr::V4(addr) => {
let octets = addr.octets();
hash = hash.wrapping_add(u32::from(u16::from_be_bytes([octets[0], octets[1]])));
hash = hash.wrapping_add(u32::from(u16::from_be_bytes([octets[2], octets[3]])));
}
IpAddr::V6(addr) => {
let octets = addr.octets();
for chunk in octets.chunks(2) {
hash = hash.wrapping_add(u32::from(u16::from_be_bytes([chunk[0], chunk[1]])));
}
}
}
hash = hash.wrapping_add(u32::from(self.src_port));
hash = hash.wrapping_add(u32::from(self.dst_port));
while hash > 0xffff {
hash = (hash & 0xffff) + (hash >> 16);
}
hash as u16
}
pub fn with_src_port(&self, port: u16) -> Self {
Self {
src_port: port,
..*self
}
}
pub fn with_dst_port(&self, port: u16) -> Self {
Self {
dst_port: port,
..*self
}
}
}
pub fn calculate_flow_hash(
src_ip: IpAddr,
dst_ip: IpAddr,
src_port: u16,
dst_port: u16,
protocol: u8,
) -> u16 {
FlowId::new(src_ip, dst_ip, src_port, dst_port, protocol).flow_hash()
}
pub fn flow_hash_from_addrs(src: SocketAddr, dst: SocketAddr, is_tcp: bool) -> u16 {
if is_tcp {
FlowId::from_tcp(src, dst).flow_hash()
} else {
FlowId::from_udp(src, dst).flow_hash()
}
}
#[derive(Debug)]
pub struct EcmpPathEnumerator {
base_flow: FlowId,
base_port: u16,
num_paths: u16,
use_src_port: bool,
}
impl EcmpPathEnumerator {
pub fn new(base_flow: FlowId, base_port: u16, num_paths: u16, use_src_port: bool) -> Self {
Self {
base_flow,
base_port,
num_paths,
use_src_port,
}
}
pub fn flows(&self) -> Vec<FlowId> {
(0..self.num_paths)
.map(|i| {
let port = self.base_port.wrapping_add(i);
if self.use_src_port {
self.base_flow.with_src_port(port)
} else {
self.base_flow.with_dst_port(port)
}
})
.collect()
}
pub fn unique_hashes(&self) -> Vec<u16> {
let mut hashes: Vec<_> = self.flows().iter().map(FlowId::flow_hash).collect();
hashes.sort_unstable();
hashes.dedup();
hashes
}
pub fn estimated_path_count(&self) -> usize {
self.unique_hashes().len()
}
}
#[derive(Debug, Default)]
pub struct FlowHashBucket {
buckets: std::collections::HashMap<u16, Vec<FlowId>>,
}
impl FlowHashBucket {
pub fn new() -> Self {
Self::default()
}
pub fn add(&mut self, flow: FlowId) {
let hash = flow.flow_hash();
self.buckets.entry(hash).or_default().push(flow);
}
pub fn bucket_count(&self) -> usize {
self.buckets.len()
}
pub fn get_bucket(&self, hash: u16) -> Option<&Vec<FlowId>> {
self.buckets.get(&hash)
}
pub fn iter(&self) -> impl Iterator<Item = (&u16, &Vec<FlowId>)> {
self.buckets.iter()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{Ipv4Addr, Ipv6Addr};
#[test]
fn test_flow_hash_consistency() {
let flow = FlowId::new(
IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
12345,
80,
6,
);
let hash1 = flow.flow_hash();
let hash2 = flow.flow_hash();
assert_eq!(hash1, hash2);
}
#[test]
fn test_different_ports_different_hash() {
let flow1 = FlowId::new(
IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
12345,
80,
17,
);
let flow2 = flow1.with_src_port(12346);
let hash1 = flow1.flow_hash();
let hash2 = flow2.flow_hash();
assert_ne!(hash1, hash2);
}
#[test]
fn test_flow_hash_never_zero() {
for port in 1..1000 {
let flow = FlowId::new(
IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)),
IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)),
port,
port,
0,
);
assert_ne!(flow.flow_hash(), 0);
}
}
#[test]
fn test_ipv6_flow_hash() {
let flow = FlowId::new(
IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1)),
IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 2)),
12345,
443,
6,
);
let hash = flow.flow_hash();
assert_ne!(hash, 0);
}
#[test]
fn test_ecmp_enumerator() {
let base = FlowId::new(
IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)),
10000,
53,
17,
);
let enumerator = EcmpPathEnumerator::new(base, 10000, 16, true);
let flows = enumerator.flows();
assert_eq!(flows.len(), 16);
let unique = enumerator.estimated_path_count();
assert!(unique > 1);
}
#[test]
fn test_flow_hash_bucket() {
let mut bucket = FlowHashBucket::new();
for port in 10000..10016 {
let flow = FlowId::new(
IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)),
port,
80,
6,
);
bucket.add(flow);
}
assert!(bucket.bucket_count() > 0);
}
#[test]
fn test_from_socket_addrs() {
let src = SocketAddr::from(([192, 168, 1, 1], 12345));
let dst = SocketAddr::from(([10, 0, 0, 1], 80));
let flow_tcp = FlowId::from_tcp(src, dst);
let flow_udp = FlowId::from_udp(src, dst);
assert_eq!(flow_tcp.protocol, 6);
assert_eq!(flow_udp.protocol, 17);
assert_ne!(flow_tcp.flow_hash(), flow_udp.flow_hash());
}
}