Skip to main content

reauthfi_core/
lib.rs

1use std::error::Error;
2use std::fmt;
3use std::io::{self, Write};
4use std::process::Command;
5use std::result::Result;
6use std::sync::{
7    atomic::{AtomicBool, Ordering},
8    Arc,
9};
10use std::thread;
11use std::time::{Duration, Instant};
12
13use colored::Colorize;
14use regex::Regex;
15use reqwest::blocking::Client;
16
17#[derive(Debug)]
18pub enum ReauthfiError {
19    Network(reqwest::Error),
20    Io(std::io::Error),
21    NotFound,
22    CommandFailed(String),
23    UnsupportedPlatform,
24}
25
26#[derive(Debug)]
27pub enum DetectionResult {
28    PortalFound(String),
29    NoPortalDetected,
30    NetworkError,
31}
32
33impl fmt::Display for ReauthfiError {
34    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
35        match self {
36            ReauthfiError::Network(e) => write!(f, "Network error: {}", e),
37            ReauthfiError::Io(e) => write!(f, "IO error: {}", e),
38            ReauthfiError::NotFound => write!(f, "Captive portal not found"),
39            ReauthfiError::CommandFailed(msg) => write!(f, "Command failed: {}", msg),
40            ReauthfiError::UnsupportedPlatform => write!(f, "Unsupported platform"),
41        }
42    }
43}
44
45impl Error for ReauthfiError {}
46
47impl From<reqwest::Error> for ReauthfiError {
48    fn from(err: reqwest::Error) -> Self {
49        ReauthfiError::Network(err)
50    }
51}
52
53impl From<std::io::Error> for ReauthfiError {
54    fn from(err: std::io::Error) -> Self {
55        ReauthfiError::Io(err)
56    }
57}
58
59#[derive(Debug, Clone)]
60pub enum Platform {
61    MacOS,
62    Unsupported,
63}
64
65impl Platform {
66    pub fn detect() -> Self {
67        #[cfg(target_os = "macos")]
68        {
69            Platform::MacOS
70        }
71
72        #[cfg(not(target_os = "macos"))]
73        {
74            Platform::Unsupported
75        }
76    }
77
78    pub fn detection_endpoints(&self) -> &'static [DetectionEndpoint] {
79        match self {
80            Platform::MacOS => MACOS_DETECTION_ENDPOINTS,
81            Platform::Unsupported => &[],
82        }
83    }
84}
85
86#[derive(Debug, Clone)]
87pub struct DetectionEndpoint {
88    pub name: &'static str,
89    pub url: &'static str,
90    pub expected_status: Option<u16>,
91}
92
93const MACOS_DETECTION_ENDPOINTS: &[DetectionEndpoint] = &[
94    DetectionEndpoint {
95        name: "Apple",
96        url: "http://captive.apple.com/hotspot-detect.html",
97        expected_status: None,
98    },
99    DetectionEndpoint {
100        name: "Google",
101        url: "http://connectivitycheck.gstatic.com/generate_204",
102        expected_status: Some(204),
103    },
104];
105
106#[derive(Debug)]
107pub struct PlatformConfig {
108    pub gateway_command: &'static [&'static str],
109    pub gateway_regex: &'static str,
110    pub gateway_endpoints: &'static [&'static str],
111}
112
113const MACOS_GATEWAY_COMMAND: &[&str] = &["route", "-n", "get", "default"];
114const MACOS_GATEWAY_REGEX: &str = r"gateway:\s+(\d+\.\d+\.\d+\.\d+)";
115const MACOS_GATEWAY_ENDPOINTS: &[&str] = &["/"];
116
117impl PlatformConfig {
118    pub fn for_platform(platform: &Platform) -> Result<Self, ReauthfiError> {
119        match platform {
120            Platform::MacOS => Ok(PlatformConfig {
121                gateway_command: MACOS_GATEWAY_COMMAND,
122                gateway_regex: MACOS_GATEWAY_REGEX,
123                gateway_endpoints: MACOS_GATEWAY_ENDPOINTS,
124            }),
125            Platform::Unsupported => Err(ReauthfiError::UnsupportedPlatform),
126        }
127    }
128}
129
130pub trait DetectionStrategy {
131    fn detect(&self, ctx: &DetectionContext) -> DetectionResult;
132}
133
134#[derive(Copy, Clone)]
135pub enum StrategyKind {
136    Gateway,
137    StandardUrl,
138}
139
140pub const GATEWAY_PRIORITY: [StrategyKind; 2] = [StrategyKind::Gateway, StrategyKind::StandardUrl];
141pub const STANDARD_PRIORITY: [StrategyKind; 2] = [StrategyKind::StandardUrl, StrategyKind::Gateway];
142
143pub struct DetectionContext<'a> {
144    pub platform: &'a Platform,
145    pub config: &'a PlatformConfig,
146    pub client: &'a Client,
147    pub options: &'a Options,
148}
149
150pub struct StandardUrlDetection;
151
152impl DetectionStrategy for StandardUrlDetection {
153    fn detect(&self, ctx: &DetectionContext) -> DetectionResult {
154        let endpoints = ctx.platform.detection_endpoints();
155        let mut saw_any_error = false;
156
157        for endpoint in endpoints {
158            if ctx.options.verbose {
159                println!(
160                    "  {} Checking {} ({})",
161                    "•".yellow(),
162                    endpoint.name,
163                    endpoint.url
164                );
165            }
166
167            match check_with_progress(endpoint.url, ctx.client, ctx.options.timeout) {
168                Ok(response) => {
169                    let status = response.status();
170
171                    if let Some(expected) = endpoint.expected_status {
172                        if status.as_u16() == expected {
173                            if ctx.options.verbose {
174                                println!("    {} Expected {} status", "✓".green(), expected);
175                            }
176                            continue; // move to next endpoint
177                        }
178                    }
179
180                    if let Some(portal_url) = redirect_location_url(&response) {
181                        if ctx.options.verbose {
182                            println!("    {} {} Redirect", "✓".green(), status.as_u16());
183                        }
184                        return DetectionResult::PortalFound(portal_url);
185                    }
186                }
187                Err(e) => {
188                    saw_any_error = true;
189                    if ctx.options.verbose {
190                        if e.is_timeout() {
191                            println!("    {} Timeout ({}s)", "⏱".yellow(), ctx.options.timeout);
192                        } else if e.is_connect() {
193                            println!("    {} Connection failed", "✗".red());
194                        } else {
195                            println!("    {} Failed: {}", "✗".red(), e);
196                        }
197                    }
198                }
199            }
200        }
201
202        // Determine result based on what happened
203        if saw_any_error {
204            DetectionResult::NetworkError
205        } else {
206            DetectionResult::NoPortalDetected
207        }
208    }
209}
210
211pub struct GatewayDetection;
212
213impl DetectionStrategy for GatewayDetection {
214    fn detect(&self, ctx: &DetectionContext) -> DetectionResult {
215        let gateway_ip = match get_gateway_ip(ctx.config) {
216            Ok(ip) => ip,
217            Err(_) => return DetectionResult::NetworkError,
218        };
219
220        if ctx.options.verbose {
221            println!("  {} Gateway IP: {}", "•".yellow(), gateway_ip);
222        }
223
224        for endpoint in ctx.config.gateway_endpoints {
225            let url = format!("http://{}{}", gateway_ip, endpoint);
226
227            if ctx.options.verbose {
228                println!("    {} Checking {}...", "•".yellow(), url);
229            }
230
231            match check_with_progress(&url, ctx.client, ctx.options.timeout) {
232                Ok(response) => {
233                    let status = response.status();
234
235                    if let Some(portal_url) = redirect_location_url(&response) {
236                        if ctx.options.verbose {
237                            println!("      {} {} Redirect", "✓".green(), status.as_u16());
238                        }
239                        return DetectionResult::PortalFound(portal_url);
240                    }
241
242                    if status.is_success() {
243                        if let Ok(html) = response.text() {
244                            if let Some(meta_url) = extract_meta_refresh(&html) {
245                                if ctx.options.verbose {
246                                    println!("      {} Found meta refresh", "✓".green());
247                                }
248                                return DetectionResult::PortalFound(meta_url);
249                            }
250                        }
251                    }
252                }
253                Err(e) => {
254                    if ctx.options.verbose {
255                        if e.is_timeout() {
256                            println!("      {} Timeout ({}s)", "⏱".yellow(), ctx.options.timeout);
257                        } else {
258                            println!("      {} Failed", "✗".red());
259                        }
260                    }
261                }
262            }
263        }
264
265        DetectionResult::NoPortalDetected
266    }
267}
268
269pub struct PortalOpenerService;
270
271impl PortalOpenerService {
272    pub fn open(url: &str) -> Result<(), ReauthfiError> {
273        #[cfg(target_os = "macos")]
274        {
275            let status = Command::new("open").arg(url).status()?;
276
277            if status.success() {
278                Ok(())
279            } else {
280                let detail = status
281                    .code()
282                    .map(|code| format!("exit code {}", code))
283                    .unwrap_or_else(|| "terminated by signal".to_string());
284                Err(ReauthfiError::CommandFailed(detail))
285            }
286        }
287
288        #[cfg(not(target_os = "macos"))]
289        {
290            let _ = url;
291            Err(ReauthfiError::UnsupportedPlatform)
292        }
293    }
294}
295
296pub fn build_client(timeout_secs: u64) -> Result<Client, ReauthfiError> {
297    let client = Client::builder()
298        .redirect(reqwest::redirect::Policy::none())
299        .timeout(Duration::from_secs(timeout_secs))
300        .build()?;
301    Ok(client)
302}
303
304pub fn print_network_not_ready(verbose: bool, detail: Option<&dyn fmt::Display>) {
305    println!(
306        "{} Network not ready - this may be a first-time Wi-Fi connection",
307        "❌".red().bold()
308    );
309    println!("  Close any macOS network popup windows and try again");
310    println!("  Or wait a few seconds for the network to stabilize");
311
312    if verbose {
313        if let Some(detail) = detail {
314            println!("  Detail: {}", detail);
315        }
316    }
317}
318
319fn print_progress(message: &str, elapsed: u64, total: u64) {
320    print!("\r  {} {} [", "•".yellow(), message);
321
322    let bar_slots = 20;
323    let safe_total = total.max(1);
324    let progress = (elapsed * bar_slots / safe_total).min(bar_slots);
325    for i in 0..bar_slots {
326        if i < progress {
327            print!("█");
328        } else {
329            print!("░");
330        }
331    }
332
333    print!("] {}s/{}s", elapsed, total);
334    io::stdout().flush().ok();
335}
336
337fn check_with_progress(
338    url: &str,
339    client: &Client,
340    timeout: u64,
341) -> Result<reqwest::blocking::Response, reqwest::Error> {
342    let start = Instant::now();
343    let done = Arc::new(AtomicBool::new(false));
344    let done_clone = done.clone();
345
346    let url_clone = url.to_string();
347    print_progress(&url_clone, 0, timeout);
348    io::stdout().flush().ok();
349
350    let handle = thread::spawn(move || {
351        while !done_clone.load(Ordering::Relaxed) {
352            let elapsed = start.elapsed().as_secs();
353            if elapsed <= timeout {
354                print_progress(&url_clone, elapsed, timeout);
355            }
356            thread::sleep(Duration::from_millis(500));
357        }
358        println!("");
359        io::stdout().flush().ok();
360    });
361
362    let result = client.get(url).send();
363
364    done.store(true, Ordering::Relaxed);
365    handle.join().ok();
366
367    result
368}
369
370fn get_gateway_ip(config: &PlatformConfig) -> Result<String, ReauthfiError> {
371    let output = Command::new(config.gateway_command[0])
372        .args(&config.gateway_command[1..])
373        .output()?;
374
375    let stdout = String::from_utf8_lossy(&output.stdout);
376    let re = Regex::new(config.gateway_regex).map_err(|_| ReauthfiError::NotFound)?;
377
378    re.captures(&stdout)
379        .and_then(|caps| caps.get(1))
380        .map(|m| m.as_str().to_string())
381        .ok_or(ReauthfiError::NotFound)
382}
383
384fn extract_meta_refresh(html: &str) -> Option<String> {
385    // Case-insensitive match for meta refresh with URL
386    let re = Regex::new(r#"(?i)content\s*=\s*["']?\d+\s*;\s*url\s*=\s*([^"'\s>]+)"#).ok()?;
387    re.captures(html)
388        .and_then(|caps| caps.get(1))
389        .map(|m| m.as_str())
390        .and_then(|url| {
391            if url.starts_with("http") {
392                Some(url.to_string())
393            } else {
394                None
395            }
396        })
397}
398
399fn redirect_location_url(response: &reqwest::blocking::Response) -> Option<String> {
400    if response.status().is_redirection() {
401        response
402            .headers()
403            .get("location")
404            .and_then(|v| v.to_str().ok())
405            .map(|s| s.to_string())
406    } else {
407        None
408    }
409}
410
411#[derive(Debug, Clone)]
412pub struct Options {
413    pub verbose: bool,
414    pub no_open: bool,
415    pub gateway: bool,
416    pub timeout: u64,
417}
418
419impl Default for Options {
420    fn default() -> Self {
421        Self {
422            verbose: false,
423            no_open: false,
424            gateway: false,
425            timeout: 10,
426        }
427    }
428}
429
430#[derive(Debug, Clone, Copy, PartialEq, Eq)]
431pub enum ExecutionStatus {
432    Completed,
433    NetworkNotReady,
434}
435
436pub fn run(options: &Options) -> Result<ExecutionStatus, ReauthfiError> {
437    let platform = Platform::detect();
438    let config = PlatformConfig::for_platform(&platform)?;
439
440    let client = match build_client(options.timeout) {
441        Ok(client) => client,
442        Err(e) => {
443            print_network_not_ready(options.verbose, Some(&e));
444            return Ok(ExecutionStatus::NetworkNotReady);
445        }
446    };
447
448    println!("{}", "🔍 Detecting Captive Portal...".cyan().bold());
449
450    let strategies: &[StrategyKind] = if options.gateway {
451        &GATEWAY_PRIORITY
452    } else {
453        &STANDARD_PRIORITY
454    };
455
456    let ctx = DetectionContext {
457        platform: &platform,
458        config: &config,
459        client: &client,
460        options,
461    };
462
463    for &strategy in strategies {
464        let detector: &dyn DetectionStrategy = match strategy {
465            StrategyKind::Gateway => &GatewayDetection,
466            StrategyKind::StandardUrl => &StandardUrlDetection,
467        };
468
469        match detector.detect(&ctx) {
470            DetectionResult::PortalFound(portal_url) => {
471                if options.verbose {
472                    println!("  {} Portal URL: {}", "→".green().bold(), portal_url);
473                }
474
475                if !options.no_open {
476                    println!("{}", "📱 Opening in browser...".cyan().bold());
477                    match PortalOpenerService::open(&portal_url) {
478                        Ok(_) => println!("{}", "✅ Done!".green().bold()),
479                        Err(e) => return Err(e),
480                    }
481                }
482                return Ok(ExecutionStatus::Completed);
483            }
484            DetectionResult::NetworkError => {
485                print_network_not_ready(options.verbose, None);
486                return Ok(ExecutionStatus::NetworkNotReady);
487            }
488            DetectionResult::NoPortalDetected => continue,
489        }
490    }
491
492    println!("{} No captive portal detected", "✅".green().bold());
493    Ok(ExecutionStatus::Completed)
494}