bolt_web/
response.rs

1use base64::{Engine, engine::general_purpose};
2use bytes::Bytes;
3use cookie::{Cookie, SameSite};
4use http_body_util::Full;
5use hyper::{
6    HeaderMap, Response,
7    header::{HeaderName, HeaderValue},
8};
9use mime_guess::from_path;
10use serde::Serialize;
11use std::path::Path;
12use tokio::fs;
13use tokio::io::AsyncReadExt;
14
15use crate::http::StatusCode;
16
17pub struct ResponseWriter {
18    pub body: String,
19    pub headers: HeaderMap,
20    pub status: StatusCode,
21    pub has_error: bool,
22}
23
24#[allow(dead_code)]
25impl ResponseWriter {
26    pub fn new() -> Self {
27        Self {
28            body: "".into(),
29            headers: HeaderMap::new(),
30            status: StatusCode::OK,
31            has_error: false,
32        }
33    }
34
35    pub fn status(&mut self, status: StatusCode) -> &mut Self {
36        self.status = status;
37        self
38    }
39
40    pub fn set_header(&mut self, key: &str, value: &str) -> &mut Self {
41        self.headers.insert(
42            HeaderName::from_bytes(key.as_bytes()).unwrap(),
43            HeaderValue::from_str(value).unwrap(),
44        );
45        self
46    }
47
48    pub fn get_header(&self, key: &str) -> Option<&HeaderValue> {
49        self.headers.get(key)
50    }
51
52    pub fn send(&mut self, body: &str) -> &mut Self {
53        self.body = body.into();
54        self
55    }
56
57    pub fn json<T: Serialize>(&mut self, data: &T) -> &mut Self {
58        match serde_json::to_string(data) {
59            Ok(body) => {
60                self.set_header("Content-Type", "application/json");
61                self.body = body;
62            }
63            Err(_) => {
64                self.set_header("Content-Type", "application/json");
65                self.body = r#"{"error":"Failed to serialize JSON"}"#.to_string();
66                self.status = StatusCode::InternalServerError;
67            }
68        }
69        self
70    }
71
72    pub fn html(&mut self, html: &str) -> &mut Self {
73        self.set_header("Content-Type", "text/html; charset=utf-8");
74        self.body = html.to_string();
75        self
76    }
77
78    pub async fn file<P: AsRef<Path>>(&mut self, path: P) {
79        let path_ref = path.as_ref();
80
81        match fs::File::open(path_ref).await {
82            Ok(mut file) => {
83                let mut buf = Vec::new();
84                if let Err(e) = file.read_to_end(&mut buf).await {
85                    self.error(
86                        StatusCode::InternalServerError,
87                        &format!("Failed to read file: {}", e),
88                    );
89                    return;
90                }
91
92                let mime_type = from_path(path_ref).first_or_octet_stream().to_string();
93
94                self.status(StatusCode::OK)
95                    .set_header("Content-Type", &mime_type)
96                    .bytes(&buf);
97            }
98            Err(_) => {
99                self.error(StatusCode::NotFound, "File not found");
100            }
101        }
102    }
103
104    pub fn bytes(&mut self, bytes: &[u8]) -> &mut Self {
105        let encoded = general_purpose::STANDARD.encode(bytes);
106        self.body = encoded;
107        self.set_header("Content-Type", "application/octet-stream");
108        self
109    }
110
111    pub fn get_code(&self, code: StatusCode) -> u16 {
112        match code {
113            StatusCode::Continue => return 100,
114            StatusCode::SwitchingProtocols => return 101,
115            StatusCode::Processing => return 102,
116            StatusCode::EarlyHints => return 103,
117            StatusCode::OK => return 200,
118            StatusCode::Created => return 201,
119            StatusCode::Accepted => return 202,
120            StatusCode::NonAuthoritativeInformation => return 203,
121            StatusCode::NoContent => return 204,
122            StatusCode::ResetContent => return 205,
123            StatusCode::PartialContent => return 206,
124            StatusCode::MovedPermanently => return 301,
125            StatusCode::Found => return 302,
126            StatusCode::SeeOther => return 303,
127            StatusCode::NotModified => return 304,
128            StatusCode::TemporaryRedirect => return 307,
129            StatusCode::PermanentRedirect => return 308,
130            StatusCode::BadRequest => return 400,
131            StatusCode::Unauthorized => return 401,
132            StatusCode::PaymentRequired => return 402,
133            StatusCode::Forbidden => return 403,
134            StatusCode::NotFound => return 404,
135            StatusCode::MethodNotAllowed => return 405,
136            StatusCode::NotAcceptable => return 406,
137            StatusCode::ProxyAuthenticationRequired => return 407,
138            StatusCode::RequestTimeout => return 408,
139            StatusCode::Conflict => return 409,
140            StatusCode::Gone => return 410,
141            StatusCode::LengthRequired => return 411,
142            StatusCode::PreconditionFailed => return 412,
143            StatusCode::ContentTooLarge => return 413,
144            StatusCode::URITooLong => return 414,
145            StatusCode::UnsupportedMediaType => return 415,
146            StatusCode::TooManyRequests => return 429,
147            StatusCode::InternalServerError => return 500,
148            StatusCode::NotImplemented => return 501,
149            StatusCode::BadGateway => return 502,
150            StatusCode::ServiceUnavailable => return 503,
151            StatusCode::GatewayTimeout => return 504,
152            StatusCode::HTTPVersionNotSupported => return 505,
153        }
154    }
155
156    pub fn error(&mut self, status: StatusCode, msg: &str) -> &mut Self {
157        self.status = status;
158        self.body = msg.to_string();
159        self.has_error = true;
160        self
161    }
162
163    pub fn has_error(&self) -> bool {
164        self.has_error
165    }
166
167    pub fn cookie(
168        &mut self,
169        name: &str,
170        value: &str,
171        max_age: Option<i64>,
172        path: Option<&str>,
173        domain: Option<&str>,
174        secure: bool,
175        http_only: bool,
176        same_site: Option<&str>,
177    ) -> &mut Self {
178        let mut cookie_builder = Cookie::build((name, value))
179            .path(path.unwrap_or("/"))
180            .secure(secure)
181            .http_only(http_only);
182
183        if let Some(d) = domain {
184            cookie_builder = cookie_builder.domain(d);
185        }
186
187        if let Some(age) = max_age {
188            cookie_builder = cookie_builder.max_age(time::Duration::seconds(age));
189        }
190
191        if let Some(ss) = same_site {
192            match ss.to_lowercase().as_str() {
193                "lax" => cookie_builder = cookie_builder.same_site(SameSite::Lax),
194                "strict" => cookie_builder = cookie_builder.same_site(SameSite::Strict),
195                "none" => cookie_builder = cookie_builder.same_site(SameSite::None).secure(true),
196                _ => {}
197            }
198        }
199
200        self.headers.append(
201            hyper::header::SET_COOKIE,
202            hyper::header::HeaderValue::from_str(&cookie_builder.to_string()).unwrap(),
203        );
204
205        self
206    }
207
208    pub fn into_response(&self) -> Response<Full<Bytes>> {
209        let status = &self.status;
210
211        let status_code = self.get_code(status.clone());
212        let body = &self.body;
213        let mut builder = Response::builder().status(status_code);
214
215        for (key, value) in self.headers.iter() {
216            builder = builder.header(key, value);
217        }
218
219        builder
220            .body(Full::new(Bytes::from(body.to_owned())))
221            .unwrap()
222    }
223
224    pub fn strip_header(&mut self, key: &str) {
225        if let Ok(key_name) = hyper::header::HeaderName::from_bytes(key.as_bytes()) {
226            self.headers.remove(key_name);
227        }
228    }
229}