use std::sync::Arc;
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize)]
pub enum TemplateProtocol {
V9,
Ipfix,
}
#[non_exhaustive]
#[derive(Debug, Clone)]
pub enum TemplateEvent {
Learned {
template_id: Option<u16>,
protocol: TemplateProtocol,
},
Collision {
template_id: Option<u16>,
protocol: TemplateProtocol,
},
Evicted {
template_id: Option<u16>,
protocol: TemplateProtocol,
},
Expired {
template_id: Option<u16>,
protocol: TemplateProtocol,
},
MissingTemplate {
template_id: Option<u16>,
protocol: TemplateProtocol,
},
}
pub type TemplateHookError = Box<dyn std::error::Error + Send + Sync>;
pub type TemplateHook =
Arc<dyn Fn(&TemplateEvent) -> Result<(), TemplateHookError> + Send + Sync + 'static>;
#[derive(Default)]
pub struct TemplateHooks {
hooks: Vec<TemplateHook>,
hook_errors: u64,
}
impl Clone for TemplateHooks {
fn clone(&self) -> Self {
Self {
hooks: self.hooks.clone(),
hook_errors: 0, }
}
}
impl std::fmt::Debug for TemplateHooks {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TemplateHooks")
.field("hook_count", &self.hooks.len())
.finish()
}
}
impl TemplateHooks {
pub fn new() -> Self {
Self {
hooks: Vec::new(),
hook_errors: 0,
}
}
pub fn register<F>(&mut self, hook: F)
where
F: Fn(&TemplateEvent) -> Result<(), TemplateHookError> + Send + Sync + 'static,
{
self.hooks.push(Arc::new(hook));
}
pub fn clear(&mut self) {
self.hooks.clear();
}
pub fn trigger(&mut self, event: &TemplateEvent) {
for hook in &self.hooks {
match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| hook(event))) {
Ok(Err(e)) => {
self.hook_errors = self.hook_errors.saturating_add(1);
#[cfg(debug_assertions)]
eprintln!("template hook error: {e}");
let _ = e;
}
Err(_panic) => {
self.hook_errors = self.hook_errors.saturating_add(1);
#[cfg(debug_assertions)]
eprintln!("template hook panicked");
}
Ok(Ok(())) => {}
}
}
}
pub fn hook_error_count(&self) -> u64 {
self.hook_errors
}
pub fn len(&self) -> usize {
self.hooks.len()
}
pub fn is_empty(&self) -> bool {
self.hooks.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
#[test]
fn test_hook_registration() {
let mut hooks = TemplateHooks::new();
assert_eq!(hooks.len(), 0);
assert!(hooks.is_empty());
hooks.register(|_| Ok(()));
assert_eq!(hooks.len(), 1);
assert!(!hooks.is_empty());
}
#[test]
fn test_hook_triggering() {
let mut hooks = TemplateHooks::new();
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = counter.clone();
hooks.register(move |_| {
counter_clone.fetch_add(1, Ordering::SeqCst);
Ok(())
});
let event = TemplateEvent::Learned {
template_id: Some(256),
protocol: TemplateProtocol::V9,
};
hooks.trigger(&event);
assert_eq!(counter.load(Ordering::SeqCst), 1);
hooks.trigger(&event);
assert_eq!(counter.load(Ordering::SeqCst), 2);
}
#[test]
fn test_multiple_hooks() {
let mut hooks = TemplateHooks::new();
let counter1 = Arc::new(AtomicUsize::new(0));
let counter2 = Arc::new(AtomicUsize::new(0));
let c1 = counter1.clone();
let c2 = counter2.clone();
hooks.register(move |_| {
c1.fetch_add(1, Ordering::SeqCst);
Ok(())
});
hooks.register(move |_| {
c2.fetch_add(10, Ordering::SeqCst);
Ok(())
});
let event = TemplateEvent::Collision {
template_id: Some(300),
protocol: TemplateProtocol::Ipfix,
};
hooks.trigger(&event);
assert_eq!(counter1.load(Ordering::SeqCst), 1);
assert_eq!(counter2.load(Ordering::SeqCst), 10);
}
#[test]
fn test_hook_event_matching() {
let mut hooks = TemplateHooks::new();
let learned_count = Arc::new(AtomicUsize::new(0));
let collision_count = Arc::new(AtomicUsize::new(0));
let lc = learned_count.clone();
let cc = collision_count.clone();
hooks.register(move |event| {
match event {
TemplateEvent::Learned { .. } => {
lc.fetch_add(1, Ordering::SeqCst);
}
TemplateEvent::Collision { .. } => {
cc.fetch_add(1, Ordering::SeqCst);
}
_ => {}
}
Ok(())
});
hooks.trigger(&TemplateEvent::Learned {
template_id: Some(256),
protocol: TemplateProtocol::V9,
});
hooks.trigger(&TemplateEvent::Collision {
template_id: Some(300),
protocol: TemplateProtocol::Ipfix,
});
hooks.trigger(&TemplateEvent::Learned {
template_id: Some(400),
protocol: TemplateProtocol::V9,
});
assert_eq!(learned_count.load(Ordering::SeqCst), 2);
assert_eq!(collision_count.load(Ordering::SeqCst), 1);
}
#[test]
fn test_template_event_clone() {
let event = TemplateEvent::Evicted {
template_id: Some(500),
protocol: TemplateProtocol::Ipfix,
};
let cloned = event.clone();
match (event, cloned) {
(
TemplateEvent::Evicted {
template_id: id1,
protocol: p1,
},
TemplateEvent::Evicted {
template_id: id2,
protocol: p2,
},
) => {
assert_eq!(id1, id2);
assert_eq!(p1, p2);
}
_ => panic!("Event didn't match after clone"),
}
}
}