acme_tensor/
tensor.rs

1/*
2    Appellation: tensor <mod>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use crate::actions::iter::{Iter, IterMut};
6use crate::error::{TensorError, TensorResult};
7use crate::ops::{BackpropOp, TensorExpr};
8use crate::prelude::{TensorId, TensorKind};
9use crate::shape::{IntoShape, IntoStride, Layout, Rank, Shape, Stride};
10
11#[cfg(not(feature = "std"))]
12use alloc::vec::{self, Vec};
13use core::iter::Map;
14use core::ops::{Index, IndexMut};
15use core::slice::Iter as SliceIter;
16#[cfg(feature = "std")]
17use std::vec;
18
19pub(crate) fn create<T>(
20    kind: impl Into<TensorKind>,
21    op: impl Into<BackpropOp<T>>,
22    shape: impl IntoShape,
23    data: Vec<T>,
24) -> TensorBase<T> {
25    TensorBase {
26        id: TensorId::new(),
27        data,
28        kind: kind.into(),
29        layout: Layout::contiguous(shape),
30        op: op.into(),
31    }
32}
33#[allow(dead_code)]
34pub(crate) fn from_scalar_with_op<T>(
35    kind: impl Into<TensorKind>,
36    op: TensorExpr<T>,
37    data: T,
38) -> TensorBase<T> {
39    create(
40        kind.into(),
41        BackpropOp::new(op),
42        Shape::scalar(),
43        vec![data],
44    )
45}
46
47pub(crate) fn from_vec_with_kind<T>(
48    kind: impl Into<TensorKind>,
49    shape: impl IntoShape,
50    data: Vec<T>,
51) -> TensorBase<T> {
52    create(kind, BackpropOp::none(), shape, data)
53}
54
55pub(crate) fn from_vec_with_op<T>(
56    kind: impl Into<TensorKind>,
57    op: TensorExpr<T>,
58    shape: impl IntoShape,
59    data: Vec<T>,
60) -> TensorBase<T> {
61    create(kind.into(), BackpropOp::new(op), shape, data)
62}
63
64#[derive(Clone, Debug, Hash)]
65#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
66pub struct TensorBase<T = f64> {
67    pub(crate) id: TensorId,
68    pub(crate) data: Vec<T>,
69    pub(crate) kind: TensorKind,
70    pub(crate) layout: Layout,
71    pub(crate) op: BackpropOp<T>,
72}
73
74impl<T> TensorBase<T> {
75    /// Create a new tensor from an iterator.
76    pub fn from_iter<I>(iter: I) -> Self
77    where
78        I: IntoIterator<Item = T>,
79    {
80        Self::from_vec(Vec::from_iter(iter))
81    }
82    pub unsafe fn from_raw_parts(
83        ptr: *mut T,
84        shape: impl IntoShape,
85        stride: impl IntoStride,
86    ) -> Self {
87        let shape = shape.into_shape();
88        let stride = stride.into_stride();
89
90        let data = Vec::from_raw_parts(ptr, shape.size(), shape.size());
91        Self {
92            id: TensorId::new(),
93            kind: TensorKind::default(),
94            layout: Layout::new(0, shape, stride),
95            data,
96            op: BackpropOp::none(),
97        }
98    }
99    /// Create a new tensor from a scalar value.
100    pub fn from_scalar(value: T) -> Self {
101        Self {
102            id: TensorId::new(),
103            data: vec![value],
104            kind: TensorKind::default(),
105            layout: Layout::contiguous(()),
106            op: None.into(),
107        }
108    }
109    /// Create a new tensor from an iterator, with a particular shape.
110    pub fn from_shape_iter<I>(shape: impl IntoShape, iter: I) -> Self
111    where
112        I: IntoIterator<Item = T>,
113    {
114        Self::from_shape_vec(shape, Vec::from_iter(iter))
115    }
116    pub unsafe fn from_shape_ptr(shape: impl IntoShape, ptr: *mut T) -> Self {
117        let layout = Layout::contiguous(shape);
118        let data = Vec::from_raw_parts(ptr, layout.size(), layout.size());
119        Self {
120            id: TensorId::new(),
121            kind: TensorKind::default(),
122            layout: layout.clone(),
123            data,
124            op: BackpropOp::none(),
125        }
126    }
127    /// Create a new tensor from a [Vec], with a specified [shape](Shape).
128    pub fn from_shape_vec(shape: impl IntoShape, data: Vec<T>) -> Self {
129        Self {
130            id: TensorId::new(),
131            data,
132            kind: TensorKind::default(),
133            layout: Layout::contiguous(shape),
134            op: BackpropOp::none(),
135        }
136    }
137    /// Create a new, one-dimensional tensor from a [Vec].
138    pub fn from_vec(data: Vec<T>) -> Self {
139        let shape = Shape::from(data.len());
140        Self {
141            id: TensorId::new(),
142            data,
143            kind: TensorKind::default(),
144            layout: Layout::contiguous(shape),
145            op: BackpropOp::none(),
146        }
147    }
148    /// Return a mutable pointer to the tensor's data.
149    pub fn as_mut_ptr(&mut self) -> *mut T {
150        self.data_mut().as_mut_ptr()
151    }
152    /// Return a pointer to the tensor's data.
153    pub fn as_ptr(&self) -> *const T {
154        self.data().as_ptr()
155    }
156    /// Return a reference to the tensor's data.
157    pub fn as_slice(&self) -> &[T] {
158        &self.data
159    }
160    /// Return a mutable reference to the tensor's data.
161    pub fn as_mut_slice(&mut self) -> &mut [T] {
162        &mut self.data
163    }
164    /// Assign the values of another tensor to this tensor.
165    pub fn assign(&mut self, other: &Self)
166    where
167        T: Clone,
168    {
169        self.data_mut()
170            .iter_mut()
171            .zip(other.data())
172            .for_each(|(a, b)| *a = b.clone());
173    }
174
175    pub fn boxed(self) -> Box<Self> {
176        Box::new(self)
177    }
178    /// Detach the computational graph from the tensor
179    pub fn detach(&self) -> Self
180    where
181        T: Clone,
182    {
183        if self.op.is_none() && !self.is_variable() {
184            self.clone()
185        } else {
186            Self {
187                id: self.id,
188                kind: self.kind,
189                layout: self.layout.clone(),
190                op: BackpropOp::none(),
191                data: self.data.clone(),
192            }
193        }
194    }
195    /// Returns a reference to the first element of the tensor.
196    pub fn first(&self) -> Option<&T> {
197        let pos = vec![0; *self.rank()];
198        self.get(pos)
199    }
200    /// Returns a mutable reference to the first element of the tensor.
201    pub fn first_mut(&mut self) -> Option<&mut T> {
202        let pos = vec![0; *self.rank()];
203        self.get_mut(pos)
204    }
205    /// Returns the data at the specified index.
206    pub fn get(&self, index: impl AsRef<[usize]>) -> Option<&T> {
207        let i = self.layout.index(index);
208        self.data().get(i)
209    }
210    /// Returns a mutable reference to the data at the specified index.
211    pub fn get_mut(&mut self, index: impl AsRef<[usize]>) -> Option<&mut T> {
212        let i = self.layout.index(index);
213        self.data_mut().get_mut(i)
214    }
215    /// Returns the unique identifier of the tensor.
216    pub const fn id(&self) -> TensorId {
217        self.id
218    }
219
220    pub unsafe fn into_scalar(self) -> T
221    where
222        T: Clone,
223    {
224        debug_assert!(self.is_scalar(), "Tensor is not scalar");
225        self.data.first().unwrap().clone()
226    }
227    /// Returns true if the tensor is contiguous.
228    pub fn is_contiguous(&self) -> bool {
229        self.layout().is_contiguous()
230    }
231    /// Returns true if the tensor is empty.
232    pub fn is_empty(&self) -> bool {
233        self.data().is_empty()
234    }
235    /// A function to check if the tensor is a scalar
236    pub fn is_scalar(&self) -> bool {
237        *self.rank() == 0
238    }
239    /// Returns true if the tensor is a square matrix.
240    pub fn is_square(&self) -> bool {
241        self.shape().is_square()
242    }
243    /// A function to check if the tensor is a variable
244    pub const fn is_variable(&self) -> bool {
245        self.kind().is_variable()
246    }
247    /// Creates an immutable iterator over the elements in the tensor.
248    pub fn iter(&self) -> Iter<'_, T> {
249        Iter::new(self.view())
250    }
251    /// Create a mutable iterator over the elements in the tensor.
252    pub fn iter_mut(&mut self) -> IterMut<'_, T> {
253        IterMut::new(self)
254    }
255    /// Get the kind of the tensor
256    pub const fn kind(&self) -> TensorKind {
257        self.kind
258    }
259    /// Get a reference to the last element of the tensor
260    pub fn last(&self) -> Option<&T> {
261        let pos = self.shape().get_final_position();
262        self.get(pos)
263    }
264    /// Get a mutable reference to the last element of the tensor
265    pub fn last_mut(&mut self) -> Option<&mut T> {
266        let pos = self.shape().get_final_position();
267        self.get_mut(pos)
268    }
269    /// Get a reference to the [Layout] of the tensor
270    pub const fn layout(&self) -> &Layout {
271        &self.layout
272    }
273    /// Get the number of columns in the tensor
274    pub fn ncols(&self) -> usize {
275        self.shape().ncols()
276    }
277    /// Get the number of rows in the tensor
278    pub fn nrows(&self) -> usize {
279        self.shape().nrows()
280    }
281    /// Get a reference to the operation of the tensor
282    pub const fn op(&self) -> &BackpropOp<T> {
283        &self.op
284    }
285    /// Get a reference to the operation of the tensor
286    pub fn op_view(&self) -> BackpropOp<&T> {
287        self.op().view()
288    }
289    /// Get an owned reference to the [Rank] of the tensor
290    pub fn rank(&self) -> Rank {
291        self.shape().rank()
292    }
293    /// Set the value of the tensor at the specified index
294    pub fn set(&mut self, index: impl AsRef<[usize]>, value: T) {
295        let i = self.layout().index(index);
296        self.data_mut()[i] = value;
297    }
298    /// An owned reference of the tensors [Shape]
299    pub fn shape(&self) -> &Shape {
300        self.layout().shape()
301    }
302    /// Returns the number of elements in the tensor.
303    pub fn size(&self) -> usize {
304        self.layout().size()
305    }
306    /// Get a reference to the stride of the tensor
307    pub fn strides(&self) -> &Stride {
308        self.layout().strides()
309    }
310    /// Turn the tensor into a scalar
311    /// If the tensor has a rank greater than 0, this will return an error
312    pub fn to_scalar(&self) -> TensorResult<&T> {
313        if !self.is_scalar() {
314            return Err(TensorError::NotScalar);
315        }
316        Ok(self.first().unwrap())
317    }
318    /// Turn the tensor into a one-dimensional vector
319    pub fn to_vec(&self) -> Vec<T>
320    where
321        T: Clone,
322    {
323        self.data().to_vec()
324    }
325    /// Changes the kind of tensor to a variable
326    pub fn variable(mut self) -> Self {
327        self.kind = TensorKind::Variable;
328        self
329    }
330    /// Set the layout of the tensor
331    pub fn with_layout(self, layout: Layout) -> Self {
332        if layout.size() != self.size() {
333            panic!("Size mismatch");
334        }
335        unsafe { self.with_layout_unchecked(layout) }
336    }
337    /// Set the layout of the tensor without checking for compatibility
338    ///
339    /// # Safety
340    ///
341    /// This function is unsafe because it does not check if the layout is compatible with the tensor.
342    pub unsafe fn with_layout_unchecked(mut self, layout: Layout) -> Self {
343        self.layout = layout;
344        self
345    }
346
347    pub fn with_op(mut self, op: BackpropOp<T>) -> Self {
348        self.op = op;
349        self
350    }
351
352    pub fn with_shape_c(mut self, shape: impl IntoShape) -> Self {
353        self.layout = self.layout.with_shape_c(shape);
354        self
355    }
356}
357
358impl<'a, T> TensorBase<&'a T> {
359    // pub fn as_tensor(&self) -> TensorBase<T> where T: Copy {
360    //     let store = self.data.iter().copied().collect();
361    //     TensorBase {
362    //         id: self.id,
363    //         kind: self.kind,
364    //         layout: self.layout.clone(),
365    //         op: self.op.clone(),
366    //         data: store,
367    //     }
368    // }
369}
370
371impl<T> TensorBase<T> {
372    pub fn view_from_scalar(scalar: &T) -> TensorBase<&T> {
373        TensorBase {
374            id: TensorId::new(),
375            kind: TensorKind::default(),
376            layout: Layout::scalar(),
377            op: BackpropOp::none(),
378            data: vec![scalar],
379        }
380    }
381    pub fn to_owned(&self) -> TensorBase<T>
382    where
383        T: Clone,
384    {
385        self.clone()
386    }
387
388    pub fn view(&self) -> TensorBase<&T> {
389        TensorBase {
390            id: self.id(),
391            kind: self.kind(),
392            layout: self.layout().clone(),
393            op: self.op().view(),
394            data: self.data().iter().collect(),
395        }
396    }
397
398    pub fn view_mut(&mut self) -> TensorBase<&mut T> {
399        TensorBase {
400            id: self.id(),
401            kind: self.kind(),
402            layout: self.layout().clone(),
403            op: self.op.view_mut(),
404            data: self.data.iter_mut().collect(),
405        }
406    }
407}
408// Inernal Methods
409#[allow(dead_code)]
410impl<T> TensorBase<T> {
411    pub(crate) fn data(&self) -> &Vec<T> {
412        &self.data
413    }
414
415    pub(crate) fn data_mut(&mut self) -> &mut Vec<T> {
416        &mut self.data
417    }
418
419    pub(crate) fn get_by_index(&self, index: usize) -> Option<&T> {
420        self.data.get(index)
421    }
422
423    pub(crate) fn get_mut_by_index(&mut self, index: usize) -> Option<&mut T> {
424        self.data.get_mut(index)
425    }
426
427    pub(crate) fn map<'a, F>(&'a self, f: F) -> Map<SliceIter<'a, T>, F>
428    where
429        F: FnMut(&'a T) -> T,
430        T: 'a + Clone,
431    {
432        self.data.iter().map(f)
433    }
434
435    pub(crate) fn mapv<F>(&self, f: F) -> TensorBase<T>
436    where
437        F: Fn(T) -> T,
438        T: Copy,
439    {
440        let store = self.data.iter().copied().map(f).collect();
441        TensorBase {
442            id: TensorId::new(),
443            kind: self.kind,
444            layout: self.layout.clone(),
445            op: self.op.clone(),
446            data: store,
447        }
448    }
449
450    pub(crate) fn map_binary<F>(&self, other: &TensorBase<T>, op: F) -> TensorBase<T>
451    where
452        F: acme::prelude::BinOp<T, T, Output = T>,
453        T: Copy,
454    {
455        let store = self
456            .iter()
457            .zip(other.iter())
458            .map(|(a, b)| op.eval(*a, *b))
459            .collect();
460        TensorBase {
461            id: TensorId::new(),
462            kind: self.kind,
463            layout: self.layout.clone(),
464            op: self.op.clone(),
465            data: store,
466        }
467    }
468}
469
470impl<'a, T> AsRef<TensorBase<T>> for TensorBase<&'a T> {
471    fn as_ref(&self) -> &TensorBase<T> {
472        unsafe { &*(self as *const TensorBase<&'a T> as *const TensorBase<T>) }
473    }
474}
475
476impl<Idx, T> Index<Idx> for TensorBase<T>
477where
478    Idx: AsRef<[usize]>,
479{
480    type Output = T;
481
482    fn index(&self, index: Idx) -> &Self::Output {
483        let i = self.layout().index(index);
484        &self.data[i]
485    }
486}
487
488impl<Idx, T> IndexMut<Idx> for TensorBase<T>
489where
490    Idx: AsRef<[usize]>,
491{
492    fn index_mut(&mut self, index: Idx) -> &mut Self::Output {
493        let i = self.layout().index(index);
494        &mut self.data[i]
495    }
496}
497
498impl<T> Eq for TensorBase<T> where T: Eq {}
499
500impl<T> Ord for TensorBase<T>
501where
502    T: Ord,
503{
504    fn cmp(&self, other: &Self) -> core::cmp::Ordering {
505        self.data.cmp(&other.data)
506    }
507}
508
509impl<T> PartialEq for TensorBase<T>
510where
511    T: PartialEq,
512{
513    fn eq(&self, other: &Self) -> bool {
514        self.layout == other.layout && self.data == other.data
515    }
516}
517
518impl<S, T> PartialEq<S> for TensorBase<T>
519where
520    S: AsRef<[T]>,
521    T: PartialEq,
522{
523    fn eq(&self, other: &S) -> bool {
524        &self.data == other.as_ref()
525    }
526}
527
528impl<T> PartialOrd for TensorBase<T>
529where
530    T: PartialOrd,
531{
532    fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
533        self.data.partial_cmp(&other.data)
534    }
535}