use rand::Rng;
use std::env;
use std::fs;
use std::io::Write;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::mpsc;
use std::sync::{Arc, Barrier};
use std::thread;
use std::time::{Duration, Instant};
use walrus_rust::wal::{FsyncSchedule, ReadConsistency, Walrus};
fn get_memory_info() -> (u64, u64, f64) {
#[cfg(target_os = "macos")]
{
let total_memory = get_macos_total_memory();
let dirty_pages = get_macos_dirty_pages();
let dirty_ratio = if total_memory > 0 {
(dirty_pages as f64 / total_memory as f64) * 100.0
} else {
0.0
};
(total_memory, dirty_pages, dirty_ratio)
}
#[cfg(target_os = "linux")]
{
get_linux_memory_info()
}
#[cfg(not(any(target_os = "macos", target_os = "linux")))]
{
(0, 0, 0.0)
}
}
#[cfg(target_os = "macos")]
fn get_macos_total_memory() -> u64 {
use std::process::Command;
if let Ok(output) = Command::new("sysctl").args(&["-n", "hw.memsize"]).output() {
if let Ok(memsize_str) = String::from_utf8(output.stdout) {
if let Ok(memsize_bytes) = memsize_str.trim().parse::<u64>() {
return memsize_bytes / 1024; }
}
}
0
}
#[cfg(target_os = "macos")]
fn get_macos_dirty_pages() -> u64 {
use std::process::Command;
if let Ok(output) = Command::new("vm_stat").output() {
if let Ok(vm_stat_str) = String::from_utf8(output.stdout) {
for line in vm_stat_str.lines() {
if line.contains("Pages modified:") {
if let Some(pages_str) = line.split_whitespace().nth(2) {
if let Ok(pages) = pages_str.trim_end_matches('.').parse::<u64>() {
return pages * 4; }
}
}
}
}
}
0
}
#[cfg(target_os = "linux")]
fn get_linux_memory_info() -> (u64, u64, f64) {
let mut total_memory = 0u64;
let mut dirty_pages = 0u64;
if let Ok(meminfo) = std::fs::read_to_string("/proc/meminfo") {
for line in meminfo.lines() {
if line.starts_with("MemTotal:") {
if let Some(kb_str) = line.split_whitespace().nth(1) {
total_memory = kb_str.parse().unwrap_or(0);
}
} else if line.starts_with("Dirty:") {
if let Some(kb_str) = line.split_whitespace().nth(1) {
dirty_pages = kb_str.parse().unwrap_or(0);
}
}
}
}
let dirty_ratio = if total_memory > 0 {
(dirty_pages as f64 / total_memory as f64) * 100.0
} else {
0.0
};
(total_memory, dirty_pages, dirty_ratio)
}
fn parse_fsync_schedule() -> FsyncSchedule {
if let Ok(fsync_env) = env::var("WALRUS_FSYNC") {
match fsync_env.as_str() {
"sync-each" => return FsyncSchedule::SyncEach,
"no-fsync" | "none" => return FsyncSchedule::NoFsync,
"async" => return FsyncSchedule::Milliseconds(1000),
ms_str if ms_str.ends_with("ms") => {
if let Ok(ms) = ms_str[..ms_str.len() - 2].parse::<u64>() {
return FsyncSchedule::Milliseconds(ms);
}
}
ms_str => {
if let Ok(ms) = ms_str.parse::<u64>() {
return FsyncSchedule::Milliseconds(ms);
}
}
}
}
let args: Vec<String> = env::args().collect();
for i in 0..args.len() {
if args[i] == "--fsync" && i + 1 < args.len() {
match args[i + 1].as_str() {
"sync-each" => return FsyncSchedule::SyncEach,
"no-fsync" | "none" => return FsyncSchedule::NoFsync,
"async" => return FsyncSchedule::Milliseconds(1000),
ms_str if ms_str.ends_with("ms") => {
if let Ok(ms) = ms_str[..ms_str.len() - 2].parse::<u64>() {
return FsyncSchedule::Milliseconds(ms);
}
}
ms_str => {
if let Ok(ms) = ms_str.parse::<u64>() {
return FsyncSchedule::Milliseconds(ms);
}
}
}
}
}
FsyncSchedule::Milliseconds(1000)
}
fn parse_duration() -> Duration {
if let Ok(duration_env) = env::var("WALRUS_DURATION") {
if let Some(duration) = parse_duration_string(&duration_env) {
return duration;
}
}
let args: Vec<String> = env::args().collect();
for i in 0..args.len() {
if args[i] == "--duration" && i + 1 < args.len() {
if let Some(duration) = parse_duration_string(&args[i + 1]) {
return duration;
}
}
}
Duration::from_secs(120)
}
fn parse_duration_string(duration_str: &str) -> Option<Duration> {
if duration_str.ends_with("s") {
if let Ok(secs) = duration_str[..duration_str.len() - 1].parse::<u64>() {
return Some(Duration::from_secs(secs));
}
} else if duration_str.ends_with("m") {
if let Ok(mins) = duration_str[..duration_str.len() - 1].parse::<u64>() {
return Some(Duration::from_secs(mins * 60));
}
} else if duration_str.ends_with("h") {
if let Ok(hours) = duration_str[..duration_str.len() - 1].parse::<u64>() {
return Some(Duration::from_secs(hours * 3600));
}
} else if let Ok(secs) = duration_str.parse::<u64>() {
return Some(Duration::from_secs(secs));
}
None
}
fn print_usage() {
println!(
"Usage: WALRUS_FSYNC=<schedule> WALRUS_DURATION=<duration> cargo test multithreaded_benchmark_writes"
);
println!(
" or: cargo test multithreaded_benchmark_writes -- --fsync <schedule> --duration <duration>"
);
println!();
println!("Fsync Schedule Options:");
println!(" sync-each Fsync after every write (slowest, most durable)");
println!(" no-fsync Disable fsyncing entirely (fastest, no durability)");
println!(" none Same as no-fsync");
println!(" async Async fsync every 1000ms (default)");
println!(" <number>ms Async fsync every N milliseconds (e.g., 500ms)");
println!(" <number> Async fsync every N milliseconds (e.g., 500)");
println!();
println!("Duration Options:");
println!(" <number>s Duration in seconds (e.g., 30s, 120s)");
println!(" <number>m Duration in minutes (e.g., 2m, 5m)");
println!(" <number>h Duration in hours (e.g., 1h, 2h)");
println!(" <number> Duration in seconds (e.g., 120, 300)");
println!(" Default: 2m (120 seconds)");
println!();
println!("Examples:");
println!(
" WALRUS_FSYNC=sync-each WALRUS_DURATION=30s cargo test multithreaded_benchmark_writes"
);
println!(
" WALRUS_FSYNC=no-fsync WALRUS_DURATION=1m cargo test multithreaded_benchmark_writes"
);
println!(" WALRUS_FSYNC=500ms WALRUS_DURATION=5m cargo test multithreaded_benchmark_writes");
println!(" cargo test multithreaded_benchmark_writes -- --fsync no-fsync --duration 1m");
println!(" make bench-writes-sync # Uses Makefile convenience targets");
println!();
println!("Makefile targets:");
println!(" make bench-writes # Default (async 1000ms, 2m duration)");
println!(" make bench-writes-sync # Sync each write");
println!(" make bench-writes-fast # Fast async (100ms)");
}
fn cleanup_wal() {
let _ = fs::remove_dir_all("wal_files");
thread::sleep(Duration::from_millis(100));
}
#[test]
fn multithreaded_benchmark() {
let args: Vec<String> = env::args().collect();
if args.iter().any(|arg| arg == "--help" || arg == "-h") {
print_usage();
return;
}
cleanup_wal();
unsafe {
std::env::set_var("WALRUS_QUIET", "1");
}
let fsync_schedule = parse_fsync_schedule();
let write_duration = parse_duration();
println!("=== Multi-threaded WAL Benchmark ===");
println!(
"Configuration: 10 threads, {:.0}s write phase only",
write_duration.as_secs()
);
println!("Fsync schedule: {:?}", fsync_schedule);
println!(
"Duration: {:?} (batch ramp-up: 50k→100k→150k...→500k, 500ms delays)",
write_duration
);
let wal = Arc::new(
Walrus::with_consistency_and_schedule(
ReadConsistency::AtLeastOnce {
persist_every: 5000,
},
fsync_schedule,
)
.expect("Failed to create Walrus"),
);
let num_threads = 10;
let total_writes = Arc::new(AtomicU64::new(0));
let total_write_bytes = Arc::new(AtomicU64::new(0));
let write_errors = Arc::new(AtomicU64::new(0));
let csv_path = "benchmark_throughput.csv";
let mut csv_file = fs::File::create(csv_path).expect("Failed to create CSV file");
writeln!(
csv_file,
"timestamp,elapsed_seconds,writes_per_second,bytes_per_second,total_writes,total_bytes,dirty_pages_kb,dirty_ratio_percent"
)
.expect("Failed to write CSV header");
let (throughput_tx, throughput_rx) = mpsc::channel::<()>();
let start_barrier = Arc::new(Barrier::new(num_threads + 1)); let write_end_barrier = Arc::new(Barrier::new(num_threads + 1));
let topics = vec![
"topic_0".to_string(),
"topic_1".to_string(),
"topic_2".to_string(),
"topic_3".to_string(),
"topic_4".to_string(),
"topic_5".to_string(),
"topic_6".to_string(),
"topic_7".to_string(),
"topic_8".to_string(),
"topic_9".to_string(),
];
println!("Starting {} writer threads...", num_threads);
let total_writes_monitor = Arc::clone(&total_writes);
let total_write_bytes_monitor = Arc::clone(&total_write_bytes);
let throughput_tx_clone = throughput_tx.clone();
let monitor_duration = write_duration;
let monitor_handle = thread::spawn(move || {
let mut csv_file = fs::OpenOptions::new()
.create(true)
.append(true)
.open("benchmark_throughput.csv")
.expect("Failed to open CSV file");
let mut start_time = Instant::now();
let mut last_writes = 0u64;
let mut last_bytes = 0u64;
let mut last_time = start_time;
let mut tick_index: u64 = 0;
let _ = throughput_rx.recv();
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
let (_, initial_dirty_kb, initial_dirty_ratio) = get_memory_info();
writeln!(
csv_file,
"{},{:.2},{:.0},{:.0},{},{},{},{:.2}",
timestamp, 0.0, 0.0, 0.0, 0, 0, initial_dirty_kb, initial_dirty_ratio
)
.expect("Failed to write initial CSV entry");
csv_file.flush().expect("Failed to flush CSV");
start_time = Instant::now();
last_time = start_time;
thread::sleep(Duration::from_millis(500));
loop {
thread::sleep(Duration::from_millis(500));
tick_index += 1;
let interval_s = 0.5f64;
let elapsed_total = tick_index as f64 * interval_s;
let current_time = Instant::now();
let current_writes = total_writes_monitor.load(Ordering::Relaxed);
let current_bytes = total_write_bytes_monitor.load(Ordering::Relaxed);
let writes_per_second = (current_writes - last_writes) as f64 / interval_s;
let bytes_per_second = (current_bytes - last_bytes) as f64 / interval_s;
let (_, dirty_kb, dirty_ratio) = get_memory_info();
let should_log = (current_writes != last_writes) || (tick_index % 4 == 0);
if should_log {
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
writeln!(
csv_file,
"{},{:.2},{:.0},{:.0},{},{},{},{:.2}",
timestamp,
elapsed_total,
writes_per_second,
bytes_per_second,
current_writes,
current_bytes,
dirty_kb,
dirty_ratio
)
.expect("Failed to write to CSV");
csv_file.flush().expect("Failed to flush CSV");
if current_writes > last_writes {
println!(
"[Monitor] {:.1}s: {:.0} writes/sec, {:.2} MB/sec, total: {} writes, dirty: {:.2}% ({} KB)",
elapsed_total,
writes_per_second,
bytes_per_second / (1024.0 * 1024.0),
current_writes,
dirty_ratio,
dirty_kb
);
}
}
last_writes = current_writes;
last_bytes = current_bytes;
last_time = current_time;
let max_monitor_time = monitor_duration.as_secs_f64() + 30.0;
if elapsed_total > max_monitor_time {
break;
}
}
});
let mut handles = Vec::new();
for thread_id in 0..num_threads {
let wal_clone = Arc::clone(&wal);
let total_writes_clone = Arc::clone(&total_writes);
let total_write_bytes_clone = Arc::clone(&total_write_bytes);
let write_errors_clone = Arc::clone(&write_errors);
let start_barrier_clone = Arc::clone(&start_barrier);
let write_end_barrier_clone = Arc::clone(&write_end_barrier);
let topic = topics[thread_id].clone();
let handle = thread::spawn(move || {
start_barrier_clone.wait();
let start_time = Instant::now();
let mut local_writes = 0u64;
let mut local_write_bytes = 0u64;
let mut local_errors = 0u64;
let mut counter = 0u64;
let mut rng = rand::thread_rng();
let batch_delay = Duration::from_millis(500); let mut batch_number = 0;
while start_time.elapsed() < write_duration {
let current_batch_size = match batch_number {
0 => 50_000, 1 => 100_000, 2 => 150_000, 3 => 200_000, 4 => 250_000, 5 => 300_000, 6 => 350_000, 7 => 400_000, 8 => 450_000, _ => 500_000, };
for _ in 0..current_batch_size {
if start_time.elapsed() >= write_duration {
break;
}
let size = rng.gen_range(500..=1024);
let data = vec![(counter % 256) as u8; size];
match wal_clone.append_for_topic(&topic, &data) {
Ok(_) => {
local_writes += 1;
local_write_bytes += data.len() as u64;
total_writes_clone.fetch_add(1, Ordering::Relaxed);
total_write_bytes_clone.fetch_add(data.len() as u64, Ordering::Relaxed);
}
Err(_) => {
local_errors += 1;
}
}
counter += 1;
}
batch_number += 1;
if start_time.elapsed() < write_duration {
thread::sleep(batch_delay);
}
}
write_errors_clone.fetch_add(local_errors, Ordering::Relaxed);
println!(
"Thread {} ({}): {} writes, {} KB, {} errors",
thread_id,
topic,
local_writes,
local_write_bytes / 1024,
local_errors
);
write_end_barrier_clone.wait();
});
handles.push(handle);
}
let benchmark_start = Instant::now();
start_barrier.wait();
println!("All threads started! Write phase beginning...");
let _ = throughput_tx.send(());
write_end_barrier.wait();
let write_elapsed = benchmark_start.elapsed();
println!("Write phase completed in {:?}", write_elapsed);
let final_writes = total_writes.load(Ordering::Relaxed);
let final_write_bytes = total_write_bytes.load(Ordering::Relaxed);
let final_errors = write_errors.load(Ordering::Relaxed);
println!("\n=== Write Phase Results ===");
println!("Write Duration: {:?}", write_elapsed);
println!("Total Operations: {}", final_writes);
println!("Total Bytes: {} MB", final_write_bytes / (1024 * 1024));
println!("Write Errors: {}", final_errors);
println!(
"Write Throughput: {:.0} ops/sec",
final_writes as f64 / write_elapsed.as_secs_f64()
);
println!(
"Write Bandwidth: {:.2} MB/sec",
(final_write_bytes as f64 / (1024.0 * 1024.0)) / write_elapsed.as_secs_f64()
);
println!();
for handle in handles {
let _ = handle.join().unwrap();
}
let total_elapsed = benchmark_start.elapsed();
println!("\n=== Final Summary ===");
println!("Total Benchmark Duration: {:?}", total_elapsed);
let final_writes = total_writes.load(Ordering::Relaxed);
let final_errors = write_errors.load(Ordering::Relaxed);
assert!(
final_writes > 1000,
"Write throughput too low: {} ops",
final_writes
);
assert!(
final_errors < final_writes / 10,
"Too many write errors: {} out of {}",
final_errors,
final_writes
);
println!("Multi-threaded benchmark completed successfully!");
let _ = monitor_handle.join();
println!("Throughput data saved to: {}", csv_path);
cleanup_wal();
}