foxy/router/
predicates.rs

1// This Source Code Form is subject to the terms of the Mozilla Public
2// License, v. 2.0. If a copy of the MPL was not distributed with this
3// file, You can obtain one at https://mozilla.org/MPL/2.0/.
4
5//! Predicate implementations for router matching.
6//!
7//! This module provides various predicates that can be used for route matching.
8
9use std::collections::HashMap;
10use async_trait::async_trait;
11use regex::Regex;
12use serde::{Serialize, Deserialize};
13
14use crate::core::{ProxyRequest, HttpMethod, ProxyError};
15use super::Predicate;
16
17/// Configuration for a path predicate.
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct PathPredicateConfig {
20    /// The path pattern to match
21    pub pattern: String,
22}
23
24/// A predicate that matches on request path.
25#[derive(Debug)]
26pub struct PathPredicate {
27    /// The configuration for this predicate
28    config: PathPredicateConfig,
29    /// Compiled regex for path matching
30    regex: Regex,
31}
32
33impl PathPredicate {
34    /// Create a new path predicate with the given configuration.
35    pub fn new(config: PathPredicateConfig) -> Result<Self, ProxyError> {
36        // Convert the path pattern to a regex
37        let regex_pattern = Self::pattern_to_regex(&config.pattern);
38
39        // Compile the regex
40        let regex = Regex::new(&regex_pattern)
41            .map_err(|e| ProxyError::RoutingError(format!("Invalid path predicate regex pattern '{}': {}", config.pattern, e)))?;
42
43        Ok(Self { config, regex })
44    }
45
46    /// Convert a path pattern to a regex pattern.
47    fn pattern_to_regex(pattern: &str) -> String {
48        let mut regex_pattern = "^".to_string();
49
50        let mut chars = pattern.chars().peekable();
51        while let Some(c) = chars.next() {
52            match c {
53                // Handle path parameters like :id
54                ':' => {
55                    let mut param_name = String::new();
56                    while let Some(&next_char) = chars.peek() {
57                        if next_char.is_alphanumeric() || next_char == '_' {
58                            param_name.push(chars.next().unwrap());
59                        } else {
60                            break;
61                        }
62                    }
63
64                    // Add a capturing group for the parameter
65                    regex_pattern.push_str(&format!("([^/]+)"));
66                },
67                // Handle wildcards like *
68                '*' => {
69                    regex_pattern.push_str("(.*)");
70                },
71                // Escape special regex characters
72                '.' | '^' | '$' | '|' | '+' | '?' | '(' | ')' | '[' | ']' | '{' | '}' | '\\' => {
73                    regex_pattern.push('\\');
74                    regex_pattern.push(c);
75                },
76                // Regular characters
77                _ => {
78                    regex_pattern.push(c);
79                }
80            }
81        }
82
83        regex_pattern.push('$');
84        regex_pattern
85    }
86}
87
88#[async_trait]
89impl Predicate for PathPredicate {
90    async fn matches(&self, request: &ProxyRequest) -> bool {
91        self.regex.is_match(&request.path)
92    }
93
94    fn predicate_type(&self) -> &str {
95        "path"
96    }
97}
98
99/// Configuration for a method predicate.
100#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct MethodPredicateConfig {
102    /// The HTTP methods to match
103    pub methods: Vec<HttpMethod>,
104}
105
106/// A predicate that matches on HTTP method.
107#[derive(Debug)]
108pub struct MethodPredicate {
109    /// The configuration for this predicate
110    config: MethodPredicateConfig,
111}
112
113impl MethodPredicate {
114    /// Create a new method predicate with the given configuration.
115    pub fn new(config: MethodPredicateConfig) -> Self {
116        Self { config }
117    }
118}
119
120#[async_trait]
121impl Predicate for MethodPredicate {
122    async fn matches(&self, request: &ProxyRequest) -> bool {
123        self.config.methods.contains(&request.method)
124    }
125
126    fn predicate_type(&self) -> &str {
127        "method"
128    }
129}
130
131/// Configuration for a header predicate.
132#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct HeaderPredicateConfig {
134    /// The headers to match (name and value)
135    pub headers: HashMap<String, String>,
136    /// Whether to require exact match for header values
137    #[serde(default)]
138    pub exact_match: bool,
139}
140
141/// A predicate that matches on request headers.
142#[derive(Debug)]
143pub struct HeaderPredicate {
144    /// The configuration for this predicate
145    config: HeaderPredicateConfig,
146}
147
148impl HeaderPredicate {
149    /// Create a new header predicate with the given configuration.
150    pub fn new(config: HeaderPredicateConfig) -> Self {
151        Self { config }
152    }
153}
154
155#[async_trait]
156impl Predicate for HeaderPredicate {
157    async fn matches(&self, request: &ProxyRequest) -> bool {
158        for (name, expected_value) in &self.config.headers {
159            // Try to get the header
160            if let Some(header_value) = request.headers.get(name) {
161                // Convert to string for comparison
162                if let Ok(actual_value) = header_value.to_str() {
163                    if self.config.exact_match {
164                        // Exact match
165                        if actual_value != expected_value {
166                            return false;
167                        }
168                    } else {
169                        // Contains match
170                        if !actual_value.contains(expected_value) {
171                            return false;
172                        }
173                    }
174                } else {
175                    // Not a valid UTF-8 string
176                    return false;
177                }
178            } else {
179                // Header not found
180                return false;
181            }
182        }
183
184        // All headers matched
185        true
186    }
187
188    fn predicate_type(&self) -> &str {
189        "header"
190    }
191}
192
193/// Configuration for a query parameter predicate.
194#[derive(Debug, Clone, Serialize, Deserialize)]
195pub struct QueryPredicateConfig {
196    /// The query parameters to match (name and value)
197    pub params: HashMap<String, String>,
198    /// Whether to require exact match for parameter values
199    #[serde(default)]
200    pub exact_match: bool,
201}
202
203/// A predicate that matches on query parameters.
204#[derive(Debug)]
205pub struct QueryPredicate {
206    /// The configuration for this predicate
207    config: QueryPredicateConfig,
208}
209
210impl QueryPredicate {
211    /// Create a new query predicate with the given configuration.
212    pub fn new(config: QueryPredicateConfig) -> Self {
213        Self { config }
214    }
215
216    /// Parse query parameters from a query string.
217    fn parse_query_params(query: &str) -> HashMap<String, String> {
218        let mut params = HashMap::new();
219
220        for pair in query.split('&') {
221            let mut iter = pair.split('=');
222            if let (Some(key), Some(value)) = (iter.next(), iter.next()) {
223                params.insert(key.to_string(), value.to_string());
224            }
225        }
226
227        params
228    }
229}
230
231#[async_trait]
232impl Predicate for QueryPredicate {
233    async fn matches(&self, request: &ProxyRequest) -> bool {
234        // If no query parameters to match, then it's a match
235        if self.config.params.is_empty() {
236            return true;
237        }
238
239        // If the request has no query string, it's not a match
240        if let Some(query) = &request.query {
241            let params = Self::parse_query_params(query);
242
243            for (name, expected_value) in &self.config.params {
244                // Try to get the parameter
245                if let Some(actual_value) = params.get(name) {
246                    if self.config.exact_match {
247                        // Exact match
248                        if actual_value != expected_value {
249                            return false;
250                        }
251                    } else {
252                        // Contains match
253                        if !actual_value.contains(expected_value) {
254                            return false;
255                        }
256                    }
257                } else {
258                    // Parameter not found
259                    return false;
260                }
261            }
262
263            // All parameters matched
264            true
265        } else {
266            false
267        }
268    }
269
270    fn predicate_type(&self) -> &str {
271        "query"
272    }
273}