1#![allow(dead_code)]
2pub mod common;
3pub mod executor;
4pub mod onnxparser;
5pub mod operators;
6pub mod parallel;
7pub mod protograph;
8pub mod utils;
9
10use crate::common::MAX_OPSET_VERSION;
11use crate::executor::execute_model;
12use crate::onnxparser::onnx;
13use crate::utils::read_model;
14use common::{Args, BoxResult, VerbosityLevel};
15use once_cell::sync::Lazy;
16use utils::OutputInfo;
17use std::collections::HashMap;
18use std::ffi::CString;
19use std::path::{Path, PathBuf};
20
21static mut LAST_ERROR: Lazy<CString> = Lazy::new(|| CString::new(*b"").unwrap());
22
23#[repr(i64)]
24pub enum Verbosity {
25 Silent = -1,
26 Minimal = 0,
27 Informational = 1,
28 Results = 2,
29 Intermediate = 4,
30}
31
32#[repr(i64)]
33#[derive(PartialEq)]
34pub enum GraphFormat {
36 None = 0,
38 Json = 1,
40 Dot = 2,
42}
43
44#[repr(i64)]
45#[derive(PartialEq)]
46pub enum ExecutionMode {
48 FailFast = 1,
50 Continue = 0,
52}
53
54#[no_mangle]
55pub unsafe extern "C" fn read_onnx_model(
63 model_path: *const std::os::raw::c_char,
64) -> *mut onnx::ModelProto {
65 let model_path = unsafe { std::ffi::CStr::from_ptr(model_path) };
66 let model_path = match model_path.to_str() {
67 Ok(s) => s,
68 Err(e) => {
69 let e = match CString::new(e.to_string()) {
70 Ok(s) => s,
71 Err(_) => {
72 return std::ptr::null_mut();
73 }
74 };
75 *LAST_ERROR = e;
76 return std::ptr::null_mut();
77 }
78 };
79 let model = match read_model(Path::new(model_path)) {
80 Ok(m) => m,
81 Err(e) => {
82 let e = match CString::new(e.to_string()) {
83 Ok(s) => s,
84 Err(_) => {
85 return std::ptr::null_mut();
86 }
87 };
88 *LAST_ERROR = e;
89 return std::ptr::null_mut();
90 }
91 };
92 Box::into_raw(Box::new(model))
93}
94
95#[no_mangle]
96pub unsafe extern "C" fn free_onnx_model(model: *mut onnx::ModelProto) {
104 if model.is_null() {
105 return;
106 }
107 unsafe {
108 drop(Box::from_raw(model));
109 }
110}
111
112#[no_mangle]
113pub unsafe extern "C" fn get_opset_version(model: *const onnx::ModelProto) -> i64 {
123 if model.is_null() {
124 *LAST_ERROR = CString::new("NULL pointer passed to get_opset_version").unwrap();
125 return MAX_OPSET_VERSION;
126 }
127 unsafe {
128 if let Some(v) = (*model).opset_import.first() {
129 if let Some(v) = v.version {
130 v
131 } else {
132 MAX_OPSET_VERSION
133 }
134 } else {
135 MAX_OPSET_VERSION
136 }
137 }
138}
139
140pub struct Model {
141 path: PathBuf,
142 verbose: VerbosityLevel,
143 graph: bool,
144 graph_format: String,
145 failfast: bool,
146}
147
148impl Model {
149 pub fn path(&mut self, path: &str) -> &mut Self {
150 self.path = PathBuf::from(path);
151 self
152 }
153 pub fn verbose(&mut self, verbose: VerbosityLevel) -> &mut Self {
154 self.verbose = verbose;
155 self
156 }
157 pub fn graph(&mut self, graph: bool) -> &mut Self {
158 self.graph = graph;
159 self
160 }
161 pub fn graph_format(&mut self, graph_format: &str) -> &mut Self {
162 self.graph_format = graph_format.to_owned();
163 self
164 }
165 pub fn run(&self) -> BoxResult<HashMap<String, OutputInfo>> {
166 let args = Args::new(
167 self.path.clone(),
168 self.verbose,
169 self.graph,
170 self.graph_format.clone(),
171 self.failfast,
172 );
173 crate::execute_model(&args)
174 }
175}
176
177impl Default for Model {
178 fn default() -> Self {
179 Self {
180 path: PathBuf::new(),
181 verbose: VerbosityLevel::Minimal,
182 graph: false,
183 graph_format: "".to_owned(),
184 failfast: false,
185 }
186 }
187}
188
189#[no_mangle]
190pub unsafe extern "C" fn run_model(
206 model_path: *const std::os::raw::c_char,
207 verbosity: Verbosity,
208 graph_format: GraphFormat,
209 failfast: ExecutionMode,
210) -> bool {
211 let model_path = unsafe { std::ffi::CStr::from_ptr(model_path) };
212 let model_path = match model_path.to_str() {
213 Ok(s) => s,
214 Err(e) => {
215 let e = match CString::new(e.to_string()) {
216 Ok(s) => s,
217 Err(_) => {
218 return false;
219 }
220 };
221 *LAST_ERROR = e;
222 return false;
223 }
224 };
225 let gf = match graph_format {
226 GraphFormat::None => "".to_owned(),
227 GraphFormat::Json => "json".to_owned(),
228 GraphFormat::Dot => "dot".to_owned(),
229 };
230 let args = Args::from_parts(
231 model_path.into(),
232 verbosity as i64,
233 graph_format != GraphFormat::None,
234 gf,
235 failfast != ExecutionMode::Continue,
236 );
237 match crate::execute_model(&args) {
238 Ok(_) => true,
239 Err(e) => {
240 let e = match CString::new(e.to_string()) {
241 Ok(s) => s,
242 Err(_) => {
243 return false;
244 }
245 };
246 *LAST_ERROR = e;
247 false
248 }
249 }
250}
251
252#[no_mangle]
253pub unsafe extern "C" fn last_error() -> *const std::os::raw::c_char {
263 if LAST_ERROR.is_empty() {
264 return std::ptr::null();
265 }
266 LAST_ERROR.as_ptr()
267}