use anyhow::anyhow;
use rten_tensor::{Tensor, TensorView};
pub trait Model {
fn input_shape(&self) -> anyhow::Result<Vec<rten::Dimension>>;
fn run(
&self,
input: TensorView<f32>,
opts: Option<rten::RunOptions>,
) -> anyhow::Result<Tensor<f32>>;
}
impl Model for rten::Model {
fn input_shape(&self) -> anyhow::Result<Vec<rten::Dimension>> {
let input_id = self
.input_ids()
.first()
.copied()
.ok_or(anyhow!("model has no inputs"))?;
let input_shape = self
.node_info(input_id)
.and_then(|info| info.shape())
.ok_or(anyhow!("model does not specify expected input shape"))?;
Ok(input_shape)
}
fn run(
&self,
input: TensorView<f32>,
opts: Option<rten::RunOptions>,
) -> anyhow::Result<Tensor<f32>> {
let output = self.run_one(input.into(), opts)?.try_into()?;
Ok(output)
}
}