stomp_agnostic/
transport.rs

1use crate::{FromServer, Message, ToServer, frame};
2use async_trait::async_trait;
3use bytes::{Buf, Bytes, BytesMut};
4use std::fmt::Debug;
5use std::str::Utf8Error;
6use thiserror::Error;
7use winnow::Partial;
8use winnow::error::{ContextError, ErrMode};
9use winnow::stream::Offset;
10
11#[async_trait]
12pub trait Transport: Send + Sync {
13    /// A side channel to shuffle arbitrary data that is not part of the STOMP communication,
14    /// e.g. WebSocket Ping/Pong.
15    type ProtocolSideChannel;
16
17    async fn write(&mut self, message: Message<ToServer>) -> Result<(), WriteError>;
18    async fn read(&mut self) -> Result<ReadResponse<Self::ProtocolSideChannel>, ReadError>;
19}
20
21/// A response coming down the line from the transport layer. When the transport layer is
22/// e.g. WebSocket, custom data such as Ping/Pong can be handled separately from STOMP data
23/// by using the `Custom` variant.
24#[derive(Debug)]
25pub enum ReadResponse<T> {
26    Binary(Bytes),
27    Custom(T),
28}
29
30/// A parsed response, either a [Message] coming from the server, or a custom protocol signal
31/// in the `Custom` variant.
32#[derive(Debug)]
33pub enum Response<T>
34where
35    T: Debug,
36{
37    Message(Message<FromServer>),
38    Custom(T),
39}
40
41#[derive(Error, Debug)]
42pub enum WriteError {
43    #[error("Utf8Error")]
44    Utf8Error(#[from] Utf8Error),
45    #[error(transparent)]
46    Other(#[from] anyhow::Error),
47}
48
49#[derive(Error, Debug)]
50pub enum ReadError {
51    /// This is the most important [Transport] error to take care of - when the connection has been
52    /// closed, this is the only error that shall be returned when reading. This is so that
53    /// implementors / users of the trait can handle this case consistently.
54    #[error("Connection closed")]
55    ConnectionClosed,
56    #[error("Unexpected message")]
57    UnexpectedMessage,
58    #[error("Parser error")]
59    Parser(ErrMode<ContextError>),
60    #[error(transparent)]
61    Other(#[from] anyhow::Error),
62}
63
64pub(crate) struct BufferedTransport<T>
65where
66    T: Transport,
67    T::ProtocolSideChannel: Debug,
68{
69    transport: T,
70    buffer: BytesMut,
71}
72
73impl<T> BufferedTransport<T>
74where
75    T: Transport,
76    T::ProtocolSideChannel: Debug,
77{
78    pub(crate) fn new(transport: T) -> Self {
79        Self {
80            transport,
81            buffer: BytesMut::with_capacity(4096),
82        }
83    }
84
85    fn append(&mut self, data: Bytes) {
86        self.buffer.extend_from_slice(&data);
87    }
88
89    fn decode(&mut self) -> Result<Option<Message<FromServer>>, ReadError> {
90        // Create a partial view of the buffer for parsing
91        let buf = &mut Partial::new(self.buffer.chunk());
92
93        // Attempt to parse a frame from the buffer
94        let item = match frame::parse_frame(buf) {
95            Ok(frame) => Message::<FromServer>::from_frame(frame),
96            // Need more data
97            Err(ErrMode::Incomplete(_)) => return Ok(None),
98            Err(e) => return Err(ReadError::Parser(e)),
99        };
100
101        // Calculate how many bytes were consumed
102        let len = buf.offset_from(&Partial::new(self.buffer.chunk()));
103
104        // Advance the buffer past the consumed bytes
105        self.buffer.advance(len);
106
107        // Return the parsed message (or error)
108        item.map_err(|e| e.into()).map(Some)
109    }
110
111    pub(crate) async fn send(&mut self, message: Message<ToServer>) -> Result<(), WriteError> {
112        self.transport.write(message).await
113    }
114
115    pub(crate) async fn next(&mut self) -> Result<Response<T::ProtocolSideChannel>, ReadError> {
116        loop {
117            let response = self.transport.read().await?;
118            match response {
119                ReadResponse::Binary(buffer) => {
120                    self.append(buffer);
121                }
122                ReadResponse::Custom(custom) => {
123                    return Ok(Response::Custom(custom));
124                }
125            }
126
127            if let Some(message) = self.decode()? {
128                return Ok(Response::Message(message));
129            }
130        }
131    }
132
133    pub(crate) fn into_transport(self) -> T {
134        self.transport
135    }
136
137    pub(crate) fn as_mut_inner(&mut self) -> &mut T {
138        &mut self.transport
139    }
140}