use super::mirror::{Mirror, TestResult};
use anyhow::Result;
use futures::stream::{self, StreamExt};
use reqwest::Client;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::time::timeout;
use tracing::{debug, info, warn};
use crate::config::Config;
use crate::utils::{create_client, parse_size, SMirrorsError};
pub type ProgressCallback = Arc<dyn Fn(usize, usize, &str) + Send + Sync>;
pub struct MirrorTester {
client: Client,
test_size: usize,
timeout_duration: Duration,
concurrent_tests: usize,
speed_weight: f64,
latency_weight: f64,
retries: u32,
}
impl MirrorTester {
pub fn new(
test_size: usize,
timeout_secs: u64,
concurrent_tests: usize,
speed_weight: f64,
latency_weight: f64,
retries: u32,
) -> Result<Self> {
let client = create_client(timeout_secs, None)?;
Ok(Self {
client,
test_size,
timeout_duration: Duration::from_secs(timeout_secs),
concurrent_tests,
speed_weight,
latency_weight,
retries,
})
}
pub fn from_config(config: &Config) -> Result<Self> {
let test_size = parse_size(&config.testing.test_file_size)?;
Self::new(
test_size,
config.general.timeout,
config.general.concurrent_tests,
config.testing.speed_weight,
config.testing.latency_weight,
config.general.retries,
)
}
pub async fn test_mirror(&self, mirror: &Mirror) -> TestResult {
debug!("Testing mirror: {}", mirror.url);
for attempt in 0..=self.retries {
if attempt > 0 {
debug!("Retry attempt {} for {}", attempt, mirror.url);
tokio::time::sleep(Duration::from_millis(500 * attempt as u64)).await;
}
let result = self.test_mirror_once(mirror).await;
if result.success || attempt == self.retries {
return result;
}
}
TestResult::failure(
mirror.clone(),
"All retry attempts exhausted".to_string(),
)
}
async fn test_mirror_once(&self, mirror: &Mirror) -> TestResult {
let latency = match self.test_latency(&mirror.url).await {
Ok(lat) => lat,
Err(e) => {
warn!("Latency test failed for {}: {}", mirror.url, e);
return TestResult::failure(
mirror.clone(),
format!("Latency test failed: {}", e),
);
}
};
debug!("Latency for {}: {:?}", mirror.url, latency);
let speed = match self.test_speed(&mirror.url, mirror).await {
Ok(spd) => spd,
Err(e) => {
warn!("Speed test failed for {}: {}", mirror.url, e);
return TestResult::failure(
mirror.clone(),
format!("Speed test failed: {}", e),
);
}
};
debug!("Speed for {}: {:.2} MB/s", mirror.url, speed);
let mut temp_mirror = mirror.clone();
temp_mirror.update_from_test(speed, latency, self.speed_weight, self.latency_weight);
let score = temp_mirror.score.unwrap_or(0.0);
info!(
"Mirror {} tested: speed={:.2} MB/s, latency={:?}, score={:.2}",
mirror.url, speed, latency, score
);
TestResult::success(temp_mirror, speed, latency, score)
}
pub async fn test_all(
&self,
mirrors: Vec<Mirror>,
progress: Option<ProgressCallback>,
) -> Vec<TestResult> {
let total = mirrors.len();
info!("Testing {} mirrors with concurrency {}", total, self.concurrent_tests);
let completed = Arc::new(std::sync::atomic::AtomicUsize::new(0));
let results: Vec<TestResult> = stream::iter(mirrors.into_iter().enumerate())
.map(|(_idx, mirror)| {
let completed = Arc::clone(&completed);
let progress = progress.clone();
async move {
let result = self.test_mirror(&mirror).await;
let count = completed.fetch_add(1, std::sync::atomic::Ordering::SeqCst) + 1;
if let Some(callback) = progress {
callback(count, total, &mirror.url_string());
}
result
}
})
.buffer_unordered(self.concurrent_tests)
.collect()
.await;
info!("Completed testing {} mirrors", total);
results
}
async fn test_latency(&self, url: &url::Url) -> Result<Duration> {
let start = Instant::now();
let result = timeout(
self.timeout_duration,
self.client.head(url.as_str()).send(),
)
.await;
match result {
Ok(Ok(response)) => {
if !response.status().is_success() {
return Err(SMirrorsError::MirrorTestFailed {
url: url.to_string(),
reason: format!("HTTP status: {}", response.status()),
}
.into());
}
Ok(start.elapsed())
}
Ok(Err(e)) => Err(SMirrorsError::MirrorTestFailed {
url: url.to_string(),
reason: format!("Request failed: {}", e),
}
.into()),
Err(_) => Err(SMirrorsError::TestTimeout(self.timeout_duration.as_secs()).into()),
}
}
async fn test_speed(&self, url: &url::Url, mirror: &Mirror) -> Result<f64> {
let test_url = self.get_test_file_url(url, mirror)?;
debug!("Testing speed with URL: {}", test_url);
let start = Instant::now();
let response_result = timeout(
self.timeout_duration,
self.client.get(&test_url).send(),
)
.await;
let mut response = match response_result {
Ok(Ok(resp)) => {
if !resp.status().is_success() {
return Err(SMirrorsError::MirrorTestFailed {
url: url.to_string(),
reason: format!("HTTP status: {}", resp.status()),
}
.into());
}
resp
}
Ok(Err(e)) => {
return Err(SMirrorsError::MirrorTestFailed {
url: url.to_string(),
reason: format!("Request failed: {}", e),
}
.into())
}
Err(_) => {
return Err(SMirrorsError::TestTimeout(self.timeout_duration.as_secs()).into())
}
};
let mut downloaded = 0u64;
while downloaded < self.test_size as u64 {
let chunk_result = timeout(
Duration::from_secs(5),
response.chunk(),
)
.await;
match chunk_result {
Ok(Ok(Some(chunk))) => {
downloaded += chunk.len() as u64;
if start.elapsed() > self.timeout_duration {
break;
}
}
Ok(Ok(None)) => {
break;
}
Ok(Err(e)) => {
return Err(SMirrorsError::MirrorTestFailed {
url: url.to_string(),
reason: format!("Download failed: {}", e),
}
.into());
}
Err(_) => {
return Err(SMirrorsError::MirrorTestFailed {
url: url.to_string(),
reason: "Chunk timeout".to_string(),
}
.into());
}
}
}
let elapsed = start.elapsed();
if elapsed.as_secs_f64() < 0.01 {
return Err(SMirrorsError::MirrorTestFailed {
url: url.to_string(),
reason: "Download too fast to measure accurately".to_string(),
}
.into());
}
let speed_mbps = (downloaded as f64 / elapsed.as_secs_f64()) / 1_000_000.0;
Ok(speed_mbps)
}
fn get_test_file_url(&self, mirror_url: &url::Url, mirror: &Mirror) -> Result<String> {
if let Some(test_file) = mirror.metadata.get("test_file") {
return Ok(mirror_url.join(test_file)?.to_string());
}
let distro_type = mirror.metadata.get("distro_type")
.map(|s| s.as_str())
.unwrap_or("unknown");
let test_file = match distro_type {
"debian" | "ubuntu" => "ls-lR.gz",
"fedora" | "rhel" => "repodata/repomd.xml",
"arch" => "core/os/x86_64/core.db",
"opensuse" => "content",
_ => {
"ls-lR.gz"
}
};
Ok(mirror_url.join(test_file)?.to_string())
}
pub fn filter_successful(results: Vec<TestResult>) -> Vec<TestResult> {
results
.into_iter()
.filter(|r| r.success)
.collect()
}
pub fn sort_by_score(mut results: Vec<TestResult>) -> Vec<TestResult> {
results.sort_by(|a, b| {
let score_a = a.score.unwrap_or(0.0);
let score_b = b.score.unwrap_or(0.0);
score_b.partial_cmp(&score_a).unwrap_or(std::cmp::Ordering::Equal)
});
results
}
pub fn extract_mirrors(results: Vec<TestResult>) -> Vec<Mirror> {
results.into_iter().map(|r| r.mirror).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tester_creation() {
let tester = MirrorTester::new(
1024 * 1024,
10,
5,
0.7,
0.3,
3,
);
assert!(tester.is_ok());
let tester = tester.unwrap();
assert_eq!(tester.test_size, 1024 * 1024);
assert_eq!(tester.concurrent_tests, 5);
}
#[test]
fn test_filter_successful() {
let url1 = url::Url::parse("https://mirror1.example.com").unwrap();
let url2 = url::Url::parse("https://mirror2.example.com").unwrap();
let mirror1 = Mirror::new(url1);
let mirror2 = Mirror::new(url2);
let results = vec![
TestResult::success(mirror1, 50.0, Duration::from_millis(100), 0.8),
TestResult::failure(mirror2, "Error".to_string()),
];
let successful = MirrorTester::filter_successful(results);
assert_eq!(successful.len(), 1);
assert!(successful[0].success);
}
#[test]
fn test_sort_by_score() {
let url1 = url::Url::parse("https://mirror1.example.com").unwrap();
let url2 = url::Url::parse("https://mirror2.example.com").unwrap();
let url3 = url::Url::parse("https://mirror3.example.com").unwrap();
let mirror1 = Mirror::new(url1);
let mirror2 = Mirror::new(url2);
let mirror3 = Mirror::new(url3);
let results = vec![
TestResult::success(mirror1, 30.0, Duration::from_millis(200), 0.5),
TestResult::success(mirror2, 60.0, Duration::from_millis(50), 0.9),
TestResult::success(mirror3, 45.0, Duration::from_millis(100), 0.7),
];
let sorted = MirrorTester::sort_by_score(results);
assert_eq!(sorted[0].score.unwrap(), 0.9);
assert_eq!(sorted[1].score.unwrap(), 0.7);
assert_eq!(sorted[2].score.unwrap(), 0.5);
}
}