Skip to main content

mlua_batteries/policy/
http.rs

1//! HTTP URL access policy.
2
3use super::{PolicyError, Unrestricted};
4
5/// Policy that decides whether a given URL may be accessed.
6///
7/// Every function in the `http` module calls [`HttpPolicy::check_url`]
8/// before making a request.
9///
10/// # Built-in implementations
11///
12/// | Type | Behaviour |
13/// |------|-----------|
14/// | [`Unrestricted`] | No checks (default) |
15/// | [`HttpAllowList`] | Allow only listed host patterns |
16///
17/// # Custom implementations
18///
19/// ```rust,no_run
20/// use mlua_batteries::policy::{HttpPolicy, PolicyError};
21///
22/// struct BlockInternal;
23///
24/// impl HttpPolicy for BlockInternal {
25///     fn check_url(&self, url: &str, method: &str) -> Result<(), PolicyError> {
26///         if url.contains("169.254.") || url.contains("localhost") {
27///             Err(PolicyError::new(format!("{method} denied: internal URL '{url}'")))
28///         } else {
29///             Ok(())
30///         }
31///     }
32/// }
33/// ```
34pub trait HttpPolicy: Send + Sync + 'static {
35    /// Human-readable name for this policy, used in `Debug` output.
36    ///
37    /// The default implementation returns [`std::any::type_name`] of the
38    /// concrete type, which works correctly even through trait objects
39    /// because the vtable dispatches to the concrete implementation.
40    fn policy_name(&self) -> &'static str {
41        std::any::type_name::<Self>()
42    }
43
44    /// Validate `url` for `method` (e.g. "GET", "POST").
45    ///
46    /// Return `Ok(())` to allow, `Err(reason)` to deny.
47    fn check_url(&self, url: &str, method: &str) -> Result<(), PolicyError>;
48}
49
50impl HttpPolicy for Unrestricted {
51    fn check_url(&self, _url: &str, _method: &str) -> Result<(), PolicyError> {
52        Ok(())
53    }
54}
55
56/// Allow only requests to hosts matching the given patterns.
57///
58/// Matching is performed against the **host portion** of the URL only.
59/// The URL is parsed to extract the host (stripping scheme, userinfo,
60/// port, path, query, and fragment) before matching.
61///
62/// Patterns are matched as exact or suffix of the host — e.g.
63/// `"example.com"` matches `https://example.com/path` and
64/// `https://api.example.com/path` but does **not** match
65/// `https://notexample.com/path` or `https://evil.com/?ref=example.com`.
66///
67/// # Security
68///
69/// Previous versions matched against the full URL string, which allowed
70/// bypass via query parameters or path segments. This implementation
71/// extracts the host and matches only against it.
72///
73/// ```rust,no_run
74/// use mlua_batteries::policy::HttpAllowList;
75///
76/// let policy = HttpAllowList::new(["api.example.com", "httpbin.org"]);
77/// ```
78#[derive(Debug)]
79pub struct HttpAllowList {
80    allowed_hosts: Vec<String>,
81}
82
83impl HttpAllowList {
84    /// Create an allow-list from host patterns.
85    pub fn new<I, S>(hosts: I) -> Self
86    where
87        I: IntoIterator<Item = S>,
88        S: Into<String>,
89    {
90        Self {
91            allowed_hosts: hosts.into_iter().map(Into::into).collect(),
92        }
93    }
94}
95
96impl HttpPolicy for HttpAllowList {
97    fn check_url(&self, url: &str, method: &str) -> Result<(), PolicyError> {
98        let host = extract_url_host(url).unwrap_or("");
99        if self
100            .allowed_hosts
101            .iter()
102            .any(|pattern| host_matches(host, pattern))
103        {
104            Ok(())
105        } else {
106            Err(PolicyError::new(format!(
107                "{method} denied: URL '{url}' does not match any allowed host"
108            )))
109        }
110    }
111}
112
113/// Check if `host` matches `pattern` by exact match or as a subdomain.
114///
115/// `"example.com"` matches `"example.com"` and `"sub.example.com"`
116/// but **not** `"notexample.com"`.
117///
118/// Uses zero-allocation byte comparison instead of `format!`.
119fn host_matches(host: &str, pattern: &str) -> bool {
120    host == pattern
121        || (host.len() > pattern.len()
122            && host.as_bytes()[host.len() - pattern.len() - 1] == b'.'
123            && host.ends_with(pattern))
124}
125
126/// Extract the host portion from a URL string.
127///
128/// Handles the standard URL format: `scheme://[userinfo@]host[:port]/path...`
129///
130/// - Strips scheme (`http://`, `https://`)
131/// - Strips userinfo (`user:pass@`)
132/// - Strips port (`:8080`)
133/// - Strips path, query, and fragment
134/// - Handles IPv6 addresses (`[::1]`)
135///
136/// Returns `None` if the URL has no `://` separator.
137pub(super) fn extract_url_host(url: &str) -> Option<&str> {
138    let after_scheme = url.find("://").map(|i| i + 3)?;
139    let rest = &url[after_scheme..];
140
141    // Authority ends at the first `/`, `?`, or `#`
142    let authority_end = rest.find(['/', '?', '#']).unwrap_or(rest.len());
143    let authority = &rest[..authority_end];
144
145    // Strip userinfo (everything before the last `@`)
146    let host_start = authority.rfind('@').map(|i| i + 1).unwrap_or(0);
147    let host_part = &authority[host_start..];
148
149    if host_part.starts_with('[') {
150        // IPv6: [::1]:8080 → ::1
151        host_part.find(']').map(|i| &host_part[1..i])
152    } else {
153        // Strip port: example.com:8080 → example.com
154        Some(host_part.split(':').next().unwrap_or(host_part))
155    }
156}