use std::collections::HashMap;
use std::net::SocketAddr;
use std::time::{Duration, Instant};
use bytes::Buf;
pub const MAX_FRAME_SIZE: usize = 1024 * 9;
pub const MIN_FRAME_SIZE: usize = 14;
const BROADCAST_MAC_ADDRESS: u64 = 0xffff_ffff_ffff;
const MULTICAST_MAC_PREFIX_V4: u64 = 0x0100_5e00_0000;
const MULTICAST_MAC_PREFIX_V6: u64 = 0x3333_0000_0000;
const MULTICAST_MAC_ADDRESS_MASK_V4: u64 = 0xffff_ff80_0000;
const MULTICAST_MAC_ADDRESS_MASK_V6: u64 = 0xffff_0000_0000;
pub const MAX_MAC_COUNT: usize = 32_000;
pub const MAX_NEXTHOP_COUNT: usize = 3000;
pub const MAX_MAC_AGING_TIMEOUT: u64 = 600; pub const KEEPALIVE_TIMEOUT: u64 = 30;
pub const KEEPALIVE_INTERVAL: u64 = 10;
pub const KEEPALIVE_PACKET: &[u8] = "KEEPALIVE".as_bytes();
pub const HELLO_TIMEOUT: u64 = 180;
pub const HELLO_INTERVAL: u64 = 3;
pub const HELLO_PACKET: &[u8] = "HELLO".as_bytes();
pub const BYE_PACKET: &[u8] = "BYE".as_bytes();
pub type Error = Box<dyn std::error::Error + Sync + Send>;
type MacValue = u64;
#[derive(Clone, Debug, Copy)]
pub struct MacAddress {
value: MacValue,
}
impl MacAddress {
pub fn new(value: Vec<u8>) -> Self {
let v = bytes::Bytes::from(value);
let a = v.slice(0..2).get_u16();
let b = v.slice(2..6).get_u32();
let value = ((a as u64) << 32) + (b as u64);
MacAddress {
value
}
}
pub fn get_value(&self) -> MacValue {
self.value.clone()
}
pub fn is_broadcast(&self) -> bool {
self.value == BROADCAST_MAC_ADDRESS
}
pub fn is_multicast(&self) -> bool {
self.value & MULTICAST_MAC_ADDRESS_MASK_V4 == MULTICAST_MAC_PREFIX_V4 ||
self.value & MULTICAST_MAC_ADDRESS_MASK_V6 == MULTICAST_MAC_PREFIX_V6
}
pub fn is_unicast(&self) -> bool {
!self.is_broadcast() && !self.is_multicast()
}
}
#[derive(Debug)]
pub struct NextHop {
addr: SocketAddr,
expiry_time: Instant,
}
impl NextHop {
pub fn new(addr: SocketAddr, timeout: u64) -> Self {
NextHop {
addr,
expiry_time: Instant::now() + Duration::from_secs(timeout),
}
}
pub fn update_expiry_time(&mut self, timeout: u64) {
self.expiry_time = Instant::now() + Duration::from_secs(timeout);
}
pub fn is_expired(&self) -> bool {
self.expiry_time < Instant::now()
}
pub fn get_addr(&self) -> SocketAddr {
self.addr
}
}
#[derive(Debug)]
pub struct ForwardingTable {
mac_aging_timeout: u64,
mac_table: HashMap<MacValue, NextHop>,
nexthop_timeout: u64,
nexthop_table: HashMap<SocketAddr, Instant>,
}
impl ForwardingTable {
pub fn new(nexthop_timeout: u64, mac_aging_timeout: u64) -> ForwardingTable {
ForwardingTable {
nexthop_timeout: if nexthop_timeout == 0 { 1 } else { nexthop_timeout },
mac_aging_timeout: if mac_aging_timeout == 0 { 1 } else { mac_aging_timeout },
nexthop_table: HashMap::new(),
mac_table: HashMap::new(),
}
}
pub fn nexthop_count(&self) -> u64 {
self.nexthop_table.len() as u64
}
pub fn mac_count(&self) -> u64 {
self.mac_table.len() as u64
}
pub fn update_or_insert_nexthop(&mut self, addr: &SocketAddr) {
match self.nexthop_table.get_mut(&addr) {
Some(created_time) => {
*created_time = Instant::now() + Duration::from_secs(self.nexthop_timeout);
trace!("renew socket session creation time: {:?}", addr);
}
None => {
if self.nexthop_table.len() < MAX_NEXTHOP_COUNT {
self.nexthop_table.insert(*addr, Instant::now() + Duration::from_secs(self.nexthop_timeout));
info!("new socket session: {:?}", addr);
} else {
error!("nexthop_table overflow!");
}
}
};
}
pub fn remove_nexthop(&mut self, addr: &SocketAddr) {
let mut mac_list = Vec::new();
for (mac, nexthop) in self.mac_table.iter() {
if nexthop.get_addr() == *addr {
mac_list.push(mac.clone());
}
}
trace!("remove mac entries pointing to the deleted socket session {:?}", addr);
for m in mac_list {
self.mac_table.remove(&m);
}
match self.nexthop_table.remove(addr) {
Some(_) => info!("remove socket session: {:?}", addr),
None => {
debug!("no socket session removed for non-existing socket session: {:?}", addr);
}
}
}
pub fn remove_expired_nexthop(&mut self) {
let mut nexthop_list: Vec<SocketAddr> = Vec::new();
for (s, t) in self.nexthop_table.iter() {
if *t < Instant::now() {
nexthop_list.push(s.clone());
}
}
for s in nexthop_list {
self.remove_nexthop(&s);
trace!("remove expired socket session: {:?}", s);
}
}
pub fn mac_learning(&mut self, mac: &MacAddress, nexthop: &SocketAddr) {
if mac.is_unicast() {
match self.mac_table.get_mut(&mac.get_value()) {
Some(nh) => {
nh.update_expiry_time(self.mac_aging_timeout);
}
None => {
if self.mac_table.len() >= MAX_MAC_COUNT {
trace!("number of mac entries is overflow. try to remove expired mac entires");
self.remove_expired_mac();
}
if self.mac_table.len() < MAX_MAC_COUNT {
trace!("insert a new mac entry: {:?} {:?}", mac, nexthop);
self.mac_table.insert(mac.get_value(), NextHop::new(*nexthop, self.mac_aging_timeout));
} else {
error!("Error: total count of mac entries > MAX_MAC_COUNT")
}
}
}
}
}
pub fn get_nexthop(&self, mac: &MacAddress) -> Vec<SocketAddr> {
let mut v = Vec::new();
if mac.is_broadcast() {
trace!("mac is broadcast mac - flood to all socket sessions");
for addr in self.nexthop_table.keys() {
v.push(*addr);
}
} else if mac.is_multicast() {
trace!("mac is muticast mac - flood to all socket sessions");
for addr in self.nexthop_table.keys() {
v.push(*addr);
}
} else {
match self.mac_table.get(&mac.get_value()) {
Some(nexthop) => {
let nh = nexthop.get_addr().clone();
v.push(nh)
}
None => {
for addr in self.nexthop_table.keys() {
v.push(*addr);
}
}
}
}
v
}
pub fn remove_expired_mac(&mut self) {
let mut mac_list = Vec::new();
for (mac, nexthop) in self.mac_table.iter() {
if nexthop.is_expired() {
mac_list.push(mac.clone());
}
}
for m in mac_list {
self.mac_table.remove(&m);
trace!("mac expired: {:?}", m);
}
}
}
#[cfg(test)]
mod tests {
}