use std::fs::File;
use crate::TractResult;
use crate::{Model, Parameters};
use ansi_term::Color::*;
use ndarray_npy::NpzWriter;
use tract_core::tract_data::itertools::izip;
use tract_hir::internal::*;
use tract_libcli::tensor::RunParams;
#[cfg(feature = "pulse")]
use tract_pulse::internal::*;
fn npz_add_tensor(npz: &mut NpzWriter<File>, name: String, tensor: &Tensor) -> TractResult<()> {
match tensor.datum_type() {
DatumType::F16 => npz.add_array(name, &tensor.cast_to::<f32>()?.to_array_view::<f32>()?)?,
DatumType::Bool => npz.add_array(name, &tensor.to_array_view::<bool>()?)?,
DatumType::U8 => npz.add_array(name, &tensor.to_array_view::<u8>()?)?,
DatumType::U16 => npz.add_array(name, &tensor.to_array_view::<u16>()?)?,
DatumType::U32 => npz.add_array(name, &tensor.to_array_view::<u32>()?)?,
DatumType::U64 => npz.add_array(name, &tensor.to_array_view::<u64>()?)?,
DatumType::I8 => npz.add_array(name, &tensor.to_array_view::<i8>()?)?,
DatumType::I16 => npz.add_array(name, &tensor.to_array_view::<i16>()?)?,
DatumType::I32 => npz.add_array(name, &tensor.to_array_view::<i32>()?)?,
DatumType::I64 => npz.add_array(name, &tensor.to_array_view::<i64>()?)?,
DatumType::F32 => npz.add_array(name, &tensor.to_array_view::<f32>()?)?,
DatumType::F64 => npz.add_array(name, &tensor.to_array_view::<f64>()?)?,
DatumType::QI8(_) => npz.add_array(name, &tensor.to_array_view::<i8>()?)?,
DatumType::QU8(_) => npz.add_array(name, &tensor.to_array_view::<u8>()?)?,
DatumType::QI32(_) => npz.add_array(name, &tensor.to_array_view::<i32>()?)?,
_ => warn!("Not writing {}, {:?}, unsupported type", name, tensor),
}
Ok(())
}
pub fn handle(
params: &Parameters,
matches: &clap::ArgMatches,
sub_matches: &clap::ArgMatches,
) -> TractResult<()> {
let run_params = crate::tensor::run_params_from_subcommand(params, sub_matches)?;
let dump = sub_matches.is_present("dump");
let outputs = dispatch_model!(&*params.tract_model, |m| run_regular(
m,
&run_params,
matches,
sub_matches
))?;
if dump {
for (ix, output) in outputs.iter().enumerate() {
for (turn, output) in output.iter().enumerate() {
println!("output #{}, turn #{}\n{}\n", ix, turn, output.dump(true)?);
}
}
}
if let Some(file_path) = sub_matches.value_of("save-outputs") {
let file =
std::fs::File::create(file_path).with_context(|| format!("Creating {file_path}"))?;
let mut npz = ndarray_npy::NpzWriter::new_compressed(file);
for (ix, outputs) in outputs.iter().enumerate() {
let name = params
.tract_model
.outlet_label(params.tract_model.output_outlets()[ix])
.map(|name| name.to_string())
.unwrap_or_else(|| format!("output_{ix}"));
if outputs.len() == 1 {
npz_add_tensor(&mut npz, name, &outputs[0])?;
} else {
for (turn, output) in outputs.iter().enumerate() {
let name = format!("turn_{turn}/{name}");
npz_add_tensor(&mut npz, name, output)?;
}
}
}
}
if let Some(count) = sub_matches.value_of("assert-output-count") {
let count = count.parse::<usize>()?;
if count != outputs.len() {
bail!(
"Wrong number of outputs, command line expected {}, found {:?}",
count,
outputs.len()
);
}
}
if params.assertions.assert_outputs {
crate::utils::check_outputs(&outputs, params)?;
}
if let Some(facts) = ¶ms.assertions.assert_output_facts {
let outputs: Vec<InferenceFact> =
outputs.iter().map(|t| t[0].datum_type().fact(t[0].shape()).into()).collect();
crate::utils::check_inferred(&outputs, facts)?;
}
if let Some(asserts) = ¶ms.assertions.assert_op_count {
for (name, expected) in asserts {
let count = crate::utils::count_op(&*params.tract_model, name)?;
if count != *expected {
bail!("Wrong number of {} operators: expected {}, got {}", name, expected, count);
}
}
}
Ok(())
}
fn run_regular(
tract: &dyn Model,
run_params: &RunParams,
_matches: &clap::ArgMatches,
sub_matches: &clap::ArgMatches,
) -> TractResult<TVec<Vec<TValue>>> {
let steps = sub_matches.is_present("steps");
let assert_sane_floats = sub_matches.is_present("assert-sane-floats");
let mut npz = if let Some(npz) = sub_matches.value_of("save-steps") {
let npz = std::fs::File::create(npz).with_context(|| format!("Creating {npz}"))?;
Some(ndarray_npy::NpzWriter::new_compressed(npz))
} else {
None
};
dispatch_model!(tract, |m| {
let plan = SimplePlan::new(m)?;
let mut state = SimpleState::new(plan)?;
if let Some(set) = sub_matches.values_of("set") {
for set in set {
let mut tokens = set.split('=');
let sym = tokens.next().context("--set expect S=12 form")?;
let value = tokens.next().context("--set expect S=12 form")?;
let sym = state.model().symbol_table.sym(sym).to_owned();
let value: i64 = value.parse().context("Can not parse symbol value in set")?;
state.session_state.resolved_symbols =
state.session_state.resolved_symbols.with(&sym, value);
}
}
let inputs = tract_libcli::tensor::retrieve_or_make_inputs(tract, run_params)?;
let mut results = tvec!(vec!(); state.model().outputs.len());
let multiturn = inputs.len() > 1;
for (turn, inputs) in inputs.into_iter().enumerate() {
let turn_results =
state.run_plan_with_eval(inputs, |session_state, state, node, input| {
if steps {
for (ix, i) in input.iter().enumerate() {
eprintln!(
"{} {}{}{:?}",
White.bold().paint(node.to_string()),
ix,
Blue.bold().paint("<< "),
i
);
}
}
let r = tract_core::plan::eval(session_state, state, node, input)?;
if steps {
for (ix, o) in r.iter().enumerate() {
eprintln!(
"{} {}{}{:?}",
White.bold().paint(node.to_string()),
ix,
Yellow.bold().paint(">> "),
o
);
}
}
if let Some(npz) = npz.as_mut() {
for (ix, t) in r.iter().enumerate() {
let mut name = if ix == 0 {
node.name.to_string()
} else {
format!("{}:{}", node.name, ix)
};
if multiturn {
name = format!("turn_{turn}/{name}");
}
npz_add_tensor(npz, name, t)?;
}
}
if assert_sane_floats {
for (ix, o) in r.iter().enumerate() {
if let Ok(floats) = o.as_slice::<f32>() {
if let Some(pos) = floats.iter().position(|f| !f.is_finite()) {
eprintln!("{floats:?}");
tract_core::anyhow::bail!(
"Found {} in output {} of {}",
floats[pos],
ix,
node
);
}
} else if let Ok(floats) = o.as_slice::<f16>() {
if let Some(pos) = floats.iter().position(|f| !f.is_finite()) {
eprintln!("{floats:?}");
tract_core::anyhow::bail!(
"Found {} in output {} of {}",
floats[pos],
ix,
node
);
}
}
}
}
Ok(r)
})?;
izip!(&mut results, turn_results).for_each(|(r, tr)| r.push(tr));
}
Ok(results)
})
}