use std::collections::HashMap;
use std::net::Ipv4Addr;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum ArpValidation {
Valid,
Unsolicited,
Expired,
}
pub struct ArpValidator {
pending_requests: HashMap<Ipv4Addr, Instant>,
timeout: Duration,
}
impl ArpValidator {
pub fn new(timeout: Duration) -> Self {
Self {
pending_requests: HashMap::new(),
timeout,
}
}
pub fn with_default_timeout() -> Self {
Self::new(Duration::from_secs(2))
}
pub fn record_request(&mut self, target_ip: Ipv4Addr) {
self.pending_requests.insert(target_ip, Instant::now());
}
pub fn validate_reply(&mut self, ip: Ipv4Addr) -> ArpValidation {
match self.pending_requests.remove(&ip) {
Some(request_time) => {
if request_time.elapsed() <= self.timeout {
ArpValidation::Valid
} else {
ArpValidation::Expired
}
}
None => ArpValidation::Unsolicited,
}
}
pub fn cleanup_expired(&mut self) {
let now = Instant::now();
self.pending_requests
.retain(|_, &mut request_time| now.duration_since(request_time) <= self.timeout);
}
pub fn pending_count(&self) -> usize {
self.pending_requests.len()
}
pub fn clear(&mut self) {
self.pending_requests.clear();
}
}
impl Default for ArpValidator {
fn default() -> Self {
Self::with_default_timeout()
}
}
pub struct SharedArpValidator {
inner: Arc<Mutex<ArpValidator>>,
}
impl SharedArpValidator {
pub fn new(timeout: Duration) -> Self {
Self {
inner: Arc::new(Mutex::new(ArpValidator::new(timeout))),
}
}
pub fn with_default_timeout() -> Self {
Self::new(Duration::from_secs(2))
}
pub fn record_request(&self, target_ip: Ipv4Addr) {
if let Ok(mut validator) = self.inner.lock() {
validator.record_request(target_ip);
}
}
pub fn validate_reply(&self, ip: Ipv4Addr) -> ArpValidation {
self.inner
.lock()
.map(|mut validator| validator.validate_reply(ip))
.unwrap_or(ArpValidation::Unsolicited)
}
pub fn pending_count(&self) -> usize {
self.inner
.lock()
.map(|validator| validator.pending_count())
.unwrap_or(0)
}
pub fn clear(&self) {
if let Ok(mut validator) = self.inner.lock() {
validator.clear();
}
}
pub fn clone_inner(&self) -> Arc<Mutex<ArpValidator>> {
Arc::clone(&self.inner)
}
}
impl Default for SharedArpValidator {
fn default() -> Self {
Self::with_default_timeout()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_record_and_validate_reply() {
let mut validator = ArpValidator::with_default_timeout();
let ip: Ipv4Addr = "192.168.1.1".parse().unwrap();
assert_eq!(validator.validate_reply(ip), ArpValidation::Unsolicited);
validator.record_request(ip);
assert_eq!(validator.validate_reply(ip), ArpValidation::Valid);
assert_eq!(validator.validate_reply(ip), ArpValidation::Unsolicited);
}
#[test]
fn test_expired_request() {
let mut validator = ArpValidator::new(Duration::from_millis(100));
let ip: Ipv4Addr = "192.168.1.1".parse().unwrap();
validator.record_request(ip);
std::thread::sleep(Duration::from_millis(150));
assert_eq!(validator.validate_reply(ip), ArpValidation::Expired);
}
#[test]
fn test_cleanup_expired() {
let mut validator = ArpValidator::new(Duration::from_millis(100));
let ip1: Ipv4Addr = "192.168.1.1".parse().unwrap();
let ip2: Ipv4Addr = "192.168.1.2".parse().unwrap();
validator.record_request(ip1);
validator.record_request(ip2);
assert_eq!(validator.pending_count(), 2);
std::thread::sleep(Duration::from_millis(150));
validator.cleanup_expired();
assert_eq!(validator.pending_count(), 0);
}
#[test]
fn test_shared_validator() {
let validator = SharedArpValidator::with_default_timeout();
let ip: Ipv4Addr = "192.168.1.1".parse().unwrap();
assert_eq!(validator.validate_reply(ip), ArpValidation::Unsolicited);
validator.record_request(ip);
assert_eq!(validator.validate_reply(ip), ArpValidation::Valid);
}
#[test]
fn test_clear_pending() {
let mut validator = ArpValidator::with_default_timeout();
let ip1: Ipv4Addr = "192.168.1.1".parse().unwrap();
let ip2: Ipv4Addr = "192.168.1.2".parse().unwrap();
validator.record_request(ip1);
validator.record_request(ip2);
assert_eq!(validator.pending_count(), 2);
validator.clear();
assert_eq!(validator.pending_count(), 0);
}
}