areq_h1/
client.rs

1use {
2    crate::{
3        body::prelude::*,
4        error::Error,
5        handler::{Handler, Parser, ReadStrategy},
6        headers::{self, ContentLen},
7    },
8    async_channel::{Receiver, Sender},
9    bytes::{Buf, Bytes},
10    futures_lite::prelude::*,
11    http::{HeaderValue, Request, Response, header},
12    std::{
13        fmt, io,
14        pin::{self, Pin},
15    },
16};
17
18#[derive(Clone)]
19pub struct Config {
20    parser: Parser,
21    read_strategy: ReadStrategy,
22}
23
24impl Config {
25    #[inline]
26    pub fn read_strategy(mut self, read_strategy: ReadStrategy) -> Self {
27        self.read_strategy = read_strategy;
28        self
29    }
30
31    #[inline]
32    pub fn max_headers(mut self, n: usize) -> Self {
33        self.parser.set_max_headers(n);
34        self
35    }
36
37    #[inline]
38    pub fn handshake<I, B>(self, io: I) -> (Requester<B>, impl Future<Output = ()>)
39    where
40        I: AsyncRead + AsyncWrite,
41        B: Body,
42    {
43        let (send_req, recv_req) = async_channel::bounded(1);
44        let (send_res, recv_res) = async_channel::bounded(1);
45        let reqs = Requester { send_req, recv_res };
46        let conn = async move {
47            let io = pin::pin!(io);
48            let conn = Connection {
49                recv_req,
50                send_res,
51                io: Handler::new(io, self.read_strategy),
52                parser: self.parser,
53            };
54
55            connect(conn).await;
56        };
57
58        (reqs, conn)
59    }
60}
61
62impl Default for Config {
63    #[inline]
64    fn default() -> Self {
65        Self {
66            parser: Parser::new(),
67            read_strategy: ReadStrategy::default(),
68        }
69    }
70}
71
72struct Connection<'pin, I, B> {
73    recv_req: Receiver<Request<B>>,
74    send_res: Sender<Result<Response<FetchBody>, Error>>,
75    io: Handler<Pin<&'pin mut I>>,
76    parser: Parser,
77}
78
79async fn connect<I, B>(mut conn: Connection<'_, I, B>)
80where
81    I: AsyncRead + AsyncWrite,
82    B: Body,
83{
84    while let Ok(req) = conn.recv_req.recv().await {
85        let process = async {
86            let (parts, mut body) = req.into_parts();
87            let mut head = Request::from_parts(parts, ());
88
89            match body.size_hint() {
90                Hint::Empty => conn.io.write_header(&head).await?,
91                Hint::Full { .. } => {
92                    let full = body.take_full().await?;
93
94                    let chunk = full.as_ref().map(Buf::chunk).unwrap_or_default();
95                    let chunk_len = HeaderValue::from(chunk.len());
96
97                    head.headers_mut().insert(header::CONTENT_LENGTH, chunk_len);
98                    headers::remove_chunked_encoding(head.headers_mut());
99
100                    conn.io.write_header(&head).await?;
101                    conn.io.write_body(chunk).await?;
102                }
103                Hint::Chunked { .. } => {
104                    head.headers_mut().remove(header::CONTENT_LENGTH);
105                    headers::insert_chunked_encoding(head.headers_mut());
106
107                    conn.io.write_header(&head).await?;
108                    while let Some(chunk) = body.chunk().await {
109                        conn.io.write_chunk(chunk?.chunk()).await?;
110                        conn.io.flush().await?;
111                    }
112
113                    conn.io.write_chunk(&[]).await?;
114                }
115            }
116
117            conn.io.flush().await?;
118
119            let head = conn.io.read_header().await?;
120            let res = conn.parser.parse_header(head)?;
121
122            let headers = res.headers();
123            let state = match headers::parse_content_len(headers) {
124                ContentLen::Num(n) => ReadBodyState::Remaining(n),
125                ContentLen::None if headers::has_chunked_encoding(headers) => {
126                    ReadBodyState::Chunked
127                }
128                _ => return Err(Error::invalid_input()),
129            };
130
131            Ok((res, state))
132        };
133
134        let (give, fetch) = async_channel::bounded(16);
135        let mut state = match process.await {
136            Ok((res, state)) => {
137                let res = res.map(|_| FetchBody { fetch, end: false });
138                _ = conn.send_res.send(Ok(res)).await;
139                state
140            }
141            Err(e) => {
142                _ = conn.send_res.send(Err(e)).await;
143                continue;
144            }
145        };
146
147        loop {
148            let (frame, end) = match &mut state {
149                ReadBodyState::Remaining(0) => (Ok(Bytes::new()), true),
150                ReadBodyState::Remaining(n) => (conn.io.read_body(n).await, false),
151                ReadBodyState::Chunked => {
152                    let chunk = conn.io.read_chunk().await;
153                    let end = chunk.as_ref().is_ok_and(Bytes::is_empty);
154                    (chunk, end)
155                }
156            };
157
158            let error = frame.is_err();
159            let next = Next { frame, end };
160            if give.send(next).await.is_err() || error || end {
161                break;
162            }
163        }
164    }
165}
166
167#[derive(Clone, Copy)]
168enum ReadBodyState {
169    Remaining(usize),
170    Chunked,
171}
172
173pub struct Requester<B> {
174    send_req: Sender<Request<B>>,
175    recv_res: Receiver<Result<Response<FetchBody>, Error>>,
176}
177
178impl<B> Requester<B> {
179    #[inline]
180    pub async fn send(&self, req: Request<B>) -> Result<Response<FetchBody>, Error> {
181        self.send_req.send(req).await.map_err(|_| Error::Closed)?;
182        self.recv_res.recv().await.map_err(|_| Error::Closed)?
183    }
184}
185
186struct Next {
187    frame: Result<Bytes, Error>,
188    end: bool,
189}
190
191pub struct FetchBody {
192    fetch: Receiver<Next>,
193    end: bool,
194}
195
196impl FetchBody {
197    #[inline]
198    pub async fn frame(&mut self) -> Result<Bytes, Error> {
199        let Next { frame, end } = self.fetch.recv().await.map_err(|_| Error::Closed)?;
200        self.end = end;
201        frame
202    }
203}
204
205impl fmt::Debug for FetchBody {
206    #[inline]
207    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
208        f.debug_struct("FetchBody").finish()
209    }
210}
211
212impl Body for FetchBody {
213    type Chunk = Bytes;
214
215    #[inline]
216    async fn chunk(&mut self) -> Option<Result<Self::Chunk, io::Error>> {
217        match self.frame().await {
218            Ok(chunk) => {
219                if chunk.is_empty() {
220                    None
221                } else {
222                    Some(Ok(chunk))
223                }
224            }
225            Err(e) => Some(Err(e.into())),
226        }
227    }
228
229    fn size_hint(&self) -> Hint {
230        Hint::Chunked { end: self.end }
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use {super::*, crate::test, futures_lite::future};
237
238    fn run<C, R>(conn: C, reqs: R) -> Result<(), Error>
239    where
240        C: Future,
241        R: Future<Output = Result<(), Error>>,
242    {
243        future::block_on(future::or(
244            async {
245                conn.await;
246                Ok(())
247            },
248            reqs,
249        ))
250    }
251
252    #[test]
253    fn roundtrip_empty() -> Result<(), Error> {
254        const REQUEST: [&str; 2] = ["GET / HTTP/1.1\r\n", "\r\n"];
255        const RESPONSE: [&str; 3] = ["HTTP/1.1 200 OK\r\n", "content-length: 0\r\n", "\r\n"];
256
257        let read = test::parts(RESPONSE.map(str::as_bytes));
258        let mut write = vec![];
259        let io = test::io(read, &mut write);
260
261        let (reqs, conn) = Config::default().handshake(io);
262        run(conn, async {
263            let req = Request::new(());
264            let mut res = reqs.send(req).await?;
265
266            let empty = res.body_mut().frame().await?;
267            assert!(empty.is_empty());
268            Ok(())
269        })?;
270
271        assert_eq!(String::from_utf8(write), Ok(REQUEST.concat()));
272        Ok(())
273    }
274
275    #[test]
276    fn roundtrip_full() -> Result<(), Error> {
277        const REQUEST_BODY: &str = "Hello, request!";
278        const REQUEST: [&str; 4] = [
279            "GET / HTTP/1.1\r\n",
280            "content-length: 15\r\n",
281            "\r\n",
282            REQUEST_BODY,
283        ];
284
285        const RESPONSE_BODY: &str = "Hello, response!";
286        const RESPONSE: [&str; 4] = [
287            "HTTP/1.1 200 OK\r\n",
288            "content-length: 16\r\n",
289            "\r\n",
290            RESPONSE_BODY,
291        ];
292
293        let read = test::parts(RESPONSE.map(str::as_bytes));
294        let mut write = vec![];
295        let io = test::io(read, &mut write);
296
297        let (reqs, conn) = Config::default().handshake(io);
298        run(conn, async {
299            let req = Request::new(REQUEST_BODY);
300            let mut res = reqs.send(req).await?;
301
302            let body = res.body_mut().frame().await?;
303            assert_eq!(body, RESPONSE_BODY);
304
305            let empty = res.body_mut().frame().await?;
306            assert!(empty.is_empty());
307            Ok(())
308        })?;
309
310        assert_eq!(String::from_utf8(write), Ok(REQUEST.concat()));
311        Ok(())
312    }
313
314    #[test]
315    fn roundtrip_chunked() -> Result<(), Error> {
316        use {
317            crate::body::Chunked,
318            futures_lite::{StreamExt, stream},
319        };
320
321        const CHUNKS: [&str; 5] = ["hello", "from", "the", "internet", ":3"];
322
323        const REQUEST: [&str; 14] = [
324            "GET / HTTP/1.1\r\n",
325            "transfer-encoding: chunked\r\n",
326            "\r\n5\r\n",
327            CHUNKS[0],
328            "\r\n4\r\n",
329            CHUNKS[1],
330            "\r\n3\r\n",
331            CHUNKS[2],
332            "\r\n8\r\n",
333            CHUNKS[3],
334            "\r\n2\r\n",
335            CHUNKS[4],
336            "\r\n0\r\n",
337            "\r\n",
338        ];
339
340        const RESPONSE: [&str; 14] = [
341            "HTTP/1.1 200 OK\r\n",
342            "transfer-encoding: chunked\r\n",
343            "\r\n5\r\n",
344            CHUNKS[0],
345            "\r\n4\r\n",
346            CHUNKS[1],
347            "\r\n3\r\n",
348            CHUNKS[2],
349            "\r\n8\r\n",
350            CHUNKS[3],
351            "\r\n2\r\n",
352            CHUNKS[4],
353            "\r\n0\r\n",
354            "\r\n",
355        ];
356
357        let read = test::parts(RESPONSE.map(str::as_bytes));
358        let mut write = vec![];
359        let io = test::io(read, &mut write);
360
361        let (reqs, conn) = Config::default().handshake(io);
362        run(conn, async {
363            let body = stream::iter(CHUNKS).map(str::as_bytes).map(Ok);
364            let req = Request::new(Chunked(body));
365            let mut res = reqs.send(req).await?;
366            for expected in CHUNKS {
367                let chunk = res.body_mut().frame().await?;
368                assert_eq!(chunk, expected);
369            }
370
371            let empty = res.body_mut().frame().await?;
372            assert!(empty.is_empty());
373            Ok(())
374        })?;
375
376        assert_eq!(String::from_utf8(write), Ok(REQUEST.concat()));
377        Ok(())
378    }
379
380    #[test]
381    fn handshake_is_send() {
382        fn assert_send<S>(_: S)
383        where
384            S: Send,
385        {
386        }
387
388        let read: &[u8] = &[];
389        let write = vec![];
390        let io = test::io(read, write);
391        let (reqs, conn): (Requester<&[u8]>, _) = Config::default().handshake(io);
392        assert_send(reqs);
393        assert_send(conn);
394    }
395}