rasi_ext/net/http/
client.rs

1//! Utilities for http client program.
2
3use std::{
4    io,
5    net::{SocketAddr, ToSocketAddrs},
6    path::{Path, PathBuf},
7    pin::Pin,
8    time::Duration,
9};
10
11use boring::ssl::{SslConnector, SslMethod};
12use futures::{AsyncRead, AsyncWrite};
13use http::{uri::Scheme, Request, Response};
14use rasi::{net::TcpStream, time::TimeoutExt};
15
16use crate::net::{
17    http::{parse::Responser, writer::RequestWriter},
18    tls::{connect, SslStream},
19};
20
21use super::parse::BodyReader;
22
23pub enum HttpClientWrite {
24    TcpStream(TcpStream),
25    TlsStream(SslStream<TcpStream>),
26}
27
28impl Into<HttpClientRead> for HttpClientWrite {
29    fn into(self) -> HttpClientRead {
30        match self {
31            HttpClientWrite::TcpStream(stream) => HttpClientRead::TcpStream(stream),
32            HttpClientWrite::TlsStream(stream) => HttpClientRead::TlsStream(stream),
33        }
34    }
35}
36
37impl AsyncWrite for HttpClientWrite {
38    fn poll_write(
39        mut self: Pin<&mut Self>,
40        cx: &mut std::task::Context<'_>,
41        buf: &[u8],
42    ) -> std::task::Poll<io::Result<usize>> {
43        match &mut *self {
44            Self::TcpStream(stream) => Pin::new(stream).poll_write(cx, buf),
45            Self::TlsStream(stream) => Pin::new(stream).poll_write(cx, buf),
46        }
47    }
48
49    fn poll_flush(
50        mut self: Pin<&mut Self>,
51        cx: &mut std::task::Context<'_>,
52    ) -> std::task::Poll<io::Result<()>> {
53        match &mut *self {
54            Self::TcpStream(stream) => Pin::new(stream).poll_flush(cx),
55            Self::TlsStream(stream) => Pin::new(stream).poll_flush(cx),
56        }
57    }
58
59    fn poll_close(
60        mut self: Pin<&mut Self>,
61        cx: &mut std::task::Context<'_>,
62    ) -> std::task::Poll<io::Result<()>> {
63        match &mut *self {
64            Self::TcpStream(stream) => Pin::new(stream).poll_close(cx),
65            Self::TlsStream(stream) => Pin::new(stream).poll_close(cx),
66        }
67    }
68}
69
70pub enum HttpClientRead {
71    TcpStream(TcpStream),
72    TlsStream(SslStream<TcpStream>),
73}
74
75impl AsyncRead for HttpClientRead {
76    fn poll_read(
77        mut self: std::pin::Pin<&mut Self>,
78        cx: &mut std::task::Context<'_>,
79        buf: &mut [u8],
80    ) -> std::task::Poll<io::Result<usize>> {
81        match &mut *self {
82            HttpClientRead::TcpStream(stream) => Pin::new(stream).poll_read(cx, buf),
83            HttpClientRead::TlsStream(stream) => Pin::new(stream).poll_read(cx, buf),
84        }
85    }
86}
87
88/// A extension trait for http [`Request`] builder.
89pub trait HttpRequestSend {
90    type Body;
91
92    fn send(self) -> HttpRequestSender<Self::Body>;
93}
94
95/// A builder to create a task to send http request.
96#[must_use = "Must call response function to invoke real sending action."]
97pub struct HttpRequestSender<T> {
98    request: http::Result<Request<T>>,
99    timeout: Duration,
100    raddrs: Option<io::Result<Vec<SocketAddr>>>,
101    server_name: Option<String>,
102    ca_file: Option<PathBuf>,
103}
104
105impl<T> HttpRequestSender<T> {
106    /// Create `HttpRequestSender` with the default configuration other than provided [`Request`].
107    pub fn new(request: http::Result<Request<T>>) -> Self {
108        Self {
109            request,
110            timeout: Duration::from_secs(30),
111            raddrs: None,
112            server_name: None,
113            ca_file: None,
114        }
115    }
116
117    /// Set the request send / response header recv timeout. the default value is `30s`.
118    pub fn with_timeout(mut self, timeout: Duration) -> Self {
119        self.timeout = timeout;
120        self
121    }
122
123    /// Rewrite http request's host:port fields and send request to the specified `raddrs`.
124    pub fn redirect<R: ToSocketAddrs>(mut self, raddrs: R) -> Self {
125        self.raddrs = Some(
126            raddrs
127                .to_socket_addrs()
128                .map(|iter| iter.collect::<Vec<_>>()),
129        );
130
131        self
132    }
133
134    /// Set remote server's server name, this option will rewrite request's host field.
135    pub fn with_server_name(mut self, server_name: &str) -> Self {
136        self.server_name = Some(server_name.to_string());
137
138        self
139    }
140
141    /// Set the server verification ca file, this is useful for self signed server.
142    pub fn with_ca_file<P: AsRef<Path>>(mut self, ca_file: P) -> Self {
143        self.ca_file = Some(ca_file.as_ref().to_path_buf());
144        self
145    }
146
147    /// Consume self and create new [`HttpClientWrite`].
148    ///
149    /// On success, The response wait `timeout` and the original [`Request`] are returned together.
150    pub async fn create(self) -> io::Result<(Request<T>, HttpClientWrite, Duration)> {
151        let request = self
152            .request
153            .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
154
155        let scheme = request.uri().scheme().ok_or(io::Error::new(
156            io::ErrorKind::InvalidInput,
157            "Unspecified request scheme",
158        ))?;
159
160        let host = request.uri().host().ok_or(io::Error::new(
161            io::ErrorKind::InvalidInput,
162            "Unspecified request uri",
163        ))?;
164
165        let port =
166            request.uri().port_u16().unwrap_or_else(
167                || {
168                    if scheme == &Scheme::HTTP {
169                        80
170                    } else {
171                        440
172                    }
173                },
174            );
175
176        let raddr = if let Some(raddr) = self.raddrs {
177            raddr?
178        } else {
179            vec![format!("{}:{}", host, port,)
180                .parse()
181                .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?]
182        };
183
184        let stream = if scheme == &Scheme::HTTP {
185            let stream = TcpStream::connect(raddr.as_slice()).await?;
186
187            HttpClientWrite::TcpStream(stream)
188        } else {
189            let stream = TcpStream::connect(raddr.as_slice()).await?;
190
191            let mut config = SslConnector::builder(SslMethod::tls())
192                .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
193
194            if let Some(ca_file) = self.ca_file {
195                config
196                    .set_ca_file(ca_file)
197                    .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
198            }
199
200            let config = config.build().configure().unwrap();
201
202            let stream = connect(config, host, stream)
203                .await
204                .map_err(|err| io::Error::new(io::ErrorKind::ConnectionRefused, err))?;
205
206            HttpClientWrite::TlsStream(stream)
207        };
208
209        Ok((request, stream, self.timeout))
210    }
211
212    /// Start a new
213    pub async fn response(self) -> io::Result<Response<BodyReader<HttpClientRead>>>
214    where
215        T: AsRef<[u8]>,
216    {
217        let (request, mut stream, timeout) = self.create().await?;
218
219        let writer = RequestWriter::new(&mut stream);
220
221        match writer.write(request).timeout(timeout).await {
222            Some(Ok(_)) => {}
223            Some(Err(err)) => return Err(err),
224            None => {
225                return Err(io::Error::new(
226                    io::ErrorKind::TimedOut,
227                    "send http request timeout",
228                ));
229            }
230        }
231
232        let stream: HttpClientRead = stream.into();
233
234        match Responser::new(stream).parse().timeout(timeout).await {
235            Some(Ok(response)) => Ok(response),
236            Some(Err(err)) => return Err(err.into()),
237            None => {
238                return Err(io::Error::new(
239                    io::ErrorKind::TimedOut,
240                    "recv http response header timeout",
241                ));
242            }
243        }
244    }
245
246    pub async fn stream_response(self) -> io::Result<Response<BodyReader<HttpClientRead>>>
247    where
248        T: AsyncRead + Unpin,
249    {
250        let (request, mut stream, timeout) = self.create().await?;
251
252        let writer = RequestWriter::new(&mut stream);
253
254        match writer
255            .write_with_stream_body(request)
256            .timeout(timeout)
257            .await
258        {
259            Some(Ok(_)) => {}
260            Some(Err(err)) => return Err(err),
261            None => {
262                return Err(io::Error::new(
263                    io::ErrorKind::TimedOut,
264                    "send http request timeout",
265                ));
266            }
267        }
268
269        let stream: HttpClientRead = stream.into();
270
271        match Responser::new(stream).parse().timeout(timeout).await {
272            Some(Ok(response)) => Ok(response),
273            Some(Err(err)) => return Err(err.into()),
274            None => {
275                return Err(io::Error::new(
276                    io::ErrorKind::TimedOut,
277                    "recv http response header timeout",
278                ));
279            }
280        }
281    }
282}
283
284impl<T> HttpRequestSend for http::Result<Request<T>>
285where
286    T: AsRef<[u8]>,
287{
288    type Body = T;
289
290    fn send(self) -> HttpRequestSender<Self::Body> {
291        HttpRequestSender::new(self)
292    }
293}