1use core::convert::Infallible;
95use core::fmt::{self, Write};
96use core::future::Future;
97use core::pin::Pin;
98use core::task::{Context, Poll};
99
100use alloc::boxed::Box;
101use alloc::collections::VecDeque;
102use alloc::string::String;
103use alloc::vec::Vec;
104
105use bytes::Bytes;
106use musli::alloc::Global;
107use musli::mode::Binary;
108use musli::reader::SliceReader;
109use musli::storage;
110use musli::{Decode, Encode};
111use rand::prelude::*;
112use rand::rngs::SmallRng;
113use tokio::time::{Duration, Instant, Sleep};
114
115use crate::Buf;
116use crate::api::{Broadcast, ErrorMessage, Event, Id, MessageId, RequestHeader, ResponseHeader};
117use crate::buf::InvalidFrame;
118
119const MAX_CAPACITY: usize = 1048576;
120const CLOSE_NORMAL: u16 = 1000;
121const CLOSE_PROTOCOL_ERROR: u16 = 1002;
122const CLOSE_TIMEOUT: Duration = Duration::from_secs(30);
123const PING_TIMEOUT: Duration = Duration::from_secs(10);
124const DEFAULT_SEED: u64 = 0xdeadbeef;
125
126#[derive(Debug)]
128pub(crate) enum Message {
129 Text,
131 Binary(Bytes),
133 Ping(Bytes),
135 Pong(Bytes),
137 Close,
139}
140
141pub(crate) mod socket_sealed {
142 pub trait Sealed {}
143}
144
145pub(crate) trait SocketImpl
146where
147 Self: self::socket_sealed::Sealed,
148{
149 #[doc(hidden)]
150 type Message;
151
152 #[doc(hidden)]
153 type Error: fmt::Debug;
154
155 #[doc(hidden)]
156 fn poll_next(
157 self: Pin<&mut Self>,
158 ctx: &mut Context<'_>,
159 ) -> Poll<Option<Result<Message, Self::Error>>>;
160
161 #[doc(hidden)]
162 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>;
163
164 #[doc(hidden)]
165 fn start_send(self: Pin<&mut Self>, item: Self::Message) -> Result<(), Self::Error>;
166
167 #[doc(hidden)]
168 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>;
169}
170
171pub(crate) mod server_sealed {
172 pub trait Sealed {}
173}
174
175pub trait ServerImpl
181where
182 Self: self::server_sealed::Sealed,
183{
184 #[doc(hidden)]
185 type Error: fmt::Debug;
186
187 #[doc(hidden)]
188 type Message;
189
190 #[doc(hidden)]
191 #[allow(private_bounds)]
192 type Socket: SocketImpl<Message = Self::Message, Error = Self::Error>;
193
194 #[doc(hidden)]
195 fn ping(data: Bytes) -> Self::Message;
196
197 #[doc(hidden)]
198 fn pong(data: Bytes) -> Self::Message;
199
200 #[doc(hidden)]
201 fn binary(data: &[u8]) -> Self::Message;
202
203 #[doc(hidden)]
204 fn close(code: u16, reason: &str) -> Self::Message;
205}
206
207#[derive(Debug)]
208enum ErrorKind {
209 #[cfg(feature = "axum-core05")]
210 AxumCore05 {
211 error: axum_core05::Error,
212 },
213 Musli {
214 error: storage::Error,
215 },
216 FormatError,
217 InvalidFrame {
218 error: InvalidFrame,
219 },
220}
221
222#[derive(Debug)]
224pub struct Error {
225 kind: ErrorKind,
226}
227
228impl Error {
229 #[inline]
230 const fn new(kind: ErrorKind) -> Self {
231 Self { kind }
232 }
233}
234
235impl fmt::Display for Error {
236 #[inline]
237 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
238 match &self.kind {
239 #[cfg(feature = "axum-core05")]
240 ErrorKind::AxumCore05 { .. } => write!(f, "Error in axum-core"),
241 ErrorKind::Musli { .. } => write!(f, "Error in musli"),
242 ErrorKind::FormatError => write!(f, "Error formatting error response"),
243 ErrorKind::InvalidFrame { error } => error.fmt(f),
244 }
245 }
246}
247
248impl core::error::Error for Error {
249 #[inline]
250 fn source(&self) -> Option<&(dyn core::error::Error + 'static)> {
251 match &self.kind {
252 #[cfg(feature = "axum-core05")]
253 ErrorKind::AxumCore05 { error } => Some(error),
254 ErrorKind::Musli { error } => Some(error),
255 _ => None,
256 }
257 }
258}
259
260#[cfg(feature = "axum-core05")]
261impl From<axum_core05::Error> for Error {
262 #[inline]
263 fn from(error: axum_core05::Error) -> Self {
264 Self::new(ErrorKind::AxumCore05 { error })
265 }
266}
267
268impl From<storage::Error> for Error {
269 #[inline]
270 fn from(error: storage::Error) -> Self {
271 Self::new(ErrorKind::Musli { error })
272 }
273}
274
275impl From<ErrorKind> for Error {
276 #[inline]
277 fn from(kind: ErrorKind) -> Self {
278 Self::new(kind)
279 }
280}
281
282impl From<InvalidFrame> for Error {
283 #[inline]
284 fn from(error: InvalidFrame) -> Self {
285 Self::new(ErrorKind::InvalidFrame { error })
286 }
287}
288
289type Result<T, E = Error> = core::result::Result<T, E>;
290
291pub struct Response {
293 handled: bool,
294}
295
296pub trait IntoResponse {
298 type Error;
300
301 fn into_response(self) -> Result<Response, Self::Error>;
303}
304
305impl IntoResponse for () {
309 type Error = Infallible;
310
311 #[inline]
312 fn into_response(self) -> Result<Response, Self::Error> {
313 Ok(Response { handled: true })
314 }
315}
316
317impl IntoResponse for bool {
322 type Error = Infallible;
323
324 #[inline]
325 fn into_response(self) -> Result<Response, Self::Error> {
326 Ok(Response { handled: self })
327 }
328}
329
330impl<T, E> IntoResponse for Result<T, E>
338where
339 T: IntoResponse<Error = Infallible>,
340 E: fmt::Display,
341{
342 type Error = E;
343
344 #[inline]
345 fn into_response(self) -> Result<Response, E> {
346 match self {
347 Ok(into_response) => match IntoResponse::into_response(into_response) {
348 Ok(response) => Ok(response),
349 Err(error) => match error {},
350 },
351 Err(error) => Err(error),
352 }
353 }
354}
355
356impl<T> IntoResponse for Option<T>
362where
363 T: IntoResponse,
364{
365 type Error = T::Error;
366
367 #[inline]
368 fn into_response(self) -> Result<Response, Self::Error> {
369 match self {
370 Some(value) => value.into_response(),
371 None => Ok(Response { handled: false }),
372 }
373 }
374}
375
376pub trait Handler {
382 type Id: Id;
384 type Response: IntoResponse;
386
387 fn handle<'this>(
389 &'this mut self,
390 id: Self::Id,
391 incoming: &'this mut Incoming<'_>,
392 outgoing: &'this mut Outgoing<'_>,
393 ) -> impl Future<Output = Self::Response> + Send + 'this;
394}
395
396struct Pinned<S> {
397 socket: S,
398 close_sleep: Sleep,
399 ping_sleep: Sleep,
400}
401
402impl<S> Pinned<S> {
403 #[inline]
404 fn project(self: Pin<&mut Self>) -> (Pin<&mut Sleep>, Pin<&mut Sleep>, Pin<&mut S>) {
405 unsafe {
406 let this = self.get_unchecked_mut();
407 (
408 Pin::new_unchecked(&mut this.close_sleep),
409 Pin::new_unchecked(&mut this.ping_sleep),
410 Pin::new_unchecked(&mut this.socket),
411 )
412 }
413 }
414}
415
416pub struct Server<S, H>
422where
423 S: ServerImpl,
424{
425 closing: bool,
426 outbound: Buf,
427 error: String,
428 handler: H,
429 last_ping: Option<[u8; 4]>,
430 rng: SmallRng,
431 max_capacity: usize,
432 out: VecDeque<S::Message>,
433 socket_send: bool,
434 socket_flush: bool,
435 pinned: Pin<Box<Pinned<S::Socket>>>,
436}
437
438impl<S, H> Server<S, H>
439where
440 S: ServerImpl,
441{
442 #[inline]
444 pub(crate) fn new(socket: S::Socket, handler: H) -> Self {
445 let now = Instant::now();
446
447 Self {
448 closing: false,
449 outbound: Buf::default(),
450 error: String::new(),
451 handler,
452 last_ping: None,
453 rng: SmallRng::seed_from_u64(DEFAULT_SEED),
454 max_capacity: MAX_CAPACITY,
455 out: VecDeque::new(),
456 socket_send: false,
457 socket_flush: false,
458 pinned: Box::pin(Pinned {
459 socket,
460 close_sleep: tokio::time::sleep_until(now + CLOSE_TIMEOUT),
461 ping_sleep: tokio::time::sleep_until(now + PING_TIMEOUT),
462 }),
463 }
464 }
465
466 pub fn handler(&self) -> &H {
468 &self.handler
469 }
470
471 pub fn handler_mut(&mut self) -> &mut H {
473 &mut self.handler
474 }
475
476 #[inline]
485 pub fn max_capacity(mut self, max_capacity: usize) -> Self {
486 self.max_capacity = max_capacity;
487 self
488 }
489
490 #[inline]
496 pub fn with_max_capacity(mut self, max_capacity: usize) -> Self {
497 self.max_capacity = max_capacity;
498 self
499 }
500}
501
502impl<S, H> Server<S, H>
503where
504 S: ServerImpl,
505{
506 #[inline]
512 pub fn seed(mut self, seed: u64) -> Self {
513 self.rng = SmallRng::seed_from_u64(seed);
514 self
515 }
516}
517
518impl<S, H> Server<S, H>
519where
520 S: ServerImpl,
521 Error: From<S::Error>,
522 H: Handler<Response: IntoResponse<Error: fmt::Display>>,
523{
524 pub async fn run(&mut self) -> Result<(), Error> {
528 loop {
529 if self.closing && self.out.is_empty() && self.outbound.is_empty() {
530 break;
531 }
532
533 self.handle_send()?;
534
535 let result = {
536 let inner = Select {
537 pinned: self.pinned.as_mut(),
538 wants_socket_send: !self.socket_send,
539 wants_socket_flush: self.socket_flush,
540 };
541
542 inner.await
543 };
544
545 tracing::debug!(?result);
546
547 match result {
548 Output::Close => {
549 self.out
550 .push_back(S::close(CLOSE_NORMAL, "connection timed out"));
551 self.closing = true;
552 }
553 Output::Ping => {
554 self.handle_ping()?;
555 }
556 Output::Recv(message) => {
557 let Some(message) = message else {
558 self.closing = true;
559 continue;
560 };
561
562 match message? {
563 Message::Text => {
564 self.out.push_back(S::close(
565 CLOSE_PROTOCOL_ERROR,
566 "Unsupported text message",
567 ));
568 self.closing = true;
569 }
570 Message::Binary(bytes) => {
571 self.handle_message(bytes).await?;
572 }
573 Message::Ping(payload) => {
574 self.out.push_back(S::pong(payload));
575 }
576 Message::Pong(data) => {
577 self.handle_pong(data)?;
578 }
579 Message::Close => {
580 self.closing = true;
581 }
582 }
583 }
584 Output::Send(result) => {
585 if let Err(err) = result {
586 return Err(Error::from(err));
587 };
588
589 self.socket_send = true;
590 }
591 Output::Flushed(result) => {
592 if let Err(err) = result {
593 return Err(Error::from(err));
594 };
595
596 self.socket_flush = false;
597 }
598 }
599 }
600
601 Ok(())
602 }
603
604 pub fn broadcast<T>(&mut self, message: T) -> Result<(), Error>
609 where
610 T: Event,
611 {
612 tracing::debug!(id = ?<T::Broadcast as Broadcast>::ID, "Broadcast");
613
614 self.outbound.write(ResponseHeader {
615 serial: 0,
616 broadcast: <T::Broadcast as Broadcast>::ID.get(),
617 error: 0,
618 })?;
619
620 self.outbound.write(message)?;
621 self.outbound.done();
622 Ok(())
623 }
624
625 fn format_err(&mut self, error: impl fmt::Display) -> Result<(), Error> {
626 self.error.clear();
627
628 if write!(self.error, "{error}").is_err() {
629 self.error.clear();
630 return Err(Error::new(ErrorKind::FormatError));
631 }
632
633 Ok(())
634 }
635
636 #[tracing::instrument(skip(self, bytes))]
637 async fn handle_message(&mut self, bytes: Bytes) -> Result<(), Error> {
638 let mut reader = SliceReader::new(&bytes);
639
640 let header: RequestHeader = match storage::decode(&mut reader) {
641 Ok(header) => header,
642 Err(error) => {
643 tracing::debug!(?error, "Invalid request header");
644 self.out
645 .push_back(S::close(CLOSE_PROTOCOL_ERROR, "Invalid request header"));
646 self.closing = true;
647 return Ok(());
648 }
649 };
650
651 let err = 'err: {
652 let Some(id) = MessageId::new(header.id) else {
653 self.format_err(format_args!("Unsupported message id {}", header.id))?;
654 break 'err true;
655 };
656
657 let id = <H::Id as Id>::from_id(id);
658
659 let res = match self.handle_request(reader, header.serial, id).await {
660 Ok(res) => res,
661 Err(err) => {
662 self.format_err(err)?;
663 break 'err true;
664 }
665 };
666
667 let res = match res.into_response() {
668 Ok(res) => res,
669 Err(err) => {
670 self.format_err(err)?;
671 break 'err true;
672 }
673 };
674
675 if !res.handled {
676 self.format_err(format_args!("No support for request `{}`", header.id))?;
677 break 'err true;
678 }
679
680 false
681 };
682
683 if err {
684 self.outbound.reset();
686
687 self.outbound.write(ResponseHeader {
688 serial: header.serial,
689 broadcast: 0,
690 error: MessageId::ERROR_MESSAGE.get(),
691 })?;
692
693 self.outbound.write(ErrorMessage {
694 message: &self.error,
695 })?;
696 }
697
698 self.outbound.done();
699 Ok(())
700 }
701
702 #[tracing::instrument(skip(self))]
703 fn handle_ping(&mut self) -> Result<(), Error> {
704 let (_, mut ping_sleep, _) = self.pinned.as_mut().project();
705
706 let payload = self.rng.random::<u32>();
707 let payload = payload.to_ne_bytes();
708
709 self.last_ping = Some(payload);
710
711 tracing::debug!(data = ?&payload[..], "Sending ping");
712
713 self.out
714 .push_back(S::ping(Bytes::from_owner(Vec::from(payload))));
715
716 let now = Instant::now();
717 ping_sleep.as_mut().reset(now + PING_TIMEOUT);
718 Ok(())
719 }
720
721 #[tracing::instrument(skip(self, payload))]
722 fn handle_pong(&mut self, payload: Bytes) -> Result<(), Error> {
723 let (close_sleep, ping_sleep, _) = self.pinned.as_mut().project();
724
725 tracing::debug!(payload = ?&payload[..], "Pong");
726
727 let Some(expected) = self.last_ping else {
728 tracing::debug!("No ping sent");
729 return Ok(());
730 };
731
732 if expected[..] != payload[..] {
733 tracing::debug!(?expected, ?payload, "Pong doesn't match");
734 return Ok(());
735 }
736
737 let now = Instant::now();
738
739 close_sleep.reset(now + CLOSE_TIMEOUT);
740 ping_sleep.reset(now + PING_TIMEOUT);
741 self.last_ping = None;
742 Ok(())
743 }
744
745 #[tracing::instrument(skip(self))]
746 fn handle_send(&mut self) -> Result<(), Error> {
747 let (_, _, mut socket) = self.pinned.as_mut().project();
748
749 if self.socket_send
750 && let Some(message) = self.out.pop_front()
751 {
752 socket.as_mut().start_send(message)?;
753 self.socket_flush = true;
754 self.socket_send = false;
755 }
756
757 if self.socket_send
758 && let Some(frame) = self.outbound.read()?
759 {
760 socket.as_mut().start_send(S::binary(frame))?;
761
762 if self.outbound.is_empty() {
763 self.outbound.clear();
764 }
765
766 self.socket_flush = true;
767 self.socket_send = false;
768 }
769
770 Ok(())
771 }
772
773 async fn handle_request(
774 &mut self,
775 reader: SliceReader<'_>,
776 serial: u32,
777 id: H::Id,
778 ) -> Result<H::Response, storage::Error> {
779 tracing::debug!(serial, ?id, "Got request");
780
781 self.outbound.write(ResponseHeader {
782 serial,
783 broadcast: 0,
784 error: 0,
785 })?;
786
787 let mut incoming = Incoming {
788 error: None,
789 reader,
790 };
791
792 let mut outgoing = Outgoing {
793 error: None,
794 buf: &mut self.outbound,
795 };
796
797 let response = self.handler.handle(id, &mut incoming, &mut outgoing).await;
798
799 if let Some(error) = incoming.error.take() {
800 return Err(error);
801 }
802
803 if let Some(error) = outgoing.error.take() {
804 return Err(error);
805 }
806
807 Ok(response)
808 }
809}
810
811#[derive(Debug)]
812enum Output<E> {
813 Close,
815 Ping,
817 Recv(Option<Result<Message, E>>),
819 Send(Result<(), E>),
821 Flushed(Result<(), E>),
823}
824
825struct Select<'a, S> {
826 pinned: Pin<&'a mut Pinned<S>>,
827 wants_socket_send: bool,
828 wants_socket_flush: bool,
829}
830
831impl<S> Future for Select<'_, S>
832where
833 S: SocketImpl,
834{
835 type Output = Output<S::Error>;
836
837 #[inline]
838 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
839 let close;
840 let ping;
841 let mut socket;
842 let wants_socket_send;
843 let wants_socket_flush;
844
845 unsafe {
847 let this = Pin::get_unchecked_mut(self);
848 (close, ping, socket) = this.pinned.as_mut().project();
849 wants_socket_send = this.wants_socket_send;
850 wants_socket_flush = this.wants_socket_flush;
851 };
852
853 if close.poll(cx).is_ready() {
854 return Poll::Ready(Output::Close);
855 }
856
857 if ping.poll(cx).is_ready() {
858 return Poll::Ready(Output::Ping);
859 }
860
861 if let Poll::Ready(output) = socket.as_mut().poll_next(cx) {
862 return Poll::Ready(Output::Recv(output));
863 }
864
865 if wants_socket_send && let Poll::Ready(result) = socket.as_mut().poll_ready(cx) {
866 return Poll::Ready(Output::Send(result));
867 }
868
869 if wants_socket_flush && let Poll::Ready(result) = socket.as_mut().poll_flush(cx) {
870 return Poll::Ready(Output::Flushed(result));
871 }
872
873 Poll::Pending
874 }
875}
876
877pub struct Incoming<'de> {
883 error: Option<storage::Error>,
884 reader: SliceReader<'de>,
885}
886
887impl<'de> Incoming<'de> {
888 #[inline]
896 pub fn read<T>(&mut self) -> Option<T>
897 where
898 T: Decode<'de, Binary, Global>,
899 {
900 match storage::decode(&mut self.reader) {
901 Ok(value) => Some(value),
902 Err(error) => {
903 self.error = Some(error);
904 None
905 }
906 }
907 }
908}
909
910pub struct Outgoing<'a> {
916 error: Option<storage::Error>,
917 buf: &'a mut Buf,
918}
919
920impl Outgoing<'_> {
921 pub fn write<T>(&mut self, value: T)
927 where
928 T: Encode<Binary>,
929 {
930 if let Err(error) = self.buf.write(value) {
931 self.error = Some(error);
932 }
933 }
934}