http_proxy_client_async/
flow.rs

1use futures_io::{AsyncRead, AsyncWrite};
2use futures_util::io::{AsyncReadExt, AsyncWriteExt};
3use std::io::{Error, ErrorKind, Result};
4
5use crate::http::HeaderMap;
6
7mod handshake_outcome;
8mod request;
9
10pub use handshake_outcome::{HandshakeOutcome, ResponseParts};
11
12pub async fn handshake<ARW>(
13    stream: &mut ARW,
14    host: &str,
15    port: u16,
16    request_headers: &HeaderMap,
17    read_buf: &mut [u8],
18) -> Result<HandshakeOutcome>
19where
20    ARW: AsyncRead + AsyncWrite + Unpin,
21{
22    send_request(stream, host, port, request_headers).await?;
23    receive_response(stream, read_buf).await
24}
25
26pub async fn send_request<AW>(
27    stream: &mut AW,
28    host: &str,
29    port: u16,
30    headers: &HeaderMap,
31) -> Result<()>
32where
33    AW: AsyncWrite + Unpin,
34{
35    let mut buf: Vec<u8> = Vec::with_capacity(1024);
36    request::write(&mut buf, host, port, headers)?;
37    stream.write_all(buf.as_slice()).await
38}
39
40pub async fn receive_response<'buf, AR>(
41    stream: &mut AR,
42    read_buf: &mut [u8],
43) -> Result<HandshakeOutcome>
44where
45    AR: AsyncRead + Unpin,
46{
47    // Happy path - we expect the response to be reasonably small and to come in
48    // complete as a single buffer via a single read.
49    // In this case we don't need to allocate and carry-on second buffer.
50
51    let first_buf = {
52        let total = stream.read(read_buf).await?;
53        let buf = &read_buf[..total];
54
55        let mut response_headers = [httparse::EMPTY_HEADER; 16];
56        let mut response = httparse::Response::new(&mut response_headers);
57
58        let status = response
59            .parse(buf)
60            .map_err(|err| Error::new(ErrorKind::InvalidData, err))?;
61
62        match status {
63            httparse::Status::Partial => buf,
64            httparse::Status::Complete(consumed) => {
65                return Ok(HandshakeOutcome::new(response, Vec::from(&buf[consumed..])))
66            }
67        }
68    };
69
70    // We didn't exit early on error or completion, this means we're at slower
71    // path and we need a carry-on buffer.
72
73    // TODO: allow user to customize the data structure used for a carry-on
74    // buffer. This is useful in case user wants to limit the amount of memory
75    // this buffer can grow to, or for the cases when a more optimized data
76    // structure is at hand.
77    let mut carry_on_buf = Vec::from(first_buf);
78    loop {
79        let total = stream.read(read_buf).await?;
80        let buf = &read_buf[..total];
81        carry_on_buf.extend_from_slice(buf);
82
83        let mut response_headers = [httparse::EMPTY_HEADER; 16];
84        let mut response = httparse::Response::new(&mut response_headers);
85
86        let status = response
87            .parse(carry_on_buf.as_slice())
88            .map_err(|err| Error::new(ErrorKind::InvalidData, err))?;
89        match status {
90            httparse::Status::Partial => continue,
91            httparse::Status::Complete(consumed) => {
92                return Ok(HandshakeOutcome::new(
93                    response,
94                    Vec::from(&carry_on_buf[consumed..]),
95                ))
96            }
97        };
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104    use crate::http::HeaderValue;
105    use futures::{executor, io::Cursor};
106
107    #[test]
108    fn send_request_without_headers() -> Result<()> {
109        executor::block_on(async {
110            let sample_res = "CONNECT 127.0.0.1:8080 HTTP/1.1\r\n\
111                              Host: 127.0.0.1:8080\r\n\
112                              \r\n";
113            let mut socket = Cursor::new(vec![0u8; 1024]);
114            let headers = HeaderMap::new();
115            send_request(&mut socket, "127.0.0.1", 8080, &headers).await?;
116
117            assert_eq!(
118                &socket.get_ref()[..socket.position() as usize],
119                sample_res.as_bytes(),
120            );
121            Ok(())
122        })
123    }
124
125    #[test]
126    fn send_request_with_headers() -> Result<()> {
127        executor::block_on(async {
128            let sample_res = "CONNECT 127.0.0.1:8080 HTTP/1.1\r\n\
129                              Host: 127.0.0.1:8080\r\n\
130                              proxy-authorization: Basic aGVsbG86d29ybGQ=\r\n\
131                              \r\n";
132            let mut socket = Cursor::new(vec![0u8; 1024]);
133            let mut headers = HeaderMap::new();
134            headers.insert(
135                "Proxy-Authorization",
136                HeaderValue::from_static("Basic aGVsbG86d29ybGQ="),
137            );
138            send_request(&mut socket, "127.0.0.1", 8080, &headers).await?;
139
140            assert_eq!(
141                &socket.get_ref()[..socket.position() as usize],
142                sample_res.as_bytes(),
143            );
144            Ok(())
145        })
146    }
147
148    #[test]
149    fn receive_response_test() -> Result<()> {
150        executor::block_on(async {
151            let sample_res = "HTTP/1.1 200 OK\r\n\
152                              \r\n\
153                              this is already the proxied content";
154            let mut socket = Cursor::new(sample_res);
155            let mut read_buf = [0u8; 1024];
156            let outcome = receive_response(&mut socket, &mut read_buf).await?;
157            assert_eq!(
158                outcome.data_after_handshake.as_slice(),
159                "this is already the proxied content".as_bytes()
160            );
161            assert_eq!(outcome.response_parts.status_code, 200);
162            assert_eq!(outcome.response_parts.reason_phrase, "OK");
163            assert_eq!(outcome.response_parts.headers.len(), 0);
164            Ok(())
165        })
166    }
167
168    #[test]
169    fn receive_response_with_headers() -> Result<()> {
170        executor::block_on(async {
171            let sample_res = "HTTP/1.1 200 OK\r\n\
172                              X-Custom: Sample Value\r\n\
173                              \r\n\
174                              this is already the proxied content";
175            let mut socket = Cursor::new(sample_res);
176            let mut read_buf = [0u8; 1024];
177            let outcome = receive_response(&mut socket, &mut read_buf).await?;
178            assert_eq!(
179                outcome.data_after_handshake.as_slice(),
180                "this is already the proxied content".as_bytes()
181            );
182            assert_eq!(outcome.response_parts.status_code, 200);
183            assert_eq!(outcome.response_parts.reason_phrase, "OK");
184            assert_eq!(outcome.response_parts.headers.len(), 1);
185            assert_eq!(
186                outcome.response_parts.headers.get("x-custom").unwrap(),
187                &"Sample Value"
188            );
189            Ok(())
190        })
191    }
192
193    #[test]
194    fn receive_response_small_read_buf_test() -> Result<()> {
195        executor::block_on(async {
196            let sample_handshake = "HTTP/1.1 200 OK\r\n\
197                                    \r\n";
198            let sample_post_handshake_data = "this is already the proxied content";
199            let sample_res = sample_handshake.to_string() + sample_post_handshake_data;
200            let mut socket = Cursor::new(sample_res);
201
202            // Use small read buffer size to force non-happy-path.
203            const BUF_SIZE: usize = 4;
204            let mut read_buf = [0u8; BUF_SIZE];
205            let outcome = receive_response(&mut socket, &mut read_buf).await?;
206
207            // Prepare the estimates for the leftover data.
208            let extra_read = (BUF_SIZE - (sample_handshake.len() % BUF_SIZE)) % BUF_SIZE;
209            let expected_data = &sample_post_handshake_data[..extra_read];
210
211            assert_eq!(
212                outcome.data_after_handshake.as_slice(),
213                expected_data.as_bytes()
214            );
215            assert_eq!(outcome.response_parts.status_code, 200);
216            assert_eq!(outcome.response_parts.reason_phrase, "OK");
217            assert_eq!(outcome.response_parts.headers.len(), 0);
218            Ok(())
219        })
220    }
221}