use crate::{
App, HttpResponse, HttpResult,
headers::{ACCESS_CONTROL_REQUEST_METHOD, ORIGIN},
http::{HttpBody, Method, StatusCode, cors::CorsHeaders},
middleware::{HttpContext, Middleware, NextFn},
};
use hyper::Response;
use std::sync::Arc;
struct Cors {
default_cors: Option<Arc<CorsHeaders>>,
}
impl Middleware for Cors {
fn call(
&self,
ctx: HttpContext,
next: NextFn,
) -> impl Future<Output = HttpResult> + Send + 'static {
let default_cors = self.default_cors.clone();
async move {
let Some(cors) = ctx.resolve_cors(default_cors.as_ref()) else {
return next(ctx).await;
};
let request = ctx.request();
let method = request.method();
if method == Method::OPTIONS {
let origin = request.headers().get(&ORIGIN);
let acrm = request
.headers()
.get(ACCESS_CONTROL_REQUEST_METHOD)
.and_then(|v| Method::from_bytes(v.as_bytes()).ok());
if origin.is_some() && acrm.is_some() {
let mut response = Response::new(HttpBody::empty());
*response.status_mut() = StatusCode::NO_CONTENT;
cors.apply_preflight_response(response.headers_mut(), origin.cloned());
return Ok(HttpResponse::from_inner(response));
}
}
let origin = request.headers().get(&ORIGIN).cloned();
let mut response = next(ctx).await?;
cors.apply_normal_response(response.headers_mut(), origin);
Ok(response)
}
}
}
impl App {
pub fn use_cors(&mut self) -> &mut Self {
if !self.cors.registered() {
panic!(
"CORS error: Missing CORS configuration, you can configure it with `App::new().with_cors(|cors| cors...)`"
);
}
self.cors.is_enabled = true;
let default_cors = self.cors.get_default().cloned();
self.attach(Cors { default_cors });
self
}
}
#[cfg(test)]
mod tests {
use crate::App;
#[test]
#[should_panic]
fn it_panics_due_missing_cors_config() {
let mut app = App::new();
app.use_cors();
}
#[test]
fn it_validates_cors_config() {
let mut app = App::new().with_cors(|cors| cors.without_credentials());
app.use_cors();
}
}