1use crate::addrinfo::*;
11use crate::Result;
12use libutp_sys::*;
13
14use std::collections::HashMap;
15use std::convert::TryInto;
16use std::ffi::c_void;
17use std::io;
18use std::marker::PhantomData;
19use std::net::SocketAddr;
20use std::ops::{Deref, DerefMut};
21
22type UtpCallback<C, S> = Box<(dyn for<'r> FnMut(UtpCallbackArgs<'r, C, S>) -> u64)>;
23
24pub struct UtpContextHandle<C, S> {
25 ctx: UtpContext<C, S>,
26}
27
28impl<C, S> Deref for UtpContextHandle<C, S> {
29 type Target = UtpContext<C, S>;
30
31 fn deref(&self) -> &Self::Target {
32 &self.ctx
33 }
34}
35
36impl<C, S> DerefMut for UtpContextHandle<C, S> {
37 fn deref_mut(&mut self) -> &mut Self::Target {
38 &mut self.ctx
39 }
40}
41
42impl<C, S> Default for UtpContextHandle<C, S> {
43 fn default() -> Self {
44 let inner = unsafe {
45 let inner = utp_init(2);
46 utp_context_set_userdata(
47 inner,
48 Box::into_raw(Box::new(ContextData::<C, S> {
49 data: std::ptr::null_mut(),
50 callbacks: Default::default(),
51 })) as *mut c_void,
52 );
53 inner
54 };
55
56 UtpContextHandle {
57 ctx: UtpContext::wrap(inner),
58 }
59 }
60}
61
62impl<C, S> Drop for UtpContextHandle<C, S> {
63 fn drop(&mut self) {
64 unsafe {
65 let ContextData::<C, S> { data, .. } =
66 try_cast_ref_mut(utp_context_get_userdata(self.ctx.inner)).unwrap();
67 if !data.is_null() {
68 Box::from_raw(*data as *mut C);
69 }
70 let _ctx_data =
72 Box::from_raw(utp_context_get_userdata(self.ctx.inner) as *mut ContextData<C, S>);
73 utp_destroy(self.ctx.inner);
74 };
75 }
76}
77
78pub struct UtpContext<C, S> {
79 inner: *mut utp_context,
80 context_data_type: PhantomData<C>,
81 socket_data_type: PhantomData<S>,
82}
83
84impl<C, S> UtpContext<C, S> {
85 pub fn wrap(inner: *mut utp_context) -> UtpContext<C, S> {
86 UtpContext {
87 inner,
88 context_data_type: PhantomData,
89 socket_data_type: PhantomData,
90 }
91 }
92
93 pub unsafe fn connect(&self, addr: SocketAddr) -> Result<UtpSocketHandle<S>> {
94 let socket = utp_create_socket(self.inner);
95 let sockaddrinfo = *getaddrinfo_from_std(addr)?;
96 match utp_connect(
97 socket,
98 sockaddrinfo.ai_addr,
99 sockaddrinfo.ai_addrlen.try_into().unwrap(),
100 ) {
101 0 => Ok(UtpSocketHandle {
102 socket: UtpSocket::wrap(socket).unwrap(),
103 }),
104 _ => {
105 utp_close(socket);
106 Err(io::Error::new(
107 io::ErrorKind::Other,
108 "utp_connect returned non-zero error code",
109 ))
110 }
111 }
112 }
113
114 pub unsafe fn utp_issue_deferred_acks(&self) {
115 utp_issue_deferred_acks(self.inner);
116 }
117
118 pub unsafe fn utp_process_udp(&self, from: SocketAddr, buf: &[u8]) -> bool {
119 let ai = *getaddrinfo_from_std(from).unwrap();
120 utp_process_udp(
121 self.inner,
122 buf.as_ptr(),
123 buf.len().try_into().unwrap(),
124 ai.ai_addr,
125 ai.ai_addrlen.try_into().unwrap(),
126 ) != 0
127 }
128
129 pub unsafe fn utp_check_timeouts(&self) {
130 utp_check_timeouts(self.inner);
131 }
132
133 pub unsafe fn set_context_data(&self, data: C) {
134 let ContextData::<C, S> { data: old_data, .. } =
135 try_cast_ref_mut(utp_context_get_userdata(self.inner)).unwrap();
136 if !old_data.is_null() {
137 Box::from_raw(*old_data as *mut C);
138 }
139 *old_data = Box::into_raw(Box::new(data)) as *mut c_void;
140 }
141
142 unsafe fn get_callback(&self, event: UtpEvent) -> Option<&mut UtpCallback<C, S>> {
143 let ContextData::<C, S> { callbacks, .. } =
144 try_cast_ref_mut(utp_context_get_userdata(self.inner)).unwrap();
145 callbacks.get_mut(&event)
146 }
147
148 pub unsafe fn get_context_data(&self) -> &C {
149 let ContextData::<C, S> { data, .. } =
150 try_cast_ref(utp_context_get_userdata(self.inner)).unwrap();
151 try_cast_ref(*data).unwrap()
152 }
153
154 pub unsafe fn get_context_data_mut(&mut self) -> &mut C {
155 let ContextData::<C, S> { data, .. } =
156 try_cast_ref_mut(utp_context_get_userdata(self.inner)).unwrap();
157 try_cast_ref_mut(*data).unwrap()
158 }
159
160 pub unsafe fn clear_callback(&self, event: UtpEvent) {
161 let ContextData::<C, S> { callbacks, .. } =
162 try_cast_ref_mut(utp_context_get_userdata(self.inner)).unwrap();
163 utp_set_callback(self.inner, event as i32, None);
164 callbacks.remove(&event);
165 }
166
167 pub unsafe fn set_callback<F>(&self, event: UtpEvent, cb: F)
168 where
169 F: FnMut(UtpCallbackArgs<C, S>) -> u64 + 'static,
170 {
171 let ContextData { callbacks, .. } =
172 try_cast_ref_mut(utp_context_get_userdata(self.inner)).unwrap();
173 callbacks.insert(event, Box::new(cb));
174
175 macro_rules! set_callback {
176 ($cb_type:expr) => {{
177 unsafe extern "C" fn cb<C, S>(args: *mut utp_callback_arguments) -> uint64 {
178 let wrapped_args: UtpCallbackArgs<'_, C, S> = UtpCallbackArgs::new(args);
179 let cb = wrapped_args
180 .context
181 .get_callback($cb_type)
182 .expect("Callback was not set");
183 (cb)(UtpCallbackArgs::new(args))
184 }
185 utp_set_callback(self.inner, $cb_type as i32, Some(cb::<C, S>));
186 }};
187 }
188
189 match event {
190 UtpEvent::Log => set_callback!(UtpEvent::Log),
191 UtpEvent::OnRead => set_callback!(UtpEvent::OnRead),
192 UtpEvent::SendTo => set_callback!(UtpEvent::SendTo),
193 UtpEvent::OnAccept => set_callback!(UtpEvent::OnAccept),
194 UtpEvent::OnError => set_callback!(UtpEvent::OnError),
195 UtpEvent::OnFirewall => set_callback!(UtpEvent::OnFirewall),
196 UtpEvent::GetUdpMTU => set_callback!(UtpEvent::GetUdpMTU),
197 UtpEvent::OnStateChange => set_callback!(UtpEvent::OnStateChange),
198 }
199 }
200}
201
202struct ContextData<C, S> {
203 data: *mut c_void,
204 callbacks: HashMap<UtpEvent, UtpCallback<C, S>>,
205}
206
207#[derive(Copy, Clone, Eq, PartialEq, Hash)]
208pub enum UtpEvent {
209 Log = UTP_LOG as isize,
210 OnRead = UTP_ON_READ as isize,
211 SendTo = UTP_SENDTO as isize,
212 OnAccept = UTP_ON_ACCEPT as isize,
213 OnError = UTP_ON_ERROR as isize,
214 OnFirewall = UTP_ON_FIREWALL as isize,
215 GetUdpMTU = UTP_GET_UDP_MTU as isize,
216 OnStateChange = UTP_ON_STATE_CHANGE as isize,
217}
218
219#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
220pub enum UtpState {
221 UtpStateConnect = UTP_STATE_CONNECT as isize,
222 UtpStateWritable = UTP_STATE_WRITABLE as isize,
223 UtpStateEOF = UTP_STATE_EOF as isize,
224 UtpStateDestroying = UTP_STATE_DESTROYING as isize,
225 UtpInvalid,
226}
227
228impl From<i32> for UtpState {
229 fn from(val: i32) -> Self {
230 use UtpState::*;
231 match val {
232 1 => UtpStateConnect,
233 2 => UtpStateWritable,
234 3 => UtpStateEOF,
235 4 => UtpStateDestroying,
236 _ => UtpInvalid,
237 }
238 }
239}
240
241#[derive(Debug)]
242pub enum UtpErrorCode {
243 UtpConnRefused = UTP_ECONNREFUSED as isize,
244 UtpConnReset = UTP_ECONNRESET as isize,
245 UtpETimedOut = UTP_ETIMEDOUT as isize,
246 Invalid,
247}
248
249impl From<i32> for UtpErrorCode {
250 fn from(val: i32) -> Self {
251 use UtpErrorCode::*;
252 match val {
253 0 => UtpConnRefused,
254 1 => UtpConnReset,
255 2 => UtpETimedOut,
256 _ => Invalid,
257 }
258 }
259}
260
261pub struct UtpSocketHandle<S> {
262 socket: UtpSocket<S>,
263}
264
265impl<S> Deref for UtpSocketHandle<S> {
266 type Target = UtpSocket<S>;
267
268 fn deref(&self) -> &Self::Target {
269 &self.socket
270 }
271}
272
273impl<S> DerefMut for UtpSocketHandle<S> {
274 fn deref_mut(&mut self) -> &mut Self::Target {
275 &mut self.socket
276 }
277}
278
279impl<S> Drop for UtpSocketHandle<S> {
280 fn drop(&mut self) {
281 unsafe {
282 let socket_data = utp_get_userdata(self.socket.inner);
283 if !socket_data.is_null() {
284 Box::from_raw(socket_data as *mut S);
285 }
286 utp_close(self.socket.inner);
287 };
288 }
289}
290
291pub struct UtpSocket<S> {
292 inner: *mut utp_socket,
293 socket_data_type: PhantomData<S>,
294}
295
296impl<S> UtpSocket<S> {
297 pub fn wrap(inner: *mut utp_socket) -> Option<UtpSocket<S>> {
298 if !inner.is_null() {
299 Some(UtpSocket {
300 inner,
301 socket_data_type: PhantomData,
302 })
303 } else {
304 None
305 }
306 }
307
308 pub unsafe fn accept(self) -> UtpSocketHandle<S> {
309 UtpSocketHandle { socket: self }
310 }
311
312 pub unsafe fn utp_write(&self, buf: &mut [u8]) -> usize {
313 utp_write(
314 self.inner,
315 buf.as_mut_ptr() as *mut c_void,
316 buf.len().try_into().unwrap(),
317 ) as usize
318 }
319
320 pub unsafe fn utp_read_drained(&self) {
321 utp_read_drained(self.inner);
322 }
323
324 pub unsafe fn set_socket_data(&self, data: S) {
325 let old_data = utp_get_userdata(self.inner);
326 if !old_data.is_null() {
327 Box::from_raw(old_data as *mut S);
328 }
329 utp_set_userdata(self.inner, Box::into_raw(Box::new(data)) as *mut c_void);
330 }
331
332 pub unsafe fn get_socket_data(&self) -> &S {
333 try_cast_ref(utp_get_userdata(self.inner)).unwrap()
334 }
335
336 pub unsafe fn get_socket_data_mut(&mut self) -> &mut S {
337 try_cast_ref_mut(utp_get_userdata(self.inner)).unwrap()
338 }
339}
340
341pub struct UtpCallbackArgs<'a, C, S> {
342 pub context: UtpContext<C, S>,
343 pub socket: Option<UtpSocket<S>>,
344 pub buf: Option<&'a [u8]>,
345 pub raw: *mut utp_callback_arguments,
346}
347
348impl<'a, C, S> UtpCallbackArgs<'a, C, S> {
349 unsafe fn new(args: *mut utp_callback_arguments) -> UtpCallbackArgs<'a, C, S> {
350 UtpCallbackArgs {
351 context: UtpContext::wrap((*args).context),
352 socket: UtpSocket::wrap((*args).socket),
353 buf: buf_to_slice((*args).buf as *const u8, (*args).len as usize),
354 raw: args,
355 }
356 }
357
358 pub unsafe fn address(&self) -> Option<SocketAddr> {
359 socket_addr_from_parts((*self.raw).args1.address, (*self.raw).args2.address_len)
360 }
361
362 pub unsafe fn send(&self) -> i32 {
363 (*self.raw).args1.send
364 }
365
366 pub unsafe fn sample_ms(&self) -> i32 {
367 (*self.raw).args1.sample_ms
368 }
369
370 pub unsafe fn error_code(&self) -> UtpErrorCode {
371 (*self.raw).args1.error_code.into()
372 }
373
374 pub unsafe fn state(&self) -> UtpState {
375 (*self.raw).args1.state.into()
376 }
377
378 pub unsafe fn bandwidth_type(&self) -> i32 {
379 (*self.raw).args2.type_
380 }
381}
382
383unsafe fn try_cast_ref<'a, T>(ptr: *mut c_void) -> Option<&'a T> {
384 (ptr as *const T).as_ref()
385}
386
387unsafe fn try_cast_ref_mut<'a, T>(ptr: *mut c_void) -> Option<&'a mut T> {
388 (ptr as *mut T).as_mut()
389}
390
391unsafe fn buf_to_slice<'a>(buf: *const u8, len: usize) -> Option<&'a [u8]> {
392 if !buf.is_null() {
393 Some(std::slice::from_raw_parts(buf, len))
394 } else {
395 None
396 }
397}
398
399unsafe fn socket_addr_from_parts(addr: *const sockaddr, len: socklen_t) -> Option<SocketAddr> {
400 if !addr.is_null() {
401 socket2::SockAddr::from_raw_parts(addr, len).as_std()
402 } else {
403 None
404 }
405}
406
407#[cfg(test)]
408mod tests {
409 use super::*;
410 use crate::test_utils::get_free_socketaddr;
411 use std::rc::Rc;
412
413 #[test]
414 fn test_context_data() {
415 unsafe {
416 let mut ctx = UtpContextHandle::<u32, u32>::default();
417 ctx.set_context_data(42);
418 let data: u32 = *ctx.get_context_data();
419 let data_mut: u32 = *ctx.get_context_data_mut();
420 assert_eq!(data, 42);
421 assert_eq!(data_mut, 42);
422 }
423 }
424
425 #[test]
426 fn test_socket_data() {
427 unsafe {
428 let ctx = UtpContextHandle::<u32, u32>::default();
429 {
430 let sock = ctx.connect(get_free_socketaddr()).unwrap();
431 sock.set_socket_data(42);
432 let data: u32 = *sock.get_socket_data();
433 let data_mut: u32 = *sock.get_socket_data();
434 assert_eq!(data, 42);
435 assert_eq!(data_mut, 42);
436 }
437 }
438 }
439
440 #[test]
441 fn test_context_data_drop() {
442 let data = Rc::new(42);
443 unsafe {
444 let ctx = UtpContextHandle::<Rc<u32>, ()>::default();
445 ctx.set_context_data(Rc::clone(&data));
446 assert_eq!(Rc::strong_count(&data), 2);
447 ctx.set_context_data(Rc::clone(&data));
448 assert_eq!(Rc::strong_count(&data), 2);
449 }
450 assert_eq!(Rc::strong_count(&data), 1);
451 }
452
453 #[test]
454 fn test_socket_data_drop() {
455 let data = Rc::new(42);
456 unsafe {
457 let ctx = UtpContextHandle::<(), Rc<u32>>::default();
458 {
459 let sock = ctx.connect(get_free_socketaddr()).unwrap();
460 sock.set_socket_data(Rc::clone(&data));
461 assert_eq!(Rc::strong_count(&data), 2);
462 sock.set_socket_data(Rc::clone(&data));
463 assert_eq!(Rc::strong_count(&data), 2);
464 }
465 assert_eq!(Rc::strong_count(&data), 1);
466 }
467 }
468
469 #[test]
470 fn test_callback() {
471 unsafe {
472 let ctx = UtpContextHandle::<bool, ()>::default();
473 ctx.set_context_data(false);
474 ctx.set_callback(UtpEvent::SendTo, |mut args| {
475 *args.context.get_context_data_mut() = true;
476 0
477 });
478 assert_eq!(false, *ctx.get_context_data());
479 let _sock = ctx.connect(get_free_socketaddr()).unwrap();
481 assert_eq!(true, *ctx.get_context_data());
482 }
483 }
484}