use std::borrow::Cow;
use derive_more::Display;
use thiserror::Error;
use crate::array::ArrayBytesRaw;
use super::DataType;
#[derive(Clone, Debug, Display, Error)]
#[non_exhaustive]
pub enum TensorError {
#[display("Data type {_0:?} is not supported for this operation.")]
UnsupportedDataType(DataType),
}
pub struct Tensor {
bytes: ArrayBytesRaw<'static>,
data_type: DataType,
shape: Vec<u64>,
}
impl Tensor {
#[must_use]
pub fn new(
bytes: impl Into<ArrayBytesRaw<'static>>,
data_type: DataType,
shape: Vec<u64>,
) -> Self {
Self {
bytes: bytes.into(),
data_type,
shape,
}
}
#[must_use]
pub fn bytes(&self) -> &[u8] {
&self.bytes
}
#[must_use]
pub fn data_type(&self) -> &DataType {
&self.data_type
}
#[must_use]
pub fn shape(&self) -> &[u64] {
&self.shape
}
#[must_use]
pub fn into_parts(self) -> (Cow<'static, [u8]>, DataType, Vec<u64>) {
(self.bytes, self.data_type, self.shape)
}
#[must_use]
pub fn as_parts(&self) -> (&[u8], &DataType, &[u64]) {
(&self.bytes, &self.data_type, &self.shape)
}
}