Skip to main content

aster/diagnostics/
network.rs

1//! 网络诊断检查
2//!
3//! 提供网络连接、API 可达性、代理配置等检查
4
5use super::checker::DiagnosticCheck;
6use std::time::Duration;
7
8/// 网络检查器
9pub struct NetworkChecker;
10
11impl NetworkChecker {
12    /// 检查 API 连通性
13    pub async fn check_api_connectivity() -> DiagnosticCheck {
14        let endpoints = [
15            ("Anthropic API", "https://api.anthropic.com"),
16            ("OpenAI API", "https://api.openai.com"),
17        ];
18
19        let client = match reqwest::Client::builder()
20            .timeout(Duration::from_secs(5))
21            .build()
22        {
23            Ok(c) => c,
24            Err(e) => {
25                return DiagnosticCheck::fail("API 连通性", "无法创建 HTTP 客户端")
26                    .with_details(e.to_string());
27            }
28        };
29
30        let mut reachable = Vec::new();
31        let mut unreachable = Vec::new();
32
33        for (name, url) in endpoints {
34            match client.head(url).send().await {
35                Ok(resp) if resp.status().is_success() || resp.status().as_u16() == 405 => {
36                    reachable.push(name);
37                }
38                _ => {
39                    unreachable.push(name);
40                }
41            }
42        }
43
44        if unreachable.is_empty() {
45            DiagnosticCheck::pass("API 连通性", format!("可达: {}", reachable.join(", ")))
46        } else if !reachable.is_empty() {
47            DiagnosticCheck::warn(
48                "API 连通性",
49                format!("部分不可达: {}", unreachable.join(", ")),
50            )
51        } else {
52            DiagnosticCheck::fail("API 连通性", "所有 API 端点不可达")
53        }
54    }
55
56    /// 检查网络连接
57    pub async fn check_network_connectivity() -> DiagnosticCheck {
58        let endpoints = [
59            ("Internet", "https://www.google.com"),
60            ("GitHub", "https://github.com"),
61        ];
62
63        let client = match reqwest::Client::builder()
64            .timeout(Duration::from_secs(3))
65            .build()
66        {
67            Ok(c) => c,
68            Err(_) => {
69                return DiagnosticCheck::warn("网络连接", "无法创建 HTTP 客户端");
70            }
71        };
72
73        let mut results = Vec::new();
74        let mut failures = Vec::new();
75
76        for (name, url) in endpoints {
77            match client.head(url).send().await {
78                Ok(_) => results.push(name),
79                Err(_) => failures.push(name),
80            }
81        }
82
83        if failures.is_empty() {
84            DiagnosticCheck::pass("网络连接", "网络连接正常")
85        } else if !results.is_empty() {
86            DiagnosticCheck::warn(
87                "网络连接",
88                format!("部分端点不可达: {}", failures.join(", ")),
89            )
90        } else {
91            DiagnosticCheck::fail("网络连接", "无网络连接")
92        }
93    }
94
95    /// 检查代理配置
96    pub fn check_proxy_configuration() -> DiagnosticCheck {
97        let proxy_vars = [
98            "HTTP_PROXY",
99            "HTTPS_PROXY",
100            "http_proxy",
101            "https_proxy",
102            "NO_PROXY",
103            "no_proxy",
104        ];
105
106        let set_proxies: Vec<_> = proxy_vars
107            .iter()
108            .filter(|v| std::env::var(v).is_ok())
109            .collect();
110
111        if set_proxies.is_empty() {
112            DiagnosticCheck::pass("代理配置", "未配置代理")
113        } else {
114            let details: Vec<String> = set_proxies
115                .iter()
116                .map(|v| {
117                    let value = std::env::var(v).unwrap_or_default();
118                    // 隐藏凭证
119                    let masked = if value.contains('@') {
120                        value
121                            .rsplit('@')
122                            .next()
123                            .map(|s| format!("***@{}", s))
124                            .unwrap_or_else(|| "***".to_string())
125                    } else {
126                        value
127                    };
128                    format!("{}={}", v, masked)
129                })
130                .collect();
131
132            DiagnosticCheck::pass(
133                "代理配置",
134                format!("已配置 {} 个代理变量", set_proxies.len()),
135            )
136            .with_details(details.join(", "))
137        }
138    }
139
140    /// 检查 SSL 证书配置
141    pub fn check_ssl_certificates() -> DiagnosticCheck {
142        // 检查是否禁用了 SSL 验证
143        if std::env::var("SSL_CERT_FILE").is_ok() || std::env::var("SSL_CERT_DIR").is_ok() {
144            return DiagnosticCheck::pass("SSL 证书", "使用自定义 CA 证书");
145        }
146
147        // 检查是否有不安全的配置
148        if std::env::var("RUSTLS_DANGEROUS_CONFIGURATION").is_ok() {
149            return DiagnosticCheck::warn("SSL 证书", "SSL 验证可能被禁用")
150                .with_details("RUSTLS_DANGEROUS_CONFIGURATION 已设置")
151                .with_fix("移除不安全的 SSL 配置");
152        }
153
154        DiagnosticCheck::pass("SSL 证书", "使用系统 SSL 证书")
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161    use crate::diagnostics::checker::CheckStatus;
162
163    #[test]
164    fn test_check_proxy_configuration() {
165        let result = NetworkChecker::check_proxy_configuration();
166        // 应该返回有效结果
167        assert!(result.status == CheckStatus::Pass || result.status == CheckStatus::Warn);
168    }
169
170    #[test]
171    fn test_check_ssl_certificates() {
172        let result = NetworkChecker::check_ssl_certificates();
173        // 通常应该通过
174        assert!(result.status == CheckStatus::Pass || result.status == CheckStatus::Warn);
175    }
176
177    #[tokio::test]
178    async fn test_check_api_connectivity() {
179        let result = NetworkChecker::check_api_connectivity().await;
180        // 网络可能不可用,所以接受任何状态
181        assert!(!result.name.is_empty());
182    }
183
184    #[tokio::test]
185    async fn test_check_network_connectivity() {
186        let result = NetworkChecker::check_network_connectivity().await;
187        // 网络可能不可用,所以接受任何状态
188        assert!(!result.name.is_empty());
189    }
190
191    #[test]
192    fn test_proxy_credential_masking() {
193        // 设置带凭证的代理
194        std::env::set_var("HTTP_PROXY_TEST", "http://user:pass@proxy.example.com:8080");
195
196        // 检查不会泄露凭证
197        let result = NetworkChecker::check_proxy_configuration();
198        if let Some(details) = &result.details {
199            assert!(!details.contains("pass"));
200        }
201
202        std::env::remove_var("HTTP_PROXY_TEST");
203    }
204}