pylon_plugin/builtin/
cors.rs1use crate::{Plugin, PluginError};
2use pylon_auth::AuthContext;
3
4pub struct CorsPlugin {
6 pub allowed_origins: Vec<String>,
8 pub allow_credentials: bool,
10}
11
12impl CorsPlugin {
13 pub fn allow_all() -> Self {
15 Self {
16 allowed_origins: vec![],
17 allow_credentials: false,
18 }
19 }
20
21 pub fn new(origins: Vec<String>) -> Self {
23 Self {
24 allowed_origins: origins,
25 allow_credentials: true,
26 }
27 }
28
29 pub fn is_allowed(&self, origin: &str) -> bool {
31 if self.allowed_origins.is_empty() {
32 return true; }
34 self.allowed_origins.iter().any(|o| o == origin || o == "*")
35 }
36
37 pub fn allow_origin_header(&self, request_origin: Option<&str>) -> String {
39 if self.allowed_origins.is_empty() {
40 return "*".to_string();
41 }
42 match request_origin {
43 Some(origin) if self.is_allowed(origin) => origin.to_string(),
44 _ => String::new(),
45 }
46 }
47}
48
49impl Plugin for CorsPlugin {
50 fn name(&self) -> &str {
51 "cors"
52 }
53
54 fn on_request(
55 &self,
56 _method: &str,
57 _path: &str,
58 _auth: &AuthContext,
59 ) -> Result<(), PluginError> {
60 Ok(())
63 }
64}
65
66#[cfg(test)]
67mod tests {
68 use super::*;
69
70 #[test]
71 fn allow_all() {
72 let cors = CorsPlugin::allow_all();
73 assert!(cors.is_allowed("http://localhost:3000"));
74 assert!(cors.is_allowed("https://example.com"));
75 assert_eq!(cors.allow_origin_header(Some("http://localhost:3000")), "*");
76 }
77
78 #[test]
79 fn specific_origins() {
80 let cors = CorsPlugin::new(vec![
81 "http://localhost:3000".into(),
82 "https://myapp.com".into(),
83 ]);
84
85 assert!(cors.is_allowed("http://localhost:3000"));
86 assert!(cors.is_allowed("https://myapp.com"));
87 assert!(!cors.is_allowed("https://evil.com"));
88 }
89
90 #[test]
91 fn allow_origin_header_matches() {
92 let cors = CorsPlugin::new(vec!["https://myapp.com".into()]);
93
94 assert_eq!(
95 cors.allow_origin_header(Some("https://myapp.com")),
96 "https://myapp.com"
97 );
98 assert_eq!(cors.allow_origin_header(Some("https://evil.com")), "");
99 assert_eq!(cors.allow_origin_header(None), "");
100 }
101}