use std::time::{Duration, Instant};
use anyhow::{Context, Result};
use trusty_common::embedder::rss;
use trusty_common::embedder::{CandleEmbedder, Embedder, FastEmbedder};
const BYTES_PER_GB: f64 = 1024.0 * 1024.0 * 1024.0;
const BYTES_PER_MB: f64 = 1024.0 * 1024.0;
const TOKENS_PER_TEXT_ESTIMATE: u32 = 100;
struct BackendResult {
label: &'static str,
peak_rss_bytes: u64,
start_rss_bytes: u64,
end_rss_bytes: u64,
latencies: Vec<Duration>,
total_wall: Duration,
total_texts: usize,
}
impl BackendResult {
fn throughput_tokens_per_sec(&self) -> f64 {
let total_tokens = self.total_texts as f64 * TOKENS_PER_TEXT_ESTIMATE as f64;
total_tokens / self.total_wall.as_secs_f64().max(1e-9)
}
fn p50(&self) -> Duration {
percentile(&self.latencies, 0.50)
}
fn p99(&self) -> Duration {
percentile(&self.latencies, 0.99)
}
}
fn percentile(values: &[Duration], p: f64) -> Duration {
if values.is_empty() {
return Duration::ZERO;
}
let mut sorted = values.to_vec();
sorted.sort_unstable();
let idx = ((sorted.len() as f64) * p).ceil() as usize;
let idx = idx.saturating_sub(1).min(sorted.len() - 1);
sorted[idx]
}
fn synthetic_texts(count: usize) -> Vec<String> {
let templates = [
"fn authenticate_user(token: &str) -> Result<UserId, AuthError> { /* validate JWT, look up session, return userid */ }",
"Why: shared embedding abstraction lets memory and search share one backend.\nWhat: async trait with embed_batch primitive.\nTest: covered by FastEmbedder and MockEmbedder.",
"impl Embedder for FastEmbedder { async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> { /* ... */ } }",
"struct KnowledgeGraph { store: Arc<dyn KgStore>, embedder: Arc<dyn Embedder>, } impl KnowledgeGraph { pub fn new(...) -> Self { /* ... */ } }",
"// CoreML EP with MLComputeUnits=ALL allocates from the unified-memory GPU pool. Issue #24 inflated RSS to ~72 GB during indexing.",
"pub fn current_rss_bytes() -> u64 { let mut sys = System::new(); sys.refresh_processes(); sys.process(Pid::from_u32(std::process::id())).map(|p| p.memory()).unwrap_or(0) }",
"use anyhow::{Context, Result}; use tokio::sync::Mutex; use std::sync::Arc; pub struct Daemon { inner: Arc<Mutex<State>> }",
"GET /v1/users/{id} returns 200 with user payload or 404 if not found. Idempotent. Cacheable for 60s via Cache-Control header.",
];
(0..count)
.map(|i| format!("{} // chunk #{}", templates[i % templates.len()], i))
.collect()
}
async fn run_one_backend(
label: &'static str,
embedder: &dyn Embedder,
batches: usize,
batch_size: usize,
) -> Result<BackendResult> {
let start_rss = rss::current_rss_bytes();
let mut peak_rss = start_rss;
let mut latencies = Vec::with_capacity(batches);
let total_start = Instant::now();
for i in 0..batches {
let texts = synthetic_texts(batch_size);
let batch_start = Instant::now();
let vectors = embedder
.embed_batch(&texts)
.await
.with_context(|| format!("{label}: embed_batch failed on iteration {i}"))?;
let elapsed = batch_start.elapsed();
latencies.push(elapsed);
anyhow::ensure!(
vectors.len() == batch_size,
"{label}: expected {} vectors, got {}",
batch_size,
vectors.len()
);
let rss_now = rss::current_rss_bytes();
if rss_now > peak_rss {
peak_rss = rss_now;
}
if (i + 1) % 10 == 0 || i + 1 == batches {
eprintln!(
" {label}: batch {}/{} ({:.0} ms, RSS now {:.2} GB, peak {:.2} GB)",
i + 1,
batches,
elapsed.as_millis() as f64,
rss_now as f64 / BYTES_PER_GB,
peak_rss as f64 / BYTES_PER_GB,
);
}
}
let end_rss = rss::current_rss_bytes();
if end_rss > peak_rss {
peak_rss = end_rss;
}
Ok(BackendResult {
label,
peak_rss_bytes: peak_rss,
start_rss_bytes: start_rss,
end_rss_bytes: end_rss,
latencies,
total_wall: total_start.elapsed(),
total_texts: batches * batch_size,
})
}
fn print_backend(r: &BackendResult) {
println!("{}:", r.label);
println!(
" Start RSS: {:.2} GB ({:.0} MB)",
r.start_rss_bytes as f64 / BYTES_PER_GB,
r.start_rss_bytes as f64 / BYTES_PER_MB,
);
println!(
" End RSS: {:.2} GB ({:.0} MB)",
r.end_rss_bytes as f64 / BYTES_PER_GB,
r.end_rss_bytes as f64 / BYTES_PER_MB,
);
println!(
" Peak RSS: {:.2} GB ({:.0} MB)",
r.peak_rss_bytes as f64 / BYTES_PER_GB,
r.peak_rss_bytes as f64 / BYTES_PER_MB,
);
println!(
" Throughput: {:>8.0} tokens/sec (estimated)",
r.throughput_tokens_per_sec(),
);
println!(
" Latency p50: {:>5} ms p99: {:>5} ms",
r.p50().as_millis(),
r.p99().as_millis(),
);
println!(" Total wall: {:.2} s", r.total_wall.as_secs_f64());
println!();
}
fn decide_verdict(
candle: &BackendResult,
baseline: Option<&BackendResult>,
rss_limit_bytes: u64,
max_slowdown: f64,
) -> (bool, Vec<String>) {
let mut reasons = Vec::new();
let mut go = true;
if candle.peak_rss_bytes >= rss_limit_bytes {
go = false;
reasons.push(format!(
"candle peak RSS {:.2} GB >= {:.2} GB limit",
candle.peak_rss_bytes as f64 / BYTES_PER_GB,
rss_limit_bytes as f64 / BYTES_PER_GB,
));
} else {
reasons.push(format!(
"candle peak RSS {:.2} GB < {:.2} GB limit",
candle.peak_rss_bytes as f64 / BYTES_PER_GB,
rss_limit_bytes as f64 / BYTES_PER_GB,
));
}
if let Some(b) = baseline {
let c_tps = candle.throughput_tokens_per_sec();
let b_tps = b.throughput_tokens_per_sec();
if c_tps <= 0.0 || b_tps <= 0.0 {
go = false;
reasons.push("invalid throughput reading (<=0)".to_string());
} else {
let slowdown = b_tps / c_tps;
if slowdown > max_slowdown {
go = false;
reasons.push(format!(
"candle throughput {:.0} tok/s is {:.2}× slower than FastEmbedder {:.0} tok/s (> {:.2}×)",
c_tps, slowdown, b_tps, max_slowdown,
));
} else {
reasons.push(format!(
"candle throughput {:.0} tok/s is {:.2}× FastEmbedder {:.0} tok/s (<= {:.2}×)",
c_tps, slowdown, b_tps, max_slowdown,
));
}
}
} else {
reasons.push("baseline skipped — throughput criterion not evaluated".to_string());
}
(go, reasons)
}
fn env_usize(key: &str, default: usize) -> usize {
std::env::var(key)
.ok()
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(default)
}
fn env_f64(key: &str, default: f64) -> f64 {
std::env::var(key)
.ok()
.and_then(|s| s.parse::<f64>().ok())
.unwrap_or(default)
}
#[tokio::main(flavor = "multi_thread")]
async fn main() -> Result<()> {
let batches = env_usize("TRUSTY_BENCH_BATCHES", 100);
let batch_size = env_usize("TRUSTY_BENCH_BATCH_SIZE", 1000);
let skip_baseline = std::env::var("TRUSTY_BENCH_SKIP_BASELINE").is_ok();
let rss_limit_gb = env_f64("TRUSTY_BENCH_RSS_LIMIT_GB", 8.0);
let max_slowdown = env_f64("TRUSTY_BENCH_THROUGHPUT_X", 2.0);
let rss_limit_bytes = (rss_limit_gb * BYTES_PER_GB) as u64;
eprintln!(
"candle_metal_bench: {} batches × {} texts = {} total chunks",
batches,
batch_size,
batches * batch_size,
);
eprintln!(
"candle_metal_bench: GO criteria: peak RSS < {:.1} GB AND throughput within {:.2}× of FastEmbedder",
rss_limit_gb, max_slowdown,
);
eprintln!(
"candle_metal_bench: process RSS before any embedder init: {:.2} GB",
rss::current_rss_bytes() as f64 / BYTES_PER_GB,
);
let use_metal = cfg!(target_os = "macos");
eprintln!("candle_metal_bench: building CandleEmbedder (use_metal={use_metal})");
let candle = CandleEmbedder::new(use_metal)
.context("failed to construct CandleEmbedder — see error for missing model files")?;
eprintln!(
"candle_metal_bench: CandleEmbedder device = {:?}",
candle.device()
);
let candle_result = run_one_backend("Candle (Metal/CPU)", &candle, batches, batch_size).await?;
drop(candle);
let baseline_result = if skip_baseline {
eprintln!(
"candle_metal_bench: TRUSTY_BENCH_SKIP_BASELINE set — skipping FastEmbedder baseline"
);
None
} else {
eprintln!("candle_metal_bench: building FastEmbedder baseline");
let fast = FastEmbedder::new()
.await
.context("failed to construct FastEmbedder baseline")?;
let r = run_one_backend("FastEmbedder (baseline)", &fast, batches, batch_size).await?;
drop(fast);
Some(r)
};
println!();
println!("=== Candle Metal Validation ===");
println!(
"Batches: {batches} × {batch_size} texts = {} chunks",
batches * batch_size
);
println!();
print_backend(&candle_result);
if let Some(b) = &baseline_result {
print_backend(b);
}
let (go, reasons) = decide_verdict(
&candle_result,
baseline_result.as_ref(),
rss_limit_bytes,
max_slowdown,
);
println!("Criteria (GO requires ALL):");
println!(" • candle peak RSS < {:.1} GB", rss_limit_gb,);
println!(
" • candle throughput >= 1/{:.2} × FastEmbedder throughput",
max_slowdown,
);
println!();
for r in &reasons {
println!(" - {r}");
}
println!();
if go {
println!("VERDICT: GO");
Ok(())
} else {
println!("VERDICT: NO-GO");
std::process::exit(1);
}
}