tract-cli 0.23.0-dev.3

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
use std::iter::once;

use crate::params::Parameters;
use tract_hir::internal::*;
use tract_itertools::Itertools;
use tract_libcli::model::Model;

/// Compares the outputs of a node in tract and tensorflow.
pub fn check_outputs(got: &[Vec<TValue>], params: &Parameters) -> TractResult<()> {
    let mut error = None;
    // iter over all possible tract model outputs
    for (ix, output) in params.tract_model.output_outlets().iter().enumerate() {
        // get either name from outlet_label or from node_name
        let mut lookup_names: Vec<&str> = params
            .tract_model
            .outlet_label(*output)
            .into_iter()
            .chain(once(params.tract_model.node_name(output.node)))
            .collect_vec();
        lookup_names.dedup();
        let exp = lookup_names.iter().find_map(|name| params.tensors_values.by_name(name));
        if exp.is_none() {
            if params.assertions.allow_missing_outputs {
                debug!("Missing reference output in bundle for {}", lookup_names.join(" or "));
                continue;
            } else {
                bail!("Missing reference output in bundle for {}", lookup_names.join(" or "));
            }
        }
        let exp = exp.unwrap();
        debug!("Output {ix}, expects {exp:?}");
        let mut exp: TValue = exp.values.as_ref().with_context(|| {
            format!("Output {lookup_names:?}: found reference info without value: {exp:?}")
        })?[0]
            .clone();

        let props = params.tract_model.properties();

        let got: TValue = if got[ix].len() > 1 && props.get("pulse.output_axes").is_some() {
            let axis = props.get("pulse.output_axes").unwrap().try_as_dense()?.as_slice::<i64>()?
                [ix] as usize;
            let delay =
                props.get("pulse.delay").unwrap().try_as_dense()?.as_slice::<i64>()?[ix] as usize;
            let stacked = Tensor::stack_tensors(axis, &got[ix])?;
            stacked.slice(axis, delay, delay + exp.shape()[axis])?.into()
        } else {
            // This handles LLM prompt-chunking output
            got[ix].last().unwrap().clone()
        };
        if (params.allow_float_casts
            && exp.datum_type() == f32::datum_type()
            && got.datum_type() == f16::datum_type())
            || exp.datum_type().unquantized() == got.datum_type().unquantized()
        {
            exp = exp.cast_to_dt(got.datum_type())?.into_owned().into_tvalue();
        }
        #[allow(unused)]
        let result: TractResult<()> = if let Some(min_rbo) = params.assertions.assert_llm_rbo {
            #[cfg(not(feature = "transformers"))]
            {
                bail!("transformers feature is required for RBO metric")
            }
            #[cfg(feature = "transformers")]
            {
                let rbo = crate::llm::top_logits_rbo(
                    &exp,
                    &got,
                    params.assertions.assert_llm_rbo_p,
                    params.assertions.assert_llm_rbo_depth,
                )?;
                let rbo = (rbo * 100.0).floor() / 100.0;
                info!("LLM RBO: {rbo:.2}");
                if rbo >= min_rbo {
                    Ok(())
                } else {
                    TractResult::Err(anyhow!(
                        "RBO criteria not met: rbo={rbo:.2}, min required {min_rbo}"
                    ))
                }
            }
        } else {
            exp.close_enough(&got, params.assertions.approximation)
        };
        if let Err(e) = result.context(format!("Checking output {ix}")) {
            if error.is_some() {
                error!("{e:?}");
            } else {
                error = Some(e);
            }
        } else {
            info!("Checked output #{ix}, ok.");
        }
    }

    if let Some(e) = error { Err(e) } else { Ok(()) }
}

/// Compares the outputs of a node in tract and tensorflow.
pub fn check_inferred(got: &[InferenceFact], expected: &[InferenceFact]) -> TractResult<()> {
    if got.len() != expected.len() {
        bail!("Number of output differ: got:{}, expected:{}", got.len(), expected.len())
    }

    for (got, exp) in got.iter().zip(expected.iter()) {
        if exp.datum_type != got.datum_type {
            bail!("Failed to infer datum type: expected {:?}, got {:?}", exp, got);
        }
        if exp.shape != got.shape {
            bail!("Failed to infer shape: expected {:?}, got {:?}", exp, got);
        }
    }

    Ok(())
}

pub fn clarify_tvalue(t: &TValue) -> TractResult<TValue> {
    if t.datum_type().is_opaque() && t.volume() == 1 {
        if let Some(clarified) = t.try_as_dense()?.to_scalar::<Opaque>()?.clarify_to_tensor()? {
            return Ok(clarified.into_tvalue());
        }
    }
    Ok((*t).clone())
}

pub fn clarify_tvalues(values: &TVec<TValue>) -> TractResult<TVec<TValue>> {
    values.iter().map(clarify_tvalue).collect()
}

pub fn clarify_typed_fact<'a>(fact: impl Into<Cow<'a, TypedFact>>) -> Cow<'a, TypedFact> {
    let fact = fact.into();
    if fact.datum_type == DatumType::Opaque {
        fact.opaque_fact
            .as_ref()
            .and_then(|it| it.clarify_dt_shape())
            .map(|(dt, s)| Cow::Owned(TypedFact::dt_shape(dt, s)))
            .unwrap_or_else(|| fact)
    } else {
        fact
    }
}

pub fn count_op(model: &dyn Model, name: &str) -> TractResult<usize> {
    Ok(model
        .eval_order()
        .context("Cannot assert op count without an eval order")?
        .into_iter()
        .map(|i| {
            if model.node_op_name(i) == name {
                1
            } else {
                model.nested_models(i).into_iter().flat_map(|(_, m)| count_op(m, name)).sum()
            }
        })
        .sum())
}