1use core::fmt::{self, Debug, Display};
2use core::mem::{self, MaybeUninit};
3use core::pin::pin;
4
5use edge_nal::{
6 with_timeout, Close, Readable, TcpShutdown, TcpSplit, WithTimeout, WithTimeoutError,
7};
8
9use embassy_sync::blocking_mutex::raw::NoopRawMutex;
10use embassy_sync::mutex::Mutex;
11
12use embedded_io_async::{ErrorType, Read, Write};
13
14use super::{send_headers, send_status, Body, Error, RequestHeaders, SendBody};
15
16use crate::ws::{upgrade_response_headers, MAX_BASE64_KEY_RESPONSE_LEN};
17use crate::{ConnectionType, DEFAULT_MAX_HEADERS_COUNT};
18
19#[allow(unused_imports)]
20#[cfg(feature = "embedded-svc")]
21pub use embedded_svc_compat::*;
22
23pub const DEFAULT_HANDLER_TASKS_COUNT: usize = 4;
24pub const DEFAULT_BUF_SIZE: usize = 2048;
25
26const COMPLETION_BUF_SIZE: usize = 64;
27
28#[allow(private_interfaces)]
30pub enum Connection<'b, T, const N: usize = DEFAULT_MAX_HEADERS_COUNT> {
31 Transition(TransitionState),
32 Unbound(T),
33 Request(RequestState<'b, T, N>),
34 Response(ResponseState<T>),
35}
36
37impl<'b, T, const N: usize> Connection<'b, T, N>
38where
39 T: Read + Write,
40{
41 pub async fn new(
52 buf: &'b mut [u8],
53 mut io: T,
54 ) -> Result<Connection<'b, T, N>, Error<T::Error>> {
55 let mut request = RequestHeaders::new();
56
57 let (buf, read_len) = request.receive(buf, &mut io, true).await?;
58
59 let (connection_type, body_type) = request.resolve::<T::Error>()?;
60
61 let io = Body::new(body_type, buf, read_len, io);
62
63 Ok(Self::Request(RequestState {
64 request,
65 io,
66 connection_type,
67 }))
68 }
69
70 pub fn is_request_initiated(&self) -> bool {
72 matches!(self, Self::Request(_))
73 }
74
75 pub fn split(&mut self) -> (&RequestHeaders<'b, N>, &mut Body<'b, T>) {
77 let req = self.request_mut().expect("Not in request mode");
78
79 (&req.request, &mut req.io)
80 }
81
82 pub fn headers(&self) -> Result<&RequestHeaders<'b, N>, Error<T::Error>> {
84 Ok(&self.request_ref()?.request)
85 }
86
87 pub fn is_ws_upgrade_request(&self) -> Result<bool, Error<T::Error>> {
89 Ok(self.headers()?.is_ws_upgrade_request())
90 }
91
92 pub async fn initiate_response(
101 &mut self,
102 status: u16,
103 message: Option<&str>,
104 headers: &[(&str, &str)],
105 ) -> Result<(), Error<T::Error>> {
106 self.complete_request(status, message, headers).await
107 }
108
109 pub async fn initiate_ws_upgrade_response(
111 &mut self,
112 buf: &mut [u8; MAX_BASE64_KEY_RESPONSE_LEN],
113 ) -> Result<(), Error<T::Error>> {
114 let headers = upgrade_response_headers(self.headers()?.headers.iter(), None, buf)?;
115
116 self.initiate_response(101, None, &headers).await
117 }
118
119 pub fn is_response_initiated(&self) -> bool {
121 matches!(self, Self::Response(_))
122 }
123
124 pub async fn complete(&mut self) -> Result<(), Error<T::Error>> {
127 if self.is_request_initiated() {
128 self.complete_request(200, Some("OK"), &[]).await?;
129 }
130
131 if self.is_response_initiated() {
132 self.complete_response().await?;
133 }
134
135 Ok(())
136 }
137
138 pub async fn complete_err(&mut self, err: &str) -> Result<(), Error<T::Error>> {
142 let result = self.request_mut();
143
144 match result {
145 Ok(_) => {
146 let headers = [("Connection", "Close"), ("Content-Type", "text/plain")];
147
148 self.complete_request(500, Some("Internal Error"), &headers)
149 .await?;
150
151 let response = self.response_mut()?;
152
153 response.io.write_all(err.as_bytes()).await?;
154 response.io.finish().await?;
155
156 Ok(())
157 }
158 Err(err) => Err(err),
159 }
160 }
161
162 pub fn needs_close(&self) -> bool {
166 match self {
167 Self::Response(response) => response.needs_close(),
168 _ => true,
169 }
170 }
171
172 pub fn unbind(&mut self) -> Result<&mut T, Error<T::Error>> {
176 let io = self.unbind_mut();
177 *self = Self::Unbound(io);
178
179 Ok(self.io_mut())
180 }
181
182 async fn complete_request(
183 &mut self,
184 status: u16,
185 reason: Option<&str>,
186 headers: &[(&str, &str)],
187 ) -> Result<(), Error<T::Error>> {
188 let request = self.request_mut()?;
189
190 let mut buf = [0; COMPLETION_BUF_SIZE];
191 while request.io.read(&mut buf).await? > 0 {}
192
193 let http11 = request.request.http11;
194 let request_connection_type = request.connection_type;
195
196 let mut io = self.unbind_mut();
197
198 let result = async {
199 send_status(http11, status, reason, &mut io).await?;
200
201 let (connection_type, body_type) = send_headers(
202 headers.iter(),
203 Some(request_connection_type),
204 false,
205 http11,
206 true,
207 &mut io,
208 )
209 .await?;
210
211 Ok((connection_type, body_type))
212 }
213 .await;
214
215 match result {
216 Ok((connection_type, body_type)) => {
217 *self = Self::Response(ResponseState {
218 io: SendBody::new(body_type, io),
219 connection_type,
220 });
221
222 Ok(())
223 }
224 Err(e) => {
225 *self = Self::Unbound(io);
226
227 Err(e)
228 }
229 }
230 }
231
232 async fn complete_response(&mut self) -> Result<(), Error<T::Error>> {
233 self.response_mut()?.io.finish().await?;
234
235 Ok(())
236 }
237
238 fn unbind_mut(&mut self) -> T {
239 let state = mem::replace(self, Self::Transition(TransitionState(())));
240
241 match state {
242 Self::Request(request) => request.io.release(),
243 Self::Response(response) => response.io.release(),
244 Self::Unbound(io) => io,
245 _ => unreachable!(),
246 }
247 }
248
249 fn request_mut(&mut self) -> Result<&mut RequestState<'b, T, N>, Error<T::Error>> {
250 if let Self::Request(request) = self {
251 Ok(request)
252 } else {
253 Err(Error::InvalidState)
254 }
255 }
256
257 fn request_ref(&self) -> Result<&RequestState<'b, T, N>, Error<T::Error>> {
258 if let Self::Request(request) = self {
259 Ok(request)
260 } else {
261 Err(Error::InvalidState)
262 }
263 }
264
265 fn response_mut(&mut self) -> Result<&mut ResponseState<T>, Error<T::Error>> {
266 if let Self::Response(response) = self {
267 Ok(response)
268 } else {
269 Err(Error::InvalidState)
270 }
271 }
272
273 fn io_mut(&mut self) -> &mut T {
274 match self {
275 Self::Request(request) => request.io.as_raw_reader(),
276 Self::Response(response) => response.io.as_raw_writer(),
277 Self::Unbound(io) => io,
278 _ => unreachable!(),
279 }
280 }
281}
282
283impl<T, const N: usize> ErrorType for Connection<'_, T, N>
284where
285 T: ErrorType,
286{
287 type Error = Error<T::Error>;
288}
289
290impl<T, const N: usize> Read for Connection<'_, T, N>
291where
292 T: Read + Write,
293{
294 async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
295 self.request_mut()?.io.read(buf).await
296 }
297}
298
299impl<T, const N: usize> Write for Connection<'_, T, N>
300where
301 T: Read + Write,
302{
303 async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
304 self.response_mut()?.io.write(buf).await
305 }
306
307 async fn flush(&mut self) -> Result<(), Self::Error> {
308 self.response_mut()?.io.flush().await
309 }
310}
311
312struct TransitionState(());
313
314struct RequestState<'b, T, const N: usize> {
315 request: RequestHeaders<'b, N>,
316 io: Body<'b, T>,
317 connection_type: ConnectionType,
318}
319
320struct ResponseState<T> {
321 io: SendBody<T>,
322 connection_type: ConnectionType,
323}
324
325impl<T> ResponseState<T>
326where
327 T: Write,
328{
329 fn needs_close(&self) -> bool {
330 matches!(self.connection_type, ConnectionType::Close) || self.io.needs_close()
331 }
332}
333
334#[derive(Debug)]
335#[cfg_attr(feature = "defmt", derive(defmt::Format))]
336pub enum HandlerError<T, E> {
337 Io(T),
338 Connection(Error<T>),
339 Handler(E),
340}
341
342impl<T, E> From<Error<T>> for HandlerError<T, E> {
343 fn from(e: Error<T>) -> Self {
344 Self::Connection(e)
345 }
346}
347
348pub trait Handler {
350 type Error<E>: Debug
351 where
352 E: Debug;
353
354 async fn handle<T, const N: usize>(
360 &self,
361 task_id: impl Display + Copy,
362 connection: &mut Connection<'_, T, N>,
363 ) -> Result<(), Self::Error<T::Error>>
364 where
365 T: Read + Write + TcpSplit;
366}
367
368impl<H> Handler for &H
369where
370 H: Handler,
371{
372 type Error<E>
373 = H::Error<E>
374 where
375 E: Debug;
376
377 async fn handle<T, const N: usize>(
378 &self,
379 task_id: impl Display + Copy,
380 connection: &mut Connection<'_, T, N>,
381 ) -> Result<(), Self::Error<T::Error>>
382 where
383 T: Read + Write + TcpSplit,
384 {
385 (**self).handle(task_id, connection).await
386 }
387}
388
389impl<H> Handler for &mut H
390where
391 H: Handler,
392{
393 type Error<E>
394 = H::Error<E>
395 where
396 E: Debug;
397
398 async fn handle<T, const N: usize>(
399 &self,
400 task_id: impl Display + Copy,
401 connection: &mut Connection<'_, T, N>,
402 ) -> Result<(), Self::Error<T::Error>>
403 where
404 T: Read + Write + TcpSplit,
405 {
406 (**self).handle(task_id, connection).await
407 }
408}
409
410impl<H> Handler for WithTimeout<H>
411where
412 H: Handler,
413{
414 type Error<E>
415 = WithTimeoutError<H::Error<E>>
416 where
417 E: Debug;
418
419 async fn handle<T, const N: usize>(
420 &self,
421 task_id: impl Display + Copy,
422 connection: &mut Connection<'_, T, N>,
423 ) -> Result<(), Self::Error<T::Error>>
424 where
425 T: Read + Write + TcpSplit,
426 {
427 let mut io = pin!(self.io().handle(task_id, connection));
428
429 with_timeout(self.timeout_ms(), &mut io).await?;
430
431 Ok(())
432 }
433}
434
435pub async fn handle_connection<H, T, const N: usize>(
460 mut io: T,
461 buf: &mut [u8],
462 keepalive_timeout_ms: Option<u32>,
463 task_id: impl Display + Copy,
464 handler: H,
465) where
466 H: Handler,
467 T: Read + Write + Readable + TcpSplit + TcpShutdown,
468{
469 let close = loop {
470 debug!(
471 "Handler task {}: Waiting for a new request",
472 display2format!(task_id)
473 );
474
475 if let Some(keepalive_timeout_ms) = keepalive_timeout_ms {
476 let wait_data = with_timeout(keepalive_timeout_ms, io.readable()).await;
477 match wait_data {
478 Err(WithTimeoutError::Timeout) => {
479 info!(
480 "Handler task {}: Closing connection due to inactivity",
481 display2format!(task_id)
482 );
483 break true;
484 }
485 Err(e) => {
486 warn!(
487 "Handler task {}: Error when handling request: {:?}",
488 display2format!(task_id),
489 debug2format!(e)
490 );
491 break true;
492 }
493 Ok(_) => {}
494 }
495 }
496
497 let result = handle_request::<_, _, N>(buf, &mut io, task_id, &handler).await;
498
499 match result {
500 Err(HandlerError::Connection(Error::ConnectionClosed)) => {
501 debug!(
502 "Handler task {}: Connection closed",
503 display2format!(task_id)
504 );
505 break false;
506 }
507 Err(e) => {
508 warn!(
509 "Handler task {}: Error when handling request: {:?}",
510 display2format!(task_id),
511 debug2format!(e)
512 );
513 break true;
514 }
515 Ok(needs_close) => {
516 if needs_close {
517 debug!(
518 "Handler task {}: Request complete; closing connection",
519 display2format!(task_id)
520 );
521 break true;
522 } else {
523 debug!(
524 "Handler task {}: Request complete",
525 display2format!(task_id)
526 );
527 }
528 }
529 }
530 };
531
532 if close {
533 if let Err(e) = io.close(Close::Both).await {
534 warn!(
535 "Handler task {}: Error when closing the socket: {:?}",
536 display2format!(task_id),
537 debug2format!(e)
538 );
539 }
540 } else {
541 let _ = io.abort().await;
542 }
543}
544
545#[derive(Debug)]
547pub enum HandleRequestError<C, E> {
548 Connection(Error<C>),
550 Handler(E),
552}
553
554impl<T, E> From<Error<T>> for HandleRequestError<T, E> {
555 fn from(e: Error<T>) -> Self {
556 Self::Connection(e)
557 }
558}
559
560impl<C, E> fmt::Display for HandleRequestError<C, E>
561where
562 C: fmt::Display,
563 E: fmt::Display,
564{
565 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
566 match self {
567 Self::Connection(e) => write!(f, "Connection error: {}", e),
568 Self::Handler(e) => write!(f, "Handler error: {}", e),
569 }
570 }
571}
572
573#[cfg(feature = "defmt")]
574impl<C, E> defmt::Format for HandleRequestError<C, E>
575where
576 C: defmt::Format,
577 E: defmt::Format,
578{
579 fn format(&self, f: defmt::Formatter<'_>) {
580 match self {
581 Self::Connection(e) => defmt::write!(f, "Connection error: {}", e),
582 Self::Handler(e) => defmt::write!(f, "Handler error: {}", e),
583 }
584 }
585}
586
587impl<C, E> embedded_io_async::Error for HandleRequestError<C, E>
588where
589 C: Debug + embedded_io_async::Error,
590 E: Debug,
591{
592 fn kind(&self) -> embedded_io_async::ErrorKind {
593 match self {
594 Self::Connection(Error::Io(e)) => e.kind(),
595 _ => embedded_io_async::ErrorKind::Other,
596 }
597 }
598}
599
600#[cfg(feature = "std")]
601impl<C, E> std::error::Error for HandleRequestError<C, E>
602where
603 C: std::error::Error,
604 E: std::error::Error,
605{
606}
607
608pub async fn handle_request<H, T, const N: usize>(
622 buf: &mut [u8],
623 io: T,
624 task_id: impl Display + Copy,
625 handler: H,
626) -> Result<bool, HandlerError<T::Error, H::Error<T::Error>>>
627where
628 H: Handler,
629 T: Read + Write + TcpSplit,
630{
631 let mut connection = Connection::<_, N>::new(buf, io).await?;
632
633 let result = handler.handle(task_id, &mut connection).await;
634
635 match result {
636 Result::Ok(_) => connection.complete().await?,
637 Result::Err(e) => connection
638 .complete_err("INTERNAL ERROR")
639 .await
640 .map_err(|_| HandlerError::Handler(e))?,
641 }
642
643 Ok(connection.needs_close())
644}
645
646pub type DefaultServer =
648 Server<{ DEFAULT_HANDLER_TASKS_COUNT }, { DEFAULT_BUF_SIZE }, { DEFAULT_MAX_HEADERS_COUNT }>;
649
650pub type ServerBuffers<const P: usize, const B: usize> = MaybeUninit<[[u8; B]; P]>;
652
653#[repr(transparent)]
657pub struct Server<
658 const P: usize = DEFAULT_HANDLER_TASKS_COUNT,
659 const B: usize = DEFAULT_BUF_SIZE,
660 const N: usize = DEFAULT_MAX_HEADERS_COUNT,
661>(ServerBuffers<P, B>);
662
663impl<const P: usize, const B: usize, const N: usize> Server<P, B, N> {
664 #[inline(always)]
666 pub const fn new() -> Self {
667 Self(MaybeUninit::uninit())
668 }
669
670 #[inline(never)]
690 #[cold]
691 pub async fn run<A, H>(
692 &mut self,
693 keepalive_timeout_ms: Option<u32>,
694 acceptor: A,
695 handler: H,
696 ) -> Result<(), Error<A::Error>>
697 where
698 A: edge_nal::TcpAccept,
699 H: Handler,
700 {
701 let mutex = Mutex::<NoopRawMutex, _>::new(());
702 let mut tasks = heapless::Vec::<_, P>::new();
703
704 info!(
705 "Creating {} handler tasks, memory: {}B",
706 P,
707 core::mem::size_of_val(&tasks)
708 );
709
710 for index in 0..P {
711 let mutex = &mutex;
712 let acceptor = &acceptor;
713 let task_id = index;
714 let handler = &handler;
715 let buf: *mut [u8; B] = &mut unsafe { self.0.assume_init_mut() }[index];
716
717 unwrap!(tasks
718 .push(async move {
719 loop {
720 debug!(
721 "Handler task {}: Waiting for connection",
722 display2format!(task_id)
723 );
724
725 let io = {
726 let _guard = mutex.lock().await;
727
728 acceptor.accept().await.map_err(Error::Io)?.1
729 };
730
731 debug!(
732 "Handler task {}: Got connection request",
733 display2format!(task_id)
734 );
735
736 handle_connection::<_, _, N>(
737 io,
738 unwrap!(unsafe { buf.as_mut() }),
739 keepalive_timeout_ms,
740 task_id,
741 handler,
742 )
743 .await;
744 }
745 })
746 .map_err(|_| ()));
747 }
748
749 let tasks = pin!(tasks);
750
751 let tasks = unsafe { tasks.map_unchecked_mut(|t| t.as_mut_slice()) };
752 let (result, _) = embassy_futures::select::select_slice(tasks).await;
753
754 warn!(
755 "Server processing loop quit abruptly: {:?}",
756 debug2format!(result)
757 );
758
759 result
760 }
761}
762
763impl<const P: usize, const B: usize, const N: usize> Default for Server<P, B, N> {
764 fn default() -> Self {
765 Self::new()
766 }
767}
768
769#[cfg(feature = "embedded-svc")]
770mod embedded_svc_compat {
771 use embedded_io_async::{Read, Write};
772
773 use embedded_svc::http::server::asynch::{Connection, Headers, Query};
774
775 use crate::io::Body;
776 use crate::RequestHeaders;
777
778 impl<T, const N: usize> Headers for super::Connection<'_, T, N>
779 where
780 T: Read + Write,
781 {
782 fn header(&self, name: &str) -> Option<&'_ str> {
783 self.request_ref()
784 .expect("Not in request mode")
785 .request
786 .header(name)
787 }
788 }
789
790 impl<T, const N: usize> Query for super::Connection<'_, T, N>
791 where
792 T: Read + Write,
793 {
794 fn uri(&self) -> &'_ str {
795 self.request_ref()
796 .expect("Not in request mode")
797 .request
798 .uri()
799 }
800
801 fn method(&self) -> embedded_svc::http::Method {
802 self.request_ref()
803 .expect("Not in request mode")
804 .request
805 .method()
806 }
807 }
808
809 impl<'b, T, const N: usize> Connection for super::Connection<'b, T, N>
810 where
811 T: Read + Write,
812 {
813 type Headers = RequestHeaders<'b, N>;
814
815 type Read = Body<'b, T>;
816
817 type RawConnectionError = T::Error;
818
819 type RawConnection = T;
820
821 fn split(&mut self) -> (&Self::Headers, &mut Self::Read) {
822 super::Connection::split(self)
823 }
824
825 async fn initiate_response(
826 &mut self,
827 status: u16,
828 message: Option<&str>,
829 headers: &[(&str, &str)],
830 ) -> Result<(), Self::Error> {
831 super::Connection::initiate_response(self, status, message, headers).await
832 }
833
834 fn is_response_initiated(&self) -> bool {
835 super::Connection::is_response_initiated(self)
836 }
837
838 fn raw_connection(&mut self) -> Result<&mut Self::RawConnection, Self::Error> {
839 panic!("Not supported")
843 }
844 }
845
846 }