libutp_rs/
wrappers.rs

1//! Thin wrappers around the original [libutp](https://github.com/bittorrent/libutp) interface
2//!
3//! [`UtpContext`] and [`UtpSocket`] are stateless wrappers around their
4//! [libutp](https://github.com/bittorrent/libutp) counterparts [`utp_context`] and [`utp_socket`]
5//! and provide more convenient interfaces to them.
6//!
7//! [`UtpContextHandle`] and [`UtpSocketHandle`] represent "ownership" of a [`UtpContext`] and [`UtpSocket`],
8//! meaning that the underlying libutp entity will be cleaned up when they are dropped.
9
10use crate::addrinfo::*;
11use crate::Result;
12use libutp_sys::*;
13
14use std::collections::HashMap;
15use std::convert::TryInto;
16use std::ffi::c_void;
17use std::io;
18use std::marker::PhantomData;
19use std::net::SocketAddr;
20use std::ops::{Deref, DerefMut};
21
22type UtpCallback<C, S> = Box<(dyn for<'r> FnMut(UtpCallbackArgs<'r, C, S>) -> u64)>;
23
24pub struct UtpContextHandle<C, S> {
25    ctx: UtpContext<C, S>,
26}
27
28impl<C, S> Deref for UtpContextHandle<C, S> {
29    type Target = UtpContext<C, S>;
30
31    fn deref(&self) -> &Self::Target {
32        &self.ctx
33    }
34}
35
36impl<C, S> DerefMut for UtpContextHandle<C, S> {
37    fn deref_mut(&mut self) -> &mut Self::Target {
38        &mut self.ctx
39    }
40}
41
42impl<C, S> Default for UtpContextHandle<C, S> {
43    fn default() -> Self {
44        let inner = unsafe {
45            let inner = utp_init(2);
46            utp_context_set_userdata(
47                inner,
48                Box::into_raw(Box::new(ContextData::<C, S> {
49                    data: std::ptr::null_mut(),
50                    callbacks: Default::default(),
51                })) as *mut c_void,
52            );
53            inner
54        };
55
56        UtpContextHandle {
57            ctx: UtpContext::wrap(inner),
58        }
59    }
60}
61
62impl<C, S> Drop for UtpContextHandle<C, S> {
63    fn drop(&mut self) {
64        unsafe {
65            let ContextData::<C, S> { data, .. } =
66                try_cast_ref_mut(utp_context_get_userdata(self.ctx.inner)).unwrap();
67            if !data.is_null() {
68                Box::from_raw(*data as *mut C);
69            }
70            // if we don't do `let _ctx_data`, callbacks will be dropped before the context is destroyed
71            let _ctx_data =
72                Box::from_raw(utp_context_get_userdata(self.ctx.inner) as *mut ContextData<C, S>);
73            utp_destroy(self.ctx.inner);
74        };
75    }
76}
77
78pub struct UtpContext<C, S> {
79    inner: *mut utp_context,
80    context_data_type: PhantomData<C>,
81    socket_data_type: PhantomData<S>,
82}
83
84impl<C, S> UtpContext<C, S> {
85    pub fn wrap(inner: *mut utp_context) -> UtpContext<C, S> {
86        UtpContext {
87            inner,
88            context_data_type: PhantomData,
89            socket_data_type: PhantomData,
90        }
91    }
92
93    pub unsafe fn connect(&self, addr: SocketAddr) -> Result<UtpSocketHandle<S>> {
94        let socket = utp_create_socket(self.inner);
95        let sockaddrinfo = *getaddrinfo_from_std(addr)?;
96        match utp_connect(
97            socket,
98            sockaddrinfo.ai_addr,
99            sockaddrinfo.ai_addrlen.try_into().unwrap(),
100        ) {
101            0 => Ok(UtpSocketHandle {
102                socket: UtpSocket::wrap(socket).unwrap(),
103            }),
104            _ => {
105                utp_close(socket);
106                Err(io::Error::new(
107                    io::ErrorKind::Other,
108                    "utp_connect returned non-zero error code",
109                ))
110            }
111        }
112    }
113
114    pub unsafe fn utp_issue_deferred_acks(&self) {
115        utp_issue_deferred_acks(self.inner);
116    }
117
118    pub unsafe fn utp_process_udp(&self, from: SocketAddr, buf: &[u8]) -> bool {
119        let ai = *getaddrinfo_from_std(from).unwrap();
120        utp_process_udp(
121            self.inner,
122            buf.as_ptr(),
123            buf.len().try_into().unwrap(),
124            ai.ai_addr,
125            ai.ai_addrlen.try_into().unwrap(),
126        ) != 0
127    }
128
129    pub unsafe fn utp_check_timeouts(&self) {
130        utp_check_timeouts(self.inner);
131    }
132
133    pub unsafe fn set_context_data(&self, data: C) {
134        let ContextData::<C, S> { data: old_data, .. } =
135            try_cast_ref_mut(utp_context_get_userdata(self.inner)).unwrap();
136        if !old_data.is_null() {
137            Box::from_raw(*old_data as *mut C);
138        }
139        *old_data = Box::into_raw(Box::new(data)) as *mut c_void;
140    }
141
142    unsafe fn get_callback(&self, event: UtpEvent) -> Option<&mut UtpCallback<C, S>> {
143        let ContextData::<C, S> { callbacks, .. } =
144            try_cast_ref_mut(utp_context_get_userdata(self.inner)).unwrap();
145        callbacks.get_mut(&event)
146    }
147
148    pub unsafe fn get_context_data(&self) -> &C {
149        let ContextData::<C, S> { data, .. } =
150            try_cast_ref(utp_context_get_userdata(self.inner)).unwrap();
151        try_cast_ref(*data).unwrap()
152    }
153
154    pub unsafe fn get_context_data_mut(&mut self) -> &mut C {
155        let ContextData::<C, S> { data, .. } =
156            try_cast_ref_mut(utp_context_get_userdata(self.inner)).unwrap();
157        try_cast_ref_mut(*data).unwrap()
158    }
159
160    pub unsafe fn clear_callback(&self, event: UtpEvent) {
161        let ContextData::<C, S> { callbacks, .. } =
162            try_cast_ref_mut(utp_context_get_userdata(self.inner)).unwrap();
163        utp_set_callback(self.inner, event as i32, None);
164        callbacks.remove(&event);
165    }
166
167    pub unsafe fn set_callback<F>(&self, event: UtpEvent, cb: F)
168    where
169        F: FnMut(UtpCallbackArgs<C, S>) -> u64 + 'static,
170    {
171        let ContextData { callbacks, .. } =
172            try_cast_ref_mut(utp_context_get_userdata(self.inner)).unwrap();
173        callbacks.insert(event, Box::new(cb));
174
175        macro_rules! set_callback {
176            ($cb_type:expr) => {{
177                unsafe extern "C" fn cb<C, S>(args: *mut utp_callback_arguments) -> uint64 {
178                    let wrapped_args: UtpCallbackArgs<'_, C, S> = UtpCallbackArgs::new(args);
179                    let cb = wrapped_args
180                        .context
181                        .get_callback($cb_type)
182                        .expect("Callback was not set");
183                    (cb)(UtpCallbackArgs::new(args))
184                }
185                utp_set_callback(self.inner, $cb_type as i32, Some(cb::<C, S>));
186            }};
187        }
188
189        match event {
190            UtpEvent::Log => set_callback!(UtpEvent::Log),
191            UtpEvent::OnRead => set_callback!(UtpEvent::OnRead),
192            UtpEvent::SendTo => set_callback!(UtpEvent::SendTo),
193            UtpEvent::OnAccept => set_callback!(UtpEvent::OnAccept),
194            UtpEvent::OnError => set_callback!(UtpEvent::OnError),
195            UtpEvent::OnFirewall => set_callback!(UtpEvent::OnFirewall),
196            UtpEvent::GetUdpMTU => set_callback!(UtpEvent::GetUdpMTU),
197            UtpEvent::OnStateChange => set_callback!(UtpEvent::OnStateChange),
198        }
199    }
200}
201
202struct ContextData<C, S> {
203    data: *mut c_void,
204    callbacks: HashMap<UtpEvent, UtpCallback<C, S>>,
205}
206
207#[derive(Copy, Clone, Eq, PartialEq, Hash)]
208pub enum UtpEvent {
209    Log = UTP_LOG as isize,
210    OnRead = UTP_ON_READ as isize,
211    SendTo = UTP_SENDTO as isize,
212    OnAccept = UTP_ON_ACCEPT as isize,
213    OnError = UTP_ON_ERROR as isize,
214    OnFirewall = UTP_ON_FIREWALL as isize,
215    GetUdpMTU = UTP_GET_UDP_MTU as isize,
216    OnStateChange = UTP_ON_STATE_CHANGE as isize,
217}
218
219#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
220pub enum UtpState {
221    UtpStateConnect = UTP_STATE_CONNECT as isize,
222    UtpStateWritable = UTP_STATE_WRITABLE as isize,
223    UtpStateEOF = UTP_STATE_EOF as isize,
224    UtpStateDestroying = UTP_STATE_DESTROYING as isize,
225    UtpInvalid,
226}
227
228impl From<i32> for UtpState {
229    fn from(val: i32) -> Self {
230        use UtpState::*;
231        match val {
232            1 => UtpStateConnect,
233            2 => UtpStateWritable,
234            3 => UtpStateEOF,
235            4 => UtpStateDestroying,
236            _ => UtpInvalid,
237        }
238    }
239}
240
241#[derive(Debug)]
242pub enum UtpErrorCode {
243    UtpConnRefused = UTP_ECONNREFUSED as isize,
244    UtpConnReset = UTP_ECONNRESET as isize,
245    UtpETimedOut = UTP_ETIMEDOUT as isize,
246    Invalid,
247}
248
249impl From<i32> for UtpErrorCode {
250    fn from(val: i32) -> Self {
251        use UtpErrorCode::*;
252        match val {
253            0 => UtpConnRefused,
254            1 => UtpConnReset,
255            2 => UtpETimedOut,
256            _ => Invalid,
257        }
258    }
259}
260
261pub struct UtpSocketHandle<S> {
262    socket: UtpSocket<S>,
263}
264
265impl<S> Deref for UtpSocketHandle<S> {
266    type Target = UtpSocket<S>;
267
268    fn deref(&self) -> &Self::Target {
269        &self.socket
270    }
271}
272
273impl<S> DerefMut for UtpSocketHandle<S> {
274    fn deref_mut(&mut self) -> &mut Self::Target {
275        &mut self.socket
276    }
277}
278
279impl<S> Drop for UtpSocketHandle<S> {
280    fn drop(&mut self) {
281        unsafe {
282            let socket_data = utp_get_userdata(self.socket.inner);
283            if !socket_data.is_null() {
284                Box::from_raw(socket_data as *mut S);
285            }
286            utp_close(self.socket.inner);
287        };
288    }
289}
290
291pub struct UtpSocket<S> {
292    inner: *mut utp_socket,
293    socket_data_type: PhantomData<S>,
294}
295
296impl<S> UtpSocket<S> {
297    pub fn wrap(inner: *mut utp_socket) -> Option<UtpSocket<S>> {
298        if !inner.is_null() {
299            Some(UtpSocket {
300                inner,
301                socket_data_type: PhantomData,
302            })
303        } else {
304            None
305        }
306    }
307
308    pub unsafe fn accept(self) -> UtpSocketHandle<S> {
309        UtpSocketHandle { socket: self }
310    }
311
312    pub unsafe fn utp_write(&self, buf: &mut [u8]) -> usize {
313        utp_write(
314            self.inner,
315            buf.as_mut_ptr() as *mut c_void,
316            buf.len().try_into().unwrap(),
317        ) as usize
318    }
319
320    pub unsafe fn utp_read_drained(&self) {
321        utp_read_drained(self.inner);
322    }
323
324    pub unsafe fn set_socket_data(&self, data: S) {
325        let old_data = utp_get_userdata(self.inner);
326        if !old_data.is_null() {
327            Box::from_raw(old_data as *mut S);
328        }
329        utp_set_userdata(self.inner, Box::into_raw(Box::new(data)) as *mut c_void);
330    }
331
332    pub unsafe fn get_socket_data(&self) -> &S {
333        try_cast_ref(utp_get_userdata(self.inner)).unwrap()
334    }
335
336    pub unsafe fn get_socket_data_mut(&mut self) -> &mut S {
337        try_cast_ref_mut(utp_get_userdata(self.inner)).unwrap()
338    }
339}
340
341pub struct UtpCallbackArgs<'a, C, S> {
342    pub context: UtpContext<C, S>,
343    pub socket: Option<UtpSocket<S>>,
344    pub buf: Option<&'a [u8]>,
345    pub raw: *mut utp_callback_arguments,
346}
347
348impl<'a, C, S> UtpCallbackArgs<'a, C, S> {
349    unsafe fn new(args: *mut utp_callback_arguments) -> UtpCallbackArgs<'a, C, S> {
350        UtpCallbackArgs {
351            context: UtpContext::wrap((*args).context),
352            socket: UtpSocket::wrap((*args).socket),
353            buf: buf_to_slice((*args).buf as *const u8, (*args).len as usize),
354            raw: args,
355        }
356    }
357
358    pub unsafe fn address(&self) -> Option<SocketAddr> {
359        socket_addr_from_parts((*self.raw).args1.address, (*self.raw).args2.address_len)
360    }
361
362    pub unsafe fn send(&self) -> i32 {
363        (*self.raw).args1.send
364    }
365
366    pub unsafe fn sample_ms(&self) -> i32 {
367        (*self.raw).args1.sample_ms
368    }
369
370    pub unsafe fn error_code(&self) -> UtpErrorCode {
371        (*self.raw).args1.error_code.into()
372    }
373
374    pub unsafe fn state(&self) -> UtpState {
375        (*self.raw).args1.state.into()
376    }
377
378    pub unsafe fn bandwidth_type(&self) -> i32 {
379        (*self.raw).args2.type_
380    }
381}
382
383unsafe fn try_cast_ref<'a, T>(ptr: *mut c_void) -> Option<&'a T> {
384    (ptr as *const T).as_ref()
385}
386
387unsafe fn try_cast_ref_mut<'a, T>(ptr: *mut c_void) -> Option<&'a mut T> {
388    (ptr as *mut T).as_mut()
389}
390
391unsafe fn buf_to_slice<'a>(buf: *const u8, len: usize) -> Option<&'a [u8]> {
392    if !buf.is_null() {
393        Some(std::slice::from_raw_parts(buf, len))
394    } else {
395        None
396    }
397}
398
399unsafe fn socket_addr_from_parts(addr: *const sockaddr, len: socklen_t) -> Option<SocketAddr> {
400    if !addr.is_null() {
401        socket2::SockAddr::from_raw_parts(addr, len).as_std()
402    } else {
403        None
404    }
405}
406
407#[cfg(test)]
408mod tests {
409    use super::*;
410    use crate::test_utils::get_free_socketaddr;
411    use std::rc::Rc;
412
413    #[test]
414    fn test_context_data() {
415        unsafe {
416            let mut ctx = UtpContextHandle::<u32, u32>::default();
417            ctx.set_context_data(42);
418            let data: u32 = *ctx.get_context_data();
419            let data_mut: u32 = *ctx.get_context_data_mut();
420            assert_eq!(data, 42);
421            assert_eq!(data_mut, 42);
422        }
423    }
424
425    #[test]
426    fn test_socket_data() {
427        unsafe {
428            let ctx = UtpContextHandle::<u32, u32>::default();
429            {
430                let sock = ctx.connect(get_free_socketaddr()).unwrap();
431                sock.set_socket_data(42);
432                let data: u32 = *sock.get_socket_data();
433                let data_mut: u32 = *sock.get_socket_data();
434                assert_eq!(data, 42);
435                assert_eq!(data_mut, 42);
436            }
437        }
438    }
439
440    #[test]
441    fn test_context_data_drop() {
442        let data = Rc::new(42);
443        unsafe {
444            let ctx = UtpContextHandle::<Rc<u32>, ()>::default();
445            ctx.set_context_data(Rc::clone(&data));
446            assert_eq!(Rc::strong_count(&data), 2);
447            ctx.set_context_data(Rc::clone(&data));
448            assert_eq!(Rc::strong_count(&data), 2);
449        }
450        assert_eq!(Rc::strong_count(&data), 1);
451    }
452
453    #[test]
454    fn test_socket_data_drop() {
455        let data = Rc::new(42);
456        unsafe {
457            let ctx = UtpContextHandle::<(), Rc<u32>>::default();
458            {
459                let sock = ctx.connect(get_free_socketaddr()).unwrap();
460                sock.set_socket_data(Rc::clone(&data));
461                assert_eq!(Rc::strong_count(&data), 2);
462                sock.set_socket_data(Rc::clone(&data));
463                assert_eq!(Rc::strong_count(&data), 2);
464            }
465            assert_eq!(Rc::strong_count(&data), 1);
466        }
467    }
468
469    #[test]
470    fn test_callback() {
471        unsafe {
472            let ctx = UtpContextHandle::<bool, ()>::default();
473            ctx.set_context_data(false);
474            ctx.set_callback(UtpEvent::SendTo, |mut args| {
475                *args.context.get_context_data_mut() = true;
476                0
477            });
478            assert_eq!(false, *ctx.get_context_data());
479            // Calling connect will cause the SENDTO callback to fire
480            let _sock = ctx.connect(get_free_socketaddr()).unwrap();
481            assert_eq!(true, *ctx.get_context_data());
482        }
483    }
484}