capybara_core/protocol/http/
response.rs

1use std::sync::atomic::{AtomicBool, Ordering};
2use std::sync::Arc;
3
4use arc_swap::ArcSwap;
5use bitflags::bitflags;
6use bytes::{BufMut, Bytes, BytesMut};
7use once_cell::sync::Lazy;
8use smallvec::SmallVec;
9use tokio::io::AsyncWriteExt;
10
11use crate::Result;
12
13use super::frame::{HeadersBuilder, StatusLine};
14use super::httpfield::HttpField;
15use super::misc;
16
17type DateBytes = SmallVec<[u8; 32]>;
18
19static DATE: Lazy<ArcSwap<DateBytes>> = Lazy::new(|| ArcSwap::new(Arc::new(DateBytes::new())));
20static DATE_TICKER_STARTED: Lazy<AtomicBool> = Lazy::new(AtomicBool::default);
21
22async fn start_date_ticker() {
23    use tokio::time::{sleep, Duration};
24    // start the timer
25    info!("start interval of http date generator ok");
26    loop {
27        // 1. compute the next truncated second, as the next launch time
28        let sleep_nanos = 1_000_000_000i64 - (chrono::Utc::now().timestamp_subsec_nanos() as i64);
29        if sleep_nanos > 0 {
30            sleep(Duration::from_nanos(sleep_nanos as u64)).await;
31            // 2. refresh date str every second
32            DATE.store(Arc::new(generate_date_bytes()));
33        } else {
34            // sleep 1ms at least
35            sleep(Duration::from_millis(1)).await;
36        }
37    }
38}
39
40#[inline(always)]
41fn generate_date_bytes() -> DateBytes {
42    // Sun, 08 Oct 2023 08:49:35 GMT
43    let mut b = DateBytes::new();
44    {
45        use std::io::Write as _;
46        write!(
47            &mut b,
48            "{}",
49            chrono::Utc::now().format("%a, %d %b %Y %T GMT")
50        )
51        .ok();
52    }
53    b
54}
55
56#[inline]
57async fn write_header_date<W>(w: &mut W, lowercase: bool) -> Result<()>
58where
59    W: AsyncWriteExt + Unpin,
60{
61    if lowercase {
62        w.write_all(HttpField::Date.as_str().to_ascii_lowercase().as_bytes())
63            .await?;
64    } else {
65        w.write_all(HttpField::Date.as_bytes()).await?;
66    }
67
68    w.write_all(b": ").await?;
69
70    loop {
71        {
72            let loaded = DATE.load();
73            if !loaded.is_empty() {
74                w.write_all(&loaded[..]).await?;
75                break;
76            }
77        }
78
79        if let Ok(origin) =
80            DATE_TICKER_STARTED.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
81        {
82            if !origin {
83                // generate when starting
84                DATE.store(Arc::new(generate_date_bytes()));
85
86                tokio::spawn(async {
87                    start_date_ticker().await;
88                });
89            }
90        }
91    }
92
93    w.write_all(misc::CRLF).await?;
94    Ok(())
95}
96
97#[derive(Debug, Copy, Clone, Default, Hash, PartialEq, Eq)]
98pub struct ResponseFlags(u8);
99
100bitflags! {
101    impl ResponseFlags: u8 {
102        const LOWERCASE_DATE_HEADER = 1 << 0;
103    }
104}
105
106#[derive(Debug, Clone)]
107pub struct Response {
108    pub status_line: StatusLine,
109    pub headers: Bytes,
110    pub body: Option<Bytes>,
111}
112
113impl Response {
114    pub fn builder() -> ResponseBuilder {
115        Default::default()
116    }
117
118    pub async fn write_to<W>(&self, w: &mut W, flags: ResponseFlags) -> Result<()>
119    where
120        W: AsyncWriteExt + Unpin,
121    {
122        w.write_all(&self.status_line.0[..]).await?;
123
124        write_header_date(w, flags.contains(ResponseFlags::LOWERCASE_DATE_HEADER)).await?;
125
126        w.write_all(&self.headers[..]).await?;
127        if let Some(body) = &self.body {
128            w.write_all(&body[..]).await?;
129        }
130        Ok(())
131    }
132}
133
134#[derive(Debug, Copy, Clone, Default, Hash, PartialEq, Eq)]
135struct ResponseBuilderFlags(u16);
136
137bitflags! {
138    impl ResponseBuilderFlags: u16 {
139        const FLAG_DNT_SERVER = 1 << 0;
140        const FLAG_NO_HTTP11 = 1 << 1;
141        const FLAG_NO_KEEPALIVE = 1 << 2;
142        const FLAG_LOWERCASE_HEADER = 1 << 3;
143    }
144}
145
146pub struct ResponseBuilder {
147    headers: HeadersBuilder,
148    body_buf: BytesMut,
149    code: u16,
150    flag: ResponseBuilderFlags,
151}
152
153impl Default for ResponseBuilder {
154    fn default() -> Self {
155        ResponseBuilder {
156            code: 200,
157            headers: Default::default(),
158            body_buf: Default::default(),
159            flag: ResponseBuilderFlags::default(),
160        }
161    }
162}
163
164impl ResponseBuilder {
165    pub fn status_code(mut self, status_code: u16) -> Self {
166        self.code = status_code;
167        self
168    }
169
170    pub fn use_lowercase_header(mut self) -> Self {
171        self.flag |= ResponseBuilderFlags::FLAG_LOWERCASE_HEADER;
172        self
173    }
174
175    pub fn content_type<T>(mut self, typ: T) -> Self
176    where
177        T: AsRef<str>,
178    {
179        if self
180            .flag
181            .contains(ResponseBuilderFlags::FLAG_LOWERCASE_HEADER)
182        {
183            self.headers = self.headers.put(
184                HttpField::ContentType.as_str().to_ascii_lowercase(),
185                typ.as_ref(),
186            );
187        } else {
188            self.headers = self
189                .headers
190                .put(HttpField::ContentType.as_str(), typ.as_ref());
191        }
192
193        self
194    }
195
196    pub fn header<K, V>(mut self, key: K, value: V) -> Self
197    where
198        K: AsRef<str>,
199        V: AsRef<str>,
200    {
201        if self
202            .flag
203            .contains(ResponseBuilderFlags::FLAG_LOWERCASE_HEADER)
204        {
205            self.headers = self.headers.put(key.as_ref().to_ascii_lowercase(), value);
206        } else {
207            self.headers = self.headers.put(key, value);
208        }
209
210        self
211    }
212
213    pub fn body<B>(mut self, value: B) -> Self
214    where
215        B: AsRef<[u8]>,
216    {
217        let value = value.as_ref();
218        self.body_buf.put(value);
219        self
220    }
221
222    pub fn disable_server(mut self) -> Self {
223        self.flag |= ResponseBuilderFlags::FLAG_DNT_SERVER;
224        self
225    }
226
227    pub fn build(mut self) -> Response {
228        let lowercase_headers = self
229            .flag
230            .contains(ResponseBuilderFlags::FLAG_LOWERCASE_HEADER);
231
232        let size = self.body_buf.len();
233
234        if lowercase_headers {
235            self.headers = self.headers.put(
236                HttpField::ContentLength.as_str().to_ascii_lowercase(),
237                size.to_string(),
238            );
239        } else {
240            self.headers = self
241                .headers
242                .put(HttpField::ContentLength.as_str(), size.to_string());
243        }
244
245        // set 'Server: leap/x.y.z' if custom server header is enabled.
246        if !self.flag.contains(ResponseBuilderFlags::FLAG_DNT_SERVER) {
247            if lowercase_headers {
248                self.headers = self.headers.put(
249                    HttpField::Server.as_str().to_ascii_lowercase(),
250                    misc::SERVER.as_str(),
251                );
252            } else {
253                self.headers = self
254                    .headers
255                    .put(HttpField::Server.as_str(), misc::SERVER.as_str());
256            }
257        }
258
259        // set 'Connection: close' if keepalive is disabled.
260        if self.flag.contains(ResponseBuilderFlags::FLAG_NO_KEEPALIVE) {
261            if lowercase_headers {
262                self.headers = self
263                    .headers
264                    .put(HttpField::Connection.as_str().to_ascii_lowercase(), "close");
265            } else {
266                self.headers = self.headers.put(HttpField::Connection.as_str(), "close");
267            }
268        }
269
270        self.headers = self.headers.complete();
271
272        let Self {
273            code,
274            headers,
275            body_buf: body_,
276            ..
277        } = self;
278
279        let mut status_line = BytesMut::with_capacity(32);
280
281        if self.flag.contains(ResponseBuilderFlags::FLAG_NO_HTTP11) {
282            status_line.put(&b"HTTP/1.0 "[..]);
283        } else {
284            status_line.put(&b"HTTP/1.1 "[..]);
285        }
286
287        // https://developer.mozilla.org/en-US/docs/Web/HTTP/Status
288        match code {
289            100 => status_line.put(&b"100 Continue"[..]),
290            101 => status_line.put(&b"101 Switching Protocols"[..]),
291            102 => status_line.put(&b"102 Processing"[..]),
292            103 => status_line.put(&b"103 Early Hints"[..]),
293            200 => status_line.put(&b"200 OK"[..]),
294            201 => status_line.put(&b"201 Created"[..]),
295            202 => status_line.put(&b"202 Accepted"[..]),
296            203 => status_line.put(&b"203 Non-Authoritative Information"[..]),
297            204 => status_line.put(&b"204 No Content"[..]),
298            205 => status_line.put(&b"205 Reset Content"[..]),
299            206 => status_line.put(&b"206 Partial Content"[..]),
300            207 => status_line.put(&b"207 Multi-Status"[..]),
301            208 => status_line.put(&b"208 Already Reported"[..]),
302            226 => status_line.put(&b"226 IM Used"[..]),
303            300 => status_line.put(&b"300 Multiple Choices"[..]),
304            301 => status_line.put(&b"301 Moved Permanently"[..]),
305            302 => status_line.put(&b"302 Found"[..]),
306            303 => status_line.put(&b"303 See Other"[..]),
307            304 => status_line.put(&b"304 Not Modified"[..]),
308            307 => status_line.put(&b"307 Temporary Redirect"[..]),
309            308 => status_line.put(&b"308 Permanent Redirect"[..]),
310            400 => status_line.put(&b"400 Bad Request"[..]),
311            401 => status_line.put(&b"401 Unauthorized"[..]),
312            402 => status_line.put(&b"402 Payment Required"[..]),
313            403 => status_line.put(&b"403 Forbidden"[..]),
314            404 => status_line.put(&b"404 Not Found"[..]),
315            405 => status_line.put(&b"405 Method Not Allowed"[..]),
316            406 => status_line.put(&b"406 Not Acceptable"[..]),
317            407 => status_line.put(&b"407 Proxy Authentication Required"[..]),
318            408 => status_line.put(&b"408 Request Timeout"[..]),
319            409 => status_line.put(&b"409 Conflict"[..]),
320            410 => status_line.put(&b"410 Gone"[..]),
321            411 => status_line.put(&b"411 Length Required"[..]),
322            412 => status_line.put(&b"412 Precondition Failed"[..]),
323            413 => status_line.put(&b"413 Content Too Large"[..]),
324            414 => status_line.put(&b"414 URI Too Long"[..]),
325            415 => status_line.put(&b"415 Unsupported Media Type"[..]),
326            416 => status_line.put(&b"416 Range Not Satisfiable"[..]),
327            417 => status_line.put(&b"417 Expectation Failed"[..]),
328            418 => status_line.put(&b"418 I'm a teapot"[..]),
329            421 => status_line.put(&b"421 Misdirected Request"[..]),
330            422 => status_line.put(&b"422 Unprocessable Content"[..]),
331            423 => status_line.put(&b"423 Locked"[..]),
332            424 => status_line.put(&b"424 Failed Dependency"[..]),
333            425 => status_line.put(&b"425 Too Early"[..]),
334            426 => status_line.put(&b"426 Upgrade Required"[..]),
335            428 => status_line.put(&b"428 Precondition Required"[..]),
336            429 => status_line.put(&b"429 Too Many Requests"[..]),
337            431 => status_line.put(&b"431 Request Header Fields Too Large"[..]),
338            451 => status_line.put(&b"451 Unavailable For Legal Reasons"[..]),
339            500 => status_line.put(&b"500 Internal Server Error"[..]),
340            501 => status_line.put(&b"501 Not Implemented"[..]),
341            502 => status_line.put(&b"502 Bad Gateway"[..]),
342            503 => status_line.put(&b"503 Service Unavailable"[..]),
343            504 => status_line.put(&b"504 Gateway Timeout"[..]),
344            505 => status_line.put(&b"505 HTTP Version Not Supported"[..]),
345            506 => status_line.put(&b"506 Variant Also Negotiates"[..]),
346            507 => status_line.put(&b"507 Insufficient Storage"[..]),
347            508 => status_line.put(&b"508 Loop Detected"[..]),
348            510 => status_line.put(&b"510 Not Extended"[..]),
349            511 => status_line.put(&b"511 Network Authentication Required"[..]),
350
351            other => {
352                use std::fmt::Write as _;
353                write!(&mut status_line, "{} UNKNOWN", other).unwrap();
354            }
355        };
356
357        status_line.put(misc::CRLF);
358
359        let headers = headers.build();
360        let body = body_.freeze();
361
362        Response {
363            status_line: StatusLine(status_line.freeze()),
364            headers: headers.into(),
365            body: if body.is_empty() { None } else { Some(body) },
366        }
367    }
368}
369
370#[cfg(test)]
371mod response_tests {
372    use super::*;
373
374    fn init() {
375        pretty_env_logger::try_init_timed().ok();
376    }
377
378    #[test]
379    fn response_builder() {
380        init();
381
382        let resp = Response::builder()
383            .header("X-Ray-Id", "foobar")
384            .header("Content-Type", "text/plain")
385            .body(b"hello world")
386            .build();
387
388        let mut b = BytesMut::new();
389
390        b.put_slice(&resp.status_line.0);
391        b.put_slice(&resp.headers);
392        if let Some(body) = &resp.body {
393            b.put_slice(body);
394        }
395
396        let b = b.freeze();
397
398        info!("response: {:?}", b);
399    }
400}