use std::{collections::HashMap, num::ParseFloatError, path::PathBuf, str::FromStr};
use structopt::StructOpt;
use thiserror::Error;
use wonnx::{
onnx::ModelProto,
utils::{OutputTensor, Shape, TensorConversionError},
SessionError, WonnxError,
};
use wonnx_preprocessing::{
text::{EncodedText, PreprocessingError},
Tensor,
};
#[cfg(feature = "cpu")]
use tract_onnx::prelude::*;
#[derive(Debug, StructOpt)]
pub struct InfoOptions {
#[structopt(parse(from_os_str))]
pub model: PathBuf,
}
#[derive(Debug, StructOpt)]
pub enum Backend {
Gpu,
#[cfg(feature = "cpu")]
Cpu,
}
#[derive(Error, Debug)]
pub enum NNXError {
#[error("invalid backend selected")]
InvalidBackend(String),
#[error("input shape is invalid")]
InvalidInputShape,
#[error("output not found: {0}")]
OutputNotFound(String),
#[error("input not found")]
InputNotFound(String),
#[error("backend error: {0}")]
BackendFailed(#[from] WonnxError),
#[error("backend execution error: {0}")]
BackendExecutionFailed(#[from] SessionError),
#[cfg(feature = "cpu")]
#[error("cpu backend error: {0}")]
CPUBackendFailed(#[from] TractError),
#[cfg(feature = "cpu")]
#[error("comparison failed")]
Comparison(String),
#[error("preprocessing failed: {0}")]
PreprocessingFailed(#[from] PreprocessingError),
#[error("invalid number: {0}")]
InvalidNumber(ParseFloatError),
#[error("tensor error: {0}")]
TensorConversionError(#[from] TensorConversionError),
#[error("I/O error: {0}")]
IOError(#[from] std::io::Error),
}
impl FromStr for Backend {
type Err = NNXError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"gpu" => Ok(Backend::Gpu),
#[cfg(feature = "cpu")]
"cpu" => Ok(Backend::Cpu),
_ => Err(NNXError::InvalidBackend(s.to_string())),
}
}
}
fn parse_key_val<T, U>(s: &str) -> Result<(T, U), Box<dyn std::error::Error>>
where
T: std::str::FromStr,
T::Err: std::error::Error + 'static,
U: std::str::FromStr,
U::Err: std::error::Error + 'static,
{
let pos = s
.find('=')
.ok_or_else(|| format!("invalid KEY=value: no `=` found in `{}`", s))?;
Ok((s[..pos].parse()?, s[pos + 1..].parse()?))
}
#[derive(Debug, StructOpt)]
pub struct InferOptions {
#[structopt(parse(from_os_str))]
pub model: PathBuf,
#[structopt(long, default_value = "gpu")]
pub backend: Backend,
#[structopt(short = "i", parse(try_from_str = parse_key_val), number_of_values = 1)]
pub input_images: Vec<(String, PathBuf)>,
#[structopt(long)]
pub top: Option<usize>,
#[structopt(long)]
pub probabilities: bool,
#[structopt(long)]
pub output_name: Vec<String>,
#[structopt(
long = "tokenizer",
parse(from_os_str),
default_value = "./tokenizer.json"
)]
pub tokenizer: PathBuf,
#[structopt(short = "q", long = "question")]
pub question: Option<String>,
#[structopt(short = "c", long = "context")]
pub context: Option<String>,
#[structopt(long = "qa-tokens-input", default_value = "input_ids:0")]
pub qa_tokens_input: String,
#[structopt(long = "qa-mask-input", default_value = "input_mask:0")]
pub qa_mask_input: String,
#[structopt(long = "qa-segment-input", default_value = "segment_ids:0")]
pub qa_segment_input: String,
#[structopt(long = "qa-answer-start-output", default_value = "unstack:0")]
pub qa_answer_start: String,
#[structopt(long = "qa-answer-end-output", default_value = "unstack:1")]
pub qa_answer_end: String,
#[structopt(long = "qa-answer")]
pub qa_answer: bool,
#[structopt(short = "t", parse(try_from_str = parse_key_val), number_of_values = 1)]
pub text: Vec<(String, String)>,
#[structopt(short = "m", parse(try_from_str = parse_key_val), number_of_values = 1)]
pub text_mask: Vec<(String, String)>,
#[structopt(short = "r", parse(try_from_str = parse_key_val), number_of_values = 1)]
pub raw: Vec<(String, String)>,
#[structopt(short, long, parse(from_os_str))]
pub labels: Option<PathBuf>,
#[cfg(feature = "cpu")]
#[structopt(long)]
pub fallback: bool,
#[cfg(feature = "cpu")]
#[structopt(long, conflicts_with = "backend")]
pub compare: bool,
#[structopt(long)]
pub benchmark: bool,
}
#[derive(Debug, StructOpt)]
#[allow(clippy::large_enum_variant)]
pub enum Command {
Devices,
Infer(InferOptions),
Info(InfoOptions),
Graph(InfoOptions),
}
#[derive(Debug, StructOpt)]
#[structopt(
name = "nnx: Neural Network Execute",
about = "GPU-accelerated ONNX inference through wonnx from the command line"
)]
pub struct Opt {
#[structopt(subcommand)]
pub cmd: Command,
}
use async_trait::async_trait;
#[async_trait]
pub trait Inferer {
async fn infer(
&self,
outputs: &[String],
inputs: &HashMap<String, Tensor>,
model: &ModelProto,
) -> Result<HashMap<String, OutputTensor>, NNXError>;
}
pub struct InferenceInput {
pub inputs: HashMap<String, Tensor>,
pub input_shapes: HashMap<String, Shape>,
pub qa_encoding: Option<EncodedText>,
}