use beet_core::prelude::*;
use beet_flow::prelude::*;
pub fn no_cache_headers() -> impl Bundle {
(
Name::new("No-Cache Headers Middleware"),
OnSpawn::observe(
|ev: On<GetOutcome>, agents: AgentQuery, mut commands: Commands| {
let action = ev.target();
let agent = agents.entity(action);
commands.queue(move |world: &mut World| -> Result {
let mut entity = world.entity_mut(agent);
let Some(mut response) = entity.get_mut::<Response>()
else {
cross_log!(
"No Response found for no_cache_headers middleware"
);
return Ok(());
};
let parts = response.parts_mut();
parts.insert_header(
"cache-control",
"no-cache, no-store, must-revalidate",
);
parts.insert_header("pragma", "no-cache");
parts.insert_header("expires", "0");
Ok(())
});
commands.entity(action).trigger_target(Outcome::Pass);
},
),
)
}
#[derive(Debug, Default, Clone, Resource, Reflect)]
#[reflect(Resource)]
pub struct CorsConfig {
pub allow_any_origin: bool,
allowed_origins: Vec<String>,
}
impl CorsConfig {
pub const ANY_ORIGIN: &'static str = "*";
pub fn new(
allow_any_origin: bool,
allowed_origins: Vec<&'static str>,
) -> Self {
Self {
allow_any_origin,
allowed_origins: allowed_origins
.into_iter()
.map(|s| s.to_string())
.collect(),
}
}
pub fn origin_allowed(&self, origin: &str) -> bool {
self.allow_any_origin
|| self.allowed_origins.iter().any(|o| o == origin)
}
}
#[derive(Debug, Clone, Component)]
pub struct ValidatedOrigin(pub String);
pub fn cors_request(config: CorsConfig) -> impl Bundle {
(
Name::new("CORS Request Middleware"),
OnSpawn::observe(
move |ev: On<GetOutcome>,
agents: AgentQuery,
mut commands: Commands| {
let action = ev.target();
let agent = agents.entity(action);
let config = config.clone();
commands.queue(move |world: &mut World| -> Result {
let origin_header = world
.entity(agent)
.get::<Request>()
.ok_or_else(|| {
bevyhow!("No Request found for CORS middleware")
})?
.get_header("origin")
.map(|s| s.to_string());
let origin = match (config.allow_any_origin, origin_header)
{
(true, None) => CorsConfig::ANY_ORIGIN.to_string(),
(true, Some(origin)) => origin,
(false, None) => {
world.entity_mut(agent).insert(
Response::from_status_body(
StatusCode::MalformedRequest,
b"Origin header not found",
"text/plain",
),
);
world
.entity_mut(action)
.trigger_target(Outcome::Fail);
return Ok(());
}
(false, Some(origin)) => origin,
};
if !config.origin_allowed(&origin) {
world.entity_mut(agent).insert(
Response::from_status_body(
StatusCode::Forbidden,
b"Origin not allowed",
"text/plain",
),
);
world.entity_mut(action).trigger_target(Outcome::Fail);
return Ok(());
}
world.entity_mut(agent).insert(ValidatedOrigin(origin));
world.entity_mut(action).trigger_target(Outcome::Pass);
Ok(())
});
},
),
)
}
pub fn cors_response(_config: CorsConfig) -> impl Bundle {
(
Name::new("CORS Response Middleware"),
OnSpawn::observe(
|ev: On<GetOutcome>, agents: AgentQuery, mut commands: Commands| {
let action = ev.target();
let agent = agents.entity(action);
commands.queue(move |world: &mut World| -> Result {
let origin = world
.entity(agent)
.get::<ValidatedOrigin>()
.ok_or_else(|| {
bevyhow!(
"No ValidatedOrigin found for CORS response middleware"
)
})?
.0
.clone();
let mut entity = world.entity_mut(agent);
let Some(mut response) = entity.get_mut::<Response>()
else {
cross_log!(
"No Response found for CORS response middleware"
);
return Ok(());
};
response
.parts_mut()
.insert_header("access-control-allow-origin", &origin);
Ok(())
});
commands.entity(action).trigger_target(Outcome::Pass);
},
),
)
}
pub fn cors_preflight(config: CorsConfig) -> impl Bundle {
(
Name::new("CORS Preflight Middleware"),
OnSpawn::observe(
move |ev: On<GetOutcome>,
agents: AgentQuery,
mut commands: Commands| {
let action = ev.target();
let agent = agents.entity(action);
let config = config.clone();
commands.queue(move |world: &mut World| -> Result {
let request = world
.entity(agent)
.get::<Request>()
.ok_or_else(|| {
bevyhow!(
"No Request found for CORS preflight middleware"
)
})?;
if *request.method() != HttpMethod::Options {
world.entity_mut(action).trigger_target(Outcome::Pass);
return Ok(());
}
let origin_header =
request.get_header("origin").map(|s| s.to_string());
let origin = match (config.allow_any_origin, origin_header)
{
(true, Some(origin)) => origin,
(true, None) => CorsConfig::ANY_ORIGIN.to_string(),
(false, None) => {
world.entity_mut(agent).insert(
Response::from_status_body(
StatusCode::MalformedRequest,
b"Origin header not found",
"text/plain",
),
);
world
.entity_mut(action)
.trigger_target(Outcome::Fail);
return Ok(());
}
(false, Some(origin)) => origin,
};
if !config.origin_allowed(&origin) {
world.entity_mut(agent).insert(
Response::from_status_body(
StatusCode::Forbidden,
b"Origin not allowed",
"text/plain",
),
);
world.entity_mut(action).trigger_target(Outcome::Fail);
return Ok(());
}
world
.entity_mut(agent)
.insert(ValidatedOrigin(origin.clone()));
let mut response = Response::ok();
let parts = response.parts_mut();
parts.insert_header("access-control-max-age", "60");
parts.insert_header(
"access-control-allow-headers",
"content-type",
);
parts.insert_header("access-control-allow-origin", &origin);
world.entity_mut(agent).insert(response);
world.entity_mut(action).trigger_target(Outcome::Pass);
Ok(())
});
},
),
)
}
#[cfg(test)]
mod test {
use super::*;
use crate::prelude::*;
use beet_net::prelude::*;
#[beet_core::test]
async fn no_cache_headers_works() {
RouterPlugin::world()
.spawn(ExchangeSpawner::new_flow(|| {
(InfallibleSequence, children![
EndpointBuilder::get().with_handler(|| "Hello"),
no_cache_headers(),
])
}))
.oneshot(Request::get("/"))
.await
.xtap(|response| {
response
.get_header("cache-control")
.unwrap()
.xpect_eq("no-cache, no-store, must-revalidate");
response.get_header("pragma").unwrap().xpect_eq("no-cache");
response.get_header("expires").unwrap().xpect_eq("0");
});
}
#[beet_core::test]
async fn cors_allows_origin() {
let config = CorsConfig::new(false, vec!["https://allowed.com"]);
RouterPlugin::world()
.spawn(ExchangeSpawner::new_flow(|| {
(InfallibleSequence, children![
cors_request(config.clone()),
EndpointBuilder::get().with_handler(|| "Hello"),
cors_response(config),
])
}))
.oneshot(
Request::get("/").with_header("origin", "https://allowed.com"),
)
.await
.xtap(|response| {
response.status().xpect_eq(StatusCode::Ok);
response
.get_header("access-control-allow-origin")
.unwrap()
.xpect_eq("https://allowed.com");
});
}
#[beet_core::test]
async fn cors_blocks_origin() {
let config = CorsConfig::new(false, vec![]);
RouterPlugin::world()
.spawn(ExchangeSpawner::new_flow(|| {
(Sequence, children![
cors_request(config.clone()),
EndpointBuilder::get().with_handler(|| "Hello"),
cors_response(config),
])
}))
.oneshot(
Request::get("/").with_header("origin", "https://blocked.com"),
)
.await
.status()
.xpect_eq(StatusCode::Forbidden);
}
#[beet_core::test]
async fn cors_allows_any() {
let config = CorsConfig::new(true, vec![]);
RouterPlugin::world()
.spawn(ExchangeSpawner::new_flow(|| {
(InfallibleSequence, children![
cors_request(config.clone()),
EndpointBuilder::get().with_handler(|| "Hello"),
cors_response(config),
])
}))
.oneshot(
Request::get("/").with_header("origin", "https://anything.com"),
)
.await
.xtap(|response| {
response.status().xpect_eq(StatusCode::Ok);
response
.get_header("access-control-allow-origin")
.unwrap()
.xpect_eq("https://anything.com");
});
}
#[beet_core::test]
async fn cors_preflight_works() {
let config = CorsConfig::new(false, vec!["https://allowed.com"]);
RouterPlugin::world()
.spawn(ExchangeSpawner::new_flow(move || {
(Fallback, children![
cors_preflight(config.clone()),
EndpointBuilder::any_method().with_handler(|| "Hello"),
])
}))
.oneshot(
Request::options("/")
.with_header("origin", "https://allowed.com"),
)
.await
.xtap(|response| {
response.status().xpect_eq(StatusCode::Ok);
response
.get_header("access-control-allow-origin")
.unwrap()
.xpect_eq("https://allowed.com");
response
.get_header("access-control-max-age")
.unwrap()
.xpect_eq("60");
});
}
#[beet_core::test]
async fn cors_preflight_non_options_passthrough() {
let config = CorsConfig::new(true, vec![]);
RouterPlugin::world()
.spawn(ExchangeSpawner::new_flow(move || {
(InfallibleSequence, children![
cors_preflight(config.clone()),
cors_request(config.clone()),
EndpointBuilder::get().with_handler(|| "Hello"),
cors_response(config),
])
}))
.oneshot(
Request::get("/").with_header("origin", "https://example.com"),
)
.await
.xtap(|response| {
response.status().xpect_eq(StatusCode::Ok);
response
.get_header("access-control-allow-origin")
.unwrap()
.xpect_eq("https://example.com");
});
}
#[beet_core::test]
async fn multiple_middleware_chain() {
let config = CorsConfig::new(true, vec![]);
RouterPlugin::world()
.spawn(ExchangeSpawner::new_flow(move || {
(InfallibleSequence, children![
cors_request(config.clone()),
EndpointBuilder::get().with_handler(|| "Hello"),
cors_response(config),
no_cache_headers(),
])
}))
.oneshot(
Request::get("/").with_header("origin", "https://example.com"),
)
.await
.xtap(|response| {
response.status().xpect_eq(StatusCode::Ok);
response
.get_header("access-control-allow-origin")
.unwrap()
.xpect_eq("https://example.com");
response
.get_header("cache-control")
.unwrap()
.xpect_eq("no-cache, no-store, must-revalidate");
});
}
}