use std::{fmt::Debug, ops::Deref, ptr};
use ndarray::ArrayView;
use super::{TensorData, TensorDataToType};
use crate::{ortsys, sys};
#[derive(Debug)]
pub struct OrtOwnedTensor<'t, T, D>
where
T: TensorDataToType,
D: ndarray::Dimension
{
pub(crate) data: TensorData<'t, T, D>
}
impl<'t, T, D> OrtOwnedTensor<'t, T, D>
where
T: TensorDataToType,
D: ndarray::Dimension + 't
{
pub fn view<'s>(&'s self) -> ViewHolder<'s, T, D>
where
't: 's {
ViewHolder::new(&self.data)
}
}
pub struct ViewHolder<'s, T, D>
where
T: TensorDataToType,
D: ndarray::Dimension
{
array_view: ndarray::ArrayView<'s, T, D>
}
impl<'s, T, D> ViewHolder<'s, T, D>
where
T: TensorDataToType,
D: ndarray::Dimension
{
fn new<'t>(data: &'s TensorData<'t, T, D>) -> ViewHolder<'s, T, D>
where
't: 's {
match data {
TensorData::TensorPtr { array_view, .. } => ViewHolder {
array_view: array_view.view()
},
TensorData::Strings { strings } => ViewHolder {
array_view: strings.view()
}
}
}
}
impl<'t, T, D> Deref for ViewHolder<'t, T, D>
where
T: TensorDataToType,
D: ndarray::Dimension
{
type Target = ArrayView<'t, T, D>;
fn deref(&self) -> &Self::Target {
&self.array_view
}
}
#[derive(Debug)]
pub struct TensorPointerHolder {
pub(crate) tensor_ptr: *mut sys::OrtValue
}
impl Drop for TensorPointerHolder {
#[tracing::instrument]
fn drop(&mut self) {
ortsys![unsafe ReleaseValue(self.tensor_ptr)];
self.tensor_ptr = ptr::null_mut();
}
}