Skip to main content

dhttp_access/expr/
atomics.rs

1#![allow(deprecated)]
2
3use std::{fmt::Display, net::IpAddr, str::FromStr};
4
5use derive_more::{AsRef, Display, From, Into};
6use snafu::OptionExt;
7
8use crate::{
9    action::RequestAction,
10    expr::{
11        eval::{EvalRuleError, Evaluable},
12        parse::{self, InvalidPatternExpr},
13    },
14    pattern::{ClientNamePattern, NormalPattern, Pattern},
15};
16
17/// 表示网络请求的源类型
18#[deprecated = "Redesign in the future"]
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub enum Source {
21    /// 本地网络(LAN)
22    Lan,
23    /// 广域网络(WAN)
24    Wan,
25}
26
27impl Source {
28    pub fn as_str(&self) -> &'static str {
29        match self {
30            Source::Lan => "lan",
31            Source::Wan => "wan",
32        }
33    }
34
35    /// 判断IP地址是否匹配当前源类型
36    ///
37    /// ```
38    /// use dhttp_access::expr::atomics::Source;
39    ///
40    /// let lan = Source::Lan;
41    /// let wan = Source::Wan;
42    ///
43    /// // LAN 地址
44    /// assert!(lan.is_match("192.168.1.1".parse().unwrap()));
45    /// assert!(lan.is_match("10.0.0.1".parse().unwrap()));
46    /// assert!(lan.is_match("::1".parse().unwrap()));
47    ///
48    /// // WAN 地址
49    /// assert!(wan.is_match("8.8.8.8".parse().unwrap()));
50    /// assert!(wan.is_match("2001:4860:4860::8888".parse().unwrap()));
51    ///
52    /// // 交叉验证
53    /// assert!(!lan.is_match("8.8.8.8".parse().unwrap()));
54    /// assert!(!wan.is_match("192.168.1.1".parse().unwrap()));
55    /// ```
56    pub fn is_match(&self, source: IpAddr) -> bool {
57        let source_is_lan = match source {
58            IpAddr::V4(ip) => ip.is_loopback() || ip.is_private() || ip.is_link_local(),
59            IpAddr::V6(ip) => {
60                ip.is_loopback() || ip.is_unique_local() || ip.is_unicast_link_local()
61            }
62        };
63        (self == &Self::Lan) == source_is_lan
64    }
65}
66
67impl Evaluable<IpAddr> for Source {
68    type Value = bool;
69
70    fn eval(&self, argument: &IpAddr) -> Self::Value {
71        self.is_match(*argument)
72    }
73}
74
75impl Display for Source {
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        self.as_str().fmt(f)
78    }
79}
80
81#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Hash)]
82pub struct AnyClient;
83
84impl Evaluable<Option<&str>> for AnyClient {
85    type Value = bool;
86
87    fn eval(&self, _: &Option<&str>) -> Self::Value {
88        true
89    }
90}
91
92impl Display for AnyClient {
93    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94        "*?".fmt(f)
95    }
96}
97
98#[derive(snafu::Snafu, Debug, Clone, Copy)]
99pub enum EvalError {
100    #[snafu(display("client name is not provided, cannot match client name pattern"))]
101    MissingClientName,
102}
103
104impl EvalRuleError<RequestAction> for EvalError {
105    fn fallback(&self, matched_action: RequestAction) -> Option<RequestAction> {
106        _ = matched_action;
107        Some(RequestAction::Deny)
108    }
109}
110
111#[derive(Debug, Display, Clone, From, Into, AsRef, PartialEq, Eq)]
112pub struct ClientName(ClientNamePattern);
113
114impl Evaluable<Option<&str>> for ClientName {
115    type Value = Result<bool, EvalError>;
116
117    fn eval(&self, argument: &Option<&str>) -> Self::Value {
118        argument
119            .map(|client_name| self.0.eval(&client_name))
120            .context(MissingClientNameSnafu)
121    }
122}
123
124#[derive(Debug, Clone, From, Into, AsRef, PartialEq, Eq)]
125pub struct Method {
126    pattern: NormalPattern,
127}
128
129#[cfg(feature = "http")]
130impl Evaluable<&http::Method> for Method {
131    type Value = bool;
132
133    fn eval(&self, method: &&http::Method) -> Self::Value {
134        self.pattern.eval(&method.as_str())
135    }
136}
137
138/// 键值对模式,用于匹配HTTP头或查询参数
139///
140/// TODO: 支持key的匹配项参与匹配?
141#[derive(Debug, Clone, PartialEq, Eq)]
142pub struct KVPattern {
143    pub key: NormalPattern,
144    pub value: NormalPattern,
145}
146
147impl Evaluable<(&str, &str)> for KVPattern {
148    type Value = bool;
149
150    fn eval(&self, (key, value): &(&str, &str)) -> Self::Value {
151        self.key.eval(key) && self.value.eval(value)
152    }
153}
154
155impl Display for KVPattern {
156    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
157        write!(f, "{}:{}", self.key, self.value)
158    }
159}
160
161#[derive(Debug, Clone, From, Into, AsRef, PartialEq, Eq)]
162pub struct Header {
163    pattern: KVPattern,
164}
165
166impl Evaluable<(&str, &str)> for Header {
167    type Value = bool;
168
169    fn eval(&self, pair: &(&str, &str)) -> Self::Value {
170        self.pattern.eval(pair)
171    }
172}
173
174#[cfg(feature = "http")]
175impl Evaluable<(&http::HeaderName, &http::HeaderValue)> for Header {
176    type Value = bool;
177
178    fn eval(&self, (key, value): &(&http::HeaderName, &http::HeaderValue)) -> Self::Value {
179        let Ok(value) = value.to_str() else {
180            // TODO: support binary header value match
181            return false;
182        };
183        self.eval(&(key.as_str(), value))
184    }
185}
186
187#[derive(Debug, Clone, From, Into, AsRef, PartialEq, Eq)]
188pub struct Query {
189    pattern: KVPattern,
190}
191
192impl Evaluable<(&str, &str)> for Query {
193    type Value = bool;
194
195    fn eval(&self, pair: &(&str, &str)) -> Self::Value {
196        self.pattern.eval(pair)
197    }
198}
199
200fn escape_pattern<Kind>(pat: &Pattern<Kind>) -> String {
201    pat.as_str().replace('\\', "\\\\").replace('"', "\\\"")
202}
203
204fn to_quoted_escaped_pattern<Kind>(pat: &Pattern<Kind>) -> String {
205    format!("\"{}\"", escape_pattern(pat))
206}
207
208fn to_quoted_escaped_kv_pattern(pat: &KVPattern) -> String {
209    let KVPattern { key, value } = pat;
210    let (key, value) = (escape_pattern(key), escape_pattern(value));
211    format!("\"{}\":\"{}\"", key, value)
212}
213
214#[derive(Debug, Clone, PartialEq, Eq)]
215pub enum AtomicLocationRuleExpr {
216    Any(AnyClient), // "*?"
217    ClientName(ClientName),
218    Method(Method),
219    Header(Header),
220    Query(Query),
221}
222
223impl Display for AtomicLocationRuleExpr {
224    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
225        match self {
226            Self::Any(any) => any.fmt(f),
227            Self::ClientName(pattern) => {
228                write!(f, "{}", to_quoted_escaped_pattern(pattern.as_ref()))
229            }
230            Self::Method(Method { pattern }) => {
231                write!(f, "With Method {}", to_quoted_escaped_pattern(pattern))
232            }
233            Self::Header(Header { pattern }) => {
234                write!(f, "With Header {}", to_quoted_escaped_kv_pattern(pattern))
235            }
236            Self::Query(Query { pattern }) => {
237                write!(f, "With Query {}", to_quoted_escaped_kv_pattern(pattern))
238            }
239        }
240    }
241}
242
243mod parse_atomic {
244
245    use peg::{error::ParseError, str::LineCol};
246    use snafu::ResultExt;
247
248    use super::*;
249
250    #[derive(snafu::Snafu, Debug)]
251    pub enum ParseAtomicRuleExprError {
252        #[snafu(display("failed to parse rule expr"))]
253        Pattern { source: InvalidPatternExpr },
254        #[snafu(display("failed to parse rule expr `{input}`"))]
255        Incomplete {
256            input: String,
257            source: ParseError<LineCol>,
258        },
259    }
260
261    impl FromStr for AtomicLocationRuleExpr {
262        type Err = ParseAtomicRuleExprError;
263
264        fn from_str(infix: &str) -> Result<Self, Self::Err> {
265            let tokens =
266                parse::TokenStream::new(infix).context(IncompleteSnafu { input: infix })?;
267            parse::atomic_location_rule_expr(&tokens)
268                .context(IncompleteSnafu { input: infix })?
269                .context(PatternSnafu)
270        }
271    }
272}
273
274#[cfg(feature = "http")]
275pub struct HttpRequest<'a> {
276    client_name: Option<&'a str>,
277    method: &'a http::Method,
278    headers: &'a http::HeaderMap<http::HeaderValue>,
279    queries: Vec<(&'a str, &'a str)>,
280}
281
282#[cfg(feature = "http")]
283impl<'a> HttpRequest<'a> {
284    pub fn new<T>(client_name: Option<&'a str>, request: &'a http::Request<T>) -> Self {
285        Self {
286            client_name,
287            method: request.method(),
288            headers: request.headers(),
289            queries: request.uri().query().map_or(vec![], |q| {
290                q.split('&')
291                    .filter_map(|pair| {
292                        let mut parts = pair.splitn(2, '=');
293                        let key = parts.next()?;
294                        let value = parts.next().unwrap_or("");
295                        Some((key, value))
296                    })
297                    .collect::<Vec<(&str, &str)>>()
298            }),
299        }
300    }
301}
302
303#[cfg(feature = "http")]
304impl Evaluable<HttpRequest<'_>> for AtomicLocationRuleExpr {
305    type Value = Result<bool, EvalError>;
306
307    fn eval(&self, request: &HttpRequest) -> Self::Value {
308        Ok(match self {
309            // Self::Source(source) => source.eval(&argument.source_ip),
310            Self::Any(..) => true,
311            Self::ClientName(pattern) => pattern.eval(&request.client_name)?,
312            Self::Method(method) => method.eval(&request.method),
313            Self::Header(header) => request.headers.iter().any(|(k, v)| header.eval(&(k, v))),
314            Self::Query(query) => request.queries.iter().any(|pair| query.eval(pair)),
315        })
316    }
317}