1use core::mem;
2use core::net::SocketAddr;
3use core::str;
4
5use embedded_io_async::{ErrorType, Read, Write};
6
7use edge_nal::{Close, TcpConnect, TcpShutdown};
8
9use crate::{
10 ws::{upgrade_request_headers, MAX_BASE64_KEY_LEN, MAX_BASE64_KEY_RESPONSE_LEN, NONCE_LEN},
11 ConnectionType, DEFAULT_MAX_HEADERS_COUNT,
12};
13
14use super::{send_headers, send_request, Body, Error, ResponseHeaders, SendBody};
15
16use super::Method;
17
18const COMPLETION_BUF_SIZE: usize = 64;
19
20#[allow(private_interfaces)]
22pub enum Connection<'b, T, const N: usize = DEFAULT_MAX_HEADERS_COUNT>
23where
24 T: TcpConnect,
25{
26 Transition(TransitionState),
27 Unbound(UnboundState<'b, T, N>),
28 Request(RequestState<'b, T, N>),
29 Response(ResponseState<'b, T, N>),
30}
31
32impl<'b, T, const N: usize> Connection<'b, T, N>
33where
34 T: TcpConnect,
35{
36 pub fn new(buf: &'b mut [u8], socket: &'b T, addr: SocketAddr) -> Self {
48 Self::Unbound(UnboundState {
49 buf,
50 socket,
51 addr,
52 io: None,
53 })
54 }
55
56 pub async fn reinitialize(&mut self, addr: SocketAddr) -> Result<(), Error<T::Error>> {
58 let _ = self.complete().await;
59 unwrap!(self.unbound_mut(), "Unreachable").addr = addr;
60
61 Ok(())
62 }
63
64 pub async fn initiate_request(
66 &mut self,
67 http11: bool,
68 method: Method,
69 uri: &str,
70 headers: &[(&str, &str)],
71 ) -> Result<(), Error<T::Error>> {
72 self.start_request(http11, method, uri, headers).await
73 }
74
75 pub async fn initiate_ws_upgrade_request(
77 &mut self,
78 host: Option<&str>,
79 origin: Option<&str>,
80 uri: &str,
81 version: Option<&str>,
82 nonce: &[u8; NONCE_LEN],
83 nonce_base64_buf: &mut [u8; MAX_BASE64_KEY_LEN],
84 ) -> Result<(), Error<T::Error>> {
85 let headers = upgrade_request_headers(host, origin, version, nonce, nonce_base64_buf);
86
87 self.initiate_request(true, Method::Get, uri, &headers)
88 .await
89 }
90
91 pub fn is_request_initiated(&self) -> bool {
93 matches!(self, Self::Request(_))
94 }
95
96 pub async fn initiate_response(&mut self) -> Result<(), Error<T::Error>> {
100 self.complete_request().await
101 }
102
103 pub fn is_response_initiated(&self) -> bool {
105 matches!(self, Self::Response(_))
106 }
107
108 pub fn is_ws_upgrade_accepted(
110 &self,
111 nonce: &[u8; NONCE_LEN],
112 buf: &mut [u8; MAX_BASE64_KEY_RESPONSE_LEN],
113 ) -> Result<bool, Error<T::Error>> {
114 Ok(self.headers()?.is_ws_upgrade_accepted(nonce, buf))
115 }
116
117 #[allow(clippy::type_complexity)]
121 pub fn split(&mut self) -> (&ResponseHeaders<'b, N>, &mut Body<'b, T::Socket<'b>>) {
122 let response = self.response_mut().expect("Not in response mode");
123
124 (&response.response, &mut response.io)
125 }
126
127 pub fn headers(&self) -> Result<&ResponseHeaders<'b, N>, Error<T::Error>> {
131 let response = self.response_ref()?;
132
133 Ok(&response.response)
134 }
135
136 pub fn raw_connection(&mut self) -> Result<&mut T::Socket<'b>, Error<T::Error>> {
140 Ok(self.io_mut())
141 }
142
143 pub fn release(mut self) -> (T::Socket<'b>, &'b mut [u8]) {
145 let mut state = self.unbind();
146
147 let io = unwrap!(state.io.take());
148
149 (io, state.buf)
150 }
151
152 async fn start_request(
153 &mut self,
154 http11: bool,
155 method: Method,
156 uri: &str,
157 headers: &[(&str, &str)],
158 ) -> Result<(), Error<T::Error>> {
159 let _ = self.complete().await;
160
161 let state = self.unbound_mut()?;
162
163 let fresh_connection = if state.io.is_none() {
164 state.io = Some(state.socket.connect(state.addr).await.map_err(Error::Io)?);
165 true
166 } else {
167 false
168 };
169
170 let mut state = self.unbind();
171
172 let result = async {
173 match send_request(http11, method, uri, unwrap!(state.io.as_mut())).await {
174 Ok(_) => (),
175 Err(Error::Io(_)) => {
176 if !fresh_connection {
177 state.io = None;
179 state.io = Some(state.socket.connect(state.addr).await.map_err(Error::Io)?);
180
181 send_request(http11, method, uri, unwrap!(state.io.as_mut())).await?;
182 }
183 }
184 Err(other) => Err(other)?,
185 }
186
187 let io = unwrap!(state.io.as_mut());
188
189 send_headers(headers, None, true, http11, true, &mut *io).await
190 }
191 .await;
192
193 match result {
194 Ok((connection_type, body_type)) => {
195 *self = Self::Request(RequestState {
196 buf: state.buf,
197 socket: state.socket,
198 addr: state.addr,
199 connection_type,
200 io: SendBody::new(body_type, unwrap!(state.io)),
201 });
202
203 Ok(())
204 }
205 Err(e) => {
206 state.io = None;
207 *self = Self::Unbound(state);
208
209 Err(e)
210 }
211 }
212 }
213
214 pub async fn complete(&mut self) -> Result<(), Error<T::Error>> {
219 let result = async {
220 if self.request_mut().is_ok() {
221 self.complete_request().await?;
222 }
223
224 let needs_close = if self.response_mut().is_ok() {
225 self.complete_response().await?
226 } else {
227 false
228 };
229
230 Result::<_, Error<T::Error>>::Ok(needs_close)
231 }
232 .await;
233
234 let mut state = self.unbind();
235
236 match result {
237 Ok(true) | Err(_) => {
238 let io = state.io.take();
239 *self = Self::Unbound(state);
240
241 if let Some(mut io) = io {
242 io.close(Close::Both).await.map_err(Error::Io)?;
243 let _ = io.abort().await;
244 }
245 }
246 _ => {
247 *self = Self::Unbound(state);
248 }
249 };
250
251 result?;
252
253 Ok(())
254 }
255
256 pub async fn close(mut self) -> Result<(), Error<T::Error>> {
257 let res = self.complete().await;
258
259 if let Some(mut io) = self.unbind().io.take() {
260 io.close(Close::Both).await.map_err(Error::Io)?;
261 let _ = io.abort().await;
262 }
263
264 res
265 }
266
267 async fn complete_request(&mut self) -> Result<(), Error<T::Error>> {
268 self.request_mut()?.io.finish().await?;
269
270 let request_connection_type = self.request_mut()?.connection_type;
271
272 let mut state = self.unbind();
273 let buf_ptr: *mut [u8] = state.buf;
274 let mut response = ResponseHeaders::new();
275
276 match response
277 .receive(state.buf, &mut unwrap!(state.io.as_mut()), true)
278 .await
279 {
280 Ok((buf, read_len)) => {
281 let (connection_type, body_type) =
282 response.resolve::<T::Error>(request_connection_type)?;
283
284 let io = Body::new(body_type, buf, read_len, unwrap!(state.io));
285
286 *self = Self::Response(ResponseState {
287 buf: buf_ptr,
288 response,
289 socket: state.socket,
290 addr: state.addr,
291 connection_type,
292 io,
293 });
294
295 Ok(())
296 }
297 Err(e) => {
298 state.io = None;
299 state.buf = unwrap!(unsafe { buf_ptr.as_mut() });
300
301 *self = Self::Unbound(state);
302
303 Err(e)
304 }
305 }
306 }
307
308 async fn complete_response(&mut self) -> Result<bool, Error<T::Error>> {
309 if self.request_mut().is_ok() {
310 self.complete_request().await?;
311 }
312
313 let response = self.response_mut()?;
314
315 let mut buf = [0; COMPLETION_BUF_SIZE];
316 while response.io.read(&mut buf).await? > 0 {}
317
318 let needs_close = response.needs_close();
319
320 *self = Self::Unbound(self.unbind());
321
322 Ok(needs_close)
323 }
324
325 pub fn needs_close(&self) -> bool {
327 match self {
328 Self::Response(response) => response.needs_close(),
329 _ => true,
330 }
331 }
332
333 fn unbind(&mut self) -> UnboundState<'b, T, N> {
334 let state = mem::replace(self, Self::Transition(TransitionState(())));
335
336 match state {
337 Self::Unbound(unbound) => unbound,
338 Self::Request(request) => {
339 let io = request.io.release();
340
341 UnboundState {
342 buf: request.buf,
343 socket: request.socket,
344 addr: request.addr,
345 io: Some(io),
346 }
347 }
348 Self::Response(response) => {
349 let io = response.io.release();
350
351 UnboundState {
352 buf: unwrap!(unsafe { response.buf.as_mut() }),
353 socket: response.socket,
354 addr: response.addr,
355 io: Some(io),
356 }
357 }
358 _ => unreachable!(),
359 }
360 }
361
362 fn unbound_mut(&mut self) -> Result<&mut UnboundState<'b, T, N>, Error<T::Error>> {
363 if let Self::Unbound(new) = self {
364 Ok(new)
365 } else {
366 Err(Error::InvalidState)
367 }
368 }
369
370 fn request_mut(&mut self) -> Result<&mut RequestState<'b, T, N>, Error<T::Error>> {
371 if let Self::Request(request) = self {
372 Ok(request)
373 } else {
374 Err(Error::InvalidState)
375 }
376 }
377
378 fn response_mut(&mut self) -> Result<&mut ResponseState<'b, T, N>, Error<T::Error>> {
379 if let Self::Response(response) = self {
380 Ok(response)
381 } else {
382 Err(Error::InvalidState)
383 }
384 }
385
386 fn response_ref(&self) -> Result<&ResponseState<'b, T, N>, Error<T::Error>> {
387 if let Self::Response(response) = self {
388 Ok(response)
389 } else {
390 Err(Error::InvalidState)
391 }
392 }
393
394 fn io_mut(&mut self) -> &mut T::Socket<'b> {
395 match self {
396 Self::Unbound(unbound) => unwrap!(unbound.io.as_mut()),
397 Self::Request(request) => request.io.as_raw_writer(),
398 Self::Response(response) => response.io.as_raw_reader(),
399 _ => unreachable!(),
400 }
401 }
402}
403
404impl<T, const N: usize> ErrorType for Connection<'_, T, N>
405where
406 T: TcpConnect,
407{
408 type Error = Error<T::Error>;
409}
410
411impl<T, const N: usize> Read for Connection<'_, T, N>
412where
413 T: TcpConnect,
414{
415 async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
416 self.response_mut()?.io.read(buf).await
417 }
418}
419
420impl<T, const N: usize> Write for Connection<'_, T, N>
421where
422 T: TcpConnect,
423{
424 async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
425 self.request_mut()?.io.write(buf).await
426 }
427
428 async fn flush(&mut self) -> Result<(), Self::Error> {
429 self.request_mut()?.io.flush().await
430 }
431}
432
433struct TransitionState(());
434
435struct UnboundState<'b, T, const N: usize>
436where
437 T: TcpConnect,
438{
439 buf: &'b mut [u8],
440 socket: &'b T,
441 addr: SocketAddr,
442 io: Option<T::Socket<'b>>,
443}
444
445struct RequestState<'b, T, const N: usize>
446where
447 T: TcpConnect,
448{
449 buf: &'b mut [u8],
450 socket: &'b T,
451 addr: SocketAddr,
452 connection_type: ConnectionType,
453 io: SendBody<T::Socket<'b>>,
454}
455
456struct ResponseState<'b, T, const N: usize>
457where
458 T: TcpConnect,
459{
460 buf: *mut [u8],
461 response: ResponseHeaders<'b, N>,
462 socket: &'b T,
463 addr: SocketAddr,
464 connection_type: ConnectionType,
465 io: Body<'b, T::Socket<'b>>,
466}
467
468impl<T, const N: usize> ResponseState<'_, T, N>
469where
470 T: TcpConnect,
471{
472 fn needs_close(&self) -> bool {
473 matches!(self.connection_type, ConnectionType::Close) || self.io.needs_close()
474 }
475}