capybara_core/protocol/http/
response.rs1use std::sync::atomic::{AtomicBool, Ordering};
2use std::sync::Arc;
3
4use arc_swap::ArcSwap;
5use bitflags::bitflags;
6use bytes::{BufMut, Bytes, BytesMut};
7use once_cell::sync::Lazy;
8use smallvec::SmallVec;
9use tokio::io::AsyncWriteExt;
10
11use crate::Result;
12
13use super::frame::{HeadersBuilder, StatusLine};
14use super::httpfield::HttpField;
15use super::misc;
16
17type DateBytes = SmallVec<[u8; 32]>;
18
19static DATE: Lazy<ArcSwap<DateBytes>> = Lazy::new(|| ArcSwap::new(Arc::new(DateBytes::new())));
20static DATE_TICKER_STARTED: Lazy<AtomicBool> = Lazy::new(AtomicBool::default);
21
22async fn start_date_ticker() {
23 use tokio::time::{sleep, Duration};
24 info!("start interval of http date generator ok");
26 loop {
27 let sleep_nanos = 1_000_000_000i64 - (chrono::Utc::now().timestamp_subsec_nanos() as i64);
29 if sleep_nanos > 0 {
30 sleep(Duration::from_nanos(sleep_nanos as u64)).await;
31 DATE.store(Arc::new(generate_date_bytes()));
33 } else {
34 sleep(Duration::from_millis(1)).await;
36 }
37 }
38}
39
40#[inline(always)]
41fn generate_date_bytes() -> DateBytes {
42 let mut b = DateBytes::new();
44 {
45 use std::io::Write as _;
46 write!(
47 &mut b,
48 "{}",
49 chrono::Utc::now().format("%a, %d %b %Y %T GMT")
50 )
51 .ok();
52 }
53 b
54}
55
56#[inline]
57async fn write_header_date<W>(w: &mut W, lowercase: bool) -> Result<()>
58where
59 W: AsyncWriteExt + Unpin,
60{
61 if lowercase {
62 w.write_all(HttpField::Date.as_str().to_ascii_lowercase().as_bytes())
63 .await?;
64 } else {
65 w.write_all(HttpField::Date.as_bytes()).await?;
66 }
67
68 w.write_all(b": ").await?;
69
70 loop {
71 {
72 let loaded = DATE.load();
73 if !loaded.is_empty() {
74 w.write_all(&loaded[..]).await?;
75 break;
76 }
77 }
78
79 if let Ok(origin) =
80 DATE_TICKER_STARTED.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
81 {
82 if !origin {
83 DATE.store(Arc::new(generate_date_bytes()));
85
86 tokio::spawn(async {
87 start_date_ticker().await;
88 });
89 }
90 }
91 }
92
93 w.write_all(misc::CRLF).await?;
94 Ok(())
95}
96
97#[derive(Debug, Copy, Clone, Default, Hash, PartialEq, Eq)]
98pub struct ResponseFlags(u8);
99
100bitflags! {
101 impl ResponseFlags: u8 {
102 const LOWERCASE_DATE_HEADER = 1 << 0;
103 }
104}
105
106#[derive(Debug, Clone)]
107pub struct Response {
108 pub status_line: StatusLine,
109 pub headers: Bytes,
110 pub body: Option<Bytes>,
111}
112
113impl Response {
114 pub fn builder() -> ResponseBuilder {
115 Default::default()
116 }
117
118 pub async fn write_to<W>(&self, w: &mut W, flags: ResponseFlags) -> Result<()>
119 where
120 W: AsyncWriteExt + Unpin,
121 {
122 w.write_all(&self.status_line.0[..]).await?;
123
124 write_header_date(w, flags.contains(ResponseFlags::LOWERCASE_DATE_HEADER)).await?;
125
126 w.write_all(&self.headers[..]).await?;
127 if let Some(body) = &self.body {
128 w.write_all(&body[..]).await?;
129 }
130 Ok(())
131 }
132}
133
134#[derive(Debug, Copy, Clone, Default, Hash, PartialEq, Eq)]
135struct ResponseBuilderFlags(u16);
136
137bitflags! {
138 impl ResponseBuilderFlags: u16 {
139 const FLAG_DNT_SERVER = 1 << 0;
140 const FLAG_NO_HTTP11 = 1 << 1;
141 const FLAG_NO_KEEPALIVE = 1 << 2;
142 const FLAG_LOWERCASE_HEADER = 1 << 3;
143 }
144}
145
146pub struct ResponseBuilder {
147 headers: HeadersBuilder,
148 body_buf: BytesMut,
149 code: u16,
150 flag: ResponseBuilderFlags,
151}
152
153impl Default for ResponseBuilder {
154 fn default() -> Self {
155 ResponseBuilder {
156 code: 200,
157 headers: Default::default(),
158 body_buf: Default::default(),
159 flag: ResponseBuilderFlags::default(),
160 }
161 }
162}
163
164impl ResponseBuilder {
165 pub fn status_code(mut self, status_code: u16) -> Self {
166 self.code = status_code;
167 self
168 }
169
170 pub fn use_lowercase_header(mut self) -> Self {
171 self.flag |= ResponseBuilderFlags::FLAG_LOWERCASE_HEADER;
172 self
173 }
174
175 pub fn content_type<T>(mut self, typ: T) -> Self
176 where
177 T: AsRef<str>,
178 {
179 if self
180 .flag
181 .contains(ResponseBuilderFlags::FLAG_LOWERCASE_HEADER)
182 {
183 self.headers = self.headers.put(
184 HttpField::ContentType.as_str().to_ascii_lowercase(),
185 typ.as_ref(),
186 );
187 } else {
188 self.headers = self
189 .headers
190 .put(HttpField::ContentType.as_str(), typ.as_ref());
191 }
192
193 self
194 }
195
196 pub fn header<K, V>(mut self, key: K, value: V) -> Self
197 where
198 K: AsRef<str>,
199 V: AsRef<str>,
200 {
201 if self
202 .flag
203 .contains(ResponseBuilderFlags::FLAG_LOWERCASE_HEADER)
204 {
205 self.headers = self.headers.put(key.as_ref().to_ascii_lowercase(), value);
206 } else {
207 self.headers = self.headers.put(key, value);
208 }
209
210 self
211 }
212
213 pub fn body<B>(mut self, value: B) -> Self
214 where
215 B: AsRef<[u8]>,
216 {
217 let value = value.as_ref();
218 self.body_buf.put(value);
219 self
220 }
221
222 pub fn disable_server(mut self) -> Self {
223 self.flag |= ResponseBuilderFlags::FLAG_DNT_SERVER;
224 self
225 }
226
227 pub fn build(mut self) -> Response {
228 let lowercase_headers = self
229 .flag
230 .contains(ResponseBuilderFlags::FLAG_LOWERCASE_HEADER);
231
232 let size = self.body_buf.len();
233
234 if lowercase_headers {
235 self.headers = self.headers.put(
236 HttpField::ContentLength.as_str().to_ascii_lowercase(),
237 size.to_string(),
238 );
239 } else {
240 self.headers = self
241 .headers
242 .put(HttpField::ContentLength.as_str(), size.to_string());
243 }
244
245 if !self.flag.contains(ResponseBuilderFlags::FLAG_DNT_SERVER) {
247 if lowercase_headers {
248 self.headers = self.headers.put(
249 HttpField::Server.as_str().to_ascii_lowercase(),
250 misc::SERVER.as_str(),
251 );
252 } else {
253 self.headers = self
254 .headers
255 .put(HttpField::Server.as_str(), misc::SERVER.as_str());
256 }
257 }
258
259 if self.flag.contains(ResponseBuilderFlags::FLAG_NO_KEEPALIVE) {
261 if lowercase_headers {
262 self.headers = self
263 .headers
264 .put(HttpField::Connection.as_str().to_ascii_lowercase(), "close");
265 } else {
266 self.headers = self.headers.put(HttpField::Connection.as_str(), "close");
267 }
268 }
269
270 self.headers = self.headers.complete();
271
272 let Self {
273 code,
274 headers,
275 body_buf: body_,
276 ..
277 } = self;
278
279 let mut status_line = BytesMut::with_capacity(32);
280
281 if self.flag.contains(ResponseBuilderFlags::FLAG_NO_HTTP11) {
282 status_line.put(&b"HTTP/1.0 "[..]);
283 } else {
284 status_line.put(&b"HTTP/1.1 "[..]);
285 }
286
287 match code {
289 100 => status_line.put(&b"100 Continue"[..]),
290 101 => status_line.put(&b"101 Switching Protocols"[..]),
291 102 => status_line.put(&b"102 Processing"[..]),
292 103 => status_line.put(&b"103 Early Hints"[..]),
293 200 => status_line.put(&b"200 OK"[..]),
294 201 => status_line.put(&b"201 Created"[..]),
295 202 => status_line.put(&b"202 Accepted"[..]),
296 203 => status_line.put(&b"203 Non-Authoritative Information"[..]),
297 204 => status_line.put(&b"204 No Content"[..]),
298 205 => status_line.put(&b"205 Reset Content"[..]),
299 206 => status_line.put(&b"206 Partial Content"[..]),
300 207 => status_line.put(&b"207 Multi-Status"[..]),
301 208 => status_line.put(&b"208 Already Reported"[..]),
302 226 => status_line.put(&b"226 IM Used"[..]),
303 300 => status_line.put(&b"300 Multiple Choices"[..]),
304 301 => status_line.put(&b"301 Moved Permanently"[..]),
305 302 => status_line.put(&b"302 Found"[..]),
306 303 => status_line.put(&b"303 See Other"[..]),
307 304 => status_line.put(&b"304 Not Modified"[..]),
308 307 => status_line.put(&b"307 Temporary Redirect"[..]),
309 308 => status_line.put(&b"308 Permanent Redirect"[..]),
310 400 => status_line.put(&b"400 Bad Request"[..]),
311 401 => status_line.put(&b"401 Unauthorized"[..]),
312 402 => status_line.put(&b"402 Payment Required"[..]),
313 403 => status_line.put(&b"403 Forbidden"[..]),
314 404 => status_line.put(&b"404 Not Found"[..]),
315 405 => status_line.put(&b"405 Method Not Allowed"[..]),
316 406 => status_line.put(&b"406 Not Acceptable"[..]),
317 407 => status_line.put(&b"407 Proxy Authentication Required"[..]),
318 408 => status_line.put(&b"408 Request Timeout"[..]),
319 409 => status_line.put(&b"409 Conflict"[..]),
320 410 => status_line.put(&b"410 Gone"[..]),
321 411 => status_line.put(&b"411 Length Required"[..]),
322 412 => status_line.put(&b"412 Precondition Failed"[..]),
323 413 => status_line.put(&b"413 Content Too Large"[..]),
324 414 => status_line.put(&b"414 URI Too Long"[..]),
325 415 => status_line.put(&b"415 Unsupported Media Type"[..]),
326 416 => status_line.put(&b"416 Range Not Satisfiable"[..]),
327 417 => status_line.put(&b"417 Expectation Failed"[..]),
328 418 => status_line.put(&b"418 I'm a teapot"[..]),
329 421 => status_line.put(&b"421 Misdirected Request"[..]),
330 422 => status_line.put(&b"422 Unprocessable Content"[..]),
331 423 => status_line.put(&b"423 Locked"[..]),
332 424 => status_line.put(&b"424 Failed Dependency"[..]),
333 425 => status_line.put(&b"425 Too Early"[..]),
334 426 => status_line.put(&b"426 Upgrade Required"[..]),
335 428 => status_line.put(&b"428 Precondition Required"[..]),
336 429 => status_line.put(&b"429 Too Many Requests"[..]),
337 431 => status_line.put(&b"431 Request Header Fields Too Large"[..]),
338 451 => status_line.put(&b"451 Unavailable For Legal Reasons"[..]),
339 500 => status_line.put(&b"500 Internal Server Error"[..]),
340 501 => status_line.put(&b"501 Not Implemented"[..]),
341 502 => status_line.put(&b"502 Bad Gateway"[..]),
342 503 => status_line.put(&b"503 Service Unavailable"[..]),
343 504 => status_line.put(&b"504 Gateway Timeout"[..]),
344 505 => status_line.put(&b"505 HTTP Version Not Supported"[..]),
345 506 => status_line.put(&b"506 Variant Also Negotiates"[..]),
346 507 => status_line.put(&b"507 Insufficient Storage"[..]),
347 508 => status_line.put(&b"508 Loop Detected"[..]),
348 510 => status_line.put(&b"510 Not Extended"[..]),
349 511 => status_line.put(&b"511 Network Authentication Required"[..]),
350
351 other => {
352 use std::fmt::Write as _;
353 write!(&mut status_line, "{} UNKNOWN", other).unwrap();
354 }
355 };
356
357 status_line.put(misc::CRLF);
358
359 let headers = headers.build();
360 let body = body_.freeze();
361
362 Response {
363 status_line: StatusLine(status_line.freeze()),
364 headers: headers.into(),
365 body: if body.is_empty() { None } else { Some(body) },
366 }
367 }
368}
369
370#[cfg(test)]
371mod response_tests {
372 use super::*;
373
374 fn init() {
375 pretty_env_logger::try_init_timed().ok();
376 }
377
378 #[test]
379 fn response_builder() {
380 init();
381
382 let resp = Response::builder()
383 .header("X-Ray-Id", "foobar")
384 .header("Content-Type", "text/plain")
385 .body(b"hello world")
386 .build();
387
388 let mut b = BytesMut::new();
389
390 b.put_slice(&resp.status_line.0);
391 b.put_slice(&resp.headers);
392 if let Some(body) = &resp.body {
393 b.put_slice(body);
394 }
395
396 let b = b.freeze();
397
398 info!("response: {:?}", b);
399 }
400}