kcp/
protocol.rs

1use ::bytes::{BufMut, Bytes, BytesMut};
2use ::std::{
3    collections::VecDeque,
4    ffi::CStr,
5    io,
6    os::raw::{c_char, c_int, c_long, c_void},
7    ptr::null_mut,
8    slice, str,
9    time::Instant,
10};
11
12#[path = "ffi.rs"]
13mod ffi;
14pub use ffi::*;
15
16pub struct Kcp {
17    handle: usize,
18    time_base: Instant,
19    buffer_size: (usize, usize),
20    write_buf: BytesMut,
21    read_buf: BytesMut,
22    output_queue: VecDeque<BytesMut>,
23}
24
25impl Drop for Kcp {
26    fn drop(&mut self) {
27        if self.handle != 0 {
28            unsafe { ikcp_release(self.as_mut()) }
29            self.handle = 0;
30            self.output_queue.clear();
31        }
32    }
33}
34
35impl Kcp {
36    #[inline]
37    fn as_ref(&self) -> &IKCPCB {
38        unsafe { &*(self.handle as *const IKCPCB) }
39    }
40
41    #[inline]
42    fn as_mut(&mut self) -> &mut IKCPCB {
43        unsafe { &mut *(self.handle as *mut IKCPCB) }
44    }
45}
46
47macro_rules! export_fields {
48    ($($field:ident),+ $(,)?) => {
49        $(
50            #[inline]
51            pub fn $field(&self) -> u32 {
52                self.as_ref().$field as u32
53            }
54        )*
55    };
56}
57
58impl Kcp {
59    pub fn new(conv: u32) -> Self {
60        Self {
61            handle: unsafe { ikcp_create(conv, null_mut()) as *const _ as usize },
62            time_base: Instant::now(),
63            buffer_size: (32768, 32768),
64            write_buf: BytesMut::new(),
65            read_buf: BytesMut::new(),
66            output_queue: VecDeque::with_capacity(32),
67        }
68    }
69
70    pub fn get_system_time(&self) -> u32 {
71        let elapsed = self.time_base.elapsed();
72        (elapsed.as_secs() as u32)
73            .wrapping_mul(1000)
74            .wrapping_add(elapsed.subsec_millis())
75    }
76
77    /// # Warning
78    ///
79    /// After initialization, self must be ***pinned*** in memory.
80    pub fn initialize(&mut self) {
81        unsafe extern "C" fn _writelog(log: *const c_char, _kcp: *mut IKCPCB, _user: *mut c_void) {
82            log::trace!(
83                "{}",
84                str::from_utf8_unchecked(CStr::from_ptr(log).to_bytes())
85            );
86        }
87
88        unsafe extern "C" fn _output(
89            buf: *const c_char,
90            len: c_int,
91            _kcp: *mut IKCPCB,
92            user: *mut c_void,
93        ) -> c_int {
94            let this = &mut *(user as *const _ as *mut Kcp);
95            let size = len as usize;
96            if this.write_buf.capacity() < size {
97                this.write_buf.reserve(this.buffer_size.0.max(size));
98            }
99            this.write_buf
100                .put_slice(slice::from_raw_parts(buf as _, size));
101            this.output_queue.push_back(this.write_buf.split_to(size));
102            size as c_int
103        }
104
105        self.as_mut().user = self as *const _ as _;
106        self.as_mut().output = Some(_output);
107        self.as_mut().writelog = Some(_writelog);
108        self.update(self.get_system_time());
109    }
110
111    /// io::ErrorKind::InvalidInput - buffer is too small to contain a frame.
112    pub fn peek(&mut self, buf: &mut [u8]) -> io::Result<usize> {
113        match unsafe { ikcp_recv(self.as_mut(), buf.as_mut_ptr() as _, -(buf.len() as c_int)) } {
114            size if size >= 0 => Ok(size as usize),
115            -1 | -2 => Ok(0),
116            -3 => Err(io::ErrorKind::InvalidInput.into()),
117            _ => unreachable!(),
118        }
119    }
120
121    /// io::ErrorKind::InvalidInput - buffer is too small to contain a frame.
122    pub fn recv(&mut self, buf: &mut [u8]) -> io::Result<usize> {
123        match unsafe { ikcp_recv(self.as_mut(), buf.as_mut_ptr() as _, buf.len() as c_int) } {
124            size if size >= 0 => Ok(size as usize),
125            -1 | -2 => Ok(0),
126            -3 => Err(io::ErrorKind::InvalidInput.into()),
127            _ => unreachable!(),
128        }
129    }
130
131    pub fn recv_bytes(&mut self) -> Option<Bytes> {
132        let size = self.peek_size();
133        if size > 0 {
134            if self.read_buf.capacity() < size {
135                self.read_buf.reserve(self.buffer_size.1.max(size));
136            }
137            unsafe { self.read_buf.set_len(size) };
138            let mut buf = self.read_buf.split_to(size);
139            let _ = self.recv(&mut buf);
140            Some(buf.freeze())
141        } else {
142            None
143        }
144    }
145
146    #[inline]
147    pub fn peek_size(&self) -> usize {
148        unsafe { ikcp_peeksize(self.as_ref()).max(0) as usize }
149    }
150
151    /// io::ErrorKind::InvalidInput - frame is too large.
152    pub fn send(&mut self, data: &[u8]) -> io::Result<usize> {
153        if data.is_empty() {
154            return Ok(0);
155        }
156        match unsafe { ikcp_send(self.as_mut(), data.as_ptr() as _, data.len() as c_int) } {
157            size if size >= 0 => Ok(size as usize),
158            -1 | -2 => Err(io::ErrorKind::InvalidInput.into()),
159            _ => unreachable!(),
160        }
161    }
162
163    /// ErrorKind::NotFound - conv is inconsistent
164    ///
165    /// ErrorKind::InvalidData - Invalid packet or unrecognized command
166    pub fn input(&mut self, packet: &[u8]) -> io::Result<()> {
167        match unsafe { ikcp_input(self.as_mut(), packet.as_ptr() as _, packet.len() as c_long) } {
168            0 => Ok(()),
169            -1 => Err(io::ErrorKind::NotFound.into()),
170            -2 | -3 => Err(io::ErrorKind::InvalidData.into()),
171            _ => unreachable!(),
172        }
173    }
174
175    #[inline]
176    pub fn flush(&mut self) {
177        unsafe { ikcp_flush(self.as_mut()) }
178    }
179
180    #[inline]
181    pub fn update(&mut self, current: u32) {
182        unsafe { ikcp_update(self.as_mut(), current) }
183    }
184
185    #[inline]
186    pub fn check(&self, current: u32) -> u32 {
187        unsafe { ikcp_check(self.as_ref(), current) }
188    }
189
190    pub fn set_mtu(&mut self, mtu: u32) -> io::Result<()> {
191        if self.as_ref().mtu == mtu {
192            return Ok(());
193        }
194        match unsafe { ikcp_setmtu(self.as_mut(), mtu as c_int) } {
195            0 => Ok(()),
196            -1 => Err(io::ErrorKind::InvalidInput.into()),
197            -2 => Err(io::ErrorKind::OutOfMemory.into()),
198            _ => unreachable!(),
199        }
200    }
201
202    pub fn set_nodelay(&mut self, nodelay: bool, interval: u32, resend: u32, nc: bool) {
203        unsafe {
204            ikcp_nodelay(
205                self.as_mut(),
206                nodelay.into(),
207                interval as c_int,
208                resend as c_int,
209                nc.into(),
210            );
211        }
212    }
213
214    pub fn get_waitsnd(&self) -> u32 {
215        unsafe { ikcp_waitsnd(self.as_ref()) as u32 }
216    }
217
218    pub fn set_wndsize(&mut self, sndwnd: u32, rcvwnd: u32) {
219        unsafe {
220            ikcp_wndsize(self.as_mut(), sndwnd as c_int, rcvwnd as c_int);
221        }
222    }
223}
224
225impl Kcp {
226    export_fields! { conv, current, nsnd_que, nrcv_que, nrcv_buf }
227
228    pub fn duration_since(&self, since: u32) -> u32 {
229        (self.current().wrapping_sub(since) as i32).max(0) as u32
230    }
231
232    pub fn set_logmask(&mut self, logmask: u32) {
233        self.as_mut().logmask = logmask as i32;
234    }
235
236    pub fn set_conv(&mut self, conv: u32) {
237        self.as_mut().conv = conv;
238    }
239
240    pub fn set_stream(&mut self, stream: bool) {
241        self.as_mut().stream = stream.into();
242    }
243
244    #[inline]
245    pub fn is_dead_link(&self) -> bool {
246        self.as_ref().state == u32::MAX
247    }
248
249    #[inline]
250    pub fn is_recv_queue_full(&self) -> bool {
251        self.as_ref().nrcv_que >= self.as_ref().rcv_wnd
252    }
253
254    #[inline]
255    pub fn is_send_queue_full(&self) -> bool {
256        //self.get_waitsnd() >= self.as_ref().snd_wnd
257        self.as_ref().nsnd_que >= self.as_ref().snd_wnd
258    }
259
260    #[inline]
261    pub fn has_ouput(&mut self) -> bool {
262        !self.output_queue.is_empty()
263    }
264
265    #[inline]
266    pub fn pop_output(&mut self) -> Option<BytesMut> {
267        self.output_queue.pop_front()
268    }
269
270    pub fn write_ack_head(&self, buf: &mut BytesMut, cmd_flags: u8, payload_size: usize) {
271        buf.reserve(IKCP_OVERHEAD as usize + payload_size);
272        let kcp = self.as_ref();
273        buf.put_u32_le(kcp.conv);
274        buf.put_u8(IKCP_CMD_ACK as u8 | cmd_flags);
275        buf.put_u8(0);
276        buf.put_u16_le(kcp.rcv_wnd as u16);
277        buf.put_u32_le(self.get_system_time());
278        buf.put_u32_le(kcp.snd_nxt);
279        buf.put_u32_le(kcp.rcv_nxt);
280        buf.put_u32_le(0);
281        buf.put_u32_le(payload_size as u32);
282    }
283
284    /// Read conv from a packet buffer.
285    #[inline]
286    pub fn read_conv(buf: &[u8]) -> Option<u32> {
287        if buf.len() >= IKCP_OVERHEAD as usize {
288            Some(unsafe {
289                (*buf.get_unchecked(0) as u32)
290                    | (*buf.get_unchecked(1) as u32).wrapping_shl(8)
291                    | (*buf.get_unchecked(2) as u32).wrapping_shl(16)
292                    | (*buf.get_unchecked(3) as u32).wrapping_shl(24)
293            })
294        } else {
295            None
296        }
297    }
298
299    /// Read cmd from a packet buffer.
300    #[inline]
301    pub fn read_cmd(buf: &[u8]) -> u8 {
302        buf[4]
303    }
304
305    /// Write cmd to a packet buffer.
306    #[inline]
307    pub fn write_cmd(buf: &mut [u8], cmd: u8) {
308        buf[4] = cmd;
309    }
310
311    /// Get the first segment payload from a packet buffer.
312    #[inline]
313    pub fn read_payload_data(buf: &[u8]) -> Option<&[u8]> {
314        unsafe {
315            let mut p = buf.as_ptr();
316            let mut left = buf.len();
317            while left >= IKCP_OVERHEAD as usize {
318                let len = (*p.wrapping_add(IKCP_OVERHEAD as usize - 4) as usize)
319                    | (*p.wrapping_add(IKCP_OVERHEAD as usize - 3) as usize).wrapping_shl(8)
320                    | (*p.wrapping_add(IKCP_OVERHEAD as usize - 2) as usize).wrapping_shl(16)
321                    | (*p.wrapping_add(IKCP_OVERHEAD as usize - 1) as usize).wrapping_shl(24);
322                p = p.wrapping_add(IKCP_OVERHEAD as usize);
323                left -= IKCP_OVERHEAD as usize;
324                if (1..=left).contains(&len) {
325                    return Some(slice::from_raw_parts(p, len));
326                }
327            }
328        }
329        None
330    }
331}
332
333impl Kcp {
334    /// The conv used for SYN handshake.
335    pub const SYN_CONV: u32 = 0xFFFF_FFFE;
336
337    /// Check if a conv is valid.
338    #[inline]
339    pub fn is_valid_conv(conv: u32) -> bool {
340        conv != 0 && conv < Self::SYN_CONV
341    }
342
343    /// Generate a random conv.
344    pub fn rand_conv() -> u32 {
345        loop {
346            let conv = rand::random();
347            if Self::is_valid_conv(conv) {
348                break conv;
349            }
350        }
351    }
352
353    /// Maximum size of a data frame.
354    pub const fn max_frame_size(mtu: u32) -> u32 {
355        (mtu - IKCP_OVERHEAD) * (IKCP_WND_RCV - 1)
356    }
357}