use std::{
net::IpAddr,
sync::{
Arc,
atomic::{AtomicU32, Ordering},
},
thread,
};
use aok::{OK, Void};
use pick_fast::PickFast;
#[static_init::constructor(0)]
extern "C" fn _log_init() {
log_init::init();
}
#[derive(Debug, Clone, Copy)]
struct DnsServer {
ip: IpAddr,
}
#[test]
fn test() -> Void {
const SERVERS: [DnsServer; 8] = [
DnsServer {
ip: IpAddr::V4(std::net::Ipv4Addr::new(8, 8, 8, 8)),
}, DnsServer {
ip: IpAddr::V4(std::net::Ipv4Addr::new(1, 1, 1, 1)),
}, DnsServer {
ip: IpAddr::V4(std::net::Ipv4Addr::new(223, 5, 5, 5)),
}, DnsServer {
ip: IpAddr::V4(std::net::Ipv4Addr::new(208, 67, 222, 222)),
}, DnsServer {
ip: IpAddr::V4(std::net::Ipv4Addr::new(9, 9, 9, 9)),
}, DnsServer {
ip: IpAddr::V4(std::net::Ipv4Addr::new(1, 0, 0, 1)),
}, DnsServer {
ip: IpAddr::V4(std::net::Ipv4Addr::new(114, 114, 114, 114)),
}, DnsServer {
ip: IpAddr::V4(std::net::Ipv4Addr::new(180, 76, 76, 76)),
}, ];
let lb = Arc::new(PickFast::<DnsServer, pick_fast::Inverse>::new(SERVERS));
println!("Load Balancer initialized with {} nodes.", SERVERS.len());
let handles: Vec<_> = (0..8)
.map(|_| {
let c_lb = lb.clone();
thread::spawn(move || {
for _ in 0..1000 {
let node = c_lb.pick();
let latency = match node.ip {
ip if ip == IpAddr::V4(std::net::Ipv4Addr::new(8, 8, 8, 8)) => 100_000,
ip if ip == IpAddr::V4(std::net::Ipv4Addr::new(1, 1, 1, 1)) => 80_000,
ip if ip == IpAddr::V4(std::net::Ipv4Addr::new(223, 5, 5, 5)) => 5_000,
ip if ip == IpAddr::V4(std::net::Ipv4Addr::new(208, 67, 222, 222)) => 60_000,
ip if ip == IpAddr::V4(std::net::Ipv4Addr::new(9, 9, 9, 9)) => 40_000,
ip if ip == IpAddr::V4(std::net::Ipv4Addr::new(1, 0, 0, 1)) => 20_000,
ip if ip == IpAddr::V4(std::net::Ipv4Addr::new(114, 114, 114, 114)) => 70_000,
ip if ip == IpAddr::V4(std::net::Ipv4Addr::new(180, 76, 76, 76)) => 90_000,
_ => 100_000,
};
c_lb.set(node.index, latency);
}
})
})
.collect();
for h in handles {
h.join().unwrap();
}
let w_slow = lb.li[0].weight.load(Ordering::Relaxed);
let w_fast = lb.li[2].weight.load(Ordering::Relaxed);
println!("Slow Node Weight Google (8.8.8.8, 100ms): {}", w_slow);
println!("Fast Node Weight AliDNS (223.5.5.5, 5ms): {}", w_fast);
println!(
"Ratio: {:.2} (Expected ~20.0)",
w_fast as f64 / w_slow as f64
);
OK
}
#[test]
fn test_pick_count_with_chart() -> Void {
const SERVERS: [DnsServer; 8] = [
DnsServer {
ip: IpAddr::V4(std::net::Ipv4Addr::new(8, 8, 8, 8)),
}, DnsServer {
ip: IpAddr::V4(std::net::Ipv4Addr::new(1, 1, 1, 1)),
}, DnsServer {
ip: IpAddr::V4(std::net::Ipv4Addr::new(223, 5, 5, 5)),
}, DnsServer {
ip: IpAddr::V4(std::net::Ipv4Addr::new(208, 67, 222, 222)),
}, DnsServer {
ip: IpAddr::V4(std::net::Ipv4Addr::new(9, 9, 9, 9)),
}, DnsServer {
ip: IpAddr::V4(std::net::Ipv4Addr::new(1, 0, 0, 1)),
}, DnsServer {
ip: IpAddr::V4(std::net::Ipv4Addr::new(114, 114, 114, 114)),
}, DnsServer {
ip: IpAddr::V4(std::net::Ipv4Addr::new(180, 76, 76, 76)),
}, ];
const LATENCIES: [u32; 8] = [100, 80, 5, 60, 40, 20, 70, 90];
let lb = Arc::new(PickFast::<DnsServer, pick_fast::Inverse>::new(SERVERS));
let pick_counts: Arc<[AtomicU32; 8]> = Arc::new([const { AtomicU32::new(0) }; 8]);
println!("Running 10000 picks to verify fast node is selected more than slow node...");
let handles: Vec<_> = (0..8)
.map(|_| {
let c_lb = lb.clone();
let c_counts = pick_counts.clone();
thread::spawn(move || {
for _ in 0..1250 {
let node = c_lb.pick();
c_counts[node.index].fetch_add(1, Ordering::Relaxed);
let latency = match node.ip {
ip if ip == IpAddr::V4(std::net::Ipv4Addr::new(8, 8, 8, 8)) => 100_000,
ip if ip == IpAddr::V4(std::net::Ipv4Addr::new(1, 1, 1, 1)) => 80_000,
ip if ip == IpAddr::V4(std::net::Ipv4Addr::new(223, 5, 5, 5)) => 5_000,
ip if ip == IpAddr::V4(std::net::Ipv4Addr::new(208, 67, 222, 222)) => 60_000,
ip if ip == IpAddr::V4(std::net::Ipv4Addr::new(9, 9, 9, 9)) => 40_000,
ip if ip == IpAddr::V4(std::net::Ipv4Addr::new(1, 0, 0, 1)) => 20_000,
ip if ip == IpAddr::V4(std::net::Ipv4Addr::new(114, 114, 114, 114)) => 70_000,
ip if ip == IpAddr::V4(std::net::Ipv4Addr::new(180, 76, 76, 76)) => 90_000,
_ => 100_000,
};
c_lb.set(node.index, latency);
}
})
})
.collect();
for h in handles {
h.join().unwrap();
}
let mut counts = [0; 8];
for (i, counter) in pick_counts.iter().enumerate() {
counts[i] = counter.load(Ordering::Relaxed);
}
println!("\n=== 节点选择统计 ===");
for (i, &count) in counts.iter().enumerate() {
println!(
"Node {}: {}ms -> {} times",
SERVERS[i].ip, LATENCIES[i], count
);
}
let slow_count = counts[0];
let fast_count = counts[2];
println!("\n慢节点 (8.8.8.8, 100ms) 被选中: {} 次", slow_count);
println!("快节点 (223.5.5.5, 5ms) 被选中: {} 次", fast_count);
println!("比例: {:.2}", fast_count as f64 / slow_count as f64);
assert!(
fast_count > slow_count,
"Fast node should be picked more than slow node"
);
let mut indexed_data: Vec<_> = SERVERS
.iter()
.zip(counts.iter())
.zip(LATENCIES.iter())
.enumerate()
.collect();
indexed_data.sort_by_key(|&(_, ((_, &count), _))| std::cmp::Reverse(count));
let sorted_counts: Vec<u32> = indexed_data
.iter()
.map(|&(_, ((_, &count), _))| count)
.collect();
let sorted_latencies: Vec<u32> = indexed_data
.iter()
.map(|&(_, ((..), &latency))| latency)
.collect();
let sorted_servers: Vec<&DnsServer> = indexed_data
.iter()
.map(|&(_, ((server, _), _))| server)
.collect();
draw_svg_histogram(&sorted_servers, &sorted_latencies, &sorted_counts)?;
OK
}
fn draw_svg_histogram(servers: &[&DnsServer], latencies: &[u32], counts: &[u32]) -> Void {
use std::fs;
fs::create_dir_all("readme")?;
draw_3d_chart(servers, latencies, counts, "readme/rank-zh.svg", true)?;
draw_3d_chart(servers, latencies, counts, "readme/rank-en.svg", false)?;
println!("SVG图表已保存到 readme/rank-zh.svg 和 readme/rank-en.svg");
println!("SVG charts saved to readme/rank-zh.svg and readme/rank-en.svg");
OK
}
fn draw_3d_chart(
servers: &[&DnsServer],
latencies: &[u32],
counts: &[u32],
filename: &str,
is_chinese: bool,
) -> Void {
use std::fs;
use svg::{
Document,
node::element::{Group, Polygon, Rectangle, Text},
};
let width = 1000;
let height = 480; let margin = 50; let title_margin = 60; let chart_width = width - 2 * margin;
let chart_height = height - 2 * margin - title_margin - 60;
let max_count = *counts.iter().max().unwrap_or(&1);
let depth = 40.0;
let angle_rad: f64 = 0.5; let dx = depth * angle_rad.cos();
let dy = depth * angle_rad.sin();
let mut document = Document::new()
.set("viewBox", (0, 0, width, height))
.set("width", width)
.set("height", height);
let front_color = "rgb(147, 197, 253)"; let top_color = "rgb(191, 219, 254)"; let right_color = "rgb(96, 165, 250)";
let title = if is_chinese {
"PickFast 使用演示:DNS 响应延时 与 选中次数"
} else {
"PickFast Demo: DNS Response Latency vs Selection Count"
};
let title_text = Text::new(title)
.set("x", width / 2)
.set("y", 35) .set("text-anchor", "middle")
.set("font-family", "Arial, sans-serif")
.set("font-size", 22) .set("font-weight", "bold")
.set("fill", "rgb(30, 41, 59)");
document = document.add(title_text);
let bar_width = chart_width as f64 / 8.0 * 0.7;
let bar_spacing = chart_width as f64 / 8.0;
for (i, &count) in counts.iter().enumerate() {
if count == 0 {
continue;
}
let x = margin as f64 + i as f64 * bar_spacing + bar_spacing * 0.15;
let bar_height = (count as f64 / max_count as f64) * chart_height as f64;
let y = margin as f64 + title_margin as f64 + chart_height as f64 - bar_height;
let mut group = Group::new();
let front_face = Rectangle::new()
.set("x", x)
.set("y", y)
.set("width", bar_width)
.set("height", bar_height)
.set("fill", front_color)
.set("stroke", "rgba(0,0,0,0.2)")
.set("stroke-width", 1);
group = group.add(front_face);
let top_points = format!(
"{},{} {},{} {},{} {},{}",
x,
y,
x + bar_width,
y,
x + bar_width + dx,
y - dy,
x + dx,
y - dy
);
let top_face = Polygon::new()
.set("points", top_points)
.set("fill", top_color)
.set("stroke", "rgba(0,0,0,0.2)")
.set("stroke-width", 1);
group = group.add(top_face);
let right_points = format!(
"{},{} {},{} {},{} {},{}",
x + bar_width,
y,
x + bar_width,
y + bar_height,
x + bar_width + dx,
y + bar_height - dy,
x + bar_width + dx,
y - dy
);
let right_face = Polygon::new()
.set("points", right_points)
.set("fill", right_color)
.set("stroke", "rgba(0,0,0,0.2)")
.set("stroke-width", 1);
group = group.add(right_face);
let value_stroke = Text::new(format!("{count}"))
.set("x", x + bar_width / 2.0)
.set("y", y - 10.0)
.set("text-anchor", "middle")
.set("font-family", "Arial, sans-serif")
.set("font-size", 13)
.set("font-weight", "bold")
.set("fill", "none")
.set("stroke", "white")
.set("stroke-width", 3);
group = group.add(value_stroke);
let value_text = Text::new(format!("{count}"))
.set("x", x + bar_width / 2.0)
.set("y", y - 10.0)
.set("text-anchor", "middle")
.set("font-family", "Arial, sans-serif")
.set("font-size", 13)
.set("font-weight", "bold")
.set("fill", "black");
group = group.add(value_text);
let latency_text = Text::new(format!("{}ms", latencies[i]))
.set("x", x + bar_width / 2.0)
.set("y", margin + title_margin + chart_height + 20) .set("text-anchor", "middle")
.set("font-family", "Arial, sans-serif")
.set("font-size", 12)
.set("fill", "rgb(71, 85, 105)");
group = group.add(latency_text);
let ip_text = Text::new(format!("{}", servers[i].ip))
.set("x", x + bar_width / 2.0)
.set("y", margin + title_margin + chart_height + 38) .set("text-anchor", "middle")
.set("font-family", "Arial, sans-serif")
.set("font-size", 10)
.set("fill", "rgb(100, 116, 139)");
group = group.add(ip_text);
document = document.add(group);
}
let y_desc = if is_chinese {
"选择次数"
} else {
"Selection Count"
};
let y_label_x = 35; let y_label_y = margin + title_margin + chart_height / 2; let y_label = Text::new(y_desc)
.set("x", y_label_x)
.set("y", y_label_y)
.set("text-anchor", "middle")
.set("font-family", "Arial, sans-serif")
.set("font-size", 16) .set("fill", "rgb(71, 85, 105)")
.set("transform", format!("rotate(-90, {y_label_x}, {y_label_y})")); document = document.add(y_label);
fs::write(filename, document.to_string())?;
OK
}
#[cfg(feature = "iter")]
#[test]
fn test_iter() -> Void {
const SERVERS: [DnsServer; 4] = [
DnsServer {
ip: IpAddr::V4(std::net::Ipv4Addr::new(8, 8, 8, 8)),
}, DnsServer {
ip: IpAddr::V4(std::net::Ipv4Addr::new(1, 1, 1, 1)),
}, DnsServer {
ip: IpAddr::V4(std::net::Ipv4Addr::new(223, 5, 5, 5)),
}, DnsServer {
ip: IpAddr::V4(std::net::Ipv4Addr::new(208, 67, 222, 222)),
}, ];
let lb = PickFast::<DnsServer, pick_fast::Inverse>::new(SERVERS);
lb.set(0, 100_000); lb.set(1, 80_000);
lb.set(2, 5_000); lb.set(3, 60_000);
println!("Testing iter() with weighted random start position...");
let mut start_positions = [0u32; 4];
for _ in 0..1000 {
let mut iter = lb.iter();
let first_item = iter.next().unwrap();
let actual_index = lb
.li
.iter()
.position(|n| std::ptr::eq(n, first_item))
.unwrap();
start_positions[actual_index] += 1;
}
println!("Start position distribution over 1000 iterations:");
for (i, &count) in start_positions.iter().enumerate() {
println!("Position {i}: {count} times ({:.1}%)", count as f32 / 10.0);
}
let fast_node_starts = start_positions[2];
let slow_node_starts = start_positions[0];
println!("Fast node (index 2) starts: {fast_node_starts}");
println!("Slow node (index 0) starts: {slow_node_starts}");
assert!(
fast_node_starts > 0 && slow_node_starts > 0,
"Both fast and slow nodes should be selected at least once (fast: {fast_node_starts}, slow: {slow_node_starts})"
);
let iter = lb.iter();
let items: Vec<_> = iter.take(4).collect(); assert_eq!(items.len(), 4);
println!(
"Iterator test passed - collected {len} items",
len = items.len()
);
OK
}
#[test]
fn test_failed_method() -> Void {
const SERVERS: [DnsServer; 3] = [
DnsServer {
ip: IpAddr::V4(std::net::Ipv4Addr::new(8, 8, 8, 8)),
},
DnsServer {
ip: IpAddr::V4(std::net::Ipv4Addr::new(1, 1, 1, 1)),
},
DnsServer {
ip: IpAddr::V4(std::net::Ipv4Addr::new(223, 5, 5, 5)),
},
];
let lb = PickFast::<DnsServer, pick_fast::Inverse>::new(SERVERS);
lb.set(0, 100_000); lb.set(1, 50_000); lb.set(2, 10_000);
println!("Initial weights:");
for (i, node) in lb.li.iter().enumerate() {
let weight = node.weight.load(Ordering::Relaxed);
println!("Node {i}: weight = {weight}");
}
let weight_before = lb.li[1].weight.load(Ordering::Relaxed);
lb.failed(1);
let weight_after = lb.li[1].weight.load(Ordering::Relaxed);
println!("Node 1 weight before failed(): {weight_before}");
println!("Node 1 weight after failed(): {weight_after}");
assert_eq!(weight_after, (weight_before >> 1).max(1));
lb.li[0].weight.store(1, Ordering::Relaxed);
lb.failed(0);
let final_weight = lb.li[0].weight.load(Ordering::Relaxed);
assert_eq!(final_weight, 1);
println!("Node 0 weight after failed() when already 1: {final_weight}");
println!("Failed method test passed");
OK
}
#[cfg(feature = "iter")]
#[tokio::test]
async fn test_iter_with_race_dns() -> Void {
use std::time::Duration;
use race::Race;
use tokio::net::lookup_host;
const DNS_HOSTS: [&str; 4] = [
"8.8.8.8:53", "1.1.1.1:53", "9.9.9.9:53", "208.67.222.222:53", ];
let servers: Vec<DnsServer> = DNS_HOSTS
.iter()
.enumerate()
.map(|(i, _)| DnsServer {
ip: IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, i as u8 + 1)),
})
.collect();
let lb = Arc::new(PickFast::<DnsServer, pick_fast::Inverse>::new(servers));
println!("Testing iter() with race crate for real DNS resolution...");
#[derive(Debug, Clone)]
struct DnsTask {
host: &'static str,
index: usize,
start: std::time::Instant,
}
println!("Testing iter() with race crate for real DNS resolution...");
let lb_clone = lb.clone();
let resolve_dns_with_feedback = move |task: DnsTask| {
let lb = lb_clone.clone();
async move {
match lookup_host(task.host).await {
Ok(mut addrs) => {
if let Some(addr) = addrs.next() {
let duration = task.start.elapsed();
let latency_us = duration.as_micros() as u32;
lb.set(task.index, latency_us);
println!(
" ✅ {} resolved in {duration:?} (latency: {latency_us}μs)",
task.host
);
Ok(addr.ip())
} else {
let duration = task.start.elapsed();
lb.failed(task.index);
let error = std::io::Error::new(
std::io::ErrorKind::NotFound,
format!("No address found for {}", task.host),
);
println!(" ❌ {} failed after {duration:?}: {error}", task.host);
Err(error)
}
}
Err(e) => {
let duration = task.start.elapsed();
lb.failed(task.index);
println!(" ❌ {} failed after {duration:?}: {e}", task.host);
Err(e)
}
}
}
};
let server_iter = lb.iter().map(|server_node| {
let index = lb.li.iter().position(|n| std::ptr::eq(n, server_node)).unwrap();
DnsTask {
host: DNS_HOSTS[index],
index,
start: std::time::Instant::now(),
}
});
let race = Race::new(resolve_dns_with_feedback, Duration::from_millis(500)); let rx = race.run(server_iter);
println!("Starting staggered DNS resolution with 500ms intervals...");
let mut resolved_ip = None;
while let Ok(result) = rx.recv().await {
match result {
Ok(ip) => {
println!("🎯 First successful resolution: {ip}");
resolved_ip = Some(ip);
drop(rx); break;
}
Err(e) => {
println!("⚠️ Resolution attempt failed: {e}");
}
}
}
if resolved_ip.is_some() {
println!("✅ Race-based DNS resolution completed successfully");
println!("This demonstrates how PickFast + Race provides fast DNS failover");
} else {
println!("❌ All DNS resolutions failed (network issue?)");
}
println!("Real DNS resolution test completed");
OK
}
#[cfg(feature = "iter")]
#[tokio::test]
async fn test_dns_performance_analysis() -> Void {
use std::time::Duration;
use tokio::net::lookup_host;
const DNS_HOSTS: [&str; 4] = [
"8.8.8.8:53", "1.1.1.1:53", "9.9.9.9:53", "208.67.222.222:53", ];
const TEST_DOMAINS: [&str; 20] = [
"google.com:80",
"github.com:80",
"stackoverflow.com:80",
"reddit.com:80",
"youtube.com:80",
"facebook.com:80",
"twitter.com:80",
"linkedin.com:80",
"amazon.com:80",
"microsoft.com:80",
"apple.com:80",
"netflix.com:80",
"wikipedia.org:80",
"baidu.com:80",
"qq.com:80",
"taobao.com:80",
"instagram.com:80",
"tiktok.com:80",
"discord.com:80",
"twitch.tv:80",
];
let servers: Vec<DnsServer> = DNS_HOSTS
.iter()
.enumerate()
.map(|(i, _)| DnsServer {
ip: IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, i as u8 + 1)),
})
.collect();
let lb = Arc::new(PickFast::<DnsServer, pick_fast::Inverse>::new(servers));
println!("Starting DNS performance analysis with 100 resolutions...");
println!("Testing {} different domains", TEST_DOMAINS.len());
#[derive(Debug, Clone)]
struct DnsAnalysisTask {
domain: &'static str,
dns_server_host: &'static str,
dns_server_index: usize,
start: std::time::Instant,
}
let lb_clone = lb.clone();
let resolve_with_analysis = move |task: DnsAnalysisTask| {
let lb = lb_clone.clone();
async move {
match lookup_host(task.domain).await {
Ok(mut addrs) => {
if let Some(addr) = addrs.next() {
let duration = task.start.elapsed();
let latency_us = duration.as_micros() as u32;
lb.set(task.dns_server_index, latency_us);
println!(
" ✅ {} via {} resolved in {duration:?} (latency: {latency_us}μs)",
task.domain, task.dns_server_host
);
Ok(addr.ip())
} else {
let duration = task.start.elapsed();
lb.failed(task.dns_server_index);
println!(
" ❌ {} via {} failed after {duration:?}: No address found",
task.domain, task.dns_server_host
);
Err(std::io::Error::new(
std::io::ErrorKind::NotFound,
format!(
"No address found for {} via {}",
task.domain, task.dns_server_host
),
))
}
}
Err(e) => {
let duration = task.start.elapsed();
lb.failed(task.dns_server_index);
println!(
" ❌ {} via {} failed after {duration:?}: {e}",
task.domain, task.dns_server_host
);
Err(e)
}
}
}
};
for i in 0..100 {
let mut iter = lb.iter();
let selected_server = iter.next().unwrap();
let dns_server_index = lb
.li
.iter()
.position(|n| std::ptr::eq(n, selected_server))
.unwrap();
let dns_server_host = DNS_HOSTS[dns_server_index];
let domain = TEST_DOMAINS[i % TEST_DOMAINS.len()];
let task = DnsAnalysisTask {
domain,
dns_server_host,
dns_server_index,
start: std::time::Instant::now(),
};
let _ = resolve_with_analysis(task).await;
tokio::time::sleep(Duration::from_millis(10)).await;
}
println!("\n=== DNS服务器性能分析结果 / DNS Server Performance Analysis ===");
for (i, dns_host) in DNS_HOSTS.iter().enumerate() {
let weight = lb.li[i].weight.load(Ordering::Relaxed);
const BASE: u32 = 1 << 22; let inferred_latency_us = if weight > 0 { BASE / weight } else { u32::MAX };
let inferred_latency_ms = inferred_latency_us as f64 / 1000.0;
println!("DNS服务器 {dns_host}:");
println!(" 权重 Weight: {weight}");
println!(" 倒推延时 Inferred Latency: {inferred_latency_us}μs ({inferred_latency_ms:.2}ms)");
println!();
}
let mut servers_with_perf: Vec<_> = DNS_HOSTS
.iter()
.enumerate()
.map(|(i, host)| {
let weight = lb.li[i].weight.load(Ordering::Relaxed);
let inferred_latency_us = if weight > 0 {
(1 << 22) / weight
} else {
u32::MAX
};
(host, weight, inferred_latency_us)
})
.collect();
servers_with_perf.sort_by_key(|&(_, _, latency)| latency);
println!("=== 性能排名 Performance Ranking ===");
for (rank, &(host, weight, latency_us)) in servers_with_perf.iter().enumerate() {
let latency_ms = latency_us as f64 / 1000.0;
println!(
"#{}: {} - {latency_ms:.2}ms (权重: {weight})",
rank + 1,
host
);
}
println!("\nDNS性能分析完成 / DNS performance analysis completed");
OK
}