mlua_batteries/policy/http.rs
1//! HTTP URL access policy.
2
3use super::{PolicyError, Unrestricted};
4
5/// Policy that decides whether a given URL may be accessed.
6///
7/// Every function in the `http` module calls [`HttpPolicy::check_url`]
8/// before making a request.
9///
10/// # Built-in implementations
11///
12/// | Type | Behaviour |
13/// |------|-----------|
14/// | [`Unrestricted`] | No checks (default) |
15/// | [`HttpAllowList`] | Allow only listed host patterns |
16///
17/// # Custom implementations
18///
19/// ```rust,no_run
20/// use mlua_batteries::policy::{HttpPolicy, PolicyError};
21///
22/// struct BlockInternal;
23///
24/// impl HttpPolicy for BlockInternal {
25/// fn check_url(&self, url: &str, method: &str) -> Result<(), PolicyError> {
26/// if url.contains("169.254.") || url.contains("localhost") {
27/// Err(PolicyError::new(format!("{method} denied: internal URL '{url}'")))
28/// } else {
29/// Ok(())
30/// }
31/// }
32/// }
33/// ```
34pub trait HttpPolicy: Send + Sync + 'static {
35 /// Human-readable name for this policy, used in `Debug` output.
36 ///
37 /// The default implementation returns [`std::any::type_name`] of the
38 /// concrete type, which works correctly even through trait objects
39 /// because the vtable dispatches to the concrete implementation.
40 fn policy_name(&self) -> &'static str {
41 std::any::type_name::<Self>()
42 }
43
44 /// Validate `url` for `method` (e.g. "GET", "POST").
45 ///
46 /// Return `Ok(())` to allow, `Err(reason)` to deny.
47 fn check_url(&self, url: &str, method: &str) -> Result<(), PolicyError>;
48}
49
50impl HttpPolicy for Unrestricted {
51 fn check_url(&self, _url: &str, _method: &str) -> Result<(), PolicyError> {
52 Ok(())
53 }
54}
55
56/// Allow only requests to hosts matching the given patterns.
57///
58/// Matching is performed against the **host portion** of the URL only.
59/// The URL is parsed to extract the host (stripping scheme, userinfo,
60/// port, path, query, and fragment) before matching.
61///
62/// Patterns are matched as exact or suffix of the host — e.g.
63/// `"example.com"` matches `https://example.com/path` and
64/// `https://api.example.com/path` but does **not** match
65/// `https://notexample.com/path` or `https://evil.com/?ref=example.com`.
66///
67/// # Security
68///
69/// Previous versions matched against the full URL string, which allowed
70/// bypass via query parameters or path segments. This implementation
71/// extracts the host and matches only against it.
72///
73/// ```rust,no_run
74/// use mlua_batteries::policy::HttpAllowList;
75///
76/// let policy = HttpAllowList::new(["api.example.com", "httpbin.org"]);
77/// ```
78#[derive(Debug)]
79pub struct HttpAllowList {
80 allowed_hosts: Vec<String>,
81}
82
83impl HttpAllowList {
84 /// Create an allow-list from host patterns.
85 pub fn new<I, S>(hosts: I) -> Self
86 where
87 I: IntoIterator<Item = S>,
88 S: Into<String>,
89 {
90 Self {
91 allowed_hosts: hosts.into_iter().map(Into::into).collect(),
92 }
93 }
94}
95
96impl HttpPolicy for HttpAllowList {
97 fn check_url(&self, url: &str, method: &str) -> Result<(), PolicyError> {
98 let host = extract_url_host(url).unwrap_or("");
99 if self
100 .allowed_hosts
101 .iter()
102 .any(|pattern| host_matches(host, pattern))
103 {
104 Ok(())
105 } else {
106 Err(PolicyError::new(format!(
107 "{method} denied: URL '{url}' does not match any allowed host"
108 )))
109 }
110 }
111}
112
113/// Check if `host` matches `pattern` by exact match or as a subdomain.
114///
115/// `"example.com"` matches `"example.com"` and `"sub.example.com"`
116/// but **not** `"notexample.com"`.
117///
118/// Uses zero-allocation byte comparison instead of `format!`.
119fn host_matches(host: &str, pattern: &str) -> bool {
120 host == pattern
121 || (host.len() > pattern.len()
122 && host.as_bytes()[host.len() - pattern.len() - 1] == b'.'
123 && host.ends_with(pattern))
124}
125
126/// Extract the host portion from a URL string.
127///
128/// Handles the standard URL format: `scheme://[userinfo@]host[:port]/path...`
129///
130/// - Strips scheme (`http://`, `https://`)
131/// - Strips userinfo (`user:pass@`)
132/// - Strips port (`:8080`)
133/// - Strips path, query, and fragment
134/// - Handles IPv6 addresses (`[::1]`)
135///
136/// Returns `None` if the URL has no `://` separator.
137pub(super) fn extract_url_host(url: &str) -> Option<&str> {
138 let after_scheme = url.find("://").map(|i| i + 3)?;
139 let rest = &url[after_scheme..];
140
141 // Authority ends at the first `/`, `?`, or `#`
142 let authority_end = rest.find(['/', '?', '#']).unwrap_or(rest.len());
143 let authority = &rest[..authority_end];
144
145 // Strip userinfo (everything before the last `@`)
146 let host_start = authority.rfind('@').map(|i| i + 1).unwrap_or(0);
147 let host_part = &authority[host_start..];
148
149 if host_part.starts_with('[') {
150 // IPv6: [::1]:8080 → ::1
151 host_part.find(']').map(|i| &host_part[1..i])
152 } else {
153 // Strip port: example.com:8080 → example.com
154 Some(host_part.split(':').next().unwrap_or(host_part))
155 }
156}