use crate::actions::iter::Iter;
use crate::error::{TensorError, TensorResult};
use crate::ops::{BackpropOp, TensorExpr};
use crate::prelude::{TensorId, TensorKind};
use crate::shape::{IntoShape, Layout, Rank, Shape, Stride};
#[cfg(not(feature = "std"))]
use alloc::vec::{self, Vec};
use core::iter::Map;
use core::ops::{Index, IndexMut};
use core::slice::Iter as SliceIter;
#[cfg(feature = "std")]
use std::vec;
pub(crate) fn create<T>(
kind: impl Into<TensorKind>,
op: impl Into<BackpropOp<T>>,
shape: impl IntoShape,
data: Vec<T>,
) -> TensorBase<T> {
TensorBase {
id: TensorId::new(),
data,
kind: kind.into(),
layout: Layout::contiguous(shape),
op: op.into(),
}
}
#[allow(dead_code)]
pub(crate) fn from_scalar_with_op<T>(
kind: impl Into<TensorKind>,
op: TensorExpr<T>,
data: T,
) -> TensorBase<T> {
create(
kind.into(),
BackpropOp::new(op),
Shape::scalar(),
vec![data],
)
}
pub(crate) fn from_vec_with_kind<T>(
kind: impl Into<TensorKind>,
shape: impl IntoShape,
data: Vec<T>,
) -> TensorBase<T> {
create(kind, BackpropOp::none(), shape, data)
}
pub(crate) fn from_vec_with_op<T>(
kind: impl Into<TensorKind>,
op: TensorExpr<T>,
shape: impl IntoShape,
data: Vec<T>,
) -> TensorBase<T> {
create(kind.into(), BackpropOp::new(op), shape, data)
}
#[derive(Clone, Debug, Hash)]
pub struct TensorBase<T = f64> {
pub(crate) id: TensorId,
pub(crate) data: Vec<T>,
pub(crate) kind: TensorKind,
pub(crate) layout: Layout,
pub(crate) op: BackpropOp<T>,
}
impl<T> TensorBase<T> {
pub fn from_iter<I>(iter: I) -> Self
where
I: IntoIterator<Item = T>,
{
Self::from_vec(Vec::from_iter(iter))
}
pub fn from_scalar(value: T) -> Self {
Self {
id: TensorId::new(),
data: vec![value],
kind: TensorKind::default(),
layout: Layout::contiguous(()),
op: None.into(),
}
}
pub fn from_shape_iter<I>(shape: impl IntoShape, iter: I) -> Self
where
I: IntoIterator<Item = T>,
{
Self::from_shape_vec(shape, Vec::from_iter(iter))
}
pub fn from_shape_vec(shape: impl IntoShape, data: Vec<T>) -> Self {
Self {
id: TensorId::new(),
data,
kind: TensorKind::default(),
layout: Layout::contiguous(shape),
op: BackpropOp::none(),
}
}
pub fn from_vec(data: Vec<T>) -> Self {
let shape = Shape::from(data.len());
Self {
id: TensorId::new(),
data,
kind: TensorKind::default(),
layout: Layout::contiguous(shape),
op: BackpropOp::none(),
}
}
pub fn as_mut_ptr(&mut self) -> *mut T {
self.data_mut().as_mut_ptr()
}
pub fn as_ptr(&self) -> *const T {
self.data().as_ptr()
}
pub fn as_slice(&self) -> &[T] {
&self.data
}
pub fn as_mut_slice(&mut self) -> &mut [T] {
&mut self.data
}
pub fn assign(&mut self, other: &Self)
where
T: Clone,
{
self.data_mut()
.iter_mut()
.zip(other.data())
.for_each(|(a, b)| *a = b.clone());
}
pub fn detach(&self) -> Self
where
T: Clone,
{
if self.op.is_none() && !self.is_variable() {
self.clone()
} else {
Self {
id: self.id,
kind: self.kind,
layout: self.layout.clone(),
op: BackpropOp::none(),
data: self.data.clone(),
}
}
}
pub fn first(&self) -> Option<&T> {
let pos = vec![0; *self.rank()];
self.get(pos)
}
pub fn first_mut(&mut self) -> Option<&mut T> {
let pos = vec![0; *self.rank()];
self.get_mut(pos)
}
pub fn get(&self, index: impl AsRef<[usize]>) -> Option<&T> {
let i = self.layout.index(index);
self.data().get(i)
}
pub fn get_mut(&mut self, index: impl AsRef<[usize]>) -> Option<&mut T> {
let i = self.layout.index(index);
self.data_mut().get_mut(i)
}
pub const fn id(&self) -> TensorId {
self.id
}
pub fn is_contiguous(&self) -> bool {
self.layout().is_contiguous()
}
pub fn is_empty(&self) -> bool {
self.data().is_empty()
}
pub fn is_scalar(&self) -> bool {
*self.rank() == 0
}
pub fn is_square(&self) -> bool {
self.shape().is_square()
}
pub const fn is_variable(&self) -> bool {
self.kind().is_variable()
}
pub fn iter(&self) -> Iter<'_, T> {
Iter::new(self)
}
pub const fn kind(&self) -> TensorKind {
self.kind
}
pub fn last(&self) -> Option<&T> {
let pos = self.shape().get_final_position();
self.get(pos)
}
pub fn last_mut(&mut self) -> Option<&mut T> {
let pos = self.shape().get_final_position();
self.get_mut(pos)
}
pub const fn layout(&self) -> &Layout {
&self.layout
}
pub fn ncols(&self) -> usize {
self.shape().ncols()
}
pub fn nrows(&self) -> usize {
self.shape().nrows()
}
pub const fn op(&self) -> &BackpropOp<T> {
&self.op
}
pub fn op_view(&self) -> BackpropOp<&T> {
self.op().view()
}
pub fn rank(&self) -> Rank {
self.shape().rank()
}
pub fn set(&mut self, index: impl AsRef<[usize]>, value: T) {
let i = self.layout().index(index);
self.data_mut()[i] = value;
}
pub fn shape(&self) -> &Shape {
self.layout().shape()
}
pub fn size(&self) -> usize {
self.layout().size()
}
pub fn strides(&self) -> &Stride {
self.layout().strides()
}
pub fn to_scalar(&self) -> TensorResult<&T> {
if !self.is_scalar() {
return Err(TensorError::NotScalar);
}
Ok(self.first().unwrap())
}
pub fn to_vec(&self) -> Vec<T>
where
T: Clone,
{
self.data().to_vec()
}
pub fn variable(mut self) -> Self {
self.kind = TensorKind::Variable;
self
}
pub fn with_layout(self, layout: Layout) -> Self {
if layout.size() != self.size() {
panic!("Size mismatch");
}
unsafe { self.with_layout_unchecked(layout) }
}
pub unsafe fn with_layout_unchecked(mut self, layout: Layout) -> Self {
self.layout = layout;
self
}
pub fn with_op(mut self, op: BackpropOp<T>) -> Self {
self.op = op;
self
}
pub unsafe fn with_shape_unchecked(mut self, shape: impl IntoShape) -> Self {
self.layout = self.layout.with_shape_c(shape);
self
}
}
impl<'a, T> TensorBase<&'a T> {
}
impl<T> TensorBase<T> {
pub fn to_owned(&self) -> TensorBase<T>
where
T: Clone,
{
self.clone()
}
pub fn view<'a>(&'a self) -> TensorBase<&'a T> {
TensorBase {
id: self.id(),
kind: self.kind(),
layout: self.layout().clone(),
op: self.op().view(),
data: self.data().iter().collect(),
}
}
}
#[allow(dead_code)]
impl<T> TensorBase<T> {
pub(crate) fn data(&self) -> &Vec<T> {
&self.data
}
pub(crate) fn data_mut(&mut self) -> &mut Vec<T> {
&mut self.data
}
pub(crate) fn get_by_index(&self, index: usize) -> Option<&T> {
self.data.get(index)
}
pub(crate) fn get_by_index_mut(&mut self, index: usize) -> Option<&mut T> {
self.data.get_mut(index)
}
pub(crate) fn map<'a, F>(&'a self, f: F) -> Map<SliceIter<'a, T>, F>
where
F: FnMut(&'a T) -> T,
T: 'a + Clone,
{
self.data.iter().map(f)
}
pub(crate) fn mapv<F>(&self, f: F) -> TensorBase<T>
where
F: Fn(T) -> T,
T: Copy,
{
let store = self.data.iter().copied().map(f).collect();
TensorBase {
id: TensorId::new(),
kind: self.kind,
layout: self.layout.clone(),
op: self.op.clone(),
data: store,
}
}
pub(crate) fn map_binary<F>(&self, other: &TensorBase<T>, op: F) -> TensorBase<T>
where
F: acme::prelude::BinOp<T, T, Output = T>,
T: Copy,
{
let store = self
.iter()
.zip(other.iter())
.map(|(a, b)| op.eval(*a, *b))
.collect();
TensorBase {
id: TensorId::new(),
kind: self.kind,
layout: self.layout.clone(),
op: self.op.clone(),
data: store,
}
}
}
impl<'a, T> AsRef<TensorBase<T>> for TensorBase<&'a T> {
fn as_ref(&self) -> &TensorBase<T> {
unsafe { &*(self as *const TensorBase<&'a T> as *const TensorBase<T>) }
}
}
impl<Idx, T> Index<Idx> for TensorBase<T>
where
Idx: AsRef<[usize]>,
{
type Output = T;
fn index(&self, index: Idx) -> &Self::Output {
let i = self.layout().index(index);
&self.data[i]
}
}
impl<Idx, T> IndexMut<Idx> for TensorBase<T>
where
Idx: AsRef<[usize]>,
{
fn index_mut(&mut self, index: Idx) -> &mut Self::Output {
let i = self.layout().index(index);
&mut self.data[i]
}
}
impl<T> Eq for TensorBase<T> where T: Eq {}
impl<T> Ord for TensorBase<T>
where
T: Ord,
{
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
self.data.cmp(&other.data)
}
}
impl<T> PartialEq for TensorBase<T>
where
T: PartialEq,
{
fn eq(&self, other: &Self) -> bool {
self.layout == other.layout && self.data == other.data
}
}
impl<S, T> PartialEq<S> for TensorBase<T>
where
S: AsRef<[T]>,
T: PartialEq,
{
fn eq(&self, other: &S) -> bool {
&self.data == other.as_ref()
}
}
impl<T> PartialOrd for TensorBase<T>
where
T: PartialOrd,
{
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
self.data.partial_cmp(&other.data)
}
}