1use core::convert::Infallible;
91use core::fmt::{self, Write};
92use core::future::Future;
93use core::pin::Pin;
94use core::task::{Context, Poll};
95
96use alloc::collections::VecDeque;
97use alloc::string::String;
98use alloc::vec::Vec;
99
100use bytes::Bytes;
101use musli::alloc::Global;
102use musli::mode::Binary;
103use musli::reader::SliceReader;
104use musli::storage;
105use musli::{Decode, Encode};
106use rand::prelude::*;
107use rand::rngs::SmallRng;
108use tokio::time::{Duration, Instant, Sleep};
109
110use crate::Buf;
111use crate::api::{Broadcast, ErrorMessage, Event, Id, MessageId, RequestHeader, ResponseHeader};
112use crate::buf::InvalidFrame;
113
114const MAX_CAPACITY: usize = 1048576;
115const CLOSE_NORMAL: u16 = 1000;
116const CLOSE_PROTOCOL_ERROR: u16 = 1002;
117const CLOSE_TIMEOUT: Duration = Duration::from_secs(30);
118const PING_TIMEOUT: Duration = Duration::from_secs(10);
119const DEFAULT_SEED: u64 = 0xdeadbeef;
120
121#[derive(Debug)]
123pub(crate) enum Message {
124 Text,
126 Binary(Bytes),
128 Ping(Bytes),
130 Pong(Bytes),
132 Close,
134}
135
136pub(crate) mod socket_sealed {
137 pub trait Sealed {}
138}
139
140pub(crate) trait SocketImpl
141where
142 Self: self::socket_sealed::Sealed,
143{
144 #[doc(hidden)]
145 type Message;
146
147 #[doc(hidden)]
148 type Error: fmt::Debug;
149
150 #[doc(hidden)]
151 fn poll_next(
152 self: Pin<&mut Self>,
153 ctx: &mut Context<'_>,
154 ) -> Poll<Option<Result<Message, Self::Error>>>;
155
156 #[doc(hidden)]
157 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>;
158
159 #[doc(hidden)]
160 fn start_send(self: Pin<&mut Self>, item: Self::Message) -> Result<(), Self::Error>;
161
162 #[doc(hidden)]
163 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>;
164}
165
166pub(crate) mod server_sealed {
167 pub trait Sealed {}
168}
169
170pub trait ServerImpl
176where
177 Self: self::server_sealed::Sealed,
178{
179 #[doc(hidden)]
180 type Error: fmt::Debug;
181
182 #[doc(hidden)]
183 type Message;
184
185 #[doc(hidden)]
186 #[allow(private_bounds)]
187 type Socket: SocketImpl<Message = Self::Message, Error = Self::Error>;
188
189 #[doc(hidden)]
190 fn ping(data: Bytes) -> Self::Message;
191
192 #[doc(hidden)]
193 fn pong(data: Bytes) -> Self::Message;
194
195 #[doc(hidden)]
196 fn binary(data: &[u8]) -> Self::Message;
197
198 #[doc(hidden)]
199 fn close(code: u16, reason: &str) -> Self::Message;
200}
201
202#[derive(Debug)]
203enum ErrorKind {
204 #[cfg(feature = "axum-core05")]
205 AxumCore05 {
206 error: axum_core05::Error,
207 },
208 Musli {
209 error: storage::Error,
210 },
211 FormatError,
212 InvalidFrame {
213 error: InvalidFrame,
214 },
215}
216
217#[derive(Debug)]
219pub struct Error {
220 kind: ErrorKind,
221}
222
223impl Error {
224 #[inline]
225 const fn new(kind: ErrorKind) -> Self {
226 Self { kind }
227 }
228}
229
230impl fmt::Display for Error {
231 #[inline]
232 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
233 match &self.kind {
234 #[cfg(feature = "axum-core05")]
235 ErrorKind::AxumCore05 { .. } => write!(f, "Error in axum-core"),
236 ErrorKind::Musli { .. } => write!(f, "Error in musli"),
237 ErrorKind::FormatError => write!(f, "Error formatting error response"),
238 ErrorKind::InvalidFrame { error } => error.fmt(f),
239 }
240 }
241}
242
243impl core::error::Error for Error {
244 #[inline]
245 fn source(&self) -> Option<&(dyn core::error::Error + 'static)> {
246 match &self.kind {
247 #[cfg(feature = "axum-core05")]
248 ErrorKind::AxumCore05 { error } => Some(error),
249 ErrorKind::Musli { error } => Some(error),
250 _ => None,
251 }
252 }
253}
254
255#[cfg(feature = "axum-core05")]
256impl From<axum_core05::Error> for Error {
257 #[inline]
258 fn from(error: axum_core05::Error) -> Self {
259 Self::new(ErrorKind::AxumCore05 { error })
260 }
261}
262
263impl From<storage::Error> for Error {
264 #[inline]
265 fn from(error: storage::Error) -> Self {
266 Self::new(ErrorKind::Musli { error })
267 }
268}
269
270impl From<ErrorKind> for Error {
271 #[inline]
272 fn from(kind: ErrorKind) -> Self {
273 Self::new(kind)
274 }
275}
276
277impl From<InvalidFrame> for Error {
278 #[inline]
279 fn from(error: InvalidFrame) -> Self {
280 Self::new(ErrorKind::InvalidFrame { error })
281 }
282}
283
284type Result<T, E = Error> = core::result::Result<T, E>;
285
286pub struct Response {
288 handled: bool,
289}
290
291pub trait IntoResponse {
293 type Error: fmt::Display;
295
296 fn into_response(self) -> Result<Response, Self::Error>;
298}
299
300impl IntoResponse for () {
304 type Error = Infallible;
305
306 #[inline]
307 fn into_response(self) -> Result<Response, Self::Error> {
308 Ok(Response { handled: true })
309 }
310}
311
312impl IntoResponse for bool {
317 type Error = Infallible;
318
319 #[inline]
320 fn into_response(self) -> Result<Response, Self::Error> {
321 Ok(Response { handled: self })
322 }
323}
324
325impl<T, E> IntoResponse for Result<T, E>
333where
334 T: IntoResponse<Error = Infallible>,
335 E: fmt::Display,
336{
337 type Error = E;
338
339 #[inline]
340 fn into_response(self) -> Result<Response, E> {
341 match self {
342 Ok(into_response) => match IntoResponse::into_response(into_response) {
343 Ok(response) => Ok(response),
344 Err(error) => match error {},
345 },
346 Err(error) => Err(error),
347 }
348 }
349}
350
351impl<T> IntoResponse for Option<T>
357where
358 T: IntoResponse,
359{
360 type Error = T::Error;
361
362 #[inline]
363 fn into_response(self) -> Result<Response, Self::Error> {
364 match self {
365 Some(value) => value.into_response(),
366 None => Ok(Response { handled: false }),
367 }
368 }
369}
370
371pub trait Handler {
377 type Id: Id;
379 type Response: IntoResponse;
381
382 fn handle<'this>(
384 &'this mut self,
385 id: Self::Id,
386 incoming: &'this mut Incoming<'_>,
387 outgoing: &'this mut Outgoing<'_>,
388 ) -> impl Future<Output = Self::Response> + Send + 'this;
389}
390
391pub struct Server<S, H>
397where
398 S: ServerImpl,
399{
400 closing: bool,
401 buf: Buf,
402 error: String,
403 socket: S::Socket,
404 handler: H,
405 last_ping: Option<u32>,
406 rng: SmallRng,
407 close_sleep: Sleep,
408 ping_sleep: Sleep,
409 max_capacity: usize,
410 out: VecDeque<S::Message>,
411 socket_send: bool,
412 socket_flush: bool,
413}
414
415impl<S, H> Server<S, H>
416where
417 S: ServerImpl,
418{
419 #[inline]
421 pub(crate) fn new(socket: S::Socket, handler: H) -> Self {
422 let now = Instant::now();
423
424 Self {
425 closing: false,
426 buf: Buf::default(),
427 error: String::new(),
428 socket,
429 handler,
430 last_ping: None::<u32>,
431 rng: SmallRng::seed_from_u64(DEFAULT_SEED),
432 close_sleep: tokio::time::sleep_until(now + CLOSE_TIMEOUT),
433 ping_sleep: tokio::time::sleep_until(now + PING_TIMEOUT),
434 max_capacity: MAX_CAPACITY,
435 out: VecDeque::new(),
436 socket_send: false,
437 socket_flush: false,
438 }
439 }
440
441 #[inline]
447 pub fn with_max_capacity(mut self, max_capacity: usize) -> Self {
448 self.max_capacity = max_capacity;
449 self
450 }
451}
452
453impl<S, H> Server<S, H>
454where
455 S: ServerImpl,
456{
457 #[inline]
463 pub fn seed(mut self, seed: u64) -> Self {
464 self.rng = SmallRng::seed_from_u64(seed);
465 self
466 }
467}
468
469impl<S, H> Server<S, H>
470where
471 S: ServerImpl,
472 Error: From<S::Error>,
473 H: Handler,
474{
475 fn format_err(&mut self, error: impl fmt::Display) -> Result<(), Error> {
476 self.error.clear();
477
478 if write!(self.error, "{error}").is_err() {
479 self.error.clear();
480 return Err(Error::new(ErrorKind::FormatError));
481 }
482
483 Ok(())
484 }
485
486 pub async fn run(self: Pin<&mut Self>) -> Result<(), Error> {
490 let this = unsafe { Pin::get_unchecked_mut(self) };
492
493 loop {
494 if this.closing && this.out.is_empty() && this.buf.is_empty() {
495 break;
496 }
497
498 this.handle_send()?;
499
500 let result = {
501 let inner = Select {
502 close: &mut this.close_sleep,
503 ping: &mut this.ping_sleep,
504 socket: &mut this.socket,
505 wants_socket_send: !this.socket_send,
506 wants_socket_flush: this.socket_flush,
507 };
508
509 inner.await
510 };
511
512 tracing::debug!(?result);
513
514 match result {
515 Output::Close => {
516 this.out
517 .push_back(S::close(CLOSE_NORMAL, "connection timed out"));
518 this.closing = true;
519 }
520 Output::Ping => {
521 this.handle_ping()?;
522 }
523 Output::Recv(message) => {
524 let Some(message) = message else {
525 this.closing = true;
526 continue;
527 };
528
529 match message? {
530 Message::Text => {
531 this.out.push_back(S::close(
532 CLOSE_PROTOCOL_ERROR,
533 "Unsupported text message",
534 ));
535 this.closing = true;
536 }
537 Message::Binary(bytes) => {
538 this.handle_message(bytes).await?;
539 }
540 Message::Ping(payload) => {
541 this.out.push_back(S::pong(payload));
542 }
543 Message::Pong(data) => {
544 this.handle_pong(data)?;
545 }
546 Message::Close => {
547 this.closing = true;
548 }
549 }
550 }
551 Output::Send(result) => {
552 if let Err(err) = result {
553 return Err(Error::from(err));
554 };
555
556 this.socket_send = true;
557 }
558 Output::Flushed(result) => {
559 if let Err(err) = result {
560 return Err(Error::from(err));
561 };
562
563 this.socket_flush = false;
564 }
565 }
566 }
567
568 Ok(())
569 }
570
571 #[tracing::instrument(skip(self, bytes))]
572 async fn handle_message(&mut self, bytes: Bytes) -> Result<(), Error> {
573 let mut reader = SliceReader::new(&bytes);
574
575 let header: RequestHeader = match storage::decode(&mut reader) {
576 Ok(header) => header,
577 Err(error) => {
578 tracing::debug!(?error, "Invalid request header");
579 self.out
580 .push_back(S::close(CLOSE_PROTOCOL_ERROR, "Invalid request header"));
581 self.closing = true;
582 return Ok(());
583 }
584 };
585
586 let err = 'err: {
587 let Some(id) = <H::Id as Id>::from_raw(header.id) else {
588 self.format_err("Unsupported message id 0")?;
589 break 'err true;
590 };
591
592 let res = match self.handle_request(reader, header.serial, id).await {
593 Ok(res) => res,
594 Err(err) => {
595 self.format_err(err)?;
596 break 'err true;
597 }
598 };
599
600 let res = match res.into_response() {
601 Ok(res) => res,
602 Err(err) => {
603 self.format_err(err)?;
604 break 'err true;
605 }
606 };
607
608 if !res.handled {
609 self.format_err(format_args!("No support for request `{}`", header.id))?;
610 break 'err true;
611 }
612
613 false
614 };
615
616 if err {
617 self.buf.reset();
619
620 self.buf.write(ResponseHeader {
621 serial: header.serial,
622 broadcast: 0,
623 error: MessageId::ERROR_MESSAGE.get(),
624 })?;
625
626 self.buf.write(ErrorMessage {
627 message: &self.error,
628 })?;
629 }
630
631 self.buf.done();
632 Ok(())
633 }
634
635 #[tracing::instrument(skip(self))]
636 fn handle_ping(&mut self) -> Result<(), Error> {
637 let mut ping_sleep = unsafe { Pin::new_unchecked(&mut self.ping_sleep) };
638
639 let payload = self.rng.random::<u32>();
640 self.last_ping = Some(payload);
641 let data = payload.to_ne_bytes().into_iter().collect::<Vec<_>>();
642
643 tracing::debug!(data = ?&data[..], "Sending ping");
644 self.out.push_back(S::ping(data.into()));
645
646 let now = Instant::now();
647 ping_sleep.as_mut().reset(now + PING_TIMEOUT);
648
649 Ok(())
650 }
651
652 #[tracing::instrument(skip(self, data))]
653 fn handle_pong(&mut self, data: Bytes) -> Result<(), Error> {
654 let close_sleep = unsafe { Pin::new_unchecked(&mut self.close_sleep) };
655 let ping_sleep = unsafe { Pin::new_unchecked(&mut self.ping_sleep) };
656
657 tracing::debug!(data = ?&data[..], "Pong");
658
659 let Some(expected) = self.last_ping else {
660 tracing::debug!("No ping sent");
661 return Ok(());
662 };
663
664 let expected = expected.to_ne_bytes();
665
666 if expected[..] != data[..] {
667 tracing::debug!(?expected, ?data, "Pong doesn't match");
668 return Ok(());
669 }
670
671 let now = Instant::now();
672
673 close_sleep.reset(now + CLOSE_TIMEOUT);
674 ping_sleep.reset(now + PING_TIMEOUT);
675 self.last_ping = None;
676 Ok(())
677 }
678
679 #[tracing::instrument(skip(self))]
680 fn handle_send(&mut self) -> Result<(), Error> {
681 let mut socket = unsafe { Pin::new_unchecked(&mut self.socket) };
682
683 if self.socket_send
684 && let Some(message) = self.out.pop_front()
685 {
686 socket.as_mut().start_send(message)?;
687 self.socket_flush = true;
688 self.socket_send = false;
689 }
690
691 if self.socket_send
692 && let Some(frame) = self.buf.read()?
693 {
694 socket.as_mut().start_send(S::binary(frame))?;
695
696 if self.buf.is_empty() {
697 self.buf.clear();
698 }
699
700 self.socket_flush = true;
701 self.socket_send = false;
702 }
703
704 Ok(())
705 }
706
707 pub fn broadcast<T>(self: Pin<&mut Self>, message: T) -> Result<(), Error>
712 where
713 T: Event,
714 {
715 let this = unsafe { Pin::get_unchecked_mut(self) };
716
717 this.buf.write(ResponseHeader {
718 serial: 0,
719 broadcast: <T::Broadcast as Broadcast>::ID.get(),
720 error: 0,
721 })?;
722
723 this.buf.write(message)?;
724 this.buf.done();
725 Ok(())
726 }
727
728 async fn handle_request(
729 &mut self,
730 reader: SliceReader<'_>,
731 serial: u32,
732 id: H::Id,
733 ) -> Result<H::Response, storage::Error> {
734 tracing::debug!(serial, ?id, "Got request");
735
736 self.buf.write(ResponseHeader {
737 serial,
738 broadcast: 0,
739 error: 0,
740 })?;
741
742 let mut incoming = Incoming {
743 error: None,
744 reader,
745 };
746
747 let mut outgoing = Outgoing {
748 error: None,
749 buf: &mut self.buf,
750 };
751
752 let response = self.handler.handle(id, &mut incoming, &mut outgoing).await;
753
754 if let Some(error) = incoming.error.take() {
755 return Err(error);
756 }
757
758 if let Some(error) = outgoing.error.take() {
759 return Err(error);
760 }
761
762 Ok(response)
763 }
764}
765
766#[derive(Debug)]
767enum Output<E> {
768 Close,
770 Ping,
772 Recv(Option<Result<Message, E>>),
774 Send(Result<(), E>),
776 Flushed(Result<(), E>),
778}
779
780struct Select<'a, C, P, S> {
781 close: &'a mut C,
782 ping: &'a mut P,
783 socket: &'a mut S,
784 wants_socket_send: bool,
785 wants_socket_flush: bool,
786}
787
788impl<C, P, S> Future for Select<'_, C, P, S>
789where
790 C: Future,
791 P: Future,
792 S: SocketImpl,
793{
794 type Output = Output<S::Error>;
795
796 #[inline]
797 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
798 let close;
799 let ping;
800 let mut socket;
801 let wants_socket_send;
802 let wants_socket_flush;
803
804 unsafe {
806 let this = Pin::get_unchecked_mut(self);
807 close = Pin::new_unchecked(&mut *this.close);
808 ping = Pin::new_unchecked(&mut *this.ping);
809 socket = Pin::new_unchecked(&mut *this.socket);
810 wants_socket_send = this.wants_socket_send;
811 wants_socket_flush = this.wants_socket_flush;
812 };
813
814 if close.poll(cx).is_ready() {
815 return Poll::Ready(Output::Close);
816 }
817
818 if ping.poll(cx).is_ready() {
819 return Poll::Ready(Output::Ping);
820 }
821
822 if let Poll::Ready(output) = socket.as_mut().poll_next(cx) {
823 return Poll::Ready(Output::Recv(output));
824 }
825
826 if wants_socket_send && let Poll::Ready(result) = socket.as_mut().poll_ready(cx) {
827 return Poll::Ready(Output::Send(result));
828 }
829
830 if wants_socket_flush && let Poll::Ready(result) = socket.as_mut().poll_flush(cx) {
831 return Poll::Ready(Output::Flushed(result));
832 }
833
834 Poll::Pending
835 }
836}
837
838pub struct Incoming<'de> {
844 error: Option<storage::Error>,
845 reader: SliceReader<'de>,
846}
847
848impl<'de> Incoming<'de> {
849 #[inline]
857 pub fn read<T>(&mut self) -> Option<T>
858 where
859 T: Decode<'de, Binary, Global>,
860 {
861 match storage::decode(&mut self.reader) {
862 Ok(value) => Some(value),
863 Err(error) => {
864 self.error = Some(error);
865 None
866 }
867 }
868 }
869}
870
871pub struct Outgoing<'a> {
877 error: Option<storage::Error>,
878 buf: &'a mut Buf,
879}
880
881impl Outgoing<'_> {
882 pub fn write<T>(&mut self, value: T)
888 where
889 T: Encode<Binary>,
890 {
891 if let Err(error) = self.buf.write(value) {
892 self.error = Some(error);
893 }
894 }
895}