use crate::Parameters;
use float_ord::FloatOrd;
use readings_probe::Probe;
use std::time::{Duration, Instant};
use tract_core::num_traits::Zero;
use tract_core::tract_data::itertools::Itertools;
use tract_hir::internal::*;
use tract_libcli::profile::BenchLimits;
use tract_libcli::tensor::get_or_make_inputs;
#[cfg(feature = "transformers")]
use tract_transformers::figure_out_causal_llm_b_s_p;
pub fn handle(
params: &Parameters,
matches: &clap::ArgMatches,
sub_matches: &clap::ArgMatches,
limits: &BenchLimits,
probe: Option<&Probe>,
) -> TractResult<()> {
bench_pp(params, matches, sub_matches, limits, 512, probe)?;
bench_tg(params, matches, sub_matches, limits, 128, probe)?;
Ok(())
}
pub fn bench_pp(
params: &Parameters,
_matches: &clap::ArgMatches,
sub_matches: &clap::ArgMatches,
limits: &BenchLimits,
pp: usize,
_probe: Option<&Probe>,
) -> TractResult<()> {
let mut run_params = crate::tensor::run_params_from_subcommand(params, sub_matches)?;
run_params.allow_random_input = true;
let model = params.req_typed_model();
let (b, s, p) = tract_transformers::figure_out_causal_llm_b_s_p(&model)
.context("Could not find out LLM symbolic parameters")?;
if let Some(b) = b {
run_params.symbols.set(&b, 1);
}
ensure!(s.is_some() && p.is_some(), "Could not find LLM symbols in model");
run_params.symbols.set(&p.unwrap(), 0);
run_params.symbols.set(&s.unwrap(), pp as i64);
let inputs = get_or_make_inputs(¶ms.tract_model, &run_params)?;
limits.warmup(¶ms.req_runnable()?, &inputs)?;
let inputs = get_or_make_inputs(¶ms.tract_model, &run_params)?;
let (iters, dur) = limits.bench(¶ms.req_runnable()?, &inputs)?;
let tokens = pp as f64 / dur.as_secs_f64() * iters as f64;
println!("PP{pp}: {tokens:.1} tokens/sec");
Ok(())
}
pub fn bench_tg(
params: &Parameters,
_matches: &clap::ArgMatches,
sub_matches: &clap::ArgMatches,
limits: &BenchLimits,
tg: usize,
probe: Option<&Probe>,
) -> TractResult<()> {
let mut run_params = crate::tensor::run_params_from_subcommand(params, sub_matches)?;
run_params.allow_random_input = true;
let model = params.req_typed_model();
let (b, s, p) = figure_out_causal_llm_b_s_p(&model)
.context("Could not find out LLM symbolic parameters")?;
if let Some(b) = b {
run_params.symbols.set(&b, 1);
}
ensure!(s.is_some() && p.is_some(), "Could not find LLM symbols in model");
run_params.symbols.set(&s.unwrap(), 1);
let p = p.unwrap();
if !limits.warmup_loops.is_zero() || !limits.warmup_time.is_zero() {
let mut iters = 0;
let max_loops =
if limits.warmup_loops.is_zero() { usize::MAX } else { limits.warmup_loops };
let max_time =
if limits.warmup_time.is_zero() { Duration::MAX } else { limits.warmup_time };
let start_warmup = Instant::now();
info!("TG warming before profiling...");
while iters < max_loops && start_warmup.elapsed() < max_time {
let mut state = params.req_runnable()?.spawn()?;
for t in 0..tg {
run_params.symbols.set(&p, t as i64);
let mut inputs = get_or_make_inputs(¶ms.tract_model, &run_params)?;
state.run(inputs.sources.remove(0))?;
}
iters += 1;
}
info!("Done warming up.");
}
let mut tot_dur = Duration::default();
let mut state = params.req_runnable()?.spawn()?;
for t in 0..tg {
if let Some(p) = probe {
p.log_event(&format!("Starting token {t}"))?;
}
run_params.symbols.set(&p, t as i64);
let mut inputs = get_or_make_inputs(¶ms.tract_model, &run_params)?;
let start = Instant::now();
state.run(inputs.sources.remove(0))?;
tot_dur += start.elapsed();
}
let tokens = tg as f64 / tot_dur.as_secs_f64();
println!("TG{tg}: {tokens:.1} tokens/sec");
Ok(())
}
pub fn top_logits_rbo(test: &Tensor, reference: &Tensor, p: f64, depth: usize) -> TractResult<f64> {
use std::collections::HashSet;
let rankings: Vec<Vec<usize>> = [test, reference]
.into_iter()
.map(|t| {
t.cast_to::<f32>()
.unwrap()
.try_as_plain()
.unwrap()
.as_slice::<f32>()
.unwrap()
.iter()
.copied()
.enumerate()
.sorted_by_key(|(_, f)| FloatOrd(-*f))
.map(|p| p.0)
.collect_vec()
})
.collect();
let a = &rankings[0];
let b = &rankings[1];
let k = depth.min(a.len()).min(b.len());
let mut set_a: HashSet<usize> = HashSet::new();
let mut set_b: HashSet<usize> = HashSet::new();
let mut rbo = 0.0;
for d in 1..=k {
set_a.insert(a[d - 1]);
set_b.insert(b[d - 1]);
let overlap = set_a.intersection(&set_b).count() as f64 / d as f64;
rbo += p.powi((d as i32) - 1) * overlap;
}
let top1_match = a[0] == b[0];
let top5_overlap = {
let sa: HashSet<usize> = a[..5.min(k)].iter().copied().collect();
let sb: HashSet<usize> = b[..5.min(k)].iter().copied().collect();
sa.intersection(&sb).count()
};
debug!("RBO detail: top1_match={top1_match} top5_overlap={top5_overlap}/5");
Ok((1.0 - p) * rbo)
}