onnxruntime_ng/tensor/
ort_tensor.rs

1//! Module containing tensor with memory owned by Rust
2
3use std::{ffi, fmt::Debug, ops::Deref};
4
5use ndarray::Array;
6use tracing::{debug, error};
7
8use onnxruntime_sys_ng as sys;
9
10use crate::{
11    error::{assert_not_null_pointer, call_ort, status_to_result},
12    g_ort,
13    memory::MemoryInfo,
14    tensor::ndarray_tensor::NdArrayTensor,
15    OrtError, Result, TensorElementDataType, TypeToTensorElementDataType,
16};
17
18/// Owned tensor, backed by an [`ndarray::Array`](https://docs.rs/ndarray/latest/ndarray/type.Array.html)
19///
20/// This tensor bounds the ONNX Runtime to `ndarray`; it is used to copy an
21/// [`ndarray::Array`](https://docs.rs/ndarray/latest/ndarray/type.Array.html) to the runtime's memory.
22///
23/// **NOTE**: The type is not meant to be used directly, use an [`ndarray::Array`](https://docs.rs/ndarray/latest/ndarray/type.Array.html)
24/// instead.
25#[derive(Debug)]
26pub struct OrtTensor<'t, T, D>
27where
28    T: TypeToTensorElementDataType + Debug + Clone,
29    D: ndarray::Dimension,
30{
31    pub(crate) c_ptr: *mut sys::OrtValue,
32    array: Array<T, D>,
33    memory_info: &'t MemoryInfo,
34}
35
36impl<'t, T, D> OrtTensor<'t, T, D>
37where
38    T: TypeToTensorElementDataType + Debug + Clone,
39    D: ndarray::Dimension,
40{
41    pub(crate) fn from_array<'m>(
42        memory_info: &'m MemoryInfo,
43        allocator_ptr: *mut sys::OrtAllocator,
44        mut array: Array<T, D>,
45    ) -> Result<OrtTensor<'t, T, D>>
46    where
47        'm: 't, // 'm outlives 't
48    {
49        // where onnxruntime will write the tensor data to
50        let mut tensor_ptr: *mut sys::OrtValue = std::ptr::null_mut();
51        let tensor_ptr_ptr: *mut *mut sys::OrtValue = &mut tensor_ptr;
52
53        let shape: Vec<i64> = array.shape().iter().map(|d: &usize| *d as i64).collect();
54        let shape_ptr: *const i64 = shape.as_ptr();
55        let shape_len = array.shape().len();
56
57        match T::tensor_element_data_type() {
58            TensorElementDataType::Float
59            | TensorElementDataType::Uint8
60            | TensorElementDataType::Int8
61            | TensorElementDataType::Uint16
62            | TensorElementDataType::Int16
63            | TensorElementDataType::Int32
64            | TensorElementDataType::Int64
65            | TensorElementDataType::Double
66            | TensorElementDataType::Uint32
67            | TensorElementDataType::Uint64 => {
68                // primitive data is already suitably laid out in memory; provide it to
69                // onnxruntime as is
70                let tensor_values_ptr: *mut std::ffi::c_void =
71                    array.as_mut_ptr() as *mut std::ffi::c_void;
72                assert_not_null_pointer(tensor_values_ptr, "TensorValues")?;
73
74                unsafe {
75                    call_ort(|ort| {
76                        ort.CreateTensorWithDataAsOrtValue.unwrap()(
77                            memory_info.ptr,
78                            tensor_values_ptr,
79                            array.len() * std::mem::size_of::<T>(),
80                            shape_ptr,
81                            shape_len,
82                            T::tensor_element_data_type().into(),
83                            tensor_ptr_ptr,
84                        )
85                    })
86                }
87                .map_err(OrtError::CreateTensorWithData)?;
88                assert_not_null_pointer(tensor_ptr, "Tensor")?;
89
90                let mut is_tensor = 0;
91                let status = unsafe { g_ort().IsTensor.unwrap()(tensor_ptr, &mut is_tensor) };
92                status_to_result(status).map_err(OrtError::IsTensor)?;
93            }
94            TensorElementDataType::String => {
95                // create tensor without data -- data is filled in later
96                unsafe {
97                    call_ort(|ort| {
98                        ort.CreateTensorAsOrtValue.unwrap()(
99                            allocator_ptr,
100                            shape_ptr,
101                            shape_len,
102                            T::tensor_element_data_type().into(),
103                            tensor_ptr_ptr,
104                        )
105                    })
106                }
107                .map_err(OrtError::CreateTensor)?;
108
109                // create null-terminated copies of each string, as per `FillStringTensor` docs
110                let null_terminated_copies: Vec<ffi::CString> = array
111                    .iter()
112                    .map(|elt| {
113                        let slice = elt
114                            .try_utf8_bytes()
115                            .expect("String data type must provide utf8 bytes");
116                        ffi::CString::new(slice)
117                    })
118                    .collect::<std::result::Result<Vec<_>, _>>()
119                    .map_err(OrtError::CStringNulError)?;
120
121                let string_pointers = null_terminated_copies
122                    .iter()
123                    .map(|cstring| cstring.as_ptr())
124                    .collect::<Vec<_>>();
125
126                unsafe {
127                    call_ort(|ort| {
128                        ort.FillStringTensor.unwrap()(
129                            tensor_ptr,
130                            string_pointers.as_ptr(),
131                            string_pointers.len(),
132                        )
133                    })
134                }
135                .map_err(OrtError::FillStringTensor)?;
136            }
137        }
138
139        assert_not_null_pointer(tensor_ptr, "Tensor")?;
140
141        Ok(OrtTensor {
142            c_ptr: tensor_ptr,
143            array,
144            memory_info,
145        })
146    }
147}
148
149impl<'t, T, D> Deref for OrtTensor<'t, T, D>
150where
151    T: TypeToTensorElementDataType + Debug + Clone,
152    D: ndarray::Dimension,
153{
154    type Target = Array<T, D>;
155
156    fn deref(&self) -> &Self::Target {
157        &self.array
158    }
159}
160
161impl<'t, T, D> Drop for OrtTensor<'t, T, D>
162where
163    T: TypeToTensorElementDataType + Debug + Clone,
164    D: ndarray::Dimension,
165{
166    #[tracing::instrument]
167    fn drop(&mut self) {
168        // We need to let the C part free
169        debug!("Dropping Tensor.");
170        if self.c_ptr.is_null() {
171            error!("Null pointer, not calling free.");
172        } else {
173            unsafe { g_ort().ReleaseValue.unwrap()(self.c_ptr) }
174        }
175
176        self.c_ptr = std::ptr::null_mut();
177    }
178}
179
180impl<'t, T, D> OrtTensor<'t, T, D>
181where
182    T: TypeToTensorElementDataType + Debug + Clone,
183    D: ndarray::Dimension,
184{
185    /// Apply a softmax on the specified axis
186    pub fn softmax(&self, axis: ndarray::Axis) -> Array<T, D>
187    where
188        D: ndarray::RemoveAxis,
189        T: ndarray::NdFloat + std::ops::SubAssign + std::ops::DivAssign,
190    {
191        self.array.softmax(axis)
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198    use crate::{AllocatorType, MemType};
199    use ndarray::{arr0, arr1, arr2, arr3};
200    use std::ptr;
201    use test_env_log::test;
202
203    #[test]
204    fn orttensor_from_array_0d_i32() {
205        let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default).unwrap();
206        let array = arr0::<i32>(123);
207        let tensor = OrtTensor::from_array(&memory_info, ptr::null_mut(), array).unwrap();
208        let expected_shape: &[usize] = &[];
209        assert_eq!(tensor.shape(), expected_shape);
210    }
211
212    #[test]
213    fn orttensor_from_array_1d_i32() {
214        let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default).unwrap();
215        let array = arr1(&[1_i32, 2, 3, 4, 5, 6]);
216        let tensor = OrtTensor::from_array(&memory_info, ptr::null_mut(), array).unwrap();
217        let expected_shape: &[usize] = &[6];
218        assert_eq!(tensor.shape(), expected_shape);
219    }
220
221    #[test]
222    fn orttensor_from_array_2d_i32() {
223        let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default).unwrap();
224        let array = arr2(&[[1_i32, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]);
225        let tensor = OrtTensor::from_array(&memory_info, ptr::null_mut(), array).unwrap();
226        assert_eq!(tensor.shape(), &[2, 6]);
227    }
228
229    #[test]
230    fn orttensor_from_array_3d_i32() {
231        let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default).unwrap();
232        let array = arr3(&[
233            [[1_i32, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]],
234            [[13, 14, 15, 16, 17, 18], [19, 20, 21, 22, 23, 24]],
235            [[25, 26, 27, 28, 29, 30], [31, 32, 33, 34, 35, 36]],
236        ]);
237        let tensor = OrtTensor::from_array(&memory_info, ptr::null_mut(), array).unwrap();
238        assert_eq!(tensor.shape(), &[3, 2, 6]);
239    }
240
241    #[test]
242    fn orttensor_from_array_1d_string() {
243        let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default).unwrap();
244        let array = arr1(&[
245            String::from("foo"),
246            String::from("bar"),
247            String::from("baz"),
248        ]);
249        let tensor = OrtTensor::from_array(&memory_info, ort_default_allocator(), array).unwrap();
250        assert_eq!(tensor.shape(), &[3]);
251    }
252
253    #[test]
254    fn orttensor_from_array_3d_str() {
255        let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default).unwrap();
256        let array = arr3(&[
257            [["1", "2", "3"], ["4", "5", "6"]],
258            [["7", "8", "9"], ["10", "11", "12"]],
259        ]);
260        let tensor = OrtTensor::from_array(&memory_info, ort_default_allocator(), array).unwrap();
261        assert_eq!(tensor.shape(), &[2, 2, 3]);
262    }
263
264    fn ort_default_allocator() -> *mut sys::OrtAllocator {
265        let mut allocator_ptr: *mut sys::OrtAllocator = std::ptr::null_mut();
266        unsafe {
267            // this default non-arena allocator doesn't need to be deallocated
268            call_ort(|ort| ort.GetAllocatorWithDefaultOptions.unwrap()(&mut allocator_ptr))
269        }
270        .unwrap();
271        allocator_ptr
272    }
273}