1use std::{ffi, fmt::Debug, ops::Deref};
4
5use ndarray::Array;
6use tracing::{debug, error};
7
8use onnxruntime_sys_ng as sys;
9
10use crate::{
11 error::{assert_not_null_pointer, call_ort, status_to_result},
12 g_ort,
13 memory::MemoryInfo,
14 tensor::ndarray_tensor::NdArrayTensor,
15 OrtError, Result, TensorElementDataType, TypeToTensorElementDataType,
16};
17
18#[derive(Debug)]
26pub struct OrtTensor<'t, T, D>
27where
28 T: TypeToTensorElementDataType + Debug + Clone,
29 D: ndarray::Dimension,
30{
31 pub(crate) c_ptr: *mut sys::OrtValue,
32 array: Array<T, D>,
33 memory_info: &'t MemoryInfo,
34}
35
36impl<'t, T, D> OrtTensor<'t, T, D>
37where
38 T: TypeToTensorElementDataType + Debug + Clone,
39 D: ndarray::Dimension,
40{
41 pub(crate) fn from_array<'m>(
42 memory_info: &'m MemoryInfo,
43 allocator_ptr: *mut sys::OrtAllocator,
44 mut array: Array<T, D>,
45 ) -> Result<OrtTensor<'t, T, D>>
46 where
47 'm: 't, {
49 let mut tensor_ptr: *mut sys::OrtValue = std::ptr::null_mut();
51 let tensor_ptr_ptr: *mut *mut sys::OrtValue = &mut tensor_ptr;
52
53 let shape: Vec<i64> = array.shape().iter().map(|d: &usize| *d as i64).collect();
54 let shape_ptr: *const i64 = shape.as_ptr();
55 let shape_len = array.shape().len();
56
57 match T::tensor_element_data_type() {
58 TensorElementDataType::Float
59 | TensorElementDataType::Uint8
60 | TensorElementDataType::Int8
61 | TensorElementDataType::Uint16
62 | TensorElementDataType::Int16
63 | TensorElementDataType::Int32
64 | TensorElementDataType::Int64
65 | TensorElementDataType::Double
66 | TensorElementDataType::Uint32
67 | TensorElementDataType::Uint64 => {
68 let tensor_values_ptr: *mut std::ffi::c_void =
71 array.as_mut_ptr() as *mut std::ffi::c_void;
72 assert_not_null_pointer(tensor_values_ptr, "TensorValues")?;
73
74 unsafe {
75 call_ort(|ort| {
76 ort.CreateTensorWithDataAsOrtValue.unwrap()(
77 memory_info.ptr,
78 tensor_values_ptr,
79 array.len() * std::mem::size_of::<T>(),
80 shape_ptr,
81 shape_len,
82 T::tensor_element_data_type().into(),
83 tensor_ptr_ptr,
84 )
85 })
86 }
87 .map_err(OrtError::CreateTensorWithData)?;
88 assert_not_null_pointer(tensor_ptr, "Tensor")?;
89
90 let mut is_tensor = 0;
91 let status = unsafe { g_ort().IsTensor.unwrap()(tensor_ptr, &mut is_tensor) };
92 status_to_result(status).map_err(OrtError::IsTensor)?;
93 }
94 TensorElementDataType::String => {
95 unsafe {
97 call_ort(|ort| {
98 ort.CreateTensorAsOrtValue.unwrap()(
99 allocator_ptr,
100 shape_ptr,
101 shape_len,
102 T::tensor_element_data_type().into(),
103 tensor_ptr_ptr,
104 )
105 })
106 }
107 .map_err(OrtError::CreateTensor)?;
108
109 let null_terminated_copies: Vec<ffi::CString> = array
111 .iter()
112 .map(|elt| {
113 let slice = elt
114 .try_utf8_bytes()
115 .expect("String data type must provide utf8 bytes");
116 ffi::CString::new(slice)
117 })
118 .collect::<std::result::Result<Vec<_>, _>>()
119 .map_err(OrtError::CStringNulError)?;
120
121 let string_pointers = null_terminated_copies
122 .iter()
123 .map(|cstring| cstring.as_ptr())
124 .collect::<Vec<_>>();
125
126 unsafe {
127 call_ort(|ort| {
128 ort.FillStringTensor.unwrap()(
129 tensor_ptr,
130 string_pointers.as_ptr(),
131 string_pointers.len(),
132 )
133 })
134 }
135 .map_err(OrtError::FillStringTensor)?;
136 }
137 }
138
139 assert_not_null_pointer(tensor_ptr, "Tensor")?;
140
141 Ok(OrtTensor {
142 c_ptr: tensor_ptr,
143 array,
144 memory_info,
145 })
146 }
147}
148
149impl<'t, T, D> Deref for OrtTensor<'t, T, D>
150where
151 T: TypeToTensorElementDataType + Debug + Clone,
152 D: ndarray::Dimension,
153{
154 type Target = Array<T, D>;
155
156 fn deref(&self) -> &Self::Target {
157 &self.array
158 }
159}
160
161impl<'t, T, D> Drop for OrtTensor<'t, T, D>
162where
163 T: TypeToTensorElementDataType + Debug + Clone,
164 D: ndarray::Dimension,
165{
166 #[tracing::instrument]
167 fn drop(&mut self) {
168 debug!("Dropping Tensor.");
170 if self.c_ptr.is_null() {
171 error!("Null pointer, not calling free.");
172 } else {
173 unsafe { g_ort().ReleaseValue.unwrap()(self.c_ptr) }
174 }
175
176 self.c_ptr = std::ptr::null_mut();
177 }
178}
179
180impl<'t, T, D> OrtTensor<'t, T, D>
181where
182 T: TypeToTensorElementDataType + Debug + Clone,
183 D: ndarray::Dimension,
184{
185 pub fn softmax(&self, axis: ndarray::Axis) -> Array<T, D>
187 where
188 D: ndarray::RemoveAxis,
189 T: ndarray::NdFloat + std::ops::SubAssign + std::ops::DivAssign,
190 {
191 self.array.softmax(axis)
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198 use crate::{AllocatorType, MemType};
199 use ndarray::{arr0, arr1, arr2, arr3};
200 use std::ptr;
201 use test_env_log::test;
202
203 #[test]
204 fn orttensor_from_array_0d_i32() {
205 let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default).unwrap();
206 let array = arr0::<i32>(123);
207 let tensor = OrtTensor::from_array(&memory_info, ptr::null_mut(), array).unwrap();
208 let expected_shape: &[usize] = &[];
209 assert_eq!(tensor.shape(), expected_shape);
210 }
211
212 #[test]
213 fn orttensor_from_array_1d_i32() {
214 let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default).unwrap();
215 let array = arr1(&[1_i32, 2, 3, 4, 5, 6]);
216 let tensor = OrtTensor::from_array(&memory_info, ptr::null_mut(), array).unwrap();
217 let expected_shape: &[usize] = &[6];
218 assert_eq!(tensor.shape(), expected_shape);
219 }
220
221 #[test]
222 fn orttensor_from_array_2d_i32() {
223 let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default).unwrap();
224 let array = arr2(&[[1_i32, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]);
225 let tensor = OrtTensor::from_array(&memory_info, ptr::null_mut(), array).unwrap();
226 assert_eq!(tensor.shape(), &[2, 6]);
227 }
228
229 #[test]
230 fn orttensor_from_array_3d_i32() {
231 let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default).unwrap();
232 let array = arr3(&[
233 [[1_i32, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]],
234 [[13, 14, 15, 16, 17, 18], [19, 20, 21, 22, 23, 24]],
235 [[25, 26, 27, 28, 29, 30], [31, 32, 33, 34, 35, 36]],
236 ]);
237 let tensor = OrtTensor::from_array(&memory_info, ptr::null_mut(), array).unwrap();
238 assert_eq!(tensor.shape(), &[3, 2, 6]);
239 }
240
241 #[test]
242 fn orttensor_from_array_1d_string() {
243 let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default).unwrap();
244 let array = arr1(&[
245 String::from("foo"),
246 String::from("bar"),
247 String::from("baz"),
248 ]);
249 let tensor = OrtTensor::from_array(&memory_info, ort_default_allocator(), array).unwrap();
250 assert_eq!(tensor.shape(), &[3]);
251 }
252
253 #[test]
254 fn orttensor_from_array_3d_str() {
255 let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default).unwrap();
256 let array = arr3(&[
257 [["1", "2", "3"], ["4", "5", "6"]],
258 [["7", "8", "9"], ["10", "11", "12"]],
259 ]);
260 let tensor = OrtTensor::from_array(&memory_info, ort_default_allocator(), array).unwrap();
261 assert_eq!(tensor.shape(), &[2, 2, 3]);
262 }
263
264 fn ort_default_allocator() -> *mut sys::OrtAllocator {
265 let mut allocator_ptr: *mut sys::OrtAllocator = std::ptr::null_mut();
266 unsafe {
267 call_ort(|ort| ort.GetAllocatorWithDefaultOptions.unwrap()(&mut allocator_ptr))
269 }
270 .unwrap();
271 allocator_ptr
272 }
273}