1use crate::actions::iter::{Iter, IterMut};
6use crate::error::{TensorError, TensorResult};
7use crate::ops::{BackpropOp, TensorExpr};
8use crate::prelude::{TensorId, TensorKind};
9use crate::shape::{IntoShape, IntoStride, Layout, Rank, Shape, Stride};
10
11#[cfg(not(feature = "std"))]
12use alloc::vec::{self, Vec};
13use core::iter::Map;
14use core::ops::{Index, IndexMut};
15use core::slice::Iter as SliceIter;
16#[cfg(feature = "std")]
17use std::vec;
18
19pub(crate) fn create<T>(
20 kind: impl Into<TensorKind>,
21 op: impl Into<BackpropOp<T>>,
22 shape: impl IntoShape,
23 data: Vec<T>,
24) -> TensorBase<T> {
25 TensorBase {
26 id: TensorId::new(),
27 data,
28 kind: kind.into(),
29 layout: Layout::contiguous(shape),
30 op: op.into(),
31 }
32}
33#[allow(dead_code)]
34pub(crate) fn from_scalar_with_op<T>(
35 kind: impl Into<TensorKind>,
36 op: TensorExpr<T>,
37 data: T,
38) -> TensorBase<T> {
39 create(
40 kind.into(),
41 BackpropOp::new(op),
42 Shape::scalar(),
43 vec![data],
44 )
45}
46
47pub(crate) fn from_vec_with_kind<T>(
48 kind: impl Into<TensorKind>,
49 shape: impl IntoShape,
50 data: Vec<T>,
51) -> TensorBase<T> {
52 create(kind, BackpropOp::none(), shape, data)
53}
54
55pub(crate) fn from_vec_with_op<T>(
56 kind: impl Into<TensorKind>,
57 op: TensorExpr<T>,
58 shape: impl IntoShape,
59 data: Vec<T>,
60) -> TensorBase<T> {
61 create(kind.into(), BackpropOp::new(op), shape, data)
62}
63
64#[derive(Clone, Debug, Hash)]
65#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
66pub struct TensorBase<T = f64> {
67 pub(crate) id: TensorId,
68 pub(crate) data: Vec<T>,
69 pub(crate) kind: TensorKind,
70 pub(crate) layout: Layout,
71 pub(crate) op: BackpropOp<T>,
72}
73
74impl<T> TensorBase<T> {
75 pub fn from_iter<I>(iter: I) -> Self
77 where
78 I: IntoIterator<Item = T>,
79 {
80 Self::from_vec(Vec::from_iter(iter))
81 }
82 pub unsafe fn from_raw_parts(
83 ptr: *mut T,
84 shape: impl IntoShape,
85 stride: impl IntoStride,
86 ) -> Self {
87 let shape = shape.into_shape();
88 let stride = stride.into_stride();
89
90 let data = Vec::from_raw_parts(ptr, shape.size(), shape.size());
91 Self {
92 id: TensorId::new(),
93 kind: TensorKind::default(),
94 layout: Layout::new(0, shape, stride),
95 data,
96 op: BackpropOp::none(),
97 }
98 }
99 pub fn from_scalar(value: T) -> Self {
101 Self {
102 id: TensorId::new(),
103 data: vec![value],
104 kind: TensorKind::default(),
105 layout: Layout::contiguous(()),
106 op: None.into(),
107 }
108 }
109 pub fn from_shape_iter<I>(shape: impl IntoShape, iter: I) -> Self
111 where
112 I: IntoIterator<Item = T>,
113 {
114 Self::from_shape_vec(shape, Vec::from_iter(iter))
115 }
116 pub unsafe fn from_shape_ptr(shape: impl IntoShape, ptr: *mut T) -> Self {
117 let layout = Layout::contiguous(shape);
118 let data = Vec::from_raw_parts(ptr, layout.size(), layout.size());
119 Self {
120 id: TensorId::new(),
121 kind: TensorKind::default(),
122 layout: layout.clone(),
123 data,
124 op: BackpropOp::none(),
125 }
126 }
127 pub fn from_shape_vec(shape: impl IntoShape, data: Vec<T>) -> Self {
129 Self {
130 id: TensorId::new(),
131 data,
132 kind: TensorKind::default(),
133 layout: Layout::contiguous(shape),
134 op: BackpropOp::none(),
135 }
136 }
137 pub fn from_vec(data: Vec<T>) -> Self {
139 let shape = Shape::from(data.len());
140 Self {
141 id: TensorId::new(),
142 data,
143 kind: TensorKind::default(),
144 layout: Layout::contiguous(shape),
145 op: BackpropOp::none(),
146 }
147 }
148 pub fn as_mut_ptr(&mut self) -> *mut T {
150 self.data_mut().as_mut_ptr()
151 }
152 pub fn as_ptr(&self) -> *const T {
154 self.data().as_ptr()
155 }
156 pub fn as_slice(&self) -> &[T] {
158 &self.data
159 }
160 pub fn as_mut_slice(&mut self) -> &mut [T] {
162 &mut self.data
163 }
164 pub fn assign(&mut self, other: &Self)
166 where
167 T: Clone,
168 {
169 self.data_mut()
170 .iter_mut()
171 .zip(other.data())
172 .for_each(|(a, b)| *a = b.clone());
173 }
174
175 pub fn boxed(self) -> Box<Self> {
176 Box::new(self)
177 }
178 pub fn detach(&self) -> Self
180 where
181 T: Clone,
182 {
183 if self.op.is_none() && !self.is_variable() {
184 self.clone()
185 } else {
186 Self {
187 id: self.id,
188 kind: self.kind,
189 layout: self.layout.clone(),
190 op: BackpropOp::none(),
191 data: self.data.clone(),
192 }
193 }
194 }
195 pub fn first(&self) -> Option<&T> {
197 let pos = vec![0; *self.rank()];
198 self.get(pos)
199 }
200 pub fn first_mut(&mut self) -> Option<&mut T> {
202 let pos = vec![0; *self.rank()];
203 self.get_mut(pos)
204 }
205 pub fn get(&self, index: impl AsRef<[usize]>) -> Option<&T> {
207 let i = self.layout.index(index);
208 self.data().get(i)
209 }
210 pub fn get_mut(&mut self, index: impl AsRef<[usize]>) -> Option<&mut T> {
212 let i = self.layout.index(index);
213 self.data_mut().get_mut(i)
214 }
215 pub const fn id(&self) -> TensorId {
217 self.id
218 }
219
220 pub unsafe fn into_scalar(self) -> T
221 where
222 T: Clone,
223 {
224 debug_assert!(self.is_scalar(), "Tensor is not scalar");
225 self.data.first().unwrap().clone()
226 }
227 pub fn is_contiguous(&self) -> bool {
229 self.layout().is_contiguous()
230 }
231 pub fn is_empty(&self) -> bool {
233 self.data().is_empty()
234 }
235 pub fn is_scalar(&self) -> bool {
237 *self.rank() == 0
238 }
239 pub fn is_square(&self) -> bool {
241 self.shape().is_square()
242 }
243 pub const fn is_variable(&self) -> bool {
245 self.kind().is_variable()
246 }
247 pub fn iter(&self) -> Iter<'_, T> {
249 Iter::new(self.view())
250 }
251 pub fn iter_mut(&mut self) -> IterMut<'_, T> {
253 IterMut::new(self)
254 }
255 pub const fn kind(&self) -> TensorKind {
257 self.kind
258 }
259 pub fn last(&self) -> Option<&T> {
261 let pos = self.shape().get_final_position();
262 self.get(pos)
263 }
264 pub fn last_mut(&mut self) -> Option<&mut T> {
266 let pos = self.shape().get_final_position();
267 self.get_mut(pos)
268 }
269 pub const fn layout(&self) -> &Layout {
271 &self.layout
272 }
273 pub fn ncols(&self) -> usize {
275 self.shape().ncols()
276 }
277 pub fn nrows(&self) -> usize {
279 self.shape().nrows()
280 }
281 pub const fn op(&self) -> &BackpropOp<T> {
283 &self.op
284 }
285 pub fn op_view(&self) -> BackpropOp<&T> {
287 self.op().view()
288 }
289 pub fn rank(&self) -> Rank {
291 self.shape().rank()
292 }
293 pub fn set(&mut self, index: impl AsRef<[usize]>, value: T) {
295 let i = self.layout().index(index);
296 self.data_mut()[i] = value;
297 }
298 pub fn shape(&self) -> &Shape {
300 self.layout().shape()
301 }
302 pub fn size(&self) -> usize {
304 self.layout().size()
305 }
306 pub fn strides(&self) -> &Stride {
308 self.layout().strides()
309 }
310 pub fn to_scalar(&self) -> TensorResult<&T> {
313 if !self.is_scalar() {
314 return Err(TensorError::NotScalar);
315 }
316 Ok(self.first().unwrap())
317 }
318 pub fn to_vec(&self) -> Vec<T>
320 where
321 T: Clone,
322 {
323 self.data().to_vec()
324 }
325 pub fn variable(mut self) -> Self {
327 self.kind = TensorKind::Variable;
328 self
329 }
330 pub fn with_layout(self, layout: Layout) -> Self {
332 if layout.size() != self.size() {
333 panic!("Size mismatch");
334 }
335 unsafe { self.with_layout_unchecked(layout) }
336 }
337 pub unsafe fn with_layout_unchecked(mut self, layout: Layout) -> Self {
343 self.layout = layout;
344 self
345 }
346
347 pub fn with_op(mut self, op: BackpropOp<T>) -> Self {
348 self.op = op;
349 self
350 }
351
352 pub fn with_shape_c(mut self, shape: impl IntoShape) -> Self {
353 self.layout = self.layout.with_shape_c(shape);
354 self
355 }
356}
357
358impl<'a, T> TensorBase<&'a T> {
359 }
370
371impl<T> TensorBase<T> {
372 pub fn view_from_scalar(scalar: &T) -> TensorBase<&T> {
373 TensorBase {
374 id: TensorId::new(),
375 kind: TensorKind::default(),
376 layout: Layout::scalar(),
377 op: BackpropOp::none(),
378 data: vec![scalar],
379 }
380 }
381 pub fn to_owned(&self) -> TensorBase<T>
382 where
383 T: Clone,
384 {
385 self.clone()
386 }
387
388 pub fn view(&self) -> TensorBase<&T> {
389 TensorBase {
390 id: self.id(),
391 kind: self.kind(),
392 layout: self.layout().clone(),
393 op: self.op().view(),
394 data: self.data().iter().collect(),
395 }
396 }
397
398 pub fn view_mut(&mut self) -> TensorBase<&mut T> {
399 TensorBase {
400 id: self.id(),
401 kind: self.kind(),
402 layout: self.layout().clone(),
403 op: self.op.view_mut(),
404 data: self.data.iter_mut().collect(),
405 }
406 }
407}
408#[allow(dead_code)]
410impl<T> TensorBase<T> {
411 pub(crate) fn data(&self) -> &Vec<T> {
412 &self.data
413 }
414
415 pub(crate) fn data_mut(&mut self) -> &mut Vec<T> {
416 &mut self.data
417 }
418
419 pub(crate) fn get_by_index(&self, index: usize) -> Option<&T> {
420 self.data.get(index)
421 }
422
423 pub(crate) fn get_mut_by_index(&mut self, index: usize) -> Option<&mut T> {
424 self.data.get_mut(index)
425 }
426
427 pub(crate) fn map<'a, F>(&'a self, f: F) -> Map<SliceIter<'a, T>, F>
428 where
429 F: FnMut(&'a T) -> T,
430 T: 'a + Clone,
431 {
432 self.data.iter().map(f)
433 }
434
435 pub(crate) fn mapv<F>(&self, f: F) -> TensorBase<T>
436 where
437 F: Fn(T) -> T,
438 T: Copy,
439 {
440 let store = self.data.iter().copied().map(f).collect();
441 TensorBase {
442 id: TensorId::new(),
443 kind: self.kind,
444 layout: self.layout.clone(),
445 op: self.op.clone(),
446 data: store,
447 }
448 }
449
450 pub(crate) fn map_binary<F>(&self, other: &TensorBase<T>, op: F) -> TensorBase<T>
451 where
452 F: acme::prelude::BinOp<T, T, Output = T>,
453 T: Copy,
454 {
455 let store = self
456 .iter()
457 .zip(other.iter())
458 .map(|(a, b)| op.eval(*a, *b))
459 .collect();
460 TensorBase {
461 id: TensorId::new(),
462 kind: self.kind,
463 layout: self.layout.clone(),
464 op: self.op.clone(),
465 data: store,
466 }
467 }
468}
469
470impl<'a, T> AsRef<TensorBase<T>> for TensorBase<&'a T> {
471 fn as_ref(&self) -> &TensorBase<T> {
472 unsafe { &*(self as *const TensorBase<&'a T> as *const TensorBase<T>) }
473 }
474}
475
476impl<Idx, T> Index<Idx> for TensorBase<T>
477where
478 Idx: AsRef<[usize]>,
479{
480 type Output = T;
481
482 fn index(&self, index: Idx) -> &Self::Output {
483 let i = self.layout().index(index);
484 &self.data[i]
485 }
486}
487
488impl<Idx, T> IndexMut<Idx> for TensorBase<T>
489where
490 Idx: AsRef<[usize]>,
491{
492 fn index_mut(&mut self, index: Idx) -> &mut Self::Output {
493 let i = self.layout().index(index);
494 &mut self.data[i]
495 }
496}
497
498impl<T> Eq for TensorBase<T> where T: Eq {}
499
500impl<T> Ord for TensorBase<T>
501where
502 T: Ord,
503{
504 fn cmp(&self, other: &Self) -> core::cmp::Ordering {
505 self.data.cmp(&other.data)
506 }
507}
508
509impl<T> PartialEq for TensorBase<T>
510where
511 T: PartialEq,
512{
513 fn eq(&self, other: &Self) -> bool {
514 self.layout == other.layout && self.data == other.data
515 }
516}
517
518impl<S, T> PartialEq<S> for TensorBase<T>
519where
520 S: AsRef<[T]>,
521 T: PartialEq,
522{
523 fn eq(&self, other: &S) -> bool {
524 &self.data == other.as_ref()
525 }
526}
527
528impl<T> PartialOrd for TensorBase<T>
529where
530 T: PartialOrd,
531{
532 fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
533 self.data.partial_cmp(&other.data)
534 }
535}