use fs::File;
use std::fmt::{Debug, Display};
use std::path::PathBuf;
use std::str::FromStr;
use crate::TractResult;
use crate::{Model, Parameters};
use fs_err as fs;
use ndarray_npy::NpzWriter;
use nu_ansi_term::Color::*;
use tract_core::ops::cnn::conv::Im2Col;
use tract_core::ops::matmul::pack::OptMatMulPack;
use tract_core::tract_data::itertools::izip;
use tract_hir::internal::*;
use tract_libcli::tensor::{RunTensors, get_or_make_inputs};
use tract_nnef::tensors::write_tensor;
#[cfg(feature = "pulse")]
use tract_pulse::internal::*;
fn npz_add_tensor(npz: &mut NpzWriter<File>, name: String, tensor: &Tensor) -> TractResult<()> {
match tensor.datum_type() {
DatumType::F16 => {
npz.add_array(name, &tensor.cast_to::<f32>()?.to_plain_array_view::<f32>()?)?
}
DatumType::Bool => npz.add_array(name, &tensor.to_plain_array_view::<bool>()?)?,
DatumType::U8 => npz.add_array(name, &tensor.to_plain_array_view::<u8>()?)?,
DatumType::U16 => npz.add_array(name, &tensor.to_plain_array_view::<u16>()?)?,
DatumType::U32 => npz.add_array(name, &tensor.to_plain_array_view::<u32>()?)?,
DatumType::U64 => npz.add_array(name, &tensor.to_plain_array_view::<u64>()?)?,
DatumType::I8 => npz.add_array(name, &tensor.to_plain_array_view::<i8>()?)?,
DatumType::I16 => npz.add_array(name, &tensor.to_plain_array_view::<i16>()?)?,
DatumType::I32 => npz.add_array(name, &tensor.to_plain_array_view::<i32>()?)?,
DatumType::I64 => npz.add_array(name, &tensor.to_plain_array_view::<i64>()?)?,
DatumType::F32 => npz.add_array(name, &tensor.to_plain_array_view::<f32>()?)?,
DatumType::F64 => npz.add_array(name, &tensor.to_plain_array_view::<f64>()?)?,
DatumType::QI8(_) => npz.add_array(name, &tensor.to_plain_array_view::<i8>()?)?,
DatumType::QU8(_) => npz.add_array(name, &tensor.to_plain_array_view::<u8>()?)?,
DatumType::QI32(_) => npz.add_array(name, &tensor.to_plain_array_view::<i32>()?)?,
_ => warn!("Not writing {name}, {tensor:?}, unsupported type"),
}
Ok(())
}
pub fn handle(
params: &Parameters,
matches: &clap::ArgMatches,
sub_matches: &clap::ArgMatches,
) -> TractResult<()> {
let dump = sub_matches.get_flag("dump");
let outputs = run_regular(¶ms.tract_model, params, matches, sub_matches)?;
if dump {
for (ix, output) in outputs.iter().enumerate() {
for (turn, output) in output.iter().enumerate() {
println!("output #{}, turn #{}\n{}\n", ix, turn, output.dump(true)?);
}
}
}
if let Some(file_path) = sub_matches.get_one::<String>("save-outputs-nnef") {
fs::create_dir_all(file_path).with_context(|| format!("Creating {file_path} directory"))?;
for (ix, outputs) in outputs.iter().enumerate() {
let name = params
.tract_model
.outlet_label(params.tract_model.output_outlets()[ix])
.map(|name| format!("{name}.dat"))
.unwrap_or_else(|| format!("output_{ix}.dat"));
if outputs.len() == 1 {
let mut f = fs::File::create(PathBuf::from_str(file_path)?.join(&name))?;
write_tensor(&mut f, &outputs[0])?;
} else {
for (turn, output) in outputs.iter().enumerate() {
let name = format!("turn_{turn}/{name}");
let mut f = fs::File::open(PathBuf::from_str(file_path)?.join(name))?;
write_tensor(&mut f, output)?;
}
}
}
}
if let Some(file_path) = sub_matches.get_one::<String>("save-outputs-npz") {
let file = fs::File::create(file_path).with_context(|| format!("Creating {file_path}"))?;
let mut npz = ndarray_npy::NpzWriter::new_compressed(file);
for (ix, outputs) in outputs.iter().enumerate() {
let name = params
.tract_model
.outlet_label(params.tract_model.output_outlets()[ix])
.map(|name| name.to_string())
.unwrap_or_else(|| format!("output_{ix}"));
if outputs.len() == 1 {
npz_add_tensor(&mut npz, name, &outputs[0])?;
} else {
for (turn, output) in outputs.iter().enumerate() {
let name = format!("turn_{turn}/{name}");
npz_add_tensor(&mut npz, name, output)?;
}
}
}
}
if let Some(count) = sub_matches.get_one::<String>("assert-output-count") {
let count = count.parse::<usize>()?;
if count != outputs.len() {
bail!(
"Wrong number of outputs, command line expected {}, found {:?}",
count,
outputs.len()
);
}
}
if params.assertions.assert_outputs {
crate::utils::check_outputs(&outputs, params)?;
}
if let Some(facts) = ¶ms.assertions.assert_output_facts {
let outputs: Vec<InferenceFact> =
outputs.iter().map(|t| t[0].datum_type().fact(t[0].shape()).into()).collect();
crate::utils::check_inferred(&outputs, facts)?;
}
if let Some(asserts) = ¶ms.assertions.assert_op_count {
for (name, expected) in asserts {
let count = crate::utils::count_op(&*params.tract_model, name)?;
if count != *expected {
bail!("Wrong number of {} operators: expected {}, got {}", name, expected, count);
}
}
}
if let Some(patterns) = ¶ms.assertions.assert_op_only {
crate::utils::check_op_only(&*params.tract_model, patterns)?;
}
Ok(())
}
fn run_regular_t<F, O>(
state: &mut SimpleState<F, O>,
inputs: RunTensors,
steps: bool,
check_f16_overflow: bool,
assert_sane_floats: bool,
mut npz: Option<NpzWriter<File>>,
) -> TractResult<TVec<Vec<TValue>>>
where
F: Fact + Clone + 'static,
O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
{
let mut results = tvec!(vec!(); state.model().outputs.len());
let multiturn = inputs.sources.len() > 1;
let cache_output_to_input: Vec<(usize, usize)> = if multiturn {
let model = state.model();
let mut mapping = Vec::new();
for (out_ix, out_outlet) in model.outputs.iter().enumerate() {
let out_label = model.outlet_label(*out_outlet).unwrap_or("");
if let Some(base) = out_label.strip_suffix("_concat") {
for (in_ix, in_outlet) in model.inputs.iter().enumerate() {
let in_name = &model.nodes[in_outlet.node].name;
if in_name == base {
mapping.push((out_ix, in_ix));
break;
}
}
}
}
mapping
} else {
Vec::new()
};
let pulse_sym_binding: Option<(Symbol, i64)> =
if let (Some(sym_name), Some(stream_len), Some(first_input)) = (
state
.model()
.properties
.get("pulse.streaming_symbol")
.and_then(|t| t.to_plain_array_view::<String>().ok())
.and_then(|a| a.first().cloned()),
inputs.streaming_input_len,
inputs.sources.first().and_then(|t| t.first()),
) {
let input_axes = state
.model()
.properties
.get("pulse.input_axes")
.context("Expect pulse.input_axes when pulse.streaming_symbol is set")?
.cast_to::<i64>()?;
let input_axis = input_axes.try_as_plain()?.as_slice::<i64>()?[0] as usize;
let pulse_value = first_input.shape()[input_axis];
if pulse_value > 0 && stream_len % pulse_value == 0 {
let sym = state.model().symbols.sym(&sym_name);
Some((sym, (stream_len / pulse_value) as i64))
} else {
None
}
} else {
None
};
let mut sources = inputs.sources;
for turn in 0..sources.len() {
let inputs = std::mem::replace(&mut sources[turn], TVec::new());
if let Some((sym, val)) = &pulse_sym_binding {
state.turn_state.resolved_symbols.set(sym, *val);
}
let turn_results =
state.run_plan_with_eval(inputs, |session_state, state, node, input| {
if steps {
for (ix, i) in input.iter().enumerate() {
eprintln!(
"{} {}{}{:?}",
White.bold().paint(node.to_string()),
ix,
Blue.bold().paint("<< "),
i
);
}
}
let r = tract_core::plan::eval(session_state, state, node, input)?;
if steps || npz.is_some() || check_f16_overflow || assert_sane_floats {
let clarified_r = crate::utils::clarify_tvalues(&r)?;
if steps {
for (ix, o) in clarified_r.iter().enumerate() {
eprintln!(
"{} {}{}{:?}",
White.bold().paint(node.to_string()),
ix,
Yellow.bold().paint(">> "),
o
);
}
}
if let Some(npz) = npz.as_mut() {
for (ix, t) in clarified_r.iter().enumerate() {
let mut name = if ix == 0 {
node.name.to_string()
} else {
format!("{}:{}", node.name, ix)
};
if multiturn {
name = format!("turn_{turn}/{name}");
}
npz_add_tensor(npz, name, t)?;
}
}
if check_f16_overflow {
for (ix, o) in clarified_r.iter().enumerate() {
if let Ok(plain) = o.try_as_plain() {
if let Ok(f32s) = plain.as_slice::<f32>() {
if f32s.iter().any(|f| f.abs() > f16::MAX.to_f32()) {
warn!("{node}, output {ix} overflows f16");
}
}
}
}
}
if assert_sane_floats {
for (ix, o) in clarified_r.iter().enumerate() {
if node.op_is::<Im2Col>() || node.op_is::<OptMatMulPack>() {
continue;
}
if let Ok(plain) = o.try_as_plain() {
if let Ok(floats) = plain.as_slice::<f32>() {
if let Some(pos) = floats.iter().position(|f| !f.is_finite()) {
eprintln!("{floats:?}");
bail!("Found {} in output {} of {}", floats[pos], ix, node);
}
} else if let Ok(floats) = plain.as_slice::<f16>() {
if let Some(pos) = floats.iter().position(|f| !f.is_finite()) {
eprintln!("{floats:?}");
bail!("Found {} in output {} of {}", floats[pos], ix, node);
}
}
}
}
}
}
Ok(r)
})?;
if turn + 1 < sources.len() && !cache_output_to_input.is_empty() {
for &(out_ix, in_ix) in &cache_output_to_input {
sources[turn + 1][in_ix] = turn_results[out_ix].clone();
}
}
izip!(&mut results, turn_results).for_each(|(r, tr)| r.push(tr));
}
Ok(results)
}
fn run_regular(
tract: &Arc<dyn Model>,
params: &Parameters,
_matches: &clap::ArgMatches,
sub_matches: &clap::ArgMatches,
) -> TractResult<TVec<Vec<TValue>>> {
let run_params = crate::tensor::run_params_from_subcommand(params, sub_matches)?;
let steps = sub_matches.get_flag("steps");
let check_f16_overflow = sub_matches.get_flag("check-f16-overflow");
let assert_sane_floats = sub_matches.get_flag("assert-sane-floats");
let npz = if let Some(npz) = sub_matches.get_one::<String>("save-steps") {
let npz = fs::File::create(npz).with_context(|| format!("Creating {npz}"))?;
Some(ndarray_npy::NpzWriter::new_compressed(npz))
} else {
None
};
let inputs = get_or_make_inputs(tract, &run_params)?;
if let Some(runnable) = ¶ms.runnable {
if let Some(plan) = runnable.typed_plan() {
let mut state = plan.spawn()?;
let results = run_regular_t(
&mut state,
inputs,
steps,
check_f16_overflow,
assert_sane_floats,
npz,
)?;
Ok(results)
} else {
todo!("Run handler for abstract runtime/runnable");
}
} else {
dispatch_model!(tract, |m| {
let plan_options = crate::plan_options::plan_options_from_subcommand(sub_matches)?;
let plan = SimplePlan::new_with_options(m, &plan_options)?;
let mut state = plan.spawn()?;
let results = run_regular_t(
&mut state,
inputs,
steps,
check_f16_overflow,
assert_sane_floats,
npz,
)?;
Ok(results)
})
}
}