use crate::{Plugin, PluginError};
use pylon_auth::AuthContext;
pub struct CorsPlugin {
pub allowed_origins: Vec<String>,
pub allow_credentials: bool,
}
impl CorsPlugin {
pub fn allow_all() -> Self {
Self {
allowed_origins: vec![],
allow_credentials: false,
}
}
pub fn new(origins: Vec<String>) -> Self {
Self {
allowed_origins: origins,
allow_credentials: true,
}
}
pub fn is_allowed(&self, origin: &str) -> bool {
if self.allowed_origins.is_empty() {
return true; }
self.allowed_origins.iter().any(|o| o == origin || o == "*")
}
pub fn allow_origin_header(&self, request_origin: Option<&str>) -> String {
if self.allowed_origins.is_empty() {
return "*".to_string();
}
match request_origin {
Some(origin) if self.is_allowed(origin) => origin.to_string(),
_ => String::new(),
}
}
}
impl Plugin for CorsPlugin {
fn name(&self) -> &str {
"cors"
}
fn on_request(
&self,
_method: &str,
_path: &str,
_auth: &AuthContext,
) -> Result<(), PluginError> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn allow_all() {
let cors = CorsPlugin::allow_all();
assert!(cors.is_allowed("http://localhost:3000"));
assert!(cors.is_allowed("https://example.com"));
assert_eq!(cors.allow_origin_header(Some("http://localhost:3000")), "*");
}
#[test]
fn specific_origins() {
let cors = CorsPlugin::new(vec![
"http://localhost:3000".into(),
"https://myapp.com".into(),
]);
assert!(cors.is_allowed("http://localhost:3000"));
assert!(cors.is_allowed("https://myapp.com"));
assert!(!cors.is_allowed("https://evil.com"));
}
#[test]
fn allow_origin_header_matches() {
let cors = CorsPlugin::new(vec!["https://myapp.com".into()]);
assert_eq!(
cors.allow_origin_header(Some("https://myapp.com")),
"https://myapp.com"
);
assert_eq!(cors.allow_origin_header(Some("https://evil.com")), "");
assert_eq!(cors.allow_origin_header(None), "");
}
}