onnxruntime_ng/tensor/
ort_owned_tensor.rs1use std::{fmt::Debug, ops::Deref};
4
5use ndarray::{Array, ArrayView};
6use tracing::debug;
7
8use onnxruntime_sys_ng as sys;
9
10use crate::{
11 error::status_to_result, g_ort, memory::MemoryInfo, tensor::ndarray_tensor::NdArrayTensor,
12 OrtError, Result, TypeToTensorElementDataType,
13};
14
15#[derive(Debug)]
26pub struct OrtOwnedTensor<'t, 'm, T, D>
27where
28 T: TypeToTensorElementDataType + Debug + Clone,
29 D: ndarray::Dimension,
30 'm: 't, {
32 pub(crate) tensor_ptr: *mut sys::OrtValue,
33 array_view: ArrayView<'t, T, D>,
34 memory_info: &'m MemoryInfo,
35}
36
37impl<'t, 'm, T, D> Deref for OrtOwnedTensor<'t, 'm, T, D>
38where
39 T: TypeToTensorElementDataType + Debug + Clone,
40 D: ndarray::Dimension,
41{
42 type Target = ArrayView<'t, T, D>;
43
44 fn deref(&self) -> &Self::Target {
45 &self.array_view
46 }
47}
48
49impl<'t, 'm, T, D> OrtOwnedTensor<'t, 'm, T, D>
50where
51 T: TypeToTensorElementDataType + Debug + Clone,
52 D: ndarray::Dimension,
53{
54 pub fn softmax(&self, axis: ndarray::Axis) -> Array<T, D>
56 where
57 D: ndarray::RemoveAxis,
58 T: ndarray::NdFloat + std::ops::SubAssign + std::ops::DivAssign,
59 {
60 self.array_view.softmax(axis)
61 }
62}
63
64#[derive(Debug)]
65pub(crate) struct OrtOwnedTensorExtractor<'m, D>
66where
67 D: ndarray::Dimension,
68{
69 pub(crate) tensor_ptr: *mut sys::OrtValue,
70 memory_info: &'m MemoryInfo,
71 shape: D,
72}
73
74impl<'m, D> OrtOwnedTensorExtractor<'m, D>
75where
76 D: ndarray::Dimension,
77{
78 pub(crate) fn new(memory_info: &'m MemoryInfo, shape: D) -> OrtOwnedTensorExtractor<'m, D> {
79 OrtOwnedTensorExtractor {
80 tensor_ptr: std::ptr::null_mut(),
81 memory_info,
82 shape,
83 }
84 }
85
86 pub(crate) fn extract<'t, T>(self) -> Result<OrtOwnedTensor<'t, 'm, T, D>>
87 where
88 T: TypeToTensorElementDataType + Debug + Clone,
89 {
90 assert_ne!(self.tensor_ptr, std::ptr::null_mut());
94
95 let mut is_tensor = 0;
96 let status = unsafe { g_ort().IsTensor.unwrap()(self.tensor_ptr, &mut is_tensor) };
97 status_to_result(status).map_err(OrtError::IsTensor)?;
98 (is_tensor == 1)
99 .then(|| ())
100 .ok_or(OrtError::IsTensorCheck)?;
101
102 let mut output_array_ptr: *mut T = std::ptr::null_mut();
104 let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr;
105 let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void =
106 output_array_ptr_ptr as *mut *mut std::ffi::c_void;
107 let status = unsafe {
108 g_ort().GetTensorMutableData.unwrap()(self.tensor_ptr, output_array_ptr_ptr_void)
109 };
110 status_to_result(status).map_err(OrtError::IsTensor)?;
111 assert_ne!(output_array_ptr, std::ptr::null_mut());
112
113 let array_view = unsafe { ArrayView::from_shape_ptr(self.shape, output_array_ptr) };
114
115 Ok(OrtOwnedTensor {
116 tensor_ptr: self.tensor_ptr,
117 array_view,
118 memory_info: self.memory_info,
119 })
120 }
121}
122
123impl<'t, 'm, T, D> Drop for OrtOwnedTensor<'t, 'm, T, D>
124where
125 T: TypeToTensorElementDataType + Debug + Clone,
126 D: ndarray::Dimension,
127 'm: 't, {
129 #[tracing::instrument]
130 fn drop(&mut self) {
131 debug!("Dropping OrtOwnedTensor.");
132 unsafe { g_ort().ReleaseValue.unwrap()(self.tensor_ptr) }
133
134 self.tensor_ptr = std::ptr::null_mut();
135 }
136}