1use crate::{
9 WebSocket, WebSocketCloseStatusCode, WebSocketContext, WebSocketOptions,
10 WebSocketReceiveMessageType, WebSocketSendMessageType, WebSocketState, WebSocketSubProtocol,
11 WebSocketType,
12};
13use core::{cmp::min, str::Utf8Error};
14use rand_core::RngCore;
15
16#[cfg(feature = "std")]
19impl Stream<std::io::Error> for std::net::TcpStream {
20 fn read(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
21 std::io::Read::read(self, buf)
22 }
23
24 fn write_all(&mut self, buf: &[u8]) -> Result<(), std::io::Error> {
25 std::io::Write::write_all(self, buf)
26 }
27}
28
29pub trait Stream<E> {
30 fn read(&mut self, buf: &mut [u8]) -> Result<usize, E>;
31 fn write_all(&mut self, buf: &[u8]) -> Result<(), E>;
32}
33
34pub enum ReadResult<'a> {
35 Binary(&'a [u8]),
36 Text(&'a str),
37 Pong(&'a [u8]),
38 Closed,
39}
40
41#[derive(Debug)]
42pub enum FramerError<E> {
43 Io(E),
44 FrameTooLarge(usize),
45 Utf8(Utf8Error),
46 HttpHeader(httparse::Error),
47 WebSocket(crate::Error),
48}
49
50pub struct Framer<'a, TRng, TWebSocketType>
51where
52 TRng: RngCore,
53 TWebSocketType: WebSocketType,
54{
55 read_buf: &'a mut [u8],
56 write_buf: &'a mut [u8],
57 read_cursor: &'a mut usize,
58 frame_cursor: usize,
59 read_len: usize,
60 websocket: &'a mut WebSocket<TRng, TWebSocketType>,
61}
62
63impl<'a, TRng> Framer<'a, TRng, crate::Client>
64where
65 TRng: RngCore,
66{
67 pub fn connect<E>(
68 &mut self,
69 stream: &mut impl Stream<E>,
70 websocket_options: &WebSocketOptions,
71 ) -> Result<Option<WebSocketSubProtocol>, FramerError<E>> {
72 let (len, web_socket_key) = self
73 .websocket
74 .client_connect(websocket_options, self.write_buf)
75 .map_err(FramerError::WebSocket)?;
76 stream
77 .write_all(&self.write_buf[..len])
78 .map_err(FramerError::Io)?;
79 *self.read_cursor = 0;
80
81 loop {
82 let received_size = stream
84 .read(&mut self.read_buf[*self.read_cursor..])
85 .map_err(FramerError::Io)?;
86
87 match self.websocket.client_accept(
88 &web_socket_key,
89 &self.read_buf[..*self.read_cursor + received_size],
90 ) {
91 Ok((len, sub_protocol)) => {
92 *self.read_cursor += received_size - len;
95 return Ok(sub_protocol);
96 }
97 Err(crate::Error::HttpHeaderIncomplete) => {
98 *self.read_cursor += received_size;
99 }
101 Err(e) => {
102 *self.read_cursor += received_size;
103 return Err(FramerError::WebSocket(e));
104 }
105 }
106 }
107 }
108}
109
110impl<'a, TRng> Framer<'a, TRng, crate::Server>
111where
112 TRng: RngCore,
113{
114 pub fn accept<E>(
115 &mut self,
116 stream: &mut impl Stream<E>,
117 websocket_context: &WebSocketContext,
118 ) -> Result<(), FramerError<E>> {
119 let len = self
120 .websocket
121 .server_accept(&websocket_context.sec_websocket_key, None, self.write_buf)
122 .map_err(FramerError::WebSocket)?;
123
124 stream
125 .write_all(&self.write_buf[..len])
126 .map_err(FramerError::Io)?;
127 Ok(())
128 }
129}
130
131impl<'a, TRng, TWebSocketType> Framer<'a, TRng, TWebSocketType>
132where
133 TRng: RngCore,
134 TWebSocketType: WebSocketType,
135{
136 pub fn new(
139 read_buf: &'a mut [u8],
140 read_cursor: &'a mut usize,
141 write_buf: &'a mut [u8],
142 websocket: &'a mut WebSocket<TRng, TWebSocketType>,
143 ) -> Self {
144 Self {
145 read_buf,
146 write_buf,
147 read_cursor,
148 frame_cursor: 0,
149 read_len: 0,
150 websocket,
151 }
152 }
153
154 pub fn state(&self) -> WebSocketState {
155 self.websocket.state
156 }
157
158 pub fn close<E>(
160 &mut self,
161 stream: &mut impl Stream<E>,
162 close_status: WebSocketCloseStatusCode,
163 status_description: Option<&str>,
164 ) -> Result<(), FramerError<E>> {
165 let len = self
166 .websocket
167 .close(close_status, status_description, self.write_buf)
168 .map_err(FramerError::WebSocket)?;
169 stream
170 .write_all(&self.write_buf[..len])
171 .map_err(FramerError::Io)?;
172 Ok(())
173 }
174
175 pub fn write<E>(
176 &mut self,
177 stream: &mut impl Stream<E>,
178 message_type: WebSocketSendMessageType,
179 end_of_message: bool,
180 frame_buf: &[u8],
181 ) -> Result<(), FramerError<E>> {
182 let len = self
183 .websocket
184 .write(message_type, end_of_message, frame_buf, self.write_buf)
185 .map_err(FramerError::WebSocket)?;
186 stream
187 .write_all(&self.write_buf[..len])
188 .map_err(FramerError::Io)?;
189 Ok(())
190 }
191
192 pub fn read<'b, E>(
196 &mut self,
197 stream: &mut impl Stream<E>,
198 frame_buf: &'b mut [u8],
199 ) -> Result<ReadResult<'b>, FramerError<E>> {
200 loop {
201 if *self.read_cursor == 0 || *self.read_cursor == self.read_len {
202 self.read_len = stream.read(self.read_buf).map_err(FramerError::Io)?;
203 *self.read_cursor = 0;
204 }
205
206 if self.read_len == 0 {
207 return Ok(ReadResult::Closed);
208 }
209
210 loop {
211 if *self.read_cursor == self.read_len {
212 break;
213 }
214
215 if self.frame_cursor == frame_buf.len() {
216 return Err(FramerError::FrameTooLarge(frame_buf.len()));
217 }
218
219 let ws_result = self
220 .websocket
221 .read(
222 &self.read_buf[*self.read_cursor..self.read_len],
223 &mut frame_buf[self.frame_cursor..],
224 )
225 .map_err(FramerError::WebSocket)?;
226
227 *self.read_cursor += ws_result.len_from;
228
229 match ws_result.message_type {
230 WebSocketReceiveMessageType::Binary => {
231 self.frame_cursor += ws_result.len_to;
232 if ws_result.end_of_message {
233 let frame = &frame_buf[..self.frame_cursor];
234 self.frame_cursor = 0;
235 return Ok(ReadResult::Binary(frame));
236 }
237 }
238 WebSocketReceiveMessageType::Text => {
239 self.frame_cursor += ws_result.len_to;
240 if ws_result.end_of_message {
241 let frame = &frame_buf[..self.frame_cursor];
242 self.frame_cursor = 0;
243 let text = core::str::from_utf8(frame).map_err(FramerError::Utf8)?;
244 return Ok(ReadResult::Text(text));
245 }
246 }
247 WebSocketReceiveMessageType::CloseMustReply => {
248 self.send_back(
249 stream,
250 frame_buf,
251 ws_result.len_to,
252 WebSocketSendMessageType::CloseReply,
253 )?;
254 return Ok(ReadResult::Closed);
255 }
256 WebSocketReceiveMessageType::CloseCompleted => return Ok(ReadResult::Closed),
257 WebSocketReceiveMessageType::Ping => {
258 self.send_back(
259 stream,
260 frame_buf,
261 ws_result.len_to,
262 WebSocketSendMessageType::Pong,
263 )?;
264 }
265 WebSocketReceiveMessageType::Pong => {
266 let bytes =
267 &frame_buf[self.frame_cursor..self.frame_cursor + ws_result.len_to];
268 return Ok(ReadResult::Pong(bytes));
269 }
270 }
271 }
272 }
273 }
274
275 fn send_back<E>(
276 &mut self,
277 stream: &mut impl Stream<E>,
278 frame_buf: &'_ mut [u8],
279 len_to: usize,
280 send_message_type: WebSocketSendMessageType,
281 ) -> Result<(), FramerError<E>> {
282 let payload_len = min(self.write_buf.len(), len_to);
283 let from = &frame_buf[self.frame_cursor..self.frame_cursor + payload_len];
284 let len = self
285 .websocket
286 .write(send_message_type, true, from, self.write_buf)
287 .map_err(FramerError::WebSocket)?;
288 stream
289 .write_all(&self.write_buf[..len])
290 .map_err(FramerError::Io)?;
291 Ok(())
292 }
293}
294
295#[cfg(test)]
300mod tests {
301 extern crate std;
302
303 use super::*;
304 use crate::{WebSocketClient, WebSocketOpCode, WebSocketServer};
305
306 struct DummyStream {
307 pub read_buf: Vec<u8>, pub read_cursor: usize,
309 pub write_buf: Vec<u8>, }
311
312 impl DummyStream {
313 pub fn new(read_buf: Vec<u8>) -> Self {
314 Self {
315 read_buf,
316 read_cursor: 0,
317 write_buf: Vec::new(),
318 }
319 }
320 }
321
322 impl<E> Stream<E> for DummyStream {
323 fn read(&mut self, buf: &mut [u8]) -> Result<usize, E> {
324 let len = buf.len().min(self.read_buf.len() - self.read_cursor);
325 buf[..len].copy_from_slice(&self.read_buf[self.read_cursor..self.read_cursor + len]);
326 self.read_cursor += len;
327 Ok(len)
328 }
329
330 fn write_all(&mut self, buf: &[u8]) -> Result<(), E> {
331 self.write_buf.extend_from_slice(buf);
332 Ok(())
333 }
334 }
335
336 #[test]
337 fn fragmented_frames() {
338 let mut read_buf = vec![0; 1024];
339 let mut write_buf = vec![0; 1024];
340 let mut read_cursor = 0;
341 let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
342 ws_client.state = WebSocketState::Open;
343 let mut client_framer = Framer::new(
344 &mut read_buf,
345 &mut read_cursor,
346 &mut write_buf,
347 &mut ws_client,
348 );
349
350 let mut frame_buf = vec![0; 1024];
351 let frames = get_fragmented_frames();
352 let mut stream = DummyStream::new(frames);
353 let frame = client_framer
354 .read::<()>(&mut stream, &mut frame_buf)
355 .unwrap();
356 match frame {
357 ReadResult::Text(x) => {
358 assert_eq!(x, "hello world!")
359 }
360 _ => panic!("expected text frame"),
361 };
362 }
363
364 #[test]
365 fn fragmented_frames_with_control_frames() {
366 let mut read_buf = vec![0; 1024];
367 let mut write_buf = vec![0; 1024];
368 let mut read_cursor = 0;
369 let mut ws_client = WebSocketClient::new_client(rand::thread_rng());
370 ws_client.state = WebSocketState::Open;
371 let mut client_framer = Framer::new(
372 &mut read_buf,
373 &mut read_cursor,
374 &mut write_buf,
375 &mut ws_client,
376 );
377
378 let mut frame_buf = vec![0; 1024];
379 let frames = get_fragmented_frames_with_ping();
380 let mut stream = DummyStream::new(frames);
381 let frame = client_framer
382 .read::<()>(&mut stream, &mut frame_buf)
383 .unwrap();
384 match frame {
385 ReadResult::Text(x) => {
386 assert_eq!(x, "hello world!")
387 }
388 _ => panic!("expected text frame"),
389 };
390 }
391
392 fn get_fragmented_frames() -> Vec<u8> {
393 let mut messages = Vec::new();
399 let mut scratch = vec![0u8; 1024];
400 let mut ws_server = WebSocketServer::new_server();
401 let len = ws_server
402 .write_frame(
403 "hello".as_bytes(),
404 &mut scratch,
405 WebSocketOpCode::TextFrame,
406 false,
407 )
408 .unwrap();
409 messages.extend_from_slice(&scratch[..len]);
410 let len = ws_server
411 .write_frame(
412 " world".as_bytes(),
413 &mut scratch,
414 WebSocketOpCode::ContinuationFrame,
415 false,
416 )
417 .unwrap();
418 messages.extend_from_slice(&scratch[..len]);
419
420 let len = ws_server
421 .write_frame(
422 "!".as_bytes(),
423 &mut scratch,
424 WebSocketOpCode::ContinuationFrame,
425 true,
426 )
427 .unwrap();
428 messages.extend_from_slice(&scratch[..len]);
429 messages
430 }
431
432 fn get_fragmented_frames_with_ping() -> Vec<u8> {
433 let mut messages = Vec::new();
439 let mut scratch = vec![0u8; 1024];
440 let mut ws_server = WebSocketServer::new_server();
441 let len = ws_server
442 .write_frame(
443 "hello".as_bytes(),
444 &mut scratch,
445 WebSocketOpCode::TextFrame,
446 false,
447 )
448 .unwrap();
449 messages.extend_from_slice(&scratch[..len]);
450 let len = ws_server
451 .write_frame(b"", &mut scratch, WebSocketOpCode::Ping, true)
452 .unwrap();
453 messages.extend_from_slice(&scratch[..len]);
454 let len = ws_server
455 .write_frame(
456 " world!".as_bytes(),
457 &mut scratch,
458 WebSocketOpCode::ContinuationFrame,
459 true,
460 )
461 .unwrap();
462 messages.extend_from_slice(&scratch[..len]);
463 messages
464 }
465}