1use std::{
2 ffi::c_void,
3 net::SocketAddr,
4 sync::{Arc, Mutex},
5};
6
7use crate::{bridge, error::Result, ffi, SystemConfigurationError};
8
9#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
10pub struct ReachabilityFlags(
12 pub u32,
14);
15
16impl ReachabilityFlags {
17 pub fn bits(self) -> u32 {
19 self.0
20 }
21
22 pub fn is_transient_connection(self) -> bool {
24 self.0 & (1 << 0) != 0
25 }
26
27 pub fn is_reachable(self) -> bool {
29 self.0 & (1 << 1) != 0
30 }
31
32 pub fn needs_connection(self) -> bool {
34 self.0 & (1 << 2) != 0
35 }
36
37 pub fn is_connection_on_traffic(self) -> bool {
39 self.0 & (1 << 3) != 0
40 }
41
42 pub fn needs_intervention(self) -> bool {
44 self.0 & (1 << 4) != 0
45 }
46
47 pub fn is_connection_on_demand(self) -> bool {
49 self.0 & (1 << 5) != 0
50 }
51
52 pub fn is_local_address(self) -> bool {
54 self.0 & (1 << 16) != 0
55 }
56
57 pub fn is_direct(self) -> bool {
59 self.0 & (1 << 17) != 0
60 }
61
62 pub fn is_wwan(self) -> bool {
64 self.0 & (1 << 18) != 0
65 }
66}
67
68impl std::fmt::Display for ReachabilityFlags {
69 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70 let mut labels = Vec::new();
71 if self.is_transient_connection() {
72 labels.push("transient");
73 }
74 if self.is_reachable() {
75 labels.push("reachable");
76 }
77 if self.needs_connection() {
78 labels.push("needs-connection");
79 }
80 if self.is_connection_on_traffic() {
81 labels.push("on-traffic");
82 }
83 if self.needs_intervention() {
84 labels.push("needs-intervention");
85 }
86 if self.is_connection_on_demand() {
87 labels.push("on-demand");
88 }
89 if self.is_local_address() {
90 labels.push("local-address");
91 }
92 if self.is_direct() {
93 labels.push("direct");
94 }
95 if self.is_wwan() {
96 labels.push("wwan");
97 }
98 if labels.is_empty() {
99 write!(f, "0x{:x}", self.bits())
100 } else {
101 write!(f, "{} (0x{:x})", labels.join("|"), self.bits())
102 }
103 }
104}
105
106struct LocalCallbackState {
107 callback: Box<dyn FnMut(ReachabilityFlags)>,
108}
109
110struct SendCallbackState {
111 callback: Box<dyn FnMut(ReachabilityFlags) + Send>,
112}
113
114enum RegisteredCallback {
115 Local {
116 _state: Box<LocalCallbackState>,
117 },
118 Send {
119 _state: Arc<Mutex<SendCallbackState>>,
120 },
121}
122
123unsafe extern "C" fn reachability_callback_local(flags: u32, info: *mut c_void) {
124 if info.is_null() {
125 return;
126 }
127
128 let state = unsafe { &mut *info.cast::<LocalCallbackState>() };
129 (state.callback)(ReachabilityFlags(flags));
130}
131
132unsafe extern "C" fn reachability_callback_send(flags: u32, info: *mut c_void) {
133 if info.is_null() {
134 return;
135 }
136
137 let mutex = unsafe { &*info.cast::<Mutex<SendCallbackState>>() };
138 if let Ok(mut state) = mutex.lock() {
139 (state.callback)(ReachabilityFlags(flags));
140 }
141}
142
143pub struct Reachability {
145 raw: bridge::OwnedHandle,
146 callback: Option<RegisteredCallback>,
147 scheduled_with_current_run_loop: bool,
148 dispatch_queue_active: bool,
149}
150
151pub type NetworkReachability = Reachability;
153
154impl Reachability {
155 pub fn type_id() -> u64 {
157 unsafe { ffi::network_reachability::sc_reachability_get_type_id() }
158 }
159
160 pub fn with_name(name: &str) -> Result<Self> {
162 let name = bridge::cstring(name, "sc_reachability_create_with_name")?;
163 let raw =
164 unsafe { ffi::network_reachability::sc_reachability_create_with_name(name.as_ptr()) };
165 let raw = bridge::owned_handle_or_last("sc_reachability_create_with_name", raw)?;
166 Ok(Self {
167 raw,
168 callback: None,
169 scheduled_with_current_run_loop: false,
170 dispatch_queue_active: false,
171 })
172 }
173
174 pub fn with_address(address: SocketAddr) -> Result<Self> {
176 let storage = socket_addr_to_bytes(address);
177 let raw = unsafe {
178 ffi::network_reachability::sc_reachability_create_with_address(
179 storage.as_ptr(),
180 isize::try_from(storage.len()).expect("socket address length exceeded isize"),
181 )
182 };
183 let raw = bridge::owned_handle_or_last("sc_reachability_create_with_address", raw)?;
184 Ok(Self {
185 raw,
186 callback: None,
187 scheduled_with_current_run_loop: false,
188 dispatch_queue_active: false,
189 })
190 }
191
192 pub fn with_address_pair(
194 local_address: Option<SocketAddr>,
195 remote_address: Option<SocketAddr>,
196 ) -> Result<Self> {
197 let local = local_address.map(socket_addr_to_bytes);
198 let remote = remote_address.map(socket_addr_to_bytes);
199 let raw = unsafe {
200 ffi::network_reachability::sc_reachability_create_with_address_pair(
201 local.as_ref().map_or(std::ptr::null(), Vec::as_ptr),
202 local.as_ref().map_or(0, |value| {
203 isize::try_from(value.len()).expect("socket address length exceeded isize")
204 }),
205 remote.as_ref().map_or(std::ptr::null(), Vec::as_ptr),
206 remote.as_ref().map_or(0, |value| {
207 isize::try_from(value.len()).expect("socket address length exceeded isize")
208 }),
209 )
210 };
211 let raw = bridge::owned_handle_or_last("sc_reachability_create_with_address_pair", raw)?;
212 Ok(Self {
213 raw,
214 callback: None,
215 scheduled_with_current_run_loop: false,
216 dispatch_queue_active: false,
217 })
218 }
219
220 pub fn flags(&self) -> Result<ReachabilityFlags> {
222 let mut flags = 0_u32;
223 let ok = unsafe {
224 ffi::network_reachability::sc_reachability_get_flags(self.raw.as_ptr(), &mut flags)
225 };
226 bridge::bool_result("sc_reachability_get_flags", ok)?;
227 Ok(ReachabilityFlags(flags))
228 }
229
230 pub fn set_callback<F>(&mut self, callback: F) -> Result<()>
232 where
233 F: FnMut(ReachabilityFlags) + 'static,
234 {
235 if self.dispatch_queue_active {
236 return Err(SystemConfigurationError::null(
237 "sc_reachability_set_callback",
238 "dispatch queues require callbacks registered via Reachability::set_callback_send; clear the dispatch queue first",
239 ));
240 }
241
242 let mut callback = Box::new(LocalCallbackState {
243 callback: Box::new(callback),
244 });
245 self.set_registered_callback(
246 Some(reachability_callback_local),
247 std::ptr::from_mut(&mut *callback).cast::<c_void>(),
248 Some(RegisteredCallback::Local { _state: callback }),
249 )
250 }
251
252 pub fn set_callback_send<F>(&mut self, callback: F) -> Result<()>
254 where
255 F: FnMut(ReachabilityFlags) + Send + 'static,
256 {
257 let callback = Arc::new(Mutex::new(SendCallbackState {
258 callback: Box::new(callback),
259 }));
260 self.set_registered_callback(
261 Some(reachability_callback_send),
262 Arc::as_ptr(&callback).cast_mut().cast::<c_void>(),
263 Some(RegisteredCallback::Send { _state: callback }),
264 )
265 }
266
267 pub fn clear_callback(&mut self) -> Result<()> {
269 if self.dispatch_queue_active {
270 self.clear_dispatch_queue()?;
271 }
272 self.set_registered_callback(None, std::ptr::null_mut(), None)
273 }
274
275 pub fn schedule_with_run_loop_current(&mut self) -> Result<()> {
277 let ok = unsafe {
278 ffi::network_reachability::sc_reachability_schedule_with_run_loop_current(
279 self.raw.as_ptr(),
280 )
281 };
282 bridge::bool_result("sc_reachability_schedule_with_run_loop_current", ok)?;
283 self.scheduled_with_current_run_loop = true;
284 Ok(())
285 }
286
287 pub fn unschedule_from_run_loop_current(&mut self) -> Result<()> {
289 let ok = unsafe {
290 ffi::network_reachability::sc_reachability_unschedule_from_run_loop_current(
291 self.raw.as_ptr(),
292 )
293 };
294 bridge::bool_result("sc_reachability_unschedule_from_run_loop_current", ok)?;
295 self.scheduled_with_current_run_loop = false;
296 Ok(())
297 }
298
299 pub fn set_dispatch_queue_global(&mut self) -> Result<()> {
301 if matches!(self.callback, Some(RegisteredCallback::Local { .. })) {
302 return Err(SystemConfigurationError::null(
303 "sc_reachability_set_dispatch_queue_global",
304 "dispatch queues require callbacks registered via Reachability::set_callback_send",
305 ));
306 }
307
308 let ok = unsafe {
309 ffi::network_reachability::sc_reachability_set_dispatch_queue_global(self.raw.as_ptr())
310 };
311 bridge::bool_result("sc_reachability_set_dispatch_queue_global", ok)?;
312 self.dispatch_queue_active = true;
313 Ok(())
314 }
315
316 pub fn clear_dispatch_queue(&mut self) -> Result<()> {
318 let ok = unsafe {
319 ffi::network_reachability::sc_reachability_clear_dispatch_queue(self.raw.as_ptr())
320 };
321 bridge::bool_result("sc_reachability_clear_dispatch_queue", ok)?;
322 self.dispatch_queue_active = false;
323 Ok(())
324 }
325
326 fn set_registered_callback(
327 &mut self,
328 callback: ffi::network_reachability::ReachabilityCallback,
329 info: *mut c_void,
330 registered: Option<RegisteredCallback>,
331 ) -> Result<()> {
332 let ok = unsafe {
333 ffi::network_reachability::sc_reachability_set_callback(
334 self.raw.as_ptr(),
335 callback,
336 info,
337 )
338 };
339 bridge::bool_result("sc_reachability_set_callback", ok)?;
340 self.callback = registered;
341 Ok(())
342 }
343}
344
345impl Drop for Reachability {
346 fn drop(&mut self) {
347 if self.dispatch_queue_active {
348 let _ = unsafe {
349 ffi::network_reachability::sc_reachability_clear_dispatch_queue(self.raw.as_ptr())
350 };
351 }
352 if self.scheduled_with_current_run_loop {
353 let _ = unsafe {
354 ffi::network_reachability::sc_reachability_unschedule_from_run_loop_current(
355 self.raw.as_ptr(),
356 )
357 };
358 }
359 if self.callback.is_some() {
360 let _ = unsafe {
361 ffi::network_reachability::sc_reachability_set_callback(
362 self.raw.as_ptr(),
363 None,
364 std::ptr::null_mut(),
365 )
366 };
367 }
368 }
369}
370
371fn socket_addr_to_bytes(address: SocketAddr) -> Vec<u8> {
372 match address {
373 SocketAddr::V4(address) => {
374 let mut storage: libc::sockaddr_in = unsafe { std::mem::zeroed() };
375 storage.sin_len = u8::try_from(std::mem::size_of::<libc::sockaddr_in>())
376 .expect("sockaddr_in length exceeds u8");
377 storage.sin_family = u8::try_from(libc::AF_INET).expect("AF_INET exceeds u8");
378 storage.sin_port = address.port().to_be();
379 storage.sin_addr = libc::in_addr {
380 s_addr: u32::from_ne_bytes(address.ip().octets()),
381 };
382 unsafe {
383 std::slice::from_raw_parts(
384 std::ptr::from_ref(&storage).cast::<u8>(),
385 std::mem::size_of::<libc::sockaddr_in>(),
386 )
387 .to_vec()
388 }
389 }
390 SocketAddr::V6(address) => {
391 let mut storage: libc::sockaddr_in6 = unsafe { std::mem::zeroed() };
392 storage.sin6_len = u8::try_from(std::mem::size_of::<libc::sockaddr_in6>())
393 .expect("sockaddr_in6 length exceeds u8");
394 storage.sin6_family = u8::try_from(libc::AF_INET6).expect("AF_INET6 exceeds u8");
395 storage.sin6_port = address.port().to_be();
396 storage.sin6_flowinfo = address.flowinfo();
397 storage.sin6_scope_id = address.scope_id();
398 storage.sin6_addr = libc::in6_addr {
399 s6_addr: address.ip().octets(),
400 };
401 unsafe {
402 std::slice::from_raw_parts(
403 std::ptr::from_ref(&storage).cast::<u8>(),
404 std::mem::size_of::<libc::sockaddr_in6>(),
405 )
406 .to_vec()
407 }
408 }
409 }
410}