Skip to main content

systemconfiguration/
network_reachability.rs

1use std::{
2    ffi::c_void,
3    net::SocketAddr,
4    panic::AssertUnwindSafe,
5    sync::{Arc, Mutex},
6};
7
8use crate::{bridge, error::Result, ffi, SystemConfigurationError};
9
10#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
11/// Wraps `SCNetworkReachabilityFlags`.
12pub struct ReachabilityFlags(
13    /// Wraps the raw `SCNetworkReachabilityFlags` bitfield.
14    pub u32,
15);
16
17impl ReachabilityFlags {
18    /// Wraps a helper on `SCNetworkReachabilityFlags`.
19    pub fn bits(self) -> u32 {
20        self.0
21    }
22
23    /// Wraps a helper on `SCNetworkReachabilityFlags`.
24    pub fn is_transient_connection(self) -> bool {
25        self.0 & (1 << 0) != 0
26    }
27
28    /// Wraps a helper on `SCNetworkReachabilityFlags`.
29    pub fn is_reachable(self) -> bool {
30        self.0 & (1 << 1) != 0
31    }
32
33    /// Wraps a helper on `SCNetworkReachabilityFlags`.
34    pub fn needs_connection(self) -> bool {
35        self.0 & (1 << 2) != 0
36    }
37
38    /// Wraps a helper on `SCNetworkReachabilityFlags`.
39    pub fn is_connection_on_traffic(self) -> bool {
40        self.0 & (1 << 3) != 0
41    }
42
43    /// Wraps a helper on `SCNetworkReachabilityFlags`.
44    pub fn needs_intervention(self) -> bool {
45        self.0 & (1 << 4) != 0
46    }
47
48    /// Wraps a helper on `SCNetworkReachabilityFlags`.
49    pub fn is_connection_on_demand(self) -> bool {
50        self.0 & (1 << 5) != 0
51    }
52
53    /// Wraps a helper on `SCNetworkReachabilityFlags`.
54    pub fn is_local_address(self) -> bool {
55        self.0 & (1 << 16) != 0
56    }
57
58    /// Wraps a helper on `SCNetworkReachabilityFlags`.
59    pub fn is_direct(self) -> bool {
60        self.0 & (1 << 17) != 0
61    }
62
63    /// Wraps a helper on `SCNetworkReachabilityFlags`.
64    pub fn is_wwan(self) -> bool {
65        self.0 & (1 << 18) != 0
66    }
67}
68
69impl std::fmt::Display for ReachabilityFlags {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        let mut labels = Vec::new();
72        if self.is_transient_connection() {
73            labels.push("transient");
74        }
75        if self.is_reachable() {
76            labels.push("reachable");
77        }
78        if self.needs_connection() {
79            labels.push("needs-connection");
80        }
81        if self.is_connection_on_traffic() {
82            labels.push("on-traffic");
83        }
84        if self.needs_intervention() {
85            labels.push("needs-intervention");
86        }
87        if self.is_connection_on_demand() {
88            labels.push("on-demand");
89        }
90        if self.is_local_address() {
91            labels.push("local-address");
92        }
93        if self.is_direct() {
94            labels.push("direct");
95        }
96        if self.is_wwan() {
97            labels.push("wwan");
98        }
99        if labels.is_empty() {
100            write!(f, "0x{:x}", self.bits())
101        } else {
102            write!(f, "{} (0x{:x})", labels.join("|"), self.bits())
103        }
104    }
105}
106
107struct LocalCallbackState {
108    callback: Box<dyn FnMut(ReachabilityFlags)>,
109}
110
111struct SendCallbackState {
112    callback: Box<dyn FnMut(ReachabilityFlags) + Send>,
113}
114
115enum RegisteredCallback {
116    Local {
117        _state: Box<LocalCallbackState>,
118    },
119    Send {
120        _state: Arc<Mutex<SendCallbackState>>,
121    },
122}
123
124unsafe extern "C" fn reachability_callback_local(flags: u32, info: *mut c_void) {
125    if info.is_null() {
126        return;
127    }
128
129    let state = unsafe { &mut *info.cast::<LocalCallbackState>() };
130    // Catch panics: unwinding across the Swift/C FFI boundary is UB.
131    let _ = std::panic::catch_unwind(AssertUnwindSafe(|| {
132        (state.callback)(ReachabilityFlags(flags));
133    }));
134}
135
136unsafe extern "C" fn reachability_callback_send(flags: u32, info: *mut c_void) {
137    if info.is_null() {
138        return;
139    }
140
141    let mutex = unsafe { &*info.cast::<Mutex<SendCallbackState>>() };
142    if let Ok(mut state) = mutex.lock() {
143        // Catch panics: unwinding across the Swift/C FFI boundary is UB.
144        let _ = std::panic::catch_unwind(AssertUnwindSafe(|| {
145            (state.callback)(ReachabilityFlags(flags));
146        }));
147    }
148}
149
150/// Wraps `SCNetworkReachabilityRef`.
151pub struct Reachability {
152    raw: bridge::OwnedHandle,
153    callback: Option<RegisteredCallback>,
154    scheduled_with_current_run_loop: bool,
155    dispatch_queue_active: bool,
156}
157
158/// Alias for the `SCNetworkReachabilityRef` wrapper.
159pub type NetworkReachability = Reachability;
160
161impl Reachability {
162    /// Wraps `SCReachabilityGetTypeID`.
163    pub fn type_id() -> u64 {
164        unsafe { ffi::network_reachability::sc_reachability_get_type_id() }
165    }
166
167    /// Wraps `SCReachabilityCreateWithName`.
168    pub fn with_name(name: &str) -> Result<Self> {
169        let name = bridge::cstring(name, "sc_reachability_create_with_name")?;
170        let raw =
171            unsafe { ffi::network_reachability::sc_reachability_create_with_name(name.as_ptr()) };
172        let raw = bridge::owned_handle_or_last("sc_reachability_create_with_name", raw)?;
173        Ok(Self {
174            raw,
175            callback: None,
176            scheduled_with_current_run_loop: false,
177            dispatch_queue_active: false,
178        })
179    }
180
181    /// Wraps `SCReachabilityCreateWithAddress`.
182    pub fn with_address(address: SocketAddr) -> Result<Self> {
183        let storage = socket_addr_to_bytes(address);
184        let raw = unsafe {
185            ffi::network_reachability::sc_reachability_create_with_address(
186                storage.as_ptr(),
187                isize::try_from(storage.len()).expect("socket address length exceeded isize"),
188            )
189        };
190        let raw = bridge::owned_handle_or_last("sc_reachability_create_with_address", raw)?;
191        Ok(Self {
192            raw,
193            callback: None,
194            scheduled_with_current_run_loop: false,
195            dispatch_queue_active: false,
196        })
197    }
198
199    /// Wraps `SCReachabilityCreateWithAddressPair`.
200    pub fn with_address_pair(
201        local_address: Option<SocketAddr>,
202        remote_address: Option<SocketAddr>,
203    ) -> Result<Self> {
204        let local = local_address.map(socket_addr_to_bytes);
205        let remote = remote_address.map(socket_addr_to_bytes);
206        let raw = unsafe {
207            ffi::network_reachability::sc_reachability_create_with_address_pair(
208                local.as_ref().map_or(std::ptr::null(), Vec::as_ptr),
209                local.as_ref().map_or(0, |value| {
210                    isize::try_from(value.len()).expect("socket address length exceeded isize")
211                }),
212                remote.as_ref().map_or(std::ptr::null(), Vec::as_ptr),
213                remote.as_ref().map_or(0, |value| {
214                    isize::try_from(value.len()).expect("socket address length exceeded isize")
215                }),
216            )
217        };
218        let raw = bridge::owned_handle_or_last("sc_reachability_create_with_address_pair", raw)?;
219        Ok(Self {
220            raw,
221            callback: None,
222            scheduled_with_current_run_loop: false,
223            dispatch_queue_active: false,
224        })
225    }
226
227    /// Wraps `SCReachabilityGetFlags`.
228    pub fn flags(&self) -> Result<ReachabilityFlags> {
229        let mut flags = 0_u32;
230        let ok = unsafe {
231            ffi::network_reachability::sc_reachability_get_flags(self.raw.as_ptr(), &mut flags)
232        };
233        bridge::bool_result("sc_reachability_get_flags", ok)?;
234        Ok(ReachabilityFlags(flags))
235    }
236
237    /// Wraps a helper on `SCNetworkReachabilityRef`.
238    pub fn set_callback<F>(&mut self, callback: F) -> Result<()>
239    where
240        F: FnMut(ReachabilityFlags) + 'static,
241    {
242        if self.dispatch_queue_active {
243            return Err(SystemConfigurationError::null(
244                "sc_reachability_set_callback",
245                "dispatch queues require callbacks registered via Reachability::set_callback_send; clear the dispatch queue first",
246            ));
247        }
248
249        let mut callback = Box::new(LocalCallbackState {
250            callback: Box::new(callback),
251        });
252        self.set_registered_callback(
253            Some(reachability_callback_local),
254            std::ptr::from_mut(&mut *callback).cast::<c_void>(),
255            Some(RegisteredCallback::Local { _state: callback }),
256        )
257    }
258
259    /// Wraps a helper on `SCNetworkReachabilityRef`.
260    pub fn set_callback_send<F>(&mut self, callback: F) -> Result<()>
261    where
262        F: FnMut(ReachabilityFlags) + Send + 'static,
263    {
264        let callback = Arc::new(Mutex::new(SendCallbackState {
265            callback: Box::new(callback),
266        }));
267        self.set_registered_callback(
268            Some(reachability_callback_send),
269            Arc::as_ptr(&callback).cast_mut().cast::<c_void>(),
270            Some(RegisteredCallback::Send { _state: callback }),
271        )
272    }
273
274    /// Wraps a helper on `SCNetworkReachabilityRef`.
275    pub fn clear_callback(&mut self) -> Result<()> {
276        if self.dispatch_queue_active {
277            self.clear_dispatch_queue()?;
278        }
279        self.set_registered_callback(None, std::ptr::null_mut(), None)
280    }
281
282    /// Wraps `SCReachabilityScheduleWithRunLoopCurrent`.
283    pub fn schedule_with_run_loop_current(&mut self) -> Result<()> {
284        let ok = unsafe {
285            ffi::network_reachability::sc_reachability_schedule_with_run_loop_current(
286                self.raw.as_ptr(),
287            )
288        };
289        bridge::bool_result("sc_reachability_schedule_with_run_loop_current", ok)?;
290        self.scheduled_with_current_run_loop = true;
291        Ok(())
292    }
293
294    /// Wraps `SCReachabilityUnscheduleFromRunLoopCurrent`.
295    pub fn unschedule_from_run_loop_current(&mut self) -> Result<()> {
296        let ok = unsafe {
297            ffi::network_reachability::sc_reachability_unschedule_from_run_loop_current(
298                self.raw.as_ptr(),
299            )
300        };
301        bridge::bool_result("sc_reachability_unschedule_from_run_loop_current", ok)?;
302        self.scheduled_with_current_run_loop = false;
303        Ok(())
304    }
305
306    /// Wraps `SCReachabilitySetDispatchQueueGlobal`.
307    pub fn set_dispatch_queue_global(&mut self) -> Result<()> {
308        if matches!(self.callback, Some(RegisteredCallback::Local { .. })) {
309            return Err(SystemConfigurationError::null(
310                "sc_reachability_set_dispatch_queue_global",
311                "dispatch queues require callbacks registered via Reachability::set_callback_send",
312            ));
313        }
314
315        let ok = unsafe {
316            ffi::network_reachability::sc_reachability_set_dispatch_queue_global(self.raw.as_ptr())
317        };
318        bridge::bool_result("sc_reachability_set_dispatch_queue_global", ok)?;
319        self.dispatch_queue_active = true;
320        Ok(())
321    }
322
323    /// Wraps `SCReachabilityClearDispatchQueue`.
324    pub fn clear_dispatch_queue(&mut self) -> Result<()> {
325        let ok = unsafe {
326            ffi::network_reachability::sc_reachability_clear_dispatch_queue(self.raw.as_ptr())
327        };
328        bridge::bool_result("sc_reachability_clear_dispatch_queue", ok)?;
329        self.dispatch_queue_active = false;
330        Ok(())
331    }
332
333    fn set_registered_callback(
334        &mut self,
335        callback: ffi::network_reachability::ReachabilityCallback,
336        info: *mut c_void,
337        registered: Option<RegisteredCallback>,
338    ) -> Result<()> {
339        let ok = unsafe {
340            ffi::network_reachability::sc_reachability_set_callback(
341                self.raw.as_ptr(),
342                callback,
343                info,
344            )
345        };
346        bridge::bool_result("sc_reachability_set_callback", ok)?;
347        self.callback = registered;
348        Ok(())
349    }
350}
351
352impl Drop for Reachability {
353    fn drop(&mut self) {
354        if self.dispatch_queue_active {
355            let _ = unsafe {
356                ffi::network_reachability::sc_reachability_clear_dispatch_queue(self.raw.as_ptr())
357            };
358        }
359        if self.scheduled_with_current_run_loop {
360            let _ = unsafe {
361                ffi::network_reachability::sc_reachability_unschedule_from_run_loop_current(
362                    self.raw.as_ptr(),
363                )
364            };
365        }
366        if self.callback.is_some() {
367            let _ = unsafe {
368                ffi::network_reachability::sc_reachability_set_callback(
369                    self.raw.as_ptr(),
370                    None,
371                    std::ptr::null_mut(),
372                )
373            };
374        }
375    }
376}
377
378fn socket_addr_to_bytes(address: SocketAddr) -> Vec<u8> {
379    match address {
380        SocketAddr::V4(address) => {
381            let mut storage: libc::sockaddr_in = unsafe { std::mem::zeroed() };
382            storage.sin_len = u8::try_from(std::mem::size_of::<libc::sockaddr_in>())
383                .expect("sockaddr_in length exceeds u8");
384            storage.sin_family = u8::try_from(libc::AF_INET).expect("AF_INET exceeds u8");
385            storage.sin_port = address.port().to_be();
386            storage.sin_addr = libc::in_addr {
387                s_addr: u32::from_ne_bytes(address.ip().octets()),
388            };
389            unsafe {
390                std::slice::from_raw_parts(
391                    std::ptr::from_ref(&storage).cast::<u8>(),
392                    std::mem::size_of::<libc::sockaddr_in>(),
393                )
394                .to_vec()
395            }
396        }
397        SocketAddr::V6(address) => {
398            let mut storage: libc::sockaddr_in6 = unsafe { std::mem::zeroed() };
399            storage.sin6_len = u8::try_from(std::mem::size_of::<libc::sockaddr_in6>())
400                .expect("sockaddr_in6 length exceeds u8");
401            storage.sin6_family = u8::try_from(libc::AF_INET6).expect("AF_INET6 exceeds u8");
402            storage.sin6_port = address.port().to_be();
403            storage.sin6_flowinfo = address.flowinfo();
404            storage.sin6_scope_id = address.scope_id();
405            storage.sin6_addr = libc::in6_addr {
406                s6_addr: address.ip().octets(),
407            };
408            unsafe {
409                std::slice::from_raw_parts(
410                    std::ptr::from_ref(&storage).cast::<u8>(),
411                    std::mem::size_of::<libc::sockaddr_in6>(),
412                )
413                .to_vec()
414            }
415        }
416    }
417}