1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
use std::{fmt::Debug, ops::Deref, ptr};
use ndarray::ArrayView;
use super::{TensorData, TensorDataToType};
use crate::{ortsys, sys};
/// Tensor containing data owned by the ONNX Runtime C library, used to return values from inference.
///
/// This tensor type is returned by the [`Session::run()`](../session/struct.Session.html#method.run) method.
/// It is not meant to be created directly.
///
/// The tensor hosts an [`ndarray::ArrayView`](https://docs.rs/ndarray/latest/ndarray/type.ArrayView.html)
/// of the data on the C side. This allows manipulation on the Rust side using `ndarray` without copying the data.
///
/// `OrtOwnedTensor` implements the [`std::deref::Deref`](#impl-Deref) trait for ergonomic access to
/// the underlying [`ndarray::ArrayView`](https://docs.rs/ndarray/latest/ndarray/type.ArrayView.html).
#[derive(Debug)]
pub struct OrtOwnedTensor<'t, T, D>
where
T: TensorDataToType,
D: ndarray::Dimension
{
pub(crate) data: TensorData<'t, T, D>
}
impl<'t, T, D> OrtOwnedTensor<'t, T, D>
where
T: TensorDataToType,
D: ndarray::Dimension + 't
{
/// Produce a [`ViewHolder`] for the underlying data.
pub fn view<'s>(&'s self) -> ViewHolder<'s, T, D>
where
't: 's // tensor ptr can outlive the TensorData
{
ViewHolder::new(&self.data)
}
}
/// An intermediate step on the way to an [`ArrayView`].
// Since Deref has to produce a reference, and the referent can't be a local in deref(), it must
// be a field in a struct. This struct exists only to hold that field.
// Its lifetime 's is bound to the TensorData its view was created around, not the underlying tensor
// pointer, since in the case of strings the data is the Array in the TensorData, not the pointer.
pub struct ViewHolder<'s, T, D>
where
T: TensorDataToType,
D: ndarray::Dimension
{
array_view: ndarray::ArrayView<'s, T, D>
}
impl<'s, T, D> ViewHolder<'s, T, D>
where
T: TensorDataToType,
D: ndarray::Dimension
{
fn new<'t>(data: &'s TensorData<'t, T, D>) -> ViewHolder<'s, T, D>
where
't: 's // underlying tensor ptr lives at least as long as TensorData
{
match data {
TensorData::TensorPtr { array_view, .. } => ViewHolder {
// we already have a view, but creating a view from a view is cheap
array_view: array_view.view()
},
TensorData::Strings { strings } => ViewHolder {
// This view creation has to happen here, not at new()'s callsite, because
// a field can't be a reference to another field in the same struct. Thus, we have
// this separate struct to hold the view that refers to the `Array`.
array_view: strings.view()
}
}
}
}
impl<'t, T, D> Deref for ViewHolder<'t, T, D>
where
T: TensorDataToType,
D: ndarray::Dimension
{
type Target = ArrayView<'t, T, D>;
fn deref(&self) -> &Self::Target {
&self.array_view
}
}
/// Holds on to a tensor pointer until dropped.
///
/// This allows for creating an [`OrtOwnedTensor`] from a [`DynOrtTensor`] without consuming `self`, which would prevent
/// retrying extraction and avoids awkward interaction with the outputs `Vec`. It also avoids requiring `OrtOwnedTensor`
/// to keep a reference to `DynOrtTensor`, which would be inconvenient.
#[derive(Debug)]
pub struct TensorPointerHolder {
pub(crate) tensor_ptr: *mut sys::OrtValue
}
impl Drop for TensorPointerHolder {
#[tracing::instrument]
fn drop(&mut self) {
ortsys![unsafe ReleaseValue(self.tensor_ptr)];
self.tensor_ptr = ptr::null_mut();
}
}