use ndarray::{ArrayD, ArrayViewD};
use std::path::Path;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum MnnError {
InvalidParameter(String),
OutOfMemory,
RuntimeError(String),
Unsupported,
ModelLoadFailed(String),
NullPointer,
ShapeMismatch {
expected: Vec<usize>,
got: Vec<usize>,
},
}
impl std::fmt::Display for MnnError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}
impl std::error::Error for MnnError {}
pub type Result<T> = std::result::Result<T, MnnError>;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum Backend {
#[default]
CPU,
Metal,
OpenCL,
OpenGL,
Vulkan,
CUDA,
CoreML,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum PrecisionMode {
#[default]
Normal,
Low,
High,
LowMemory,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum DataFormat {
#[default]
NCHW,
NHWC,
}
#[derive(Debug, Clone)]
pub struct InferenceConfig {
pub thread_count: i32,
pub precision_mode: PrecisionMode,
pub backend: Backend,
pub use_cache: bool,
pub data_format: DataFormat,
}
impl Default for InferenceConfig {
fn default() -> Self {
Self {
thread_count: 4,
precision_mode: PrecisionMode::Normal,
backend: Backend::CPU,
use_cache: true,
data_format: DataFormat::NCHW,
}
}
}
impl InferenceConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_threads(mut self, threads: i32) -> Self {
self.thread_count = threads;
self
}
pub fn with_precision(mut self, precision: PrecisionMode) -> Self {
self.precision_mode = precision;
self
}
pub fn with_backend(mut self, backend: Backend) -> Self {
self.backend = backend;
self
}
pub fn with_data_format(mut self, format: DataFormat) -> Self {
self.data_format = format;
self
}
}
pub struct SharedRuntime {
_private: (),
}
impl SharedRuntime {
pub fn new(_config: &InferenceConfig) -> Result<Self> {
unimplemented!(
"This feature is only available at runtime, not available during documentation build"
)
}
}
pub struct InferenceEngine {
_input_shape: Vec<usize>,
_output_shape: Vec<usize>,
}
impl InferenceEngine {
pub fn from_file(
_model_path: impl AsRef<Path>,
_config: Option<InferenceConfig>,
) -> Result<Self> {
unimplemented!(
"This feature is only available at runtime, not available during documentation build"
)
}
pub fn from_buffer(_data: &[u8], _config: Option<InferenceConfig>) -> Result<Self> {
unimplemented!(
"This feature is only available at runtime, not available during documentation build"
)
}
pub fn from_buffer_with_runtime(
_model_buffer: &[u8],
_runtime: &SharedRuntime,
) -> Result<Self> {
unimplemented!(
"This feature is only available at runtime, not available during documentation build"
)
}
pub fn input_shape(&self) -> &[usize] {
&self._input_shape
}
pub fn output_shape(&self) -> &[usize] {
&self._output_shape
}
pub fn infer(&self, _input: ArrayViewD<f32>) -> Result<ArrayD<f32>> {
unimplemented!()
}
pub fn infer_dynamic(&self, _input: ArrayViewD<f32>) -> Result<ArrayD<f32>> {
unimplemented!()
}
pub fn run_dynamic(&self, _input: ArrayViewD<f32>) -> Result<ArrayD<f32>> {
unimplemented!()
}
pub fn run_dynamic_raw(
&self,
_input_data: &[f32],
_input_shape: &[usize],
_output_data: &mut [f32],
) -> Result<Vec<usize>> {
unimplemented!()
}
}
pub fn get_version() -> String {
"unknown (docs.rs build)".to_string()
}