1#![allow(clippy::missing_errors_doc, clippy::semicolon_if_nothing_returned)]
4
5use core::ffi::{c_int, c_void};
6use std::ffi::CString;
7use std::sync::{Arc, Mutex};
8
9use crate::client::{ContentContext, TcpClient};
10use crate::endpoint::Endpoint;
11use crate::error::{from_status, NetworkError};
12use crate::ffi;
13use crate::parameters::{ConnectionParameters, KeepAlives};
14use crate::path::Path;
15use crate::protocol::{ProtocolDefinition, ProtocolMetadata, ProtocolOptions};
16use doom_fish_utils::panic_safe::catch_user_panic;
17
18fn to_cstring(value: &str, field: &str) -> Result<CString, NetworkError> {
19 CString::new(value).map_err(|e| NetworkError::InvalidArgument(format!("{field} NUL byte: {e}")))
20}
21
22pub struct ConnectionGroupDescriptor {
24 handle: *mut c_void,
25}
26
27unsafe impl Send for ConnectionGroupDescriptor {}
28unsafe impl Sync for ConnectionGroupDescriptor {}
29
30impl std::fmt::Debug for ConnectionGroupDescriptor {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 f.debug_struct("ConnectionGroupDescriptor")
33 .field("handle", &self.handle)
34 .finish()
35 }
36}
37
38impl ConnectionGroupDescriptor {
39 pub fn multiplex(host: &str, port: u16) -> Result<Self, NetworkError> {
41 let host = to_cstring(host, "host")?;
42 let handle = unsafe { ffi::nw_shim_group_descriptor_create_multiplex(host.as_ptr(), port) };
43 if handle.is_null() {
44 return Err(NetworkError::InvalidArgument(
45 "failed to create multiplex group descriptor".into(),
46 ));
47 }
48 Ok(Self { handle })
49 }
50
51 pub fn multicast(group_address: &str, port: u16) -> Result<Self, NetworkError> {
53 let group_address = to_cstring(group_address, "group_address")?;
54 let handle =
55 unsafe { ffi::nw_shim_group_descriptor_create_multicast(group_address.as_ptr(), port) };
56 if handle.is_null() {
57 return Err(NetworkError::InvalidArgument(
58 "failed to create multicast group descriptor".into(),
59 ));
60 }
61 Ok(Self { handle })
62 }
63
64 pub fn add_endpoint(&mut self, host: &str, port: u16) -> Result<&mut Self, NetworkError> {
66 let host = to_cstring(host, "host")?;
67 let added =
68 unsafe { ffi::nw_shim_group_descriptor_add_endpoint(self.handle, host.as_ptr(), port) };
69 if added == 0 {
70 return Err(NetworkError::InvalidArgument(
71 "failed to add endpoint to group descriptor".into(),
72 ));
73 }
74 Ok(self)
75 }
76
77 #[must_use]
79 pub fn endpoints(&self) -> Vec<Endpoint> {
80 unsafe extern "C" fn collect(endpoint: *mut c_void, user_info: *mut c_void) -> c_int {
81 if user_info.is_null() || endpoint.is_null() {
82 return 0;
83 }
84 let endpoints = unsafe { &mut *user_info.cast::<Vec<Endpoint>>() };
85 endpoints.push(unsafe { Endpoint::from_raw(endpoint) });
86 1
87 }
88
89 let mut endpoints = Vec::new();
90 unsafe {
91 ffi::nw_shim_group_descriptor_enumerate_endpoints(
92 self.handle,
93 Some(collect),
94 std::ptr::addr_of_mut!(endpoints).cast(),
95 )
96 };
97 endpoints
98 }
99
100 pub fn set_specific_source(&mut self, endpoint: &Endpoint) -> &mut Self {
102 unsafe {
103 ffi::nw_shim_multicast_group_descriptor_set_specific_source(
104 self.handle,
105 endpoint.as_ptr(),
106 )
107 };
108 self
109 }
110
111 #[must_use]
113 pub fn disable_unicast_traffic(&self) -> bool {
114 unsafe {
115 ffi::nw_shim_multicast_group_descriptor_get_disable_unicast_traffic(self.handle) != 0
116 }
117 }
118
119 pub fn set_disable_unicast_traffic(&mut self, disable_unicast_traffic: bool) -> &mut Self {
121 unsafe {
122 ffi::nw_shim_multicast_group_descriptor_set_disable_unicast_traffic(
123 self.handle,
124 c_int::from(disable_unicast_traffic),
125 )
126 };
127 self
128 }
129
130 #[must_use]
131 pub(crate) const fn as_ptr(&self) -> *mut c_void {
132 self.handle
133 }
134}
135
136impl Clone for ConnectionGroupDescriptor {
137 fn clone(&self) -> Self {
138 let handle = unsafe { ffi::nw_shim_retain_object(self.handle) };
139 Self { handle }
140 }
141}
142
143impl Drop for ConnectionGroupDescriptor {
144 fn drop(&mut self) {
145 if !self.handle.is_null() {
146 unsafe { ffi::nw_shim_release_object(self.handle) };
147 self.handle = core::ptr::null_mut();
148 }
149 }
150}
151
152#[derive(Debug, Clone, Copy, PartialEq, Eq)]
154pub enum ConnectionGroupState {
155 Invalid,
156 Waiting,
157 Ready,
158 Failed,
159 Cancelled,
160}
161
162impl ConnectionGroupState {
163 const fn from_raw(raw: i32) -> Self {
164 match raw {
165 1 => Self::Waiting,
166 2 => Self::Ready,
167 3 => Self::Failed,
168 4 => Self::Cancelled,
169 _ => Self::Invalid,
170 }
171 }
172}
173
174#[derive(Debug, Clone)]
176pub struct ConnectionGroupMessage {
177 pub data: Vec<u8>,
178 pub context: Option<ContentContext>,
179 pub is_complete: bool,
180}
181
182type StateCallback = Mutex<Box<dyn FnMut(ConnectionGroupState) + Send + 'static>>;
183type ReceiveCallback = Mutex<Box<dyn FnMut(ConnectionGroupMessage) + Send + 'static>>;
184
185struct NewConnectionCallback {
186 keepalives: KeepAlives,
187 callback: Mutex<Box<dyn FnMut(TcpClient) + Send + 'static>>,
188}
189
190#[allow(clippy::type_complexity)]
192pub struct ConnectionGroup {
193 handle: *mut c_void,
194 state_callback: Option<Arc<StateCallback>>,
195 receive_callback: Option<Arc<ReceiveCallback>>,
196 new_connection_callback: Option<Arc<NewConnectionCallback>>,
197 keepalives: KeepAlives,
198}
199
200unsafe impl Send for ConnectionGroup {}
201unsafe impl Sync for ConnectionGroup {}
202
203impl std::fmt::Debug for ConnectionGroup {
204 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205 f.debug_struct("ConnectionGroup")
206 .field("handle", &self.handle)
207 .field("has_state_callback", &self.state_callback.is_some())
208 .field("has_receive_callback", &self.receive_callback.is_some())
209 .field("has_new_connection_callback", &self.new_connection_callback.is_some())
210 .finish_non_exhaustive()
211 }
212}
213
214impl ConnectionGroup {
215 pub fn new(
217 descriptor: &ConnectionGroupDescriptor,
218 parameters: &crate::ConnectionParameters,
219 ) -> Result<Self, NetworkError> {
220 let handle = unsafe {
221 ffi::nw_shim_connection_group_create(descriptor.as_ptr(), parameters.as_ptr())
222 };
223 if handle.is_null() {
224 return Err(NetworkError::InvalidArgument(
225 "failed to create connection group".into(),
226 ));
227 }
228 Ok(Self {
229 handle,
230 state_callback: None,
231 receive_callback: None,
232 new_connection_callback: None,
233 keepalives: parameters.keepalives(),
234 })
235 }
236
237 pub fn set_state_changed_handler<F>(&mut self, callback: F)
239 where
240 F: FnMut(ConnectionGroupState) + Send + 'static,
241 {
242 let callback: Box<dyn FnMut(ConnectionGroupState) + Send + 'static> = Box::new(callback);
243 let arc = Arc::new(Mutex::new(callback));
244 let raw = Arc::into_raw(arc.clone()).cast::<c_void>().cast_mut();
245 unsafe {
246 ffi::nw_shim_connection_group_set_state_changed_handler(
247 self.handle,
248 Some(state_trampoline),
249 raw,
250 )
251 };
252 self.state_callback = Some(arc);
253 }
254
255 pub fn set_receive_handler<F>(
257 &mut self,
258 maximum_message_size: u32,
259 reject_oversized_messages: bool,
260 callback: F,
261 ) where
262 F: FnMut(ConnectionGroupMessage) + Send + 'static,
263 {
264 let callback: Box<dyn FnMut(ConnectionGroupMessage) + Send + 'static> = Box::new(callback);
265 let arc = Arc::new(Mutex::new(callback));
266 let raw = Arc::into_raw(arc.clone()).cast::<c_void>().cast_mut();
267 unsafe {
268 ffi::nw_shim_connection_group_set_receive_handler(
269 self.handle,
270 maximum_message_size,
271 c_int::from(reject_oversized_messages),
272 Some(receive_trampoline),
273 raw,
274 )
275 };
276 self.receive_callback = Some(arc);
277 }
278
279 #[must_use]
284 pub(crate) const unsafe fn from_raw(handle: *mut c_void, keepalives: KeepAlives) -> Self {
285 Self {
286 handle,
287 state_callback: None,
288 receive_callback: None,
289 new_connection_callback: None,
290 keepalives,
291 }
292 }
293
294 pub fn start(&self) -> Result<(), NetworkError> {
296 let status = unsafe { ffi::nw_shim_connection_group_start(self.handle) };
297 if status != ffi::NW_OK {
298 return Err(from_status(status));
299 }
300 Ok(())
301 }
302
303 pub fn send(&self, data: &[u8], context: &ContentContext) -> Result<(), NetworkError> {
305 let status = unsafe {
306 ffi::nw_shim_connection_group_send(
307 self.handle,
308 data.as_ptr(),
309 data.len(),
310 core::ptr::null(),
311 0,
312 context.as_ptr(),
313 )
314 };
315 if status != ffi::NW_OK {
316 return Err(from_status(status));
317 }
318 Ok(())
319 }
320
321 pub fn send_to(
323 &self,
324 host: &str,
325 port: u16,
326 data: &[u8],
327 context: &ContentContext,
328 ) -> Result<(), NetworkError> {
329 let host = to_cstring(host, "host")?;
330 let status = unsafe {
331 ffi::nw_shim_connection_group_send(
332 self.handle,
333 data.as_ptr(),
334 data.len(),
335 host.as_ptr(),
336 port,
337 context.as_ptr(),
338 )
339 };
340 if status != ffi::NW_OK {
341 return Err(from_status(status));
342 }
343 Ok(())
344 }
345
346 #[must_use]
348 pub fn descriptor(&self) -> Option<ConnectionGroupDescriptor> {
349 let handle = unsafe { ffi::nw_shim_connection_group_copy_descriptor(self.handle) };
350 (!handle.is_null()).then_some(ConnectionGroupDescriptor { handle })
351 }
352
353 #[must_use]
355 pub fn parameters(&self) -> Option<ConnectionParameters> {
356 let handle = unsafe { ffi::nw_shim_connection_group_copy_parameters(self.handle) };
357 (!handle.is_null()).then_some(unsafe { ConnectionParameters::from_raw(handle) })
358 }
359
360 #[must_use]
362 pub fn remote_endpoint_for_message(&self, context: &ContentContext) -> Option<Endpoint> {
363 let handle = unsafe {
364 ffi::nw_shim_connection_group_copy_remote_endpoint_for_message(
365 self.handle,
366 context.as_ptr(),
367 )
368 };
369 (!handle.is_null()).then_some(unsafe { Endpoint::from_raw(handle) })
370 }
371
372 #[must_use]
374 pub fn local_endpoint_for_message(&self, context: &ContentContext) -> Option<Endpoint> {
375 let handle = unsafe {
376 ffi::nw_shim_connection_group_copy_local_endpoint_for_message(
377 self.handle,
378 context.as_ptr(),
379 )
380 };
381 (!handle.is_null()).then_some(unsafe { Endpoint::from_raw(handle) })
382 }
383
384 #[must_use]
386 pub fn path_for_message(&self, context: &ContentContext) -> Option<Path> {
387 let handle = unsafe {
388 ffi::nw_shim_connection_group_copy_path_for_message(self.handle, context.as_ptr())
389 };
390 (!handle.is_null()).then_some(unsafe { Path::from_raw(handle) })
391 }
392
393 #[must_use]
395 pub fn protocol_metadata(&self, definition: &ProtocolDefinition) -> Option<ProtocolMetadata> {
396 let handle = unsafe {
397 ffi::nw_shim_connection_group_copy_protocol_metadata(self.handle, definition.as_ptr())
398 };
399 (!handle.is_null()).then_some(unsafe { ProtocolMetadata::from_raw(handle) })
400 }
401
402 #[must_use]
404 pub fn protocol_metadata_for_message(
405 &self,
406 context: &ContentContext,
407 definition: &ProtocolDefinition,
408 ) -> Option<ProtocolMetadata> {
409 let handle = unsafe {
410 ffi::nw_shim_connection_group_copy_protocol_metadata_for_message(
411 self.handle,
412 context.as_ptr(),
413 definition.as_ptr(),
414 )
415 };
416 (!handle.is_null()).then_some(unsafe { ProtocolMetadata::from_raw(handle) })
417 }
418
419 pub fn extract_connection_for_message(
421 &self,
422 context: &ContentContext,
423 ) -> Result<TcpClient, NetworkError> {
424 let mut status = ffi::NW_OK;
425 let handle = unsafe {
426 ffi::nw_shim_connection_group_extract_connection_for_message(
427 self.handle,
428 context.as_ptr(),
429 &mut status,
430 )
431 };
432 if status != ffi::NW_OK || handle.is_null() {
433 return Err(from_status(status));
434 }
435 Ok(unsafe { TcpClient::from_raw_with_keepalives(handle, self.keepalives.clone()) })
436 }
437
438 pub fn extract_connection(
440 &self,
441 endpoint: &Endpoint,
442 protocol_options: &ProtocolOptions,
443 ) -> Result<TcpClient, NetworkError> {
444 let mut status = ffi::NW_OK;
445 let handle = unsafe {
446 ffi::nw_shim_connection_group_extract_connection(
447 self.handle,
448 endpoint.as_ptr(),
449 protocol_options.as_ptr(),
450 &mut status,
451 )
452 };
453 if status != ffi::NW_OK || handle.is_null() {
454 return Err(from_status(status));
455 }
456 Ok(unsafe { TcpClient::from_raw_with_keepalives(handle, self.keepalives.clone()) })
457 }
458
459 pub fn set_new_connection_handler<F>(&mut self, callback: F)
461 where
462 F: FnMut(TcpClient) + Send + 'static,
463 {
464 let handler = Arc::new(NewConnectionCallback {
465 keepalives: self.keepalives.clone(),
466 callback: Mutex::new(Box::new(callback)),
467 });
468 let raw = Arc::into_raw(handler.clone()).cast::<c_void>().cast_mut();
469 unsafe {
470 ffi::nw_shim_connection_group_set_new_connection_handler(
471 self.handle,
472 Some(new_connection_trampoline),
473 raw,
474 );
475 };
476 self.new_connection_callback = Some(handler);
477 }
478
479 pub fn reinsert_extracted_connection(&self, connection: TcpClient) -> Result<(), NetworkError> {
481 let status = unsafe {
482 ffi::nw_shim_connection_group_reinsert_extracted_connection(
483 self.handle,
484 connection.as_ptr(),
485 )
486 };
487 if status != ffi::NW_OK {
488 return Err(from_status(status));
489 }
490 std::mem::forget(connection);
491 Ok(())
492 }
493
494 pub fn reply(
496 &self,
497 inbound_message: &ContentContext,
498 outbound_message: Option<&ContentContext>,
499 data: &[u8],
500 ) -> Result<(), NetworkError> {
501 let status = unsafe {
502 ffi::nw_shim_connection_group_reply(
503 self.handle,
504 inbound_message.as_ptr(),
505 outbound_message.map_or(core::ptr::null_mut(), ContentContext::as_ptr),
506 data.as_ptr(),
507 data.len(),
508 )
509 };
510 if status != ffi::NW_OK {
511 return Err(from_status(status));
512 }
513 Ok(())
514 }
515
516 pub fn cancel(&self) {
518 unsafe { ffi::nw_shim_connection_group_cancel(self.handle) };
519 }
520}
521
522impl Drop for ConnectionGroup {
523 fn drop(&mut self) {
524 if !self.handle.is_null() {
525 unsafe { ffi::nw_shim_connection_group_release(self.handle) };
526 self.handle = core::ptr::null_mut();
527 }
528 }
529}
530
531unsafe extern "C" fn state_trampoline(state: c_int, user_info: *mut c_void) {
532 if user_info.is_null() {
533 return;
534 }
535 let callback = unsafe { &*user_info.cast::<StateCallback>() };
536 let Ok(mut guard) = callback.lock() else {
537 return;
538 };
539 let state = ConnectionGroupState::from_raw(state);
540 catch_user_panic("connection_group_state_trampoline", || {
541 guard(state);
542 });
543}
544
545unsafe extern "C" fn new_connection_trampoline(connection: *mut c_void, user_info: *mut c_void) {
546 if user_info.is_null() || connection.is_null() {
547 return;
548 }
549 let callback = unsafe { &*user_info.cast::<NewConnectionCallback>() };
550 let Ok(mut guard) = callback.callback.lock() else {
551 return;
552 };
553 let client =
554 unsafe { TcpClient::from_raw_with_keepalives(connection, callback.keepalives.clone()) };
555 catch_user_panic("connection_group_new_connection_trampoline", || {
556 guard(client);
557 });
558}
559
560unsafe extern "C" fn receive_trampoline(
561 data: *const u8,
562 len: usize,
563 context: *mut c_void,
564 is_complete: c_int,
565 user_info: *mut c_void,
566) {
567 if user_info.is_null() {
568 return;
569 }
570 let callback = unsafe { &*user_info.cast::<ReceiveCallback>() };
571 let Ok(mut guard) = callback.lock() else {
572 return;
573 };
574 let bytes = if data.is_null() || len == 0 {
575 Vec::new()
576 } else {
577 unsafe { std::slice::from_raw_parts(data, len) }.to_vec()
578 };
579 let context = if context.is_null() {
580 None
581 } else {
582 Some(unsafe { ContentContext::from_raw(context) })
583 };
584 let message = ConnectionGroupMessage {
585 data: bytes,
586 context,
587 is_complete: is_complete != 0,
588 };
589 catch_user_panic("connection_group_receive_trampoline", || {
590 guard(message);
591 });
592}