use crate::Parameters;
use crate::bench::{bench, make_state};
use readings_probe::Probe;
use std::collections::HashSet;
use std::time::{Duration, Instant};
use tract_hir::internal::*;
use tract_libcli::profile::BenchLimits;
use tract_libcli::tensor::get_or_make_inputs;
pub fn figure_out_b_s_p(model: &TypedModel) -> TractResult<(Option<Symbol>, Symbol, Symbol)> {
let token_input = model
.inputs
.iter()
.position(|i| model.outlet_fact(*i).unwrap().datum_type.is_integer())
.context("No token input found")?;
let tokens_symbols = model.input_fact(token_input)?.shape.volume().symbols();
let kv_symbols = if let Some(kv_input) =
model.inputs.iter().position(|i| model.outlet_fact(*i).unwrap().datum_type.is_float())
{
model.input_fact(kv_input)?.shape.volume().symbols()
} else {
let mut dummy_session_state = SessionState::default();
let mut symbols = HashSet::new();
for node in &model.nodes {
if let Some(fact) = node
.op
.state(&mut dummy_session_state, 0)?
.and_then(|state| state.init_tensor_fact())
{
symbols = fact.shape.volume().symbols();
break;
}
}
symbols
};
let b = tokens_symbols.intersection(&kv_symbols).cloned().collect::<HashSet<_>>();
let s = tokens_symbols.difference(&b).cloned().collect::<HashSet<_>>();
let p = kv_symbols.difference(&b).cloned().collect::<HashSet<_>>();
Ok((b.into_iter().next(), s.into_iter().next().unwrap(), p.into_iter().next().unwrap()))
}
pub fn handle(
params: &Parameters,
matches: &clap::ArgMatches,
sub_matches: &clap::ArgMatches,
limits: &BenchLimits,
probe: Option<&Probe>,
) -> TractResult<()> {
bench_pp(params, matches, sub_matches, limits, 512, probe)?;
bench_tg(params, matches, sub_matches, limits, 128, probe)?;
Ok(())
}
pub fn bench_pp(
params: &Parameters,
matches: &clap::ArgMatches,
sub_matches: &clap::ArgMatches,
limits: &BenchLimits,
pp: usize,
probe: Option<&Probe>,
) -> TractResult<()> {
let mut run_params = crate::tensor::run_params_from_subcommand(params, sub_matches)?;
run_params.allow_random_input = true;
let model =
params.tract_model.downcast_ref::<TypedModel>().context("Can only bench TypedModel")?;
let mut state = make_state(params, matches, sub_matches)?;
let (b, s, p) =
figure_out_b_s_p(model).context("Could not find out LLM symbolic parameters")?;
if let Some(b) = b {
run_params.symbols.set(&b, 1);
}
run_params.symbols.set(&p, 0);
run_params.symbols.set(&s, 6);
let inputs = get_or_make_inputs(model, &run_params)?;
limits.warmup(model, &inputs)?;
run_params.symbols.set(&s, pp as i64);
let inputs = get_or_make_inputs(model, &run_params)?;
let (_, dur) = bench(&mut state, sub_matches, inputs, limits, probe)?;
let tokens = pp as f64 / dur.as_secs_f64();
println!("PP{pp}: {tokens:.1} tokens/sec");
Ok(())
}
pub fn bench_tg(
params: &Parameters,
matches: &clap::ArgMatches,
sub_matches: &clap::ArgMatches,
limits: &BenchLimits,
tg: usize,
probe: Option<&Probe>,
) -> TractResult<()> {
let mut run_params = crate::tensor::run_params_from_subcommand(params, sub_matches)?;
run_params.allow_random_input = true;
let model =
params.tract_model.downcast_ref::<TypedModel>().context("Can only bench TypedModel")?;
let mut state = make_state(params, matches, sub_matches)?;
let (b, s, p) =
figure_out_b_s_p(model).context("Could not find out LLM symbolic parameters")?;
if let Some(b) = b {
run_params.symbols.set(&b, 1);
}
run_params.symbols.set(&s, 1);
run_params.symbols.set(&p, 1);
let inputs = get_or_make_inputs(model, &run_params)?;
limits.warmup(model, &inputs)?;
let mut tot_dur = Duration::default();
for t in 0..tg {
if let Some(p) = probe {
p.log_event(&format!("Starting token {t}"))?;
}
run_params.symbols.set(&p, t as i64);
let mut inputs = get_or_make_inputs(model, &run_params)?;
let start = Instant::now();
state.run(inputs.sources.remove(0))?;
tot_dur += start.elapsed();
}
state.reset_op_states()?;
let tokens = tg as f64 / tot_dur.as_secs_f64();
println!("TG{tg}: {tokens:.1} tokens/sec");
Ok(())
}