use crate::{
write_tensor_recursive, AnyTensor, DataType, Result, Shape, Tensor, TensorInner, TensorType,
};
use core::fmt;
use fmt::{Debug, Formatter};
use libc::c_int;
use std::{fmt::Display, ops::Deref};
use tensorflow_sys as tf;
#[derive(Clone, Eq)]
pub struct ReadonlyTensor<T: TensorType> {
pub(super) inner: T::InnerType,
pub(super) dims: Vec<u64>,
}
impl<T: TensorType> AnyTensor for ReadonlyTensor<T> {
fn inner(&self) -> Result<*mut tf::TF_Tensor> {
self.inner.as_mut_ptr(&self.dims)
}
fn data_type(&self) -> DataType {
T::data_type()
}
}
impl<T: TensorType> Deref for ReadonlyTensor<T> {
type Target = [T];
#[inline]
fn deref(&self) -> &[T] {
self.inner.deref()
}
}
impl<T: TensorType> Display for ReadonlyTensor<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> ::std::fmt::Result {
let mut counter: i64 = match std::env::var("TF_RUST_DISPLAY_MAX") {
Ok(e) => e.parse().unwrap_or(-1),
Err(_) => -1,
};
write_tensor_recursive(f, self, self.dims(), &mut counter)
}
}
impl<T: TensorType> Debug for ReadonlyTensor<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
crate::format_tensor(self, "ReadonlyTensor", self.dims(), f)
}
}
impl<T: TensorType + PartialEq> PartialEq for ReadonlyTensor<T> {
fn eq(&self, other: &Self) -> bool {
self.dims == other.dims && self.deref() == other.deref()
}
}
impl<T: TensorType + PartialEq> PartialEq<Tensor<T>> for ReadonlyTensor<T> {
fn eq(&self, other: &Tensor<T>) -> bool {
self.dims == other.dims && self.deref() == other.deref()
}
}
impl<T: TensorType> ReadonlyTensor<T> {
pub fn get(&self, indices: &[u64]) -> T {
let index = self.get_index(indices);
self[index].clone()
}
pub fn get_index(&self, indices: &[u64]) -> usize {
assert!(self.dims.len() == indices.len());
let mut index = 0;
let mut d = 1;
for i in (0..indices.len()).rev() {
assert!(self.dims[i] > indices[i]);
index += indices[i] * d;
d *= self.dims[i];
}
index as usize
}
pub fn dims(&self) -> &[u64] {
&self.dims
}
pub fn shape(&self) -> Shape {
Shape::from(&self.dims[..])
}
pub(super) unsafe fn from_tf_tensor(tensor: *mut tf::TF_Tensor) -> Option<Self> {
let mut dims = Vec::with_capacity(tf::TF_NumDims(tensor) as usize);
for i in 0..dims.capacity() {
dims.push(tf::TF_Dim(tensor, i as c_int) as u64);
}
Some(Self {
inner: T::InnerType::from_tf_tensor(tensor)?,
dims,
})
}
pub unsafe fn into_tensor(self) -> Tensor<T> {
Tensor {
inner: self.inner,
dims: self.dims,
}
}
}