1mod construct;
2mod indexer;
3mod iter;
4pub mod display;
5mod shape;
6mod arith;
7mod matmul;
8mod reduce;
9mod broadcast;
10mod convert;
11mod boolean;
12
13pub use construct::ToTensor;
14use std::{borrow::Borrow, hash::Hash, sync::Arc};
15pub use indexer::{Slice, IndexOp};
16use crate::{AutogradInfo, Error, FloatDType, Op, Storage};
17use super::{DType, Dim, DimCoordinates, DimNCoordinates, Layout, NumDType, Shape, StorageArc, StorageIndices, StorageMut, StorageRef, WithDType};
18pub use iter::*;
19pub use indexer::*;
20
21#[derive(Clone)]
22pub struct Tensor<T: WithDType>(Arc<TensorImpl<T>>);
23
24#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
25pub struct TensorId(usize);
26
27struct TensorImpl<T: WithDType> {
28 id: TensorId,
29 storage: Option<StorageArc<T>>,
30 layout: Layout,
31 meta: T::AutogradMeta,
32}
33
34impl TensorId {
35 pub fn new() -> Self {
36 use std::sync::atomic;
37 static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
38 Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
39 }
40
41 pub fn value(&self) -> usize {
42 self.0
43 }
44}
45
46impl Borrow<usize> for TensorId {
47 fn borrow(&self) -> &usize {
48 &self.0
49 }
50}
51
52impl<T: WithDType> Hash for Tensor<T> {
53 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
54 self.0.id.0.hash(state);
55 }
56}
57
58impl<T: WithDType> PartialEq for Tensor<T> {
59 fn eq(&self, other: &Self) -> bool {
60 self.0.id.0 == other.0.id.0
61 }
62}
63
64impl<T: WithDType> Eq for Tensor<T> {}
65
66impl<T: WithDType> Tensor<T> {
67 pub fn is_scalar(&self) -> bool {
68 self.shape().is_scalar()
69 }
70
71 pub fn check_scalar(&self) -> crate::Result<()> {
72 if !self.is_scalar() {
73 Err(Error::NotScalar)?
74 } else {
75 Ok(())
76 }
77 }
78
79 pub fn to_scalar(&self) -> crate::Result<T> {
80 self.check_scalar()?;
81 let v = self.storage_read()?.get_unchecked(self.layout().start_offset());
82 Ok(v)
83 }
84
85 pub fn set_scalar(&self, val: T) -> crate::Result<()> {
86 self.check_scalar()?;
87 self.storage_write()?.set_unchecked(self.layout().start_offset(), val);
88 Ok(())
89 }
90
91 pub fn storage_ref<'a>(&'a self, start_offset: usize) -> crate::Result<StorageRef<'a, T>> {
92 self.0.storage.as_ref()
93 .ok_or(crate::Error::MetaTensor)
94 .map(|s| s.get_ref(start_offset))
95 }
96
97 pub fn storage_mut<'a>(&'a self, start_offset: usize) -> crate::Result<StorageMut<'a, T>> {
98 self.0.storage.as_ref()
99 .ok_or(crate::Error::MetaTensor)
100 .map(|s| s.get_mut(start_offset))
101 }
102
103 pub fn storage_ptr(&self, start_offset: usize) -> crate::Result<*mut T> {
104 self.0.storage.as_ref()
105 .ok_or(crate::Error::MetaTensor)
106 .map(|s| s.get_ptr(start_offset))
107 }
108
109 pub fn is_meta(&self) -> bool {
110 self.0.storage.is_none()
111 }
112}
113
114impl<T: WithDType> Tensor<T> {
115 pub fn id(&self) -> TensorId {
116 self.0.id
117 }
118
119 pub fn shape(&self) -> &Shape {
120 self.0.layout.shape()
121 }
122
123 pub fn dtype(&self) -> DType {
124 T::DTYPE
125 }
126
127 pub fn layout(&self) -> &Layout {
128 &self.0.layout
129 }
130
131 pub fn dims(&self) -> &[usize] {
132 self.shape().dims()
133 }
134
135 pub fn dim<D: Dim>(&self, dim: D) -> crate::Result<usize> {
136 let dim = dim.to_index(self.shape(), "dim")?;
137 Ok(self.dims()[dim])
138 }
139
140 pub fn storage_read(&self) -> crate::Result<std::sync::RwLockReadGuard<'_, Storage<T>>> {
141 self.0.storage.as_ref()
142 .ok_or(crate::Error::MetaTensor)
143 .map(|s| s.read())
144 }
145
146 pub fn storage_write(&self) -> crate::Result<std::sync::RwLockWriteGuard<'_, Storage<T>>> {
147 self.0.storage.as_ref()
148 .ok_or(crate::Error::MetaTensor)
149 .map(|s| s.write())
150 }
151
152 pub fn element_count(&self) -> usize {
153 self.shape().element_count()
154 }
155
156 pub fn is_contiguous(&self) -> bool {
157 self.layout().is_contiguous()
158 }
159
160 pub fn rank(&self) -> usize {
161 self.shape().rank()
162 }
163
164 pub fn to_vec(&self) -> crate::Result<Vec<T>> {
165 self.iter().map(|i| i.collect())
166 }
167
168 pub fn storage_indices(&self) -> StorageIndices {
177 self.layout().storage_indices()
178 }
179
180 pub fn dim_coordinates(&self) -> DimCoordinates {
189 self.shape().dim_coordinates()
190 }
191
192 pub fn dims_coordinates<const N: usize>(&self) -> crate::Result<DimNCoordinates<N>> {
193 self.shape().dims_coordinates::<N>()
194 }
195
196 pub fn dim2_coordinates(&self) -> crate::Result<DimNCoordinates<2>> {
197 self.shape().dim2_coordinates()
198 }
199
200 pub fn dim3_coordinates(&self) -> crate::Result<DimNCoordinates<3>> {
201 self.shape().dim3_coordinates()
202 }
203
204 pub fn dim4_coordinates(&self) -> crate::Result<DimNCoordinates<4>> {
205 self.shape().dim4_coordinates()
206 }
207
208 pub fn dim5_coordinates(&self) -> crate::Result<DimNCoordinates<5>> {
209 self.shape().dim5_coordinates()
210 }
211}
212
213impl<T: NumDType> Tensor<T> {
214 pub fn allclose(&self, other: &Self, rtol: f64, atol: f64) -> crate::Result<bool> {
215 if self.shape() != other.shape() {
216 return Ok(false);
217 }
218 Ok(
219 self.iter()?.zip(other.iter()?).all(|(a, b)| a.close(b, rtol, atol))
220 )
221 }
222}
223
224impl<T: FloatDType> Tensor<T> {
225 pub fn detach(&self) -> Self {
226 if !self.requires_grad() {
227 self.clone()
228 } else {
229 Self(Arc::new(TensorImpl {
230 id: TensorId::new(),
231 storage: self.0.storage.clone(),
232 layout: self.layout().clone(),
233 meta: AutogradInfo::val(),
234 }))
235 }
236 }
237
238 #[inline]
239 pub fn requires_grad(&self) -> bool {
240 self.0.meta.requires_grad()
241 }
242
243 #[inline]
244 pub fn set_requires_grad(&self, mode: bool) {
245 self.0.meta.set_requires_grad(mode);
246 }
247
248 #[inline]
249 pub fn op(&self) -> Option<&Op<T>> {
250 self.0.meta.op()
251 }
252
253 #[inline]
254 pub fn is_leaf(&self) -> bool {
255 self.0.meta.is_leaf()
256 }
257}