1use hyper::Uri;
2use hyper::http::Method;
3use rand::RngCore;
4use secrecy::{ExposeSecret, SecretString};
5use std::fmt;
6use tracing::{debug, trace};
7use wasmtime_wasi_http::p2::bindings::http::types::ErrorCode;
8
9#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
11pub enum ReplacementLocation {
12 Headers,
13 Body,
14 Params,
15}
16
17#[derive(Clone, Debug)]
20pub struct PlaceholderSecret {
21 pub placeholder: String,
23 pub real_value: SecretString,
25 pub replace_in: hashbrown::HashSet<ReplacementLocation>,
27}
28
29#[derive(Clone, Debug, PartialEq, Eq, Hash, derive_more::Display)]
31pub enum SchemePattern {
32 #[display("http")]
33 Http,
34 #[display("https")]
35 Https,
36 #[display("*")]
38 Any,
39}
40impl SchemePattern {
41 #[must_use]
43 pub fn allows_unencrypted(&self) -> bool {
44 matches!(self, SchemePattern::Http | SchemePattern::Any)
45 }
46}
47
48#[derive(Clone, Debug, PartialEq, Eq, Hash)]
50pub enum PortPattern {
51 Specific(u16),
53 Any,
55 Default,
57}
58
59#[derive(Clone, Debug, PartialEq, Eq, Hash)]
61pub enum MethodsPattern {
62 AllMethods,
64 Specific(Vec<Method>),
66}
67
68#[derive(Clone, Debug, PartialEq, Eq, Hash)]
76pub struct HostPattern {
77 pub scheme: SchemePattern,
78 pub host_pattern: String,
79 pub port: PortPattern,
80 pub methods: MethodsPattern,
82}
83
84#[derive(Debug, Clone, thiserror::Error)]
85pub enum HostPatternError {
86 #[error("wildcard `*` must be the first or last character in host pattern: `{host}`")]
87 Wildcard { host: String },
88 #[error("host pattern must not contain a path: `{input}`")]
89 ContainsPath { input: String },
90}
91
92impl HostPattern {
93 pub fn parse_with_methods(
108 input: &str,
109 methods: MethodsPattern,
110 ) -> Result<Self, HostPatternError> {
111 let mut host_pattern = Self::parse(input)?;
112 host_pattern.methods = methods;
113 Ok(host_pattern)
114 }
115
116 fn parse(input: &str) -> Result<Self, HostPatternError> {
117 let (scheme, rest) = if let Some(rest) = input.strip_prefix("*://") {
119 (SchemePattern::Any, rest)
120 } else if let Some(rest) = input.strip_prefix("https://") {
121 (SchemePattern::Https, rest)
122 } else if let Some(rest) = input.strip_prefix("http://") {
123 (SchemePattern::Http, rest)
124 } else {
125 (SchemePattern::Https, input)
126 };
127
128 if rest.contains('/') {
130 return Err(HostPatternError::ContainsPath {
131 input: input.to_string(),
132 });
133 }
134
135 let (host_port_str, any_port) = if let Some(stripped) = rest.strip_suffix(":*") {
137 (stripped, true)
138 } else {
139 (rest, false)
140 };
141
142 let (host, port) = if any_port {
143 (host_port_str.to_string(), PortPattern::Any)
144 } else if let Some((h, p)) = host_port_str.rsplit_once(':') {
145 if let Ok(port_num) = p.parse::<u16>() {
146 (h.to_string(), PortPattern::Specific(port_num))
147 } else {
148 (host_port_str.to_string(), PortPattern::Default)
150 }
151 } else {
152 (host_port_str.to_string(), PortPattern::Default)
153 };
154
155 if host.contains('*') && !host.starts_with('*') && !host.ends_with('*') {
157 return Err(HostPatternError::Wildcard { host });
158 }
159
160 Ok(HostPattern {
161 scheme,
162 host_pattern: host,
163 port,
164 methods: MethodsPattern::AllMethods,
165 })
166 }
167
168 #[must_use]
170 fn matches(&self, scheme: &str, host: &str, port: u16, method: &Method) -> bool {
171 let scheme_matches = match &self.scheme {
173 SchemePattern::Http => scheme == "http",
174 SchemePattern::Https => scheme == "https",
175 SchemePattern::Any => scheme == "http" || scheme == "https",
176 };
177 if !scheme_matches {
178 return false;
179 }
180
181 let port_matches = match &self.port {
183 PortPattern::Specific(p) => port == *p,
184 PortPattern::Any => true,
185 PortPattern::Default => {
186 match scheme {
188 "http" => port == 80,
189 "https" => port == 443,
190 _ => false,
191 }
192 }
193 };
194 if !port_matches {
195 return false;
196 }
197
198 if !match_wildcard(&self.host_pattern, host) {
200 return false;
201 }
202
203 match &self.methods {
205 MethodsPattern::AllMethods => true,
206 MethodsPattern::Specific(methods) => methods.contains(method),
207 }
208 }
209}
210
211impl fmt::Display for HostPattern {
212 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
213 write!(f, "{}://{}", self.scheme, self.host_pattern)?;
215
216 match &self.port {
218 PortPattern::Specific(p) => write!(f, ":{p}")?,
219 PortPattern::Any => write!(f, ":*")?,
220 PortPattern::Default => {} }
222
223 match &self.methods {
225 MethodsPattern::AllMethods => {} MethodsPattern::Specific(methods) if methods.is_empty() => {
227 write!(f, " [NONE]")?;
228 }
229 MethodsPattern::Specific(methods) => {
230 let method_strs: Vec<&str> = methods.iter().map(Method::as_str).collect();
231 write!(f, " [{}]", method_strs.join(", "))?;
232 }
233 }
234 Ok(())
235 }
236}
237
238fn match_wildcard(pattern: &str, value: &str) -> bool {
240 if pattern == "*" {
241 return true;
242 }
243 if let Some(suffix) = pattern.strip_prefix('*') {
244 return value.ends_with(suffix);
245 }
246 if let Some(prefix) = pattern.strip_suffix('*') {
247 return value.starts_with(prefix);
248 }
249 pattern == value
250}
251
252#[derive(Clone, Debug)]
254pub struct AllowedHostPolicy {
255 pub pattern: HostPattern,
256 pub secrets: Vec<PlaceholderSecret>,
257}
258
259#[derive(Clone, Debug, Default)]
261pub struct HttpRequestPolicy {
262 pub hosts: Vec<AllowedHostPolicy>,
263}
264
265#[must_use]
267pub fn is_text_content_type(content_type: &str) -> bool {
268 let ct = content_type.to_ascii_lowercase();
269 ct.starts_with("text/")
270 || ct.starts_with("application/json")
271 || ct.contains("+json")
272 || ct.starts_with("application/x-www-form-urlencoded")
273}
274
275fn extract_request_target(uri: &hyper::Uri) -> Option<(String, String, u16)> {
277 let scheme = uri.scheme_str().unwrap_or("https").to_string();
278 let host = uri.host()?.to_string();
279 let default_port = if scheme == "https" { 443 } else { 80 };
280 let port = uri.port_u16().unwrap_or(default_port);
281 Some((scheme, host, port))
282}
283
284#[derive(Debug, thiserror::Error)]
285pub(crate) enum PolicyError {
286 #[error("outgoing HTTP request has no host in URI: {0}")]
287 RequestHasNoHost(Uri),
288 #[error("outgoing HTTP {method} request to {scheme}://{host}:{port} denied")]
289 RequestDenied {
290 method: Method,
291 scheme: String,
292 host: String,
293 port: u16,
294 },
295}
296impl From<PolicyError> for ErrorCode {
297 fn from(_value: PolicyError) -> Self {
298 ErrorCode::HttpRequestDenied
299 }
300}
301
302impl HttpRequestPolicy {
303 pub(crate) fn apply(
306 &self,
307 request: &mut hyper::Request<wasmtime_wasi_http::p2::body::HyperOutgoingBody>,
308 ) -> Result<(), PolicyError> {
309 let Some((scheme, host, port)) = extract_request_target(request.uri()) else {
310 return Err(PolicyError::RequestHasNoHost(request.uri().clone()));
311 };
312 let method = request.method().clone();
313
314 let matching: Vec<&AllowedHostPolicy> = self
316 .hosts
317 .iter()
318 .filter(|h| h.pattern.matches(&scheme, &host, port, &method))
319 .collect();
320 if matching.is_empty() {
321 return Err(PolicyError::RequestDenied {
322 method,
323 scheme,
324 host,
325 port,
326 });
327 }
328
329 let applicable: Vec<&PlaceholderSecret> =
331 matching.iter().flat_map(|h| h.secrets.iter()).collect();
332
333 if applicable.is_empty() {
334 return Ok(());
335 }
336
337 let header_secrets: Vec<_> = applicable
339 .iter()
340 .filter(|s| s.replace_in.contains(&ReplacementLocation::Headers))
341 .collect();
342 if !header_secrets.is_empty() {
343 let headers = request.headers_mut();
344 let keys: Vec<_> = headers.keys().cloned().collect();
345 for key in keys {
346 if let Some(val) = headers.get(&key)
347 && let Ok(val_str) = val.to_str()
348 {
349 let mut replaced = val_str.to_string();
350 for secret in &header_secrets {
351 replaced = replaced
352 .replace(&secret.placeholder, secret.real_value.expose_secret());
353 }
354 if replaced != val_str
355 && let Ok(new_val) = hyper::header::HeaderValue::from_str(&replaced)
356 {
357 headers.insert(&key, new_val);
358 }
359 }
360 }
361 }
362
363 let param_secrets: Vec<_> = applicable
365 .iter()
366 .filter(|s| s.replace_in.contains(&ReplacementLocation::Params))
367 .collect();
368 if !param_secrets.is_empty() {
369 let uri_str = request.uri().to_string();
370 let mut uri_replaced = uri_str.clone();
371 for secret in ¶m_secrets {
372 uri_replaced =
373 uri_replaced.replace(&secret.placeholder, secret.real_value.expose_secret());
374 }
375 if uri_replaced != uri_str
376 && let Ok(new_uri) = uri_replaced.parse::<hyper::Uri>()
377 {
378 *request.uri_mut() = new_uri;
379 }
380 }
381
382 Ok(())
386 }
387
388 fn body_secrets_for(&self, uri: &hyper::Uri, method: &Method) -> Vec<&PlaceholderSecret> {
390 let Some((scheme, host, port)) = extract_request_target(uri) else {
391 return Vec::new();
392 };
393 self.hosts
394 .iter()
395 .filter(|h| h.pattern.matches(&scheme, &host, port, method))
396 .flat_map(|h| h.secrets.iter())
397 .filter(|s| s.replace_in.contains(&ReplacementLocation::Body))
398 .collect()
399 }
400
401 pub(crate) async fn apply_body_replacement(
405 &self,
406 request: &mut hyper::Request<wasmtime_wasi_http::p2::body::HyperOutgoingBody>,
407 ) {
408 let body_secrets = self.body_secrets_for(request.uri(), request.method());
409 if body_secrets.is_empty() {
410 trace!("No secrets, no modifications to HTTP body");
411 return;
412 }
413
414 let should_replace = request
416 .headers()
417 .get(hyper::header::CONTENT_TYPE)
418 .and_then(|v| v.to_str().ok())
419 .map(is_text_content_type)
420 .unwrap_or(false);
421 if !should_replace {
422 return;
423 }
424
425 let body = std::mem::take(request.body_mut());
427 let Ok(collected) = http_body_util::BodyExt::collect(body).await else {
429 return;
430 };
431 let body_bytes = collected.to_bytes();
432 let Ok(mut body_str) = String::from_utf8(body_bytes.to_vec()) else {
433 let restored =
435 http_body_util::combinators::UnsyncBoxBody::new(http_body_util::BodyExt::map_err(
436 http_body_util::Full::new(body_bytes),
437 |_| unreachable!(),
438 ));
439 *request.body_mut() = restored;
440 debug!("Not valid UTF-8, sending original HTTP body");
441 return;
442 };
443
444 for secret in &body_secrets {
446 body_str = body_str.replace(&secret.placeholder, secret.real_value.expose_secret());
447 }
448
449 let new_body =
450 http_body_util::combinators::UnsyncBoxBody::new(http_body_util::BodyExt::map_err(
451 http_body_util::Full::new(hyper::body::Bytes::from(body_str)),
452 |_| unreachable!(),
453 ));
454 *request.body_mut() = new_body;
455 debug!("Applied secrets to HTTP body");
456 }
457}
458
459#[must_use]
461pub fn generate_placeholder() -> String {
462 let mut random_bytes = [0u8; 32];
463 rand::rng().fill_bytes(&mut random_bytes);
464 use std::fmt::Write;
465 let hex = random_bytes
466 .iter()
467 .fold(String::with_capacity(64), |mut acc, b| {
468 let _ = write!(acc, "{b:02x}");
469 acc
470 });
471 format!("OBELISK_SECRET_{hex}")
472}
473
474#[derive(Clone, Debug)]
476pub struct AllowedHostConfig {
477 pub pattern: HostPattern,
478 pub secret_env_mappings: Vec<(String, SecretString)>,
480 pub replace_in: hashbrown::HashSet<ReplacementLocation>,
482}
483
484#[cfg(test)]
485mod tests {
486 use super::*;
487
488 #[test]
489 fn parse_host_pattern_bare_hostname() {
490 let p = HostPattern::parse("api.openai.com").unwrap();
491 assert_eq!(p.scheme, SchemePattern::Https);
492 assert_eq!(p.host_pattern, "api.openai.com");
493 assert_eq!(p.port, PortPattern::Default);
494 assert!(p.matches("https", "api.openai.com", 443, &Method::GET));
495 assert!(!p.matches("http", "api.openai.com", 80, &Method::GET));
496 }
497
498 #[test]
499 fn parse_host_pattern_with_scheme_and_port() {
500 let p = HostPattern::parse("http://localhost:8080").unwrap();
501 assert_eq!(p.scheme, SchemePattern::Http);
502 assert_eq!(p.host_pattern, "localhost");
503 assert_eq!(p.port, PortPattern::Specific(8080));
504 assert!(p.matches("http", "localhost", 8080, &Method::GET));
505 assert!(!p.matches("https", "localhost", 8080, &Method::GET));
506 }
507
508 #[test]
509 fn parse_host_pattern_http_default_port() {
510 let p = HostPattern::parse("http://example.com").unwrap();
511 assert_eq!(p.scheme, SchemePattern::Http);
512 assert_eq!(p.host_pattern, "example.com");
513 assert_eq!(p.port, PortPattern::Default);
514 assert!(p.matches("http", "example.com", 80, &Method::GET));
515 assert!(!p.matches("http", "example.com", 8080, &Method::GET));
516 }
517
518 #[test]
519 fn parse_host_pattern_wildcard_prefix() {
520 let p = HostPattern::parse("*.example.com").unwrap();
521 assert!(p.matches("https", "api.example.com", 443, &Method::GET));
522 assert!(p.matches("https", "foo.bar.example.com", 443, &Method::POST));
523 assert!(!p.matches("https", "example.com", 443, &Method::GET));
524 }
525
526 #[test]
527 fn parse_host_pattern_wildcard_suffix() {
528 let p = HostPattern::parse("192.168.1.*").unwrap();
529 assert!(p.matches("https", "192.168.1.100", 443, &Method::GET));
530 assert!(!p.matches("https", "192.168.2.100", 443, &Method::GET));
531 }
532
533 #[test]
534 fn parse_host_pattern_wildcard_all_https() {
535 let p = HostPattern::parse("*").unwrap();
536 assert!(p.matches("https", "anything.com", 443, &Method::GET));
537 assert!(!p.matches("http", "anything.com", 80, &Method::GET));
538 }
539
540 #[test]
541 fn parse_host_pattern_wildcard_http() {
542 let p = HostPattern::parse("http://*").unwrap();
543 assert!(!p.matches("https", "anything.com", 443, &Method::GET));
544 assert!(p.matches("http", "anything.com", 80, &Method::GET));
545 }
546
547 #[test]
548 fn parse_host_pattern_wildcard_middle_rejected() {
549 assert!(HostPattern::parse("foo.*.com").is_err());
550 }
551
552 #[test]
553 fn parse_host_pattern_trailing_slash_rejected() {
554 assert!(HostPattern::parse("http://localhost:8080/").is_err());
555 assert!(HostPattern::parse("https://api.example.com/v1").is_err());
556 assert!(HostPattern::parse("example.com/path").is_err());
557 }
558
559 #[test]
560 fn parse_host_pattern_https_non_default_port() {
561 let p = HostPattern::parse("internal.corp.com:8443").unwrap();
562 assert_eq!(p.scheme, SchemePattern::Https);
563 assert_eq!(p.host_pattern, "internal.corp.com");
564 assert_eq!(p.port, PortPattern::Specific(8443));
565 }
566
567 #[test]
568 fn host_pattern_method_restriction() {
569 let p = HostPattern::parse_with_methods(
570 "api.example.com",
571 MethodsPattern::Specific(vec![Method::GET, Method::HEAD]),
572 )
573 .unwrap();
574 assert!(p.matches("https", "api.example.com", 443, &Method::GET));
575 assert!(p.matches("https", "api.example.com", 443, &Method::HEAD));
576 assert!(!p.matches("https", "api.example.com", 443, &Method::POST));
577 assert!(!p.matches("https", "api.example.com", 443, &Method::DELETE));
578 }
579
580 #[test]
581 fn host_pattern_all_methods_allows_all() {
582 let p = HostPattern::parse("api.example.com").unwrap();
583 assert_eq!(p.methods, MethodsPattern::AllMethods);
584 assert!(p.matches("https", "api.example.com", 443, &Method::GET));
585 assert!(p.matches("https", "api.example.com", 443, &Method::POST));
586 assert!(p.matches("https", "api.example.com", 443, &Method::DELETE));
587 assert!(p.matches("https", "api.example.com", 443, &Method::PUT));
588 }
589
590 #[test]
591 fn host_pattern_empty_methods_matches_nothing() {
592 let p =
593 HostPattern::parse_with_methods("api.example.com", MethodsPattern::Specific(vec![]))
594 .unwrap();
595 assert!(!p.matches("https", "api.example.com", 443, &Method::GET));
596 assert!(!p.matches("https", "api.example.com", 443, &Method::POST));
597 assert!(!p.matches("https", "api.example.com", 443, &Method::DELETE));
598 }
599
600 #[test]
601 fn display_host_pattern_with_methods() {
602 let p = HostPattern::parse_with_methods(
603 "api.example.com",
604 MethodsPattern::Specific(vec![Method::GET, Method::POST]),
605 )
606 .unwrap();
607 assert_eq!(p.to_string(), "https://api.example.com [GET, POST]");
608 }
609
610 #[test]
611 fn parse_host_pattern_any_scheme_default_ports() {
612 let p = HostPattern::parse("*://*").unwrap();
614 assert_eq!(p.scheme, SchemePattern::Any);
615 assert_eq!(p.host_pattern, "*");
616 assert_eq!(p.port, PortPattern::Default);
617
618 assert!(p.matches("http", "foo.com", 80, &Method::GET));
620 assert!(p.matches("https", "foo.com", 443, &Method::GET));
622 assert!(!p.matches("http", "foo.com", 8080, &Method::GET));
624 assert!(!p.matches("https", "foo.com", 8443, &Method::GET));
626 }
627
628 #[test]
629 fn parse_host_pattern_any_scheme_any_port() {
630 let p = HostPattern::parse("*://*:*").unwrap();
632 assert_eq!(p.scheme, SchemePattern::Any);
633 assert_eq!(p.host_pattern, "*");
634 assert_eq!(p.port, PortPattern::Any);
635
636 assert!(p.matches("http", "foo.com", 80, &Method::GET));
638 assert!(p.matches("https", "foo.com", 443, &Method::GET));
639 assert!(p.matches("http", "foo.com", 8080, &Method::GET));
640 assert!(p.matches("https", "foo.com", 8443, &Method::GET));
641 assert!(p.matches("http", "localhost", 3000, &Method::POST));
642 }
643
644 #[test]
645 fn parse_host_pattern_any_port_specific_scheme() {
646 let p = HostPattern::parse("http://localhost:*").unwrap();
648 assert_eq!(p.scheme, SchemePattern::Http);
649 assert_eq!(p.host_pattern, "localhost");
650 assert_eq!(p.port, PortPattern::Any);
651
652 assert!(p.matches("http", "localhost", 80, &Method::GET));
653 assert!(p.matches("http", "localhost", 8080, &Method::GET));
654 assert!(p.matches("http", "localhost", 3000, &Method::GET));
655 assert!(!p.matches("https", "localhost", 443, &Method::GET));
656 assert!(!p.matches("http", "other.com", 80, &Method::GET));
657 }
658
659 #[test]
660 fn parse_host_pattern_wildcard_host_any_port() {
661 let p = HostPattern::parse("http://192.*:*").unwrap();
663 assert_eq!(p.scheme, SchemePattern::Http);
664 assert_eq!(p.host_pattern, "192.*");
665 assert_eq!(p.port, PortPattern::Any);
666
667 assert!(p.matches("http", "192.168.1.1", 80, &Method::GET));
668 assert!(p.matches("http", "192.168.1.1", 8080, &Method::GET));
669 assert!(p.matches("http", "192.0.0.1", 3000, &Method::POST));
670 assert!(!p.matches("https", "192.168.1.1", 443, &Method::GET));
671 assert!(!p.matches("http", "10.0.0.1", 80, &Method::GET));
672 }
673
674 #[test]
675 fn display_host_pattern_any_scheme() {
676 let p = HostPattern::parse("*://*").unwrap();
677 assert_eq!(p.to_string(), "*://*");
678
679 let p = HostPattern::parse("*://*:*").unwrap();
680 assert_eq!(p.to_string(), "*://*:*");
681
682 let p = HostPattern::parse("http://localhost:*").unwrap();
683 assert_eq!(p.to_string(), "http://localhost:*");
684 }
685
686 #[test]
687 fn display_host_pattern_empty_methods() {
688 let p =
689 HostPattern::parse_with_methods("api.example.com", MethodsPattern::Specific(vec![]))
690 .unwrap();
691 assert_eq!(p.to_string(), "https://api.example.com [NONE]");
692 }
693
694 #[test]
695 fn generate_placeholder_format() {
696 let p = generate_placeholder();
697 assert!(p.starts_with("OBELISK_SECRET_"));
698 assert_eq!(p.len(), 15 + 64); }
700
701 #[test]
702 fn generate_placeholder_unique() {
703 let p1 = generate_placeholder();
704 let p2 = generate_placeholder();
705 assert_ne!(p1, p2);
706 }
707
708 #[test]
709 fn display_host_pattern() {
710 let p = HostPattern::parse("api.openai.com").unwrap();
711 assert_eq!(p.to_string(), "https://api.openai.com");
712
713 let p = HostPattern::parse("http://localhost:8080").unwrap();
714 assert_eq!(p.to_string(), "http://localhost:8080");
715
716 let p = HostPattern::parse("internal.corp.com:8443").unwrap();
717 assert_eq!(p.to_string(), "https://internal.corp.com:8443");
718 }
719
720 #[test]
721 fn test_is_text_content_type() {
722 assert!(is_text_content_type("application/json"));
723 assert!(is_text_content_type("application/json; charset=utf-8"));
724 assert!(is_text_content_type("application/vnd.api+json"));
725 assert!(is_text_content_type("text/plain"));
726 assert!(is_text_content_type("text/html"));
727 assert!(is_text_content_type("application/x-www-form-urlencoded"));
728 assert!(!is_text_content_type("application/octet-stream"));
729 assert!(!is_text_content_type("image/png"));
730 }
731}