1use core::{fmt::Debug, ops::Deref, str::Utf8Error};
2
3use futures::{Sink, SinkExt, Stream, StreamExt};
4use rand_core::RngCore;
5
6use crate::{
7 WebSocket, WebSocketCloseStatusCode, WebSocketContext, WebSocketOptions,
8 WebSocketReceiveMessageType, WebSocketSendMessageType, WebSocketSubProtocol, WebSocketType,
9};
10
11pub struct CloseMessage<'a> {
12 pub status_code: WebSocketCloseStatusCode,
13 pub reason: &'a [u8],
14}
15
16pub enum ReadResult<'a> {
17 Binary(&'a [u8]),
18 Text(&'a str),
19 Pong(&'a [u8]),
22 Ping(&'a [u8]),
26 Close(CloseMessage<'a>),
30}
31
32#[derive(Debug)]
33pub enum FramerError<E> {
34 Io(E),
35 FrameTooLarge(usize),
36 Utf8(Utf8Error),
37 HttpHeader(httparse::Error),
38 WebSocket(crate::Error),
39 Disconnected,
40 RxBufferTooSmall(usize),
41}
42
43pub struct Framer<TRng, TWebSocketType>
44where
45 TRng: RngCore,
46 TWebSocketType: WebSocketType,
47{
48 websocket: WebSocket<TRng, TWebSocketType>,
49 frame_cursor: usize,
50 rx_remainder_len: usize,
51}
52
53impl<TRng> Framer<TRng, crate::Server>
54where
55 TRng: RngCore,
56{
57 pub async fn accept<'a, B, E>(
58 &mut self,
59 stream: &mut (impl Stream<Item = Result<B, E>> + Sink<&'a [u8], Error = E> + Unpin),
60 buffer: &'a mut [u8],
61 websocket_context: &WebSocketContext,
62 ) -> Result<(), FramerError<E>> {
63 let len = self
64 .websocket
65 .server_accept(&websocket_context.sec_websocket_key, None, buffer)
66 .map_err(FramerError::WebSocket)?;
67
68 stream.send(&buffer[..len]).await.map_err(FramerError::Io)?;
69 Ok(())
70 }
71}
72
73impl<TRng> Framer<TRng, crate::Client>
74where
75 TRng: RngCore,
76{
77 pub async fn connect<'a, B, E>(
78 &mut self,
79 stream: &mut (impl Stream<Item = Result<B, E>> + Sink<&'a [u8], Error = E> + Unpin),
80 buffer: &'a mut [u8],
81 websocket_options: &WebSocketOptions<'_>,
82 ) -> Result<Option<WebSocketSubProtocol>, FramerError<E>>
83 where
84 B: AsRef<[u8]>,
85 {
86 let (tx_len, web_socket_key) = self
87 .websocket
88 .client_connect(websocket_options, buffer)
89 .map_err(FramerError::WebSocket)?;
90
91 let (tx_buf, rx_buf) = buffer.split_at_mut(tx_len);
92 stream.send(tx_buf).await.map_err(FramerError::Io)?;
93 stream.flush().await.map_err(FramerError::Io)?;
94
95 match stream.next().await {
96 Some(buf) => {
97 let buf = buf.map_err(FramerError::Io)?;
98 let buf = buf.as_ref();
99
100 match self.websocket.client_accept(&web_socket_key, buf) {
101 Ok((len, sub_protocol)) => {
102 let from = len;
107 let to = buf.len();
108 let remaining_len = to - from;
109
110 if remaining_len > 0 {
111 let rx_start = rx_buf.len() - remaining_len;
112 rx_buf[rx_start..].copy_from_slice(&buf[from..to]);
113 self.rx_remainder_len = remaining_len;
114 }
115
116 Ok(sub_protocol)
117 }
118 Err(crate::Error::HttpHeaderIncomplete) => {
119 panic!("http header not complete");
121 }
122 Err(e) => Err(FramerError::WebSocket(e)),
123 }
124 }
125 None => Err(FramerError::Disconnected),
126 }
127 }
128}
129
130impl<TRng, TWebSocketType> Framer<TRng, TWebSocketType>
131where
132 TRng: RngCore,
133 TWebSocketType: WebSocketType,
134{
135 pub fn new(websocket: WebSocket<TRng, TWebSocketType>) -> Self {
136 Self {
137 websocket,
138 frame_cursor: 0,
139 rx_remainder_len: 0,
140 }
141 }
142
143 pub fn encode<E>(
144 &mut self,
145 message_type: WebSocketSendMessageType,
146 end_of_message: bool,
147 from: &[u8],
148 to: &mut [u8],
149 ) -> Result<usize, FramerError<E>> {
150 let len = self
151 .websocket
152 .write(message_type, end_of_message, from, to)
153 .map_err(FramerError::WebSocket)?;
154
155 Ok(len)
156 }
157
158 pub async fn write<'b, E>(
159 &mut self,
160 tx: &mut (impl Sink<&'b [u8], Error = E> + Unpin),
161 tx_buf: &'b mut [u8],
162 message_type: WebSocketSendMessageType,
163 end_of_message: bool,
164 frame_buf: &[u8],
165 ) -> Result<(), FramerError<E>>
166 where
167 E: Debug,
168 {
169 let len = self
170 .websocket
171 .write(message_type, end_of_message, frame_buf, tx_buf)
172 .map_err(FramerError::WebSocket)?;
173
174 tx.send(&tx_buf[..len]).await.map_err(FramerError::Io)?;
175 tx.flush().await.map_err(FramerError::Io)?;
176 Ok(())
177 }
178
179 pub async fn close<'b, E>(
180 &mut self,
181 tx: &mut (impl Sink<&'b [u8], Error = E> + Unpin),
182 tx_buf: &'b mut [u8],
183 close_status: WebSocketCloseStatusCode,
184 status_description: Option<&str>,
185 ) -> Result<(), FramerError<E>>
186 where
187 E: Debug,
188 {
189 let len = self
190 .websocket
191 .close(close_status, status_description, tx_buf)
192 .map_err(FramerError::WebSocket)?;
193
194 tx.send(&tx_buf[..len]).await.map_err(FramerError::Io)?;
195 tx.flush().await.map_err(FramerError::Io)?;
196 Ok(())
197 }
198
199 pub async fn read<'a, B: Deref<Target = [u8]>, E>(
204 &mut self,
205 stream: &mut (impl Stream<Item = Result<B, E>> + Sink<&'a [u8], Error = E> + Unpin),
206 buffer: &'a mut [u8],
207 ) -> Option<Result<ReadResult<'a>, FramerError<E>>>
208 where
209 E: Debug,
210 {
211 if self.rx_remainder_len == 0 {
212 match stream.next().await {
213 Some(Ok(input)) => {
214 if buffer.len() < input.len() {
215 return Some(Err(FramerError::RxBufferTooSmall(input.len())));
216 }
217
218 let rx_start = buffer.len() - input.len();
219
220 buffer[rx_start..].copy_from_slice(&input);
222 self.rx_remainder_len = input.len()
223 }
224 Some(Err(e)) => {
225 return Some(Err(FramerError::Io(e)));
226 }
227 None => return None,
228 }
229 }
230
231 let rx_start = buffer.len() - self.rx_remainder_len;
232 let (frame_buf, rx_buf) = buffer.split_at_mut(rx_start);
233
234 let ws_result = match self.websocket.read(rx_buf, frame_buf) {
235 Ok(ws_result) => ws_result,
236 Err(e) => return Some(Err(FramerError::WebSocket(e))),
237 };
238
239 self.rx_remainder_len -= ws_result.len_from;
240
241 match ws_result.message_type {
242 WebSocketReceiveMessageType::Binary => {
243 self.frame_cursor += ws_result.len_to;
244 if ws_result.end_of_message {
245 let range = 0..self.frame_cursor;
246 self.frame_cursor = 0;
247 return Some(Ok(ReadResult::Binary(&frame_buf[range])));
248 }
249 }
250 WebSocketReceiveMessageType::Text => {
251 self.frame_cursor += ws_result.len_to;
252 if ws_result.end_of_message {
253 let range = 0..self.frame_cursor;
254 self.frame_cursor = 0;
255 match core::str::from_utf8(&frame_buf[range]) {
256 Ok(text) => return Some(Ok(ReadResult::Text(text))),
257 Err(e) => return Some(Err(FramerError::Utf8(e))),
258 }
259 }
260 }
261 WebSocketReceiveMessageType::CloseMustReply => {
262 let range = self.frame_cursor..self.frame_cursor + ws_result.len_to;
263
264 let tx_buf_len = ws_result.len_to + 14; let split_at = frame_buf.len() - tx_buf_len;
267 let (frame_buf, tx_buf) = frame_buf.split_at_mut(split_at);
268
269 match self.websocket.write(
270 WebSocketSendMessageType::CloseReply,
271 true,
272 &frame_buf[range.start..range.end],
273 tx_buf,
274 ) {
275 Ok(len) => match stream.send(&tx_buf[..len]).await {
276 Ok(()) => {
277 self.frame_cursor = 0;
278 let status_code = ws_result
279 .close_status
280 .expect("close message must have code");
281 let reason = &frame_buf[range];
282 return Some(Ok(ReadResult::Close(CloseMessage {
283 status_code,
284 reason,
285 })));
286 }
287 Err(e) => return Some(Err(FramerError::Io(e))),
288 },
289 Err(e) => return Some(Err(FramerError::WebSocket(e))),
290 }
291 }
292 WebSocketReceiveMessageType::CloseCompleted => return None,
293 WebSocketReceiveMessageType::Pong => {
294 let range = self.frame_cursor..self.frame_cursor + ws_result.len_to;
295 return Some(Ok(ReadResult::Pong(&frame_buf[range])));
296 }
297 WebSocketReceiveMessageType::Ping => {
298 let range = self.frame_cursor..self.frame_cursor + ws_result.len_to;
299
300 let tx_buf_len = ws_result.len_to + 14; let split_at = frame_buf.len() - tx_buf_len;
303 let (frame_buf, tx_buf) = frame_buf.split_at_mut(split_at);
304
305 match self.websocket.write(
306 WebSocketSendMessageType::Pong,
307 true,
308 &frame_buf[range.start..range.end],
309 tx_buf,
310 ) {
311 Ok(len) => match stream.send(&tx_buf[..len]).await {
312 Ok(()) => {
313 return Some(Ok(ReadResult::Ping(&frame_buf[range])));
314 }
315 Err(e) => return Some(Err(FramerError::Io(e))),
316 },
317 Err(e) => return Some(Err(FramerError::WebSocket(e))),
318 }
319 }
320 }
321
322 None
323 }
324}