edge_http/io/
client.rs

1use core::mem;
2use core::net::SocketAddr;
3use core::str;
4
5use embedded_io_async::{ErrorType, Read, Write};
6
7use edge_nal::{Close, TcpConnect, TcpShutdown};
8
9use crate::{
10    ws::{upgrade_request_headers, MAX_BASE64_KEY_LEN, MAX_BASE64_KEY_RESPONSE_LEN, NONCE_LEN},
11    ConnectionType, DEFAULT_MAX_HEADERS_COUNT,
12};
13
14use super::{send_headers, send_request, Body, Error, ResponseHeaders, SendBody};
15
16use super::Method;
17
18const COMPLETION_BUF_SIZE: usize = 64;
19
20/// A client connection that can be used to send HTTP requests and receive responses.
21#[allow(private_interfaces)]
22pub enum Connection<'b, T, const N: usize = DEFAULT_MAX_HEADERS_COUNT>
23where
24    T: TcpConnect,
25{
26    Transition(TransitionState),
27    Unbound(UnboundState<'b, T, N>),
28    Request(RequestState<'b, T, N>),
29    Response(ResponseState<'b, T, N>),
30}
31
32impl<'b, T, const N: usize> Connection<'b, T, N>
33where
34    T: TcpConnect,
35{
36    /// Create a new client connection.
37    ///
38    /// Note that the connection does not have any built-in read/write timeouts:
39    /// - To add a timeout on each IO operation, wrap the `socket` type with the `edge_nal::WithTimeout` wrapper.
40    /// - To add a global request-response timeout, wrap your complete request-response processing
41    ///   logic with the `edge_nal::with_timeout` function.
42    ///
43    /// Parameters:
44    /// - `buf`: A buffer to use for reading and writing data.
45    /// - `socket`: The TCP stack to use for the connection.
46    /// - `addr`: The address of the server to connect to.
47    pub fn new(buf: &'b mut [u8], socket: &'b T, addr: SocketAddr) -> Self {
48        Self::Unbound(UnboundState {
49            buf,
50            socket,
51            addr,
52            io: None,
53        })
54    }
55
56    /// Reinitialize the connection with a new address.
57    pub async fn reinitialize(&mut self, addr: SocketAddr) -> Result<(), Error<T::Error>> {
58        let _ = self.complete().await;
59        unwrap!(self.unbound_mut(), "Unreachable").addr = addr;
60
61        Ok(())
62    }
63
64    /// Initiate an HTTP request.
65    pub async fn initiate_request(
66        &mut self,
67        http11: bool,
68        method: Method,
69        uri: &str,
70        headers: &[(&str, &str)],
71    ) -> Result<(), Error<T::Error>> {
72        self.start_request(http11, method, uri, headers).await
73    }
74
75    /// A utility method to initiate a WebSocket upgrade request.
76    pub async fn initiate_ws_upgrade_request(
77        &mut self,
78        host: Option<&str>,
79        origin: Option<&str>,
80        uri: &str,
81        version: Option<&str>,
82        nonce: &[u8; NONCE_LEN],
83        nonce_base64_buf: &mut [u8; MAX_BASE64_KEY_LEN],
84    ) -> Result<(), Error<T::Error>> {
85        let headers = upgrade_request_headers(host, origin, version, nonce, nonce_base64_buf);
86
87        self.initiate_request(true, Method::Get, uri, &headers)
88            .await
89    }
90
91    /// Return `true` if a request has been initiated.
92    pub fn is_request_initiated(&self) -> bool {
93        matches!(self, Self::Request(_))
94    }
95
96    /// Initiate an HTTP response.
97    ///
98    /// This should be called after a request has been initiated and the request body had been sent.
99    pub async fn initiate_response(&mut self) -> Result<(), Error<T::Error>> {
100        self.complete_request().await
101    }
102
103    /// Return `true` if a response has been initiated.
104    pub fn is_response_initiated(&self) -> bool {
105        matches!(self, Self::Response(_))
106    }
107
108    /// Return `true` if the server accepted the WebSocket upgrade request.
109    pub fn is_ws_upgrade_accepted(
110        &self,
111        nonce: &[u8; NONCE_LEN],
112        buf: &mut [u8; MAX_BASE64_KEY_RESPONSE_LEN],
113    ) -> Result<bool, Error<T::Error>> {
114        Ok(self.headers()?.is_ws_upgrade_accepted(nonce, buf))
115    }
116
117    /// Split the connection into its headers and body parts.
118    ///
119    /// The connection must be in response mode.
120    #[allow(clippy::type_complexity)]
121    pub fn split(&mut self) -> (&ResponseHeaders<'b, N>, &mut Body<'b, T::Socket<'b>>) {
122        let response = self.response_mut().expect("Not in response mode");
123
124        (&response.response, &mut response.io)
125    }
126
127    /// Get the headers of the response.
128    ///
129    /// The connection must be in response mode.
130    pub fn headers(&self) -> Result<&ResponseHeaders<'b, N>, Error<T::Error>> {
131        let response = self.response_ref()?;
132
133        Ok(&response.response)
134    }
135
136    /// Get a mutable reference to the raw connection.
137    ///
138    /// This can be used to send raw data over the connection.
139    pub fn raw_connection(&mut self) -> Result<&mut T::Socket<'b>, Error<T::Error>> {
140        Ok(self.io_mut())
141    }
142
143    /// Release the connection, returning the raw connection and the buffer.
144    pub fn release(mut self) -> (T::Socket<'b>, &'b mut [u8]) {
145        let mut state = self.unbind();
146
147        let io = unwrap!(state.io.take());
148
149        (io, state.buf)
150    }
151
152    async fn start_request(
153        &mut self,
154        http11: bool,
155        method: Method,
156        uri: &str,
157        headers: &[(&str, &str)],
158    ) -> Result<(), Error<T::Error>> {
159        let _ = self.complete().await;
160
161        let state = self.unbound_mut()?;
162
163        let fresh_connection = if state.io.is_none() {
164            state.io = Some(state.socket.connect(state.addr).await.map_err(Error::Io)?);
165            true
166        } else {
167            false
168        };
169
170        let mut state = self.unbind();
171
172        let result = async {
173            match send_request(http11, method, uri, unwrap!(state.io.as_mut())).await {
174                Ok(_) => (),
175                Err(Error::Io(_)) => {
176                    if !fresh_connection {
177                        // Attempt to reconnect and re-send the request
178                        state.io = None;
179                        state.io = Some(state.socket.connect(state.addr).await.map_err(Error::Io)?);
180
181                        send_request(http11, method, uri, unwrap!(state.io.as_mut())).await?;
182                    }
183                }
184                Err(other) => Err(other)?,
185            }
186
187            let io = unwrap!(state.io.as_mut());
188
189            send_headers(headers, None, true, http11, true, &mut *io).await
190        }
191        .await;
192
193        match result {
194            Ok((connection_type, body_type)) => {
195                *self = Self::Request(RequestState {
196                    buf: state.buf,
197                    socket: state.socket,
198                    addr: state.addr,
199                    connection_type,
200                    io: SendBody::new(body_type, unwrap!(state.io)),
201                });
202
203                Ok(())
204            }
205            Err(e) => {
206                state.io = None;
207                *self = Self::Unbound(state);
208
209                Err(e)
210            }
211        }
212    }
213
214    /// Complete the request-response cycle
215    ///
216    /// If the request has not been initiated, this method will do nothing.
217    /// If the response has not been initiated, it will be initiated and will be consumed.
218    pub async fn complete(&mut self) -> Result<(), Error<T::Error>> {
219        let result = async {
220            if self.request_mut().is_ok() {
221                self.complete_request().await?;
222            }
223
224            let needs_close = if self.response_mut().is_ok() {
225                self.complete_response().await?
226            } else {
227                false
228            };
229
230            Result::<_, Error<T::Error>>::Ok(needs_close)
231        }
232        .await;
233
234        let mut state = self.unbind();
235
236        match result {
237            Ok(true) | Err(_) => {
238                let io = state.io.take();
239                *self = Self::Unbound(state);
240
241                if let Some(mut io) = io {
242                    io.close(Close::Both).await.map_err(Error::Io)?;
243                    let _ = io.abort().await;
244                }
245            }
246            _ => {
247                *self = Self::Unbound(state);
248            }
249        };
250
251        result?;
252
253        Ok(())
254    }
255
256    pub async fn close(mut self) -> Result<(), Error<T::Error>> {
257        let res = self.complete().await;
258
259        if let Some(mut io) = self.unbind().io.take() {
260            io.close(Close::Both).await.map_err(Error::Io)?;
261            let _ = io.abort().await;
262        }
263
264        res
265    }
266
267    async fn complete_request(&mut self) -> Result<(), Error<T::Error>> {
268        self.request_mut()?.io.finish().await?;
269
270        let request_connection_type = self.request_mut()?.connection_type;
271
272        let mut state = self.unbind();
273        let buf_ptr: *mut [u8] = state.buf;
274        let mut response = ResponseHeaders::new();
275
276        match response
277            .receive(state.buf, &mut unwrap!(state.io.as_mut()), true)
278            .await
279        {
280            Ok((buf, read_len)) => {
281                let (connection_type, body_type) =
282                    response.resolve::<T::Error>(request_connection_type)?;
283
284                let io = Body::new(body_type, buf, read_len, unwrap!(state.io));
285
286                *self = Self::Response(ResponseState {
287                    buf: buf_ptr,
288                    response,
289                    socket: state.socket,
290                    addr: state.addr,
291                    connection_type,
292                    io,
293                });
294
295                Ok(())
296            }
297            Err(e) => {
298                state.io = None;
299                state.buf = unwrap!(unsafe { buf_ptr.as_mut() });
300
301                *self = Self::Unbound(state);
302
303                Err(e)
304            }
305        }
306    }
307
308    async fn complete_response(&mut self) -> Result<bool, Error<T::Error>> {
309        if self.request_mut().is_ok() {
310            self.complete_request().await?;
311        }
312
313        let response = self.response_mut()?;
314
315        let mut buf = [0; COMPLETION_BUF_SIZE];
316        while response.io.read(&mut buf).await? > 0 {}
317
318        let needs_close = response.needs_close();
319
320        *self = Self::Unbound(self.unbind());
321
322        Ok(needs_close)
323    }
324
325    /// Return `true` if the connection needs to be closed (i.e. the server has requested it or the connection is in an invalid state)
326    pub fn needs_close(&self) -> bool {
327        match self {
328            Self::Response(response) => response.needs_close(),
329            _ => true,
330        }
331    }
332
333    fn unbind(&mut self) -> UnboundState<'b, T, N> {
334        let state = mem::replace(self, Self::Transition(TransitionState(())));
335
336        match state {
337            Self::Unbound(unbound) => unbound,
338            Self::Request(request) => {
339                let io = request.io.release();
340
341                UnboundState {
342                    buf: request.buf,
343                    socket: request.socket,
344                    addr: request.addr,
345                    io: Some(io),
346                }
347            }
348            Self::Response(response) => {
349                let io = response.io.release();
350
351                UnboundState {
352                    buf: unwrap!(unsafe { response.buf.as_mut() }),
353                    socket: response.socket,
354                    addr: response.addr,
355                    io: Some(io),
356                }
357            }
358            _ => unreachable!(),
359        }
360    }
361
362    fn unbound_mut(&mut self) -> Result<&mut UnboundState<'b, T, N>, Error<T::Error>> {
363        if let Self::Unbound(new) = self {
364            Ok(new)
365        } else {
366            Err(Error::InvalidState)
367        }
368    }
369
370    fn request_mut(&mut self) -> Result<&mut RequestState<'b, T, N>, Error<T::Error>> {
371        if let Self::Request(request) = self {
372            Ok(request)
373        } else {
374            Err(Error::InvalidState)
375        }
376    }
377
378    fn response_mut(&mut self) -> Result<&mut ResponseState<'b, T, N>, Error<T::Error>> {
379        if let Self::Response(response) = self {
380            Ok(response)
381        } else {
382            Err(Error::InvalidState)
383        }
384    }
385
386    fn response_ref(&self) -> Result<&ResponseState<'b, T, N>, Error<T::Error>> {
387        if let Self::Response(response) = self {
388            Ok(response)
389        } else {
390            Err(Error::InvalidState)
391        }
392    }
393
394    fn io_mut(&mut self) -> &mut T::Socket<'b> {
395        match self {
396            Self::Unbound(unbound) => unwrap!(unbound.io.as_mut()),
397            Self::Request(request) => request.io.as_raw_writer(),
398            Self::Response(response) => response.io.as_raw_reader(),
399            _ => unreachable!(),
400        }
401    }
402}
403
404impl<T, const N: usize> ErrorType for Connection<'_, T, N>
405where
406    T: TcpConnect,
407{
408    type Error = Error<T::Error>;
409}
410
411impl<T, const N: usize> Read for Connection<'_, T, N>
412where
413    T: TcpConnect,
414{
415    async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
416        self.response_mut()?.io.read(buf).await
417    }
418}
419
420impl<T, const N: usize> Write for Connection<'_, T, N>
421where
422    T: TcpConnect,
423{
424    async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
425        self.request_mut()?.io.write(buf).await
426    }
427
428    async fn flush(&mut self) -> Result<(), Self::Error> {
429        self.request_mut()?.io.flush().await
430    }
431}
432
433struct TransitionState(());
434
435struct UnboundState<'b, T, const N: usize>
436where
437    T: TcpConnect,
438{
439    buf: &'b mut [u8],
440    socket: &'b T,
441    addr: SocketAddr,
442    io: Option<T::Socket<'b>>,
443}
444
445struct RequestState<'b, T, const N: usize>
446where
447    T: TcpConnect,
448{
449    buf: &'b mut [u8],
450    socket: &'b T,
451    addr: SocketAddr,
452    connection_type: ConnectionType,
453    io: SendBody<T::Socket<'b>>,
454}
455
456struct ResponseState<'b, T, const N: usize>
457where
458    T: TcpConnect,
459{
460    buf: *mut [u8],
461    response: ResponseHeaders<'b, N>,
462    socket: &'b T,
463    addr: SocketAddr,
464    connection_type: ConnectionType,
465    io: Body<'b, T::Socket<'b>>,
466}
467
468impl<T, const N: usize> ResponseState<'_, T, N>
469where
470    T: TcpConnect,
471{
472    fn needs_close(&self) -> bool {
473        matches!(self.connection_type, ConnectionType::Close) || self.io.needs_close()
474    }
475}