Skip to main content

structured_proxy/shield/
matcher.rs

1//! Compile Shield config patterns into runtime path matchers.
2
3use globset::GlobMatcher;
4
5use super::rate::Rate;
6use crate::config::{EndpointClassConfig, IdentifierEndpointConfig};
7use std::time::Duration;
8
9/// A compiled endpoint-class rule: requests whose path matches `matcher` are
10/// limited per client at `rate`, grouped under `class`.
11pub struct EndpointClass {
12    pub matcher: GlobMatcher,
13    pub class: String,
14    pub rate: Rate,
15}
16
17/// A compiled per-identifier rule: requests to a matching path are limited by a
18/// value read from the request body field `body_field`.
19pub struct IdentifierEndpoint {
20    pub matcher: GlobMatcher,
21    pub body_field: String,
22    pub rate: Rate,
23}
24
25/// Build a glob matcher where `*` stays within a path segment and `**` spans
26/// segments, matching the `google.api.http` / maintenance path convention.
27fn path_glob(pattern: &str) -> Result<GlobMatcher, String> {
28    globset::GlobBuilder::new(pattern)
29        .literal_separator(true)
30        .build()
31        .map(|g| g.compile_matcher())
32        .map_err(|e| format!("invalid glob pattern {pattern:?}: {e}"))
33}
34
35/// Compile endpoint-class rules, parsing each rate against `default_window`.
36pub fn compile_endpoint_classes(
37    configs: &[EndpointClassConfig],
38    default_window: Duration,
39) -> Result<Vec<EndpointClass>, String> {
40    configs
41        .iter()
42        .map(|c| {
43            Ok(EndpointClass {
44                matcher: path_glob(&c.pattern)?,
45                class: c.class.clone(),
46                rate: Rate::parse(&c.rate, default_window)?,
47            })
48        })
49        .collect()
50}
51
52/// Compile per-identifier rules, parsing each rate against `default_window`.
53pub fn compile_identifier_endpoints(
54    configs: &[IdentifierEndpointConfig],
55    default_window: Duration,
56) -> Result<Vec<IdentifierEndpoint>, String> {
57    configs
58        .iter()
59        .map(|c| {
60            Ok(IdentifierEndpoint {
61                matcher: path_glob(&c.path)?,
62                body_field: c.body_field.clone(),
63                rate: Rate::parse(&c.rate, default_window)?,
64            })
65        })
66        .collect()
67}
68
69#[cfg(test)]
70mod tests {
71    use super::*;
72
73    fn ec(pattern: &str, class: &str, rate: &str) -> EndpointClassConfig {
74        EndpointClassConfig {
75            pattern: pattern.to_string(),
76            class: class.to_string(),
77            rate: rate.to_string(),
78        }
79    }
80
81    #[test]
82    fn endpoint_class_glob_respects_segments() {
83        let classes = compile_endpoint_classes(
84            &[ec("/api/v1/heavy-*", "heavy", "10/min")],
85            Duration::from_secs(60),
86        )
87        .unwrap();
88        let m = &classes[0].matcher;
89        assert!(m.is_match("/api/v1/heavy-export"));
90        // `*` does not cross a path separator.
91        assert!(!m.is_match("/api/v1/heavy-export/sub"));
92        assert!(!m.is_match("/api/v1/light"));
93    }
94
95    #[test]
96    fn double_star_spans_segments() {
97        let classes = compile_endpoint_classes(
98            &[ec("/v1/auth/**", "auth", "20/min")],
99            Duration::from_secs(60),
100        )
101        .unwrap();
102        let m = &classes[0].matcher;
103        assert!(m.is_match("/v1/auth/login"));
104        assert!(m.is_match("/v1/auth/opaque/start"));
105    }
106
107    #[test]
108    fn invalid_rate_fails_compilation() {
109        let err = compile_endpoint_classes(&[ec("/x", "c", "nonsense")], Duration::from_secs(60));
110        assert!(err.is_err());
111    }
112}