Skip to main content

systemconfiguration/
network_reachability.rs

1use std::{
2    ffi::c_void,
3    net::SocketAddr,
4    sync::{Arc, Mutex},
5};
6
7use crate::{bridge, error::Result, ffi, SystemConfigurationError};
8
9#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
10pub struct ReachabilityFlags(pub u32);
11
12impl ReachabilityFlags {
13    pub fn bits(self) -> u32 {
14        self.0
15    }
16
17    pub fn is_transient_connection(self) -> bool {
18        self.0 & (1 << 0) != 0
19    }
20
21    pub fn is_reachable(self) -> bool {
22        self.0 & (1 << 1) != 0
23    }
24
25    pub fn needs_connection(self) -> bool {
26        self.0 & (1 << 2) != 0
27    }
28
29    pub fn is_connection_on_traffic(self) -> bool {
30        self.0 & (1 << 3) != 0
31    }
32
33    pub fn needs_intervention(self) -> bool {
34        self.0 & (1 << 4) != 0
35    }
36
37    pub fn is_connection_on_demand(self) -> bool {
38        self.0 & (1 << 5) != 0
39    }
40
41    pub fn is_local_address(self) -> bool {
42        self.0 & (1 << 16) != 0
43    }
44
45    pub fn is_direct(self) -> bool {
46        self.0 & (1 << 17) != 0
47    }
48
49    pub fn is_wwan(self) -> bool {
50        self.0 & (1 << 18) != 0
51    }
52}
53
54impl std::fmt::Display for ReachabilityFlags {
55    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56        let mut labels = Vec::new();
57        if self.is_transient_connection() {
58            labels.push("transient");
59        }
60        if self.is_reachable() {
61            labels.push("reachable");
62        }
63        if self.needs_connection() {
64            labels.push("needs-connection");
65        }
66        if self.is_connection_on_traffic() {
67            labels.push("on-traffic");
68        }
69        if self.needs_intervention() {
70            labels.push("needs-intervention");
71        }
72        if self.is_connection_on_demand() {
73            labels.push("on-demand");
74        }
75        if self.is_local_address() {
76            labels.push("local-address");
77        }
78        if self.is_direct() {
79            labels.push("direct");
80        }
81        if self.is_wwan() {
82            labels.push("wwan");
83        }
84        if labels.is_empty() {
85            write!(f, "0x{:x}", self.bits())
86        } else {
87            write!(f, "{} (0x{:x})", labels.join("|"), self.bits())
88        }
89    }
90}
91
92struct LocalCallbackState {
93    callback: Box<dyn FnMut(ReachabilityFlags)>,
94}
95
96struct SendCallbackState {
97    callback: Box<dyn FnMut(ReachabilityFlags) + Send>,
98}
99
100enum RegisteredCallback {
101    Local {
102        _state: Box<LocalCallbackState>,
103    },
104    Send {
105        _state: Arc<Mutex<SendCallbackState>>,
106    },
107}
108
109unsafe extern "C" fn reachability_callback_local(flags: u32, info: *mut c_void) {
110    if info.is_null() {
111        return;
112    }
113
114    let state = unsafe { &mut *info.cast::<LocalCallbackState>() };
115    (state.callback)(ReachabilityFlags(flags));
116}
117
118unsafe extern "C" fn reachability_callback_send(flags: u32, info: *mut c_void) {
119    if info.is_null() {
120        return;
121    }
122
123    let mutex = unsafe { &*info.cast::<Mutex<SendCallbackState>>() };
124    if let Ok(mut state) = mutex.lock() {
125        (state.callback)(ReachabilityFlags(flags));
126    }
127}
128
129pub struct Reachability {
130    raw: bridge::OwnedHandle,
131    callback: Option<RegisteredCallback>,
132    scheduled_with_current_run_loop: bool,
133    dispatch_queue_active: bool,
134}
135
136pub type NetworkReachability = Reachability;
137
138impl Reachability {
139    pub fn type_id() -> u64 {
140        unsafe { ffi::network_reachability::sc_reachability_get_type_id() }
141    }
142
143    pub fn with_name(name: &str) -> Result<Self> {
144        let name = bridge::cstring(name, "sc_reachability_create_with_name")?;
145        let raw =
146            unsafe { ffi::network_reachability::sc_reachability_create_with_name(name.as_ptr()) };
147        let raw = bridge::owned_handle_or_last("sc_reachability_create_with_name", raw)?;
148        Ok(Self {
149            raw,
150            callback: None,
151            scheduled_with_current_run_loop: false,
152            dispatch_queue_active: false,
153        })
154    }
155
156    pub fn with_address(address: SocketAddr) -> Result<Self> {
157        let storage = socket_addr_to_bytes(address);
158        let raw = unsafe {
159            ffi::network_reachability::sc_reachability_create_with_address(
160                storage.as_ptr(),
161                isize::try_from(storage.len()).expect("socket address length exceeded isize"),
162            )
163        };
164        let raw = bridge::owned_handle_or_last("sc_reachability_create_with_address", raw)?;
165        Ok(Self {
166            raw,
167            callback: None,
168            scheduled_with_current_run_loop: false,
169            dispatch_queue_active: false,
170        })
171    }
172
173    pub fn with_address_pair(
174        local_address: Option<SocketAddr>,
175        remote_address: Option<SocketAddr>,
176    ) -> Result<Self> {
177        let local = local_address.map(socket_addr_to_bytes);
178        let remote = remote_address.map(socket_addr_to_bytes);
179        let raw = unsafe {
180            ffi::network_reachability::sc_reachability_create_with_address_pair(
181                local.as_ref().map_or(std::ptr::null(), Vec::as_ptr),
182                local.as_ref().map_or(0, |value| {
183                    isize::try_from(value.len()).expect("socket address length exceeded isize")
184                }),
185                remote.as_ref().map_or(std::ptr::null(), Vec::as_ptr),
186                remote.as_ref().map_or(0, |value| {
187                    isize::try_from(value.len()).expect("socket address length exceeded isize")
188                }),
189            )
190        };
191        let raw = bridge::owned_handle_or_last("sc_reachability_create_with_address_pair", raw)?;
192        Ok(Self {
193            raw,
194            callback: None,
195            scheduled_with_current_run_loop: false,
196            dispatch_queue_active: false,
197        })
198    }
199
200    pub fn flags(&self) -> Result<ReachabilityFlags> {
201        let mut flags = 0_u32;
202        let ok = unsafe {
203            ffi::network_reachability::sc_reachability_get_flags(self.raw.as_ptr(), &mut flags)
204        };
205        bridge::bool_result("sc_reachability_get_flags", ok)?;
206        Ok(ReachabilityFlags(flags))
207    }
208
209    pub fn set_callback<F>(&mut self, callback: F) -> Result<()>
210    where
211        F: FnMut(ReachabilityFlags) + 'static,
212    {
213        if self.dispatch_queue_active {
214            return Err(SystemConfigurationError::null(
215                "sc_reachability_set_callback",
216                "dispatch queues require callbacks registered via Reachability::set_callback_send; clear the dispatch queue first",
217            ));
218        }
219
220        let mut callback = Box::new(LocalCallbackState {
221            callback: Box::new(callback),
222        });
223        self.set_registered_callback(
224            Some(reachability_callback_local),
225            std::ptr::from_mut(&mut *callback).cast::<c_void>(),
226            Some(RegisteredCallback::Local { _state: callback }),
227        )
228    }
229
230    pub fn set_callback_send<F>(&mut self, callback: F) -> Result<()>
231    where
232        F: FnMut(ReachabilityFlags) + Send + 'static,
233    {
234        let callback = Arc::new(Mutex::new(SendCallbackState {
235            callback: Box::new(callback),
236        }));
237        self.set_registered_callback(
238            Some(reachability_callback_send),
239            Arc::as_ptr(&callback).cast_mut().cast::<c_void>(),
240            Some(RegisteredCallback::Send { _state: callback }),
241        )
242    }
243
244    pub fn clear_callback(&mut self) -> Result<()> {
245        if self.dispatch_queue_active {
246            self.clear_dispatch_queue()?;
247        }
248        self.set_registered_callback(None, std::ptr::null_mut(), None)
249    }
250
251    pub fn schedule_with_run_loop_current(&mut self) -> Result<()> {
252        let ok = unsafe {
253            ffi::network_reachability::sc_reachability_schedule_with_run_loop_current(
254                self.raw.as_ptr(),
255            )
256        };
257        bridge::bool_result("sc_reachability_schedule_with_run_loop_current", ok)?;
258        self.scheduled_with_current_run_loop = true;
259        Ok(())
260    }
261
262    pub fn unschedule_from_run_loop_current(&mut self) -> Result<()> {
263        let ok = unsafe {
264            ffi::network_reachability::sc_reachability_unschedule_from_run_loop_current(
265                self.raw.as_ptr(),
266            )
267        };
268        bridge::bool_result("sc_reachability_unschedule_from_run_loop_current", ok)?;
269        self.scheduled_with_current_run_loop = false;
270        Ok(())
271    }
272
273    pub fn set_dispatch_queue_global(&mut self) -> Result<()> {
274        if matches!(self.callback, Some(RegisteredCallback::Local { .. })) {
275            return Err(SystemConfigurationError::null(
276                "sc_reachability_set_dispatch_queue_global",
277                "dispatch queues require callbacks registered via Reachability::set_callback_send",
278            ));
279        }
280
281        let ok = unsafe {
282            ffi::network_reachability::sc_reachability_set_dispatch_queue_global(self.raw.as_ptr())
283        };
284        bridge::bool_result("sc_reachability_set_dispatch_queue_global", ok)?;
285        self.dispatch_queue_active = true;
286        Ok(())
287    }
288
289    pub fn clear_dispatch_queue(&mut self) -> Result<()> {
290        let ok = unsafe {
291            ffi::network_reachability::sc_reachability_clear_dispatch_queue(self.raw.as_ptr())
292        };
293        bridge::bool_result("sc_reachability_clear_dispatch_queue", ok)?;
294        self.dispatch_queue_active = false;
295        Ok(())
296    }
297
298    fn set_registered_callback(
299        &mut self,
300        callback: ffi::network_reachability::ReachabilityCallback,
301        info: *mut c_void,
302        registered: Option<RegisteredCallback>,
303    ) -> Result<()> {
304        let ok = unsafe {
305            ffi::network_reachability::sc_reachability_set_callback(
306                self.raw.as_ptr(),
307                callback,
308                info,
309            )
310        };
311        bridge::bool_result("sc_reachability_set_callback", ok)?;
312        self.callback = registered;
313        Ok(())
314    }
315}
316
317impl Drop for Reachability {
318    fn drop(&mut self) {
319        if self.dispatch_queue_active {
320            let _ = unsafe {
321                ffi::network_reachability::sc_reachability_clear_dispatch_queue(self.raw.as_ptr())
322            };
323        }
324        if self.scheduled_with_current_run_loop {
325            let _ = unsafe {
326                ffi::network_reachability::sc_reachability_unschedule_from_run_loop_current(
327                    self.raw.as_ptr(),
328                )
329            };
330        }
331        if self.callback.is_some() {
332            let _ = unsafe {
333                ffi::network_reachability::sc_reachability_set_callback(
334                    self.raw.as_ptr(),
335                    None,
336                    std::ptr::null_mut(),
337                )
338            };
339        }
340    }
341}
342
343fn socket_addr_to_bytes(address: SocketAddr) -> Vec<u8> {
344    match address {
345        SocketAddr::V4(address) => {
346            let mut storage: libc::sockaddr_in = unsafe { std::mem::zeroed() };
347            storage.sin_len = u8::try_from(std::mem::size_of::<libc::sockaddr_in>())
348                .expect("sockaddr_in length exceeds u8");
349            storage.sin_family = u8::try_from(libc::AF_INET).expect("AF_INET exceeds u8");
350            storage.sin_port = address.port().to_be();
351            storage.sin_addr = libc::in_addr {
352                s_addr: u32::from_ne_bytes(address.ip().octets()),
353            };
354            unsafe {
355                std::slice::from_raw_parts(
356                    std::ptr::from_ref(&storage).cast::<u8>(),
357                    std::mem::size_of::<libc::sockaddr_in>(),
358                )
359                .to_vec()
360            }
361        }
362        SocketAddr::V6(address) => {
363            let mut storage: libc::sockaddr_in6 = unsafe { std::mem::zeroed() };
364            storage.sin6_len = u8::try_from(std::mem::size_of::<libc::sockaddr_in6>())
365                .expect("sockaddr_in6 length exceeds u8");
366            storage.sin6_family = u8::try_from(libc::AF_INET6).expect("AF_INET6 exceeds u8");
367            storage.sin6_port = address.port().to_be();
368            storage.sin6_flowinfo = address.flowinfo();
369            storage.sin6_scope_id = address.scope_id();
370            storage.sin6_addr = libc::in6_addr {
371                s6_addr: address.ip().octets(),
372            };
373            unsafe {
374                std::slice::from_raw_parts(
375                    std::ptr::from_ref(&storage).cast::<u8>(),
376                    std::mem::size_of::<libc::sockaddr_in6>(),
377                )
378                .to_vec()
379            }
380        }
381    }
382}