use crate::common;
use crate::error::Error;
use crate::test_config::TestConfig;
use reqwest::Client;
use rustls::client::WebPkiServerVerifier;
use rustls::client::danger::ServerCertVerifier;
use rustls::crypto::CryptoProvider;
use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
use rustls::{ClientConfig, DigitallySignedStruct, RootCertStore, SignatureScheme};
use std::sync::Arc;
#[derive(Debug, Clone, Default)]
pub struct TlsConfig {
pub ca_cert_path: Option<std::path::PathBuf>,
pub min_tls_version: Option<String>,
pub pin_speedtest_certs: bool,
}
impl TlsConfig {
#[must_use]
pub fn with_ca_cert(mut self, path: std::path::PathBuf) -> Self {
self.ca_cert_path = Some(path);
self
}
#[must_use]
pub fn with_min_tls_version(mut self, version: impl Into<String>) -> Self {
self.min_tls_version = Some(version.into());
self
}
#[must_use]
pub fn with_cert_pinning(mut self) -> Self {
self.pin_speedtest_certs = true;
self
}
}
pub const DEFAULT_USER_AGENT: &str = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36";
#[derive(Debug, Clone)]
pub struct Settings {
pub timeout_secs: u64,
pub source_ip: Option<String>,
pub user_agent: String,
pub retry_enabled: bool,
pub tls: TlsConfig,
}
impl From<&crate::config::Config> for Settings {
fn from(config: &crate::config::Config) -> Self {
Self {
timeout_secs: config.timeout(),
source_ip: config.source().map(String::from),
user_agent: config
.custom_user_agent()
.map(String::from)
.unwrap_or_else(|| DEFAULT_USER_AGENT.to_string()),
retry_enabled: true,
tls: TlsConfig {
ca_cert_path: config.ca_cert_path(),
min_tls_version: config.tls_version().map(String::from),
pin_speedtest_certs: config.pin_certs(),
},
}
}
}
impl Default for Settings {
fn default() -> Self {
Self {
timeout_secs: 10,
source_ip: None,
user_agent: DEFAULT_USER_AGENT.to_string(),
retry_enabled: true,
tls: TlsConfig::default(),
}
}
}
impl Settings {
#[must_use]
pub fn with_user_agent(mut self, user_agent: impl Into<String>) -> Self {
self.user_agent = user_agent.into();
self
}
#[must_use]
pub fn with_retry_disabled(mut self) -> Self {
self.retry_enabled = false;
self
}
}
pub fn create_client(settings: &Settings) -> Result<Client, Error> {
let mut builder = Client::builder()
.timeout(std::time::Duration::from_secs(settings.timeout_secs))
.http1_only()
.no_gzip()
.use_rustls_tls()
.user_agent(&settings.user_agent);
if let Some(ref source_ip) = settings.source_ip {
let addr: std::net::SocketAddr = source_ip
.parse()
.map_err(|e| Error::with_source("Invalid source IP", e))?;
builder = builder.local_address(addr.ip());
}
if settings.tls.ca_cert_path.is_some()
|| settings.tls.min_tls_version.is_some()
|| settings.tls.pin_speedtest_certs
{
let tls_config = build_tls_config(&settings.tls)?;
builder = builder.use_preconfigured_tls(tls_config);
}
let client = builder.build().map_err(Error::NetworkError)?;
Ok(client)
}
fn build_tls_config(tls: &TlsConfig) -> Result<ClientConfig, Error> {
let versions: &[&rustls::SupportedProtocolVersion] = match tls.min_tls_version.as_deref() {
Some("1.2") => &[&rustls::version::TLS12],
Some("1.3") => &[&rustls::version::TLS13],
Some(v) => {
eprintln!("Warning: Unknown TLS version '{}', using defaults", v);
rustls::DEFAULT_VERSIONS
}
None => rustls::DEFAULT_VERSIONS,
};
if tls.pin_speedtest_certs && tls.ca_cert_path.is_some() {
eprintln!(
"Warning: Both --ca-cert and --pin-certs are set. Custom CA verification will be used before the speedtest.net domain restriction."
);
}
let provider = Arc::new(rustls::crypto::aws_lc_rs::default_provider());
let builder = ClientConfig::builder_with_provider(Arc::clone(&provider))
.with_protocol_versions(versions)
.map_err(|e| Error::context(format!("Invalid TLS configuration: {e}")))?;
let root_store = match tls.ca_cert_path.as_deref() {
Some(ca_path) => load_custom_ca_cert(ca_path)?,
None => default_root_store(),
};
if tls.pin_speedtest_certs {
let verifier = PinningVerifier::try_new(root_store, provider)?;
return Ok(builder
.dangerous()
.with_custom_certificate_verifier(Arc::new(verifier))
.with_no_client_auth());
}
Ok(builder
.with_root_certificates(root_store)
.with_no_client_auth())
}
fn default_root_store() -> RootCertStore {
let mut root_store = RootCertStore::empty();
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
root_store
}
fn load_custom_ca_cert(path: &std::path::Path) -> Result<RootCertStore, Error> {
let pem_data = std::fs::read(path)
.map_err(|e| Error::context(format!("Failed to read CA cert: {}", e)))?;
let mut store = RootCertStore::empty();
let mut cursor = std::io::Cursor::new(&pem_data);
let mut found_cert = false;
for cert_result in rustls_pemfile::certs(&mut cursor) {
match cert_result {
Ok(cert) => {
store
.add(cert)
.map_err(|e| Error::context(format!("Failed to add cert: {}", e)))?;
found_cert = true;
}
Err(e) => {
eprintln!("Warning: Failed to parse PEM cert: {}", e);
}
}
}
if !found_cert {
store
.add(CertificateDer::from(pem_data))
.map_err(|e| Error::context(format!("Failed to parse cert: {}", e)))?;
}
Ok(store)
}
#[derive(Debug)]
struct PinningVerifier {
inner: Arc<WebPkiServerVerifier>,
}
impl PinningVerifier {
#[cfg(test)]
fn new() -> Self {
Self::try_new(
default_root_store(),
Arc::new(rustls::crypto::aws_lc_rs::default_provider()),
)
.expect("default TLS verifier should build")
}
fn try_new(root_store: RootCertStore, provider: Arc<CryptoProvider>) -> Result<Self, Error> {
let inner = WebPkiServerVerifier::builder_with_provider(Arc::new(root_store), provider)
.build()
.map_err(|e| Error::context(format!("Failed to build TLS verifier: {e:?}")))?;
Ok(Self { inner })
}
fn is_valid_domain(host: &str) -> bool {
host == "speedtest.net"
|| host == "ookla.com"
|| host.ends_with(".speedtest.net")
|| host.ends_with(".ookla.com")
}
}
impl ServerCertVerifier for PinningVerifier {
fn verify_server_cert(
&self,
end_entity: &CertificateDer<'_>,
_intermediate_certs: &[CertificateDer<'_>],
server_name: &ServerName<'_>,
_ocsp_response: &[u8],
_now: UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
let hostname = match server_name {
ServerName::DnsName(name) => name.as_ref(),
_ => {
return Err(rustls::Error::General(
"Unsupported server name type".to_string(),
));
}
};
if !Self::is_valid_domain(hostname) {
return Err(rustls::Error::General(format!(
"'{}' is not a speedtest.net domain",
hostname
)));
}
self.inner.verify_server_cert(
end_entity,
_intermediate_certs,
server_name,
_ocsp_response,
_now,
)
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
self.inner.verify_tls12_signature(message, cert, dss)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
self.inner.verify_tls13_signature(message, cert, dss)
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
self.inner.supported_verify_schemes()
}
}
fn is_transient_error(e: &reqwest::Error) -> bool {
if e.is_timeout() {
return true;
}
if e.is_connect() {
return true;
}
if let Some(status) = e.status() {
return status.as_u16() >= 500;
}
false
}
pub async fn with_retry<R, F, Fut>(mut request: F) -> Result<R, Error>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = Result<R, reqwest::Error>>,
{
let config = TestConfig::default();
let max_attempts = config.http_retry_attempts;
for attempt in 0..max_attempts {
let result = request().await;
if let Ok(r) = result {
return Ok(r);
}
if let Err(e) = &result {
let (delay, should_retry) = TestConfig::retry_delay(attempt);
#[allow(clippy::collapsible_if)]
if should_retry && is_transient_error(e) && attempt < max_attempts - 1 {
tokio::time::sleep(std::time::Duration::from_millis(delay)).await;
continue;
}
return result.map_err(Error::NetworkError);
}
}
Err(Error::context("retry loop ended without result or error"))
}
pub async fn discover_client_ip(client: &Client) -> Result<String, Error> {
if let Ok(response) = client
.get("https://www.speedtest.net/api/ip.php")
.send()
.await
{
if let Ok(text) = response.text().await {
let trimmed = text.trim().to_string();
if common::is_valid_ipv4(&trimmed) {
return Ok(trimmed);
}
}
}
if let Ok(response) = client
.get("https://www.speedtest.net/api/ios-config.php")
.send()
.await
{
if let Ok(text) = response.text().await {
if let Some(ip) = parse_ip_from_xml(&text) {
return Ok(ip);
}
}
}
Ok("unknown".to_string())
}
fn parse_ip_from_xml(xml: &str) -> Option<String> {
#[derive(serde::Deserialize)]
struct Settings {
client: ClientElement,
}
#[derive(serde::Deserialize)]
struct ClientElement {
#[serde(rename = "@ip")]
ip: Option<String>,
}
let settings: Settings = match quick_xml::de::from_str(xml) {
Ok(s) => s,
Err(_) => return None,
};
let ip = settings.client.ip?;
if common::is_valid_ipv4(&ip) {
Some(ip)
} else {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
#[test]
fn test_tls_config_with_ca_cert() {
let config = TlsConfig::default();
assert!(config.ca_cert_path.is_none());
let config = config.with_ca_cert(std::path::PathBuf::from("/path/to/cert.pem"));
assert_eq!(
config.ca_cert_path,
Some(std::path::PathBuf::from("/path/to/cert.pem"))
);
}
#[test]
fn test_tls_config_with_min_tls_version() {
let config = TlsConfig::default();
assert!(config.min_tls_version.is_none());
let config = config.with_min_tls_version("1.2");
assert_eq!(config.min_tls_version, Some("1.2".to_string()));
let config = TlsConfig::default().with_min_tls_version("1.3");
assert_eq!(config.min_tls_version, Some("1.3".to_string()));
}
#[test]
fn test_tls_config_with_cert_pinning() {
let config = TlsConfig::default();
assert!(!config.pin_speedtest_certs);
let config = config.with_cert_pinning();
assert!(config.pin_speedtest_certs);
}
#[test]
fn test_tls_config_builder_chaining() {
let config = TlsConfig::default()
.with_ca_cert(std::path::PathBuf::from("/custom/ca.pem"))
.with_min_tls_version("1.3")
.with_cert_pinning();
assert_eq!(
config.ca_cert_path,
Some(std::path::PathBuf::from("/custom/ca.pem"))
);
assert_eq!(config.min_tls_version, Some("1.3".to_string()));
assert!(config.pin_speedtest_certs);
}
#[test]
fn test_settings_default_values() {
let settings = Settings::default();
assert_eq!(settings.timeout_secs, 10);
assert!(settings.source_ip.is_none());
assert_eq!(settings.user_agent, DEFAULT_USER_AGENT);
assert!(settings.retry_enabled);
assert!(settings.tls.ca_cert_path.is_none());
assert!(settings.tls.min_tls_version.is_none());
assert!(!settings.tls.pin_speedtest_certs);
}
#[test]
fn test_settings_with_user_agent() {
let settings = Settings::default().with_user_agent("Custom Agent/1.0");
assert_eq!(settings.user_agent, "Custom Agent/1.0");
}
#[test]
fn test_settings_with_user_agent_chaining() {
let settings = Settings::default()
.with_user_agent("Test Agent")
.with_retry_disabled();
assert_eq!(settings.user_agent, "Test Agent");
assert!(!settings.retry_enabled);
}
#[test]
fn test_settings_with_retry_disabled() {
let settings = Settings::default();
assert!(settings.retry_enabled);
let settings = settings.with_retry_disabled();
assert!(!settings.retry_enabled);
}
#[test]
fn test_settings_debug_trait() {
let settings = Settings::default();
let debug_str = format!("{:?}", settings);
assert!(debug_str.contains("timeout_secs"));
assert!(debug_str.contains("user_agent"));
}
#[test]
fn test_settings_clone() {
let settings = Settings::default();
let cloned = settings.clone();
assert_eq!(settings.user_agent, cloned.user_agent);
assert_eq!(settings.timeout_secs, cloned.timeout_secs);
}
#[test]
#[ignore]
fn test_build_tls_config_unknown_tls_version() {
let tls = TlsConfig {
min_tls_version: Some("99.0".to_string()),
..Default::default()
};
let result = build_tls_config(&tls);
assert!(result.is_ok());
}
#[test]
#[ignore]
fn test_build_tls_config_tls12() {
let tls = TlsConfig {
min_tls_version: Some("1.2".to_string()),
..Default::default()
};
let result = build_tls_config(&tls);
assert!(result.is_ok());
}
#[test]
#[ignore]
fn test_build_tls_config_tls13() {
let tls = TlsConfig {
min_tls_version: Some("1.3".to_string()),
..Default::default()
};
let result = build_tls_config(&tls);
assert!(result.is_ok());
}
#[test]
#[ignore]
fn test_build_tls_config_pinning_takes_precedence() {
let tls = TlsConfig {
ca_cert_path: Some(std::path::PathBuf::from("/path/to/ca.pem")),
pin_speedtest_certs: true,
..Default::default()
};
let result = build_tls_config(&tls);
assert!(result.is_ok());
}
#[test]
#[ignore]
fn test_build_tls_config_pinning_only() {
let tls = TlsConfig {
pin_speedtest_certs: true,
..Default::default()
};
let result = build_tls_config(&tls);
assert!(result.is_ok());
}
#[test]
#[ignore]
fn test_build_tls_config_no_options() {
let tls = TlsConfig::default();
let result = build_tls_config(&tls);
assert!(result.is_ok());
}
#[test]
fn test_load_custom_ca_cert_file_not_found() {
let result = load_custom_ca_cert(std::path::Path::new("/nonexistent/cert.pem"));
assert!(result.is_err());
let err = result.unwrap_err();
let err_msg = format!("{:?}", err);
assert!(err_msg.contains("nonexistent") || err_msg.contains("Failed to read CA cert"));
}
#[test]
fn test_load_custom_ca_cert_invalid_path() {
let result = load_custom_ca_cert(std::path::Path::new("/tmp"));
assert!(result.is_err());
}
#[test]
fn test_create_client_source_ip_v4() {
let settings = Settings {
source_ip: Some("192.168.1.100".to_string()),
..Default::default()
};
let result = create_client(&settings);
match result {
Ok(_) => {}
Err(Error::Context { .. }) => {} Err(e) => panic!("Unexpected error type for valid IPv4: {e:?}"),
}
}
#[test]
fn test_create_client_source_ip_v6() {
let settings = Settings {
source_ip: Some("::1".to_string()),
..Default::default()
};
let result = create_client(&settings);
match result {
Ok(_) => {}
Err(Error::NetworkError(_) | Error::Context { .. }) => {} Err(e) => panic!("Unexpected error type: {e:?}"),
}
}
#[test]
#[ignore]
fn test_create_client_with_ca_cert() {
let settings = Settings {
tls: TlsConfig {
ca_cert_path: Some(std::path::PathBuf::from("/nonexistent/ca.pem")),
..Default::default()
},
..Default::default()
};
let result = create_client(&settings);
assert!(result.is_err());
}
#[test]
#[ignore]
fn test_create_client_with_pinning() {
let settings = Settings {
tls: TlsConfig {
pin_speedtest_certs: true,
..Default::default()
},
..Default::default()
};
let result = create_client(&settings);
assert!(result.is_ok());
}
#[test]
fn test_create_client_with_retry_disabled() {
let settings = Settings::default().with_retry_disabled();
let result = create_client(&settings);
assert!(result.is_ok());
}
#[test]
fn test_create_client_timeout_30() {
let settings = Settings {
timeout_secs: 30,
..Default::default()
};
let result = create_client(&settings);
assert!(result.is_ok());
}
#[test]
fn test_create_client_timeout_60() {
let settings = Settings {
timeout_secs: 60,
..Default::default()
};
let result = create_client(&settings);
assert!(result.is_ok());
}
#[test]
fn test_pinning_verifier_is_valid_domain_speedtest() {
assert!(PinningVerifier::is_valid_domain("speedtest.net"));
assert!(PinningVerifier::is_valid_domain("www.speedtest.net"));
assert!(PinningVerifier::is_valid_domain("api.speedtest.net"));
assert!(PinningVerifier::is_valid_domain("foo.bar.speedtest.net"));
assert!(PinningVerifier::is_valid_domain("fake.speedtest.net"));
assert!(!PinningVerifier::is_valid_domain("evilsite.net"));
assert!(!PinningVerifier::is_valid_domain("speedtest.com"));
assert!(!PinningVerifier::is_valid_domain("notspeedtest.net"));
}
#[test]
fn test_pinning_verifier_is_valid_domain_ookla() {
assert!(PinningVerifier::is_valid_domain("ookla.com"));
assert!(PinningVerifier::is_valid_domain("www.ookla.com"));
assert!(PinningVerifier::is_valid_domain("api.ookla.com"));
assert!(PinningVerifier::is_valid_domain("foo.bar.ookla.com"));
assert!(PinningVerifier::is_valid_domain("fake.ookla.com"));
assert!(!PinningVerifier::is_valid_domain("ookla.net"));
}
#[test]
fn test_pinning_verifier_edge_cases() {
assert!(!PinningVerifier::is_valid_domain(""));
assert!(!PinningVerifier::is_valid_domain("speedtestXnet")); assert!(!PinningVerifier::is_valid_domain("attack.com")); }
#[test]
fn test_pinning_verifier_exact_domains() {
assert!(PinningVerifier::is_valid_domain("speedtest.net"));
assert!(PinningVerifier::is_valid_domain("ookla.com"));
}
#[test]
fn test_pinning_verifier_subdomains() {
assert!(PinningVerifier::is_valid_domain("www.speedtest.net"));
assert!(PinningVerifier::is_valid_domain("api.speedtest.net"));
assert!(PinningVerifier::is_valid_domain("a.b.c.speedtest.net"));
assert!(PinningVerifier::is_valid_domain("www.ookla.com"));
assert!(PinningVerifier::is_valid_domain("api.www.ookla.com"));
}
#[test]
fn test_pinning_verifier_invalid_suffixes() {
assert!(!PinningVerifier::is_valid_domain("xspeedtest.net")); assert!(!PinningVerifier::is_valid_domain("fake-speedtest.net")); assert!(!PinningVerifier::is_valid_domain("speedtest.net.evil.com")); assert!(!PinningVerifier::is_valid_domain("ookla.com.evil.com")); assert!(!PinningVerifier::is_valid_domain("fooookla.com")); }
#[test]
fn test_pinning_verifier_case_sensitivity() {
assert!(!PinningVerifier::is_valid_domain("Speedtest.net")); assert!(!PinningVerifier::is_valid_domain("SPEEDTEST.NET")); assert!(!PinningVerifier::is_valid_domain("www.Speedtest.net")); assert!(!PinningVerifier::is_valid_domain("OOKLA.COM")); }
#[test]
fn test_pinning_verifier_special_characters() {
assert!(!PinningVerifier::is_valid_domain("speedtest.net/")); assert!(!PinningVerifier::is_valid_domain("speedtest.net:443")); assert!(!PinningVerifier::is_valid_domain("speedtest.net/path")); }
#[test]
fn test_pinning_verifier_numeric_domains() {
assert!(PinningVerifier::is_valid_domain("123.speedtest.net")); assert!(PinningVerifier::is_valid_domain("1.2.3.speedtest.net")); assert!(!PinningVerifier::is_valid_domain("speedtest123.net")); assert!(!PinningVerifier::is_valid_domain("123speedtest.net")); }
#[test]
fn test_pinning_verifier_new_returns_self() {
let verifier = PinningVerifier::new();
assert!(!verifier.supported_verify_schemes().is_empty());
}
#[test]
fn test_pinning_verifier_debug_trait() {
let verifier = PinningVerifier::new();
let debug_str = format!("{:?}", verifier);
assert!(debug_str.contains("PinningVerifier"));
}
#[test]
fn test_pinning_verifier_supported_verify_schemes() {
let verifier = PinningVerifier::new();
let schemes = verifier.supported_verify_schemes();
assert!(schemes.contains(&SignatureScheme::RSA_PKCS1_SHA256));
assert!(schemes.contains(&SignatureScheme::RSA_PKCS1_SHA384));
assert!(schemes.contains(&SignatureScheme::RSA_PKCS1_SHA512));
assert!(schemes.contains(&SignatureScheme::ECDSA_NISTP256_SHA256));
assert!(schemes.contains(&SignatureScheme::ECDSA_NISTP384_SHA384));
assert!(schemes.contains(&SignatureScheme::RSA_PSS_SHA256));
assert!(schemes.contains(&SignatureScheme::RSA_PSS_SHA384));
assert!(schemes.contains(&SignatureScheme::RSA_PSS_SHA512));
assert!(schemes.len() >= 8);
}
#[test]
fn test_pinning_verifier_verify_server_cert_rejects_invalid_domain() {
let verifier = PinningVerifier::new();
let dns_name = rustls::pki_types::DnsName::try_from("evil.com".to_string()).unwrap();
let server_name = ServerName::DnsName(dns_name);
let cert_der = CertificateDer::from(vec![]);
let result =
verifier.verify_server_cert(&cert_der, &[], &server_name, &[], UnixTime::now());
assert!(result.is_err());
let err = result.unwrap_err();
let err_msg = format!("{:?}", err);
assert!(err_msg.contains("evil.com") || err_msg.contains("not a speedtest.net domain"));
}
#[test]
fn test_pinning_verifier_verify_server_cert_rejects_unsupported_name_type() {
let verifier = PinningVerifier::new();
let ip_addr = std::net::IpAddr::from([127, 0, 0, 1]);
let server_name = ServerName::IpAddress(ip_addr.into());
let cert_der = CertificateDer::from(vec![]);
let result =
verifier.verify_server_cert(&cert_der, &[], &server_name, &[], UnixTime::now());
assert!(result.is_err());
let err = result.unwrap_err();
let err_msg = format!("{:?}", err);
assert!(err_msg.contains("Unsupported server name type"));
}
#[test]
fn test_pinning_verifier_verify_server_cert_rejects_invalid_certificate() {
let verifier = PinningVerifier::new();
let dns_name =
rustls::pki_types::DnsName::try_from("www.speedtest.net".to_string()).unwrap();
let server_name = ServerName::DnsName(dns_name);
let cert_der = CertificateDer::from(vec![]);
let result =
verifier.verify_server_cert(&cert_der, &[], &server_name, &[], UnixTime::now());
assert!(result.is_err());
let err = result.unwrap_err();
let err_msg = format!("{:?}", err);
assert!(!err_msg.contains("not a speedtest.net domain"));
}
#[test]
fn test_pinning_verifier_domain_checked_before_cert_parse_speedtest() {
let verifier = PinningVerifier::new();
let dns_name = rustls::pki_types::DnsName::try_from("speedtest.net".to_string()).unwrap();
let server_name = ServerName::DnsName(dns_name);
let cert_der = CertificateDer::from(vec![]);
let result =
verifier.verify_server_cert(&cert_der, &[], &server_name, &[], UnixTime::now());
assert!(result.is_err());
let err = result.unwrap_err();
let err_msg = format!("{:?}", err);
assert!(!err_msg.contains("not a speedtest.net domain"));
}
#[test]
fn test_pinning_verifier_ipv6_address_rejected() {
let verifier = PinningVerifier::new();
let ip_addr = std::net::IpAddr::from([0, 0, 0, 0, 0, 0, 0, 1]); let server_name = ServerName::IpAddress(ip_addr.into());
let cert_der = CertificateDer::from(vec![]);
let result =
verifier.verify_server_cert(&cert_der, &[], &server_name, &[], UnixTime::now());
assert!(result.is_err());
}
#[test]
fn test_pinning_verifier_domain_checked_before_cert_parse_ookla() {
let verifier = PinningVerifier::new();
let dns_name = rustls::pki_types::DnsName::try_from("ookla.com".to_string()).unwrap();
let server_name = ServerName::DnsName(dns_name);
let cert_der = CertificateDer::from(vec![]);
let result =
verifier.verify_server_cert(&cert_der, &[], &server_name, &[], UnixTime::now());
assert!(result.is_err());
}
#[test]
fn test_pinning_verifier_domain_validation_order() {
let verifier = PinningVerifier::new();
let dns_name = rustls::pki_types::DnsName::try_from("attacker.com".to_string()).unwrap();
let server_name = ServerName::DnsName(dns_name);
let cert_der = CertificateDer::from(vec![]);
let result =
verifier.verify_server_cert(&cert_der, &[], &server_name, &[], UnixTime::now());
assert!(result.is_err());
let err = result.unwrap_err();
let err_msg = format!("{:?}", err);
assert!(
err_msg.contains("not a speedtest.net domain"),
"Expected domain validation error, got: {}",
err_msg
);
}
#[test]
fn test_pinning_verifier_verify_server_cert_rejects_different_tld() {
let verifier = PinningVerifier::new();
let dns_name =
rustls::pki_types::DnsName::try_from("speedtest.net.org".to_string()).unwrap();
let server_name = ServerName::DnsName(dns_name);
let cert_der = CertificateDer::from(vec![]);
let result =
verifier.verify_server_cert(&cert_der, &[], &server_name, &[], UnixTime::now());
assert!(result.is_err());
let err = result.unwrap_err();
let err_msg = format!("{:?}", err);
assert!(err_msg.contains("not a speedtest.net domain"));
}
#[test]
fn test_pinning_verifier_intermediate_certs_ignored() {
let verifier = PinningVerifier::new();
let dns_name =
rustls::pki_types::DnsName::try_from("www.speedtest.net".to_string()).unwrap();
let server_name = ServerName::DnsName(dns_name);
let cert_der = CertificateDer::from(vec![]);
let intermediate_cert = CertificateDer::from(vec![0u8; 10]);
let result = verifier.verify_server_cert(
&cert_der,
&[intermediate_cert],
&server_name,
&[],
UnixTime::now(),
);
assert!(result.is_err());
}
#[test]
fn test_pinning_verifier_ocsp_response_ignored() {
let verifier = PinningVerifier::new();
let dns_name = rustls::pki_types::DnsName::try_from("api.ookla.com".to_string()).unwrap();
let server_name = ServerName::DnsName(dns_name);
let cert_der = CertificateDer::from(vec![]);
let ocsp_response = vec![0x30, 0x03, 0x01, 0x00];
let result = verifier.verify_server_cert(
&cert_der,
&[],
&server_name,
&ocsp_response,
UnixTime::now(),
);
assert!(result.is_err());
}
#[test]
fn test_pinning_verifier_all_valid_subdomains() {
let valid_subdomains = [
"www.speedtest.net",
"api.speedtest.net",
"test.speedtest.net",
"staging.speedtest.net",
"prod.speedtest.net",
"cdn.speedtest.net",
"a.speedtest.net",
"z.speedtest.net",
"a1b2c3.speedtest.net",
"my-site.speedtest.net",
"www.ookla.com",
"api.ookla.com",
"test.ookla.com",
];
for domain in valid_subdomains {
assert!(
PinningVerifier::is_valid_domain(domain),
"Domain '{}' should be valid",
domain
);
}
}
#[test]
fn test_pinning_verifier_all_invalid_domains() {
let invalid_domains = [
"evilsite.net",
"speedtest.net.evil.com",
"ookla.com.evil.com",
"speedtest.com",
"ookla.net",
"notspeedtest.net",
"notookla.com",
"fake-speedtest.net",
"fake-ookla.com",
"attacker.speedtest.net.fake.com",
"attacker.ookla.com.fake.com",
];
for domain in invalid_domains {
assert!(
!PinningVerifier::is_valid_domain(domain),
"Domain '{}' should be invalid",
domain
);
}
}
#[test]
fn test_parse_ip_from_xml() {
let xml = r#"<settings><client country="CA" ip="173.35.57.235" isp="Rogers"/></settings>"#;
assert_eq!(parse_ip_from_xml(xml), Some("173.35.57.235".to_string()));
}
#[test]
fn test_parse_ip_from_xml_full_response() {
let xml = r#"<?xml version="1.0"?>
<settings>
<config downloadThreadCountV3="4"/>
<client country="CA" ip="173.35.57.235" isp="Rogers"/>
</settings>"#;
assert_eq!(parse_ip_from_xml(xml), Some("173.35.57.235".to_string()));
}
#[test]
fn test_parse_ip_from_xml_invalid() {
assert!(parse_ip_from_xml("not xml").is_none());
assert!(parse_ip_from_xml("<html></html>").is_none());
assert!(parse_ip_from_xml("<settings><client ip=\"invalid\"/></settings>").is_none());
}
#[test]
fn test_create_client_invalid_source_ip() {
let source = crate::config::ConfigSource::default();
let config = crate::config::Config::from_source(&source);
let mut settings = Settings::from(&config);
settings.source_ip = Some("invalid-ip".to_string());
let result = create_client(&settings);
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), Error::Context { .. }));
}
#[test]
fn test_create_client_valid_config() {
let source = crate::config::ConfigSource::default();
let config = crate::config::Config::from_source(&source);
let settings = Settings::from(&config);
let result = create_client(&settings);
assert!(result.is_ok());
}
#[test]
fn test_create_client_with_source_ip() {
let source = crate::config::ConfigSource {
network: crate::config::NetworkSource {
source: Some("0.0.0.0".into()),
..Default::default()
},
..Default::default()
};
let config = crate::config::Config::from_source(&source);
let settings = Settings::from(&config);
let result = create_client(&settings);
match result {
Ok(_) | Err(Error::NetworkError(_) | Error::Context { .. }) => {}
Err(e) => panic!("Unexpected error type: {e:?}"),
}
}
#[test]
fn test_create_client_custom_timeout() {
let source = crate::config::ConfigSource {
network: crate::config::NetworkSource {
timeout: 30,
..Default::default()
},
..Default::default()
};
let config = crate::config::Config::from_source(&source);
let settings = Settings::from(&config);
let result = create_client(&settings);
assert!(result.is_ok());
}
#[test]
fn test_settings_from_config_with_source_ip() {
let source = crate::config::ConfigSource {
network: crate::config::NetworkSource {
source: Some("192.168.1.50".to_string()),
..Default::default()
},
..Default::default()
};
let config = crate::config::Config::from_source(&source);
let settings = Settings::from(&config);
assert_eq!(settings.source_ip, Some("192.168.1.50".to_string()));
}
#[test]
fn test_settings_from_config_with_ca_cert() {
let source = crate::config::ConfigSource {
network: crate::config::NetworkSource {
ca_cert: Some("/path/to/ca.pem".to_string()),
..Default::default()
},
..Default::default()
};
let config = crate::config::Config::from_source(&source);
let settings = Settings::from(&config);
assert_eq!(
settings.tls.ca_cert_path,
Some(std::path::PathBuf::from("/path/to/ca.pem"))
);
}
#[test]
fn test_settings_from_config_with_tls_version() {
let source = crate::config::ConfigSource {
network: crate::config::NetworkSource {
tls_version: Some("1.2".to_string()),
..Default::default()
},
..Default::default()
};
let config = crate::config::Config::from_source(&source);
let settings = Settings::from(&config);
assert_eq!(settings.tls.min_tls_version, Some("1.2".to_string()));
}
#[test]
fn test_settings_from_config_with_pinning() {
let source = crate::config::ConfigSource {
network: crate::config::NetworkSource {
pin_certs: Some(true),
..Default::default()
},
..Default::default()
};
let config = crate::config::Config::from_source(&source);
let settings = Settings::from(&config);
assert!(settings.tls.pin_speedtest_certs);
}
#[test]
fn test_settings_from_config_timeout() {
let source = crate::config::ConfigSource {
network: crate::config::NetworkSource {
timeout: 45,
..Default::default()
},
..Default::default()
};
let config = crate::config::Config::from_source(&source);
let settings = Settings::from(&config);
assert_eq!(settings.timeout_secs, 45);
}
#[test]
fn test_settings_from_config_default_user_agent() {
let config = crate::config::Config::from_source(&crate::config::ConfigSource::default());
let settings = Settings::from(&config);
assert_eq!(settings.user_agent, DEFAULT_USER_AGENT);
}
#[test]
fn test_settings_from_config_retry_enabled_by_default() {
let config = crate::config::Config::from_source(&crate::config::ConfigSource::default());
let settings = Settings::from(&config);
assert!(settings.retry_enabled);
}
#[tokio::test]
async fn test_with_retry_immediate_success() {
let counter = Arc::new(AtomicUsize::new(0));
let count = Arc::clone(&counter);
let result = with_retry(|| {
let c = Arc::clone(&count);
async move {
c.fetch_add(1, Ordering::SeqCst);
Ok::<_, reqwest::Error>(42)
}
})
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 42);
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_with_retry_with_mock_request() {
let result = with_retry(|| async { Ok::<_, reqwest::Error>(100) }).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 100);
}
#[tokio::test]
async fn test_with_retry_counter_increment() {
let counter = Arc::new(AtomicUsize::new(0));
let count = Arc::clone(&counter);
let _result = with_retry(|| {
let c = Arc::clone(&count);
async move {
c.fetch_add(1, Ordering::SeqCst);
Ok::<_, reqwest::Error>(1)
}
})
.await;
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_with_retry_different_value_types() {
let result_str = with_retry(|| async { Ok::<_, reqwest::Error>("hello") }).await;
assert!(result_str.is_ok());
assert_eq!(result_str.unwrap(), "hello");
let result_u64 = with_retry(|| async { Ok::<_, reqwest::Error>(999u64) }).await;
assert!(result_u64.is_ok());
assert_eq!(result_u64.unwrap(), 999);
let result_vec = with_retry(|| async { Ok::<_, reqwest::Error>(vec![1, 2, 3]) }).await;
assert!(result_vec.is_ok());
assert_eq!(result_vec.unwrap(), vec![1, 2, 3]);
}
#[tokio::test]
async fn test_with_retry_multiple_sequential_calls() {
for i in 0..3 {
let result = with_retry(|| async { Ok::<_, reqwest::Error>(i) }).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), i);
}
}
#[test]
fn test_parse_ip_from_xml_missing_client_element() {
let xml = r#"<settings><server ip="127.0.0.1"/></settings>"#;
assert!(parse_ip_from_xml(xml).is_none());
}
#[test]
fn test_parse_ip_from_xml_empty_ip() {
let xml = r#"<settings><client ip=""/></settings>"#;
assert!(parse_ip_from_xml(xml).is_none());
}
#[test]
fn test_parse_ip_from_xml_whitespace_ip() {
let xml = r#"<settings><client ip=" " /></settings>"#;
assert!(parse_ip_from_xml(xml).is_none());
}
#[test]
fn test_parse_ip_from_xml_ipv6_format() {
let xml = r#"<settings><client ip="::1"/></settings>"#;
assert!(parse_ip_from_xml(xml).is_none());
}
#[test]
fn test_parse_ip_from_xml_special_characters() {
let xml = r#"<settings><client country="US" ip="192.168.1.1" isp="ISP"/></settings>"#;
assert_eq!(parse_ip_from_xml(xml), Some("192.168.1.1".to_string()));
}
#[test]
fn test_parse_ip_from_xml_garbage_after_xml() {
let xml = r#"<settings><client ip="1.2.3.4" /></settings>GARBAGE"#;
assert_eq!(parse_ip_from_xml(xml), Some("1.2.3.4".to_string()));
}
#[test]
fn test_parse_ip_from_xml_malformed_xml() {
assert!(parse_ip_from_xml("<settings><client").is_none());
assert!(parse_ip_from_xml("</settings>").is_none());
assert!(parse_ip_from_xml("").is_none());
}
#[tokio::test]
async fn test_discover_client_ip_handles_network_failure() {
let settings = Settings::default().with_retry_disabled();
let client = create_client(&settings).unwrap();
let result = discover_client_ip(&client).await;
match result {
Ok(ip) => {
assert!(ip == "unknown" || common::is_valid_ipv4(&ip));
}
Err(e) => {
assert!(matches!(e, Error::NetworkError(_)));
}
}
}
#[test]
fn test_tls_config_debug() {
let config = TlsConfig::default();
let debug_str = format!("{:?}", config);
assert!(debug_str.contains("TlsConfig"));
}
#[test]
fn test_tls_config_clone() {
let config = TlsConfig::default()
.with_ca_cert(std::path::PathBuf::from("/test.pem"))
.with_min_tls_version("1.3")
.with_cert_pinning();
let cloned = config.clone();
assert_eq!(cloned.ca_cert_path, config.ca_cert_path);
assert_eq!(cloned.min_tls_version, config.min_tls_version);
assert_eq!(cloned.pin_speedtest_certs, config.pin_speedtest_certs);
}
#[test]
fn test_tls_config_default_trait() {
let config = TlsConfig::default();
assert!(config.ca_cert_path.is_none());
assert!(config.min_tls_version.is_none());
assert!(!config.pin_speedtest_certs);
}
#[test]
fn test_settings_with_source_ip() {
let settings = Settings {
source_ip: Some("10.0.0.1".to_string()),
..Default::default()
};
let cloned = settings.clone();
assert_eq!(cloned.source_ip, Some("10.0.0.1".to_string()));
}
#[test]
fn test_settings_builder_full_chain() {
let settings = Settings::default()
.with_user_agent("Test/1.0")
.with_retry_disabled();
assert_eq!(settings.user_agent, "Test/1.0");
assert!(!settings.retry_enabled);
}
#[test]
fn test_settings_clone_is_independent() {
let mut settings = Settings {
timeout_secs: 60,
..Default::default()
};
let cloned = settings.clone();
assert_eq!(cloned.timeout_secs, 60);
settings.timeout_secs = 120;
assert_eq!(cloned.timeout_secs, 60); }
#[test]
fn test_create_client_with_source_ip_none() {
let settings = Settings::default();
let result = create_client(&settings);
assert!(result.is_ok());
}
#[test]
fn test_create_client_with_custom_user_agent() {
let settings = Settings::default().with_user_agent("TestAgent/1.0");
let result = create_client(&settings);
assert!(result.is_ok());
}
#[test]
fn test_create_client_timeout_zero() {
let settings = Settings {
timeout_secs: 0,
..Default::default()
};
let result = create_client(&settings);
assert!(result.is_ok());
}
#[test]
fn test_create_client_timeout_large() {
let settings = Settings {
timeout_secs: 300,
..Default::default()
};
let result = create_client(&settings);
assert!(result.is_ok());
}
#[test]
fn test_error_context_message() {
let err = Error::context("test error");
let msg = format!("{:?}", err);
assert!(msg.contains("test error"));
}
#[test]
fn test_error_context_with_source() {
let inner = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
let err = Error::with_source("operation failed", inner);
let msg = format!("{:?}", err);
assert!(msg.contains("operation failed") || msg.contains("file not found"));
}
#[test]
fn test_error_server_not_found() {
let err = Error::ServerNotFound("no servers available".into());
let msg = format!("{:?}", err);
assert!(msg.contains("no servers available") || msg.contains("ServerNotFound"));
}
#[test]
fn test_error_download_failure() {
let err = Error::DownloadFailure("test download failed".into());
let msg = format!("{:?}", err);
assert!(msg.contains("test download failed") || msg.contains("DownloadFailure"));
}
#[test]
fn test_error_upload_failure() {
let err = Error::UploadFailure("test upload failed".into());
let msg = format!("{:?}", err);
assert!(msg.contains("test upload failed") || msg.contains("UploadFailure"));
}
#[test]
fn test_error_context_debug() {
let err = Error::context("context debug");
let debug_str = format!("{:?}", err);
assert!(debug_str.contains("Context"));
assert!(debug_str.contains("context debug"));
}
#[test]
fn test_error_context_display() {
let err = Error::context("context display");
assert_eq!(format!("{}", err), "context display");
}
#[test]
fn test_error_download_failure_display() {
let err = Error::DownloadFailure("download failed".into());
let display = format!("{}", err);
assert!(display.contains("download failed"));
}
#[test]
fn test_error_upload_failure_display() {
let err = Error::UploadFailure("upload failed".into());
let display = format!("{}", err);
assert!(display.contains("upload failed"));
}
#[test]
fn test_error_server_not_found_display() {
let err = Error::ServerNotFound("server not found".into());
let display = format!("{}", err);
assert!(display.contains("Server not found"));
assert!(display.contains("server not found"));
}
#[test]
fn test_settings_default_timeout_10() {
let settings = Settings::default();
assert_eq!(settings.timeout_secs, 10);
}
#[test]
fn test_settings_default_retry_true() {
let settings = Settings::default();
assert!(settings.retry_enabled);
}
#[test]
fn test_settings_with_timeout() {
let settings = Settings {
timeout_secs: 120,
..Default::default()
};
assert_eq!(settings.timeout_secs, 120);
}
#[test]
fn test_tls_config_default_values() {
let tls = TlsConfig::default();
assert!(tls.ca_cert_path.is_none());
assert!(tls.min_tls_version.is_none());
assert!(!tls.pin_speedtest_certs);
}
#[test]
fn test_tls_config_multiple_options() {
let tls = TlsConfig::default()
.with_ca_cert("/path/to/ca.pem".into())
.with_min_tls_version("1.2");
assert!(tls.ca_cert_path.is_some());
assert!(tls.min_tls_version.is_some());
}
#[test]
fn test_settings_chained_modifications() {
let settings = Settings::default()
.with_user_agent("Test/1.0")
.with_retry_disabled()
.with_user_agent("Test/2.0");
assert_eq!(settings.user_agent, "Test/2.0");
assert!(!settings.retry_enabled);
}
#[test]
fn test_default_user_agent_is_valid() {
assert!(!DEFAULT_USER_AGENT.is_empty());
assert!(DEFAULT_USER_AGENT.contains("Mozilla"));
assert!(DEFAULT_USER_AGENT.contains("Chrome"));
}
#[test]
fn test_default_user_agent_in_settings() {
let settings = Settings::default();
assert_eq!(settings.user_agent, DEFAULT_USER_AGENT);
}
#[test]
fn test_create_client_all_defaults() {
let settings = Settings::default();
let result = create_client(&settings);
assert!(result.is_ok());
}
#[test]
fn test_create_client_minimal_tls_config() {
let settings = Settings {
tls: TlsConfig::default(),
..Default::default()
};
let result = create_client(&settings);
assert!(result.is_ok());
}
#[test]
fn test_create_client_http1_only() {
let settings = Settings::default();
let result = create_client(&settings);
assert!(result.is_ok());
}
#[test]
fn test_pinning_verifier_single_char_subdomain() {
assert!(PinningVerifier::is_valid_domain("a.speedtest.net"));
assert!(PinningVerifier::is_valid_domain("z.ookla.com"));
}
#[test]
fn test_pinning_verifier_numbers_in_subdomain() {
assert!(PinningVerifier::is_valid_domain("123.speedtest.net")); assert!(!PinningVerifier::is_valid_domain("speedtest123.net")); assert!(!PinningVerifier::is_valid_domain("123speedtest.net")); }
#[test]
fn test_pinning_verifier_unicode_in_subdomain() {
assert!(PinningVerifier::is_valid_domain("münchen.speedtest.net"));
}
#[test]
fn test_pinning_verifier_empty_cert_with_valid_domain() {
let verifier = PinningVerifier::new();
let dns_name =
rustls::pki_types::DnsName::try_from("cdn.speedtest.net".to_string()).unwrap();
let server_name = ServerName::DnsName(dns_name);
let cert_der = CertificateDer::from(vec![]);
let result =
verifier.verify_server_cert(&cert_der, &[], &server_name, &[], UnixTime::now());
assert!(result.is_err());
}
#[test]
fn test_pinning_verifier_subdomain_with_dashes() {
assert!(PinningVerifier::is_valid_domain(
"my-custom-subdomain.speedtest.net"
));
assert!(PinningVerifier::is_valid_domain("api-v2.ookla.com"));
}
#[test]
fn test_pinning_verifier_long_subdomain() {
let long_subdomain = "a".repeat(63) + ".speedtest.net";
assert!(PinningVerifier::is_valid_domain(&long_subdomain));
}
#[test]
fn test_pinning_verifier_concatenation_attack() {
assert!(!PinningVerifier::is_valid_domain("speedtestXnet"));
assert!(!PinningVerifier::is_valid_domain("speedtestXcom"));
assert!(!PinningVerifier::is_valid_domain("ooklaXcom"));
assert!(!PinningVerifier::is_valid_domain("ooklaXnet"));
}
#[test]
fn test_settings_retry_disabled_chain() {
let settings = Settings::default().with_retry_disabled();
assert!(!settings.retry_enabled);
assert_eq!(settings.timeout_secs, 10);
assert_eq!(settings.user_agent, DEFAULT_USER_AGENT);
}
#[test]
fn test_settings_user_agent_chain() {
let settings = Settings::default()
.with_user_agent("Custom/1.0")
.with_user_agent("Custom/2.0");
assert_eq!(settings.user_agent, "Custom/2.0");
}
#[test]
fn test_create_client_source_ip_loopback_v4() {
let settings = Settings {
source_ip: Some("127.0.0.1".to_string()),
..Default::default()
};
let result = create_client(&settings);
match result {
Ok(_) | Err(Error::NetworkError(_) | Error::Context { .. }) => {}
Err(e) => panic!("Unexpected error: {e:?}"),
}
}
#[test]
fn test_create_client_source_ip_loopback_v6() {
let settings = Settings {
source_ip: Some("::1".to_string()),
..Default::default()
};
let result = create_client(&settings);
match result {
Ok(_) | Err(Error::NetworkError(_) | Error::Context { .. }) => {}
Err(e) => panic!("Unexpected error: {e:?}"),
}
}
#[test]
fn test_create_client_source_ip_unspecified() {
let settings = Settings {
source_ip: Some("0.0.0.0".to_string()),
..Default::default()
};
let result = create_client(&settings);
match result {
Ok(_) | Err(Error::NetworkError(_) | Error::Context { .. }) => {}
Err(e) => panic!("Unexpected error: {e:?}"),
}
}
#[test]
fn test_create_client_source_ip_with_tls() {
let settings = Settings {
source_ip: Some("127.0.0.1".to_string()),
tls: TlsConfig::default(),
..Default::default()
};
let result = create_client(&settings);
match result {
Ok(_) | Err(Error::NetworkError(_) | Error::Context { .. }) => {}
Err(e) => panic!("Unexpected error: {e:?}"),
}
}
}