1use base64::{Engine, engine::general_purpose};
2use bytes::Bytes;
3use cookie::{Cookie, SameSite};
4use http_body_util::Full;
5use hyper::{
6 HeaderMap, Response,
7 header::{HeaderName, HeaderValue},
8};
9use mime_guess::from_path;
10use serde::Serialize;
11use std::path::Path;
12use tokio::fs;
13use tokio::io::AsyncReadExt;
14
15use crate::http::StatusCode;
16
17pub struct ResponseWriter {
18 pub body: String,
19 pub headers: HeaderMap,
20 pub status: StatusCode,
21 pub has_error: bool,
22}
23
24#[allow(dead_code)]
25impl ResponseWriter {
26 pub fn new() -> Self {
27 Self {
28 body: "".into(),
29 headers: HeaderMap::new(),
30 status: StatusCode::OK,
31 has_error: false,
32 }
33 }
34
35 pub fn status(&mut self, status: StatusCode) -> &mut Self {
36 self.status = status;
37 self
38 }
39
40 pub fn set_header(&mut self, key: &str, value: &str) -> &mut Self {
41 self.headers.insert(
42 HeaderName::from_bytes(key.as_bytes()).unwrap(),
43 HeaderValue::from_str(value).unwrap(),
44 );
45 self
46 }
47
48 pub fn get_header(&self, key: &str) -> Option<&HeaderValue> {
49 self.headers.get(key)
50 }
51
52 pub fn send(&mut self, body: &str) -> &mut Self {
53 self.body = body.into();
54 self
55 }
56
57 pub fn json<T: Serialize>(&mut self, data: &T) -> &mut Self {
58 match serde_json::to_string(data) {
59 Ok(body) => {
60 self.set_header("Content-Type", "application/json");
61 self.body = body;
62 }
63 Err(_) => {
64 self.set_header("Content-Type", "application/json");
65 self.body = r#"{"error":"Failed to serialize JSON"}"#.to_string();
66 self.status = StatusCode::InternalServerError;
67 }
68 }
69 self
70 }
71
72 pub fn html(&mut self, html: &str) -> &mut Self {
73 self.set_header("Content-Type", "text/html; charset=utf-8");
74 self.body = html.to_string();
75 self
76 }
77
78 pub async fn file<P: AsRef<Path>>(&mut self, path: P) {
79 let path_ref = path.as_ref();
80
81 match fs::File::open(path_ref).await {
82 Ok(mut file) => {
83 let mut buf = Vec::new();
84 if let Err(e) = file.read_to_end(&mut buf).await {
85 self.error(
86 StatusCode::InternalServerError,
87 &format!("Failed to read file: {}", e),
88 );
89 return;
90 }
91
92 let mime_type = from_path(path_ref).first_or_octet_stream().to_string();
93
94 self.status(StatusCode::OK)
95 .set_header("Content-Type", &mime_type)
96 .bytes(&buf);
97 }
98 Err(_) => {
99 self.error(StatusCode::NotFound, "File not found");
100 }
101 }
102 }
103
104 pub fn bytes(&mut self, bytes: &[u8]) -> &mut Self {
105 let encoded = general_purpose::STANDARD.encode(bytes);
106 self.body = encoded;
107 self.set_header("Content-Type", "application/octet-stream");
108 self
109 }
110
111 pub fn get_code(&self, code: StatusCode) -> u16 {
112 match code {
113 StatusCode::Continue => return 100,
114 StatusCode::SwitchingProtocols => return 101,
115 StatusCode::Processing => return 102,
116 StatusCode::EarlyHints => return 103,
117 StatusCode::OK => return 200,
118 StatusCode::Created => return 201,
119 StatusCode::Accepted => return 202,
120 StatusCode::NonAuthoritativeInformation => return 203,
121 StatusCode::NoContent => return 204,
122 StatusCode::ResetContent => return 205,
123 StatusCode::PartialContent => return 206,
124 StatusCode::MovedPermanently => return 301,
125 StatusCode::Found => return 302,
126 StatusCode::SeeOther => return 303,
127 StatusCode::NotModified => return 304,
128 StatusCode::TemporaryRedirect => return 307,
129 StatusCode::PermanentRedirect => return 308,
130 StatusCode::BadRequest => return 400,
131 StatusCode::Unauthorized => return 401,
132 StatusCode::PaymentRequired => return 402,
133 StatusCode::Forbidden => return 403,
134 StatusCode::NotFound => return 404,
135 StatusCode::MethodNotAllowed => return 405,
136 StatusCode::NotAcceptable => return 406,
137 StatusCode::ProxyAuthenticationRequired => return 407,
138 StatusCode::RequestTimeout => return 408,
139 StatusCode::Conflict => return 409,
140 StatusCode::Gone => return 410,
141 StatusCode::LengthRequired => return 411,
142 StatusCode::PreconditionFailed => return 412,
143 StatusCode::ContentTooLarge => return 413,
144 StatusCode::URITooLong => return 414,
145 StatusCode::UnsupportedMediaType => return 415,
146 StatusCode::TooManyRequests => return 429,
147 StatusCode::InternalServerError => return 500,
148 StatusCode::NotImplemented => return 501,
149 StatusCode::BadGateway => return 502,
150 StatusCode::ServiceUnavailable => return 503,
151 StatusCode::GatewayTimeout => return 504,
152 StatusCode::HTTPVersionNotSupported => return 505,
153 }
154 }
155
156 pub fn error(&mut self, status: StatusCode, msg: &str) -> &mut Self {
157 self.status = status;
158 self.body = msg.to_string();
159 self.has_error = true;
160 self
161 }
162
163 pub fn has_error(&self) -> bool {
164 self.has_error
165 }
166
167 pub fn cookie(
168 &mut self,
169 name: &str,
170 value: &str,
171 max_age: Option<i64>,
172 path: Option<&str>,
173 domain: Option<&str>,
174 secure: bool,
175 http_only: bool,
176 same_site: Option<&str>,
177 ) -> &mut Self {
178 let mut cookie_builder = Cookie::build((name, value))
179 .path(path.unwrap_or("/"))
180 .secure(secure)
181 .http_only(http_only);
182
183 if let Some(d) = domain {
184 cookie_builder = cookie_builder.domain(d);
185 }
186
187 if let Some(age) = max_age {
188 cookie_builder = cookie_builder.max_age(time::Duration::seconds(age));
189 }
190
191 if let Some(ss) = same_site {
192 match ss.to_lowercase().as_str() {
193 "lax" => cookie_builder = cookie_builder.same_site(SameSite::Lax),
194 "strict" => cookie_builder = cookie_builder.same_site(SameSite::Strict),
195 "none" => cookie_builder = cookie_builder.same_site(SameSite::None).secure(true),
196 _ => {}
197 }
198 }
199
200 self.headers.append(
201 hyper::header::SET_COOKIE,
202 hyper::header::HeaderValue::from_str(&cookie_builder.to_string()).unwrap(),
203 );
204
205 self
206 }
207
208 pub fn into_response(&self) -> Response<Full<Bytes>> {
209 let status = &self.status;
210
211 let status_code = self.get_code(status.clone());
212 let body = &self.body;
213 let mut builder = Response::builder().status(status_code);
214
215 for (key, value) in self.headers.iter() {
216 builder = builder.header(key, value);
217 }
218
219 builder
220 .body(Full::new(Bytes::from(body.to_owned())))
221 .unwrap()
222 }
223
224 pub fn strip_header(&mut self, key: &str) {
225 if let Ok(key_name) = hyper::header::HeaderName::from_bytes(key.as_bytes()) {
226 self.headers.remove(key_name);
227 }
228 }
229}