tract_proxy/
lib.rs

1use std::ffi::{CStr, CString};
2use std::path::Path;
3use std::ptr::{null, null_mut};
4
5use tract_api::*;
6use tract_proxy_sys as sys;
7
8use anyhow::{Context, Result};
9use ndarray::*;
10
11macro_rules! check {
12    ($expr:expr) => {
13        unsafe {
14            if $expr == sys::TRACT_RESULT_TRACT_RESULT_KO {
15                let buf = CStr::from_ptr(sys::tract_get_last_error());
16                Err(anyhow::anyhow!(buf.to_string_lossy().to_string()))
17            } else {
18                Ok(())
19            }
20        }
21    };
22}
23
24macro_rules! wrapper {
25    ($new_type:ident, $c_type:ident, $dest:ident $(, $typ:ty )*) => {
26        #[derive(Debug, Clone)]
27        pub struct $new_type(*mut sys::$c_type $(, $typ)*);
28
29        impl Drop for $new_type {
30            fn drop(&mut self) {
31                unsafe {
32                    sys::$dest(&mut self.0);
33                }
34            }
35        }
36    };
37}
38
39pub fn nnef() -> Result<Nnef> {
40    let mut nnef = null_mut();
41    check!(sys::tract_nnef_create(&mut nnef))?;
42    Ok(Nnef(nnef))
43}
44
45pub fn onnx() -> Result<Onnx> {
46    let mut onnx = null_mut();
47    check!(sys::tract_onnx_create(&mut onnx))?;
48    Ok(Onnx(onnx))
49}
50
51pub fn version() -> &'static str {
52    unsafe { CStr::from_ptr(sys::tract_version()).to_str().unwrap() }
53}
54
55wrapper!(Nnef, TractNnef, tract_nnef_destroy);
56impl NnefInterface for Nnef {
57    type Model = Model;
58    fn model_for_path(&self, path: impl AsRef<Path>) -> Result<Model> {
59        let path = path.as_ref();
60        let path = CString::new(
61            path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
62        )?;
63        let mut model = null_mut();
64        check!(sys::tract_nnef_model_for_path(self.0, path.as_ptr(), &mut model))?;
65        Ok(Model(model))
66    }
67
68    fn transform_model(&self, model: &mut Self::Model, transform_spec: &str) -> Result<()> {
69        let t = CString::new(transform_spec)?;
70        check!(sys::tract_nnef_transform_model(self.0, model.0, t.as_ptr()))
71    }
72
73    fn enable_tract_core(&mut self) -> Result<()> {
74        check!(sys::tract_nnef_enable_tract_core(self.0))
75    }
76
77    fn enable_tract_extra(&mut self) -> Result<()> {
78        check!(sys::tract_nnef_enable_tract_extra(self.0))
79    }
80
81    fn enable_tract_transformers(&mut self) -> Result<()> {
82        check!(sys::tract_nnef_enable_tract_transformers(self.0))
83    }
84
85    fn enable_onnx(&mut self) -> Result<()> {
86        check!(sys::tract_nnef_enable_onnx(self.0))
87    }
88
89    fn enable_pulse(&mut self) -> Result<()> {
90        check!(sys::tract_nnef_enable_pulse(self.0))
91    }
92
93    fn enable_extended_identifier_syntax(&mut self) -> Result<()> {
94        check!(sys::tract_nnef_enable_extended_identifier_syntax(self.0))
95    }
96
97    fn write_model_to_dir(&self, path: impl AsRef<Path>, model: &Model) -> Result<()> {
98        let path = path.as_ref();
99        let path = CString::new(
100            path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
101        )?;
102        check!(sys::tract_nnef_write_model_to_dir(self.0, path.as_ptr(), model.0))?;
103        Ok(())
104    }
105
106    fn write_model_to_tar(&self, path: impl AsRef<Path>, model: &Model) -> Result<()> {
107        let path = path.as_ref();
108        let path = CString::new(
109            path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
110        )?;
111        check!(sys::tract_nnef_write_model_to_tar(self.0, path.as_ptr(), model.0))?;
112        Ok(())
113    }
114
115    fn write_model_to_tar_gz(&self, path: impl AsRef<Path>, model: &Model) -> Result<()> {
116        let path = path.as_ref();
117        let path = CString::new(
118            path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
119        )?;
120        check!(sys::tract_nnef_write_model_to_tar_gz(self.0, path.as_ptr(), model.0))?;
121        Ok(())
122    }
123}
124
125// ONNX
126wrapper!(Onnx, TractOnnx, tract_onnx_destroy);
127
128impl OnnxInterface for Onnx {
129    type InferenceModel = InferenceModel;
130    fn model_for_path(&self, path: impl AsRef<Path>) -> Result<InferenceModel> {
131        let path = path.as_ref();
132        let path = CString::new(
133            path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
134        )?;
135        let mut model = null_mut();
136        check!(sys::tract_onnx_model_for_path(self.0, path.as_ptr(), &mut model))?;
137        Ok(InferenceModel(model))
138    }
139}
140
141// INFERENCE MODEL
142wrapper!(InferenceModel, TractInferenceModel, tract_inference_model_destroy);
143impl InferenceModelInterface for InferenceModel {
144    type Model = Model;
145    type InferenceFact = InferenceFact;
146    fn set_output_names(
147        &mut self,
148        outputs: impl IntoIterator<Item = impl AsRef<str>>,
149    ) -> Result<()> {
150        let c_strings: Vec<CString> =
151            outputs.into_iter().map(|a| Ok(CString::new(a.as_ref())?)).collect::<Result<_>>()?;
152        let ptrs: Vec<_> = c_strings.iter().map(|cs| cs.as_ptr()).collect();
153        check!(sys::tract_inference_model_set_output_names(
154            self.0,
155            c_strings.len(),
156            ptrs.as_ptr()
157        ))?;
158        Ok(())
159    }
160
161    fn input_count(&self) -> Result<usize> {
162        let mut count = 0;
163        check!(sys::tract_inference_model_input_count(self.0, &mut count))?;
164        Ok(count)
165    }
166
167    fn output_count(&self) -> Result<usize> {
168        let mut count = 0;
169        check!(sys::tract_inference_model_output_count(self.0, &mut count))?;
170        Ok(count)
171    }
172
173    fn input_name(&self, id: usize) -> Result<String> {
174        let mut ptr = null_mut();
175        check!(sys::tract_inference_model_input_name(self.0, id, &mut ptr))?;
176        unsafe {
177            let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
178            sys::tract_free_cstring(ptr);
179            Ok(ret)
180        }
181    }
182
183    fn output_name(&self, id: usize) -> Result<String> {
184        let mut ptr = null_mut();
185        check!(sys::tract_inference_model_output_name(self.0, id, &mut ptr))?;
186        unsafe {
187            let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
188            sys::tract_free_cstring(ptr);
189            Ok(ret)
190        }
191    }
192
193    fn input_fact(&self, id: usize) -> Result<InferenceFact> {
194        let mut ptr = null_mut();
195        check!(sys::tract_inference_model_input_fact(self.0, id, &mut ptr))?;
196        Ok(InferenceFact(ptr))
197    }
198
199    fn set_input_fact(
200        &mut self,
201        id: usize,
202        fact: impl AsFact<Self, Self::InferenceFact>,
203    ) -> Result<()> {
204        let fact = fact.as_fact(self)?;
205        check!(sys::tract_inference_model_set_input_fact(self.0, id, fact.0))?;
206        Ok(())
207    }
208
209    fn output_fact(&self, id: usize) -> Result<InferenceFact> {
210        let mut ptr = null_mut();
211        check!(sys::tract_inference_model_output_fact(self.0, id, &mut ptr))?;
212        Ok(InferenceFact(ptr))
213    }
214
215    fn set_output_fact(
216        &mut self,
217        id: usize,
218        fact: impl AsFact<InferenceModel, InferenceFact>,
219    ) -> Result<()> {
220        let fact = fact.as_fact(self)?;
221        check!(sys::tract_inference_model_set_output_fact(self.0, id, fact.0))?;
222        Ok(())
223    }
224
225    fn analyse(&mut self) -> Result<()> {
226        check!(sys::tract_inference_model_analyse(self.0))?;
227        Ok(())
228    }
229
230    fn into_typed(mut self) -> Result<Self::Model> {
231        let mut ptr = null_mut();
232        check!(sys::tract_inference_model_into_typed(&mut self.0, &mut ptr))?;
233        Ok(Model(ptr))
234    }
235
236    fn into_optimized(mut self) -> Result<Self::Model> {
237        let mut ptr = null_mut();
238        check!(sys::tract_inference_model_into_optimized(&mut self.0, &mut ptr))?;
239        Ok(Model(ptr))
240    }
241}
242
243// MODEL
244wrapper!(Model, TractModel, tract_model_destroy);
245
246impl ModelInterface for Model {
247    type Fact = Fact;
248    type Value = Value;
249    type Runnable = Runnable;
250    fn input_count(&self) -> Result<usize> {
251        let mut count = 0;
252        check!(sys::tract_model_input_count(self.0, &mut count))?;
253        Ok(count)
254    }
255
256    fn output_count(&self) -> Result<usize> {
257        let mut count = 0;
258        check!(sys::tract_model_output_count(self.0, &mut count))?;
259        Ok(count)
260    }
261
262    fn input_name(&self, id: usize) -> Result<String> {
263        let mut ptr = null_mut();
264        check!(sys::tract_model_input_name(self.0, id, &mut ptr))?;
265        unsafe {
266            let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
267            sys::tract_free_cstring(ptr);
268            Ok(ret)
269        }
270    }
271
272    fn output_name(&self, id: usize) -> Result<String> {
273        let mut ptr = null_mut();
274        check!(sys::tract_model_output_name(self.0, id, &mut ptr))?;
275        unsafe {
276            let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
277            sys::tract_free_cstring(ptr);
278            Ok(ret)
279        }
280    }
281
282    fn set_output_names(
283        &mut self,
284        outputs: impl IntoIterator<Item = impl AsRef<str>>,
285    ) -> Result<()> {
286        let c_strings: Vec<CString> =
287            outputs.into_iter().map(|a| Ok(CString::new(a.as_ref())?)).collect::<Result<_>>()?;
288        let ptrs: Vec<_> = c_strings.iter().map(|cs| cs.as_ptr()).collect();
289        check!(sys::tract_model_set_output_names(self.0, c_strings.len(), ptrs.as_ptr()))?;
290        Ok(())
291    }
292
293    fn input_fact(&self, id: usize) -> Result<Fact> {
294        let mut ptr = null_mut();
295        check!(sys::tract_model_input_fact(self.0, id, &mut ptr))?;
296        Ok(Fact(ptr))
297    }
298
299    fn output_fact(&self, id: usize) -> Result<Fact> {
300        let mut ptr = null_mut();
301        check!(sys::tract_model_output_fact(self.0, id, &mut ptr))?;
302        Ok(Fact(ptr))
303    }
304
305    fn declutter(&mut self) -> Result<()> {
306        check!(sys::tract_model_declutter(self.0))?;
307        Ok(())
308    }
309
310    fn optimize(&mut self) -> Result<()> {
311        check!(sys::tract_model_optimize(self.0))?;
312        Ok(())
313    }
314
315    fn into_decluttered(self) -> Result<Model> {
316        check!(sys::tract_model_declutter(self.0))?;
317        Ok(self)
318    }
319
320    fn into_optimized(self) -> Result<Model> {
321        check!(sys::tract_model_optimize(self.0))?;
322        Ok(self)
323    }
324
325    fn into_runnable(self) -> Result<Runnable> {
326        let mut model = self;
327        let mut runnable = null_mut();
328        check!(sys::tract_model_into_runnable(&mut model.0, &mut runnable))?;
329        Ok(Runnable(runnable))
330    }
331
332    fn concretize_symbols(
333        &mut self,
334        values: impl IntoIterator<Item = (impl AsRef<str>, i64)>,
335    ) -> Result<()> {
336        let (names, values): (Vec<_>, Vec<_>) = values.into_iter().unzip();
337        let c_strings: Vec<CString> =
338            names.into_iter().map(|a| Ok(CString::new(a.as_ref())?)).collect::<Result<_>>()?;
339        let ptrs: Vec<_> = c_strings.iter().map(|cs| cs.as_ptr()).collect();
340        check!(sys::tract_model_concretize_symbols(
341            self.0,
342            ptrs.len(),
343            ptrs.as_ptr(),
344            values.as_ptr()
345        ))?;
346        Ok(())
347    }
348
349    fn transform(&mut self, transform: &str) -> Result<()> {
350        let t = CString::new(transform)?;
351        check!(sys::tract_model_transform(self.0, t.as_ptr()))?;
352        Ok(())
353    }
354
355    fn pulse(&mut self, name: impl AsRef<str>, value: impl AsRef<str>) -> Result<()> {
356        let name = CString::new(name.as_ref())?;
357        let value = CString::new(value.as_ref())?;
358        check!(sys::tract_model_pulse_simple(&mut self.0, name.as_ptr(), value.as_ptr()))?;
359        Ok(())
360    }
361
362    fn cost_json(&self) -> Result<String> {
363        let input: Option<Vec<Value>> = None;
364        self.profile_json(input)
365    }
366
367    fn profile_json<I, V, E>(&self, inputs: Option<I>) -> Result<String>
368    where
369        I: IntoIterator<Item = V>,
370        V: TryInto<Value, Error = E>,
371        E: Into<anyhow::Error>,
372    {
373        let inputs = if let Some(inputs) = inputs {
374            let inputs = inputs
375                .into_iter()
376                .map(|i| i.try_into().map_err(|e| e.into()))
377                .collect::<Result<Vec<Value>>>()?;
378            anyhow::ensure!(self.input_count()? == inputs.len());
379            Some(inputs)
380        } else {
381            None
382        };
383        let mut iptrs: Option<Vec<*mut sys::TractValue>> =
384            inputs.as_ref().map(|is| is.iter().map(|v| v.0).collect());
385        let mut json: *mut i8 = null_mut();
386        let values = iptrs.as_mut().map(|it| it.as_mut_ptr()).unwrap_or(null_mut());
387        check!(sys::tract_model_profile_json(self.0, values, &mut json))?;
388        anyhow::ensure!(!json.is_null());
389        unsafe {
390            let s = CStr::from_ptr(json).to_owned();
391            sys::tract_free_cstring(json);
392            Ok(s.to_str()?.to_owned())
393        }
394    }
395
396    fn property_keys(&self) -> Result<Vec<String>> {
397        let mut len = 0;
398        check!(sys::tract_model_property_count(self.0, &mut len))?;
399        let mut keys = vec![null_mut(); len];
400        check!(sys::tract_model_property_names(self.0, keys.as_mut_ptr()))?;
401        unsafe {
402            keys.into_iter()
403                .map(|pc| {
404                    let s = CStr::from_ptr(pc).to_str()?.to_owned();
405                    sys::tract_free_cstring(pc);
406                    Ok(s)
407                })
408                .collect()
409        }
410    }
411
412    fn property(&self, name: impl AsRef<str>) -> Result<Value> {
413        let mut v = null_mut();
414        let name = CString::new(name.as_ref())?;
415        check!(sys::tract_model_property(self.0, name.as_ptr(), &mut v))?;
416        Ok(Value(v))
417    }
418}
419
420// RUNNABLE
421wrapper!(Runnable, TractRunnable, tract_runnable_release);
422
423impl RunnableInterface for Runnable {
424    type Value = Value;
425    type State = State;
426
427    fn run<I, V, E>(&self, inputs: I) -> Result<Vec<Value>>
428    where
429        I: IntoIterator<Item = V>,
430        V: TryInto<Value, Error = E>,
431        E: Into<anyhow::Error>,
432    {
433        self.spawn_state()?.run(inputs)
434    }
435
436    fn spawn_state(&self) -> Result<State> {
437        let mut state = null_mut();
438        check!(sys::tract_runnable_spawn_state(self.0, &mut state))?;
439        Ok(State(state))
440    }
441
442    fn input_count(&self) -> Result<usize> {
443        let mut count = 0;
444        check!(sys::tract_runnable_input_count(self.0, &mut count))?;
445        Ok(count)
446    }
447
448    fn output_count(&self) -> Result<usize> {
449        let mut count = 0;
450        check!(sys::tract_runnable_output_count(self.0, &mut count))?;
451        Ok(count)
452    }
453}
454
455// STATE
456wrapper!(State, TractState, tract_state_destroy);
457
458impl StateInterface for State {
459    type Value = Value;
460    fn run<I, V, E>(&mut self, inputs: I) -> Result<Vec<Value>>
461    where
462        I: IntoIterator<Item = V>,
463        V: TryInto<Value, Error = E>,
464        E: Into<anyhow::Error>,
465    {
466        let inputs = inputs
467            .into_iter()
468            .map(|i| i.try_into().map_err(|e| e.into()))
469            .collect::<Result<Vec<Value>>>()?;
470        let mut outputs = vec![null_mut(); self.output_count()?];
471        let mut inputs: Vec<_> = inputs.iter().map(|v| v.0).collect();
472        check!(sys::tract_state_run(self.0, inputs.as_mut_ptr(), outputs.as_mut_ptr()))?;
473        let outputs = outputs.into_iter().map(Value).collect();
474        Ok(outputs)
475    }
476
477    fn input_count(&self) -> Result<usize> {
478        let mut count = 0;
479        check!(sys::tract_state_input_count(self.0, &mut count))?;
480        Ok(count)
481    }
482
483    fn output_count(&self) -> Result<usize> {
484        let mut count = 0;
485        check!(sys::tract_state_output_count(self.0, &mut count))?;
486        Ok(count)
487    }
488}
489
490// VALUE
491wrapper!(Value, TractValue, tract_value_destroy);
492
493impl ValueInterface for Value {
494    fn from_bytes(dt: DatumType, shape: &[usize], data: &[u8]) -> Result<Self> {
495        anyhow::ensure!(data.len() == shape.iter().product::<usize>() * dt.size_of());
496        let mut value = null_mut();
497        check!(sys::tract_value_from_bytes(
498            dt as _,
499            shape.len(),
500            shape.as_ptr(),
501            data.as_ptr() as _,
502            &mut value
503        ))?;
504        Ok(Value(value))
505    }
506
507    fn as_bytes(&self) -> Result<(DatumType, &[usize], &[u8])> {
508        let mut rank = 0;
509        let mut dt = sys::DatumType_TRACT_DATUM_TYPE_BOOL as _;
510        let mut shape = null();
511        let mut data = null();
512        check!(sys::tract_value_as_bytes(self.0, &mut dt, &mut rank, &mut shape, &mut data))?;
513        unsafe {
514            let dt: DatumType = std::mem::transmute(dt);
515            let shape = std::slice::from_raw_parts(shape, rank);
516            let len: usize = shape.iter().product();
517            let data = std::slice::from_raw_parts(data as *const u8, len * dt.size_of());
518            Ok((dt, shape, data))
519        }
520    }
521}
522
523value_from_to_ndarray!();
524
525// FACT
526wrapper!(Fact, TractFact, tract_fact_destroy);
527
528impl Fact {
529    fn new(model: &mut Model, spec: impl ToString) -> Result<Fact> {
530        let cstr = CString::new(spec.to_string())?;
531        let mut fact = null_mut();
532        check!(sys::tract_fact_parse(model.0, cstr.as_ptr(), &mut fact))?;
533        Ok(Fact(fact))
534    }
535
536    fn dump(&self) -> Result<String> {
537        let mut ptr = null_mut();
538        check!(sys::tract_fact_dump(self.0, &mut ptr))?;
539        unsafe {
540            let s = CStr::from_ptr(ptr).to_owned();
541            sys::tract_free_cstring(ptr);
542            Ok(s.to_str()?.to_owned())
543        }
544    }
545}
546
547impl FactInterface for Fact {}
548
549impl std::fmt::Display for Fact {
550    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
551        match self.dump() {
552            Ok(s) => f.write_str(&s),
553            Err(_) => Err(std::fmt::Error),
554        }
555    }
556}
557
558// INFERENCE FACT
559wrapper!(InferenceFact, TractInferenceFact, tract_inference_fact_destroy);
560
561impl InferenceFact {
562    fn new(model: &mut InferenceModel, spec: impl ToString) -> Result<InferenceFact> {
563        let cstr = CString::new(spec.to_string())?;
564        let mut fact = null_mut();
565        check!(sys::tract_inference_fact_parse(model.0, cstr.as_ptr(), &mut fact))?;
566        Ok(InferenceFact(fact))
567    }
568
569    fn dump(&self) -> Result<String> {
570        let mut ptr = null_mut();
571        check!(sys::tract_inference_fact_dump(self.0, &mut ptr))?;
572        unsafe {
573            let s = CStr::from_ptr(ptr).to_owned();
574            sys::tract_free_cstring(ptr);
575            Ok(s.to_str()?.to_owned())
576        }
577    }
578}
579
580impl InferenceFactInterface for InferenceFact {
581    fn empty() -> Result<InferenceFact> {
582        let mut fact = null_mut();
583        check!(sys::tract_inference_fact_empty(&mut fact))?;
584        Ok(InferenceFact(fact))
585    }
586}
587
588impl Default for InferenceFact {
589    fn default() -> Self {
590        Self::empty().unwrap()
591    }
592}
593
594impl std::fmt::Display for InferenceFact {
595    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
596        match self.dump() {
597            Ok(s) => f.write_str(&s),
598            Err(_) => Err(std::fmt::Error),
599        }
600    }
601}
602
603as_inference_fact_impl!(InferenceModel, InferenceFact);
604as_fact_impl!(Model, Fact);