1use core::mem;
2use core::net::SocketAddr;
3use core::str;
4
5use embedded_io_async::{ErrorType, Read, Write};
6
7use edge_nal::{Close, TcpConnect, TcpShutdown};
8
9use crate::{
10 ws::{upgrade_request_headers, MAX_BASE64_KEY_LEN, MAX_BASE64_KEY_RESPONSE_LEN, NONCE_LEN},
11 ConnectionType, DEFAULT_MAX_HEADERS_COUNT,
12};
13
14use super::{send_headers, send_request, Body, Error, ResponseHeaders, SendBody};
15
16#[allow(unused_imports)]
17#[cfg(feature = "embedded-svc")]
18pub use embedded_svc_compat::*;
19
20use super::Method;
21
22const COMPLETION_BUF_SIZE: usize = 64;
23
24#[allow(private_interfaces)]
26pub enum Connection<'b, T, const N: usize = DEFAULT_MAX_HEADERS_COUNT>
27where
28 T: TcpConnect,
29{
30 Transition(TransitionState),
31 Unbound(UnboundState<'b, T, N>),
32 Request(RequestState<'b, T, N>),
33 Response(ResponseState<'b, T, N>),
34}
35
36impl<'b, T, const N: usize> Connection<'b, T, N>
37where
38 T: TcpConnect,
39{
40 pub fn new(buf: &'b mut [u8], socket: &'b T, addr: SocketAddr) -> Self {
52 Self::Unbound(UnboundState {
53 buf,
54 socket,
55 addr,
56 io: None,
57 })
58 }
59
60 pub async fn reinitialize(&mut self, addr: SocketAddr) -> Result<(), Error<T::Error>> {
62 let _ = self.complete().await;
63 unwrap!(self.unbound_mut(), "Unreachable").addr = addr;
64
65 Ok(())
66 }
67
68 pub async fn initiate_request(
70 &mut self,
71 http11: bool,
72 method: Method,
73 uri: &str,
74 headers: &[(&str, &str)],
75 ) -> Result<(), Error<T::Error>> {
76 self.start_request(http11, method, uri, headers).await
77 }
78
79 pub async fn initiate_ws_upgrade_request(
81 &mut self,
82 host: Option<&str>,
83 origin: Option<&str>,
84 uri: &str,
85 version: Option<&str>,
86 nonce: &[u8; NONCE_LEN],
87 nonce_base64_buf: &mut [u8; MAX_BASE64_KEY_LEN],
88 ) -> Result<(), Error<T::Error>> {
89 let headers = upgrade_request_headers(host, origin, version, nonce, nonce_base64_buf);
90
91 self.initiate_request(true, Method::Get, uri, &headers)
92 .await
93 }
94
95 pub fn is_request_initiated(&self) -> bool {
97 matches!(self, Self::Request(_))
98 }
99
100 pub async fn initiate_response(&mut self) -> Result<(), Error<T::Error>> {
104 self.complete_request().await
105 }
106
107 pub fn is_response_initiated(&self) -> bool {
109 matches!(self, Self::Response(_))
110 }
111
112 pub fn is_ws_upgrade_accepted(
114 &self,
115 nonce: &[u8; NONCE_LEN],
116 buf: &mut [u8; MAX_BASE64_KEY_RESPONSE_LEN],
117 ) -> Result<bool, Error<T::Error>> {
118 Ok(self.headers()?.is_ws_upgrade_accepted(nonce, buf))
119 }
120
121 #[allow(clippy::type_complexity)]
125 pub fn split(&mut self) -> (&ResponseHeaders<'b, N>, &mut Body<'b, T::Socket<'b>>) {
126 let response = self.response_mut().expect("Not in response mode");
127
128 (&response.response, &mut response.io)
129 }
130
131 pub fn headers(&self) -> Result<&ResponseHeaders<'b, N>, Error<T::Error>> {
135 let response = self.response_ref()?;
136
137 Ok(&response.response)
138 }
139
140 pub fn raw_connection(&mut self) -> Result<&mut T::Socket<'b>, Error<T::Error>> {
144 Ok(self.io_mut())
145 }
146
147 pub fn release(mut self) -> (T::Socket<'b>, &'b mut [u8]) {
149 let mut state = self.unbind();
150
151 let io = unwrap!(state.io.take());
152
153 (io, state.buf)
154 }
155
156 async fn start_request(
157 &mut self,
158 http11: bool,
159 method: Method,
160 uri: &str,
161 headers: &[(&str, &str)],
162 ) -> Result<(), Error<T::Error>> {
163 let _ = self.complete().await;
164
165 let state = self.unbound_mut()?;
166
167 let fresh_connection = if state.io.is_none() {
168 state.io = Some(state.socket.connect(state.addr).await.map_err(Error::Io)?);
169 true
170 } else {
171 false
172 };
173
174 let mut state = self.unbind();
175
176 let result = async {
177 match send_request(http11, method, uri, unwrap!(state.io.as_mut())).await {
178 Ok(_) => (),
179 Err(Error::Io(_)) => {
180 if !fresh_connection {
181 state.io = None;
183 state.io = Some(state.socket.connect(state.addr).await.map_err(Error::Io)?);
184
185 send_request(http11, method, uri, unwrap!(state.io.as_mut())).await?;
186 }
187 }
188 Err(other) => Err(other)?,
189 }
190
191 let io = unwrap!(state.io.as_mut());
192
193 send_headers(headers, None, true, http11, true, &mut *io).await
194 }
195 .await;
196
197 match result {
198 Ok((connection_type, body_type)) => {
199 *self = Self::Request(RequestState {
200 buf: state.buf,
201 socket: state.socket,
202 addr: state.addr,
203 connection_type,
204 io: SendBody::new(body_type, unwrap!(state.io)),
205 });
206
207 Ok(())
208 }
209 Err(e) => {
210 state.io = None;
211 *self = Self::Unbound(state);
212
213 Err(e)
214 }
215 }
216 }
217
218 pub async fn complete(&mut self) -> Result<(), Error<T::Error>> {
223 let result = async {
224 if self.request_mut().is_ok() {
225 self.complete_request().await?;
226 }
227
228 let needs_close = if self.response_mut().is_ok() {
229 self.complete_response().await?
230 } else {
231 false
232 };
233
234 Result::<_, Error<T::Error>>::Ok(needs_close)
235 }
236 .await;
237
238 let mut state = self.unbind();
239
240 match result {
241 Ok(true) | Err(_) => {
242 let io = state.io.take();
243 *self = Self::Unbound(state);
244
245 if let Some(mut io) = io {
246 io.close(Close::Both).await.map_err(Error::Io)?;
247 let _ = io.abort().await;
248 }
249 }
250 _ => {
251 *self = Self::Unbound(state);
252 }
253 };
254
255 result?;
256
257 Ok(())
258 }
259
260 pub async fn close(mut self) -> Result<(), Error<T::Error>> {
261 let res = self.complete().await;
262
263 if let Some(mut io) = self.unbind().io.take() {
264 io.close(Close::Both).await.map_err(Error::Io)?;
265 let _ = io.abort().await;
266 }
267
268 res
269 }
270
271 async fn complete_request(&mut self) -> Result<(), Error<T::Error>> {
272 self.request_mut()?.io.finish().await?;
273
274 let request_connection_type = self.request_mut()?.connection_type;
275
276 let mut state = self.unbind();
277 let buf_ptr: *mut [u8] = state.buf;
278 let mut response = ResponseHeaders::new();
279
280 match response
281 .receive(state.buf, &mut unwrap!(state.io.as_mut()), true)
282 .await
283 {
284 Ok((buf, read_len)) => {
285 let (connection_type, body_type) =
286 response.resolve::<T::Error>(request_connection_type)?;
287
288 let io = Body::new(body_type, buf, read_len, unwrap!(state.io));
289
290 *self = Self::Response(ResponseState {
291 buf: buf_ptr,
292 response,
293 socket: state.socket,
294 addr: state.addr,
295 connection_type,
296 io,
297 });
298
299 Ok(())
300 }
301 Err(e) => {
302 state.io = None;
303 state.buf = unwrap!(unsafe { buf_ptr.as_mut() });
304
305 *self = Self::Unbound(state);
306
307 Err(e)
308 }
309 }
310 }
311
312 async fn complete_response(&mut self) -> Result<bool, Error<T::Error>> {
313 if self.request_mut().is_ok() {
314 self.complete_request().await?;
315 }
316
317 let response = self.response_mut()?;
318
319 let mut buf = [0; COMPLETION_BUF_SIZE];
320 while response.io.read(&mut buf).await? > 0 {}
321
322 let needs_close = response.needs_close();
323
324 *self = Self::Unbound(self.unbind());
325
326 Ok(needs_close)
327 }
328
329 pub fn needs_close(&self) -> bool {
331 match self {
332 Self::Response(response) => response.needs_close(),
333 _ => true,
334 }
335 }
336
337 fn unbind(&mut self) -> UnboundState<'b, T, N> {
338 let state = mem::replace(self, Self::Transition(TransitionState(())));
339
340 match state {
341 Self::Unbound(unbound) => unbound,
342 Self::Request(request) => {
343 let io = request.io.release();
344
345 UnboundState {
346 buf: request.buf,
347 socket: request.socket,
348 addr: request.addr,
349 io: Some(io),
350 }
351 }
352 Self::Response(response) => {
353 let io = response.io.release();
354
355 UnboundState {
356 buf: unwrap!(unsafe { response.buf.as_mut() }),
357 socket: response.socket,
358 addr: response.addr,
359 io: Some(io),
360 }
361 }
362 _ => unreachable!(),
363 }
364 }
365
366 fn unbound_mut(&mut self) -> Result<&mut UnboundState<'b, T, N>, Error<T::Error>> {
367 if let Self::Unbound(new) = self {
368 Ok(new)
369 } else {
370 Err(Error::InvalidState)
371 }
372 }
373
374 fn request_mut(&mut self) -> Result<&mut RequestState<'b, T, N>, Error<T::Error>> {
375 if let Self::Request(request) = self {
376 Ok(request)
377 } else {
378 Err(Error::InvalidState)
379 }
380 }
381
382 fn response_mut(&mut self) -> Result<&mut ResponseState<'b, T, N>, Error<T::Error>> {
383 if let Self::Response(response) = self {
384 Ok(response)
385 } else {
386 Err(Error::InvalidState)
387 }
388 }
389
390 fn response_ref(&self) -> Result<&ResponseState<'b, T, N>, Error<T::Error>> {
391 if let Self::Response(response) = self {
392 Ok(response)
393 } else {
394 Err(Error::InvalidState)
395 }
396 }
397
398 fn io_mut(&mut self) -> &mut T::Socket<'b> {
399 match self {
400 Self::Unbound(unbound) => unwrap!(unbound.io.as_mut()),
401 Self::Request(request) => request.io.as_raw_writer(),
402 Self::Response(response) => response.io.as_raw_reader(),
403 _ => unreachable!(),
404 }
405 }
406}
407
408impl<T, const N: usize> ErrorType for Connection<'_, T, N>
409where
410 T: TcpConnect,
411{
412 type Error = Error<T::Error>;
413}
414
415impl<T, const N: usize> Read for Connection<'_, T, N>
416where
417 T: TcpConnect,
418{
419 async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
420 self.response_mut()?.io.read(buf).await
421 }
422}
423
424impl<T, const N: usize> Write for Connection<'_, T, N>
425where
426 T: TcpConnect,
427{
428 async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
429 self.request_mut()?.io.write(buf).await
430 }
431
432 async fn flush(&mut self) -> Result<(), Self::Error> {
433 self.request_mut()?.io.flush().await
434 }
435}
436
437struct TransitionState(());
438
439struct UnboundState<'b, T, const N: usize>
440where
441 T: TcpConnect,
442{
443 buf: &'b mut [u8],
444 socket: &'b T,
445 addr: SocketAddr,
446 io: Option<T::Socket<'b>>,
447}
448
449struct RequestState<'b, T, const N: usize>
450where
451 T: TcpConnect,
452{
453 buf: &'b mut [u8],
454 socket: &'b T,
455 addr: SocketAddr,
456 connection_type: ConnectionType,
457 io: SendBody<T::Socket<'b>>,
458}
459
460struct ResponseState<'b, T, const N: usize>
461where
462 T: TcpConnect,
463{
464 buf: *mut [u8],
465 response: ResponseHeaders<'b, N>,
466 socket: &'b T,
467 addr: SocketAddr,
468 connection_type: ConnectionType,
469 io: Body<'b, T::Socket<'b>>,
470}
471
472impl<T, const N: usize> ResponseState<'_, T, N>
473where
474 T: TcpConnect,
475{
476 fn needs_close(&self) -> bool {
477 matches!(self.connection_type, ConnectionType::Close) || self.io.needs_close()
478 }
479}
480
481#[cfg(feature = "embedded-svc")]
482mod embedded_svc_compat {
483 use super::*;
484
485 use embedded_svc::http::client::asynch::{Connection, Headers, Method, Status};
486
487 impl<T, const N: usize> Headers for super::Connection<'_, T, N>
488 where
489 T: TcpConnect,
490 {
491 fn header(&self, name: &str) -> Option<&'_ str> {
492 let response = self.response_ref().expect("Not in response state");
493
494 response.response.header(name)
495 }
496 }
497
498 impl<T, const N: usize> Status for super::Connection<'_, T, N>
499 where
500 T: TcpConnect,
501 {
502 fn status(&self) -> u16 {
503 let response = self.response_ref().expect("Not in response state");
504
505 response.response.status()
506 }
507
508 fn status_message(&self) -> Option<&'_ str> {
509 let response = self.response_ref().expect("Not in response state");
510
511 response.response.status_message()
512 }
513 }
514
515 impl<'b, T, const N: usize> Connection for super::Connection<'b, T, N>
516 where
517 T: TcpConnect,
518 {
519 type Read = Body<'b, T::Socket<'b>>;
520
521 type Headers = ResponseHeaders<'b, N>;
522
523 type RawConnectionError = T::Error;
524
525 type RawConnection = T::Socket<'b>;
526
527 async fn initiate_request(
528 &mut self,
529 method: Method,
530 uri: &str,
531 headers: &[(&str, &str)],
532 ) -> Result<(), Self::Error> {
533 super::Connection::initiate_request(self, true, method.into(), uri, headers).await
534 }
535
536 fn is_request_initiated(&self) -> bool {
537 super::Connection::is_request_initiated(self)
538 }
539
540 async fn initiate_response(&mut self) -> Result<(), Self::Error> {
541 super::Connection::initiate_response(self).await
542 }
543
544 fn is_response_initiated(&self) -> bool {
545 super::Connection::is_response_initiated(self)
546 }
547
548 fn split(&mut self) -> (&Self::Headers, &mut Self::Read) {
549 super::Connection::split(self)
550 }
551
552 fn raw_connection(&mut self) -> Result<&mut Self::RawConnection, Self::Error> {
553 panic!("Not supported")
557 }
558 }
559}