1use crate::dtype::DType;
6use crate::scalar::{Scalar, Float};
7use crate::shape::{to_axis, IntoAxes, IntoPadding, IntoShape};
8use core::cmp::Ordering;
9use std::collections::{BTreeMap, BTreeSet};
10use std::fmt::{Debug, Display};
11use std::iter::repeat;
12use std::ops::{
13 Add, BitAnd, BitOr, BitXor, Bound, Div, Mul, Neg, Not, Range, RangeBounds, RangeFrom,
14 RangeFull, RangeInclusive, RangeTo, RangeToInclusive, Sub,
15};
16use std::path::Path;
17
18use crate::runtime::ZyxError;
19use crate::RT;
20
21#[cfg(feature = "half")]
22use half::{bf16, f16};
23
24#[cfg(feature = "complex")]
25use num_complex::Complex;
26
27pub(crate) type TensorId = usize;
28
29#[cfg_attr(feature = "py", pyo3::pyclass)]
32pub struct Tensor {
33 id: TensorId,
34}
35
36impl Clone for Tensor {
37 fn clone(&self) -> Self {
38 RT.lock().retain(self.id);
39 Tensor { id: self.id }
40 }
41}
42
43impl Drop for Tensor {
44 fn drop(&mut self) {
45 RT.lock().release(self.id).unwrap();
47 }
48}
49
50impl Tensor {
51 #[must_use]
53 pub fn shape(&self) -> Vec<usize> {
54 RT.lock().shape(self.id).to_vec()
55 }
56
57 #[must_use]
59 pub fn numel(&self) -> usize {
60 self.shape().iter().product()
61 }
62
63 #[must_use]
65 pub fn rank(&self) -> usize {
66 self.shape().len()
67 }
68
69 #[must_use]
71 pub fn dtype(&self) -> DType {
72 RT.lock().dtype(self.id)
73 }
74
75 #[must_use]
77 pub fn training() -> bool {
78 RT.lock().training
79 }
80
81 pub fn set_training(training: bool) {
83 RT.lock().training = training;
84 }
85
86 pub fn realize<'a>(tensors: impl IntoIterator<Item = &'a Tensor>) -> Result<(), ZyxError> {
88 RT.lock()
89 .realize(tensors.into_iter().map(|t| t.id).collect())
90 }
91
92 #[must_use]
94 pub fn backward<'a>(
95 &self,
96 sources: impl IntoIterator<Item = &'a Tensor>,
97 ) -> Vec<Option<Tensor>> {
98 let sources: Vec<TensorId> = sources.into_iter().map(|t| t.id).collect();
99 let grads: BTreeMap<TensorId, TensorId> = RT
100 .lock()
101 .backward(self.id, sources.iter().copied().collect());
102 sources
103 .into_iter()
104 .map(|x: TensorId| grads.get(&x).copied())
105 .map(|id: Option<TensorId>| id.map(|id| Tensor { id }))
106 .collect()
107 }
108
109 #[must_use]
121 pub fn detach(self) -> Result<Tensor, ZyxError> {
122 let shape = self.shape();
124 let dtype = self.dtype();
125 match dtype {
126 #[cfg(feature = "half")]
127 DType::F16 => {
128 let data: Vec<f16> = self.try_into()?;
129 Tensor::from(data).reshape(shape)
130 }
131 #[cfg(feature = "half")]
132 DType::BF16 => {
133 let data: Vec<bf16> = self.try_into()?;
134 Tensor::from(data).reshape(shape)
135 }
136 DType::F32 => {
137 let data: Vec<f32> = self.try_into()?;
138 Tensor::from(data).reshape(shape)
139 }
140 DType::F64 => {
141 let data: Vec<f64> = self.try_into()?;
142 Tensor::from(data).reshape(shape)
143 }
144 #[cfg(feature = "complex")]
145 DType::CF32 => {
146 let data: Vec<Complex<f32>> = self.try_into()?;
147 Tensor::from(data).reshape(shape)
148 }
149 #[cfg(feature = "complex")]
150 DType::CF64 => {
151 let data: Vec<Complex<f64>> = self.try_into()?;
152 Tensor::from(data).reshape(shape)
153 }
154 DType::U8 => {
155 let data: Vec<u8> = self.try_into()?;
156 Tensor::from(data).reshape(shape)
157 }
158 DType::I8 => {
159 let data: Vec<i8> = self.try_into()?;
160 Tensor::from(data).reshape(shape)
161 }
162 DType::I16 => {
163 let data: Vec<i16> = self.try_into()?;
164 Tensor::from(data).reshape(shape)
165 }
166 DType::I32 => {
167 let data: Vec<i32> = self.try_into()?;
168 Tensor::from(data).reshape(shape)
169 }
170 DType::I64 => {
171 let data: Vec<i64> = self.try_into()?;
172 Tensor::from(data).reshape(shape)
173 }
174 DType::Bool => {
175 let data: Vec<bool> = self.try_into()?;
176 Tensor::from(data).reshape(shape)
177 }
178 }
179 }
180
181 #[must_use]
192 pub fn debug_guard(debug: u32) -> DebugGuard {
193 let mut rt = RT.lock();
194 let guard = DebugGuard { debug: rt.debug };
195 rt.debug = debug;
196 guard
197 }
198
199 pub fn plot_graph<'a>(
203 tensors: impl IntoIterator<Item = &'a Tensor>,
204 name: &str,
205 ) -> Result<(), std::io::Error> {
206 use std::format;
207 let graph = RT
208 .lock()
209 .plot_dot_graph(&tensors.into_iter().map(|t| t.id).collect());
210 std::fs::write(format!("{name}.dot"), graph)?;
211 let output = std::process::Command::new("dot")
212 .arg("-Tpng")
213 .arg(format!("{name}.dot"))
214 .arg("-o")
215 .arg(format!("{name}.png"))
216 .output();
217 if let Err(err) = output {
218 println!("Graph png could not be created: {err}");
219 } else {
220 let _ = std::fs::remove_file(format!("{name}.dot"));
221 }
222 Ok(())
223 }
224
225 #[cfg(feature = "rand")]
228 pub fn manual_seed(seed: u64) {
229 RT.lock().manual_seed(seed);
230 }
231
232 #[cfg(feature = "rand")]
235 #[must_use]
236 pub fn rand(shape: impl IntoShape, dtype: DType) -> Result<Tensor, ZyxError> {
237 const SEED: u64 = 69420;
238 use std::i32;
239
240 use rand::distributions::Uniform;
241 use rand::rngs::SmallRng;
242 use rand::Rng;
243 use rand::SeedableRng;
244 let shape: Vec<usize> = shape.into_shape().collect();
245 let n = shape.iter().product();
246 if dtype.is_float() {
247 let mut rt = RT.lock();
249 rt.rng.get_or_init(|| SmallRng::seed_from_u64(SEED));
250 let Some(rng) = rt.rng.get_mut() else {
251 panic!()
252 };
253 match dtype {
254 DType::F32 => {
255 let range = Uniform::new(0., 1.);
256 let data: Vec<f32> = (0..n).map(|_| rng.sample(&range)).collect();
257 Ok(Tensor {
258 id: rt.variable(shape, &data)?,
259 })
260 }
261 DType::F64 => {
262 let range = Uniform::new(0., 1.);
263 let data: Vec<f64> = (0..n).map(|_| rng.sample(&range)).collect();
264 Ok(Tensor {
265 id: rt.variable(shape, &data)?,
266 })
267 }
268 _ => panic!(),
269 }
270 } else {
271 let mut rt = RT.lock();
272 rt.rng.get_or_init(|| SmallRng::seed_from_u64(SEED));
273 let Some(rng) = rt.rng.get_mut() else {
274 panic!()
275 };
276 match dtype {
277 DType::U8 => {
278 let range = Uniform::new(0, u8::MAX);
279 let data: Vec<u8> = (0..n).map(|_| rng.sample(&range)).collect();
280 Ok(Tensor {
281 id: rt.variable(shape, &data)?,
282 })
283 }
284 DType::I8 => {
285 let range = Uniform::new(0, i8::MAX);
286 let data: Vec<i8> = (0..n).map(|_| rng.sample(&range)).collect();
287 Ok(Tensor {
288 id: rt.variable(shape, &data)?,
289 })
290 }
291 DType::I16 => {
292 let range = Uniform::new(0, i16::MAX);
293 let data: Vec<i16> = (0..n).map(|_| rng.sample(&range)).collect();
294 Ok(Tensor {
295 id: rt.variable(shape, &data)?,
296 })
297 }
298 DType::I32 => {
299 let range = Uniform::new(0, i32::MAX);
300 let data: Vec<i32> = (0..n).map(|_| rng.sample(&range)).collect();
301 Ok(Tensor {
302 id: rt.variable(shape, &data)?,
303 })
304 }
305 DType::I64 => {
306 let range = Uniform::new(0, i64::MAX);
307 let data: Vec<i64> = (0..n).map(|_| rng.sample(&range)).collect();
308 Ok(Tensor {
309 id: rt.variable(shape, &data)?,
310 })
311 }
312 _ => panic!(),
313 }
314 }
315 }
340
341 #[cfg(feature = "rand")]
344 #[must_use]
345 pub fn randn(shape: impl IntoShape, dtype: DType) -> Result<Tensor, ZyxError> {
346 let shape: Vec<usize> = [2].into_iter().chain(shape.into_shape()).collect();
348 let src = Tensor::rand(shape, dtype)?;
349 let mut x = src.get(0)?;
350 x = x.mul(Tensor::constant(2f32 * std::f32::consts::PI));
351 x = x.cos();
353 let mut y = Tensor::constant(1f32) - src.get(1)?;
354 y = y.ln().mul(Tensor::constant(-2f32)).sqrt();
356 Ok(x.mul(y).cast(dtype))
358 }
359
360 #[cfg(feature = "rand")]
363 #[must_use]
364 pub fn uniform<T: Scalar>(
365 shape: impl IntoShape,
366 range: impl core::ops::RangeBounds<T>,
367 ) -> Result<Tensor, ZyxError> {
368 use core::ops::Bound;
369 let low = match range.start_bound() {
370 Bound::Included(value) => *value,
371 Bound::Excluded(value) => *value,
372 Bound::Unbounded => T::min_value(),
373 };
374 let high = match range.end_bound() {
375 Bound::Included(value) => *value,
376 Bound::Excluded(value) => *value,
377 Bound::Unbounded => T::max_value(),
378 };
379 Ok(Tensor::rand(shape, T::dtype())? * high.sub(low) + low)
380 }
381
382 #[cfg(feature = "rand")]
384 #[must_use]
385 pub fn kaiming_uniform<T: Scalar>(shape: impl IntoShape, a: T) -> Result<Tensor, ZyxError> {
386 let n = T::from_i64(shape.clone().into_shape().skip(1).product::<usize>() as i64);
387 let one = T::one();
388 let x = Scalar::add(one, Scalar::mul(a, a));
389 let two = Scalar::add(one, one);
390 let three = Scalar::add(two, one);
391 let x = Scalar::div(two, x).sqrt();
392 let bound = Scalar::mul(three.sqrt(), Scalar::div(x, n));
393 return Tensor::uniform(shape, bound.neg()..bound);
394 }
395
396 #[must_use]
398 pub fn zeros(shape: impl IntoShape, dtype: DType) -> Tensor {
399 return Tensor {
400 id: RT.lock().zeros(shape.into_shape().collect(), dtype),
401 };
402 }
403
404 #[must_use]
406 pub fn ones(shape: impl IntoShape, dtype: DType) -> Tensor {
407 return Tensor {
408 id: RT.lock().ones(shape.into_shape().collect(), dtype),
409 };
410 }
411
412 #[must_use]
414 pub fn full(shape: impl IntoShape, value: impl Scalar) -> Result<Tensor, ZyxError> {
415 return Ok(Tensor {
416 id: RT.lock().full(shape.into_shape().collect(), value)?,
417 });
418 }
419
420 #[must_use]
422 pub fn eye(n: usize, dtype: DType) -> Tensor {
423 Tensor::ones(vec![n, 1], dtype)
424 .pad_zeros([(0, n as isize)])
425 .unwrap()
426 .reshape([n + 1, n])
427 .unwrap()
428 .get((..-1, ..)).unwrap()
429 }
430
431 #[must_use]
433 pub fn arange<T: Scalar>(start: T, stop: T, step: T) -> Result<Tensor, ZyxError> {
434 let n: i64 = stop.sub(start).div(step).cast();
437 let n = n as usize;
438 let m = start.sub(step);
440 let x = Tensor::full(n, step)?;
441 let x = x.cumsum(0)?;
443 Ok(x + m)
444 }
445
446 #[must_use]
452 pub fn constant(value: impl Scalar) -> Tensor {
453 Tensor {
454 id: RT.lock().constant(value),
455 }
456 }
457
458 #[must_use]
461 pub fn abs(&self) -> Tensor {
462 self.relu() + (-self).relu()
463 }
464
465 #[must_use]
467 pub fn cast(&self, dtype: DType) -> Tensor {
468 return Tensor {
469 id: RT.lock().cast(self.id, dtype),
470 };
471 }
472
473 #[must_use]
475 pub fn celu(&self, alpha: impl Scalar) -> Tensor {
476 return self.relu() - (-((self / alpha).exp() - 1) * alpha).relu();
477 }
478
479 #[must_use]
481 pub fn cos(&self) -> Tensor {
482 let x = self.float_cast();
483 let x = Tensor {
484 id: RT.lock().cos(x.id),
485 };
486 x
487 }
488
489 #[must_use]
491 pub fn cosh(&self) -> Tensor {
492 let nx = self.neg();
494 let enx = nx.exp();
495 let ex = self.exp();
496 (ex + enx) / 2
497 }
498
499 #[cfg(feature = "rand")]
505 #[must_use]
506 pub fn dropout<P: Scalar + Float>(&self, probability: P) -> Result<Tensor, ZyxError> {
507 Ok(Tensor::from(probability).cmplt(Tensor::rand(self.shape(), P::dtype())?)? * self)
509 }
510
511 #[must_use]
520 pub fn elu(&self, alpha: impl Scalar) -> Tensor {
521 self.relu() - (Tensor::ones(1, self.dtype()) - self.exp()).relu() * alpha
522 }
523
524 #[must_use]
526 pub fn exp2(&self) -> Tensor {
527 let x = self.float_cast();
528 let x = Tensor {
529 id: RT.lock().exp2(x.id),
530 };
531 x
532 }
533
534 #[must_use]
544 pub fn exp(&self) -> Tensor {
545 let c: Tensor = Tensor::constant(std::f64::consts::E.log2());
546 (self * c.cast(self.dtype())).exp2()
547 }
548
549 #[must_use]
554 pub fn gelu(&self) -> Result<Tensor, ZyxError> {
555 Ok(self * 0.5f32
556 * (((self + self.pow(3f32)? * 0.044_715f32) * (2f32 / core::f32::consts::PI).sqrt())
557 .tanh()
558 + 1f32))
559 }
560
561 #[must_use]
575 pub fn leaky_relu(&self, neg_slope: impl Scalar) -> Tensor {
576 self.relu() - (self * (-Tensor::from(neg_slope))).relu()
577 }
578
579 #[must_use]
589 pub fn log2(&self) -> Tensor {
590 let x = self.float_cast();
591 return Tensor {
592 id: RT.lock().log2(x.id),
593 };
594 }
595
596 #[must_use]
610 pub fn ln(&self) -> Tensor {
611 let x = self.float_cast();
612 let c: Tensor = Tensor::constant(1f64 / std::f64::consts::E.log2());
613 x.log2() * c.cast(x.dtype())
614 }
615
616 #[must_use]
626 pub fn inv(&self) -> Tensor {
627 return Tensor {
628 id: RT.lock().inv(self.id),
629 };
630 }
631
632 #[must_use]
642 pub fn mish(&self) -> Tensor {
643 self * self.softplus(1, 20).tanh()
644 }
645
646 #[must_use]
656 pub fn quick_gelu(&self) -> Tensor {
657 self * (1.702f32 * self).sigmoid()
658 }
659
660 #[must_use]
670 pub fn reciprocal(&self) -> Tensor {
671 return Tensor {
672 id: RT.lock().reciprocal(self.id),
673 };
674 }
675
676 #[must_use]
686 pub fn relu(&self) -> Tensor {
687 return Tensor {
688 id: RT.lock().relu(self.id),
689 };
690 }
691
692 #[must_use]
702 pub fn rsqrt(&self) -> Tensor {
703 self.reciprocal().sqrt()
704 }
705
706 #[must_use]
716 pub fn selu(&self) -> Tensor {
717 1.0507009873554804934193349852946f32
718 * (self.relu()
719 - (1.6732632423543772848170429916717f32
720 * (Tensor::ones(1, self.dtype()) - self.exp()))
721 .relu())
722 }
723
724 #[must_use]
734 pub fn sigmoid(&self) -> Tensor {
735 let one = Tensor::ones(1, self.dtype());
736 let exp_x = self.exp();
737 return &exp_x / (&one + &exp_x);
738 }
739
740 #[must_use]
750 pub fn sin(&self) -> Tensor {
751 let x = self.float_cast();
752 let x = Tensor {
753 id: RT.lock().sin(x.id),
754 };
755 x
756 }
757
758 #[must_use]
768 pub fn sinh(&self) -> Tensor {
769 let nx = self.neg();
771 let enx = nx.exp();
772 let ex = self.exp();
773 (ex - enx) / 2
774 }
775
776 #[must_use]
788 pub fn softplus(&self, beta: impl Scalar, threshold: impl Scalar) -> Tensor {
789 let x = self * beta;
790 x.cmplt(threshold).unwrap().where_(((x).exp() + 1).ln() * beta.reciprocal(), x).unwrap()
791 }
792
793 #[must_use]
803 pub fn sqrt(&self) -> Tensor {
804 let x = self.float_cast();
805 let x = Tensor {
806 id: RT.lock().sqrt(x.id),
807 };
808 x
809 }
810
811 #[must_use]
821 pub fn swish(&self) -> Tensor {
822 self * self.sigmoid()
823 }
824
825 #[must_use]
835 pub fn tan(&self) -> Tensor {
836 self.sin() / self.cos()
837 }
838
839 #[must_use]
856 pub fn tanh(&self) -> Tensor {
857 let x = (self + self).sigmoid();
858 (&x + &x) - Tensor::constant(1).cast(self.dtype())
859 }
860
861 #[must_use]
875 pub fn expand(&self, shape: impl IntoShape) -> Result<Tensor, ZyxError> {
876 let mut sh = self.shape();
877 let shape: Vec<usize> = shape.into_shape().collect();
878 if shape.rank() < sh.rank() {
880 return Err(ZyxError::ShapeError(format!("Cannot expand {:?} into {:?}", self.shape(), shape)));
881 }
882 if shape.rank() > sh.rank() {
883 let mut i = sh.len();
884 for d in shape.iter().copied().rev() {
885 if i == 0 {
886 sh.insert(i, 1);
888 } else {
889 i -= 1;
890 }
891 if d != sh[i] {
892 if sh[i] != 1 {
893 return Err(ZyxError::ShapeError(format!("Cannot expand {:?} into {:?}", self.shape(), shape)));
894 }
895 }
896 }
897 let x = self.reshape(sh).unwrap();
898 let id = RT.lock().expand(x.id, shape);
899 drop(x);
900 return Ok(Tensor { id })
901 };
902 Ok(Tensor { id: RT.lock().expand(self.id, shape) })
903 }
904
905 #[must_use]
922 pub fn permute(&self, axes: impl IntoAxes) -> Result<Tensor, ZyxError> {
923 let rank = self.rank();
924 let axes: Vec<usize> = axes.into_axes(rank).collect();
925 if rank != axes.len() {
926 return Err(ZyxError::ShapeError(format!("Axes has rank {}, but tensor has rank {}. It must be the same for permute.", axes.len(), rank)));
927 }
928 Ok(Tensor { id: RT.lock().permute(self.id, axes) })
929 }
930
931 #[must_use]
950 pub fn pad_zeros(&self, padding: impl IntoPadding) -> Result<Tensor, ZyxError> {
951 let padding = padding.into_padding();
952 for (i, &(l, r)) in padding.iter().enumerate() {
953 let shape = self.shape();
954 let rank = shape.len();
955 let mut total = 0;
956 if l < 0 {
957 total -= l;
958 }
959 if r < 0 {
960 total -= r;
961 }
962 if (total as usize) >= shape[rank-i-1] {
963 return Err(ZyxError::ShapeError(format!("Invalid padding {padding:?} on shape {shape:?}")));
964 }
965 }
966 Ok(Tensor { id: RT.lock().pad_zeros(self.id, padding) })
967 }
968
969 #[must_use]
1000 pub fn pad(
1001 &self,
1002 padding: impl IntoPadding,
1003 value: impl Into<Tensor>,
1004 ) -> Result<Tensor, ZyxError> {
1005 let dtype = self.dtype();
1006 let value: Tensor = value.into();
1007 let padding = padding.into_padding();
1008 let sh = self.shape();
1009 if value.dtype() != dtype {
1010 return Err(ZyxError::DTypeError(format!("Cannot pad tensor with dtype {} with value of dtype {}", dtype, value.dtype())));
1011 }
1012 if !padding.len() <= sh.rank() && padding.iter().zip(sh.iter().rev()).all(|((lp, rp), d)| if *lp < 0 { ((-*lp) as usize) <= *d } else { true } && if *rp < 0 { ((-*rp) as usize) <= *d } else { true }) {
1013 return Err(ZyxError::ShapeError(format!("Cannot pad tensor with shape {sh:?} with padding {padding:?}")));
1014 }
1015 let t0 = self.pad_zeros(padding.clone());
1016 if value.numel() == 1
1017 && match dtype {
1018 #[cfg(feature = "half")]
1019 DType::BF16 => {
1020 let x: bf16 = value.clone().try_into()?;
1021 x == bf16::ZERO
1022 }
1023 #[cfg(feature = "half")]
1024 DType::F16 => {
1025 let x: f16 = value.clone().try_into()?;
1026 x == f16::ZERO
1027 }
1028 DType::F32 => {
1029 let x: f32 = value.clone().try_into()?;
1030 x == 0.
1031 }
1032 DType::F64 => {
1033 let x: f64 = value.clone().try_into()?;
1034 x == 0.
1035 }
1036 #[cfg(feature = "complex")]
1037 DType::CF32 => {
1038 let x: Complex<f32> = value.clone().try_into()?;
1039 x == Complex::new(0., 0.)
1040 }
1041 #[cfg(feature = "complex")]
1042 DType::CF64 => {
1043 let x: Complex<f64> = value.clone().try_into()?;
1044 x == Complex::new(0., 0.)
1045 }
1046 DType::U8 => {
1047 let x: u8 = value.clone().try_into()?;
1048 x == 0
1049 }
1050 DType::I8 => {
1051 let x: i8 = value.clone().try_into()?;
1052 x == 0
1053 }
1054 DType::I16 => {
1055 let x: i16 = value.clone().try_into()?;
1056 x == 0
1057 }
1058 DType::I32 => {
1059 let x: i32 = value.clone().try_into()?;
1060 x == 0
1061 }
1062 DType::I64 => {
1063 let x: i64 = value.clone().try_into()?;
1064 x == 0
1065 }
1066 DType::Bool => {
1067 let x: bool = value.clone().try_into()?;
1068 x == false
1069 }
1070 }
1071 {
1072 t0
1073 } else {
1074 let ones = Tensor::ones(sh.clone(), dtype);
1075 let zeros = Tensor::zeros(sh, self.dtype());
1076 Ok(t0? + ones.pad_zeros(padding)?.where_(zeros, value)?)
1077 }
1078 }
1079
1080 #[must_use]
1094 pub fn reshape(&self, shape: impl IntoShape) -> Result<Tensor, ZyxError> {
1095 let shape: Vec<usize> = shape.into_shape().collect();
1096 if shape.iter().product::<usize>() != self.numel() {
1097 return Err(ZyxError::ShapeError(format!("Invalid reshape {:?} into {:?}", self.shape(), shape)));
1098 };
1099 Ok(Tensor { id: RT.lock().reshape(self.id, shape) })
1100 }
1101
1102 #[must_use]
1104 pub fn view(&self, shape: impl IntoShape) -> Result<Tensor, ZyxError> {
1105 self.reshape(shape)
1106 }
1107
1108 #[must_use]
1111 pub fn t(&self) -> Tensor {
1112 let mut rank = self.rank();
1113 let x = if rank == 1 {
1114 let n = self.numel();
1115 rank = 2;
1116 self.reshape([1, n]).unwrap()
1117 } else {
1118 self.clone()
1119 };
1120 let mut axes: Vec<isize> = (0..rank as isize).collect();
1121 axes.swap(rank - 1, rank - 2);
1122 x.permute(axes).unwrap()
1123 }
1124
1125 #[must_use]
1127 pub fn transpose(&self, dim0: isize, dim1: isize) -> Result<Tensor, ZyxError> {
1128 let rank = self.rank();
1129 if dim0 < 0 {
1130 if (-dim0) as usize >= rank {
1131 return Err(ZyxError::ShapeError(format!("Cannot transpose dimensions {dim0} and {dim1}, {dim0} is greater than rank {rank}")));
1132 }
1133 } else {
1134 if dim0 as usize >= rank {
1135 return Err(ZyxError::ShapeError(format!("Cannot transpose dimensions {dim0} and {dim1}, {dim0} is greater than rank {rank}")));
1136 }
1137 }
1138 if dim1 < 0 {
1139 if (-dim1) as usize >= rank {
1140 return Err(ZyxError::ShapeError(format!("Cannot transpose dimensions {dim0} and {dim1}, {dim1} is greater than rank {rank}")));
1141 }
1142 } else {
1143 if dim1 as usize >= rank {
1144 return Err(ZyxError::ShapeError(format!("Cannot transpose dimensions {dim0} and {dim1}, {dim1} is greater than rank {rank}")));
1145 }
1146 }
1147 let mut axes: Vec<isize> = (0..rank as isize).collect();
1148 axes.swap(to_axis(dim0, rank), to_axis(dim1, rank));
1149 self.permute(axes)
1150 }
1151
1152 pub fn ln_softmax(&self, axes: impl IntoAxes) -> Result<Tensor, ZyxError> {
1181 let m = self - self.max_kd(axes.clone())?;
1182 Ok(&m - m.exp().sum_kd(axes)?.ln())
1183 }
1184
1185 #[must_use]
1204 pub fn max(&self, axes: impl IntoAxes) -> Result<Tensor, ZyxError> {
1205 let rank = self.rank();
1206 let axes: Vec<usize> = axes.into_axes(rank).collect();
1207 let mut unique = BTreeSet::new();
1208 for a in &axes {
1209 if !unique.insert(a) {
1210 return Err(ZyxError::ShapeError("Axes contain duplicates.".into()));
1211 }
1212 }
1213 Ok(Tensor { id: RT.lock().max_reduce(self.id, axes) })
1214 }
1215
1216 #[must_use]
1232 pub fn max_kd(&self, axes: impl IntoAxes) -> Result<Tensor, ZyxError> {
1233 self.max(axes.clone())?.reshape(self.reduce_kd_shape(axes))
1234 }
1235
1236 #[must_use]
1253 pub fn mean(&self, axes: impl IntoAxes) -> Result<Tensor, ZyxError> {
1254 let shape = self.shape();
1255 Ok(self.sum(axes.clone())?
1256 / axes
1257 .into_axes(shape.rank())
1258 .map(|a| shape[a])
1259 .product::<usize>() as i64)
1260 }
1261
1262 #[must_use]
1281 pub fn mean_kd(&self, axes: impl IntoAxes) -> Result<Tensor, ZyxError> {
1282 self.mean(axes.clone())?.reshape(self.reduce_kd_shape(axes))
1283 }
1284
1285 #[must_use]
1299 pub fn product(&self, axes: impl IntoAxes) -> Result<Tensor, ZyxError> {
1300 Ok(self.ln().sum(axes)?.exp())
1301 }
1302
1303 #[must_use]
1324 pub fn std(&self, axes: impl IntoAxes) -> Result<Tensor, ZyxError> {
1325 Ok(self.var(axes)?.sqrt())
1326 }
1327
1328 #[must_use]
1347 pub fn std_kd(&self, axes: impl IntoAxes) -> Result<Tensor, ZyxError> {
1348 self.std(axes.clone())?.reshape(self.reduce_kd_shape(axes))
1349 }
1350
1351 #[must_use]
1356 pub fn sum(&self, axes: impl IntoAxes) -> Result<Tensor, ZyxError> {
1357 let rank = self.rank();
1359 let axes: Vec<usize> = axes.into_axes(rank).collect();
1360 {
1361 let mut unique = BTreeSet::new();
1363 for a in &axes {
1364 if !unique.insert(a) {
1365 return Err(ZyxError::ShapeError("Axes contains duplicates.".into()));
1366 }
1367 }
1370 }
1371 Ok(Tensor { id: RT.lock().sum_reduce(self.id, axes) })
1372 }
1373
1374 #[must_use]
1378 pub fn sum_kd(&self, axes: impl IntoAxes) -> Result<Tensor, ZyxError> {
1379 self.sum(axes.clone())?.reshape(self.reduce_kd_shape(axes))
1380 }
1381
1382 #[must_use]
1384 pub fn cumsum(&self, axis: isize) -> Result<Tensor, ZyxError> {
1385 let axis = to_axis(axis, self.rank());
1386 let pl_sz = (self.shape()[axis] - 1) as isize;
1387 let k = self.shape()[axis];
1388 let axis = axis as isize;
1389 let mut x = self.transpose(axis, -1)?;
1390 x = x.pad_zeros([(pl_sz, 0)])?;
1391 x = x.pool(k, 1, 1)?;
1393 x = x.sum(-1)?;
1395 x = x.transpose(axis, -1)?;
1397 Ok(x)
1399 }
1400
1401 #[must_use]
1425 pub fn softmax(&self, axes: impl IntoAxes) -> Result<Tensor, ZyxError> {
1426 let e = (self - self.max_kd(axes.clone())?).exp();
1427 Ok(&e / e.sum_kd(axes)?)
1428 }
1429
1430 #[must_use]
1457 pub fn var(&self, axes: impl IntoAxes) -> Result<Tensor, ZyxError> {
1458 Ok((self - self.mean(axes.clone())?).pow(2)?.sum(axes)?)
1459 }
1460
1461 #[must_use]
1484 pub fn var_kd(&self, axes: impl IntoAxes) -> Result<Tensor, ZyxError> {
1485 self.var(axes.clone())?.reshape(self.reduce_kd_shape(axes))
1486 }
1487
1488 #[must_use]
1491 pub fn get(&self, index: impl IntoIndex) -> Result<Tensor, ZyxError> {
1492 let shape = self.shape();
1493 let padding: Vec<(isize, isize)> = index
1494 .into_index()
1495 .into_iter()
1496 .zip(shape.iter())
1497 .map(|(r, d)| {
1498 (
1499 if r.start >= 0 {
1500 -r.start
1501 } else {
1502 -r.start - *d as isize
1503 },
1504 if r.end == isize::MAX {
1505 0
1506 } else if r.end > 0 {
1507 -(*d as isize - r.end)
1508 } else {
1509 r.end
1510 },
1511 )
1512 })
1513 .collect();
1514 let n = shape.rank() - padding.len();
1515 let padding: Vec<(isize, isize)> = padding
1516 .into_iter()
1517 .chain(core::iter::repeat((0, 0)).take(n))
1518 .collect::<Vec<(isize, isize)>>()
1519 .into_iter()
1520 .rev()
1521 .collect();
1522 self.pad_zeros(padding)
1524 }
1525
1526 #[must_use]
1548 pub fn diagonal(&self) -> Tensor {
1549 let n = *self.shape().last().expect("Shape in invalid state. Internal bug.");
1550 self.flatten(..)
1551 .unwrap()
1552 .pad_zeros([(0, n as isize)])
1553 .unwrap()
1554 .reshape([n, n + 1])
1555 .unwrap()
1556 .get((.., 0))
1557 .unwrap()
1558 }
1559
1560 #[must_use]
1579 pub fn cmplt(&self, rhs: impl Into<Tensor>) -> Result<Tensor, ZyxError> {
1580 let (x, y) = Tensor::broadcast(self, rhs)?;
1581 Ok(Tensor {
1582 id: RT.lock().cmplt(x.id, y.id),
1583 })
1584 }
1585
1586 #[must_use]
1588 pub fn maximum(&self, rhs: impl Into<Tensor>) -> Result<Tensor, ZyxError> {
1589 let (x, y) = Tensor::broadcast(self, rhs)?;
1590 Ok(Tensor {
1591 id: RT.lock().maximum(x.id, y.id),
1592 })
1593 }
1594
1595 #[must_use]
1597 pub fn dot(&self, rhs: impl Into<Tensor>) -> Result<Tensor, ZyxError> {
1598 let rhs = rhs.into();
1599 let org_y_shape = rhs.shape();
1600 let y = rhs.t();
1601 let xshape = self.shape();
1602 let yshape = y.shape();
1603 let xrank = xshape.rank();
1604 let yrank = yshape.rank();
1605 if xshape[xrank - 1] != yshape[yrank - 1] {
1606 return Err(ZyxError::ShapeError(format!("Cannot dot tensors with shapes {xshape:?} and {org_y_shape:?}")));
1608 }
1609 let x_shape = xshape[..xrank - 1]
1610 .iter()
1611 .copied()
1612 .chain([1])
1613 .chain([xshape[xrank - 1]])
1614 .collect::<Vec<usize>>();
1615 let y_shape = yshape[0..yrank - 2]
1616 .iter()
1617 .copied()
1618 .chain([1])
1619 .chain(yshape[yrank - yrank.min(2)..yrank].iter().copied())
1620 .collect::<Vec<usize>>();
1621 (self.reshape(x_shape)? * y.reshape(y_shape)?)
1624 .sum(-1)?
1625 .reshape(
1626 xshape[0..xshape.len() - 1]
1627 .iter()
1628 .copied()
1629 .chain([yshape[yshape.len() - 2]])
1630 .collect::<Vec<usize>>(),
1631 )
1632 }
1633
1634 #[must_use]
1636 pub fn matmul(&self, rhs: impl Into<Tensor>) -> Result<Tensor, ZyxError> {
1637 self.dot(rhs)
1638 }
1639
1640 #[must_use]
1659 pub fn pow(&self, exponent: impl Into<Tensor>) -> Result<Tensor, ZyxError> {
1660 let (x, y) = Tensor::broadcast(self, exponent)?;
1661 Ok(Tensor {
1662 id: RT.lock().pow(x.id, y.id),
1663 })
1664 }
1665
1666 #[must_use]
1668 pub fn nonzero(&self) -> Tensor {
1669 Tensor {
1670 id: RT.lock().nonzero(self.id),
1671 }
1672 }
1673
1674 #[must_use]
1677 pub fn where_(&self, if_true: impl Into<Tensor>, if_false: impl Into<Tensor>) -> Result<Tensor, ZyxError> {
1678 let (x, y) = Tensor::broadcast(self, if_true)?;
1679 let (x, z) = Tensor::broadcast(x, if_false)?;
1680 let (y, z) = Tensor::broadcast(y, z)?;
1681 let x_nonzero = x.nonzero();
1682 Ok(&x_nonzero * y + !x_nonzero * z)
1683 }
1684
1685 #[must_use]
1704 pub fn cross_entropy_loss(&self, target: impl Into<Tensor>, axes: impl IntoAxes) -> Result<Tensor, ZyxError> {
1705 Ok(self.ln_softmax(axes)? * target)
1706 }
1707
1708 #[must_use]
1729 pub fn l1_loss(&self, target: impl Into<Tensor>) -> Tensor {
1730 (self - target).abs()
1731 }
1732
1733 pub fn mse_loss(&self, target: impl Into<Tensor>) -> Result<Tensor, ZyxError> {
1758 (self - target).pow(2)
1759 }
1760
1761 #[must_use]
1788 pub fn cosine_similarity(&self, rhs: impl Into<Tensor>, eps: impl Into<Tensor>) -> Result<Tensor, ZyxError> {
1789 let rhs: Tensor = rhs.into();
1790 let eps: Tensor = eps.into();
1791 let x = self.pow(2)?.sqrt() * rhs.pow(2)?.sqrt();
1792 Ok(self * rhs / x.cmplt(&eps)?.where_(eps, x)?)
1793 }
1794
1795 #[must_use]
1798 pub fn flatten(&self, axes: impl RangeBounds<isize>) -> Result<Tensor, ZyxError> {
1799 let shape = self.shape();
1800 let rank = shape.len();
1801 let start_dim = to_axis(
1802 match axes.start_bound() {
1803 Bound::Included(dim) => *dim,
1804 Bound::Excluded(dim) => *dim + 1,
1805 Bound::Unbounded => 0,
1806 },
1807 rank,
1808 );
1809 let end_dim = to_axis(
1810 match axes.end_bound() {
1811 Bound::Included(dim) => *dim,
1812 Bound::Excluded(dim) => *dim - 1,
1813 Bound::Unbounded => 0,
1814 },
1815 rank,
1816 );
1817 let dim = shape[start_dim..end_dim].iter().product();
1818 let new_shape: Vec<usize> = shape[..start_dim]
1819 .iter()
1820 .copied()
1821 .chain([dim])
1822 .chain(shape[end_dim..].iter().copied())
1823 .collect();
1824 self.reshape(new_shape)
1825 }
1826
1827 #[must_use]
1854 pub fn cat<'a>(tensors: impl IntoIterator<Item = &'a Tensor>, dim: isize) -> Result<Tensor, ZyxError> {
1855 let tensors: Vec<&Tensor> = tensors.into_iter().collect();
1856 if tensors.len() < 2 {
1857 return Err(ZyxError::ShapeError("Cat requires two or more tensors.".into()));
1858 }
1859 let shape = tensors[0].shape();
1860 let rank = shape.rank();
1861 let dim = if dim < 0 { dim + rank as isize } else { dim } as usize;
1862 for tensor in &tensors {
1864 for (i, (d1, d2)) in shape.iter().zip(tensor.shape().iter()).enumerate() {
1865 if i != dim {
1866 if *d1 != *d2 {
1867 return Err(ZyxError::ShapeError("Cannot concatenate these tensors.".into()));
1868 }
1869 }
1870 }
1871 }
1872 let mut offset = 0isize;
1873 let mut offset2 = tensors.iter().fold(0, |acc, t| acc + t.shape()[dim] as isize);
1874 let mut shape = tensors[0].shape();
1875 shape[dim] = offset2 as usize;
1876 let mut res = None;
1877 for tensor in tensors {
1878 let d = tensor.shape()[dim] as isize;
1879 offset2 -= d;
1880 let padding: Vec<(isize, isize)> = core::iter::repeat((0isize, 0isize))
1881 .take(rank - dim - 1)
1882 .chain([(offset, offset2)]).collect();
1883 let t = tensor.pad_zeros(padding)?;
1884 if let Some(r) = res {
1885 res = Some(r + t);
1886 } else {
1887 res = Some(t);
1888 }
1889 offset += d;
1890 }
1891 Ok(res.unwrap())
1892 }
1893
1894 #[must_use]
1914 pub fn unsqueeze(&self, dim: isize) -> Result<Tensor, ZyxError> {
1915 let shape = self.shape();
1916 if dim < 0 {
1917 let rank = shape.len();
1918 let dim = (-dim) as usize;
1919 let dim = rank - dim + 1;
1920 self.reshape(
1921 shape[..dim]
1922 .iter()
1923 .copied()
1924 .chain([1])
1925 .chain(shape[dim..].iter().copied())
1926 .collect::<Vec<usize>>(),
1927 )
1928 } else {
1929 let dim = dim as usize;
1930 self.reshape(
1931 shape[..dim]
1932 .iter()
1933 .copied()
1934 .chain([1])
1935 .chain(shape[dim..].iter().copied())
1936 .collect::<Vec<usize>>(),
1937 )
1938 }
1939 }
1940
1941 #[must_use]
1972 pub fn stack<'a>(tensors: impl IntoIterator<Item = &'a Tensor>, dim: isize) -> Result<Tensor, ZyxError> {
1973 let tensors: Vec<Tensor> = tensors.into_iter().map(|t| t.unsqueeze(dim).unwrap()).collect();
1975 Tensor::cat(&tensors, dim)
1976 }
1977
1978 #[must_use]
1980 pub fn split(&self, sizes: impl IntoShape, dim: isize) -> Result<Vec<Tensor>, ZyxError> {
1981 let sizes: Vec<usize> = sizes.into_shape().collect();
1987 let shape = self.shape();
1988 let rank = shape.rank();
1989 let dim: usize = if dim < 0 { dim + rank as isize } else { dim } as usize;
1990 if sizes.iter().sum::<usize>() != shape[dim] {
1991 return Err(ZyxError::ShapeError(format!("Sizes must sum exactly to {}, but got {:?}, which sums to {}", shape[dim], sizes, sizes.iter().sum::<usize>())));
1992 }
1993
1994 let mut res = Vec::new();
1995 let mut acc_size = 0;
1996 for size in sizes {
1997 let size = size as isize;
1998 let mut index = Vec::new();
1999 for i in 0..dim {
2000 index.push(0..shape[i] as isize);
2001 }
2002 index.push(acc_size..acc_size + size);
2003 res.push(self.get(index)?);
2005 acc_size += size;
2006 }
2007 Ok(res)
2008 }
2009
2010 #[must_use]
2012 pub fn masked_fill(&self, mask: impl Into<Tensor>, value: impl Into<Tensor>) -> Result<Tensor, ZyxError> {
2013 mask.into().where_(value, self)
2014 }
2015
2016 #[must_use]
2035 pub fn pool(
2036 &self,
2037 kernel_size: impl IntoShape,
2038 stride: impl IntoShape,
2039 dilation: impl IntoShape,
2040 ) -> Result<Tensor, ZyxError> {
2041 let k_: Vec<usize> = kernel_size.into_shape().collect();
2043 let stride: Vec<usize> = stride.into_shape().collect();
2044 let dilation: Vec<usize> = dilation.into_shape().collect();
2045
2046 let shape = self.shape();
2047 let rank = shape.len();
2048
2049 let s_: Vec<usize> = if stride.len() == 1 {
2050 repeat(stride[0]).take(k_.len()).collect()
2051 } else {
2052 stride
2053 };
2054 let d_: Vec<usize> = if dilation.len() == 1 {
2055 repeat(dilation[0]).take(k_.len()).collect()
2056 } else {
2057 dilation
2058 };
2059 let i_ = &shape[rank - k_.len()..];
2060 let o_: Vec<usize> = i_
2061 .iter()
2062 .cloned()
2063 .zip(d_.iter().cloned())
2064 .zip(k_.iter().cloned())
2065 .zip(s_.iter().cloned())
2066 .map(|(((i, d), k), s)| (i - d * (k - 1)).div_ceil(s))
2067 .collect();
2068 let repeats: Vec<usize> = repeat(1)
2070 .take(rank - k_.len())
2071 .chain(
2072 k_.iter()
2073 .copied()
2074 .zip(i_.iter().copied())
2075 .zip(d_.iter().copied())
2076 .map(|((k, i), d)| (k * (i + d)).div_ceil(i)),
2077 )
2078 .collect();
2079 let pad_b: Vec<Range<isize>> = shape[..rank - k_.len()]
2081 .iter()
2082 .map(|&d| 0..d as isize)
2083 .collect();
2084 let sh_b: Vec<usize> = shape[..rank - k_.len()].into();
2085 let mut xup = self.repeat(repeats)?;
2086
2087 let padding: Vec<Range<isize>> = pad_b
2090 .iter()
2091 .cloned()
2092 .chain(
2093 k_.iter()
2094 .copied()
2095 .zip(i_.iter().copied())
2096 .zip(d_.iter().copied())
2097 .map(|((k, i), d)| (0..(k * (i + d)) as isize)),
2098 )
2099 .collect();
2100 xup = xup.get(padding)?;
2102 let sh: Vec<usize> = sh_b
2104 .iter()
2105 .copied()
2106 .chain(
2107 k_.iter()
2108 .copied()
2109 .zip(i_.iter().copied())
2110 .zip(d_.iter().copied())
2111 .map(|((k, i), d)| [k, i + d])
2112 .flatten(),
2113 )
2114 .collect();
2115 xup = xup.reshape(sh)?;
2117
2118 let padding: Vec<Range<isize>> = pad_b
2122 .iter()
2123 .cloned()
2124 .chain(
2125 k_.iter()
2126 .copied()
2127 .zip(o_.iter().copied())
2128 .zip(s_.iter().copied())
2129 .map(|((k, o), s)| [(0..k as isize), (0..(o * s) as isize)])
2130 .flatten(),
2131 )
2132 .collect();
2133 xup = xup.get(padding)?;
2134 let sh: Vec<usize> = sh_b
2137 .iter()
2138 .copied()
2139 .chain(
2140 k_.iter()
2141 .copied()
2142 .zip(o_.iter().copied())
2143 .zip(s_.iter().copied())
2144 .map(|((k, o), s)| [k, o, s])
2145 .flatten(),
2146 )
2147 .collect();
2148 xup = xup.reshape(sh)?;
2149 let padding: Vec<Range<isize>> = pad_b
2152 .iter()
2153 .cloned()
2154 .chain(
2155 k_.iter()
2156 .copied()
2157 .zip(o_.iter().copied())
2158 .map(|(k, o)| [(0..k as isize), (0..o as isize), (0..1)])
2159 .flatten(),
2160 )
2161 .collect();
2162 xup = xup.get(padding)?;
2163 let sh: Vec<usize> = sh_b
2166 .iter()
2167 .copied()
2168 .chain(
2169 k_.iter()
2170 .copied()
2171 .zip(o_.iter().copied())
2172 .map(|(k, o)| [k, o])
2173 .flatten(),
2174 )
2175 .collect();
2176 xup = xup.reshape(sh)?;
2177
2178 let axes: Vec<isize> = (0..rank - k_.len())
2180 .chain((0..i_.len()).map(|i| rank - k_.len() + i * 2 + 1))
2181 .chain((0..i_.len()).map(|i| rank - k_.len() + i * 2))
2182 .map(|i| i as isize)
2183 .collect();
2184 xup = xup.permute(axes)?;
2185
2186 Ok(xup)
2187 }
2188
2189 #[must_use]
2209 pub fn repeat(&self, repeats: impl IntoShape) -> Result<Tensor, ZyxError> {
2210 let repeats: Vec<usize> = repeats.into_shape().collect();
2211 let shape = self.shape();
2212 let rank = shape.len();
2213 if repeats.len() < rank {
2214 return Err(ZyxError::ShapeError("Repeats must be greater or equal to rank of the tensor.".into()));
2215 }
2216
2217 let base_shape: Vec<usize> = repeat(1)
2218 .take(repeats.len() - rank)
2219 .chain(shape.iter().copied())
2220 .collect();
2221 let new_shape: Vec<usize> = repeat(1)
2222 .take(repeats.len() - rank)
2223 .chain(shape.into_iter())
2224 .flat_map(|d| [1, d])
2225 .collect();
2226 let expand_shape: Vec<usize> = repeats
2227 .iter()
2228 .copied()
2229 .zip(base_shape.iter().copied())
2230 .flat_map(|(r, d)| [r, d])
2231 .collect();
2232 let final_shape: Vec<usize> = repeats
2233 .iter()
2234 .copied()
2235 .zip(base_shape.iter().copied())
2236 .map(|(r, d)| r * d)
2237 .collect();
2238
2239 let mut x = self.reshape(new_shape)?;
2242 x = x.expand(expand_shape)?;
2243 x = x.reshape(final_shape)?;
2244 Ok(x)
2245 }
2246
2247 pub fn load<Module: FromIterator<Tensor>>(path: impl AsRef<Path>) -> Result<Module, ZyxError> {
2255 let debug_print: bool = RT.lock().debug_dev();
2256 use std::io::Read;
2257 let mut f = std::fs::File::open(path)?;
2258 let mut header_len = [0u8; 8];
2259 f.read_exact(&mut header_len)?;
2260 let n = usize::try_from(u64::from_le_bytes(header_len)).map_err(|e| {
2261 ZyxError::ParseError(format!(
2262 "Failed to parse header len in safetensors file. {e}"
2263 ))
2264 })?;
2265 let mut header = vec![0u8; n];
2266 f.read_exact(&mut header)?;
2267 let header = core::str::from_utf8(&header)
2268 .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidData, err))?;
2269 let mut text = String::with_capacity(10);
2270 let mut begin_str = false;
2271 let mut i = 0;
2272 let mut tensors = Vec::new();
2273 let mut dtype = DType::F32;
2274 let mut shape = vec![1];
2275 for x in header.chars() {
2276 if ['"', '[', ']'].contains(&x) {
2277 if begin_str {
2278 if i % 7 == 0 {
2280 } else if i % 7 == 2 {
2282 dtype = DType::from_safetensors(&text)?;
2283 } else if i % 7 == 4 {
2284 shape = text
2285 .split(',')
2286 .map(|d| {
2287 d.parse::<usize>().map_err(|err| {
2288 ZyxError::ParseError(format!(
2289 "Cannot parse safetensors shape: {err}"
2290 ))
2291 })
2292 })
2293 .collect::<Result<_, ZyxError>>()?;
2294 } else if i % 7 == 6 {
2295 let offsets = text
2298 .split(',')
2299 .map(|offset| {
2300 offset.parse::<usize>().map_err(|err| {
2301 ZyxError::ParseError(format!(
2302 "Could not parse safetensors offset: {err}"
2303 ))
2304 })
2305 })
2306 .collect::<Result<Vec<usize>, ZyxError>>()?;
2307 let bytes = shape.iter().product::<usize>() * dtype.byte_size();
2309 if offsets[1] - offsets[0] != bytes {
2310 return Err(ZyxError::ParseError(
2311 "Safetensors shapes and offsets are incorrect.".into(),
2312 ));
2313 }
2314 let mut buf = vec![0u8; bytes];
2315 if debug_print {
2316 print!("Loading tensor with shape {shape:?}, {dtype:?} ...");
2317 }
2318 f.read_exact(&mut buf)?;
2319 if debug_print {
2320 println!(" DONE");
2321 }
2322 tensors.push(match dtype {
2323 DType::F32 => {
2324 let vec: Vec<f32> = buf
2325 .chunks_exact(dtype.byte_size())
2326 .map(|x| f32::from_le_bytes([x[0], x[1], x[2], x[3]]))
2327 .collect();
2328 Tensor::from(vec).reshape(&shape)?
2329 }
2330 DType::F64 => {
2331 let vec: Vec<f64> = buf
2332 .chunks_exact(dtype.byte_size())
2333 .map(|x| {
2334 f64::from_le_bytes([
2335 x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7],
2336 ])
2337 })
2338 .collect();
2339 Tensor::from(vec).reshape(&shape)?
2340 }
2341 DType::I32 => {
2342 let vec: Vec<i32> = buf
2343 .chunks_exact(dtype.byte_size())
2344 .map(|x| i32::from_le_bytes([x[0], x[1], x[2], x[3]]))
2345 .collect();
2346 Tensor::from(vec).reshape(&shape)?
2347 }
2348 _ => todo!(),
2349 });
2350 }
2351 i += 1;
2352 text.clear();
2353 begin_str = false;
2354 } else {
2355 text.clear();
2356 begin_str = true;
2357 }
2358 } else {
2359 text.push(x);
2360 }
2361 }
2362 Ok(Module::from_iter(tensors))
2363 }
2364
2365 pub fn to_le_bytes(&self) -> Result<Vec<u8>, ZyxError> {
2367 Ok(match self.dtype() {
2368 DType::F32 => {
2369 let data: Vec<f32> = self.clone().try_into()?;
2370 data.into_iter().flat_map(|x| x.to_le_bytes()).collect()
2371 }
2372 DType::F64 => {
2373 let data: Vec<f64> = self.clone().try_into()?;
2374 data.into_iter().flat_map(|x| x.to_le_bytes()).collect()
2375 }
2376 DType::U8 => {
2377 let data: Vec<u8> = self.clone().try_into()?;
2378 data.into_iter().flat_map(|x| x.to_le_bytes()).collect()
2379 }
2380 DType::I8 => {
2381 let data: Vec<i8> = self.clone().try_into()?;
2382 data.into_iter().flat_map(|x| x.to_le_bytes()).collect()
2383 }
2384 DType::I16 => {
2385 let data: Vec<i16> = self.clone().try_into()?;
2386 data.into_iter().flat_map(|x| x.to_le_bytes()).collect()
2387 }
2388 DType::I32 => {
2389 let data: Vec<i32> = self.clone().try_into()?;
2390 data.into_iter().flat_map(|x| x.to_le_bytes()).collect()
2391 }
2392 DType::I64 => {
2393 let data: Vec<i64> = self.clone().try_into()?;
2394 data.into_iter().flat_map(|x| x.to_le_bytes()).collect()
2395 }
2396 DType::Bool => {
2397 let data: Vec<bool> = self.clone().try_into()?;
2398 unsafe { std::mem::transmute(data) }
2399 }
2400 })
2401 }
2402
2403 pub fn from_le_bytes(&self, bytes: &[u8]) -> Result<(), ZyxError> {
2405 let _ = bytes;
2406 todo!()
2407 }
2408}
2409
2410pub struct DebugGuard {
2411 debug: u32,
2412}
2413
2414impl Drop for DebugGuard {
2415 fn drop(&mut self) {
2416 RT.lock().debug = self.debug;
2417 }
2418}
2419
2420impl Tensor {
2421 #[must_use]
2423 fn float_cast(&self) -> Tensor {
2424 let dtype = self.dtype();
2425 if !dtype.is_float() {
2426 return match dtype.byte_size() {
2427 #[cfg(feature = "half")]
2428 1 | 2 => self.cast(DType::F16),
2429 #[cfg(feature = "half")]
2430 4 => self.cast(DType::F32),
2431 #[cfg(not(feature = "half"))]
2432 1 | 2 | 4 => self.cast(DType::F32),
2433 8 => self.cast(DType::F64),
2434 _ => panic!(),
2435 };
2436 }
2437 self.clone()
2438 }
2439
2440 #[must_use]
2444 fn broadcast(x: impl Into<Tensor>, y: impl Into<Tensor>) -> Result<(Tensor, Tensor), ZyxError> {
2445 let mut x = x.into();
2446 let mut y = y.into();
2447 match (x.dtype(), y.dtype()) {
2457 (DType::F32, DType::I32) => y = y.cast(DType::F32),
2458 (DType::F32, DType::F64) => x = x.cast(DType::F64),
2459 (DType::I32, DType::F32) => x = x.cast(DType::F32),
2460 (DType::I32, DType::F64) => x = x.cast(DType::F64),
2461 (DType::F64, DType::F32) => y = y.cast(DType::F64),
2462 (DType::F64, DType::I32) => y = y.cast(DType::F64),
2463 _ => {}
2464 }
2465 let mut x_shape = x.shape();
2466 let mut y_shape = y.shape();
2467
2468 for (&x, &y) in x_shape.iter().rev().zip(y_shape.iter().rev()) {
2469 if x != y {
2470 if x != 1 && y != 1 {
2471 return Err(ZyxError::ShapeError(format!("Left and right tensor shapes can not be broadcasted: {x_shape:?} and {y_shape:?}")));
2472 }
2473 }
2475 }
2476
2477 let rx = x_shape.rank();
2478 let ry = y_shape.rank();
2479 match rx.cmp(&ry) {
2480 Ordering::Less => {
2481 x_shape = core::iter::repeat(1)
2482 .take(ry - rx)
2483 .chain(x_shape.into_iter())
2484 .collect();
2485 }
2486 Ordering::Greater => {
2487 y_shape = core::iter::repeat(1)
2488 .take(rx - ry)
2489 .chain(y_shape.into_iter())
2490 .collect();
2491 }
2492 Ordering::Equal => {}
2493 }
2494 let mut eshape = Vec::new();
2495 for (x, y) in x_shape.iter().zip(y_shape.iter()) {
2496 eshape.push(*x.max(y));
2497 }
2498 x = x.reshape(&x_shape)?;
2499 if x_shape != eshape {
2500 x = x.expand(&eshape)?;
2501 }
2502 y = y.reshape(&y_shape)?;
2504 if y_shape != eshape {
2508 y = y.expand(&eshape)?;
2509 }
2510 return Ok((x, y));
2514 }
2515
2516 fn reduce_kd_shape(&self, axes: impl IntoAxes) -> Vec<usize> {
2518 let mut shape = self.shape();
2519 for a in axes.clone().into_axes(shape.len()) {
2520 shape[a] = 1;
2521 }
2522 shape
2523 }
2524
2525 pub(super) fn id(&self) -> TensorId {
2526 self.id
2527 }
2528}
2529
2530#[cfg(feature = "half")]
2531impl TryFrom<Tensor> for bf16 {
2532 type Error = ZyxError;
2533 fn try_from(value: Tensor) -> Result<Self, Self::Error> {
2534 RT.lock()
2535 .load(value.id)?
2536 .first()
2537 .copied()
2538 .ok_or(ZyxError::EmptyTensor)
2539 }
2540}
2541
2542#[cfg(feature = "half")]
2543impl TryFrom<Tensor> for f16 {
2544 type Error = ZyxError;
2545 fn try_from(value: Tensor) -> Result<Self, Self::Error> {
2546 RT.lock()
2547 .load(value.id)?
2548 .first()
2549 .copied()
2550 .ok_or(ZyxError::EmptyTensor)
2551 }
2552}
2553
2554impl TryFrom<Tensor> for f32 {
2555 type Error = ZyxError;
2556 fn try_from(value: Tensor) -> Result<Self, Self::Error> {
2557 let mut data = [0.];
2558 RT.lock().load(value.id, &mut data)?;
2559 Ok(data[0])
2560 }
2561}
2562
2563impl TryFrom<Tensor> for f64 {
2564 type Error = ZyxError;
2565 fn try_from(value: Tensor) -> Result<Self, Self::Error> {
2566 let mut data = [0.];
2567 RT.lock().load(value.id, &mut data)?;
2568 Ok(data[0])
2569 }
2570}
2571
2572#[cfg(feature = "complex")]
2573impl TryFrom<Tensor> for Complex<f32> {
2574 type Error = ZyxError;
2575 fn try_from(value: Tensor) -> Result<Self, Self::Error> {
2576 RT.lock()
2577 .load(value.id)?
2578 .first()
2579 .copied()
2580 .ok_or(ZyxError::EmptyTensor)
2581 }
2582}
2583
2584#[cfg(feature = "complex")]
2585impl TryFrom<Tensor> for Complex<f64> {
2586 type Error = ZyxError;
2587 fn try_from(value: Tensor) -> Result<Self, Self::Error> {
2588 RT.lock()
2589 .load(value.id)?
2590 .first()
2591 .copied()
2592 .ok_or(ZyxError::EmptyTensor)
2593 }
2594}
2595
2596impl TryFrom<Tensor> for u8 {
2597 type Error = ZyxError;
2598 fn try_from(value: Tensor) -> Result<Self, Self::Error> {
2599 let mut data = [0];
2600 RT.lock().load(value.id, &mut data)?;
2601 Ok(data[0])
2602 }
2603}
2604
2605impl TryFrom<Tensor> for i8 {
2606 type Error = ZyxError;
2607 fn try_from(value: Tensor) -> Result<Self, Self::Error> {
2608 let mut data = [0];
2609 RT.lock().load(value.id, &mut data)?;
2610 Ok(data[0])
2611 }
2612}
2613
2614impl TryFrom<Tensor> for i16 {
2615 type Error = ZyxError;
2616 fn try_from(value: Tensor) -> Result<Self, Self::Error> {
2617 let mut data = [0];
2618 RT.lock().load(value.id, &mut data)?;
2619 Ok(data[0])
2620 }
2621}
2622
2623impl TryFrom<Tensor> for i32 {
2624 type Error = ZyxError;
2625 fn try_from(value: Tensor) -> Result<Self, Self::Error> {
2626 let mut data = [0];
2627 RT.lock().load(value.id, &mut data)?;
2628 Ok(data[0])
2629 }
2630}
2631
2632impl TryFrom<Tensor> for i64 {
2633 type Error = ZyxError;
2634 fn try_from(value: Tensor) -> Result<Self, Self::Error> {
2635 let mut data = [0];
2636 RT.lock().load(value.id, &mut data)?;
2637 Ok(data[0])
2638 }
2639}
2640
2641impl TryFrom<Tensor> for bool {
2642 type Error = ZyxError;
2643 fn try_from(value: Tensor) -> Result<Self, Self::Error> {
2644 let mut data = [false];
2645 RT.lock().load(value.id, &mut data)?;
2646 Ok(data[0])
2647 }
2648}
2649
2650impl<T: Scalar> TryFrom<Tensor> for Vec<T> {
2651 type Error = ZyxError;
2652 fn try_from(value: Tensor) -> Result<Self, Self::Error> {
2653 let numel = value.numel();
2654 let mut data = Vec::with_capacity(numel);
2655 unsafe { data.set_len(numel) };
2656 RT.lock().load(value.id, &mut data)?;
2657 Ok(data)
2658 }
2659}
2660
2661impl<T: Scalar, const D0: usize> TryFrom<Tensor> for [T; D0] {
2662 type Error = ZyxError;
2663 fn try_from(value: Tensor) -> Result<Self, Self::Error> {
2664 let mut data = [T::zero(); D0];
2665 RT.lock().load(value.id, &mut data)?;
2666 Ok(data)
2667 }
2668}
2669
2670impl<T: Scalar, const D0: usize, const D1: usize> TryFrom<Tensor> for [[T; D1]; D0] {
2671 type Error = ZyxError;
2672 fn try_from(value: Tensor) -> Result<Self, Self::Error> {
2673 let mut data = [[T::zero(); D1]; D0];
2674 RT.lock().load(value.id, data.as_flattened_mut())?;
2675 Ok(data)
2676 }
2677}
2678
2679impl<T: Scalar, const D0: usize, const D1: usize, const D2: usize> TryFrom<Tensor>
2680 for [[[T; D2]; D1]; D0]
2681{
2682 type Error = ZyxError;
2683 fn try_from(value: Tensor) -> Result<Self, Self::Error> {
2684 let mut data = [[[T::zero(); D2]; D1]; D0];
2685 RT.lock()
2686 .load(value.id, data.as_flattened_mut().as_flattened_mut())?;
2687 Ok(data)
2688 }
2689}
2690
2691impl<T: Scalar, const D0: usize, const D1: usize, const D2: usize, const D3: usize> TryFrom<Tensor>
2692 for [[[[T; D3]; D2]; D1]; D0]
2693{
2694 type Error = ZyxError;
2695 fn try_from(value: Tensor) -> Result<Self, Self::Error> {
2696 let mut data = [[[[T::zero(); D3]; D2]; D1]; D0];
2697 RT.lock().load(
2698 value.id,
2699 data.as_flattened_mut()
2700 .as_flattened_mut()
2701 .as_flattened_mut(),
2702 )?;
2703 Ok(data)
2704 }
2705}
2706
2707impl Debug for Tensor {
2708 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
2709 f.write_fmt(format_args!("{self}"))
2710 }
2712}
2713
2714impl Display for Tensor {
2715 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
2716 let precision = if let Some(precision) = f.precision() {
2718 precision
2719 } else {
2720 3
2721 };
2722 let x = self.clone();
2723 let res = match self.dtype() {
2724 #[cfg(feature = "half")]
2725 DType::BF16 => {
2726 let data: Result<Vec<bf16>, _> = x.try_into();
2727 match data {
2728 Ok(data) => tensor_to_string(&data, &self.shape(), precision, f.width()),
2729 Err(e) => format!("f16 tensor failed to realize {e:?}"),
2730 }
2731 }
2732 #[cfg(feature = "half")]
2733 DType::F16 => {
2734 let data: Result<Vec<f16>, _> = x.try_into();
2735 match data {
2736 Ok(data) => tensor_to_string(&data, &self.shape(), precision, f.width()),
2737 Err(e) => format!("f16 tensor failed to realize {e:?}"),
2738 }
2739 }
2740 DType::F32 => {
2741 let data: Result<Vec<f32>, _> = x.try_into();
2742 match data {
2743 Ok(data) => tensor_to_string(&data, &self.shape(), precision, f.width()),
2744 Err(e) => format!("f32 tensor failed to realize {e:?}"),
2745 }
2746 }
2747 DType::F64 => {
2748 let data: Result<Vec<f64>, _> = x.try_into();
2749 match data {
2750 Ok(data) => tensor_to_string(&data, &self.shape(), precision, f.width()),
2751 Err(e) => format!("f64 tensor failed to realize {e:?}"),
2752 }
2753 }
2754 #[cfg(feature = "complex")]
2755 DType::CF32 => {
2756 let data: Result<Vec<Complex<f32>>, _> = x.try_into();
2757 match data {
2758 Ok(data) => tensor_to_string(&data, &self.shape(), precision, f.width()),
2759 Err(e) => format!("f32 tensor failed to realize {e:?}"),
2760 }
2761 }
2762 #[cfg(feature = "complex")]
2763 DType::CF64 => {
2764 let data: Result<Vec<Complex<f64>>, _> = x.try_into();
2765 match data {
2766 Ok(data) => tensor_to_string(&data, &self.shape(), precision, f.width()),
2767 Err(e) => format!("f64 tensor failed to realize {e:?}"),
2768 }
2769 }
2770 DType::U8 => {
2771 let data: Result<Vec<u8>, _> = x.try_into();
2772 match data {
2773 Ok(data) => tensor_to_string(&data, &self.shape(), precision, f.width()),
2774 Err(e) => format!("i32 tensor failed to realize {e:?}"),
2775 }
2776 }
2777 DType::I8 => {
2778 let data: Result<Vec<i8>, _> = x.try_into();
2779 match data {
2780 Ok(data) => tensor_to_string(&data, &self.shape(), precision, f.width()),
2781 Err(e) => format!("i32 tensor failed to realize {e:?}"),
2782 }
2783 }
2784 DType::I16 => {
2785 let data: Result<Vec<i16>, _> = x.try_into();
2786 match data {
2787 Ok(data) => tensor_to_string(&data, &self.shape(), precision, f.width()),
2788 Err(e) => format!("i32 tensor failed to realize {e:?}"),
2789 }
2790 }
2791 DType::I32 => {
2792 let data: Result<Vec<i32>, _> = x.try_into();
2793 match data {
2794 Ok(data) => tensor_to_string(&data, &self.shape(), precision, f.width()),
2795 Err(e) => format!("i32 tensor failed to realize {e:?}"),
2796 }
2797 }
2798 DType::I64 => {
2799 let data: Result<Vec<i64>, _> = x.try_into();
2800 match data {
2801 Ok(data) => tensor_to_string(&data, &self.shape(), precision, f.width()),
2802 Err(e) => format!("i32 tensor failed to realize {e:?}"),
2803 }
2804 }
2805 DType::Bool => {
2806 let data: Result<Vec<bool>, _> = x.try_into();
2807 match data {
2808 Ok(data) => tensor_to_string(&data, &self.shape(), precision, f.width()),
2809 Err(e) => format!("i32 tensor failed to realize {e:?}"),
2810 }
2811 }
2812 };
2813 f.write_fmt(format_args!(
2814 "Tensor {:?} {}\n{res}",
2815 self.shape(),
2816 self.dtype()
2817 ))
2818 }
2819}
2820
2821fn tensor_to_string<T: core::fmt::Display>(
2822 data: &[T],
2823 shape: &[usize],
2824 precision: usize,
2825 width: Option<usize>,
2826) -> String {
2827 use core::fmt::Write;
2828 let n: usize = shape.iter().product();
2829 let rank = shape.len();
2830 let mut res = String::new();
2831 if data.is_empty() {
2832 return "[]".into();
2833 }
2834 let mut w = 0;
2836 if let Some(width) = width {
2837 w = width;
2838 } else {
2839 for x in data {
2840 let l = format!("{x:>.precision$}").len();
2841 if l > w {
2842 w = l;
2843 }
2844 }
2845 }
2846 let d0 = shape[rank - 1];
2847 for (i, x) in data.iter().enumerate() {
2848 {
2849 let mut var = 1;
2850 let mut r = rank;
2851 while r > 0 {
2852 if i % (n / var) == 0 {
2853 res += &(" ".repeat(rank - r) + "[".repeat(r - 1).as_str());
2854 break;
2855 }
2856 var *= shape[rank - r];
2857 r -= 1;
2858 }
2859 }
2860 let _ = write!(res, "{x:>w$.precision$}");
2861 if (i + 1) % d0 != 0usize {
2862 res += " ";
2863 }
2864 {
2865 let mut var = 1;
2866 let mut r = rank;
2867 while r > 0 {
2868 if (i + 1) % (n / var) == 0 {
2869 res += &"]".repeat(r - 1);
2870 break;
2871 }
2872 var *= shape[rank - r];
2873 r -= 1;
2874 }
2875 }
2876 if (i + 1) % d0 == 0usize && i != n - 1 {
2877 res += "\n";
2878 }
2879 }
2880 res
2881}
2882
2883pub trait IntoRange: Clone {
2885 fn into_range(self) -> Range<isize>;
2887}
2888
2889impl IntoRange for RangeFull {
2890 fn into_range(self) -> Range<isize> {
2891 0..isize::MAX
2892 }
2893}
2894
2895impl IntoRange for RangeFrom<isize> {
2896 fn into_range(self) -> Range<isize> {
2897 self.start..isize::MAX
2898 }
2899}
2900
2901impl IntoRange for RangeTo<isize> {
2902 fn into_range(self) -> Range<isize> {
2903 0..self.end
2904 }
2905}
2906
2907impl IntoRange for RangeInclusive<isize> {
2908 fn into_range(self) -> Range<isize> {
2909 *self.start()..*self.end() + 1
2910 }
2911}
2912
2913impl IntoRange for RangeToInclusive<isize> {
2914 fn into_range(self) -> Range<isize> {
2915 0..self.end + 1
2916 }
2917}
2918
2919impl IntoRange for Range<isize> {
2920 fn into_range(self) -> Range<isize> {
2921 self
2922 }
2923}
2924
2925impl IntoRange for isize {
2926 fn into_range(self) -> Range<isize> {
2927 self..self + 1
2928 }
2929}
2930
2931pub trait IntoIndex {
2933 fn into_index(self) -> impl IntoIterator<Item = Range<isize>>;
2935}
2936
2937impl IntoIndex for Vec<Range<isize>> {
2938 fn into_index(self) -> impl IntoIterator<Item = Range<isize>> {
2939 self.into_iter()
2940 }
2941}
2942
2943impl<I: IntoRange> IntoIndex for &[I] {
2944 fn into_index(self) -> impl IntoIterator<Item = Range<isize>> {
2945 self.iter().cloned().map(IntoRange::into_range)
2946 }
2947}
2948
2949impl<I0: IntoRange> IntoIndex for I0 {
2950 fn into_index(self) -> impl IntoIterator<Item = Range<isize>> {
2951 [self.into_range()].into_iter()
2952 }
2953}
2954
2955impl<I0: IntoRange, I1: IntoRange> IntoIndex for (I0, I1) {
2956 fn into_index(self) -> impl IntoIterator<Item = Range<isize>> {
2957 [self.0.into_range(), self.1.into_range()].into_iter()
2958 }
2959}
2960
2961impl<I0: IntoRange, I1: IntoRange, I2: IntoRange> IntoIndex for (I0, I1, I2) {
2962 fn into_index(self) -> impl IntoIterator<Item = Range<isize>> {
2963 [
2964 self.0.into_range(),
2965 self.1.into_range(),
2966 self.2.into_range(),
2967 ]
2968 .into_iter()
2969 }
2970}
2971
2972impl<I0: IntoRange, I1: IntoRange, I2: IntoRange, I3: IntoRange> IntoIndex for (I0, I1, I2, I3) {
2973 fn into_index(self) -> impl IntoIterator<Item = Range<isize>> {
2974 [
2975 self.0.into_range(),
2976 self.1.into_range(),
2977 self.2.into_range(),
2978 self.3.into_range(),
2979 ]
2980 .into_iter()
2981 }
2982}
2983
2984impl<I0: IntoRange, I1: IntoRange, I2: IntoRange, I3: IntoRange, I4: IntoRange> IntoIndex
2985 for (I0, I1, I2, I3, I4)
2986{
2987 fn into_index(self) -> impl IntoIterator<Item = Range<isize>> {
2988 [
2989 self.0.into_range(),
2990 self.1.into_range(),
2991 self.2.into_range(),
2992 self.3.into_range(),
2993 self.4.into_range(),
2994 ]
2995 .into_iter()
2996 }
2997}
2998
2999impl<I0: IntoRange, I1: IntoRange, I2: IntoRange, I3: IntoRange, I4: IntoRange, I5: IntoRange>
3000 IntoIndex for (I0, I1, I2, I3, I4, I5)
3001{
3002 fn into_index(self) -> impl IntoIterator<Item = Range<isize>> {
3003 [
3004 self.0.into_range(),
3005 self.1.into_range(),
3006 self.2.into_range(),
3007 self.3.into_range(),
3008 self.4.into_range(),
3009 self.5.into_range(),
3010 ]
3011 .into_iter()
3012 }
3013}
3014
3015impl<
3016 I0: IntoRange,
3017 I1: IntoRange,
3018 I2: IntoRange,
3019 I3: IntoRange,
3020 I4: IntoRange,
3021 I5: IntoRange,
3022 I6: IntoRange,
3023 > IntoIndex for (I0, I1, I2, I3, I4, I5, I6)
3024{
3025 fn into_index(self) -> impl IntoIterator<Item = Range<isize>> {
3026 [
3027 self.0.into_range(),
3028 self.1.into_range(),
3029 self.2.into_range(),
3030 self.3.into_range(),
3031 self.4.into_range(),
3032 self.5.into_range(),
3033 self.6.into_range(),
3034 ]
3035 .into_iter()
3036 }
3037}
3038
3039impl<
3040 I0: IntoRange,
3041 I1: IntoRange,
3042 I2: IntoRange,
3043 I3: IntoRange,
3044 I4: IntoRange,
3045 I5: IntoRange,
3046 I6: IntoRange,
3047 I7: IntoRange,
3048 > IntoIndex for (I0, I1, I2, I3, I4, I5, I6, I7)
3049{
3050 fn into_index(self) -> impl IntoIterator<Item = Range<isize>> {
3051 [
3052 self.0.into_range(),
3053 self.1.into_range(),
3054 self.2.into_range(),
3055 self.3.into_range(),
3056 self.4.into_range(),
3057 self.5.into_range(),
3058 self.6.into_range(),
3059 self.7.into_range(),
3060 ]
3061 .into_iter()
3062 }
3063}
3064
3065impl From<&Tensor> for Tensor {
3066 fn from(value: &Tensor) -> Self {
3067 value.clone()
3068 }
3069}
3070
3071impl<T: Scalar> From<T> for Tensor {
3072 fn from(value: T) -> Self {
3073 return Tensor {
3074 id: RT.lock().variable(vec![1], &[value]).unwrap(),
3075 };
3076 }
3077}
3078
3079impl<T: Scalar> From<Vec<T>> for Tensor {
3080 fn from(data: Vec<T>) -> Self {
3081 return Tensor {
3082 id: RT.lock().variable(vec![data.len()], &data).unwrap(),
3083 };
3084 }
3085}
3086
3087impl<T: Scalar> From<&Vec<T>> for Tensor {
3088 fn from(data: &Vec<T>) -> Self {
3089 return Tensor {
3090 id: RT.lock().variable(vec![data.len()], &data).unwrap(),
3091 };
3092 }
3093}
3094
3095impl<T: Scalar> From<&[T]> for Tensor {
3096 fn from(data: &[T]) -> Self {
3097 let n = data.len();
3098 return Tensor {
3099 id: RT.lock().variable(vec![n], data).unwrap(),
3100 };
3101 }
3102}
3103
3104impl<T: Scalar, const D0: usize> From<[T; D0]> for Tensor {
3105 fn from(data: [T; D0]) -> Self {
3106 return Tensor {
3107 id: RT.lock().variable(vec![D0], &data).unwrap(),
3108 };
3109 }
3110}
3111
3112impl<T: Scalar, const D0: usize, const D1: usize> From<[[T; D1]; D0]> for Tensor {
3113 fn from(data: [[T; D1]; D0]) -> Self {
3114 let data = unsafe { core::slice::from_raw_parts(data[0].as_ptr(), D0 * D1) };
3115 return Tensor {
3116 id: RT.lock().variable(vec![D0, D1], data).unwrap(),
3117 };
3118 }
3119}
3120
3121impl<T: Scalar, const D0: usize, const D1: usize, const D2: usize> From<[[[T; D2]; D1]; D0]>
3122 for Tensor
3123{
3124 fn from(data: [[[T; D2]; D1]; D0]) -> Self {
3125 let data = unsafe { core::slice::from_raw_parts(data[0][0].as_ptr(), D0 * D1 * D2) };
3126 return Tensor {
3127 id: RT.lock().variable(vec![D0, D1, D2], data).unwrap(),
3128 };
3129 }
3130}
3131
3132impl<T: Scalar, const D0: usize, const D1: usize, const D2: usize, const D3: usize>
3133 From<[[[[T; D3]; D2]; D1]; D0]> for Tensor
3134{
3135 fn from(data: [[[[T; D3]; D2]; D1]; D0]) -> Self {
3136 let data =
3137 unsafe { core::slice::from_raw_parts(data[0][0][0].as_ptr(), D0 * D1 * D2 * D3) };
3138 return Tensor {
3139 id: RT.lock().variable(vec![D0, D1, D2, D3], data).unwrap(),
3140 };
3141 }
3142}
3143
3144impl PartialEq<f32> for Tensor {
3145 fn eq(&self, other: &f32) -> bool {
3146 if let Ok(data) = self.clone().try_into() {
3147 let data: f32 = data;
3148 &data == other
3149 } else {
3150 false
3151 }
3152 }
3153}
3154
3155impl PartialEq<i32> for Tensor {
3156 fn eq(&self, other: &i32) -> bool {
3157 if let Ok(data) = self.clone().try_into() {
3158 let data: i32 = data;
3159 &data == other
3160 } else {
3161 false
3162 }
3163 }
3164}
3165
3166impl<T: Scalar, const D0: usize> PartialEq<[T; D0]> for Tensor {
3167 fn eq(&self, other: &[T; D0]) -> bool {
3168 if self.shape() != [D0] {
3169 return false
3170 }
3171 if let Ok(data) = self.clone().try_into() {
3172 let data: [T; D0] = data;
3173 &data == other
3174 } else {
3175 false
3176 }
3177 }
3178}
3179
3180impl<T: Scalar, const D0: usize, const D1: usize> PartialEq<[[T; D1]; D0]> for Tensor {
3181 fn eq(&self, other: &[[T; D1]; D0]) -> bool {
3182 if self.shape() != [D0, D1] {
3183 return false
3184 }
3185 if let Ok(data) = self.clone().try_into() {
3186 let data: [[T; D1]; D0] = data;
3187 &data == other
3188 } else {
3189 false
3190 }
3191 }
3192}
3193
3194impl<T: Scalar, const D0: usize, const D1: usize, const D2: usize> PartialEq<[[[T; D2]; D1]; D0]>
3195 for Tensor
3196{
3197 fn eq(&self, other: &[[[T; D2]; D1]; D0]) -> bool {
3198 if self.shape() != [D0, D1, D2] {
3199 return false
3200 }
3201 if let Ok(data) = self.clone().try_into() {
3202 let data: [[[T; D2]; D1]; D0] = data;
3203 &data == other
3204 } else {
3205 false
3206 }
3207 }
3208}
3209
3210impl<T: Scalar, const D0: usize, const D1: usize, const D2: usize, const D3: usize>
3211 PartialEq<[[[[T; D3]; D2]; D1]; D0]> for Tensor
3212{
3213 fn eq(&self, other: &[[[[T; D3]; D2]; D1]; D0]) -> bool {
3214 if self.shape() != [D0, D1, D2, D3] {
3215 return false
3216 }
3217 if let Ok(data) = self.clone().try_into() {
3218 let data: [[[[T; D3]; D2]; D1]; D0] = data;
3219 &data == other
3220 } else {
3221 false
3222 }
3223 }
3224}
3225
3226impl<IT: Into<Tensor>> Add<IT> for Tensor {
3227 type Output = Tensor;
3228 fn add(self, rhs: IT) -> Self::Output {
3229 let (x, y) = Tensor::broadcast(self, rhs).unwrap();
3230 let tensor = Tensor {
3235 id: RT.lock().add(x.id, y.id),
3236 };
3237 return tensor;
3238 }
3239}
3240
3241impl<IT: Into<Tensor>> Add<IT> for &Tensor {
3242 type Output = Tensor;
3243 fn add(self, rhs: IT) -> Self::Output {
3244 let (x, y) = Tensor::broadcast(self, rhs).unwrap();
3245 let tensor = Tensor {
3250 id: RT.lock().add(x.id, y.id),
3251 };
3252 return tensor;
3253 }
3254}
3255
3256impl<IT: Into<Tensor>> Sub<IT> for Tensor {
3257 type Output = Tensor;
3258 fn sub(self, rhs: IT) -> Self::Output {
3259 let (x, y) = Tensor::broadcast(self, rhs).unwrap();
3260 let tensor = Tensor {
3265 id: RT.lock().sub(x.id, y.id),
3266 };
3267 return tensor;
3268 }
3269}
3270
3271impl<IT: Into<Tensor>> Sub<IT> for &Tensor {
3272 type Output = Tensor;
3273 fn sub(self, rhs: IT) -> Self::Output {
3274 let (x, y) = Tensor::broadcast(self, rhs).unwrap();
3275 let tensor = Tensor {
3280 id: RT.lock().sub(x.id, y.id),
3281 };
3282 return tensor;
3283 }
3284}
3285
3286impl<IT: Into<Tensor>> Mul<IT> for Tensor {
3287 type Output = Tensor;
3288 fn mul(self, rhs: IT) -> Self::Output {
3289 let rhs = rhs.into();
3290 let (x, y) = Tensor::broadcast(self, rhs).unwrap();
3291 let tensor = Tensor {
3297 id: RT.lock().mul(x.id, y.id),
3298 };
3299 return tensor;
3300 }
3301}
3302
3303impl<IT: Into<Tensor>> Mul<IT> for &Tensor {
3304 type Output = Tensor;
3305 fn mul(self, rhs: IT) -> Self::Output {
3306 let rhs = rhs.into();
3307 let (x, y) = Tensor::broadcast(self, rhs).unwrap();
3308 let tensor = Tensor {
3313 id: RT.lock().mul(x.id, y.id),
3314 };
3315 return tensor;
3316 }
3317}
3318
3319impl<IT: Into<Tensor>> Div<IT> for Tensor {
3320 type Output = Tensor;
3321 fn div(self, rhs: IT) -> Self::Output {
3322 let (x, y) = Tensor::broadcast(self, rhs).unwrap();
3323 let tensor = Tensor {
3324 id: RT.lock().div(x.id, y.id),
3325 };
3326 return tensor;
3327 }
3328}
3329
3330impl<IT: Into<Tensor>> Div<IT> for &Tensor {
3331 type Output = Tensor;
3332 fn div(self, rhs: IT) -> Self::Output {
3333 let (x, y) = Tensor::broadcast(self, rhs).unwrap();
3334 let tensor = Tensor {
3335 id: RT.lock().div(x.id, y.id),
3336 };
3337 return tensor;
3338 }
3339}
3340
3341impl<IT: Into<Tensor>> BitOr<IT> for Tensor {
3342 type Output = Tensor;
3343 fn bitor(self, rhs: IT) -> Self::Output {
3344 let (x, y) = Tensor::broadcast(self, rhs).unwrap();
3345 let tensor = Tensor {
3346 id: RT.lock().bitor(x.id, y.id),
3347 };
3348 return tensor;
3349 }
3350}
3351
3352impl<IT: Into<Tensor>> BitOr<IT> for &Tensor {
3353 type Output = Tensor;
3354 fn bitor(self, rhs: IT) -> Self::Output {
3355 let (x, y) = Tensor::broadcast(self, rhs).unwrap();
3356 let tensor = Tensor {
3357 id: RT.lock().bitor(x.id, y.id),
3358 };
3359 return tensor;
3360 }
3361}
3362
3363impl<IT: Into<Tensor>> BitXor<IT> for Tensor {
3364 type Output = Tensor;
3365 fn bitxor(self, rhs: IT) -> Self::Output {
3366 let (x, y) = Tensor::broadcast(self, rhs).unwrap();
3367 let tensor = Tensor {
3368 id: RT.lock().bitxor(x.id, y.id),
3369 };
3370 return tensor;
3371 }
3372}
3373
3374impl<IT: Into<Tensor>> BitXor<IT> for &Tensor {
3375 type Output = Tensor;
3376 fn bitxor(self, rhs: IT) -> Self::Output {
3377 let (x, y) = Tensor::broadcast(self, rhs).unwrap();
3378 let tensor = Tensor {
3379 id: RT.lock().bitxor(x.id, y.id),
3380 };
3381 return tensor;
3382 }
3383}
3384
3385impl<IT: Into<Tensor>> BitAnd<IT> for Tensor {
3386 type Output = Tensor;
3387 fn bitand(self, rhs: IT) -> Self::Output {
3388 let (x, y) = Tensor::broadcast(self, rhs).unwrap();
3389 let tensor = Tensor {
3390 id: RT.lock().bitand(x.id, y.id),
3391 };
3392 return tensor;
3393 }
3394}
3395
3396impl<IT: Into<Tensor>> BitAnd<IT> for &Tensor {
3397 type Output = Tensor;
3398 fn bitand(self, rhs: IT) -> Self::Output {
3399 let (x, y) = Tensor::broadcast(self, rhs).unwrap();
3400 let tensor = Tensor {
3401 id: RT.lock().bitand(x.id, y.id),
3402 };
3403 return tensor;
3404 }
3405}
3406
3407impl Neg for Tensor {
3408 type Output = Tensor;
3409 fn neg(self) -> Self::Output {
3410 Tensor {
3411 id: RT.lock().neg(self.id),
3412 }
3413 }
3414}
3415
3416impl Neg for &Tensor {
3417 type Output = Tensor;
3418 fn neg(self) -> Self::Output {
3419 Tensor {
3420 id: RT.lock().neg(self.id),
3421 }
3422 }
3423}
3424
3425impl Not for Tensor {
3426 type Output = Tensor;
3427 fn not(self) -> Self::Output {
3428 Tensor {
3429 id: RT.lock().not(self.id),
3430 }
3431 }
3432}
3433
3434impl Not for &Tensor {
3435 type Output = Tensor;
3436 fn not(self) -> Self::Output {
3437 Tensor {
3438 id: RT.lock().not(self.id),
3439 }
3440 }
3441}
3442
3443macro_rules! impl_trait {
3444 ($trait:ident for $type:ty, $fn_name:ident) => {
3445 impl $trait<Tensor> for $type {
3446 type Output = Tensor;
3447 fn $fn_name(self, rhs: Tensor) -> Self::Output {
3448 rhs.$fn_name(self)
3449 }
3450 }
3451
3452 impl $trait<&Tensor> for $type {
3453 type Output = Tensor;
3454 fn $fn_name(self, rhs: &Tensor) -> Self::Output {
3455 rhs.$fn_name(self)
3456 }
3457 }
3458 };
3459}
3460
3461#[cfg(feature = "half")]
3462impl_trait!(Add for bf16, add);
3463#[cfg(feature = "half")]
3464impl_trait!(Add for f16, add);
3465impl_trait!(Add for f32, add);
3466impl_trait!(Add for f64, add);
3467#[cfg(feature = "complex")]
3468impl_trait!(Add for Complex<f32>, add);
3469#[cfg(feature = "complex")]
3470impl_trait!(Add for Complex<f64>, add);
3471impl_trait!(Add for u8, add);
3472impl_trait!(Add for i8, add);
3473impl_trait!(Add for i16, add);
3474impl_trait!(Add for i32, add);
3475impl_trait!(Add for i64, add);
3476impl_trait!(Add for bool, add);
3477
3478#[cfg(feature = "half")]
3479impl_trait!(Sub for bf16, sub);
3480#[cfg(feature = "half")]
3481impl_trait!(Sub for f16, sub);
3482impl_trait!(Sub for f32, sub);
3483impl_trait!(Sub for f64, sub);
3484#[cfg(feature = "complex")]
3485impl_trait!(Sub for Complex<f32>, sub);
3486#[cfg(feature = "complex")]
3487impl_trait!(Sub for Complex<f64>, sub);
3488impl_trait!(Sub for u8, sub);
3489impl_trait!(Sub for i8, sub);
3490impl_trait!(Sub for i16, sub);
3491impl_trait!(Sub for i32, sub);
3492impl_trait!(Sub for i64, sub);
3493impl_trait!(Sub for bool, sub);
3494
3495#[cfg(feature = "half")]
3496impl_trait!(Mul for bf16, mul);
3497#[cfg(feature = "half")]
3498impl_trait!(Mul for f16, mul);
3499impl_trait!(Mul for f32, mul);
3500impl_trait!(Mul for f64, mul);
3501#[cfg(feature = "complex")]
3502impl_trait!(Mul for Complex<f32>, mul);
3503#[cfg(feature = "complex")]
3504impl_trait!(Mul for Complex<f64>, mul);
3505impl_trait!(Mul for u8, mul);
3506impl_trait!(Mul for i8, mul);
3507impl_trait!(Mul for i16, mul);
3508impl_trait!(Mul for i32, mul);
3509impl_trait!(Mul for i64, mul);
3510impl_trait!(Mul for bool, mul);
3511
3512#[cfg(feature = "half")]
3513impl_trait!(Div for bf16, div);
3514#[cfg(feature = "half")]
3515impl_trait!(Div for f16, div);
3516impl_trait!(Div for f32, div);
3517impl_trait!(Div for f64, div);
3518#[cfg(feature = "complex")]
3519impl_trait!(Div for Complex<f32>, div);
3520#[cfg(feature = "complex")]
3521impl_trait!(Div for Complex<f64>, div);
3522impl_trait!(Div for u8, div);
3523impl_trait!(Div for i8, div);
3524impl_trait!(Div for i16, div);
3525impl_trait!(Div for i32, div);
3526impl_trait!(Div for i64, div);
3527impl_trait!(Div for bool, div);
3528
3529#[cfg(feature = "half")]
3530impl_trait!(BitXor for bf16, bitxor);
3531#[cfg(feature = "half")]
3532impl_trait!(BitXor for f16, bitxor);
3533impl_trait!(BitXor for f32, bitxor);
3534impl_trait!(BitXor for f64, bitxor);
3535#[cfg(feature = "complex")]
3536impl_trait!(BitXor for Complex<f32>, bitxor);
3537#[cfg(feature = "complex")]
3538impl_trait!(BitXor for Complex<f64>, bitxor);
3539impl_trait!(BitXor for u8, bitxor);
3540impl_trait!(BitXor for i8, bitxor);
3541impl_trait!(BitXor for i16, bitxor);
3542impl_trait!(BitXor for i32, bitxor);
3543impl_trait!(BitXor for i64, bitxor);
3544impl_trait!(BitXor for bool, bitxor);
3545
3546#[cfg(feature = "half")]
3547impl_trait!(BitOr for bf16, bitor);
3548#[cfg(feature = "half")]
3549impl_trait!(BitOr for f16, bitor);
3550impl_trait!(BitOr for f32, bitor);
3551impl_trait!(BitOr for f64, bitor);
3552#[cfg(feature = "complex")]
3553impl_trait!(BitOr for Complex<f32>, bitor);
3554#[cfg(feature = "complex")]
3555impl_trait!(BitOr for Complex<f64>, bitor);
3556impl_trait!(BitOr for u8, bitor);
3557impl_trait!(BitOr for i8, bitor);
3558impl_trait!(BitOr for i16, bitor);
3559impl_trait!(BitOr for i32, bitor);
3560impl_trait!(BitOr for i64, bitor);
3561impl_trait!(BitOr for bool, bitor);
3562
3563#[cfg(feature = "half")]
3564impl_trait!(BitAnd for bf16, bitand);
3565#[cfg(feature = "half")]
3566impl_trait!(BitAnd for f16, bitand);
3567impl_trait!(BitAnd for f32, bitand);
3568impl_trait!(BitAnd for f64, bitand);
3569#[cfg(feature = "complex")]
3570impl_trait!(BitAnd for Complex<f32>, bitand);
3571#[cfg(feature = "complex")]
3572impl_trait!(BitAnd for Complex<f64>, bitand);
3573impl_trait!(BitAnd for u8, bitand);
3574impl_trait!(BitAnd for i8, bitand);
3575impl_trait!(BitAnd for i16, bitand);
3576impl_trait!(BitAnd for i32, bitand);
3577impl_trait!(BitAnd for i64, bitand);
3578impl_trait!(BitAnd for bool, bitand);