edge_http/io/
client.rs

1use core::mem;
2use core::net::SocketAddr;
3use core::str;
4
5use embedded_io_async::{ErrorType, Read, Write};
6
7use edge_nal::{Close, TcpConnect, TcpShutdown};
8
9use crate::{
10    ws::{upgrade_request_headers, MAX_BASE64_KEY_LEN, MAX_BASE64_KEY_RESPONSE_LEN, NONCE_LEN},
11    ConnectionType, DEFAULT_MAX_HEADERS_COUNT,
12};
13
14use super::{send_headers, send_request, Body, Error, ResponseHeaders, SendBody};
15
16#[allow(unused_imports)]
17#[cfg(feature = "embedded-svc")]
18pub use embedded_svc_compat::*;
19
20use super::Method;
21
22const COMPLETION_BUF_SIZE: usize = 64;
23
24/// A client connection that can be used to send HTTP requests and receive responses.
25#[allow(private_interfaces)]
26pub enum Connection<'b, T, const N: usize = DEFAULT_MAX_HEADERS_COUNT>
27where
28    T: TcpConnect,
29{
30    Transition(TransitionState),
31    Unbound(UnboundState<'b, T, N>),
32    Request(RequestState<'b, T, N>),
33    Response(ResponseState<'b, T, N>),
34}
35
36impl<'b, T, const N: usize> Connection<'b, T, N>
37where
38    T: TcpConnect,
39{
40    /// Create a new client connection.
41    ///
42    /// Note that the connection does not have any built-in read/write timeouts:
43    /// - To add a timeout on each IO operation, wrap the `socket` type with the `edge_nal::WithTimeout` wrapper.
44    /// - To add a global request-response timeout, wrap your complete request-response processing
45    ///   logic with the `edge_nal::with_timeout` function.
46    ///
47    /// Parameters:
48    /// - `buf`: A buffer to use for reading and writing data.
49    /// - `socket`: The TCP stack to use for the connection.
50    /// - `addr`: The address of the server to connect to.
51    pub fn new(buf: &'b mut [u8], socket: &'b T, addr: SocketAddr) -> Self {
52        Self::Unbound(UnboundState {
53            buf,
54            socket,
55            addr,
56            io: None,
57        })
58    }
59
60    /// Reinitialize the connection with a new address.
61    pub async fn reinitialize(&mut self, addr: SocketAddr) -> Result<(), Error<T::Error>> {
62        let _ = self.complete().await;
63        unwrap!(self.unbound_mut(), "Unreachable").addr = addr;
64
65        Ok(())
66    }
67
68    /// Initiate an HTTP request.
69    pub async fn initiate_request(
70        &mut self,
71        http11: bool,
72        method: Method,
73        uri: &str,
74        headers: &[(&str, &str)],
75    ) -> Result<(), Error<T::Error>> {
76        self.start_request(http11, method, uri, headers).await
77    }
78
79    /// A utility method to initiate a WebSocket upgrade request.
80    pub async fn initiate_ws_upgrade_request(
81        &mut self,
82        host: Option<&str>,
83        origin: Option<&str>,
84        uri: &str,
85        version: Option<&str>,
86        nonce: &[u8; NONCE_LEN],
87        nonce_base64_buf: &mut [u8; MAX_BASE64_KEY_LEN],
88    ) -> Result<(), Error<T::Error>> {
89        let headers = upgrade_request_headers(host, origin, version, nonce, nonce_base64_buf);
90
91        self.initiate_request(true, Method::Get, uri, &headers)
92            .await
93    }
94
95    /// Return `true` if a request has been initiated.
96    pub fn is_request_initiated(&self) -> bool {
97        matches!(self, Self::Request(_))
98    }
99
100    /// Initiate an HTTP response.
101    ///
102    /// This should be called after a request has been initiated and the request body had been sent.
103    pub async fn initiate_response(&mut self) -> Result<(), Error<T::Error>> {
104        self.complete_request().await
105    }
106
107    /// Return `true` if a response has been initiated.
108    pub fn is_response_initiated(&self) -> bool {
109        matches!(self, Self::Response(_))
110    }
111
112    /// Return `true` if the server accepted the WebSocket upgrade request.
113    pub fn is_ws_upgrade_accepted(
114        &self,
115        nonce: &[u8; NONCE_LEN],
116        buf: &mut [u8; MAX_BASE64_KEY_RESPONSE_LEN],
117    ) -> Result<bool, Error<T::Error>> {
118        Ok(self.headers()?.is_ws_upgrade_accepted(nonce, buf))
119    }
120
121    /// Split the connection into its headers and body parts.
122    ///
123    /// The connection must be in response mode.
124    #[allow(clippy::type_complexity)]
125    pub fn split(&mut self) -> (&ResponseHeaders<'b, N>, &mut Body<'b, T::Socket<'b>>) {
126        let response = self.response_mut().expect("Not in response mode");
127
128        (&response.response, &mut response.io)
129    }
130
131    /// Get the headers of the response.
132    ///
133    /// The connection must be in response mode.
134    pub fn headers(&self) -> Result<&ResponseHeaders<'b, N>, Error<T::Error>> {
135        let response = self.response_ref()?;
136
137        Ok(&response.response)
138    }
139
140    /// Get a mutable reference to the raw connection.
141    ///
142    /// This can be used to send raw data over the connection.
143    pub fn raw_connection(&mut self) -> Result<&mut T::Socket<'b>, Error<T::Error>> {
144        Ok(self.io_mut())
145    }
146
147    /// Release the connection, returning the raw connection and the buffer.
148    pub fn release(mut self) -> (T::Socket<'b>, &'b mut [u8]) {
149        let mut state = self.unbind();
150
151        let io = unwrap!(state.io.take());
152
153        (io, state.buf)
154    }
155
156    async fn start_request(
157        &mut self,
158        http11: bool,
159        method: Method,
160        uri: &str,
161        headers: &[(&str, &str)],
162    ) -> Result<(), Error<T::Error>> {
163        let _ = self.complete().await;
164
165        let state = self.unbound_mut()?;
166
167        let fresh_connection = if state.io.is_none() {
168            state.io = Some(state.socket.connect(state.addr).await.map_err(Error::Io)?);
169            true
170        } else {
171            false
172        };
173
174        let mut state = self.unbind();
175
176        let result = async {
177            match send_request(http11, method, uri, unwrap!(state.io.as_mut())).await {
178                Ok(_) => (),
179                Err(Error::Io(_)) => {
180                    if !fresh_connection {
181                        // Attempt to reconnect and re-send the request
182                        state.io = None;
183                        state.io = Some(state.socket.connect(state.addr).await.map_err(Error::Io)?);
184
185                        send_request(http11, method, uri, unwrap!(state.io.as_mut())).await?;
186                    }
187                }
188                Err(other) => Err(other)?,
189            }
190
191            let io = unwrap!(state.io.as_mut());
192
193            send_headers(headers, None, true, http11, true, &mut *io).await
194        }
195        .await;
196
197        match result {
198            Ok((connection_type, body_type)) => {
199                *self = Self::Request(RequestState {
200                    buf: state.buf,
201                    socket: state.socket,
202                    addr: state.addr,
203                    connection_type,
204                    io: SendBody::new(body_type, unwrap!(state.io)),
205                });
206
207                Ok(())
208            }
209            Err(e) => {
210                state.io = None;
211                *self = Self::Unbound(state);
212
213                Err(e)
214            }
215        }
216    }
217
218    /// Complete the request-response cycle
219    ///
220    /// If the request has not been initiated, this method will do nothing.
221    /// If the response has not been initiated, it will be initiated and will be consumed.
222    pub async fn complete(&mut self) -> Result<(), Error<T::Error>> {
223        let result = async {
224            if self.request_mut().is_ok() {
225                self.complete_request().await?;
226            }
227
228            let needs_close = if self.response_mut().is_ok() {
229                self.complete_response().await?
230            } else {
231                false
232            };
233
234            Result::<_, Error<T::Error>>::Ok(needs_close)
235        }
236        .await;
237
238        let mut state = self.unbind();
239
240        match result {
241            Ok(true) | Err(_) => {
242                let io = state.io.take();
243                *self = Self::Unbound(state);
244
245                if let Some(mut io) = io {
246                    io.close(Close::Both).await.map_err(Error::Io)?;
247                    let _ = io.abort().await;
248                }
249            }
250            _ => {
251                *self = Self::Unbound(state);
252            }
253        };
254
255        result?;
256
257        Ok(())
258    }
259
260    pub async fn close(mut self) -> Result<(), Error<T::Error>> {
261        let res = self.complete().await;
262
263        if let Some(mut io) = self.unbind().io.take() {
264            io.close(Close::Both).await.map_err(Error::Io)?;
265            let _ = io.abort().await;
266        }
267
268        res
269    }
270
271    async fn complete_request(&mut self) -> Result<(), Error<T::Error>> {
272        self.request_mut()?.io.finish().await?;
273
274        let request_connection_type = self.request_mut()?.connection_type;
275
276        let mut state = self.unbind();
277        let buf_ptr: *mut [u8] = state.buf;
278        let mut response = ResponseHeaders::new();
279
280        match response
281            .receive(state.buf, &mut unwrap!(state.io.as_mut()), true)
282            .await
283        {
284            Ok((buf, read_len)) => {
285                let (connection_type, body_type) =
286                    response.resolve::<T::Error>(request_connection_type)?;
287
288                let io = Body::new(body_type, buf, read_len, unwrap!(state.io));
289
290                *self = Self::Response(ResponseState {
291                    buf: buf_ptr,
292                    response,
293                    socket: state.socket,
294                    addr: state.addr,
295                    connection_type,
296                    io,
297                });
298
299                Ok(())
300            }
301            Err(e) => {
302                state.io = None;
303                state.buf = unwrap!(unsafe { buf_ptr.as_mut() });
304
305                *self = Self::Unbound(state);
306
307                Err(e)
308            }
309        }
310    }
311
312    async fn complete_response(&mut self) -> Result<bool, Error<T::Error>> {
313        if self.request_mut().is_ok() {
314            self.complete_request().await?;
315        }
316
317        let response = self.response_mut()?;
318
319        let mut buf = [0; COMPLETION_BUF_SIZE];
320        while response.io.read(&mut buf).await? > 0 {}
321
322        let needs_close = response.needs_close();
323
324        *self = Self::Unbound(self.unbind());
325
326        Ok(needs_close)
327    }
328
329    /// Return `true` if the connection needs to be closed (i.e. the server has requested it or the connection is in an invalid state)
330    pub fn needs_close(&self) -> bool {
331        match self {
332            Self::Response(response) => response.needs_close(),
333            _ => true,
334        }
335    }
336
337    fn unbind(&mut self) -> UnboundState<'b, T, N> {
338        let state = mem::replace(self, Self::Transition(TransitionState(())));
339
340        match state {
341            Self::Unbound(unbound) => unbound,
342            Self::Request(request) => {
343                let io = request.io.release();
344
345                UnboundState {
346                    buf: request.buf,
347                    socket: request.socket,
348                    addr: request.addr,
349                    io: Some(io),
350                }
351            }
352            Self::Response(response) => {
353                let io = response.io.release();
354
355                UnboundState {
356                    buf: unwrap!(unsafe { response.buf.as_mut() }),
357                    socket: response.socket,
358                    addr: response.addr,
359                    io: Some(io),
360                }
361            }
362            _ => unreachable!(),
363        }
364    }
365
366    fn unbound_mut(&mut self) -> Result<&mut UnboundState<'b, T, N>, Error<T::Error>> {
367        if let Self::Unbound(new) = self {
368            Ok(new)
369        } else {
370            Err(Error::InvalidState)
371        }
372    }
373
374    fn request_mut(&mut self) -> Result<&mut RequestState<'b, T, N>, Error<T::Error>> {
375        if let Self::Request(request) = self {
376            Ok(request)
377        } else {
378            Err(Error::InvalidState)
379        }
380    }
381
382    fn response_mut(&mut self) -> Result<&mut ResponseState<'b, T, N>, Error<T::Error>> {
383        if let Self::Response(response) = self {
384            Ok(response)
385        } else {
386            Err(Error::InvalidState)
387        }
388    }
389
390    fn response_ref(&self) -> Result<&ResponseState<'b, T, N>, Error<T::Error>> {
391        if let Self::Response(response) = self {
392            Ok(response)
393        } else {
394            Err(Error::InvalidState)
395        }
396    }
397
398    fn io_mut(&mut self) -> &mut T::Socket<'b> {
399        match self {
400            Self::Unbound(unbound) => unwrap!(unbound.io.as_mut()),
401            Self::Request(request) => request.io.as_raw_writer(),
402            Self::Response(response) => response.io.as_raw_reader(),
403            _ => unreachable!(),
404        }
405    }
406}
407
408impl<T, const N: usize> ErrorType for Connection<'_, T, N>
409where
410    T: TcpConnect,
411{
412    type Error = Error<T::Error>;
413}
414
415impl<T, const N: usize> Read for Connection<'_, T, N>
416where
417    T: TcpConnect,
418{
419    async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
420        self.response_mut()?.io.read(buf).await
421    }
422}
423
424impl<T, const N: usize> Write for Connection<'_, T, N>
425where
426    T: TcpConnect,
427{
428    async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
429        self.request_mut()?.io.write(buf).await
430    }
431
432    async fn flush(&mut self) -> Result<(), Self::Error> {
433        self.request_mut()?.io.flush().await
434    }
435}
436
437struct TransitionState(());
438
439struct UnboundState<'b, T, const N: usize>
440where
441    T: TcpConnect,
442{
443    buf: &'b mut [u8],
444    socket: &'b T,
445    addr: SocketAddr,
446    io: Option<T::Socket<'b>>,
447}
448
449struct RequestState<'b, T, const N: usize>
450where
451    T: TcpConnect,
452{
453    buf: &'b mut [u8],
454    socket: &'b T,
455    addr: SocketAddr,
456    connection_type: ConnectionType,
457    io: SendBody<T::Socket<'b>>,
458}
459
460struct ResponseState<'b, T, const N: usize>
461where
462    T: TcpConnect,
463{
464    buf: *mut [u8],
465    response: ResponseHeaders<'b, N>,
466    socket: &'b T,
467    addr: SocketAddr,
468    connection_type: ConnectionType,
469    io: Body<'b, T::Socket<'b>>,
470}
471
472impl<T, const N: usize> ResponseState<'_, T, N>
473where
474    T: TcpConnect,
475{
476    fn needs_close(&self) -> bool {
477        matches!(self.connection_type, ConnectionType::Close) || self.io.needs_close()
478    }
479}
480
481#[cfg(feature = "embedded-svc")]
482mod embedded_svc_compat {
483    use super::*;
484
485    use embedded_svc::http::client::asynch::{Connection, Headers, Method, Status};
486
487    impl<T, const N: usize> Headers for super::Connection<'_, T, N>
488    where
489        T: TcpConnect,
490    {
491        fn header(&self, name: &str) -> Option<&'_ str> {
492            let response = self.response_ref().expect("Not in response state");
493
494            response.response.header(name)
495        }
496    }
497
498    impl<T, const N: usize> Status for super::Connection<'_, T, N>
499    where
500        T: TcpConnect,
501    {
502        fn status(&self) -> u16 {
503            let response = self.response_ref().expect("Not in response state");
504
505            response.response.status()
506        }
507
508        fn status_message(&self) -> Option<&'_ str> {
509            let response = self.response_ref().expect("Not in response state");
510
511            response.response.status_message()
512        }
513    }
514
515    impl<'b, T, const N: usize> Connection for super::Connection<'b, T, N>
516    where
517        T: TcpConnect,
518    {
519        type Read = Body<'b, T::Socket<'b>>;
520
521        type Headers = ResponseHeaders<'b, N>;
522
523        type RawConnectionError = T::Error;
524
525        type RawConnection = T::Socket<'b>;
526
527        async fn initiate_request(
528            &mut self,
529            method: Method,
530            uri: &str,
531            headers: &[(&str, &str)],
532        ) -> Result<(), Self::Error> {
533            super::Connection::initiate_request(self, true, method.into(), uri, headers).await
534        }
535
536        fn is_request_initiated(&self) -> bool {
537            super::Connection::is_request_initiated(self)
538        }
539
540        async fn initiate_response(&mut self) -> Result<(), Self::Error> {
541            super::Connection::initiate_response(self).await
542        }
543
544        fn is_response_initiated(&self) -> bool {
545            super::Connection::is_response_initiated(self)
546        }
547
548        fn split(&mut self) -> (&Self::Headers, &mut Self::Read) {
549            super::Connection::split(self)
550        }
551
552        fn raw_connection(&mut self) -> Result<&mut Self::RawConnection, Self::Error> {
553            // TODO: Needs a GAT rather than `&mut` return type
554            // or `embedded-svc` fully upgraded to async traits & `embedded-io` 0.4 to re-enable
555            //ClientConnection::raw_connection(self).map(EmbIo)
556            panic!("Not supported")
557        }
558    }
559}