Skip to main content

systemconfiguration/
network_reachability.rs

1use std::{ffi::c_void, net::SocketAddr};
2
3use crate::{bridge, error::Result, ffi};
4
5#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
6pub struct ReachabilityFlags(pub u32);
7
8impl ReachabilityFlags {
9    pub fn bits(self) -> u32 {
10        self.0
11    }
12
13    pub fn is_transient_connection(self) -> bool {
14        self.0 & (1 << 0) != 0
15    }
16
17    pub fn is_reachable(self) -> bool {
18        self.0 & (1 << 1) != 0
19    }
20
21    pub fn needs_connection(self) -> bool {
22        self.0 & (1 << 2) != 0
23    }
24
25    pub fn is_connection_on_traffic(self) -> bool {
26        self.0 & (1 << 3) != 0
27    }
28
29    pub fn needs_intervention(self) -> bool {
30        self.0 & (1 << 4) != 0
31    }
32
33    pub fn is_connection_on_demand(self) -> bool {
34        self.0 & (1 << 5) != 0
35    }
36
37    pub fn is_local_address(self) -> bool {
38        self.0 & (1 << 16) != 0
39    }
40
41    pub fn is_direct(self) -> bool {
42        self.0 & (1 << 17) != 0
43    }
44
45    pub fn is_wwan(self) -> bool {
46        self.0 & (1 << 18) != 0
47    }
48}
49
50impl std::fmt::Display for ReachabilityFlags {
51    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52        let mut labels = Vec::new();
53        if self.is_transient_connection() {
54            labels.push("transient");
55        }
56        if self.is_reachable() {
57            labels.push("reachable");
58        }
59        if self.needs_connection() {
60            labels.push("needs-connection");
61        }
62        if self.is_connection_on_traffic() {
63            labels.push("on-traffic");
64        }
65        if self.needs_intervention() {
66            labels.push("needs-intervention");
67        }
68        if self.is_connection_on_demand() {
69            labels.push("on-demand");
70        }
71        if self.is_local_address() {
72            labels.push("local-address");
73        }
74        if self.is_direct() {
75            labels.push("direct");
76        }
77        if self.is_wwan() {
78            labels.push("wwan");
79        }
80        if labels.is_empty() {
81            write!(f, "0x{:x}", self.bits())
82        } else {
83            write!(f, "{} (0x{:x})", labels.join("|"), self.bits())
84        }
85    }
86}
87
88struct CallbackState {
89    callback: Box<dyn FnMut(ReachabilityFlags)>,
90}
91
92unsafe extern "C" fn reachability_callback(flags: u32, info: *mut c_void) {
93    if info.is_null() {
94        return;
95    }
96
97    let state = &mut *info.cast::<CallbackState>();
98    (state.callback)(ReachabilityFlags(flags));
99}
100
101pub struct Reachability {
102    raw: bridge::OwnedHandle,
103    callback: Option<Box<CallbackState>>,
104    scheduled_with_current_run_loop: bool,
105}
106
107pub type NetworkReachability = Reachability;
108
109impl Reachability {
110    pub fn with_name(name: &str) -> Result<Self> {
111        let name = bridge::cstring(name, "sc_reachability_create_with_name")?;
112        let raw = unsafe { ffi::network_reachability::sc_reachability_create_with_name(name.as_ptr()) };
113        let raw = bridge::owned_handle_or_last("sc_reachability_create_with_name", raw)?;
114        Ok(Self {
115            raw,
116            callback: None,
117            scheduled_with_current_run_loop: false,
118        })
119    }
120
121    pub fn with_address(address: SocketAddr) -> Result<Self> {
122        let storage = socket_addr_to_bytes(address);
123        let raw = unsafe {
124            ffi::network_reachability::sc_reachability_create_with_address(
125                storage.as_ptr(),
126                isize::try_from(storage.len()).expect("socket address length exceeded isize"),
127            )
128        };
129        let raw = bridge::owned_handle_or_last("sc_reachability_create_with_address", raw)?;
130        Ok(Self {
131            raw,
132            callback: None,
133            scheduled_with_current_run_loop: false,
134        })
135    }
136
137    pub fn with_address_pair(local_address: Option<SocketAddr>, remote_address: Option<SocketAddr>) -> Result<Self> {
138        let local = local_address.map(socket_addr_to_bytes);
139        let remote = remote_address.map(socket_addr_to_bytes);
140        let raw = unsafe {
141            ffi::network_reachability::sc_reachability_create_with_address_pair(
142                local.as_ref().map_or(std::ptr::null(), Vec::as_ptr),
143                local.as_ref().map_or(0, |value| {
144                    isize::try_from(value.len()).expect("socket address length exceeded isize")
145                }),
146                remote.as_ref().map_or(std::ptr::null(), Vec::as_ptr),
147                remote.as_ref().map_or(0, |value| {
148                    isize::try_from(value.len()).expect("socket address length exceeded isize")
149                }),
150            )
151        };
152        let raw = bridge::owned_handle_or_last("sc_reachability_create_with_address_pair", raw)?;
153        Ok(Self {
154            raw,
155            callback: None,
156            scheduled_with_current_run_loop: false,
157        })
158    }
159
160    pub fn flags(&self) -> Result<ReachabilityFlags> {
161        let mut flags = 0_u32;
162        let ok = unsafe {
163            ffi::network_reachability::sc_reachability_get_flags(self.raw.as_ptr(), &mut flags)
164        };
165        bridge::bool_result("sc_reachability_get_flags", ok)?;
166        Ok(ReachabilityFlags(flags))
167    }
168
169    pub fn set_callback<F>(&mut self, callback: F) -> Result<()>
170    where
171        F: FnMut(ReachabilityFlags) + 'static,
172    {
173        let mut callback = Box::new(CallbackState {
174            callback: Box::new(callback),
175        });
176        let ok = unsafe {
177            ffi::network_reachability::sc_reachability_set_callback(
178                self.raw.as_ptr(),
179                Some(reachability_callback),
180                std::ptr::from_mut(&mut *callback).cast::<c_void>(),
181            )
182        };
183        bridge::bool_result("sc_reachability_set_callback", ok)?;
184        self.callback = Some(callback);
185        Ok(())
186    }
187
188    pub fn clear_callback(&mut self) -> Result<()> {
189        let ok = unsafe {
190            ffi::network_reachability::sc_reachability_set_callback(
191                self.raw.as_ptr(),
192                None,
193                std::ptr::null_mut(),
194            )
195        };
196        bridge::bool_result("sc_reachability_set_callback", ok)?;
197        self.callback = None;
198        Ok(())
199    }
200
201    pub fn schedule_with_run_loop_current(&mut self) -> Result<()> {
202        let ok = unsafe {
203            ffi::network_reachability::sc_reachability_schedule_with_run_loop_current(self.raw.as_ptr())
204        };
205        bridge::bool_result("sc_reachability_schedule_with_run_loop_current", ok)?;
206        self.scheduled_with_current_run_loop = true;
207        Ok(())
208    }
209
210    pub fn unschedule_from_run_loop_current(&mut self) -> Result<()> {
211        let ok = unsafe {
212            ffi::network_reachability::sc_reachability_unschedule_from_run_loop_current(self.raw.as_ptr())
213        };
214        bridge::bool_result("sc_reachability_unschedule_from_run_loop_current", ok)?;
215        self.scheduled_with_current_run_loop = false;
216        Ok(())
217    }
218}
219
220impl Drop for Reachability {
221    fn drop(&mut self) {
222        if self.scheduled_with_current_run_loop {
223            let _ = unsafe {
224                ffi::network_reachability::sc_reachability_unschedule_from_run_loop_current(
225                    self.raw.as_ptr(),
226                )
227            };
228        }
229        if self.callback.is_some() {
230            let _ = unsafe {
231                ffi::network_reachability::sc_reachability_set_callback(
232                    self.raw.as_ptr(),
233                    None,
234                    std::ptr::null_mut(),
235                )
236            };
237        }
238    }
239}
240
241fn socket_addr_to_bytes(address: SocketAddr) -> Vec<u8> {
242    match address {
243        SocketAddr::V4(address) => {
244            let mut storage: libc::sockaddr_in = unsafe { std::mem::zeroed() };
245            storage.sin_len = u8::try_from(std::mem::size_of::<libc::sockaddr_in>())
246                .expect("sockaddr_in length exceeds u8");
247            storage.sin_family = u8::try_from(libc::AF_INET).expect("AF_INET exceeds u8");
248            storage.sin_port = address.port().to_be();
249            storage.sin_addr = libc::in_addr {
250                s_addr: u32::from_ne_bytes(address.ip().octets()),
251            };
252            unsafe {
253                std::slice::from_raw_parts(
254                    std::ptr::from_ref(&storage).cast::<u8>(),
255                    std::mem::size_of::<libc::sockaddr_in>(),
256                )
257                .to_vec()
258            }
259        }
260        SocketAddr::V6(address) => {
261            let mut storage: libc::sockaddr_in6 = unsafe { std::mem::zeroed() };
262            storage.sin6_len = u8::try_from(std::mem::size_of::<libc::sockaddr_in6>())
263                .expect("sockaddr_in6 length exceeds u8");
264            storage.sin6_family = u8::try_from(libc::AF_INET6).expect("AF_INET6 exceeds u8");
265            storage.sin6_port = address.port().to_be();
266            storage.sin6_flowinfo = address.flowinfo();
267            storage.sin6_scope_id = address.scope_id();
268            storage.sin6_addr = libc::in6_addr {
269                s6_addr: address.ip().octets(),
270            };
271            unsafe {
272                std::slice::from_raw_parts(
273                    std::ptr::from_ref(&storage).cast::<u8>(),
274                    std::mem::size_of::<libc::sockaddr_in6>(),
275                )
276                .to_vec()
277            }
278        }
279    }
280}