Skip to main content

networkframework/connection_group/
mod.rs

1//! Connection groups built on `nw_connection_group_*`.
2
3#![allow(clippy::missing_errors_doc, clippy::semicolon_if_nothing_returned)]
4
5use core::ffi::{c_int, c_void};
6use std::ffi::CString;
7use std::sync::{Arc, Mutex};
8
9use crate::client::{ContentContext, TcpClient};
10use crate::endpoint::Endpoint;
11use crate::error::{from_status, NetworkError};
12use crate::ffi;
13use crate::parameters::{ConnectionParameters, KeepAlives};
14use crate::path::Path;
15use crate::protocol::{ProtocolDefinition, ProtocolMetadata, ProtocolOptions};
16use doom_fish_utils::panic_safe::catch_user_panic;
17
18fn to_cstring(value: &str, field: &str) -> Result<CString, NetworkError> {
19    CString::new(value).map_err(|e| NetworkError::InvalidArgument(format!("{field} NUL byte: {e}")))
20}
21
22/// A group descriptor for multicast or multiplex connection groups.
23pub struct ConnectionGroupDescriptor {
24    handle: *mut c_void,
25}
26
27unsafe impl Send for ConnectionGroupDescriptor {}
28unsafe impl Sync for ConnectionGroupDescriptor {}
29
30impl std::fmt::Debug for ConnectionGroupDescriptor {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        f.debug_struct("ConnectionGroupDescriptor")
33            .field("handle", &self.handle)
34            .finish()
35    }
36}
37
38impl ConnectionGroupDescriptor {
39    /// Create a multiplex group descriptor for a remote endpoint.
40    pub fn multiplex(host: &str, port: u16) -> Result<Self, NetworkError> {
41        let host = to_cstring(host, "host")?;
42        let handle = unsafe { ffi::nw_shim_group_descriptor_create_multiplex(host.as_ptr(), port) };
43        if handle.is_null() {
44            return Err(NetworkError::InvalidArgument(
45                "failed to create multiplex group descriptor".into(),
46            ));
47        }
48        Ok(Self { handle })
49    }
50
51    /// Create a multicast group descriptor from an IP multicast address.
52    pub fn multicast(group_address: &str, port: u16) -> Result<Self, NetworkError> {
53        let group_address = to_cstring(group_address, "group_address")?;
54        let handle =
55            unsafe { ffi::nw_shim_group_descriptor_create_multicast(group_address.as_ptr(), port) };
56        if handle.is_null() {
57            return Err(NetworkError::InvalidArgument(
58                "failed to create multicast group descriptor".into(),
59            ));
60        }
61        Ok(Self { handle })
62    }
63
64    /// Add another endpoint to the descriptor.
65    pub fn add_endpoint(&mut self, host: &str, port: u16) -> Result<&mut Self, NetworkError> {
66        let host = to_cstring(host, "host")?;
67        let added =
68            unsafe { ffi::nw_shim_group_descriptor_add_endpoint(self.handle, host.as_ptr(), port) };
69        if added == 0 {
70            return Err(NetworkError::InvalidArgument(
71                "failed to add endpoint to group descriptor".into(),
72            ));
73        }
74        Ok(self)
75    }
76
77    /// Enumerate the endpoints described by this connection group.
78    #[must_use]
79    pub fn endpoints(&self) -> Vec<Endpoint> {
80        unsafe extern "C" fn collect(endpoint: *mut c_void, user_info: *mut c_void) -> c_int {
81            if user_info.is_null() || endpoint.is_null() {
82                return 0;
83            }
84            let endpoints = unsafe { &mut *user_info.cast::<Vec<Endpoint>>() };
85            endpoints.push(unsafe { Endpoint::from_raw(endpoint) });
86            1
87        }
88
89        let mut endpoints = Vec::new();
90        unsafe {
91            ffi::nw_shim_group_descriptor_enumerate_endpoints(
92                self.handle,
93                Some(collect),
94                std::ptr::addr_of_mut!(endpoints).cast(),
95            )
96        };
97        endpoints
98    }
99
100    /// Restrict multicast traffic to a specific source endpoint.
101    pub fn set_specific_source(&mut self, endpoint: &Endpoint) -> &mut Self {
102        unsafe {
103            ffi::nw_shim_multicast_group_descriptor_set_specific_source(
104                self.handle,
105                endpoint.as_ptr(),
106            )
107        };
108        self
109    }
110
111    /// Whether unicast traffic is disabled for multicast descriptors.
112    #[must_use]
113    pub fn disable_unicast_traffic(&self) -> bool {
114        unsafe {
115            ffi::nw_shim_multicast_group_descriptor_get_disable_unicast_traffic(self.handle) != 0
116        }
117    }
118
119    /// Enable or disable unicast traffic for multicast descriptors.
120    pub fn set_disable_unicast_traffic(&mut self, disable_unicast_traffic: bool) -> &mut Self {
121        unsafe {
122            ffi::nw_shim_multicast_group_descriptor_set_disable_unicast_traffic(
123                self.handle,
124                c_int::from(disable_unicast_traffic),
125            )
126        };
127        self
128    }
129
130    #[must_use]
131    pub(crate) const fn as_ptr(&self) -> *mut c_void {
132        self.handle
133    }
134}
135
136impl Clone for ConnectionGroupDescriptor {
137    fn clone(&self) -> Self {
138        let handle = unsafe { ffi::nw_shim_retain_object(self.handle) };
139        Self { handle }
140    }
141}
142
143impl Drop for ConnectionGroupDescriptor {
144    fn drop(&mut self) {
145        if !self.handle.is_null() {
146            unsafe { ffi::nw_shim_release_object(self.handle) };
147            self.handle = core::ptr::null_mut();
148        }
149    }
150}
151
152/// Connection group lifecycle states.
153#[derive(Debug, Clone, Copy, PartialEq, Eq)]
154pub enum ConnectionGroupState {
155    Invalid,
156    Waiting,
157    Ready,
158    Failed,
159    Cancelled,
160}
161
162impl ConnectionGroupState {
163    const fn from_raw(raw: i32) -> Self {
164        match raw {
165            1 => Self::Waiting,
166            2 => Self::Ready,
167            3 => Self::Failed,
168            4 => Self::Cancelled,
169            _ => Self::Invalid,
170        }
171    }
172}
173
174/// An inbound connection-group message.
175#[derive(Debug, Clone)]
176pub struct ConnectionGroupMessage {
177    pub data: Vec<u8>,
178    pub context: Option<ContentContext>,
179    pub is_complete: bool,
180}
181
182type StateCallback = Mutex<Box<dyn FnMut(ConnectionGroupState) + Send + 'static>>;
183type ReceiveCallback = Mutex<Box<dyn FnMut(ConnectionGroupMessage) + Send + 'static>>;
184
185struct NewConnectionCallback {
186    keepalives: KeepAlives,
187    callback: Mutex<Box<dyn FnMut(TcpClient) + Send + 'static>>,
188}
189
190/// A running connection group.
191#[allow(clippy::type_complexity)]
192pub struct ConnectionGroup {
193    handle: *mut c_void,
194    state_callback: Option<Arc<StateCallback>>,
195    receive_callback: Option<Arc<ReceiveCallback>>,
196    new_connection_callback: Option<Arc<NewConnectionCallback>>,
197    keepalives: KeepAlives,
198}
199
200unsafe impl Send for ConnectionGroup {}
201unsafe impl Sync for ConnectionGroup {}
202
203impl std::fmt::Debug for ConnectionGroup {
204    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205        f.debug_struct("ConnectionGroup")
206            .field("handle", &self.handle)
207            .field("has_state_callback", &self.state_callback.is_some())
208            .field("has_receive_callback", &self.receive_callback.is_some())
209            .field("has_new_connection_callback", &self.new_connection_callback.is_some())
210            .finish_non_exhaustive()
211    }
212}
213
214impl ConnectionGroup {
215    /// Create a connection group from a descriptor and parameters.
216    pub fn new(
217        descriptor: &ConnectionGroupDescriptor,
218        parameters: &crate::ConnectionParameters,
219    ) -> Result<Self, NetworkError> {
220        let handle = unsafe {
221            ffi::nw_shim_connection_group_create(descriptor.as_ptr(), parameters.as_ptr())
222        };
223        if handle.is_null() {
224            return Err(NetworkError::InvalidArgument(
225                "failed to create connection group".into(),
226            ));
227        }
228        Ok(Self {
229            handle,
230            state_callback: None,
231            receive_callback: None,
232            new_connection_callback: None,
233            keepalives: parameters.keepalives(),
234        })
235    }
236
237    /// Set a state-change callback. Call before [`start`](Self::start).
238    pub fn set_state_changed_handler<F>(&mut self, callback: F)
239    where
240        F: FnMut(ConnectionGroupState) + Send + 'static,
241    {
242        let callback: Box<dyn FnMut(ConnectionGroupState) + Send + 'static> = Box::new(callback);
243        let arc = Arc::new(Mutex::new(callback));
244        let raw = Arc::into_raw(arc.clone()).cast::<c_void>().cast_mut();
245        unsafe {
246            ffi::nw_shim_connection_group_set_state_changed_handler(
247                self.handle,
248                Some(state_trampoline),
249                raw,
250            )
251        };
252        self.state_callback = Some(arc);
253    }
254
255    /// Set the receive callback. Call before [`start`](Self::start).
256    pub fn set_receive_handler<F>(
257        &mut self,
258        maximum_message_size: u32,
259        reject_oversized_messages: bool,
260        callback: F,
261    ) where
262        F: FnMut(ConnectionGroupMessage) + Send + 'static,
263    {
264        let callback: Box<dyn FnMut(ConnectionGroupMessage) + Send + 'static> = Box::new(callback);
265        let arc = Arc::new(Mutex::new(callback));
266        let raw = Arc::into_raw(arc.clone()).cast::<c_void>().cast_mut();
267        unsafe {
268            ffi::nw_shim_connection_group_set_receive_handler(
269                self.handle,
270                maximum_message_size,
271                c_int::from(reject_oversized_messages),
272                Some(receive_trampoline),
273                raw,
274            )
275        };
276        self.receive_callback = Some(arc);
277    }
278
279    /// # Safety
280    ///
281    /// `handle` must be a valid retained connection-group handle owned by the
282    /// caller and remain alive for the returned wrapper.
283    #[must_use]
284    pub(crate) const unsafe fn from_raw(handle: *mut c_void, keepalives: KeepAlives) -> Self {
285        Self {
286            handle,
287            state_callback: None,
288            receive_callback: None,
289            new_connection_callback: None,
290            keepalives,
291        }
292    }
293
294    /// Start the connection group and wait for the initial state update.
295    pub fn start(&self) -> Result<(), NetworkError> {
296        let status = unsafe { ffi::nw_shim_connection_group_start(self.handle) };
297        if status != ffi::NW_OK {
298            return Err(from_status(status));
299        }
300        Ok(())
301    }
302
303    /// Send a message using the group's default destination semantics.
304    pub fn send(&self, data: &[u8], context: &ContentContext) -> Result<(), NetworkError> {
305        let status = unsafe {
306            ffi::nw_shim_connection_group_send(
307                self.handle,
308                data.as_ptr(),
309                data.len(),
310                core::ptr::null(),
311                0,
312                context.as_ptr(),
313            )
314        };
315        if status != ffi::NW_OK {
316            return Err(from_status(status));
317        }
318        Ok(())
319    }
320
321    /// Send a message to a specific endpoint.
322    pub fn send_to(
323        &self,
324        host: &str,
325        port: u16,
326        data: &[u8],
327        context: &ContentContext,
328    ) -> Result<(), NetworkError> {
329        let host = to_cstring(host, "host")?;
330        let status = unsafe {
331            ffi::nw_shim_connection_group_send(
332                self.handle,
333                data.as_ptr(),
334                data.len(),
335                host.as_ptr(),
336                port,
337                context.as_ptr(),
338            )
339        };
340        if status != ffi::NW_OK {
341            return Err(from_status(status));
342        }
343        Ok(())
344    }
345
346    /// Copy the underlying group descriptor.
347    #[must_use]
348    pub fn descriptor(&self) -> Option<ConnectionGroupDescriptor> {
349        let handle = unsafe { ffi::nw_shim_connection_group_copy_descriptor(self.handle) };
350        (!handle.is_null()).then_some(ConnectionGroupDescriptor { handle })
351    }
352
353    /// Copy the group's parameters snapshot.
354    #[must_use]
355    pub fn parameters(&self) -> Option<ConnectionParameters> {
356        let handle = unsafe { ffi::nw_shim_connection_group_copy_parameters(self.handle) };
357        (!handle.is_null()).then_some(unsafe { ConnectionParameters::from_raw(handle) })
358    }
359
360    /// Copy the remote endpoint associated with a received message.
361    #[must_use]
362    pub fn remote_endpoint_for_message(&self, context: &ContentContext) -> Option<Endpoint> {
363        let handle = unsafe {
364            ffi::nw_shim_connection_group_copy_remote_endpoint_for_message(
365                self.handle,
366                context.as_ptr(),
367            )
368        };
369        (!handle.is_null()).then_some(unsafe { Endpoint::from_raw(handle) })
370    }
371
372    /// Copy the local endpoint associated with a received message.
373    #[must_use]
374    pub fn local_endpoint_for_message(&self, context: &ContentContext) -> Option<Endpoint> {
375        let handle = unsafe {
376            ffi::nw_shim_connection_group_copy_local_endpoint_for_message(
377                self.handle,
378                context.as_ptr(),
379            )
380        };
381        (!handle.is_null()).then_some(unsafe { Endpoint::from_raw(handle) })
382    }
383
384    /// Copy the path associated with a received message.
385    #[must_use]
386    pub fn path_for_message(&self, context: &ContentContext) -> Option<Path> {
387        let handle = unsafe {
388            ffi::nw_shim_connection_group_copy_path_for_message(self.handle, context.as_ptr())
389        };
390        (!handle.is_null()).then_some(unsafe { Path::from_raw(handle) })
391    }
392
393    /// Copy group-wide protocol metadata for a specific protocol definition.
394    #[must_use]
395    pub fn protocol_metadata(&self, definition: &ProtocolDefinition) -> Option<ProtocolMetadata> {
396        let handle = unsafe {
397            ffi::nw_shim_connection_group_copy_protocol_metadata(self.handle, definition.as_ptr())
398        };
399        (!handle.is_null()).then_some(unsafe { ProtocolMetadata::from_raw(handle) })
400    }
401
402    /// Copy per-message protocol metadata for a specific protocol definition.
403    #[must_use]
404    pub fn protocol_metadata_for_message(
405        &self,
406        context: &ContentContext,
407        definition: &ProtocolDefinition,
408    ) -> Option<ProtocolMetadata> {
409        let handle = unsafe {
410            ffi::nw_shim_connection_group_copy_protocol_metadata_for_message(
411                self.handle,
412                context.as_ptr(),
413                definition.as_ptr(),
414            )
415        };
416        (!handle.is_null()).then_some(unsafe { ProtocolMetadata::from_raw(handle) })
417    }
418
419    /// Extract a new connection corresponding to a received message.
420    pub fn extract_connection_for_message(
421        &self,
422        context: &ContentContext,
423    ) -> Result<TcpClient, NetworkError> {
424        let mut status = ffi::NW_OK;
425        let handle = unsafe {
426            ffi::nw_shim_connection_group_extract_connection_for_message(
427                self.handle,
428                context.as_ptr(),
429                &mut status,
430            )
431        };
432        if status != ffi::NW_OK || handle.is_null() {
433            return Err(from_status(status));
434        }
435        Ok(unsafe { TcpClient::from_raw_with_keepalives(handle, self.keepalives.clone()) })
436    }
437
438    /// Extract a connection for a specific remote endpoint and protocol options.
439    pub fn extract_connection(
440        &self,
441        endpoint: &Endpoint,
442        protocol_options: &ProtocolOptions,
443    ) -> Result<TcpClient, NetworkError> {
444        let mut status = ffi::NW_OK;
445        let handle = unsafe {
446            ffi::nw_shim_connection_group_extract_connection(
447                self.handle,
448                endpoint.as_ptr(),
449                protocol_options.as_ptr(),
450                &mut status,
451            )
452        };
453        if status != ffi::NW_OK || handle.is_null() {
454            return Err(from_status(status));
455        }
456        Ok(unsafe { TcpClient::from_raw_with_keepalives(handle, self.keepalives.clone()) })
457    }
458
459    /// Receive callbacks for new connections accepted by the group.
460    pub fn set_new_connection_handler<F>(&mut self, callback: F)
461    where
462        F: FnMut(TcpClient) + Send + 'static,
463    {
464        let handler = Arc::new(NewConnectionCallback {
465            keepalives: self.keepalives.clone(),
466            callback: Mutex::new(Box::new(callback)),
467        });
468        let raw = Arc::into_raw(handler.clone()).cast::<c_void>().cast_mut();
469        unsafe {
470            ffi::nw_shim_connection_group_set_new_connection_handler(
471                self.handle,
472                Some(new_connection_trampoline),
473                raw,
474            );
475        };
476        self.new_connection_callback = Some(handler);
477    }
478
479    /// Reinsert an extracted connection back into the group.
480    pub fn reinsert_extracted_connection(&self, connection: TcpClient) -> Result<(), NetworkError> {
481        let status = unsafe {
482            ffi::nw_shim_connection_group_reinsert_extracted_connection(
483                self.handle,
484                connection.as_ptr(),
485            )
486        };
487        if status != ffi::NW_OK {
488            return Err(from_status(status));
489        }
490        std::mem::forget(connection);
491        Ok(())
492    }
493
494    /// Reply to an inbound message using the group's reply path.
495    pub fn reply(
496        &self,
497        inbound_message: &ContentContext,
498        outbound_message: Option<&ContentContext>,
499        data: &[u8],
500    ) -> Result<(), NetworkError> {
501        let status = unsafe {
502            ffi::nw_shim_connection_group_reply(
503                self.handle,
504                inbound_message.as_ptr(),
505                outbound_message.map_or(core::ptr::null_mut(), ContentContext::as_ptr),
506                data.as_ptr(),
507                data.len(),
508            )
509        };
510        if status != ffi::NW_OK {
511            return Err(from_status(status));
512        }
513        Ok(())
514    }
515
516    /// Cancel the connection group.
517    pub fn cancel(&self) {
518        unsafe { ffi::nw_shim_connection_group_cancel(self.handle) };
519    }
520}
521
522impl Drop for ConnectionGroup {
523    fn drop(&mut self) {
524        if !self.handle.is_null() {
525            unsafe { ffi::nw_shim_connection_group_release(self.handle) };
526            self.handle = core::ptr::null_mut();
527        }
528    }
529}
530
531unsafe extern "C" fn state_trampoline(state: c_int, user_info: *mut c_void) {
532    if user_info.is_null() {
533        return;
534    }
535    let callback = unsafe { &*user_info.cast::<StateCallback>() };
536    let Ok(mut guard) = callback.lock() else {
537        return;
538    };
539    let state = ConnectionGroupState::from_raw(state);
540    catch_user_panic("connection_group_state_trampoline", || {
541        guard(state);
542    });
543}
544
545unsafe extern "C" fn new_connection_trampoline(connection: *mut c_void, user_info: *mut c_void) {
546    if user_info.is_null() || connection.is_null() {
547        return;
548    }
549    let callback = unsafe { &*user_info.cast::<NewConnectionCallback>() };
550    let Ok(mut guard) = callback.callback.lock() else {
551        return;
552    };
553    let client =
554        unsafe { TcpClient::from_raw_with_keepalives(connection, callback.keepalives.clone()) };
555    catch_user_panic("connection_group_new_connection_trampoline", || {
556        guard(client);
557    });
558}
559
560unsafe extern "C" fn receive_trampoline(
561    data: *const u8,
562    len: usize,
563    context: *mut c_void,
564    is_complete: c_int,
565    user_info: *mut c_void,
566) {
567    if user_info.is_null() {
568        return;
569    }
570    let callback = unsafe { &*user_info.cast::<ReceiveCallback>() };
571    let Ok(mut guard) = callback.lock() else {
572        return;
573    };
574    let bytes = if data.is_null() || len == 0 {
575        Vec::new()
576    } else {
577        unsafe { std::slice::from_raw_parts(data, len) }.to_vec()
578    };
579    let context = if context.is_null() {
580        None
581    } else {
582        Some(unsafe { ContentContext::from_raw(context) })
583    };
584    let message = ConnectionGroupMessage {
585        data: bytes,
586        context,
587        is_complete: is_complete != 0,
588    };
589    catch_user_panic("connection_group_receive_trampoline", || {
590        guard(message);
591    });
592}