Skip to main content

nd_300/actions/fix/
vpn.rs

1use crate::config::Config;
2use crate::diagnostics::vpn::{self, VpnAdapter};
3use crate::render::progress::create_spinner;
4
5use super::cmd::{run_cmd, TIMEOUT_MEDIUM, TIMEOUT_SLOW};
6use super::{print_step_fail, print_step_ok, warn_icon};
7use crate::actions::{is_interactive, prompt_yes_no};
8
9/// Info about a VPN that was disabled during the fix, for potential re-enable.
10///
11/// `Clone` so a copy can be stowed in the restore registry (wrapped in `Arc`)
12/// while the original list is still passed to [`offer_reenable`].
13#[derive(Debug, Clone)]
14pub struct DisabledVpn {
15    pub name: String,
16    pub method: DisableMethod,
17}
18
19#[derive(Debug, Clone)]
20pub enum DisableMethod {
21    VendorCli(String, Vec<String>), // (binary, args)
22    Netsh(String),                  // adapter name (Windows)
23    #[cfg(target_os = "macos")]
24    Scutil(String), // VPN service name
25    #[cfg(target_os = "linux")]
26    Nmcli(String), // connection name
27    #[cfg(target_os = "linux")]
28    WgQuick(String), // interface name
29}
30
31/// Known enterprise VPN vendors — we never touch these automatically.
32fn is_enterprise_vpn(adapter: &VpnAdapter) -> bool {
33    let lower = adapter.name.to_lowercase();
34    let vendor_lower = adapter.vendor.as_deref().unwrap_or("").to_lowercase();
35    let type_lower = adapter.adapter_type.to_lowercase();
36
37    let enterprise_patterns = [
38        "cisco",
39        "anyconnect",
40        "globalprotect",
41        "palo alto",
42        "zscaler",
43        "forticlient",
44        "fortinet",
45        "pulse secure",
46        "juniper",
47        "f5 ",
48        "big-ip",
49        "checkpoint",
50        "corp",
51        "enterprise",
52        "mdm",
53        "company",
54    ];
55
56    enterprise_patterns
57        .iter()
58        .any(|p| lower.contains(p) || vendor_lower.contains(p) || type_lower.contains(p))
59}
60
61/// Try to find a vendor-specific CLI to disconnect this VPN.
62fn find_vendor_cli(adapter: &VpnAdapter) -> Option<(String, Vec<String>)> {
63    let lower = adapter.name.to_lowercase();
64    let vendor_lower = adapter.vendor.as_deref().unwrap_or("").to_lowercase();
65
66    // NordVPN
67    if lower.contains("nord") || vendor_lower.contains("nord") {
68        return Some(("nordvpn".to_string(), vec!["disconnect".to_string()]));
69    }
70    // ExpressVPN
71    if lower.contains("expressvpn") || vendor_lower.contains("expressvpn") {
72        return Some(("expressvpn".to_string(), vec!["disconnect".to_string()]));
73    }
74    // Mullvad
75    if lower.contains("mullvad") || vendor_lower.contains("mullvad") {
76        return Some(("mullvad".to_string(), vec!["disconnect".to_string()]));
77    }
78    // Tailscale
79    if lower.contains("tailscale") || vendor_lower.contains("tailscale") {
80        return Some(("tailscale".to_string(), vec!["down".to_string()]));
81    }
82    // WireGuard (wg-quick)
83    if adapter.adapter_type == "WireGuard" {
84        if let Some(ref iface) = adapter.interface_name {
85            return Some((
86                "wg-quick".to_string(),
87                vec!["down".to_string(), iface.clone()],
88            ));
89        }
90    }
91
92    // Cisco AnyConnect
93    if lower.contains("cisco")
94        || vendor_lower.contains("cisco")
95        || adapter.adapter_type.contains("Cisco")
96    {
97        #[cfg(windows)]
98        {
99            // Try common install paths
100            let paths = [
101                r"C:\Program Files (x86)\Cisco\Cisco AnyConnect Secure Mobility Client\vpncli.exe",
102                r"C:\Program Files\Cisco\Cisco AnyConnect Secure Mobility Client\vpncli.exe",
103            ];
104            for path in &paths {
105                if std::path::Path::new(path).exists() {
106                    return Some((path.to_string(), vec!["disconnect".to_string()]));
107                }
108            }
109        }
110        #[cfg(unix)]
111        {
112            return Some((
113                "/opt/cisco/anyconnect/bin/vpn".to_string(),
114                vec!["disconnect".to_string()],
115            ));
116        }
117    }
118
119    None
120}
121
122/// Detect connected VPNs and prompt user to disable them before fix stages.
123/// Returns a list of VPNs that were disabled (for later re-enable).
124pub async fn detect_and_disable(config: &Config) -> Vec<DisabledVpn> {
125    let mut disabled = Vec::new();
126
127    let spinner = create_spinner("Detecting VPN connections...");
128    let vpns = vpn::collect().await;
129    spinner.finish_and_clear();
130
131    let vpns = match vpns {
132        Some(v) => v,
133        None => return disabled,
134    };
135
136    let connected: Vec<&VpnAdapter> = vpns.iter().filter(|v| v.status == "Connected").collect();
137    if connected.is_empty() {
138        return disabled;
139    }
140
141    for adapter in connected {
142        if is_enterprise_vpn(adapter) {
143            if is_interactive(config) {
144                println!(
145                    "  {} Corporate VPN detected: {} — skipping (managed by your organization)",
146                    warn_icon(config),
147                    crate::render::color::cyan(&adapter.name, config),
148                );
149            }
150            continue;
151        }
152
153        let do_disable = if is_interactive(config) {
154            let prompt = format!(
155                "  VPN detected: {} ({}). VPN connections can interfere with network fixes. Disable? (y/N): ",
156                adapter.name, adapter.adapter_type,
157            );
158            prompt_yes_no(&prompt)
159        } else {
160            // Non-interactive mode cannot safely offer re-enable handling.
161            false
162        };
163
164        if !do_disable {
165            continue;
166        }
167
168        // Try vendor CLI first
169        if let Some((bin, args)) = find_vendor_cli(adapter) {
170            let spinner = create_spinner(&format!("Disabling {}...", adapter.name));
171            let mut cmd = tokio::process::Command::new(&bin);
172            cmd.args(&args);
173            let result = run_cmd(cmd, TIMEOUT_MEDIUM).await;
174            spinner.finish_and_clear();
175
176            if let Ok(output) = result {
177                if output.status.success() {
178                    if is_interactive(config) {
179                        print_step_ok(&format!("Disabled {}", adapter.name), config);
180                    }
181                    disabled.push(DisabledVpn {
182                        name: adapter.name.clone(),
183                        method: DisableMethod::VendorCli(
184                            bin,
185                            args.iter()
186                                .map(|a| a.replace("disconnect", "connect").replace("down", "up"))
187                                .collect(),
188                        ),
189                    });
190                    continue;
191                }
192            }
193        }
194
195        // Fallback: platform-specific adapter disable
196        let spinner = create_spinner(&format!("Disabling {}...", adapter.name));
197        let fallback_result = disable_adapter_fallback(adapter, config).await;
198        spinner.finish_and_clear();
199        match fallback_result {
200            Some(d) => disabled.push(d),
201            None => {
202                if is_interactive(config) {
203                    print_step_fail(
204                        &format!("Could not disable {}", adapter.name),
205                        "Try disconnecting manually before running fix",
206                        config,
207                    );
208                }
209            }
210        }
211    }
212
213    if !disabled.is_empty() {
214        // Give VPNs time to fully disconnect
215        let spinner = create_spinner("Waiting for VPN disconnect...");
216        tokio::time::sleep(std::time::Duration::from_secs(2)).await;
217        spinner.finish_and_clear();
218    }
219
220    disabled
221}
222
223async fn disable_adapter_fallback(adapter: &VpnAdapter, config: &Config) -> Option<DisabledVpn> {
224    #[cfg(windows)]
225    {
226        // Try netsh to disable the adapter
227        if let Some(ref iface) = adapter.interface_name {
228            let mut cmd = tokio::process::Command::new("netsh");
229            cmd.args(["interface", "set", "interface", iface, "disabled"]);
230            if let Ok(output) = run_cmd(cmd, TIMEOUT_SLOW).await {
231                if output.status.success() {
232                    if is_interactive(config) {
233                        print_step_ok(&format!("Disabled {}", adapter.name), config);
234                    }
235                    return Some(DisabledVpn {
236                        name: adapter.name.clone(),
237                        method: DisableMethod::Netsh(iface.clone()),
238                    });
239                }
240            }
241        }
242        let _ = config;
243        None
244    }
245
246    #[cfg(target_os = "macos")]
247    {
248        // Try scutil --nc stop
249        let mut cmd = tokio::process::Command::new("scutil");
250        cmd.args(["--nc", "stop", &adapter.name]);
251        if let Ok(output) = run_cmd(cmd, TIMEOUT_MEDIUM).await {
252            if output.status.success() {
253                if is_interactive(config) {
254                    print_step_ok(&format!("Disabled {}", adapter.name), config);
255                }
256                return Some(DisabledVpn {
257                    name: adapter.name.clone(),
258                    method: DisableMethod::Scutil(adapter.name.clone()),
259                });
260            }
261        }
262        None
263    }
264
265    #[cfg(target_os = "linux")]
266    {
267        // Try nmcli connection down
268        let mut nmcli_cmd = tokio::process::Command::new("nmcli");
269        nmcli_cmd.args(["connection", "down", &adapter.name]);
270        if let Ok(output) = run_cmd(nmcli_cmd, TIMEOUT_MEDIUM).await {
271            if output.status.success() {
272                if is_interactive(config) {
273                    print_step_ok(&format!("Disabled {}", adapter.name), config);
274                }
275                return Some(DisabledVpn {
276                    name: adapter.name.clone(),
277                    method: DisableMethod::Nmcli(adapter.name.clone()),
278                });
279            }
280        }
281        // Try wg-quick down
282        if let Some(ref iface) = adapter.interface_name {
283            let mut wg_cmd = tokio::process::Command::new("wg-quick");
284            wg_cmd.args(["down", iface]);
285            if let Ok(output) = run_cmd(wg_cmd, TIMEOUT_MEDIUM).await {
286                if output.status.success() {
287                    if is_interactive(config) {
288                        print_step_ok(&format!("Disabled {}", adapter.name), config);
289                    }
290                    return Some(DisabledVpn {
291                        name: adapter.name.clone(),
292                        method: DisableMethod::WgQuick(iface.clone()),
293                    });
294                }
295            }
296        }
297        let _ = config;
298        None
299    }
300}
301
302/// Offer to re-enable VPNs that were disabled during the fix.
303pub async fn offer_reenable(disabled: &[DisabledVpn], config: &Config) {
304    if disabled.is_empty() {
305        return;
306    }
307
308    for vpn in disabled {
309        let do_reenable = if is_interactive(config) {
310            let prompt = format!("  Re-enable {}? (y/N): ", vpn.name);
311            prompt_yes_no(&prompt)
312        } else {
313            // JSON mode: skip re-enable
314            false
315        };
316
317        if !do_reenable {
318            continue;
319        }
320
321        let spinner = create_spinner(&format!("Re-enabling {}...", vpn.name));
322        let success = reenable_vpn(vpn).await;
323        spinner.finish_and_clear();
324
325        if success {
326            if is_interactive(config) {
327                print_step_ok(&format!("Re-enabled {}", vpn.name), config);
328            }
329            // Verify connectivity after re-enable
330            let spinner = create_spinner("Verifying connectivity...");
331            tokio::time::sleep(std::time::Duration::from_secs(5)).await;
332            let connected = super::connectivity::check_connectivity().await;
333            spinner.finish_and_clear();
334
335            if !connected {
336                // Auto-disable again
337                let spinner = create_spinner(&format!("Disabling {} again...", vpn.name));
338                let _ = redisable_vpn(vpn).await;
339                spinner.finish_and_clear();
340
341                if is_interactive(config) {
342                    println!(
343                        "  {} Re-enabling {} broke connectivity. The VPN has been disabled again.",
344                        warn_icon(config),
345                        crate::render::color::cyan(&vpn.name, config),
346                    );
347                    println!(
348                        "    {}",
349                        crate::render::color::dim(
350                            "Check your VPN configuration or contact your VPN provider.",
351                            config
352                        ),
353                    );
354                }
355            }
356        } else if is_interactive(config) {
357            print_step_fail(
358                &format!("Failed to re-enable {}", vpn.name),
359                "Try reconnecting manually",
360                config,
361            );
362        }
363    }
364}
365
366pub(super) async fn reenable_vpn(vpn: &DisabledVpn) -> bool {
367    match &vpn.method {
368        DisableMethod::VendorCli(bin, reconnect_args) => {
369            let mut cmd = tokio::process::Command::new(bin);
370            cmd.args(reconnect_args);
371            if let Ok(output) = run_cmd(cmd, TIMEOUT_MEDIUM).await {
372                return output.status.success();
373            }
374            false
375        }
376        DisableMethod::Netsh(iface) => {
377            let mut cmd = tokio::process::Command::new("netsh");
378            cmd.args(["interface", "set", "interface", iface, "enabled"]);
379            if let Ok(output) = run_cmd(cmd, TIMEOUT_SLOW).await {
380                return output.status.success();
381            }
382            false
383        }
384        #[cfg(target_os = "macos")]
385        DisableMethod::Scutil(service) => {
386            let mut cmd = tokio::process::Command::new("scutil");
387            cmd.args(["--nc", "start", service]);
388            if let Ok(output) = run_cmd(cmd, TIMEOUT_MEDIUM).await {
389                return output.status.success();
390            }
391            false
392        }
393        #[cfg(target_os = "linux")]
394        DisableMethod::Nmcli(conn) => {
395            let mut cmd = tokio::process::Command::new("nmcli");
396            cmd.args(["connection", "up", conn]);
397            if let Ok(output) = run_cmd(cmd, TIMEOUT_MEDIUM).await {
398                return output.status.success();
399            }
400            false
401        }
402        #[cfg(target_os = "linux")]
403        DisableMethod::WgQuick(iface) => {
404            let mut cmd = tokio::process::Command::new("wg-quick");
405            cmd.args(["up", iface]);
406            if let Ok(output) = run_cmd(cmd, TIMEOUT_MEDIUM).await {
407                return output.status.success();
408            }
409            false
410        }
411    }
412}
413
414pub(super) async fn redisable_vpn(vpn: &DisabledVpn) -> bool {
415    match &vpn.method {
416        DisableMethod::VendorCli(bin, reconnect_args) => {
417            let disconnect_args: Vec<String> = reconnect_args
418                .iter()
419                .map(|a| a.replace("connect", "disconnect").replace("up", "down"))
420                .collect();
421            let mut cmd = tokio::process::Command::new(bin);
422            cmd.args(&disconnect_args);
423            if let Ok(output) = run_cmd(cmd, TIMEOUT_MEDIUM).await {
424                return output.status.success();
425            }
426            false
427        }
428        DisableMethod::Netsh(iface) => {
429            let mut cmd = tokio::process::Command::new("netsh");
430            cmd.args(["interface", "set", "interface", iface, "disabled"]);
431            if let Ok(output) = run_cmd(cmd, TIMEOUT_SLOW).await {
432                return output.status.success();
433            }
434            false
435        }
436        #[cfg(target_os = "macos")]
437        DisableMethod::Scutil(service) => {
438            let mut cmd = tokio::process::Command::new("scutil");
439            cmd.args(["--nc", "stop", service]);
440            if let Ok(output) = run_cmd(cmd, TIMEOUT_MEDIUM).await {
441                return output.status.success();
442            }
443            false
444        }
445        #[cfg(target_os = "linux")]
446        DisableMethod::Nmcli(conn) => {
447            let mut cmd = tokio::process::Command::new("nmcli");
448            cmd.args(["connection", "down", conn]);
449            if let Ok(output) = run_cmd(cmd, TIMEOUT_MEDIUM).await {
450                return output.status.success();
451            }
452            false
453        }
454        #[cfg(target_os = "linux")]
455        DisableMethod::WgQuick(iface) => {
456            let mut cmd = tokio::process::Command::new("wg-quick");
457            cmd.args(["down", iface]);
458            if let Ok(output) = run_cmd(cmd, TIMEOUT_MEDIUM).await {
459                return output.status.success();
460            }
461            false
462        }
463    }
464}
465
466/// Serialize VPN state for JSON output.
467pub fn vpn_json(disabled: &[DisabledVpn]) -> serde_json::Value {
468    if disabled.is_empty() {
469        return serde_json::json!(null);
470    }
471    let items: Vec<serde_json::Value> = disabled
472        .iter()
473        .map(|v| {
474            serde_json::json!({
475                "name": v.name,
476                "disabled": true,
477            })
478        })
479        .collect();
480    serde_json::json!(items)
481}