lightws/stream/
mod.rs

1//! Websocket stream.
2//!
3//! [`Stream`] is a simple wrapper of the underlying IO source,
4//! with small stack buffers to save states.
5//!
6//! It is transparent to call `Read` or `Write` on Stream:
7//!
8//! ```ignore
9//! {
10//!     // establish connection, handshake
11//!     let stream = ...
12//!     // read some data
13//!     stream.read(&mut buf)?;
14//!     // write some data
15//!     stream.write(&buf)?;
16//! }
17//! ```
18//!
19//! A new established [`Stream`] is in [`Direct`] (default) mode, where
20//! a `Read` or `Write` leads to **at most one** syscall, and
21//! an `Ok(0)` will be returned if frame head is not completely read or written.
22//! It can be converted to [`Guarded`] mode with [`Stream::guard`],
23//! which wraps `Read` or `Write` in a loop, where `Ok(0)` is handled internally.
24//!
25//! Stream itself does not buffer any payload data during
26//! a `Read` or `Write`, so there is no extra heap allocation.
27//!
28//! # Masking payload
29//!
30//! Data read from stream are automatically unmasked.
31//! However, data written to stream are **NOT** automatically masked,
32//! since a `Write` call requires an immutable `&[u8]`.
33//!
34//! A standard client(e.g. [`StandardClient`](crate::role::StandardClient))
35//! should mask the payload before sending it;
36//! A non-standard client (e.g. [`Client`](crate::role::Client)) which holds an empty mask key
37//! can simply skip this step.
38//!
39//! The mask key is prepared by [`ClientRole`](crate::role::ClientRole),
40//! which can be set or fetched via [`Stream::set_mask_key`] and [`Stream::mask_key`].
41//!
42//! Example:
43//!
44//! ```no_run
45//! use std::io::{Read, Write};
46//! use std::net::TcpStream;
47//! use lightws::role::StandardClient;
48//! use lightws::endpoint::Endpoint;
49//! use lightws::frame::{new_mask_key, apply_mask4};
50//! fn write_data() -> std::io::Result<()> {  
51//!     let mut buf = [0u8; 256];
52//!     let mut tcp = TcpStream::connect("example.com:80")?;
53//!     let mut ws = Endpoint::<TcpStream, StandardClient>::connect(tcp, &mut buf, "example.com", "/ws")?;
54//!
55//!     // mask data
56//!     let key = new_mask_key();
57//!     apply_mask4(key, &mut buf);
58//!
59//!     // set mask key for next write
60//!     ws.set_mask_key(key)?;
61//!
62//!     // write some data
63//!     ws.write_all(&buf)?;
64//!     Ok(())
65//! }
66//! ```
67//!
68//! # Automatic masking
69//!
70//! It is annoying to mask the payload each time before a write,
71//! and it will block us from using convenient functions like [`std::io::copy`].
72//!
73//! With `unsafe_auto_mask_write` fearure enabled, the provided immutable `&[u8]` will be casted
74//! to a mutable `&mut [u8]` then payload data can be automatically masked.
75//!
76//! This feature only has effects on [`AutoMaskClientRole`](crate::role::AutoMaskClientRole),
77//! where its inner mask key may be updated (depends on
78//! [`AutoMaskClientRole::UPDATE_MASK_KEY`](crate::role::AutoMaskClientRole::UPDATE_MASK_KEY))
79//! and used to mask the payload before each write.
80//! Other [`ClientRole`](crate::role::ClientRole) and [`ServerRole`](crate::role::ServerRole)
81//! are not affected. Related code lies in `src/stream/detail/write#L118`.
82//!
83
84mod read;
85mod write;
86
87mod ctrl;
88mod state;
89mod detail;
90mod special;
91
92cfg_if::cfg_if! {
93    if #[cfg(feature = "async")] {
94        mod async_read;
95        mod async_write;
96    }
97}
98
99use std::marker::PhantomData;
100use state::{ReadState, WriteState, HeartBeat};
101use crate::role::RoleHelper;
102
103/// Direct read or write.
104pub struct Direct {}
105
106/// Wrapped read or write.
107pub struct Guarded {}
108
109/// Websocket stream.
110///
111/// Depending on `IO`, [`Stream`] implements [`std::io::Read`] and [`std::io::Write`]
112/// or [`tokio::io::AsyncRead`] and [`tokio::io::AsyncWrite`].
113///
114/// `Role` decides whether to mask payload data.
115/// It is reserved to provide extra infomation to apply optimizations.
116///
117/// See also: `Stream::read`, `Stream::write`.
118pub struct Stream<IO, Role, Guard = Direct> {
119    io: IO,
120    role: Role,
121    read_state: ReadState,
122    write_state: WriteState,
123    heartbeat: HeartBeat,
124    __marker: PhantomData<Guard>,
125}
126
127impl<IO, Role, Guard> AsRef<IO> for Stream<IO, Role, Guard> {
128    #[inline]
129    fn as_ref(&self) -> &IO { &self.io }
130}
131
132impl<IO, Role, Guard> AsMut<IO> for Stream<IO, Role, Guard> {
133    #[inline]
134    fn as_mut(&mut self) -> &mut IO { &mut self.io }
135}
136
137impl<IO, Role, Guard> std::fmt::Debug for Stream<IO, Role, Guard> {
138    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
139        f.debug_struct("Stream")
140            .field("read_state", &self.read_state)
141            .field("write_state", &self.write_state)
142            .field("heartbeat", &self.heartbeat)
143            .finish()
144    }
145}
146
147impl<IO, Role> Stream<IO, Role> {
148    /// Create websocket stream from IO source directly,
149    /// without a handshake.
150    #[inline]
151    pub const fn new(io: IO, role: Role) -> Self {
152        Stream {
153            io,
154            role,
155            read_state: ReadState::new(),
156            write_state: WriteState::new(),
157            heartbeat: HeartBeat::new(),
158            __marker: PhantomData,
159        }
160    }
161
162    /// Convert to a guarded stream.
163    #[inline]
164    pub fn guard(self) -> Stream<IO, Role, Guarded> {
165        Stream {
166            io: self.io,
167            role: self.role,
168            read_state: self.read_state,
169            write_state: self.write_state,
170            heartbeat: self.heartbeat,
171            __marker: PhantomData,
172        }
173    }
174}
175
176#[cfg(test)]
177mod test {
178    use super::*;
179    use std::io::{Read, Write, Result};
180    use crate::frame::*;
181    use crate::role::*;
182
183    pub struct LimitReadWriter {
184        pub buf: Vec<u8>,
185        pub rlimit: usize,
186        pub wlimit: usize,
187        pub cursor: usize,
188    }
189
190    impl Read for LimitReadWriter {
191        fn read(&mut self, mut buf: &mut [u8]) -> Result<usize> {
192            let to_read = std::cmp::min(buf.len(), self.rlimit);
193            let left_data = self.buf.len() - self.cursor;
194            if left_data == 0 {
195                return Ok(0);
196            }
197            if left_data <= to_read {
198                buf.write(&self.buf[self.cursor..]).unwrap();
199                self.cursor = self.buf.len();
200                return Ok(left_data);
201            }
202
203            buf.write(&self.buf[self.cursor..self.cursor + to_read])
204                .unwrap();
205            self.cursor += to_read;
206            Ok(to_read)
207        }
208    }
209
210    impl Write for LimitReadWriter {
211        fn write(&mut self, buf: &[u8]) -> Result<usize> {
212            let len = std::cmp::min(buf.len(), self.wlimit);
213            self.buf.write(&buf[..len])
214        }
215
216        fn flush(&mut self) -> Result<()> { Ok(()) }
217    }
218
219    pub fn make_head(opcode: OpCode, mask: Mask, len: usize) -> Vec<u8> {
220        let mut tmp = vec![0; 14];
221        let head = FrameHead::new(Fin::Y, opcode, mask, PayloadLen::from_num(len as u64));
222
223        let head_len = head.encode(&mut tmp).unwrap();
224        let mut head = Vec::new();
225        let write_n = head.write(&tmp[..head_len]).unwrap();
226        assert_eq!(write_n, head_len);
227        head
228    }
229
230    pub fn make_data(len: usize) -> Vec<u8> {
231        std::iter::repeat(rand::random::<u8>()).take(len).collect()
232    }
233
234    pub fn make_frame<R: RoleHelper>(opcode: OpCode, len: usize) -> (Vec<u8>, Vec<u8>) {
235        make_frame_with_mask(opcode, R::new().mask_key(), len)
236    }
237
238    // data is unmasked
239    pub fn make_frame_with_mask(opcode: OpCode, mask: Mask, len: usize) -> (Vec<u8>, Vec<u8>) {
240        let data = make_data(len);
241        let mut data2 = data.clone();
242
243        let mut frame = make_head(opcode, mask, len);
244        let head_len = frame.len();
245
246        frame.append(&mut data2);
247        assert_eq!(frame.len(), len + head_len);
248
249        (frame, data)
250    }
251
252    #[test]
253    fn read_write_stream() {
254        fn read_write<R: RoleHelper>(rlimit: usize, wlimit: usize, len: usize) {
255            let io = LimitReadWriter {
256                buf: Vec::new(),
257                rlimit,
258                wlimit,
259                cursor: 0,
260            };
261            // data written to a client stream should be read as a server stream.
262            // here we read/write on the same (client/server)stream.
263            // this is not correct in practice, but our program can still handle it.
264            let mut stream = Stream::<_, R>::new(io, R::new());
265
266            let data: Vec<u8> = std::iter::repeat(rand::random::<u8>()).take(len).collect();
267            let mut data2: Vec<u8> = Vec::new();
268
269            let mut buf = vec![0; 0x2000];
270            let mut to_write = data.len();
271
272            while to_write > 0 {
273                let wbeg = data.len() - to_write;
274                let n = loop {
275                    let x = stream.write(&data[wbeg..]).unwrap();
276                    if x != 0 {
277                        break x;
278                    }
279                };
280
281                let mut tmp: Vec<u8> = Vec::new();
282                loop {
283                    // avoid read EOF here
284                    if stream.as_ref().cursor == stream.as_ref().buf.len() {
285                        break;
286                    }
287                    let n = stream.read(&mut buf).unwrap();
288
289                    // if n == 0 && stream.is_read_end() {
290                    //     break;
291                    // }
292
293                    tmp.write(&buf[..n]).unwrap();
294                }
295
296                assert_eq!(tmp.len(), n);
297                assert_eq!(&data[wbeg..wbeg + n], &tmp);
298
299                to_write -= n;
300                data2.append(&mut tmp);
301            }
302
303            assert_eq!(&data, &data2);
304        }
305
306        for limit in 1..512 {
307            for len in 1..=256 {
308                read_write::<Client>(limit, 512 - limit, len);
309                read_write::<Server>(limit, 512 - limit, len);
310            }
311        }
312    }
313}