use crate::layers::Layer;
use aho_corasick::{AhoCorasick, AhoCorasickBuilder};
use metrics::{GaugeValue, Key, Recorder, Unit};
pub struct Filter<R> {
    inner: R,
    automaton: AhoCorasick,
}
impl<R> Filter<R> {
    fn should_filter(&self, key: &Key) -> bool {
        key.name()
            .parts()
            .any(|s| self.automaton.is_match(s.as_ref()))
    }
}
impl<R: Recorder> Recorder for Filter<R> {
    fn register_counter(&self, key: Key, unit: Option<Unit>, description: Option<&'static str>) {
        if self.should_filter(&key) {
            return;
        }
        self.inner.register_counter(key, unit, description)
    }
    fn register_gauge(&self, key: Key, unit: Option<Unit>, description: Option<&'static str>) {
        if self.should_filter(&key) {
            return;
        }
        self.inner.register_gauge(key, unit, description)
    }
    fn register_histogram(&self, key: Key, unit: Option<Unit>, description: Option<&'static str>) {
        if self.should_filter(&key) {
            return;
        }
        self.inner.register_histogram(key, unit, description)
    }
    fn increment_counter(&self, key: Key, value: u64) {
        if self.should_filter(&key) {
            return;
        }
        self.inner.increment_counter(key, value);
    }
    fn update_gauge(&self, key: Key, value: GaugeValue) {
        if self.should_filter(&key) {
            return;
        }
        self.inner.update_gauge(key, value);
    }
    fn record_histogram(&self, key: Key, value: f64) {
        if self.should_filter(&key) {
            return;
        }
        self.inner.record_histogram(key, value);
    }
}
#[derive(Default)]
pub struct FilterLayer {
    patterns: Vec<String>,
    case_insensitive: bool,
    use_dfa: bool,
}
impl FilterLayer {
    
    pub fn from_patterns<P, I>(patterns: P) -> Self
    where
        P: IntoIterator<Item = I>,
        I: AsRef<str>,
    {
        FilterLayer {
            patterns: patterns
                .into_iter()
                .map(|s| s.as_ref().to_string())
                .collect(),
            case_insensitive: false,
            use_dfa: true,
        }
    }
    
    pub fn add_pattern<P>(&mut self, pattern: P) -> &mut FilterLayer
    where
        P: AsRef<str>,
    {
        self.patterns.push(pattern.as_ref().to_string());
        self
    }
    
    
    
    pub fn case_insensitive(&mut self, case_insensitive: bool) -> &mut FilterLayer {
        self.case_insensitive = case_insensitive;
        self
    }
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    pub fn use_dfa(&mut self, dfa: bool) -> &mut FilterLayer {
        self.use_dfa = dfa;
        self
    }
}
impl<R> Layer<R> for FilterLayer {
    type Output = Filter<R>;
    fn layer(&self, inner: R) -> Self::Output {
        let mut automaton_builder = AhoCorasickBuilder::new();
        let automaton = automaton_builder
            .ascii_case_insensitive(self.case_insensitive)
            .dfa(self.use_dfa)
            .auto_configure(&self.patterns)
            .build(&self.patterns);
        Filter { inner, automaton }
    }
}
#[cfg(test)]
mod tests {
    use super::FilterLayer;
    use crate::debugging::DebuggingRecorder;
    use crate::layers::Layer;
    use metrics::{Key, Recorder, Unit};
    #[test]
    fn test_basic_functionality() {
        let patterns = &["tokio", "bb8"];
        let recorder = DebuggingRecorder::new();
        let snapshotter = recorder.snapshotter();
        let filter = FilterLayer::from_patterns(patterns);
        let layered = filter.layer(recorder);
        let before = snapshotter.snapshot();
        assert_eq!(before.len(), 0);
        let ud = &[
            (Unit::Count, "counter desc"),
            (Unit::Bytes, "gauge desc"),
            (Unit::Bytes, "histogram desc"),
            (Unit::Count, "counter desc"),
            (Unit::Bytes, "gauge desc"),
        ];
        layered.register_counter(
            Key::Owned("tokio.loops".into()),
            Some(ud[0].0.clone()),
            Some(ud[0].1),
        );
        layered.register_gauge(
            Key::Owned("hyper.sent_bytes".into()),
            Some(ud[1].0.clone()),
            Some(ud[1].1),
        );
        layered.register_histogram(
            Key::Owned("hyper.tokio.sent_bytes".into()),
            Some(ud[2].0.clone()),
            Some(ud[2].1),
        );
        layered.register_counter(
            Key::Owned("bb8.conns".into()),
            Some(ud[3].0.clone()),
            Some(ud[3].1),
        );
        layered.register_gauge(
            Key::Owned("hyper.recv_bytes".into()),
            Some(ud[4].0.clone()),
            Some(ud[4].1),
        );
        let after = snapshotter.snapshot();
        assert_eq!(after.len(), 2);
        for (key, unit, desc, _value) in after {
            assert!(
                !key.key().name().to_string().contains("tokio")
                    && !key.key().name().to_string().contains("bb8")
            );
            
            
            assert_eq!(Some(Unit::Bytes), unit);
            assert!(!desc.unwrap().is_empty() && desc.unwrap() == "gauge desc");
        }
    }
    #[test]
    fn test_case_insensitivity() {
        let patterns = &["tokio", "bb8"];
        let recorder = DebuggingRecorder::new();
        let snapshotter = recorder.snapshotter();
        let mut filter = FilterLayer::from_patterns(patterns.iter());
        filter.case_insensitive(true);
        let layered = filter.layer(recorder);
        let before = snapshotter.snapshot();
        assert_eq!(before.len(), 0);
        layered.register_counter(Key::Owned("tokiO.loops".into()), None, None);
        layered.register_gauge(Key::Owned("hyper.sent_bytes".into()), None, None);
        layered.register_histogram(Key::Owned("hyper.recv_bytes".into()), None, None);
        layered.register_counter(Key::Owned("bb8.conns".into()), None, None);
        layered.register_counter(Key::Owned("Bb8.conns_closed".into()), None, None);
        let after = snapshotter.snapshot();
        assert_eq!(after.len(), 2);
        for (key, _unit, _desc, _value) in &after {
            assert!(
                !key.key()
                    .name()
                    .to_string()
                    .to_lowercase()
                    .contains("tokio")
                    && !key.key().name().to_string().to_lowercase().contains("bb8")
            );
        }
    }
}