1use core::fmt::{self, Debug, Display};
2use core::mem::{self, MaybeUninit};
3use core::pin::pin;
4
5use edge_nal::{
6 with_timeout, Close, Readable, TcpShutdown, TcpSplit, WithTimeout, WithTimeoutError,
7};
8
9use embassy_sync::blocking_mutex::raw::NoopRawMutex;
10use embassy_sync::mutex::Mutex;
11
12use embedded_io_async::{ErrorType, Read, Write};
13
14use super::{send_headers, send_status, Body, Error, RequestHeaders, SendBody};
15
16use crate::ws::{upgrade_response_headers, MAX_BASE64_KEY_RESPONSE_LEN};
17use crate::{ConnectionType, DEFAULT_MAX_HEADERS_COUNT};
18
19pub const DEFAULT_HANDLER_TASKS_COUNT: usize = 4;
20pub const DEFAULT_BUF_SIZE: usize = 2048;
21
22const COMPLETION_BUF_SIZE: usize = 64;
23
24#[allow(private_interfaces)]
26pub enum Connection<'b, T, const N: usize = DEFAULT_MAX_HEADERS_COUNT> {
27 Transition(TransitionState),
28 Unbound(T),
29 Request(RequestState<'b, T, N>),
30 Response(ResponseState<T>),
31}
32
33impl<'b, T, const N: usize> Connection<'b, T, N>
34where
35 T: Read + Write,
36{
37 pub async fn new(
48 buf: &'b mut [u8],
49 mut io: T,
50 ) -> Result<Connection<'b, T, N>, Error<T::Error>> {
51 let mut request = RequestHeaders::new();
52
53 let (buf, read_len) = request.receive(buf, &mut io, true).await?;
54
55 let (connection_type, body_type) = request.resolve::<T::Error>()?;
56
57 let io = Body::new(body_type, buf, read_len, io);
58
59 Ok(Self::Request(RequestState {
60 request,
61 io,
62 connection_type,
63 }))
64 }
65
66 pub fn is_request_initiated(&self) -> bool {
68 matches!(self, Self::Request(_))
69 }
70
71 pub fn split(&mut self) -> (&RequestHeaders<'b, N>, &mut Body<'b, T>) {
73 let req = self.request_mut().expect("Not in request mode");
74
75 (&req.request, &mut req.io)
76 }
77
78 pub fn headers(&self) -> Result<&RequestHeaders<'b, N>, Error<T::Error>> {
80 Ok(&self.request_ref()?.request)
81 }
82
83 pub fn is_ws_upgrade_request(&self) -> Result<bool, Error<T::Error>> {
85 Ok(self.headers()?.is_ws_upgrade_request())
86 }
87
88 pub async fn initiate_response(
97 &mut self,
98 status: u16,
99 message: Option<&str>,
100 headers: &[(&str, &str)],
101 ) -> Result<(), Error<T::Error>> {
102 self.complete_request(status, message, headers).await
103 }
104
105 pub async fn initiate_ws_upgrade_response(
107 &mut self,
108 buf: &mut [u8; MAX_BASE64_KEY_RESPONSE_LEN],
109 ) -> Result<(), Error<T::Error>> {
110 let headers = upgrade_response_headers(self.headers()?.headers.iter(), None, buf)?;
111
112 self.initiate_response(101, None, &headers).await
113 }
114
115 pub fn is_response_initiated(&self) -> bool {
117 matches!(self, Self::Response(_))
118 }
119
120 pub async fn complete(&mut self) -> Result<(), Error<T::Error>> {
123 if self.is_request_initiated() {
124 self.complete_request(200, Some("OK"), &[]).await?;
125 }
126
127 if self.is_response_initiated() {
128 self.complete_response().await?;
129 }
130
131 Ok(())
132 }
133
134 pub async fn complete_err(&mut self, err: &str) -> Result<(), Error<T::Error>> {
138 let result = self.request_mut();
139
140 match result {
141 Ok(_) => {
142 let headers = [("Connection", "Close"), ("Content-Type", "text/plain")];
143
144 self.complete_request(500, Some("Internal Error"), &headers)
145 .await?;
146
147 let response = self.response_mut()?;
148
149 response.io.write_all(err.as_bytes()).await?;
150 response.io.finish().await?;
151
152 Ok(())
153 }
154 Err(err) => Err(err),
155 }
156 }
157
158 pub fn needs_close(&self) -> bool {
162 match self {
163 Self::Response(response) => response.needs_close(),
164 _ => true,
165 }
166 }
167
168 pub fn unbind(&mut self) -> Result<&mut T, Error<T::Error>> {
172 let io = self.unbind_mut();
173 *self = Self::Unbound(io);
174
175 Ok(self.io_mut())
176 }
177
178 async fn complete_request(
179 &mut self,
180 status: u16,
181 reason: Option<&str>,
182 headers: &[(&str, &str)],
183 ) -> Result<(), Error<T::Error>> {
184 let request = self.request_mut()?;
185
186 let mut buf = [0; COMPLETION_BUF_SIZE];
187 while request.io.read(&mut buf).await? > 0 {}
188
189 let http11 = request.request.http11;
190 let request_connection_type = request.connection_type;
191
192 let mut io = self.unbind_mut();
193
194 let result = async {
195 send_status(http11, status, reason, &mut io).await?;
196
197 let (connection_type, body_type) = send_headers(
198 headers.iter(),
199 Some(request_connection_type),
200 false,
201 http11,
202 true,
203 &mut io,
204 )
205 .await?;
206
207 Ok((connection_type, body_type))
208 }
209 .await;
210
211 match result {
212 Ok((connection_type, body_type)) => {
213 *self = Self::Response(ResponseState {
214 io: SendBody::new(body_type, io),
215 connection_type,
216 });
217
218 Ok(())
219 }
220 Err(e) => {
221 *self = Self::Unbound(io);
222
223 Err(e)
224 }
225 }
226 }
227
228 async fn complete_response(&mut self) -> Result<(), Error<T::Error>> {
229 self.response_mut()?.io.finish().await?;
230
231 Ok(())
232 }
233
234 fn unbind_mut(&mut self) -> T {
235 let state = mem::replace(self, Self::Transition(TransitionState(())));
236
237 match state {
238 Self::Request(request) => request.io.release(),
239 Self::Response(response) => response.io.release(),
240 Self::Unbound(io) => io,
241 _ => unreachable!(),
242 }
243 }
244
245 fn request_mut(&mut self) -> Result<&mut RequestState<'b, T, N>, Error<T::Error>> {
246 if let Self::Request(request) = self {
247 Ok(request)
248 } else {
249 Err(Error::InvalidState)
250 }
251 }
252
253 fn request_ref(&self) -> Result<&RequestState<'b, T, N>, Error<T::Error>> {
254 if let Self::Request(request) = self {
255 Ok(request)
256 } else {
257 Err(Error::InvalidState)
258 }
259 }
260
261 fn response_mut(&mut self) -> Result<&mut ResponseState<T>, Error<T::Error>> {
262 if let Self::Response(response) = self {
263 Ok(response)
264 } else {
265 Err(Error::InvalidState)
266 }
267 }
268
269 fn io_mut(&mut self) -> &mut T {
270 match self {
271 Self::Request(request) => request.io.as_raw_reader(),
272 Self::Response(response) => response.io.as_raw_writer(),
273 Self::Unbound(io) => io,
274 _ => unreachable!(),
275 }
276 }
277}
278
279impl<T, const N: usize> ErrorType for Connection<'_, T, N>
280where
281 T: ErrorType,
282{
283 type Error = Error<T::Error>;
284}
285
286impl<T, const N: usize> Read for Connection<'_, T, N>
287where
288 T: Read + Write,
289{
290 async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
291 self.request_mut()?.io.read(buf).await
292 }
293}
294
295impl<T, const N: usize> Write for Connection<'_, T, N>
296where
297 T: Read + Write,
298{
299 async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
300 self.response_mut()?.io.write(buf).await
301 }
302
303 async fn flush(&mut self) -> Result<(), Self::Error> {
304 self.response_mut()?.io.flush().await
305 }
306}
307
308struct TransitionState(());
309
310struct RequestState<'b, T, const N: usize> {
311 request: RequestHeaders<'b, N>,
312 io: Body<'b, T>,
313 connection_type: ConnectionType,
314}
315
316struct ResponseState<T> {
317 io: SendBody<T>,
318 connection_type: ConnectionType,
319}
320
321impl<T> ResponseState<T>
322where
323 T: Write,
324{
325 fn needs_close(&self) -> bool {
326 matches!(self.connection_type, ConnectionType::Close) || self.io.needs_close()
327 }
328}
329
330#[derive(Debug)]
331#[cfg_attr(feature = "defmt", derive(defmt::Format))]
332pub enum HandlerError<T, E> {
333 Io(T),
334 Connection(Error<T>),
335 Handler(E),
336}
337
338impl<T, E> From<Error<T>> for HandlerError<T, E> {
339 fn from(e: Error<T>) -> Self {
340 Self::Connection(e)
341 }
342}
343
344pub trait Handler {
346 type Error<E>: Debug
347 where
348 E: Debug;
349
350 async fn handle<T, const N: usize>(
356 &self,
357 task_id: impl Display + Copy,
358 connection: &mut Connection<'_, T, N>,
359 ) -> Result<(), Self::Error<T::Error>>
360 where
361 T: Read + Write + TcpSplit;
362}
363
364impl<H> Handler for &H
365where
366 H: Handler,
367{
368 type Error<E>
369 = H::Error<E>
370 where
371 E: Debug;
372
373 async fn handle<T, const N: usize>(
374 &self,
375 task_id: impl Display + Copy,
376 connection: &mut Connection<'_, T, N>,
377 ) -> Result<(), Self::Error<T::Error>>
378 where
379 T: Read + Write + TcpSplit,
380 {
381 (**self).handle(task_id, connection).await
382 }
383}
384
385impl<H> Handler for &mut H
386where
387 H: Handler,
388{
389 type Error<E>
390 = H::Error<E>
391 where
392 E: Debug;
393
394 async fn handle<T, const N: usize>(
395 &self,
396 task_id: impl Display + Copy,
397 connection: &mut Connection<'_, T, N>,
398 ) -> Result<(), Self::Error<T::Error>>
399 where
400 T: Read + Write + TcpSplit,
401 {
402 (**self).handle(task_id, connection).await
403 }
404}
405
406impl<H> Handler for WithTimeout<H>
407where
408 H: Handler,
409{
410 type Error<E>
411 = WithTimeoutError<H::Error<E>>
412 where
413 E: Debug;
414
415 async fn handle<T, const N: usize>(
416 &self,
417 task_id: impl Display + Copy,
418 connection: &mut Connection<'_, T, N>,
419 ) -> Result<(), Self::Error<T::Error>>
420 where
421 T: Read + Write + TcpSplit,
422 {
423 let mut io = pin!(self.io().handle(task_id, connection));
424
425 with_timeout(self.timeout_ms(), &mut io).await?;
426
427 Ok(())
428 }
429}
430
431pub async fn handle_connection<H, T, const N: usize>(
456 mut io: T,
457 buf: &mut [u8],
458 keepalive_timeout_ms: Option<u32>,
459 task_id: impl Display + Copy,
460 handler: H,
461) where
462 H: Handler,
463 T: Read + Write + Readable + TcpSplit + TcpShutdown,
464{
465 let close = loop {
466 debug!(
467 "Handler task {}: Waiting for a new request",
468 display2format!(task_id)
469 );
470
471 if let Some(keepalive_timeout_ms) = keepalive_timeout_ms {
472 let wait_data = with_timeout(keepalive_timeout_ms, io.readable()).await;
473 match wait_data {
474 Err(WithTimeoutError::Timeout) => {
475 info!(
476 "Handler task {}: Closing connection due to inactivity",
477 display2format!(task_id)
478 );
479 break true;
480 }
481 Err(e) => {
482 warn!(
483 "Handler task {}: Error when handling request: {:?}",
484 display2format!(task_id),
485 debug2format!(e)
486 );
487 break true;
488 }
489 Ok(_) => {}
490 }
491 }
492
493 let result = handle_request::<_, _, N>(buf, &mut io, task_id, &handler).await;
494
495 match result {
496 Err(HandlerError::Connection(Error::ConnectionClosed)) => {
497 debug!(
498 "Handler task {}: Connection closed",
499 display2format!(task_id)
500 );
501 break false;
502 }
503 Err(e) => {
504 warn!(
505 "Handler task {}: Error when handling request: {:?}",
506 display2format!(task_id),
507 debug2format!(e)
508 );
509 break true;
510 }
511 Ok(needs_close) => {
512 if needs_close {
513 debug!(
514 "Handler task {}: Request complete; closing connection",
515 display2format!(task_id)
516 );
517 break true;
518 } else {
519 debug!(
520 "Handler task {}: Request complete",
521 display2format!(task_id)
522 );
523 }
524 }
525 }
526 };
527
528 if close {
529 if let Err(e) = io.close(Close::Both).await {
530 warn!(
531 "Handler task {}: Error when closing the socket: {:?}",
532 display2format!(task_id),
533 debug2format!(e)
534 );
535 }
536 } else {
537 let _ = io.abort().await;
538 }
539}
540
541#[derive(Debug)]
543pub enum HandleRequestError<C, E> {
544 Connection(Error<C>),
546 Handler(E),
548}
549
550impl<T, E> From<Error<T>> for HandleRequestError<T, E> {
551 fn from(e: Error<T>) -> Self {
552 Self::Connection(e)
553 }
554}
555
556impl<C, E> fmt::Display for HandleRequestError<C, E>
557where
558 C: fmt::Display,
559 E: fmt::Display,
560{
561 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
562 match self {
563 Self::Connection(e) => write!(f, "Connection error: {}", e),
564 Self::Handler(e) => write!(f, "Handler error: {}", e),
565 }
566 }
567}
568
569#[cfg(feature = "defmt")]
570impl<C, E> defmt::Format for HandleRequestError<C, E>
571where
572 C: defmt::Format,
573 E: defmt::Format,
574{
575 fn format(&self, f: defmt::Formatter<'_>) {
576 match self {
577 Self::Connection(e) => defmt::write!(f, "Connection error: {}", e),
578 Self::Handler(e) => defmt::write!(f, "Handler error: {}", e),
579 }
580 }
581}
582
583impl<C, E> embedded_io_async::Error for HandleRequestError<C, E>
584where
585 C: Debug + core::error::Error + embedded_io_async::Error,
586 E: Debug + core::error::Error,
587{
588 fn kind(&self) -> embedded_io_async::ErrorKind {
589 match self {
590 Self::Connection(Error::Io(e)) => e.kind(),
591 _ => embedded_io_async::ErrorKind::Other,
592 }
593 }
594}
595
596impl<C, E> core::error::Error for HandleRequestError<C, E>
597where
598 C: core::error::Error,
599 E: core::error::Error,
600{
601}
602
603pub async fn handle_request<H, T, const N: usize>(
617 buf: &mut [u8],
618 io: T,
619 task_id: impl Display + Copy,
620 handler: H,
621) -> Result<bool, HandlerError<T::Error, H::Error<T::Error>>>
622where
623 H: Handler,
624 T: Read + Write + TcpSplit,
625{
626 let mut connection = Connection::<_, N>::new(buf, io).await?;
627
628 let result = handler.handle(task_id, &mut connection).await;
629
630 match result {
631 Result::Ok(_) => connection.complete().await?,
632 Result::Err(e) => connection
633 .complete_err("INTERNAL ERROR")
634 .await
635 .map_err(|_| HandlerError::Handler(e))?,
636 }
637
638 Ok(connection.needs_close())
639}
640
641pub type DefaultServer =
643 Server<{ DEFAULT_HANDLER_TASKS_COUNT }, { DEFAULT_BUF_SIZE }, { DEFAULT_MAX_HEADERS_COUNT }>;
644
645pub type ServerBuffers<const P: usize, const B: usize> = MaybeUninit<[[u8; B]; P]>;
647
648#[repr(transparent)]
652pub struct Server<
653 const P: usize = DEFAULT_HANDLER_TASKS_COUNT,
654 const B: usize = DEFAULT_BUF_SIZE,
655 const N: usize = DEFAULT_MAX_HEADERS_COUNT,
656>(ServerBuffers<P, B>);
657
658impl<const P: usize, const B: usize, const N: usize> Server<P, B, N> {
659 #[inline(always)]
661 pub const fn new() -> Self {
662 Self(MaybeUninit::uninit())
663 }
664
665 #[inline(never)]
685 #[cold]
686 pub async fn run<A, H>(
687 &mut self,
688 keepalive_timeout_ms: Option<u32>,
689 acceptor: A,
690 handler: H,
691 ) -> Result<(), Error<A::Error>>
692 where
693 A: edge_nal::TcpAccept,
694 H: Handler,
695 {
696 let mutex = Mutex::<NoopRawMutex, _>::new(());
697 let mut tasks = heapless::Vec::<_, P>::new();
698
699 info!(
700 "Creating {} handler tasks, memory: {}B",
701 P,
702 core::mem::size_of_val(&tasks)
703 );
704
705 for index in 0..P {
706 let mutex = &mutex;
707 let acceptor = &acceptor;
708 let task_id = index;
709 let handler = &handler;
710 let buf: *mut [u8; B] = &mut unsafe { self.0.assume_init_mut() }[index];
711
712 unwrap!(tasks
713 .push(async move {
714 loop {
715 debug!(
716 "Handler task {}: Waiting for connection",
717 display2format!(task_id)
718 );
719
720 let io = {
721 let _guard = mutex.lock().await;
722
723 acceptor.accept().await.map_err(Error::Io)?.1
724 };
725
726 debug!(
727 "Handler task {}: Got connection request",
728 display2format!(task_id)
729 );
730
731 handle_connection::<_, _, N>(
732 io,
733 unwrap!(unsafe { buf.as_mut() }),
734 keepalive_timeout_ms,
735 task_id,
736 handler,
737 )
738 .await;
739 }
740 })
741 .map_err(|_| ()));
742 }
743
744 let tasks = pin!(tasks);
745
746 let tasks = unsafe { tasks.map_unchecked_mut(|t| t.as_mut_slice()) };
747 let (result, _) = embassy_futures::select::select_slice(tasks).await;
748
749 warn!(
750 "Server processing loop quit abruptly: {:?}",
751 debug2format!(result)
752 );
753
754 result
755 }
756}
757
758impl<const P: usize, const B: usize, const N: usize> Default for Server<P, B, N> {
759 fn default() -> Self {
760 Self::new()
761 }
762}