onnxruntime_ng/tensor/
ort_owned_tensor.rs

1//! Module containing tensor with memory owned by the ONNX Runtime
2
3use std::{fmt::Debug, ops::Deref};
4
5use ndarray::{Array, ArrayView};
6use tracing::debug;
7
8use onnxruntime_sys_ng as sys;
9
10use crate::{
11    error::status_to_result, g_ort, memory::MemoryInfo, tensor::ndarray_tensor::NdArrayTensor,
12    OrtError, Result, TypeToTensorElementDataType,
13};
14
15/// Tensor containing data owned by the ONNX Runtime C library, used to return values from inference.
16///
17/// This tensor type is returned by the [`Session::run()`](../session/struct.Session.html#method.run) method.
18/// It is not meant to be created directly.
19///
20/// The tensor hosts an [`ndarray::ArrayView`](https://docs.rs/ndarray/latest/ndarray/type.ArrayView.html)
21/// of the data on the C side. This allows manipulation on the Rust side using `ndarray` without copying the data.
22///
23/// `OrtOwnedTensor` implements the [`std::deref::Deref`](#impl-Deref) trait for ergonomic access to
24/// the underlying [`ndarray::ArrayView`](https://docs.rs/ndarray/latest/ndarray/type.ArrayView.html).
25#[derive(Debug)]
26pub struct OrtOwnedTensor<'t, 'm, T, D>
27where
28    T: TypeToTensorElementDataType + Debug + Clone,
29    D: ndarray::Dimension,
30    'm: 't, // 'm outlives 't
31{
32    pub(crate) tensor_ptr: *mut sys::OrtValue,
33    array_view: ArrayView<'t, T, D>,
34    memory_info: &'m MemoryInfo,
35}
36
37impl<'t, 'm, T, D> Deref for OrtOwnedTensor<'t, 'm, T, D>
38where
39    T: TypeToTensorElementDataType + Debug + Clone,
40    D: ndarray::Dimension,
41{
42    type Target = ArrayView<'t, T, D>;
43
44    fn deref(&self) -> &Self::Target {
45        &self.array_view
46    }
47}
48
49impl<'t, 'm, T, D> OrtOwnedTensor<'t, 'm, T, D>
50where
51    T: TypeToTensorElementDataType + Debug + Clone,
52    D: ndarray::Dimension,
53{
54    /// Apply a softmax on the specified axis
55    pub fn softmax(&self, axis: ndarray::Axis) -> Array<T, D>
56    where
57        D: ndarray::RemoveAxis,
58        T: ndarray::NdFloat + std::ops::SubAssign + std::ops::DivAssign,
59    {
60        self.array_view.softmax(axis)
61    }
62}
63
64#[derive(Debug)]
65pub(crate) struct OrtOwnedTensorExtractor<'m, D>
66where
67    D: ndarray::Dimension,
68{
69    pub(crate) tensor_ptr: *mut sys::OrtValue,
70    memory_info: &'m MemoryInfo,
71    shape: D,
72}
73
74impl<'m, D> OrtOwnedTensorExtractor<'m, D>
75where
76    D: ndarray::Dimension,
77{
78    pub(crate) fn new(memory_info: &'m MemoryInfo, shape: D) -> OrtOwnedTensorExtractor<'m, D> {
79        OrtOwnedTensorExtractor {
80            tensor_ptr: std::ptr::null_mut(),
81            memory_info,
82            shape,
83        }
84    }
85
86    pub(crate) fn extract<'t, T>(self) -> Result<OrtOwnedTensor<'t, 'm, T, D>>
87    where
88        T: TypeToTensorElementDataType + Debug + Clone,
89    {
90        // Note: Both tensor and array will point to the same data, nothing is copied.
91        // As such, there is no need too free the pointer used to create the ArrayView.
92
93        assert_ne!(self.tensor_ptr, std::ptr::null_mut());
94
95        let mut is_tensor = 0;
96        let status = unsafe { g_ort().IsTensor.unwrap()(self.tensor_ptr, &mut is_tensor) };
97        status_to_result(status).map_err(OrtError::IsTensor)?;
98        (is_tensor == 1)
99            .then(|| ())
100            .ok_or(OrtError::IsTensorCheck)?;
101
102        // Get pointer to output tensor float values
103        let mut output_array_ptr: *mut T = std::ptr::null_mut();
104        let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr;
105        let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void =
106            output_array_ptr_ptr as *mut *mut std::ffi::c_void;
107        let status = unsafe {
108            g_ort().GetTensorMutableData.unwrap()(self.tensor_ptr, output_array_ptr_ptr_void)
109        };
110        status_to_result(status).map_err(OrtError::IsTensor)?;
111        assert_ne!(output_array_ptr, std::ptr::null_mut());
112
113        let array_view = unsafe { ArrayView::from_shape_ptr(self.shape, output_array_ptr) };
114
115        Ok(OrtOwnedTensor {
116            tensor_ptr: self.tensor_ptr,
117            array_view,
118            memory_info: self.memory_info,
119        })
120    }
121}
122
123impl<'t, 'm, T, D> Drop for OrtOwnedTensor<'t, 'm, T, D>
124where
125    T: TypeToTensorElementDataType + Debug + Clone,
126    D: ndarray::Dimension,
127    'm: 't, // 'm outlives 't
128{
129    #[tracing::instrument]
130    fn drop(&mut self) {
131        debug!("Dropping OrtOwnedTensor.");
132        unsafe { g_ort().ReleaseValue.unwrap()(self.tensor_ptr) }
133
134        self.tensor_ptr = std::ptr::null_mut();
135    }
136}