use crate::layers::Layer;
use aho_corasick::{AhoCorasick, AhoCorasickBuilder};
use metrics::{GaugeValue, Key, Recorder, Unit};
pub struct Filter<R> {
inner: R,
automaton: AhoCorasick,
}
impl<R> Filter<R> {
fn should_filter(&self, key: &Key) -> bool {
key.name()
.parts()
.any(|s| self.automaton.is_match(s.as_ref()))
}
}
impl<R: Recorder> Recorder for Filter<R> {
fn register_counter(&self, key: Key, unit: Option<Unit>, description: Option<&'static str>) {
if self.should_filter(&key) {
return;
}
self.inner.register_counter(key, unit, description)
}
fn register_gauge(&self, key: Key, unit: Option<Unit>, description: Option<&'static str>) {
if self.should_filter(&key) {
return;
}
self.inner.register_gauge(key, unit, description)
}
fn register_histogram(&self, key: Key, unit: Option<Unit>, description: Option<&'static str>) {
if self.should_filter(&key) {
return;
}
self.inner.register_histogram(key, unit, description)
}
fn increment_counter(&self, key: Key, value: u64) {
if self.should_filter(&key) {
return;
}
self.inner.increment_counter(key, value);
}
fn update_gauge(&self, key: Key, value: GaugeValue) {
if self.should_filter(&key) {
return;
}
self.inner.update_gauge(key, value);
}
fn record_histogram(&self, key: Key, value: f64) {
if self.should_filter(&key) {
return;
}
self.inner.record_histogram(key, value);
}
}
#[derive(Default)]
pub struct FilterLayer {
patterns: Vec<String>,
case_insensitive: bool,
use_dfa: bool,
}
impl FilterLayer {
pub fn from_patterns<P, I>(patterns: P) -> Self
where
P: IntoIterator<Item = I>,
I: AsRef<str>,
{
FilterLayer {
patterns: patterns
.into_iter()
.map(|s| s.as_ref().to_string())
.collect(),
case_insensitive: false,
use_dfa: true,
}
}
pub fn add_pattern<P>(&mut self, pattern: P) -> &mut FilterLayer
where
P: AsRef<str>,
{
self.patterns.push(pattern.as_ref().to_string());
self
}
pub fn case_insensitive(&mut self, case_insensitive: bool) -> &mut FilterLayer {
self.case_insensitive = case_insensitive;
self
}
pub fn use_dfa(&mut self, dfa: bool) -> &mut FilterLayer {
self.use_dfa = dfa;
self
}
}
impl<R> Layer<R> for FilterLayer {
type Output = Filter<R>;
fn layer(&self, inner: R) -> Self::Output {
let mut automaton_builder = AhoCorasickBuilder::new();
let automaton = automaton_builder
.ascii_case_insensitive(self.case_insensitive)
.dfa(self.use_dfa)
.auto_configure(&self.patterns)
.build(&self.patterns);
Filter { inner, automaton }
}
}
#[cfg(test)]
mod tests {
use super::FilterLayer;
use crate::debugging::DebuggingRecorder;
use crate::layers::Layer;
use metrics::{Key, Recorder, Unit};
#[test]
fn test_basic_functionality() {
let patterns = &["tokio", "bb8"];
let recorder = DebuggingRecorder::new();
let snapshotter = recorder.snapshotter();
let filter = FilterLayer::from_patterns(patterns);
let layered = filter.layer(recorder);
let before = snapshotter.snapshot();
assert_eq!(before.len(), 0);
let ud = &[
(Unit::Count, "counter desc"),
(Unit::Bytes, "gauge desc"),
(Unit::Bytes, "histogram desc"),
(Unit::Count, "counter desc"),
(Unit::Bytes, "gauge desc"),
];
layered.register_counter(
Key::Owned("tokio.loops".into()),
Some(ud[0].0.clone()),
Some(ud[0].1),
);
layered.register_gauge(
Key::Owned("hyper.sent_bytes".into()),
Some(ud[1].0.clone()),
Some(ud[1].1),
);
layered.register_histogram(
Key::Owned("hyper.tokio.sent_bytes".into()),
Some(ud[2].0.clone()),
Some(ud[2].1),
);
layered.register_counter(
Key::Owned("bb8.conns".into()),
Some(ud[3].0.clone()),
Some(ud[3].1),
);
layered.register_gauge(
Key::Owned("hyper.recv_bytes".into()),
Some(ud[4].0.clone()),
Some(ud[4].1),
);
let after = snapshotter.snapshot();
assert_eq!(after.len(), 2);
for (_kind, key, unit, desc, _value) in after {
assert!(
!key.name().to_string().contains("tokio")
&& !key.name().to_string().contains("bb8")
);
assert_eq!(Some(Unit::Bytes), unit);
assert!(!desc.unwrap().is_empty() && desc.unwrap() == "gauge desc");
}
}
#[test]
fn test_case_insensitivity() {
let patterns = &["tokio", "bb8"];
let recorder = DebuggingRecorder::new();
let snapshotter = recorder.snapshotter();
let mut filter = FilterLayer::from_patterns(patterns.iter());
filter.case_insensitive(true);
let layered = filter.layer(recorder);
let before = snapshotter.snapshot();
assert_eq!(before.len(), 0);
layered.register_counter(Key::Owned("tokiO.loops".into()), None, None);
layered.register_gauge(Key::Owned("hyper.sent_bytes".into()), None, None);
layered.register_histogram(Key::Owned("hyper.recv_bytes".into()), None, None);
layered.register_counter(Key::Owned("bb8.conns".into()), None, None);
layered.register_counter(Key::Owned("Bb8.conns_closed".into()), None, None);
let after = snapshotter.snapshot();
assert_eq!(after.len(), 2);
for (_kind, key, _unit, _desc, _value) in &after {
assert!(
!key.name().to_string().to_lowercase().contains("tokio")
&& !key.name().to_string().to_lowercase().contains("bb8")
);
}
}
}