redirectionio 3.1.0

Redirection IO Library to handle matching rule, redirect and filtering headers and body.
Documentation
use std::{
    collections::{HashMap, HashSet},
    sync::Arc,
};

#[cfg(feature = "dot")]
use dot_graph::{Edge, Graph, Node};

use super::super::{Route, RouterConfig, Trace, request_matcher::HeaderMatcher, trace::TraceInfo};
#[cfg(feature = "dot")]
use crate::dot::DotBuilder;
use crate::http::Request;

#[derive(Debug, Clone)]
pub struct MethodMatcher<T> {
    methods: HashMap<String, HeaderMatcher<T>>,
    exclude_methods: HashMap<Vec<String>, HeaderMatcher<T>>,
    any_method: HeaderMatcher<T>,
    count: usize,
    config: Arc<RouterConfig>,
}

impl<T> MethodMatcher<T> {
    pub fn new(config: Arc<RouterConfig>) -> Self {
        Self {
            methods: HashMap::new(),
            exclude_methods: HashMap::new(),
            any_method: HeaderMatcher::new(config.clone()),
            count: 0,
            config,
        }
    }

    pub fn insert(&mut self, route: Arc<Route<T>>) {
        self.count += 1;

        let config = self.config.clone();

        match route.methods() {
            None => self.any_method.insert(route),
            Some(methods) => {
                if methods.is_empty() {
                    self.any_method.insert(route);
                } else {
                    if route.exclude_methods().is_some() {
                        self.exclude_methods
                            .entry(methods.clone())
                            .or_insert_with(|| HeaderMatcher::new(config.clone()))
                            .insert(route.clone());

                        return;
                    }
                    for method in methods {
                        if !self.methods.contains_key(method) {
                            self.methods.insert(method.to_string(), HeaderMatcher::new(config.clone()));
                        }

                        self.methods.get_mut(method).unwrap().insert(route.clone());
                    }
                }
            }
        }
    }

    pub fn remove(&mut self, id: &str) -> Option<Arc<Route<T>>> {
        let mut removed = self.any_method.remove(id);

        if removed.is_some() {
            self.count -= 1;

            return removed;
        }

        self.methods.retain(|_, matcher| {
            if let Some(value) = matcher.remove(id) {
                removed = Some(value);
            }

            !matcher.is_empty()
        });

        self.exclude_methods.retain(|_, matcher| {
            if let Some(value) = matcher.remove(id) {
                removed = Some(value);
            }

            !matcher.is_empty()
        });

        if removed.is_some() {
            self.count -= 1;
        }

        removed
    }

    pub fn batch_remove(&mut self, ids: &HashSet<String>) -> bool {
        self.any_method.batch_remove(ids);

        self.methods.retain(|_, matcher| {
            matcher.batch_remove(ids);

            !matcher.is_empty()
        });

        self.exclude_methods.retain(|_, matcher| {
            matcher.batch_remove(ids);

            !matcher.is_empty()
        });

        self.any_method.is_empty() && self.methods.is_empty() && self.exclude_methods.is_empty()
    }

    pub fn match_request(&self, request: &Request) -> Vec<Arc<Route<T>>> {
        let mut routes = self.any_method.match_request(request);

        if let Some(matcher) = self.methods.get(request.method()) {
            routes.extend(matcher.match_request(request));
        }

        for (methods, matcher) in &self.exclude_methods {
            if !methods.contains(&request.method().into()) {
                routes.extend(matcher.match_request(request));
            }
        }

        routes
    }

    pub fn trace(&self, request: &Request) -> Vec<Trace<T>> {
        let mut traces = self.any_method.trace(request);
        let request_method = request.method();
        let mut found = false;

        for (methods, matcher) in &self.exclude_methods {
            if !methods.contains(&request_method.into()) {
                found = true;
                let method_traces = matcher.trace(request);

                traces.push(Trace::new(
                    true,
                    true,
                    matcher.len() as u64,
                    method_traces,
                    TraceInfo::ExcludeMethods {
                        request: request_method.to_string(),
                        against: Some(methods.clone()),
                    },
                ));
            } else {
                traces.push(Trace::new(
                    false,
                    false,
                    matcher.len() as u64,
                    Vec::new(),
                    TraceInfo::ExcludeMethods {
                        request: request_method.to_string(),
                        against: Some(methods.clone()),
                    },
                ));
            }
        }

        for (method, matcher) in &self.methods {
            if method == request_method {
                found = true;
                let method_traces = matcher.trace(request);

                traces.push(Trace::new(
                    true,
                    true,
                    matcher.len() as u64,
                    method_traces,
                    TraceInfo::Method {
                        request: request_method.to_string(),
                        against: Some(method.clone()),
                    },
                ));
            } else {
                traces.push(Trace::new(
                    false,
                    false,
                    matcher.len() as u64,
                    Vec::new(),
                    TraceInfo::Method {
                        request: request_method.to_string(),
                        against: Some(method.clone()),
                    },
                ));
            }
        }

        if !found {
            traces.push(Trace::new(
                true,
                false,
                0,
                Vec::new(),
                TraceInfo::Method {
                    request: request_method.to_string(),
                    against: None,
                },
            ));
        }

        traces
    }

    pub fn cache(&mut self, limit: u64, level: u64) -> u64 {
        let mut new_limit = self.any_method.cache(limit, level);

        for matcher in self.methods.values_mut() {
            new_limit = matcher.cache(new_limit, level);
        }

        for matcher in self.exclude_methods.values_mut() {
            new_limit = matcher.cache(new_limit, level);
        }

        new_limit
    }

    pub fn len(&self) -> usize {
        self.count
    }

    pub fn is_empty(&self) -> bool {
        self.count == 0
    }
}

#[cfg(feature = "dot")]
impl<V> DotBuilder for MethodMatcher<V> {
    fn graph(&self, id: &mut u32, graph: &mut Graph) -> Option<String> {
        let node_name = format!("method_matcher_{}", id);
        *id += 1;
        graph.add_node(Node::new(&node_name));

        if let Some(key) = self.any_method.graph(id, graph) {
            graph.add_edge(Edge::new(&node_name, &key, "any method"));
        }

        for (method, matcher) in &self.methods {
            if let Some(key) = matcher.graph(id, graph) {
                graph.add_edge(Edge::new(&node_name, &key, method));
            }
        }

        Some(node_name)
    }
}