1use std::ffi::{CStr, CString};
2use std::path::Path;
3use std::ptr::{null, null_mut};
4
5use tract_api::*;
6use tract_proxy_sys as sys;
7
8use anyhow::{Context, Result};
9
10mod ndarray_interop;
11pub use ndarray_interop::__ndarray_interop;
12
13macro_rules! check {
14 ($expr:expr) => {
15 unsafe {
16 if $expr == sys::TRACT_RESULT_TRACT_RESULT_KO {
17 let buf = CStr::from_ptr(sys::tract_get_last_error());
18 Err(anyhow::anyhow!(buf.to_string_lossy().to_string()))
19 } else {
20 Ok(())
21 }
22 }
23 };
24}
25
26macro_rules! wrapper {
27 ($new_type:ident, $c_type:ident, $dest:ident $(, $typ:ty )*) => {
28 #[derive(Debug)]
29 pub struct $new_type(*mut sys::$c_type $(, $typ)*);
30
31 impl Drop for $new_type {
32 fn drop(&mut self) {
33 unsafe {
34 sys::$dest(&mut self.0);
35 }
36 }
37 }
38 };
39}
40
41macro_rules! wrapper_clone {
42 ($new_type:ident, $clone_fn:ident) => {
43 impl Clone for $new_type {
44 fn clone(&self) -> Self {
45 let mut clone = null_mut();
46 unsafe {
47 sys::$clone_fn(self.0, &mut clone);
48 }
49 $new_type(clone)
50 }
51 }
52 };
53}
54
55pub fn nnef() -> Result<Nnef> {
56 let mut nnef = null_mut();
57 check!(sys::tract_nnef_create(&mut nnef))?;
58 Ok(Nnef(nnef))
59}
60
61pub fn onnx() -> Result<Onnx> {
62 let mut onnx = null_mut();
63 check!(sys::tract_onnx_create(&mut onnx))?;
64 Ok(Onnx(onnx))
65}
66
67pub fn version() -> &'static str {
68 unsafe { CStr::from_ptr(sys::tract_version()).to_str().unwrap() }
69}
70
71wrapper!(Nnef, TractNnef, tract_nnef_destroy);
72impl NnefInterface for Nnef {
73 type Model = Model;
74 fn load(&self, path: impl AsRef<Path>) -> Result<Model> {
75 let path = path.as_ref();
76 let path = CString::new(
77 path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
78 )?;
79 let mut model = null_mut();
80 check!(sys::tract_nnef_load(self.0, path.as_ptr(), &mut model))?;
81 Ok(Model(model))
82 }
83
84 fn load_buffer(&self, data: &[u8]) -> Result<Model> {
85 let mut model = null_mut();
86 check!(sys::tract_nnef_load_buffer(self.0, data.as_ptr() as _, data.len(), &mut model))?;
87 Ok(Model(model))
88 }
89
90 fn enable_tract_core(&mut self) -> Result<()> {
91 check!(sys::tract_nnef_enable_tract_core(self.0))
92 }
93
94 fn enable_tract_extra(&mut self) -> Result<()> {
95 check!(sys::tract_nnef_enable_tract_extra(self.0))
96 }
97
98 fn enable_tract_transformers(&mut self) -> Result<()> {
99 check!(sys::tract_nnef_enable_tract_transformers(self.0))
100 }
101
102 fn enable_onnx(&mut self) -> Result<()> {
103 check!(sys::tract_nnef_enable_onnx(self.0))
104 }
105
106 fn enable_pulse(&mut self) -> Result<()> {
107 check!(sys::tract_nnef_enable_pulse(self.0))
108 }
109
110 fn enable_extended_identifier_syntax(&mut self) -> Result<()> {
111 check!(sys::tract_nnef_enable_extended_identifier_syntax(self.0))
112 }
113
114 fn write_model_to_dir(&self, path: impl AsRef<Path>, model: &Model) -> Result<()> {
115 let path = path.as_ref();
116 let path = CString::new(
117 path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
118 )?;
119 check!(sys::tract_nnef_write_model_to_dir(self.0, path.as_ptr(), model.0))?;
120 Ok(())
121 }
122
123 fn write_model_to_tar(&self, path: impl AsRef<Path>, model: &Model) -> Result<()> {
124 let path = path.as_ref();
125 let path = CString::new(
126 path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
127 )?;
128 check!(sys::tract_nnef_write_model_to_tar(self.0, path.as_ptr(), model.0))?;
129 Ok(())
130 }
131
132 fn write_model_to_tar_gz(&self, path: impl AsRef<Path>, model: &Model) -> Result<()> {
133 let path = path.as_ref();
134 let path = CString::new(
135 path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
136 )?;
137 check!(sys::tract_nnef_write_model_to_tar_gz(self.0, path.as_ptr(), model.0))?;
138 Ok(())
139 }
140}
141
142wrapper!(Onnx, TractOnnx, tract_onnx_destroy);
144
145impl OnnxInterface for Onnx {
146 type InferenceModel = InferenceModel;
147 fn load(&self, path: impl AsRef<Path>) -> Result<InferenceModel> {
148 let path = path.as_ref();
149 let path = CString::new(
150 path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
151 )?;
152 let mut model = null_mut();
153 check!(sys::tract_onnx_load(self.0, path.as_ptr(), &mut model))?;
154 Ok(InferenceModel(model))
155 }
156
157 fn load_buffer(&self, data: &[u8]) -> Result<InferenceModel> {
158 let mut model = null_mut();
159 check!(sys::tract_onnx_load_buffer(self.0, data.as_ptr() as _, data.len(), &mut model))?;
160 Ok(InferenceModel(model))
161 }
162}
163
164wrapper!(InferenceModel, TractInferenceModel, tract_inference_model_destroy);
166impl InferenceModelInterface for InferenceModel {
167 type Model = Model;
168 type InferenceFact = InferenceFact;
169 fn input_count(&self) -> Result<usize> {
170 let mut count = 0;
171 check!(sys::tract_inference_model_input_count(self.0, &mut count))?;
172 Ok(count)
173 }
174
175 fn output_count(&self) -> Result<usize> {
176 let mut count = 0;
177 check!(sys::tract_inference_model_output_count(self.0, &mut count))?;
178 Ok(count)
179 }
180
181 fn input_name(&self, id: usize) -> Result<String> {
182 let mut ptr = null_mut();
183 check!(sys::tract_inference_model_input_name(self.0, id, &mut ptr))?;
184 unsafe {
185 let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
186 sys::tract_free_cstring(ptr);
187 Ok(ret)
188 }
189 }
190
191 fn output_name(&self, id: usize) -> Result<String> {
192 let mut ptr = null_mut();
193 check!(sys::tract_inference_model_output_name(self.0, id, &mut ptr))?;
194 unsafe {
195 let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
196 sys::tract_free_cstring(ptr);
197 Ok(ret)
198 }
199 }
200
201 fn input_fact(&self, id: usize) -> Result<InferenceFact> {
202 let mut ptr = null_mut();
203 check!(sys::tract_inference_model_input_fact(self.0, id, &mut ptr))?;
204 Ok(InferenceFact(ptr))
205 }
206
207 fn set_input_fact(
208 &mut self,
209 id: usize,
210 fact: impl AsFact<Self, Self::InferenceFact>,
211 ) -> Result<()> {
212 let fact = fact.as_fact(self)?;
213 check!(sys::tract_inference_model_set_input_fact(self.0, id, fact.0))?;
214 Ok(())
215 }
216
217 fn output_fact(&self, id: usize) -> Result<InferenceFact> {
218 let mut ptr = null_mut();
219 check!(sys::tract_inference_model_output_fact(self.0, id, &mut ptr))?;
220 Ok(InferenceFact(ptr))
221 }
222
223 fn set_output_fact(
224 &mut self,
225 id: usize,
226 fact: impl AsFact<InferenceModel, InferenceFact>,
227 ) -> Result<()> {
228 let fact = fact.as_fact(self)?;
229 check!(sys::tract_inference_model_set_output_fact(self.0, id, fact.0))?;
230 Ok(())
231 }
232
233 fn analyse(&mut self) -> Result<()> {
234 check!(sys::tract_inference_model_analyse(self.0))?;
235 Ok(())
236 }
237
238 fn into_model(mut self) -> Result<Self::Model> {
239 let mut ptr = null_mut();
240 check!(sys::tract_inference_model_into_model(&mut self.0, &mut ptr))?;
241 Ok(Model(ptr))
242 }
243}
244
245wrapper!(Model, TractModel, tract_model_destroy);
247
248impl ModelInterface for Model {
249 type Fact = Fact;
250 type Tensor = Tensor;
251 type Runnable = Runnable;
252 fn input_count(&self) -> Result<usize> {
253 let mut count = 0;
254 check!(sys::tract_model_input_count(self.0, &mut count))?;
255 Ok(count)
256 }
257
258 fn output_count(&self) -> Result<usize> {
259 let mut count = 0;
260 check!(sys::tract_model_output_count(self.0, &mut count))?;
261 Ok(count)
262 }
263
264 fn input_name(&self, id: usize) -> Result<String> {
265 let mut ptr = null_mut();
266 check!(sys::tract_model_input_name(self.0, id, &mut ptr))?;
267 unsafe {
268 let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
269 sys::tract_free_cstring(ptr);
270 Ok(ret)
271 }
272 }
273
274 fn output_name(&self, id: usize) -> Result<String> {
275 let mut ptr = null_mut();
276 check!(sys::tract_model_output_name(self.0, id, &mut ptr))?;
277 unsafe {
278 let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
279 sys::tract_free_cstring(ptr);
280 Ok(ret)
281 }
282 }
283
284 fn input_fact(&self, id: usize) -> Result<Fact> {
285 let mut ptr = null_mut();
286 check!(sys::tract_model_input_fact(self.0, id, &mut ptr))?;
287 Ok(Fact(ptr))
288 }
289
290 fn output_fact(&self, id: usize) -> Result<Fact> {
291 let mut ptr = null_mut();
292 check!(sys::tract_model_output_fact(self.0, id, &mut ptr))?;
293 Ok(Fact(ptr))
294 }
295
296 fn into_runnable(self) -> Result<Runnable> {
297 let mut model = self;
298 let mut runnable = null_mut();
299 check!(sys::tract_model_into_runnable(&mut model.0, &mut runnable))?;
300 Ok(Runnable(runnable))
301 }
302
303 fn transform(&mut self, spec: impl Into<TransformSpec>) -> Result<()> {
304 let transform = spec.into().to_transform_string();
305 let t = CString::new(transform)?;
306 check!(sys::tract_model_transform(self.0, t.as_ptr()))?;
307 Ok(())
308 }
309
310 fn property_keys(&self) -> Result<Vec<String>> {
311 let mut len = 0;
312 check!(sys::tract_model_property_count(self.0, &mut len))?;
313 let mut keys = vec![null_mut(); len];
314 check!(sys::tract_model_property_names(self.0, keys.as_mut_ptr()))?;
315 unsafe {
316 keys.into_iter()
317 .map(|pc| {
318 let s = CStr::from_ptr(pc).to_str()?.to_owned();
319 sys::tract_free_cstring(pc);
320 Ok(s)
321 })
322 .collect()
323 }
324 }
325
326 fn property(&self, name: impl AsRef<str>) -> Result<Tensor> {
327 let mut v = null_mut();
328 let name = CString::new(name.as_ref())?;
329 check!(sys::tract_model_property(self.0, name.as_ptr(), &mut v))?;
330 Ok(Tensor(v))
331 }
332
333 fn parse_fact(&self, spec: &str) -> Result<Self::Fact> {
334 let spec = CString::new(spec)?;
335 let mut ptr = null_mut();
336 check!(sys::tract_model_parse_fact(self.0, spec.as_ptr(), &mut ptr))?;
337 Ok(Fact(ptr))
338 }
339}
340
341wrapper!(Runtime, TractRuntime, tract_runtime_release);
343
344pub fn runtime_for_name(name: &str) -> Result<Runtime> {
345 let mut rt = null_mut();
346 let name = CString::new(name)?;
347 check!(sys::tract_runtime_for_name(name.as_ptr(), &mut rt))?;
348 Ok(Runtime(rt))
349}
350
351impl RuntimeInterface for Runtime {
352 type Runnable = Runnable;
353
354 type Model = Model;
355
356 fn name(&self) -> Result<String> {
357 let mut ptr = null_mut();
358 check!(sys::tract_runtime_name(self.0, &mut ptr))?;
359 unsafe {
360 let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
361 sys::tract_free_cstring(ptr);
362 Ok(ret)
363 }
364 }
365
366 fn prepare(&self, model: Self::Model) -> Result<Self::Runnable> {
367 let mut model = model;
368 let mut runnable = null_mut();
369 check!(sys::tract_runtime_prepare(self.0, &mut model.0, &mut runnable))?;
370 Ok(Runnable(runnable))
371 }
372}
373
374wrapper!(Runnable, TractRunnable, tract_runnable_release);
376unsafe impl Send for Runnable {}
377unsafe impl Sync for Runnable {}
378
379impl RunnableInterface for Runnable {
380 type Tensor = Tensor;
381 type State = State;
382 type Fact = Fact;
383
384 fn run(&self, inputs: impl IntoInputs<Tensor>) -> Result<Vec<Tensor>> {
385 StateInterface::run(&mut self.spawn_state()?, inputs.into_inputs()?)
386 }
387
388 fn spawn_state(&self) -> Result<State> {
389 let mut state = null_mut();
390 check!(sys::tract_runnable_spawn_state(self.0, &mut state))?;
391 Ok(State(state))
392 }
393
394 fn input_count(&self) -> Result<usize> {
395 let mut count = 0;
396 check!(sys::tract_runnable_input_count(self.0, &mut count))?;
397 Ok(count)
398 }
399
400 fn output_count(&self) -> Result<usize> {
401 let mut count = 0;
402 check!(sys::tract_runnable_output_count(self.0, &mut count))?;
403 Ok(count)
404 }
405
406 fn input_fact(&self, id: usize) -> Result<Self::Fact> {
407 let mut ptr = null_mut();
408 check!(sys::tract_runnable_input_fact(self.0, id, &mut ptr))?;
409 Ok(Fact(ptr))
410 }
411
412 fn output_fact(&self, id: usize) -> Result<Self::Fact> {
413 let mut ptr = null_mut();
414 check!(sys::tract_runnable_output_fact(self.0, id, &mut ptr))?;
415 Ok(Fact(ptr))
416 }
417
418 fn property_keys(&self) -> Result<Vec<String>> {
419 let mut len = 0;
420 check!(sys::tract_runnable_property_count(self.0, &mut len))?;
421 let mut keys = vec![null_mut(); len];
422 check!(sys::tract_runnable_property_names(self.0, keys.as_mut_ptr()))?;
423 unsafe {
424 keys.into_iter()
425 .map(|pc| {
426 let s = CStr::from_ptr(pc).to_str()?.to_owned();
427 sys::tract_free_cstring(pc);
428 Ok(s)
429 })
430 .collect()
431 }
432 }
433
434 fn property(&self, name: impl AsRef<str>) -> Result<Tensor> {
435 let mut v = null_mut();
436 let name = CString::new(name.as_ref())?;
437 check!(sys::tract_runnable_property(self.0, name.as_ptr(), &mut v))?;
438 Ok(Tensor(v))
439 }
440
441 fn cost_json(&self) -> Result<String> {
442 let input: Option<Vec<Tensor>> = None;
443 self.profile_json(input)
444 }
445
446 fn profile_json<I, IV, IE>(&self, inputs: Option<I>) -> Result<String>
447 where
448 I: IntoIterator<Item = IV>,
449 IV: TryInto<Self::Tensor, Error = IE>,
450 IE: Into<anyhow::Error>,
451 {
452 let inputs = if let Some(inputs) = inputs {
453 let inputs = inputs
454 .into_iter()
455 .map(|i| i.try_into().map_err(|e| e.into()))
456 .collect::<Result<Vec<Tensor>>>()?;
457 anyhow::ensure!(self.input_count()? == inputs.len());
458 Some(inputs)
459 } else {
460 None
461 };
462 let mut iptrs: Option<Vec<*mut sys::TractTensor>> =
463 inputs.as_ref().map(|is| is.iter().map(|v| v.0).collect());
464 let mut json: *mut i8 = null_mut();
465 let values = iptrs.as_mut().map(|it| it.as_mut_ptr()).unwrap_or(null_mut());
466
467 check!(sys::tract_runnable_profile_json(self.0, values, &mut json))?;
468 anyhow::ensure!(!json.is_null());
469 unsafe {
470 let s = CStr::from_ptr(json).to_owned();
471 sys::tract_free_cstring(json);
472 Ok(s.to_str()?.to_owned())
473 }
474 }
475}
476
477pub struct State(*mut sys::TractState);
479
480impl Drop for State {
481 fn drop(&mut self) {
482 unsafe {
483 sys::tract_state_destroy(&mut self.0);
484 }
485 }
486}
487
488impl std::fmt::Debug for State {
489 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
490 write!(f, "State({:?})", self.0)
491 }
492}
493
494impl Clone for State {
495 fn clone(&self) -> Self {
496 let mut clone = null_mut();
497 unsafe {
498 sys::tract_state_clone(self.0, &mut clone);
499 }
500 State(clone)
501 }
502}
503
504unsafe impl Send for State {}
506
507impl StateInterface for State {
508 type Tensor = Tensor;
509 type Fact = Fact;
510
511 fn run(&mut self, inputs: impl IntoInputs<Tensor>) -> Result<Vec<Tensor>> {
512 let inputs = inputs.into_inputs()?;
513 let mut outputs = vec![null_mut(); self.output_count()?];
514 let mut inputs: Vec<_> = inputs.iter().map(|v| v.0).collect();
515 check!(sys::tract_state_run(self.0, inputs.as_mut_ptr(), outputs.as_mut_ptr()))?;
516 let outputs = outputs.into_iter().map(Tensor).collect();
517 Ok(outputs)
518 }
519
520 fn input_count(&self) -> Result<usize> {
521 let mut count = 0;
522 check!(sys::tract_state_input_count(self.0, &mut count))?;
523 Ok(count)
524 }
525
526 fn output_count(&self) -> Result<usize> {
527 let mut count = 0;
528 check!(sys::tract_state_output_count(self.0, &mut count))?;
529 Ok(count)
530 }
531}
532
533wrapper!(Tensor, TractTensor, tract_tensor_destroy);
535wrapper_clone!(Tensor, tract_tensor_clone);
536unsafe impl Send for Tensor {}
537unsafe impl Sync for Tensor {}
538
539impl TensorInterface for Tensor {
540 fn from_bytes(dt: DatumType, shape: &[usize], data: &[u8]) -> Result<Self> {
541 anyhow::ensure!(data.len() == shape.iter().product::<usize>() * dt.size_of());
542 let mut value = null_mut();
543 check!(sys::tract_tensor_from_bytes(
544 dt as _,
545 shape.len(),
546 shape.as_ptr(),
547 data.as_ptr() as _,
548 &mut value
549 ))?;
550 Ok(Tensor(value))
551 }
552
553 fn as_bytes(&self) -> Result<(DatumType, &[usize], &[u8])> {
554 let mut rank = 0;
555 let mut dt = sys::DatumType_TRACT_DATUM_TYPE_BOOL as _;
556 let mut shape = null();
557 let mut data = null();
558 check!(sys::tract_tensor_as_bytes(self.0, &mut dt, &mut rank, &mut shape, &mut data))?;
559 unsafe {
560 let dt: DatumType = std::mem::transmute(dt);
561 let shape = std::slice::from_raw_parts(shape, rank);
562 let len: usize = shape.iter().product();
563 let data = std::slice::from_raw_parts(data as *const u8, len * dt.size_of());
564 Ok((dt, shape, data))
565 }
566 }
567
568 fn datum_type(&self) -> Result<DatumType> {
569 let mut dt = sys::DatumType_TRACT_DATUM_TYPE_BOOL as _;
570 check!(sys::tract_tensor_as_bytes(
571 self.0,
572 &mut dt,
573 std::ptr::null_mut(),
574 std::ptr::null_mut(),
575 std::ptr::null_mut()
576 ))?;
577 unsafe {
578 let dt: DatumType = std::mem::transmute(dt);
579 Ok(dt)
580 }
581 }
582
583 fn convert_to(&self, to: DatumType) -> Result<Self> {
584 let mut new = null_mut();
585 check!(sys::tract_tensor_convert_to(self.0, to as _, &mut new))?;
586 Ok(Tensor(new))
587 }
588}
589
590impl PartialEq for Tensor {
591 fn eq(&self, other: &Self) -> bool {
592 let Ok((me_dt, me_shape, me_data)) = self.as_bytes() else { return false };
593 let Ok((other_dt, other_shape, other_data)) = other.as_bytes() else { return false };
594 me_dt == other_dt && me_shape == other_shape && me_data == other_data
595 }
596}
597
598wrapper!(Fact, TractFact, tract_fact_destroy);
600wrapper_clone!(Fact, tract_fact_clone);
601
602impl Fact {
603 fn new(model: &Model, spec: impl ToString) -> Result<Fact> {
604 let cstr = CString::new(spec.to_string())?;
605 let mut fact = null_mut();
606 check!(sys::tract_model_parse_fact(model.0, cstr.as_ptr(), &mut fact))?;
607 Ok(Fact(fact))
608 }
609
610 fn dump(&self) -> Result<String> {
611 let mut ptr = null_mut();
612 check!(sys::tract_fact_dump(self.0, &mut ptr))?;
613 unsafe {
614 let s = CStr::from_ptr(ptr).to_owned();
615 sys::tract_free_cstring(ptr);
616 Ok(s.to_str()?.to_owned())
617 }
618 }
619}
620
621impl FactInterface for Fact {
622 type Dim = Dim;
623
624 fn datum_type(&self) -> Result<DatumType> {
625 let mut dt = 0u32;
626 check!(sys::tract_fact_datum_type(self.0, &mut dt as *const u32 as _))?;
627 Ok(unsafe { std::mem::transmute::<u32, DatumType>(dt) })
628 }
629
630 fn rank(&self) -> Result<usize> {
631 let mut rank = 0;
632 check!(sys::tract_fact_rank(self.0, &mut rank))?;
633 Ok(rank)
634 }
635
636 fn dim(&self, axis: usize) -> Result<Self::Dim> {
637 let mut ptr = null_mut();
638 check!(sys::tract_fact_dim(self.0, axis, &mut ptr))?;
639 Ok(Dim(ptr))
640 }
641}
642
643impl std::fmt::Display for Fact {
644 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
645 match self.dump() {
646 Ok(s) => f.write_str(&s),
647 Err(_) => Err(std::fmt::Error),
648 }
649 }
650}
651
652wrapper!(InferenceFact, TractInferenceFact, tract_inference_fact_destroy);
654wrapper_clone!(InferenceFact, tract_inference_fact_clone);
655
656impl InferenceFact {
657 fn new(model: &InferenceModel, spec: impl ToString) -> Result<InferenceFact> {
658 let cstr = CString::new(spec.to_string())?;
659 let mut fact = null_mut();
660 check!(sys::tract_inference_fact_parse(model.0, cstr.as_ptr(), &mut fact))?;
661 Ok(InferenceFact(fact))
662 }
663
664 fn dump(&self) -> Result<String> {
665 let mut ptr = null_mut();
666 check!(sys::tract_inference_fact_dump(self.0, &mut ptr))?;
667 unsafe {
668 let s = CStr::from_ptr(ptr).to_owned();
669 sys::tract_free_cstring(ptr);
670 Ok(s.to_str()?.to_owned())
671 }
672 }
673}
674
675impl InferenceFactInterface for InferenceFact {
676 fn empty() -> Result<InferenceFact> {
677 let mut fact = null_mut();
678 check!(sys::tract_inference_fact_empty(&mut fact))?;
679 Ok(InferenceFact(fact))
680 }
681}
682
683impl Default for InferenceFact {
684 fn default() -> Self {
685 Self::empty().unwrap()
686 }
687}
688
689impl std::fmt::Display for InferenceFact {
690 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
691 match self.dump() {
692 Ok(s) => f.write_str(&s),
693 Err(_) => Err(std::fmt::Error),
694 }
695 }
696}
697
698as_inference_fact_impl!(InferenceModel, InferenceFact);
699as_fact_impl!(Model, Fact);
700
701wrapper!(Dim, TractDim, tract_dim_destroy);
703wrapper_clone!(Dim, tract_dim_clone);
704
705impl Dim {
706 fn dump(&self) -> Result<String> {
707 let mut ptr = null_mut();
708 check!(sys::tract_dim_dump(self.0, &mut ptr))?;
709 unsafe {
710 let s = CStr::from_ptr(ptr).to_owned();
711 sys::tract_free_cstring(ptr);
712 Ok(s.to_str()?.to_owned())
713 }
714 }
715}
716
717impl DimInterface for Dim {
718 fn eval(&self, values: impl IntoIterator<Item = (impl AsRef<str>, i64)>) -> Result<Self> {
719 let (names, values): (Vec<_>, Vec<_>) = values.into_iter().unzip();
720 let c_strings: Vec<CString> =
721 names.into_iter().map(|a| Ok(CString::new(a.as_ref())?)).collect::<Result<_>>()?;
722 let ptrs: Vec<_> = c_strings.iter().map(|cs| cs.as_ptr()).collect();
723 let mut ptr = null_mut();
724 check!(sys::tract_dim_eval(self.0, ptrs.len(), ptrs.as_ptr(), values.as_ptr(), &mut ptr))?;
725 Ok(Dim(ptr))
726 }
727
728 fn to_int64(&self) -> Result<i64> {
729 let mut i = 0;
730 check!(sys::tract_dim_to_int64(self.0, &mut i))?;
731 Ok(i)
732 }
733}
734
735impl std::fmt::Display for Dim {
736 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
737 match self.dump() {
738 Ok(s) => f.write_str(&s),
739 Err(_) => Err(std::fmt::Error),
740 }
741 }
742}
743
744#[cfg(test)]
745mod tests {
746 use super::*;
747
748 #[test]
749 fn clone_tensor_no_double_free() {
750 let t = Tensor::from_slice::<f32>(&[2, 2], &[1.0, 2.0, 3.0, 4.0]).unwrap();
751 let clone = t.clone();
752 assert_eq!(t, clone);
753 }
754}