Skip to main content

lumen_core/tensor/
mod.rs

1mod construct;
2mod indexer;
3mod iter;
4pub mod display;
5mod shape;
6mod arith;
7mod matmul;
8mod reduce;
9mod broadcast;
10mod convert;
11mod boolean;
12
13pub use construct::ToTensor;
14use std::{borrow::Borrow, hash::Hash, sync::Arc};
15pub use indexer::{Slice, IndexOp};
16use crate::{AutogradInfo, Error, FloatDType, Op, Storage};
17use super::{DType, Dim, DimCoordinates, DimNCoordinates, Layout, NumDType, Shape, StorageArc, StorageIndices, StorageMut, StorageRef, WithDType};
18pub use iter::*;
19pub use indexer::*;
20
21#[derive(Clone)]
22pub struct Tensor<T: WithDType>(Arc<TensorImpl<T>>);
23
24#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
25pub struct TensorId(usize);
26
27struct TensorImpl<T: WithDType> {
28    id: TensorId,
29    storage: Option<StorageArc<T>>,
30    layout: Layout,
31    meta: T::AutogradMeta,
32}
33
34impl TensorId {
35    pub fn new() -> Self {
36        use std::sync::atomic;
37        static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
38        Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
39    }
40
41    pub fn value(&self) -> usize {
42        self.0
43    }
44}
45
46impl Borrow<usize> for TensorId {
47    fn borrow(&self) -> &usize {
48        &self.0
49    }
50}
51
52impl<T: WithDType> Hash for Tensor<T> {
53    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
54        self.0.id.0.hash(state);
55    }
56} 
57
58impl<T: WithDType> PartialEq for Tensor<T> {
59    fn eq(&self, other: &Self) -> bool {
60        self.0.id.0 == other.0.id.0
61    }
62}
63
64impl<T: WithDType> Eq for Tensor<T> {}
65
66impl<T: WithDType> Tensor<T> {
67    pub fn is_scalar(&self) -> bool {
68        self.shape().is_scalar()
69    }
70
71    pub fn check_scalar(&self) -> crate::Result<()> {
72        if !self.is_scalar() {
73            Err(Error::NotScalar)?
74        } else {
75            Ok(())
76        }
77    }
78
79    pub fn to_scalar(&self) -> crate::Result<T> {
80        self.check_scalar()?;
81        let v = self.storage_read()?.get_unchecked(self.layout().start_offset());
82        Ok(v)
83    }
84
85    pub fn set_scalar(&self, val: T) -> crate::Result<()> {
86        self.check_scalar()?;
87        self.storage_write()?.set_unchecked(self.layout().start_offset(), val);
88        Ok(())
89    }
90
91    pub fn storage_ref<'a>(&'a self, start_offset: usize) -> crate::Result<StorageRef<'a, T>> {
92        self.0.storage.as_ref()
93            .ok_or(crate::Error::MetaTensor)
94            .map(|s| s.get_ref(start_offset))
95    }
96
97    pub fn storage_mut<'a>(&'a self, start_offset: usize) -> crate::Result<StorageMut<'a, T>> {
98        self.0.storage.as_ref()
99            .ok_or(crate::Error::MetaTensor)
100            .map(|s| s.get_mut(start_offset))
101    }
102
103    pub fn storage_ptr(&self, start_offset: usize) -> crate::Result<*mut T> {
104        self.0.storage.as_ref()
105            .ok_or(crate::Error::MetaTensor)
106            .map(|s| s.get_ptr(start_offset))
107    }
108
109    pub fn is_meta(&self) -> bool {
110        self.0.storage.is_none()
111    }
112}
113
114impl<T: WithDType> Tensor<T> {
115    pub fn id(&self) -> TensorId {
116        self.0.id
117    }
118
119    pub fn shape(&self) -> &Shape {
120        self.0.layout.shape()
121    }
122
123    pub fn dtype(&self) -> DType {
124        T::DTYPE
125    }
126
127    pub fn layout(&self) -> &Layout {
128        &self.0.layout
129    }
130
131    pub fn dims(&self) -> &[usize] {
132        self.shape().dims()
133    }
134
135    pub fn dim<D: Dim>(&self, dim: D) -> crate::Result<usize> {
136        let dim = dim.to_index(self.shape(), "dim")?;
137        Ok(self.dims()[dim])
138    }
139
140    pub fn storage_read(&self) -> crate::Result<std::sync::RwLockReadGuard<'_, Storage<T>>> {
141        self.0.storage.as_ref()
142            .ok_or(crate::Error::MetaTensor)
143            .map(|s| s.read())
144    }
145
146    pub fn storage_write(&self) -> crate::Result<std::sync::RwLockWriteGuard<'_, Storage<T>>> {
147        self.0.storage.as_ref()
148            .ok_or(crate::Error::MetaTensor)
149            .map(|s| s.write())
150    }
151    
152    pub fn element_count(&self) -> usize {
153        self.shape().element_count()
154    }
155
156    pub fn is_contiguous(&self) -> bool {
157        self.layout().is_contiguous()
158    }
159
160    pub fn rank(&self) -> usize {
161        self.shape().rank()
162    }
163
164    pub fn to_vec(&self) -> crate::Result<Vec<T>> {
165        self.iter().map(|i| i.collect())
166    }
167
168    /// Returns an iterator over **storage indices**.
169    ///
170    /// This iterator yields the linear (flat) indices as they are laid out
171    /// in the underlying storage buffer. The order depends on the memory
172    /// layout (e.g., row-major / column-major / with strides).
173    ///
174    /// Example for shape = (2, 2) in row-major layout:
175    /// yields: `0, 1, 2, 3`
176    pub fn storage_indices(&self) -> StorageIndices {
177        self.layout().storage_indices()
178    }
179
180    /// Returns an iterator over **dimension coordinates**.
181    ///
182    /// This iterator yields the multi-dimensional coordinates
183    /// (e.g., `[i, j, k, ...]`) of each element in the array, independent
184    /// of the physical storage layout.
185    ///
186    /// Example for shape = (2, 2):
187    /// yields: `[0, 0], [0, 1], [1, 0], [1, 1]`
188    pub fn dim_coordinates(&self) -> DimCoordinates {
189        self.shape().dim_coordinates()
190    }
191
192    pub fn dims_coordinates<const N: usize>(&self) -> crate::Result<DimNCoordinates<N>> {
193        self.shape().dims_coordinates::<N>()
194    }
195
196    pub fn dim2_coordinates(&self) -> crate::Result<DimNCoordinates<2>> {
197        self.shape().dim2_coordinates()
198    }
199
200    pub fn dim3_coordinates(&self) -> crate::Result<DimNCoordinates<3>> {
201        self.shape().dim3_coordinates()
202    }
203
204    pub fn dim4_coordinates(&self) -> crate::Result<DimNCoordinates<4>> {
205        self.shape().dim4_coordinates()
206    }
207
208    pub fn dim5_coordinates(&self) -> crate::Result<DimNCoordinates<5>> {
209        self.shape().dim5_coordinates()
210    }
211}
212
213impl<T: NumDType> Tensor<T> {
214    pub fn allclose(&self, other: &Self, rtol: f64, atol: f64) -> crate::Result<bool> {
215        if self.shape() != other.shape() {
216            return Ok(false);
217        }
218        Ok(
219            self.iter()?.zip(other.iter()?).all(|(a, b)| a.close(b, rtol, atol))
220        )
221    }
222}
223
224impl<T: FloatDType> Tensor<T> {
225    pub fn detach(&self) -> Self {
226        if !self.requires_grad() {
227            self.clone()
228        } else {
229            Self(Arc::new(TensorImpl { 
230                id: TensorId::new(), 
231                storage: self.0.storage.clone(), 
232                layout: self.layout().clone(), 
233                meta: AutogradInfo::val(), 
234            }))
235        }
236    }
237
238    #[inline]
239    pub fn requires_grad(&self) -> bool {
240        self.0.meta.requires_grad()
241    }
242    
243    #[inline]
244    pub fn set_requires_grad(&self, mode: bool) {
245        self.0.meta.set_requires_grad(mode);
246    }
247
248    #[inline]
249    pub fn op(&self) -> Option<&Op<T>> {
250        self.0.meta.op()
251    }
252
253    #[inline]
254    pub fn is_leaf(&self) -> bool {
255        self.0.meta.is_leaf()
256    }
257}