stonnx_api/
lib.rs

1#![allow(dead_code)]
2pub mod common;
3pub mod executor;
4pub mod onnxparser;
5pub mod operators;
6pub mod parallel;
7pub mod protograph;
8pub mod utils;
9
10use crate::common::MAX_OPSET_VERSION;
11use crate::executor::execute_model;
12use crate::onnxparser::onnx;
13use crate::utils::read_model;
14use common::{Args, BoxResult, VerbosityLevel};
15use once_cell::sync::Lazy;
16use utils::OutputInfo;
17use std::collections::HashMap;
18use std::ffi::CString;
19use std::path::{Path, PathBuf};
20
21static mut LAST_ERROR: Lazy<CString> = Lazy::new(|| CString::new(*b"").unwrap());
22
23#[repr(i64)]
24pub enum Verbosity {
25    Silent = -1,
26    Minimal = 0,
27    Informational = 1,
28    Results = 2,
29    Intermediate = 4,
30}
31
32#[repr(i64)]
33#[derive(PartialEq)]
34/// Graph format
35pub enum GraphFormat {
36    /// No graph output
37    None = 0,
38    /// Graph output in JSON format, saved to graph.json
39    Json = 1,
40    /// Graph output in DOT format, saved to graph.dot
41    Dot = 2,
42}
43
44#[repr(i64)]
45#[derive(PartialEq)]
46/// Execution mode
47pub enum ExecutionMode {
48    /// Fails immediately if it encounters an operator that is not implemented
49    FailFast = 1,
50    /// Continues execution if it encounters an operator that is not implemented, simply panicking when it encounters that operator
51    Continue = 0,
52}
53
54#[no_mangle]
55/// Reads an ONNX model from a file
56///
57/// Returns a pointer to a model, null if error, check last_error
58///
59/// # Safety
60///
61/// Should take a valid path as a C string
62pub unsafe extern "C" fn read_onnx_model(
63    model_path: *const std::os::raw::c_char,
64) -> *mut onnx::ModelProto {
65    let model_path = unsafe { std::ffi::CStr::from_ptr(model_path) };
66    let model_path = match model_path.to_str() {
67        Ok(s) => s,
68        Err(e) => {
69            let e = match CString::new(e.to_string()) {
70                Ok(s) => s,
71                Err(_) => {
72                    return std::ptr::null_mut();
73                }
74            };
75            *LAST_ERROR = e;
76            return std::ptr::null_mut();
77        }
78    };
79    let model = match read_model(Path::new(model_path)) {
80        Ok(m) => m,
81        Err(e) => {
82            let e = match CString::new(e.to_string()) {
83                Ok(s) => s,
84                Err(_) => {
85                    return std::ptr::null_mut();
86                }
87            };
88            *LAST_ERROR = e;
89            return std::ptr::null_mut();
90        }
91    };
92    Box::into_raw(Box::new(model))
93}
94
95#[no_mangle]
96/// Frees a model
97///
98/// Returns nothing, does nothing if given a null pointer
99///
100/// # Safety
101///
102/// Should take a valid pointer to a model
103pub unsafe extern "C" fn free_onnx_model(model: *mut onnx::ModelProto) {
104    if model.is_null() {
105        return;
106    }
107    unsafe {
108        drop(Box::from_raw(model));
109    }
110}
111
112#[no_mangle]
113/// Returns the opset version of a model
114///
115/// Returns MAX_OPSET_VERSION if no opset version is found
116///
117/// Returns MAX_OPSET_VERSION if given a null pointer and sets last_error
118///
119/// # Safety
120///
121/// Should take a valid pointer to a model
122pub unsafe extern "C" fn get_opset_version(model: *const onnx::ModelProto) -> i64 {
123    if model.is_null() {
124        *LAST_ERROR = CString::new("NULL pointer passed to get_opset_version").unwrap();
125        return MAX_OPSET_VERSION;
126    }
127    unsafe {
128        if let Some(v) = (*model).opset_import.first() {
129            if let Some(v) = v.version {
130                v
131            } else {
132                MAX_OPSET_VERSION
133            }
134        } else {
135            MAX_OPSET_VERSION
136        }
137    }
138}
139
140pub struct Model {
141    path: PathBuf,
142    verbose: VerbosityLevel,
143    graph: bool,
144    graph_format: String,
145    failfast: bool,
146} 
147
148impl Model {
149    pub fn path(&mut self, path: &str) -> &mut Self {
150        self.path = PathBuf::from(path);
151        self
152    }
153    pub fn verbose(&mut self, verbose: VerbosityLevel) -> &mut Self {
154        self.verbose = verbose;
155        self
156    }
157    pub fn graph(&mut self, graph: bool) -> &mut Self {
158        self.graph = graph;
159        self
160    }
161    pub fn graph_format(&mut self, graph_format: &str) -> &mut Self {
162        self.graph_format = graph_format.to_owned();
163        self
164    }
165    pub fn run(&self) -> BoxResult<HashMap<String, OutputInfo>> {
166        let args = Args::new(
167            self.path.clone(),
168            self.verbose,
169            self.graph,
170            self.graph_format.clone(),
171            self.failfast,
172        );
173        crate::execute_model(&args)
174    }
175}
176
177impl Default for Model {
178    fn default() -> Self {
179        Self {
180            path: PathBuf::new(),
181            verbose: VerbosityLevel::Minimal,
182            graph: false,
183            graph_format: "".to_owned(),
184            failfast: false,
185        }
186    }
187}
188
189#[no_mangle]
190/// Runs a model given a path to a model directory, verbosity level (0-4), graph format (json / dot), and execution mode (failfast / continue)
191///
192/// Returns true if successful, false if not
193///
194/// Sets last_error if an error occurs
195///
196/// # Safety
197///
198/// Should take a valid path to a model directory as a C string
199///
200/// Should take a valid verbosity level
201///
202/// Should take a valid graph format
203///
204/// Should take a valid execution mode
205pub unsafe extern "C" fn run_model(
206    model_path: *const std::os::raw::c_char,
207    verbosity: Verbosity,
208    graph_format: GraphFormat,
209    failfast: ExecutionMode,
210) -> bool {
211    let model_path = unsafe { std::ffi::CStr::from_ptr(model_path) };
212    let model_path = match model_path.to_str() {
213        Ok(s) => s,
214        Err(e) => {
215            let e = match CString::new(e.to_string()) {
216                Ok(s) => s,
217                Err(_) => {
218                    return false;
219                }
220            };
221            *LAST_ERROR = e;
222            return false;
223        }
224    };
225    let gf = match graph_format {
226        GraphFormat::None => "".to_owned(),
227        GraphFormat::Json => "json".to_owned(),
228        GraphFormat::Dot => "dot".to_owned(),
229    };
230    let args = Args::from_parts(
231        model_path.into(),
232        verbosity as i64,
233        graph_format != GraphFormat::None,
234        gf,
235        failfast != ExecutionMode::Continue,
236    );
237    match crate::execute_model(&args) {
238        Ok(_) => true,
239        Err(e) => {
240            let e = match CString::new(e.to_string()) {
241                Ok(s) => s,
242                Err(_) => {
243                    return false;
244                }
245            };
246            *LAST_ERROR = e;
247            false
248        }
249    }
250}
251
252#[no_mangle]
253/// Returns a pointer to a C string containing the last error
254///
255/// Returns a null pointer if no error is present
256///
257/// # Safety
258///
259/// Safe, returns a pointer to a C string, null if no error
260///
261/// Valid until the next call to run_model
262pub unsafe extern "C" fn last_error() -> *const std::os::raw::c_char {
263    if LAST_ERROR.is_empty() {
264        return std::ptr::null();
265    }
266    LAST_ERROR.as_ptr()
267}