#![allow(clippy::len_zero)]
#[macro_use]
extern crate log;
#[macro_use]
mod macros;
#[allow(unused_imports)]
use tract_itertools::Itertools;
use tract_core::internal::*;
use tract_hir::internal::*;
use tract_libcli::annotations::Annotations;
use tract_libcli::display_params::DisplayParams;
use tract_libcli::model::Model;
use tract_libcli::profile::BenchLimits;
use readings_probe::*;
mod bench;
mod compare;
mod cost;
mod dump;
mod errors {}
mod params;
mod run;
#[cfg(feature = "pulse")]
mod stream_check;
mod tensor;
mod utils;
use params::*;
readings_probe::instrumented_allocator!();
fn info_usage(stage: &str, probe: Option<&Probe>) {
if let Some(mon) = probe {
let _ = mon.log_event(stage);
}
if log::log_enabled!(log::Level::Info) {
let usage = readings_probe::get_os_readings().unwrap();
info!(
"Resource usage {}: vsz:{} rsz:{} rszmax:{}",
stage, usage.virtual_size, usage.resident_size, usage.resident_size_max
);
}
}
pub const STAGES: &[&str] = &[
"load",
"analyse",
"incorporate",
"type",
"declutter",
"pulse",
"pulse-to-type",
"pulse-declutter",
"set",
"set-declutter",
"nnef-cycle",
"nnef-cycle-declutter",
"before-optimize",
"optimize",
];
fn main() -> tract_core::anyhow::Result<()> {
use clap::*;
let mut app = command!()
.setting(AppSettings::DeriveDisplayOrder)
.allow_hyphen_values(true)
.arg(arg!(--readings "Start readings instrumentation"))
.arg(arg!(--"readings-heartbeat" [MS] "Heartbeat for readings background collector").default_value("5"))
.arg(arg!(verbose: -v ... "Sets the level of verbosity."))
.arg(arg!([model] "Sets the model to use"))
.arg(arg!(-f --format [format]
"Hint the model format ('kaldi', 'onnx', 'nnef' or 'tf') instead of guess from extension."))
.arg(Arg::new("input").long("input").short('i').multiple_occurrences(true).takes_value(true).long_help(
"Set input shape and type (@file.pb or @file.npz:thing.npy or 3x4xi32)."))
.arg(Arg::new("constantize").long("constantize").multiple_occurrences(true).takes_value(true).long_help(
"Transorm an input into a Constant"))
.arg(arg!(--"input-bundle" [input_bundle] "Path to an input container (.npz). This sets input facts and tensor values.").hide(true))
.arg(arg!(--"allow-random-input" "Will use random generated input").hide(true))
.arg(arg!(--"input-facts-from-bundle" [input_bundle] "Path to an input container (.npz). This only sets input facts."))
.arg(arg!(--"kaldi-adjust-final-offset" [frames] "Adjust value of final offset in network (for reproducibility)"))
.arg(arg!(--"kaldi-downsample" [frames] "Add a subsampling to output on axis 0"))
.arg(arg!(--"kaldi-left-context" [frames] "Add lines of left context to input (dupping first time frame)"))
.arg(arg!(--"kaldi-right-context" [frames] "Add lines of right context to input (dupping last time frame)"))
.arg(arg!(--"onnx-test-data-set" [data_set] "Use onnx-test data-set as input (expect test_data_set_N dir with input_X.pb, etc. inside)"))
.arg(arg!(--"onnx-ignore-output-shapes" "Ignore output shapes from model (workaround for pytorch export bug with mask axes)"))
.arg(arg!(--"onnx-ignore-output-types" "Ignore output shapes from types (workaround for tdim conflicting with integer types)"))
.arg(arg!(--"input-node" [node] ... "Override input nodes names (auto-detects otherwise)."))
.arg(arg!(--"output-node" [node] ... "Override output nodes name (auto-detects otherwise)."))
.arg(arg!(--"label-wires" "Propagate node labels to wires"))
.arg(arg!(--"tf-initializer-output-node" [node] "Set an initializer node"))
.arg(arg!(--"override-fact" [fact] "Override a fact."))
.arg(arg!(--"analyse-fail-fast" "Stop analyse at first error."))
.arg(arg!(--recursive "Apply to sub graphes"))
.arg(arg!(--proto "Keep proto model around after parse"))
.arg(arg!(--determinize "Enforce a seed in random operator"))
.arg(arg!(--partial "Before analyse, eliminate dead branches"))
.arg(arg!(--pass [STAGE] "Pass to stop preprocessing after.").possible_values(STAGES))
.arg(arg!(--"declutter-step" [STEP] "Stop decluttering process after application of patch number N"))
.arg(arg!(--"optimize-step" [STEP] "Stop optimizing process after application of patch number N"))
.arg(arg!(--"extract-decluttered-sub" [SUB] "Zoom on a subgraph after decluttering by parent node name"))
.arg(arg!(--"half-floats" "Convert the decluttered network from f32 to f16"))
.arg(Arg::new("set").long("set").multiple_occurrences(true).takes_value(true)
.long_help("Set a symbol to a concrete value after decluttering"))
.arg(arg!(--"allow-float-casts" "Allow casting between f16, f32 and f64 around model").hide(true))
.arg(arg!(--"nnef-cycle" "Perform NNEF dump and reload before optimizing"))
.arg(arg!(--"nnef-tract-core" "Allow usage of tract-core extension in NNEF dump and load"))
.arg(arg!(--"nnef-tract-onnx" "Allow usage of tract-onnx extension in NNEF dump and load"))
.arg(arg!(--"nnef-tract-pulse" "Allow usage of tract-pulse extension in NNEF dump and load"))
.arg(arg!(--"nnef-extended-identifier" "Allow usage of the i\"...\" syntax to escape identifier names"))
.arg(arg!(-O --optimize "Optimize before running"))
.arg(arg!(--pulse [PULSE] "Translate to pulse network"))
.arg(arg!(--"machine-friendly" "Machine friendly output"))
.arg(arg!(--"list-ops" "List all known operators"));
let compare = clap::Command::new("compare")
.long_about("Compares the output of tract and tensorflow on randomly generated input.")
.arg(
Arg::new("stage")
.long("stage")
.takes_value(true)
.possible_values(STAGES)
.help("Loading pipeline stage to compare with"),
)
.arg(Arg::new("tf").long("tf").takes_value(false).help("Compare against tensorflow"))
.arg(Arg::new("twice").long("twice").takes_value(false).help("Run twice and compare"))
.arg(Arg::new("npz").long("npz").takes_value(true).help("NPZ file to compare against"))
.arg(
Arg::new("pbdir")
.long("pbdir")
.takes_value(true)
.help("protobuf directory file to compare against (like ONNX tests)"),
)
.group(
ArgGroup::new("reference")
.args(&["npz", "pbdir", "stage", "tf", "twice"])
.required(true),
)
.arg(
Arg::new("cumulative")
.long("cumulative")
.takes_value(false)
.help("Do not reset with reference values at each node"),
)
.arg(
Arg::new("resilient")
.long("resilient")
.takes_value(false)
.help("Try nodes one per one to mitigate crashes"),
);
let compare = run_options(compare);
app = app.subcommand(output_options(compare));
let bench =
clap::Command::new("bench").long_about("Benchmarks tract on randomly generated input.");
let bench = run_options(bench);
let bench = output_options(bench);
let bench = benchlimits_options(bench);
let bench = assertions_options(bench);
app = app.subcommand(bench);
let criterion = clap::Command::new("criterion")
.long_about("Benchmarks tract on randomly generated input using criterion.");
let criterion = run_options(criterion);
app = app.subcommand(criterion);
app = app.subcommand(dump_subcommand());
let run = clap::Command::new("run")
.long_about("Run the graph")
.arg(Arg::new("dump").long("dump").help("Show output"))
.arg(
Arg::new("save-outputs")
.long("save-outputs")
.takes_value(true)
.help("Save the outputs into a npz file"),
)
.arg(Arg::new("steps").long("steps").help("Show all inputs and outputs"))
.arg(
Arg::new("set")
.long("set")
.takes_value(true)
.multiple_occurrences(true)
.number_of_values(1)
.help("Set a symbol value before running the model (--set S=12)"),
)
.arg(
Arg::new("save-steps")
.long("save-steps")
.takes_value(true)
.help("Save intermediary values as a npz file"),
)
.arg(
Arg::new("assert-sane-floats")
.long("assert-sane-floats")
.help("Check float for NaN and infinites at each step"),
);
let run = run_options(run);
let run = output_options(run);
let run = assertions_options(run);
app = app.subcommand(run);
let optimize = clap::Command::new("optimize").about("Optimize the graph");
app = app.subcommand(output_options(optimize));
let stream_check = clap::Command::new("stream-check")
.long_about("Compare output of streamed and regular exec");
app = app.subcommand(output_options(stream_check));
let matches = app.get_matches();
let probe = if matches.is_present("readings") {
let file = std::fs::File::create("readings.out").unwrap();
let mut probe = Probe::new(file).unwrap();
probe.register_i64("progress").unwrap();
let heartbeat = matches.value_of("readings-heartbeat").unwrap().parse::<f32>().unwrap();
probe.spawn_heartbeat(std::time::Duration::from_secs_f32(heartbeat / 1000.0)).unwrap();
Some(probe)
} else {
None
};
if ::std::env::var("TRACT_LOG").is_err() {
let level = match matches.occurrences_of("verbose") {
0 => "cli=warn,tract=warn",
1 => "cli=info,tract=info",
2 => "cli=debug,tract=debug",
_ => "cli=trace,tract=trace",
};
::std::env::set_var("TRACT_LOG", level);
}
let env = env_logger::Env::default().filter_or("TRACT_LOG", "warn");
env_logger::Builder::from_env(env).format_timestamp_nanos().init();
info_usage("init", probe.as_ref());
if let Err(e) = handle(matches, probe.as_ref()) {
error!("{:?}", e);
std::process::exit(1);
}
info_usage("done", probe.as_ref());
Ok(())
}
#[allow(clippy::let_and_return)]
fn dump_subcommand<'a>() -> clap::Command<'a> {
use clap::*;
let dump = clap::Command::new("dump")
.long_about("Dumps the graph in human readable form.")
.arg(
Arg::new("axes")
.long("axes")
.help("Compute and display axis tracking")
)
.arg(
Arg::new("axes-names")
.takes_value(true)
.number_of_values(1)
.multiple_occurrences(true)
.long("axes-names")
.help("Gave meaningful names to axes: [node_name=]axis0,axis1,..,axisN (apply to first input if no node_name is provided)")
)
.arg(
Arg::new("assert-cost")
.takes_value(true)
.long("assert-cost")
.help("Checks computed against the provided value (form: \"FMA(F32)=2060448 DIV(F32)=24576\")")
)
.arg(
Arg::new("nnef-override-output-name")
.takes_value(true)
.number_of_values(1)
.long("nnef-override-output-name")
.help("Rename output before dumping network")
)
.arg(
Arg::new("nnef-dir")
.takes_value(true)
.long("nnef-dir")
.help("Dump the network in NNEF format (as a directory)"),
)
.arg(
Arg::new("nnef-tar")
.takes_value(true)
.long("nnef-tar")
.help("Dump the network in NNEF format (as a tar file)"),
)
.arg(
Arg::new("nnef")
.takes_value(true)
.long("nnef")
.help("Dump the network in NNEF format (as a tar.gz file)"),
)
.arg(
Arg::new("nnef-graph")
.takes_value(true)
.long("nnef-graph")
.help("Dump the network definition (without the weights) as a graph.nnef-like file"),
)
.arg(
Arg::new("inner")
.takes_value(true)
.number_of_values(1)
.multiple_occurrences(true)
.long("inner")
.help("Navigate to a sub-model"),
);
let dump = run_options(dump);
let dump = output_options(dump);
let dump = assertions_options(dump);
let dump = benchlimits_options(dump);
dump
}
fn assertions_options(command: clap::Command) -> clap::Command {
use clap::*;
command
.arg(
Arg::new("assert-output")
.takes_value(true)
.multiple_occurrences(true)
.number_of_values(1)
.long("assert-output")
.help("Fact to check the ouput tensor against (@filename, or 3x4xf32)"),
)
.arg(
Arg::new("assert-output-bundle")
.takes_value(true)
.long("assert-output-bundle")
.help("Checks values against these tensor (.npz)"),
)
.arg(
Arg::new("assert-output-fact")
.takes_value(true)
.long("assert-output-fact")
.help("Infered shape and datum type must match exactly this"),
)
.arg(
Arg::new("assert-output-count")
.takes_value(true)
.long("assert-output-count")
.help("Check the number of outputs found."),
)
.arg(
Arg::new("assert-op-count")
.takes_value(true)
.forbid_empty_values(true)
.number_of_values(2)
.value_names(&["operator", "count"])
.multiple_occurrences(true)
.long("assert-op-count")
.help("Specified operator must appear exactly the specified number of times. This argument can appear multiple times."),
)
}
fn benchlimits_options(command: clap::Command) -> clap::Command {
use clap::*;
command.args(&[
arg!(-n --"max-iters" [max_iters] "Sets the maximum number of iterations for each node [default: 100_000]."),
arg!(--"max-time" [max_time] "Sets the maximum execution time for each node (in ms) [default: 5000].") ])
}
fn run_options(command: clap::Command) -> clap::Command {
use clap::*;
command
.arg(
Arg::new("input-from-bundle")
.long("input-from-bundle")
.takes_value(true)
.help("Path to an input container (.npz). This sets tensor values."),
)
.arg(
Arg::new("allow-random-input")
.long("allow-random-input")
.help("Will use random generated input"),
)
.arg(
Arg::new("random-range")
.long("random-range")
.multiple_occurrences(true)
.takes_value(true)
.help("Constraint random values to a given range (example: input=1.0..10.0)"),
)
.arg(
Arg::new("allow-float-casts")
.long("allow-float-casts")
.help("Allow casting between f16, f32 and f64 around model"),
)
}
fn output_options(command: clap::Command) -> clap::Command {
use clap::*;
command
.args(&[
arg!(--"natural-order" "dump nodes in id order instead of evaluation order"),
arg!(-q --quiet "don't dump"),
])
.arg(Arg::new("debug-op").long("debug-op").help("show debug dump for each op"))
.arg(Arg::new("node-id").long("node-id").takes_value(true).help("Select a node to dump"))
.arg(
Arg::new("successors")
.long("successors")
.takes_value(true)
.help("Show successors of node"),
)
.arg(Arg::new("op-name").long("op-name").takes_value(true).help("Select one op to dump"))
.arg(
Arg::new("node-name")
.long("node-name")
.takes_value(true)
.help("Select one node to dump"),
)
.arg(Arg::new("const").long("const").help("also display consts nodes"))
.arg(Arg::new("info").long("info").help("show op inner information"))
.arg(Arg::new("io-long").long("io-long").help("show full i/o information"))
.arg(Arg::new("io-none").long("io-none").help("hide i/o information"))
.arg(Arg::new("json").long("json").help("dump performance info as json"))
.arg(Arg::new("outlet-labels").long("outlet-labels").help("display outlet labels"))
.arg(Arg::new("cost").long("cost").help("Include const information"))
.arg(Arg::new("profile").long("profile").help("Include results for profile run"))
.arg(
Arg::new("invariants")
.takes_value(false)
.long("invariants")
.help("Display operators invariants"),
)
}
fn handle(matches: clap::ArgMatches, probe: Option<&Probe>) -> TractResult<()> {
if matches.is_present("list-ops") {
#[cfg(feature = "onnx")]
{
let onnx = tract_onnx::onnx();
let names = onnx.op_register.0.keys().sorted().into_iter().join(", ");
println!("Onnx:\n");
println!("{names}");
println!("\n");
}
#[cfg(feature = "tf")]
{
let tf = tract_tensorflow::tensorflow();
let names = tf.op_register.0.keys().sorted().into_iter().join(", ");
println!("Tensorflow:\n");
println!("{names}");
println!("\n");
}
return Ok(());
}
let builder_result = Parameters::from_clap(&matches, probe);
#[allow(unused_mut)]
let mut params = match builder_result {
Ok(params) => params,
Err(e) => {
if let Some(params::ModelBuildingError(ref broken_model, _)) = e.downcast_ref() {
let mut broken_model: Box<dyn Model> =
tract_hir::tract_core::dyn_clone::clone(broken_model);
let annotations = Annotations::from_model(broken_model.as_ref())?;
let display_params = if let Some(("dump", sm)) = matches.subcommand() {
display_params_from_clap(&matches, sm)?
} else {
DisplayParams::default()
};
if broken_model.output_outlets().len() == 0 {
broken_model.auto_outputs()?;
}
tract_libcli::terminal::render(
broken_model.as_ref(),
&annotations,
&display_params,
)?;
}
Err(e)?
}
};
let mut need_optimisations = false;
match matches.subcommand() {
Some(("bench", m)) => {
need_optimisations = true;
bench::handle(¶ms, &matches, m, ¶ms::bench_limits_from_clap(m)?, probe)
}
Some(("criterion", m)) => {
need_optimisations = true;
bench::criterion(¶ms, &matches, m)
}
Some(("compare", m)) => {
compare::handle(&mut params, &matches, m, display_params_from_clap(&matches, m)?)
}
Some(("run", m)) => run::handle(¶ms, &matches, m),
#[cfg(feature = "pulse")]
Some(("stream-check", m)) => {
stream_check::handle(¶ms, &display_params_from_clap(&matches, m)?)
}
None => dump::handle(
¶ms,
&DisplayParams::default(),
&matches,
&dump_subcommand().get_matches_from::<_, &'static str>([]),
&BenchLimits::default(),
vec![],
),
Some(("dump", m)) => {
need_optimisations = m.is_present("profile");
let inner = m
.values_of("inner")
.map(|ss| ss.map(|s| s.to_string()).collect())
.unwrap_or_default();
dump::handle(
¶ms,
&display_params_from_clap(&matches, m)?,
&matches,
m,
¶ms::bench_limits_from_clap(m)?,
inner,
)
}
Some((s, _)) => bail!("Unknown subcommand {}.", s),
}?;
if need_optimisations {
let style = ansi_term::Style::new().fg(ansi_term::Color::Red).bold();
if !matches.is_present("optimize") {
warn!("{}", style.paint("Profiling an un-optimized network. Consider adding -O."));
}
if cfg!(debug_assertions) {
warn!("{}", style.paint("Profiling a debug build of tract!"));
}
}
Ok(())
}
fn nnef(matches: &clap::ArgMatches) -> tract_nnef::internal::Nnef {
let mut fw = tract_nnef::nnef();
if matches.is_present("nnef-tract-onnx") {
#[cfg(feature = "onnx")]
{
use tract_onnx::WithOnnx;
fw = fw.with_onnx();
}
#[cfg(not(feature = "onnx"))]
{
panic!("tract is build without ONNX support")
}
}
if matches.is_present("nnef-tract-pulse") {
#[cfg(feature = "pulse-opl")]
{
use tract_pulse::WithPulse;
fw = fw.with_pulse();
}
#[cfg(not(feature = "pulse-opl"))]
{
panic!("tract is build without pulse-opl support")
}
}
if matches.is_present("nnef-tract-core") {
fw = fw.with_tract_core();
}
if matches.is_present("nnef-extended-identifier") {
fw.allow_extended_identifier_syntax(true);
}
fw
}