bolt_web/middleware/
cors.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use crate::{request::RequestBody, response::ResponseWriter, types::Middleware};
6
7#[allow(dead_code)]
8pub struct CorsConfig {
9    pub allowed_origins: Vec<String>,
10    pub allow_all: bool,
11    pub allow_methods: String,
12    pub allow_headers: String,
13    pub allow_credentials: bool,
14    pub max_age: Option<u32>,
15}
16
17impl Default for CorsConfig {
18    fn default() -> Self {
19        Self {
20            allowed_origins: vec!["*".into()],
21            allow_all: true,
22            allow_methods: "GET, POST, PUT, PATCH, DELETE, OPTIONS, HEAD".into(),
23            allow_headers: "Content-Type, Authorization".into(),
24            allow_credentials: false,
25            max_age: Some(86400),
26        }
27    }
28}
29
30#[allow(dead_code)]
31pub struct Cors {
32    pub config: Arc<CorsConfig>,
33}
34
35#[async_trait]
36impl Middleware for Cors {
37    async fn run(&self, req: &mut RequestBody, res: &mut ResponseWriter) {
38        let cfg = &self.config;
39
40        res.set_header("Access-Control-Allow-Methods", &cfg.allow_methods)
41            .set_header("Access-Control-Allow-Headers", &cfg.allow_headers);
42
43        if cfg.allow_all {
44            res.set_header("Access-Control-Allow-Origin", "*");
45        } else if let Some(origin) = req.get_headers("Origin") {
46            let origin_str = origin.to_str().unwrap_or("");
47            if cfg.allowed_origins.contains(&origin_str.to_string()) {
48                res.set_header("Access-Control-Allow-Origin", origin_str);
49            }
50        }
51
52        if cfg.allow_credentials {
53            res.set_header("Access-Control-Allow-Credentials", "true");
54        }
55
56        if let Some(max) = cfg.max_age {
57            res.set_header("Access-Control-Max-Age", &max.to_string());
58        }
59
60        if *req.method() == hyper::Method::OPTIONS {
61            res.status(204);
62        }
63    }
64}