tract-cli 0.23.0-dev.3

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
use tract_core::internal::*;
use tract_libcli::tensor::RunParams;
#[cfg(feature = "transformers")]
use tract_transformers::figure_out_causal_llm_b_s_p;

use crate::params::Parameters;

pub fn run_params_from_subcommand(
    params: &Parameters,
    sub_matches: &clap::ArgMatches,
) -> TractResult<RunParams> {
    let mut tv = params.tensors_values.clone();

    if let Some(bundle) = sub_matches.get_many::<String>("input-from-npz") {
        for input in bundle {
            let input = input.as_str();
            for tensor in Parameters::parse_npz(input, true, false)? {
                tv.add(tensor);
            }
        }
    }

    if let Some(dir) = sub_matches.get_one::<String>("input-from-nnef") {
        for tensor in Parameters::parse_nnef_tensors(dir, true, false)? {
            tv.add(tensor);
        }
    }

    // We also support the global arg variants for backward compatibility
    #[allow(unused_mut)]
    let mut allow_random_input: bool =
        params.allow_random_input || sub_matches.get_flag("allow-random-input");
    let allow_float_casts: bool =
        params.allow_float_casts || sub_matches.get_flag("allow-float-casts");

    let mut symbols = SymbolValues::default();

    #[cfg(feature = "transformers")]
    if let Some(pp) = sub_matches.get_one::<String>("pp") {
        let value: i64 =
            pp.parse().with_context(|| format!("Can not parse symbol value in --pp {pp}"))?;
        let Some(typed_model) = params.tract_model.downcast_ref::<TypedModel>() else {
            bail!("PP mode can only be used with a TypedModel");
        };
        let (b, s, p) = figure_out_causal_llm_b_s_p(typed_model)?;
        if let Some(b) = b {
            symbols.set(&b, 1);
        }

        ensure!(s.is_some() && p.is_some(), "Could not find LLM symbols in model");
        symbols.set(&p.unwrap(), 0);
        symbols.set(&s.unwrap(), value);
        allow_random_input = true
    }

    #[cfg(feature = "transformers")]
    if let Some(tg) = sub_matches.get_one::<String>("tg") {
        let value: i64 =
            tg.parse().with_context(|| format!("Can not parse symbol value in --tg {tg}"))?;
        let Some(typed_model) = params.tract_model.downcast_ref::<TypedModel>() else {
            bail!("TG mode can only be used with a TypedModel");
        };
        let (b, s, p) = figure_out_causal_llm_b_s_p(typed_model)?;
        if let Some(b) = b {
            symbols.set(&b, 1);
        }

        ensure!(s.is_some() && p.is_some(), "Could not find LLM symbols in model");
        symbols.set(&p.unwrap(), value - 1);
        symbols.set(&s.unwrap(), 1);
        allow_random_input = true
    }

    if let Some(set) = sub_matches.get_many::<String>("set") {
        for set in set {
            let set = set.as_str();
            let (sym, value) = set.split_once('=').context("--set expect S=12 form")?;
            let sym = params.tract_model.get_or_intern_symbol(sym);
            let value: i64 = value
                .parse()
                .with_context(|| format!("Can not parse symbol value in set {set}"))?;
            symbols.set(&sym, value);
        }
    }

    let prompt_chunk_size = sub_matches
        .get_one::<String>("prompt-chunk-size")
        .and_then(|chunk_size| chunk_size.parse().ok());
    Ok(RunParams {
        tensors_values: tv,
        allow_random_input,
        allow_float_casts,
        symbols,
        prompt_chunk_size,
    })
}