use crate::config::kwaainet_dir;
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ThroughputEntry {
pub compute_tps: f64,
pub hidden_size: usize,
}
pub fn cache_file() -> PathBuf {
kwaainet_dir().join("throughput_cache.json")
}
pub fn save(model: &str, compute_tps: f64, hidden_size: usize) -> Result<()> {
let path = cache_file();
std::fs::create_dir_all(path.parent().expect("cache_file has a parent"))?;
let mut cache: HashMap<String, ThroughputEntry> = load_cache();
cache.insert(
model.to_string(),
ThroughputEntry {
compute_tps,
hidden_size,
},
);
std::fs::write(&path, serde_json::to_string_pretty(&cache)?)?;
Ok(())
}
pub fn load(model: &str) -> Option<ThroughputEntry> {
let mut cache = load_cache();
if let Some(entry) = cache.remove(model) {
return Some(entry);
}
if cache.len() == 1 {
return cache.into_values().next();
}
None
}
fn load_cache() -> HashMap<String, ThroughputEntry> {
let text = match std::fs::read_to_string(cache_file()) {
Ok(t) => t,
Err(_) => return HashMap::new(),
};
serde_json::from_str(&text).unwrap_or_default()
}
pub const RELAY_PENALTY: f64 = 0.2;
pub async fn measure_download_bps() -> f64 {
const URL: &str = "https://speed.cloudflare.com/__down?bytes=1048576";
const TEST_BYTES: usize = 1_048_576;
let client = match reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(15))
.build()
{
Ok(c) => c,
Err(_) => return 0.0,
};
let start = std::time::Instant::now();
let resp = match client.get(URL).send().await {
Ok(r) => r,
Err(_) => return 0.0,
};
let bytes = match resp.bytes().await {
Ok(b) => b,
Err(_) => return 0.0,
};
let secs = start.elapsed().as_secs_f64();
if secs <= 0.0 || bytes.len() < TEST_BYTES / 2 {
return 0.0; }
(bytes.len() as f64 * 8.0) / secs }
pub fn effective_tps(entry: &ThroughputEntry, download_bps: f64, using_relay: bool) -> f64 {
let penalty = if using_relay { RELAY_PENALTY } else { 1.0 };
if download_bps <= 0.0 || entry.hidden_size == 0 {
return entry.compute_tps;
}
let network_rps = download_bps / (entry.hidden_size as f64 * 16.0);
entry.compute_tps.min(network_rps * penalty)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_effective_tps_compute_bound() {
let entry = ThroughputEntry {
compute_tps: 20.0,
hidden_size: 4096,
};
let tps = effective_tps(&entry, 100_000_000.0, true);
assert!((tps - 20.0).abs() < 0.01, "expected 20.0, got {tps}");
}
#[test]
fn test_effective_tps_network_bound() {
let entry = ThroughputEntry {
compute_tps: 100.0,
hidden_size: 4096,
};
let tps = effective_tps(&entry, 1_000_000.0, true);
assert!(tps < 5.0, "expected network-bound (<5), got {tps}");
assert!(tps > 2.0, "expected >2, got {tps}");
}
#[test]
fn test_effective_tps_no_relay() {
let entry = ThroughputEntry {
compute_tps: 100.0,
hidden_size: 4096,
};
let tps = effective_tps(&entry, 1_000_000.0, false);
assert!(tps > 14.0 && tps < 16.0, "expected ~15.3, got {tps}");
}
#[test]
fn test_effective_tps_no_network_data() {
let entry = ThroughputEntry {
compute_tps: 7.5,
hidden_size: 4096,
};
assert_eq!(effective_tps(&entry, 0.0, true), 7.5);
}
}