tract 0.22.1

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
use crate::Parameters;
use readings_probe::Probe;
use std::time::{Duration, Instant};
use tract_hir::internal::*;
use tract_libcli::capture_gpu_trace;
use tract_libcli::model::Model;
use tract_libcli::profile::BenchLimits;
use tract_libcli::tensor::RunTensors;
use tract_libcli::tensor::get_or_make_inputs;
use tract_libcli::terminal;

fn profile_single_turn<'m>(
    state: &mut TypedSimpleState<&'m TypedModel, Arc<TypedRunnableModel<&'m TypedModel>>>,
    inputs: &RunTensors,
) -> TractResult<Duration> {
    if state.model().properties().contains_key("pulse.delay") {
        let start = Instant::now();
        state.run(inputs.sources[0].clone())?;
        Ok(start.elapsed())
    } else {
        state.init_states(&mut inputs.state_initializers.clone())?;
        let start = Instant::now();
        state.run(inputs.sources[0].clone())?;
        let elapsed = start.elapsed();
        state.reset_op_states()?;
        Ok(elapsed)
    }
}

pub fn criterion(
    params: &Parameters,
    matches: &clap::ArgMatches,
    sub_matches: &clap::ArgMatches,
) -> TractResult<()> {
    let model =
        params.tract_model.downcast_ref::<TypedModel>().context("Can only bench TypedModel")?;
    let mut state = make_state(params, matches, sub_matches)?;

    let mut crit = criterion::Criterion::default();
    let mut group = crit.benchmark_group("net");

    let run_params = crate::tensor::run_params_from_subcommand(params, sub_matches)?;
    let inputs = get_or_make_inputs(model, &run_params)?;

    group.bench_function("run", move |b| b.iter(|| profile_single_turn(&mut state, &inputs)));
    Ok(())
}

pub(crate) fn make_state<'m>(
    params: &'m Parameters,
    matches: &clap::ArgMatches,
    sub_matches: &clap::ArgMatches,
) -> TractResult<TypedSimpleState<&'m TypedModel, Arc<TypedRunnableModel<&'m TypedModel>>>> {
    #[allow(unused_mut)]
    let mut plan_options = crate::plan_options::plan_options_from_subcommand(sub_matches)?;
    let model =
        params.tract_model.downcast_ref::<TypedModel>().context("Can only bench TypedModel")?;
    if matches.is_present("metal") || matches.is_present("cuda") {
        #[cfg(not(any(target_os = "macos", target_os = "ios")))]
        {
            use tract_cuda::utils::get_cuda_lib;
            if get_cuda_lib().is_none() {
                bail!("GPU bench called on non-GPU model");
            }
        }
        plan_options.skip_order_opt_ram = true;
        let mut plan = SimplePlan::new_with_options(model, &plan_options)?;
        let mut symbol_values = SymbolValues::default();
        let sequence_length = model.symbols.get("S").context("Could not find symbol S in model")?;
        let past_sequence_length =
            model.symbols.get("P").context("Could not find symbol P in model")?;

        symbol_values.set(&sequence_length, 1024);
        symbol_values.set(&past_sequence_length, 0);
        let session_handler =
            tract_gpu::session_handler::DeviceSessionHandler::from_plan(&plan, &symbol_values)?;

        plan = plan.with_session_handler(session_handler);
        Ok(SimpleState::new(Arc::new(plan))?)
    } else {
        let plan = SimplePlan::new_with_options(model, &plan_options)?;
        Ok(SimpleState::new(Arc::new(plan))?)
    }
}

pub(crate) fn bench<'m>(
    state: &mut TypedSimpleState<&'m TypedModel, Arc<TypedRunnableModel<&'m TypedModel>>>,
    sub_matches: &clap::ArgMatches,
    inputs: RunTensors,
    limits: &BenchLimits,
    probe: Option<&Probe>,
) -> TractResult<(usize, Duration)> {
    let mut iters = 0;
    let progress = probe.and_then(|m| m.get_i64("progress"));
    info!("Starting bench itself");
    let mut dur = Duration::default();
    capture_gpu_trace(sub_matches, || -> TractResult<()> {
        while iters < limits.max_loops && dur < limits.max_time {
            if let Some(mon) = probe {
                let _ = mon.log_event(&format!("loop_{iters}"));
            }
            if let Some(p) = &progress {
                p.store(iters as _, std::sync::atomic::Ordering::Relaxed);
            }

            dur += profile_single_turn(state, &inputs)?;

            iters += 1;
        }
        Ok(())
    })?;
    Ok((iters, Duration::from_secs_f64(dur.as_secs_f64() / iters as f64)))
}

pub fn handle(
    params: &Parameters,
    matches: &clap::ArgMatches,
    sub_matches: &clap::ArgMatches,
    limits: &BenchLimits,
    probe: Option<&Probe>,
) -> TractResult<()> {
    let run_params = crate::tensor::run_params_from_subcommand(params, sub_matches)?;
    let mut state = make_state(params, matches, sub_matches)?;

    let inputs = get_or_make_inputs(state.model(), &run_params)?;

    limits.warmup(state.model(), &inputs)?;
    let (iters, dur) = bench(&mut state, sub_matches, inputs, limits, probe)?;

    if params.machine_friendly {
        println!("real: {}", dur.as_secs_f64());
    } else {
        println!("Bench ran {} times, {}.", iters, terminal::dur_avg(dur));
    }

    if let Some(pp) = sub_matches.value_of("pp") {
        let pp = pp.parse::<usize>()?;
        let tokens = pp as f64 / dur.as_secs_f64();
        println!("PP{pp}: {tokens:.1} tokens/sec");
    }

    Ok(())
}