1use core::{fmt::Debug, ops::Deref, str::Utf8Error};
2
3use futures::{Sink, SinkExt, Stream, StreamExt};
4use rand_core::RngCore;
5
6use crate::{
7 WebSocket, WebSocketCloseStatusCode, WebSocketOptions, WebSocketReceiveMessageType,
8 WebSocketSendMessageType, WebSocketSubProtocol, WebSocketType,
9};
10
11pub struct CloseMessage<'a> {
12 pub status_code: WebSocketCloseStatusCode,
13 pub reason: &'a [u8],
14}
15
16pub enum ReadResult<'a> {
17 Binary(&'a [u8]),
18 Text(&'a str),
19 Pong(&'a [u8]),
22 Ping(&'a [u8]),
26 Close(CloseMessage<'a>),
30}
31
32#[derive(Debug)]
33pub enum FramerError<E> {
34 Io(E),
35 FrameTooLarge(usize),
36 Utf8(Utf8Error),
37 HttpHeader(httparse::Error),
38 WebSocket(crate::Error),
39 Disconnected,
40 RxBufferTooSmall(usize),
41}
42
43pub struct Framer<TRng, TWebSocketType>
44where
45 TRng: RngCore,
46 TWebSocketType: WebSocketType,
47{
48 websocket: WebSocket<TRng, TWebSocketType>,
49 frame_cursor: usize,
50 rx_remainder_len: usize,
51}
52
53impl<TRng> Framer<TRng, crate::Client>
54where
55 TRng: RngCore,
56{
57 pub async fn connect<'a, B, E>(
58 &mut self,
59 stream: &mut (impl Stream<Item = Result<B, E>> + Sink<&'a [u8], Error = E> + Unpin),
60 buffer: &'a mut [u8],
61 websocket_options: &WebSocketOptions<'_>,
62 ) -> Result<Option<WebSocketSubProtocol>, FramerError<E>>
63 where
64 B: AsRef<[u8]>,
65 {
66 let (tx_len, web_socket_key) = self
67 .websocket
68 .client_connect(websocket_options, buffer)
69 .map_err(FramerError::WebSocket)?;
70
71 let (tx_buf, rx_buf) = buffer.split_at_mut(tx_len);
72 stream.send(tx_buf).await.map_err(FramerError::Io)?;
73 stream.flush().await.map_err(FramerError::Io)?;
74
75 loop {
76 match stream.next().await {
77 Some(buf) => {
78 let buf = buf.map_err(FramerError::Io)?;
79 let buf = buf.as_ref();
80
81 match self.websocket.client_accept(&web_socket_key, buf) {
82 Ok((len, sub_protocol)) => {
83 let from = len;
88 let to = buf.len();
89 let remaining_len = to - from;
90
91 if remaining_len > 0 {
92 let rx_start = rx_buf.len() - remaining_len;
93 rx_buf[rx_start..].copy_from_slice(&buf[from..to]);
94 self.rx_remainder_len = remaining_len;
95 }
96
97 return Ok(sub_protocol);
98 }
99 Err(crate::Error::HttpHeaderIncomplete) => {
100 panic!("oh no");
102 }
103 Err(e) => {
104 return Err(FramerError::WebSocket(e));
105 }
106 }
107 }
108 None => return Err(FramerError::Disconnected),
109 }
110 }
111 }
112}
113
114impl<TRng, TWebSocketType> Framer<TRng, TWebSocketType>
115where
116 TRng: RngCore,
117 TWebSocketType: WebSocketType,
118{
119 pub fn new(websocket: WebSocket<TRng, TWebSocketType>) -> Self {
120 Self {
121 websocket,
122 frame_cursor: 0,
123 rx_remainder_len: 0,
124 }
125 }
126
127 pub fn encode<E>(
128 &mut self,
129 message_type: WebSocketSendMessageType,
130 end_of_message: bool,
131 from: &[u8],
132 to: &mut [u8],
133 ) -> Result<usize, FramerError<E>> {
134 let len = self
135 .websocket
136 .write(message_type, end_of_message, from, to)
137 .map_err(FramerError::WebSocket)?;
138
139 Ok(len)
140 }
141
142 pub async fn write<'b, E>(
143 &mut self,
144 tx: &mut (impl Sink<&'b [u8], Error = E> + Unpin),
145 tx_buf: &'b mut [u8],
146 message_type: WebSocketSendMessageType,
147 end_of_message: bool,
148 frame_buf: &[u8],
149 ) -> Result<(), FramerError<E>>
150 where
151 E: Debug,
152 {
153 let len = self
154 .websocket
155 .write(message_type, end_of_message, frame_buf, tx_buf)
156 .map_err(FramerError::WebSocket)?;
157
158 tx.send(&tx_buf[..len]).await.map_err(FramerError::Io)?;
159 tx.flush().await.map_err(FramerError::Io)?;
160 Ok(())
161 }
162
163 pub async fn close<'b, E>(
164 &mut self,
165 tx: &mut (impl Sink<&'b [u8], Error = E> + Unpin),
166 tx_buf: &'b mut [u8],
167 close_status: WebSocketCloseStatusCode,
168 status_description: Option<&str>,
169 ) -> Result<(), FramerError<E>>
170 where
171 E: Debug,
172 {
173 let len = self
174 .websocket
175 .close(close_status, status_description, tx_buf)
176 .map_err(FramerError::WebSocket)?;
177
178 tx.send(&tx_buf[..len]).await.map_err(FramerError::Io)?;
179 tx.flush().await.map_err(FramerError::Io)?;
180 Ok(())
181 }
182
183 pub async fn read<'a, B: Deref<Target = [u8]>, E>(
188 &mut self,
189 stream: &mut (impl Stream<Item = Result<B, E>> + Sink<&'a [u8], Error = E> + Unpin),
190 buffer: &'a mut [u8],
191 ) -> Option<Result<ReadResult<'a>, FramerError<E>>>
192 where
193 E: Debug,
194 {
195 if self.rx_remainder_len == 0 {
196 match stream.next().await {
197 Some(Ok(input)) => {
198 if buffer.len() < input.len() {
199 return Some(Err(FramerError::RxBufferTooSmall(input.len())));
200 }
201
202 let rx_start = buffer.len() - input.len();
203
204 buffer[rx_start..].copy_from_slice(&input);
206 self.rx_remainder_len = input.len()
207 }
208 Some(Err(e)) => {
209 return Some(Err(FramerError::Io(e)));
210 }
211 None => return None,
212 }
213 }
214
215 let rx_start = buffer.len() - self.rx_remainder_len;
216 let (frame_buf, rx_buf) = buffer.split_at_mut(rx_start);
217
218 let ws_result = match self.websocket.read(rx_buf, frame_buf) {
219 Ok(ws_result) => ws_result,
220 Err(e) => return Some(Err(FramerError::WebSocket(e))),
221 };
222
223 self.rx_remainder_len -= ws_result.len_from;
224
225 match ws_result.message_type {
226 WebSocketReceiveMessageType::Binary => {
227 self.frame_cursor += ws_result.len_to;
228 if ws_result.end_of_message {
229 let range = 0..self.frame_cursor;
230 self.frame_cursor = 0;
231 return Some(Ok(ReadResult::Binary(&frame_buf[range])));
232 }
233 }
234 WebSocketReceiveMessageType::Text => {
235 self.frame_cursor += ws_result.len_to;
236 if ws_result.end_of_message {
237 let range = 0..self.frame_cursor;
238 self.frame_cursor = 0;
239 match core::str::from_utf8(&frame_buf[range]) {
240 Ok(text) => return Some(Ok(ReadResult::Text(text))),
241 Err(e) => return Some(Err(FramerError::Utf8(e))),
242 }
243 }
244 }
245 WebSocketReceiveMessageType::CloseMustReply => {
246 let range = self.frame_cursor..self.frame_cursor + ws_result.len_to;
247
248 let tx_buf_len = ws_result.len_to + 14; let split_at = frame_buf.len() - tx_buf_len;
251 let (frame_buf, tx_buf) = frame_buf.split_at_mut(split_at);
252
253 match self.websocket.write(
254 WebSocketSendMessageType::CloseReply,
255 true,
256 &frame_buf[range.start..range.end],
257 tx_buf,
258 ) {
259 Ok(len) => match stream.send(&tx_buf[..len]).await {
260 Ok(()) => {
261 self.frame_cursor = 0;
262 let status_code = ws_result
263 .close_status
264 .expect("close message must have code");
265 let reason = &frame_buf[range];
266 return Some(Ok(ReadResult::Close(CloseMessage {
267 status_code,
268 reason,
269 })));
270 }
271 Err(e) => return Some(Err(FramerError::Io(e))),
272 },
273 Err(e) => return Some(Err(FramerError::WebSocket(e))),
274 }
275 }
276 WebSocketReceiveMessageType::CloseCompleted => return None,
277 WebSocketReceiveMessageType::Pong => {
278 let range = self.frame_cursor..self.frame_cursor + ws_result.len_to;
279 return Some(Ok(ReadResult::Pong(&frame_buf[range])));
280 }
281 WebSocketReceiveMessageType::Ping => {
282 let range = self.frame_cursor..self.frame_cursor + ws_result.len_to;
283
284 let tx_buf_len = ws_result.len_to + 14; let split_at = frame_buf.len() - tx_buf_len;
287 let (frame_buf, tx_buf) = frame_buf.split_at_mut(split_at);
288
289 match self.websocket.write(
290 WebSocketSendMessageType::Pong,
291 true,
292 &frame_buf[range.start..range.end],
293 tx_buf,
294 ) {
295 Ok(len) => match stream.send(&tx_buf[..len]).await {
296 Ok(()) => {
297 return Some(Ok(ReadResult::Ping(&frame_buf[range])));
298 }
299 Err(e) => return Some(Err(FramerError::Io(e))),
300 },
301 Err(e) => return Some(Err(FramerError::WebSocket(e))),
302 }
303 }
304 }
305
306 None
307 }
308}