zyx/
tensor.rs

1//! Tensor
2//!
3//! Tensors are at the core of all machine learning.
4
5use crate::dtype::DType;
6use crate::scalar::{Scalar, Float};
7use crate::shape::{to_axis, IntoAxes, IntoPadding, IntoShape};
8use core::cmp::Ordering;
9use std::collections::{BTreeMap, BTreeSet};
10use std::fmt::{Debug, Display};
11use std::iter::repeat;
12use std::ops::{
13    Add, BitAnd, BitOr, BitXor, Bound, Div, Mul, Neg, Not, Range, RangeBounds, RangeFrom,
14    RangeFull, RangeInclusive, RangeTo, RangeToInclusive, Sub,
15};
16use std::path::Path;
17
18use crate::runtime::ZyxError;
19use crate::RT;
20
21#[cfg(feature = "half")]
22use half::{bf16, f16};
23
24#[cfg(feature = "complex")]
25use num_complex::Complex;
26
27pub(crate) type TensorId = usize;
28
29/// A tensor represents a multi-dimensional array of values. This is the primary data structure in the library.
30/// The `Tensor` struct contains an internal identifier (`id`) that uniquely identifies each tensor.
31#[cfg_attr(feature = "py", pyo3::pyclass)]
32pub struct Tensor {
33    id: TensorId,
34}
35
36impl Clone for Tensor {
37    fn clone(&self) -> Self {
38        RT.lock().retain(self.id);
39        Tensor { id: self.id }
40    }
41}
42
43impl Drop for Tensor {
44    fn drop(&mut self) {
45        //std::println!("dropping");
46        RT.lock().release(self.id).unwrap();
47    }
48}
49
50impl Tensor {
51    /// Shape of tensor
52    #[must_use]
53    pub fn shape(&self) -> Vec<usize> {
54        RT.lock().shape(self.id).to_vec()
55    }
56
57    /// Number of scalar elements stored in self
58    #[must_use]
59    pub fn numel(&self) -> usize {
60        self.shape().iter().product()
61    }
62
63    /// Rank of self. Rank means number of dimensions/axes.
64    #[must_use]
65    pub fn rank(&self) -> usize {
66        self.shape().len()
67    }
68
69    /// Datatype of self. See [DType](crate::DType) for available datatypes.
70    #[must_use]
71    pub fn dtype(&self) -> DType {
72        RT.lock().dtype(self.id)
73    }
74
75    /// Is zyx in training mode?
76    #[must_use]
77    pub fn training() -> bool {
78        RT.lock().training
79    }
80
81    /// Set training mode
82    pub fn set_training(training: bool) {
83        RT.lock().training = training;
84    }
85
86    /// Immediatelly evaluate passed tensors
87    pub fn realize<'a>(tensors: impl IntoIterator<Item = &'a Tensor>) -> Result<(), ZyxError> {
88        RT.lock()
89            .realize(tensors.into_iter().map(|t| t.id).collect())
90    }
91
92    /// Returns gradients of self derived w.r.t. sources
93    #[must_use]
94    pub fn backward<'a>(
95        &self,
96        sources: impl IntoIterator<Item = &'a Tensor>,
97    ) -> Vec<Option<Tensor>> {
98        let sources: Vec<TensorId> = sources.into_iter().map(|t| t.id).collect();
99        let grads: BTreeMap<TensorId, TensorId> = RT
100            .lock()
101            .backward(self.id, sources.iter().copied().collect());
102        sources
103            .into_iter()
104            .map(|x: TensorId| grads.get(&x).copied())
105            .map(|id: Option<TensorId>| id.map(|id| Tensor { id }))
106            .collect()
107    }
108
109    /// Detaches tensor from graph.
110    /// This function returns a new tensor with the same data as the previous one,
111    /// but drops it's backpropagation graph. This is usefull for recurrent networks:
112    /// ```rust
113    /// let mut x = Tensor::randn([8, 8]);
114    /// let z = Tensor::randn([8]);
115    /// for _ in 0..100 {
116    ///     // Without detach the graph would grow bigger with every iteration
117    ///     x = x.detach() + z;
118    /// }
119    /// ```
120    #[must_use]
121    pub fn detach(self) -> Result<Tensor, ZyxError> {
122        // TODO remove realization from here
123        let shape = self.shape();
124        let dtype = self.dtype();
125        match dtype {
126            #[cfg(feature = "half")]
127            DType::F16 => {
128                let data: Vec<f16> = self.try_into()?;
129                Tensor::from(data).reshape(shape)
130            }
131            #[cfg(feature = "half")]
132            DType::BF16 => {
133                let data: Vec<bf16> = self.try_into()?;
134                Tensor::from(data).reshape(shape)
135            }
136            DType::F32 => {
137                let data: Vec<f32> = self.try_into()?;
138                Tensor::from(data).reshape(shape)
139            }
140            DType::F64 => {
141                let data: Vec<f64> = self.try_into()?;
142                Tensor::from(data).reshape(shape)
143            }
144            #[cfg(feature = "complex")]
145            DType::CF32 => {
146                let data: Vec<Complex<f32>> = self.try_into()?;
147                Tensor::from(data).reshape(shape)
148            }
149            #[cfg(feature = "complex")]
150            DType::CF64 => {
151                let data: Vec<Complex<f64>> = self.try_into()?;
152                Tensor::from(data).reshape(shape)
153            }
154            DType::U8 => {
155                let data: Vec<u8> = self.try_into()?;
156                Tensor::from(data).reshape(shape)
157            }
158            DType::I8 => {
159                let data: Vec<i8> = self.try_into()?;
160                Tensor::from(data).reshape(shape)
161            }
162            DType::I16 => {
163                let data: Vec<i16> = self.try_into()?;
164                Tensor::from(data).reshape(shape)
165            }
166            DType::I32 => {
167                let data: Vec<i32> = self.try_into()?;
168                Tensor::from(data).reshape(shape)
169            }
170            DType::I64 => {
171                let data: Vec<i64> = self.try_into()?;
172                Tensor::from(data).reshape(shape)
173            }
174            DType::Bool => {
175                let data: Vec<bool> = self.try_into()?;
176                Tensor::from(data).reshape(shape)
177            }
178        }
179    }
180
181    /// Create debug guard at the beginning of the block to debug that block.
182    /// Once the guard is dropped, debug gets reset to global state,
183    /// the one set by ZYX_DEBUG env variable.
184    /// ZYX_DEBUG is bitmask
185    /// 0000 0001 DEBUG_DEV
186    /// 0000 0010 DEBUG_PERF
187    /// 0000 0100 DEBUG_SCHED
188    /// 0000 1000 DEBUG_IR
189    /// 0001 0000 DEBUG_ASM
190    /// For more look at ENV_VARS.md
191    #[must_use]
192    pub fn debug_guard(debug: u32) -> DebugGuard {
193        let mut rt = RT.lock();
194        let guard = DebugGuard { debug: rt.debug };
195        rt.debug = debug;
196        guard
197    }
198
199    /// Write graph of operations between tensors as png image with given filename
200    /// Expects dot program to be in the path. Otherwise create dot graph file
201    /// without converting it to png.
202    pub fn plot_graph<'a>(
203        tensors: impl IntoIterator<Item = &'a Tensor>,
204        name: &str,
205    ) -> Result<(), std::io::Error> {
206        use std::format;
207        let graph = RT
208            .lock()
209            .plot_dot_graph(&tensors.into_iter().map(|t| t.id).collect());
210        std::fs::write(format!("{name}.dot"), graph)?;
211        let output = std::process::Command::new("dot")
212            .arg("-Tpng")
213            .arg(format!("{name}.dot"))
214            .arg("-o")
215            .arg(format!("{name}.png"))
216            .output();
217        if let Err(err) = output {
218            println!("Graph png could not be created: {err}");
219        } else {
220            let _ = std::fs::remove_file(format!("{name}.dot"));
221        }
222        Ok(())
223    }
224
225    /// Manually sets the seed for the random number generator.
226    /// This function is only available if the `rand` feature is enabled.
227    #[cfg(feature = "rand")]
228    pub fn manual_seed(seed: u64) {
229        RT.lock().manual_seed(seed);
230    }
231
232    /// Create random value in range 0f..1f with float dtype
233    /// or 0..int::MAX if it is integer
234    #[cfg(feature = "rand")]
235    #[must_use]
236    pub fn rand(shape: impl IntoShape, dtype: DType) -> Result<Tensor, ZyxError> {
237        const SEED: u64 = 69420;
238        use std::i32;
239
240        use rand::distributions::Uniform;
241        use rand::rngs::SmallRng;
242        use rand::Rng;
243        use rand::SeedableRng;
244        let shape: Vec<usize> = shape.into_shape().collect();
245        let n = shape.iter().product();
246        if dtype.is_float() {
247            // TODO later use threefry
248            let mut rt = RT.lock();
249            rt.rng.get_or_init(|| SmallRng::seed_from_u64(SEED));
250            let Some(rng) = rt.rng.get_mut() else {
251                panic!()
252            };
253            match dtype {
254                DType::F32 => {
255                    let range = Uniform::new(0., 1.);
256                    let data: Vec<f32> = (0..n).map(|_| rng.sample(&range)).collect();
257                    Ok(Tensor {
258                        id: rt.variable(shape, &data)?,
259                    })
260                }
261                DType::F64 => {
262                    let range = Uniform::new(0., 1.);
263                    let data: Vec<f64> = (0..n).map(|_| rng.sample(&range)).collect();
264                    Ok(Tensor {
265                        id: rt.variable(shape, &data)?,
266                    })
267                }
268                _ => panic!(),
269            }
270        } else {
271            let mut rt = RT.lock();
272            rt.rng.get_or_init(|| SmallRng::seed_from_u64(SEED));
273            let Some(rng) = rt.rng.get_mut() else {
274                panic!()
275            };
276            match dtype {
277                DType::U8 => {
278                    let range = Uniform::new(0, u8::MAX);
279                    let data: Vec<u8> = (0..n).map(|_| rng.sample(&range)).collect();
280                    Ok(Tensor {
281                        id: rt.variable(shape, &data)?,
282                    })
283                }
284                DType::I8 => {
285                    let range = Uniform::new(0, i8::MAX);
286                    let data: Vec<i8> = (0..n).map(|_| rng.sample(&range)).collect();
287                    Ok(Tensor {
288                        id: rt.variable(shape, &data)?,
289                    })
290                }
291                DType::I16 => {
292                    let range = Uniform::new(0, i16::MAX);
293                    let data: Vec<i16> = (0..n).map(|_| rng.sample(&range)).collect();
294                    Ok(Tensor {
295                        id: rt.variable(shape, &data)?,
296                    })
297                }
298                DType::I32 => {
299                    let range = Uniform::new(0, i32::MAX);
300                    let data: Vec<i32> = (0..n).map(|_| rng.sample(&range)).collect();
301                    Ok(Tensor {
302                        id: rt.variable(shape, &data)?,
303                    })
304                }
305                DType::I64 => {
306                    let range = Uniform::new(0, i64::MAX);
307                    let data: Vec<i64> = (0..n).map(|_| rng.sample(&range)).collect();
308                    Ok(Tensor {
309                        id: rt.variable(shape, &data)?,
310                    })
311                }
312                _ => panic!(),
313            }
314        }
315        /*# threefry
316        if (num := math.ceil(((num_ := prod(shape)) * dtype.itemsize) / 4)) == 0: return Tensor.zeros(shape, device=device, dtype=dtype, **kwargs)
317        if not had_counter: Tensor._rng_counter.assign(Tensor._rng_counter + num)
318        counts1 = (Tensor.arange(math.ceil(num / 2), device=device, dtype=dtypes.uint32, requires_grad=False)+Tensor._rng_counter.to(device))
319        counts2 = counts1 + math.ceil(num / 2)*/
320
321        /*# threefry random bits
322        x = counts2.cast(dtypes.uint64) << 32 | counts1.cast(dtypes.uint64)
323        x = F.Threefry.apply(*x._broadcasted(Tensor._seed))
324        counts1, counts2 = (x & 0xffffffff).cast(dtypes.uint32), ((x >> 32) & 0xffffffff).cast(dtypes.uint32)
325        bits = counts1.cat(counts2)[:num]
326
327        # bitcast to uint with same number of bits
328        _, nmant = dtypes.finfo(dtype)
329        uint_dtype = {1: dtypes.uint8, 2: dtypes.uint16, 4: dtypes.uint32, 8: dtypes.uint64}[dtype.itemsize]
330        bits = bits.bitcast(uint_dtype)
331        # only randomize the mantissa bits and set the exponent to 1
332        one = Tensor.ones_like(bits, device=bits.device, dtype=dtype).bitcast(uint_dtype)
333        bits = bits.rshift((dtype.itemsize * 8) - nmant).bitwise_or(one)
334
335        # bitcast back to the original dtype
336        out = bits.bitcast(dtype)[:num_].sub(1).reshape(shape)
337        out.requires_grad = kwargs.get("requires_grad")
338        return out.contiguous()*/
339    }
340
341    // Initializers
342    /// Create tensor sampled from standard distribution.
343    #[cfg(feature = "rand")]
344    #[must_use]
345    pub fn randn(shape: impl IntoShape, dtype: DType) -> Result<Tensor, ZyxError> {
346        // https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
347        let shape: Vec<usize> = [2].into_iter().chain(shape.into_shape()).collect();
348        let src = Tensor::rand(shape, dtype)?;
349        let mut x = src.get(0)?;
350        x = x.mul(Tensor::constant(2f32 * std::f32::consts::PI));
351        //panic!();
352        x = x.cos();
353        let mut y = Tensor::constant(1f32) - src.get(1)?;
354        //println!("{y} minus");
355        y = y.ln().mul(Tensor::constant(-2f32)).sqrt();
356        //println!("{y}");
357        Ok(x.mul(y).cast(dtype))
358    }
359
360    /// Create tensor sampled from uniform distribution
361    /// Start of the range must be less than the end of the range.
362    #[cfg(feature = "rand")]
363    #[must_use]
364    pub fn uniform<T: Scalar>(
365        shape: impl IntoShape,
366        range: impl core::ops::RangeBounds<T>,
367    ) -> Result<Tensor, ZyxError> {
368        use core::ops::Bound;
369        let low = match range.start_bound() {
370            Bound::Included(value) => *value,
371            Bound::Excluded(value) => *value,
372            Bound::Unbounded => T::min_value(),
373        };
374        let high = match range.end_bound() {
375            Bound::Included(value) => *value,
376            Bound::Excluded(value) => *value,
377            Bound::Unbounded => T::max_value(),
378        };
379        Ok(Tensor::rand(shape, T::dtype())? * high.sub(low) + low)
380    }
381
382    /// Create tensor sampled from kaiming uniform distribution.
383    #[cfg(feature = "rand")]
384    #[must_use]
385    pub fn kaiming_uniform<T: Scalar>(shape: impl IntoShape, a: T) -> Result<Tensor, ZyxError> {
386        let n = T::from_i64(shape.clone().into_shape().skip(1).product::<usize>() as i64);
387        let one = T::one();
388        let x = Scalar::add(one, Scalar::mul(a, a));
389        let two = Scalar::add(one, one);
390        let three = Scalar::add(two, one);
391        let x = Scalar::div(two, x).sqrt();
392        let bound = Scalar::mul(three.sqrt(), Scalar::div(x, n));
393        return Tensor::uniform(shape, bound.neg()..bound);
394    }
395
396    /// Create tensor filled with zeros.
397    #[must_use]
398    pub fn zeros(shape: impl IntoShape, dtype: DType) -> Tensor {
399        return Tensor {
400            id: RT.lock().zeros(shape.into_shape().collect(), dtype),
401        };
402    }
403
404    /// Create tensor filled with ones.
405    #[must_use]
406    pub fn ones(shape: impl IntoShape, dtype: DType) -> Tensor {
407        return Tensor {
408            id: RT.lock().ones(shape.into_shape().collect(), dtype),
409        };
410    }
411
412    /// Create tensor filled with value.
413    #[must_use]
414    pub fn full(shape: impl IntoShape, value: impl Scalar) -> Result<Tensor, ZyxError> {
415        return Ok(Tensor {
416            id: RT.lock().full(shape.into_shape().collect(), value)?,
417        });
418    }
419
420    /// Create square tensor with ones on the main diagonal and all other values set to zero.
421    #[must_use]
422    pub fn eye(n: usize, dtype: DType) -> Tensor {
423        Tensor::ones(vec![n, 1], dtype)
424            .pad_zeros([(0, n as isize)])
425            .unwrap()
426            .reshape([n + 1, n])
427            .unwrap()
428            .get((..-1, ..)).unwrap()
429    }
430
431    /// Arange method, create range from start, stop, step
432    #[must_use]
433    pub fn arange<T: Scalar>(start: T, stop: T, step: T) -> Result<Tensor, ZyxError> {
434        // if (stop-start)/step <= 0: return Tensor([], dtype=dtype, **kwargs)
435        // return (Tensor.full((math.ceil((stop-start)/step),), step, dtype=dtype, **kwargs)._cumsum() + (start - step)).cast(dtype)
436        let n: i64 = stop.sub(start).div(step).cast();
437        let n = n as usize;
438        //println!("Shape {n}");
439        let m = start.sub(step);
440        let x = Tensor::full(n, step)?;
441        //println!("{x}");
442        let x = x.cumsum(0)?;
443        Ok(x + m)
444    }
445
446    /// Create constant that will be baked into compiled kernels.
447    /// Using different value in graph in place of this constnat will force
448    /// recompilation of one or more kernels.
449    /// For performance reason use this if the value does not
450    /// change during the run of the program or if there are only few repeating variations.
451    #[must_use]
452    pub fn constant(value: impl Scalar) -> Tensor {
453        Tensor {
454            id: RT.lock().constant(value),
455        }
456    }
457
458    // unary
459    /// Computes the absolute value of each element in self.
460    #[must_use]
461    pub fn abs(&self) -> Tensor {
462        self.relu() + (-self).relu()
463    }
464
465    /// Casts self to [dtype](crate::DType).
466    #[must_use]
467    pub fn cast(&self, dtype: DType) -> Tensor {
468        return Tensor {
469            id: RT.lock().cast(self.id, dtype),
470        };
471    }
472
473    /// Applies element-wise, CELU(x)=max⁡(0,x)+min⁡(0,α∗(exp⁡(x/α)−1)).
474    #[must_use]
475    pub fn celu(&self, alpha: impl Scalar) -> Tensor {
476        return self.relu() - (-((self / alpha).exp() - 1) * alpha).relu();
477    }
478
479    /// Returns a new tensor with the cosine of the elements of self.
480    #[must_use]
481    pub fn cos(&self) -> Tensor {
482        let x = self.float_cast();
483        let x = Tensor {
484            id: RT.lock().cos(x.id),
485        };
486        x
487    }
488
489    /// `cosh(x) = (exp(x) + exp(-x)) / 2`.
490    #[must_use]
491    pub fn cosh(&self) -> Tensor {
492        // (e^x + e^-x) / 2
493        let nx = self.neg();
494        let enx = nx.exp();
495        let ex = self.exp();
496        (ex + enx) / 2
497    }
498
499    /// Applies dropout to the tensor with a given probability.
500    ///
501    /// This function randomly sets elements of the input tensor to zero based on the provided probability.
502    /// The output tensor has the same shape as the input tensor. Elements are preserved with probability `1 - probability`
503    /// and set to zero with probability `probability`.
504    #[cfg(feature = "rand")]
505    #[must_use]
506    pub fn dropout<P: Scalar + Float>(&self, probability: P) -> Result<Tensor, ZyxError> {
507        // TODO fix this for training (dropout in training is just scaling)
508        Ok(Tensor::from(probability).cmplt(Tensor::rand(self.shape(), P::dtype())?)? * self)
509    }
510
511    /// Applies the Exponential Linear Unit function element-wise.
512    ///
513    /// The ELU function is defined as:
514    /// ```
515    /// f(x) = x if x > 0
516    ///       α(e^x - 1) otherwise
517    /// ```
518    /// where `α` is a given scaling factor. This function helps mitigate the "dying ReLU" problem.
519    #[must_use]
520    pub fn elu(&self, alpha: impl Scalar) -> Tensor {
521        self.relu() - (Tensor::ones(1, self.dtype()) - self.exp()).relu() * alpha
522    }
523
524    /// Returns a new tensor with the exponential of 2 raised to the power of each element in self.
525    #[must_use]
526    pub fn exp2(&self) -> Tensor {
527        let x = self.float_cast();
528        let x = Tensor {
529            id: RT.lock().exp2(x.id),
530        };
531        x
532    }
533
534    /// Computes the exponential of each element in the input tensor using base e.
535    ///
536    /// This function returns a new tensor that is computed by taking the exponential of each
537    /// element in the input tensor. The output will have the same shape as the input tensor,
538    /// and its elements will be calculated as `e^input_element`.
539    ///
540    /// @param self The input tensor.
541    /// @return A new tensor with the same shape as the input, but with each element computed
542    ///         as `e^input_element`.
543    #[must_use]
544    pub fn exp(&self) -> Tensor {
545        let c: Tensor = Tensor::constant(std::f64::consts::E.log2());
546        (self * c.cast(self.dtype())).exp2()
547    }
548
549    /// Returns a new tensor with the Gelu activation function applied to each element of self.
550    ///
551    /// The Gelu activation function is defined as:
552    /// `gelu(x) = x * 0.5 * (1 + tanh(sqrt(2 / π) * (x + x^3 * 0.044715)))`.
553    #[must_use]
554    pub fn gelu(&self) -> Result<Tensor, ZyxError> {
555        Ok(self * 0.5f32
556            * (((self + self.pow(3f32)? * 0.044_715f32) * (2f32 / core::f32::consts::PI).sqrt())
557                .tanh()
558                + 1f32))
559    }
560
561    /// Applies the Leaky ReLU activation function element-wise.
562    ///
563    /// This function computes the Leaky ReLU of each element in the input tensor. If the element is greater than
564    /// or equal to zero, it returns the element itself; otherwise, it returns `neg_slope * element`.
565    ///
566    /// **Parameters:**
567    ///
568    /// * self: The input tensor.
569    /// * neg_slope: The negative slope coefficient (`α` in the formula) for the Leaky ReLU function.
570    ///
571    /// **Returns:**
572    ///
573    /// A new tensor with the same shape as the input, but with each element computed as `max(0., x) + neg_slope * min(0., x)`.
574    #[must_use]
575    pub fn leaky_relu(&self, neg_slope: impl Scalar) -> Tensor {
576        self.relu() - (self * (-Tensor::from(neg_slope))).relu()
577    }
578
579    /// Computes the base-2 logarithm of each element in the input tensor.
580    ///
581    /// This function returns a new tensor that is computed by taking the base-2 logarithm of each
582    /// element in the input tensor. The output will have the same shape as the input tensor,
583    /// and its elements will be calculated as `log2(input_element)`.
584    ///
585    /// @param self The input tensor.
586    /// @return A new tensor with the same shape as the input, but with each element computed
587    ///         as `log2(input_element)`.
588    #[must_use]
589    pub fn log2(&self) -> Tensor {
590        let x = self.float_cast();
591        return Tensor {
592            id: RT.lock().log2(x.id),
593        };
594    }
595
596    /// Computes the natural logarithm (ln) of each element in the input tensor.
597    ///
598    /// This function returns a new tensor that is computed by taking the natural logarithm of each
599    /// element in the input tensor. The output will have the same shape as the input tensor,
600    /// and its elements will be calculated as `ln(input_element)`.
601    ///
602    /// **Parameters:**
603    ///
604    /// * self: The input tensor.
605    ///
606    /// **Returns:**
607    ///
608    /// A new tensor with the same shape as the input, but with each element computed as `ln(input_element)`.
609    #[must_use]
610    pub fn ln(&self) -> Tensor {
611        let x = self.float_cast();
612        let c: Tensor = Tensor::constant(1f64 / std::f64::consts::E.log2());
613        x.log2() * c.cast(x.dtype())
614    }
615
616    /// Computes the multiplicative inverse of each element in the input tensor.
617    ///
618    /// This function returns a new tensor with the same shape as the input, where each element is the multiplicative inverse (i.e., reciprocal) of the corresponding element in the input tensor.
619    ///
620    /// **Parameters:**
621    ///
622    /// * self: The input tensor.
623    ///
624    /// **Returns:** A new tensor with the same shape as the input, where each element is the multiplicative inverse (reciprocal) of the corresponding element in the input tensor.
625    #[must_use]
626    pub fn inv(&self) -> Tensor {
627        return Tensor {
628            id: RT.lock().inv(self.id),
629        };
630    }
631
632    /// Computes the Mish activation function for each element in the input tensor.
633    ///
634    /// The Mish activation function is a continuous, non-monotonic function that behaves like ReLU for positive inputs and like sigmoid for negative inputs. It is defined as `x * tanh(softplus(x))`.
635    ///
636    /// **Parameters:**
637    ///
638    /// * self: The input tensor.
639    ///
640    /// **Returns:** A new tensor with the same shape as the input, but with each element computed as `Mish(input_element)`.
641    #[must_use]
642    pub fn mish(&self) -> Tensor {
643        self * self.softplus(1, 20).tanh()
644    }
645
646    /// Computes the quick GELU activation function for each element in the input tensor.
647    ///
648    /// The QuickGELU activation function is an approximation of the Gaussian Error Linear Unit (GELU) function that uses a sigmoid function to compute the approximation. It is defined as `x * sigmoid(1.702 * x)`.
649    ///
650    /// **Parameters:**
651    ///
652    /// * self: The input tensor.
653    ///
654    /// **Returns:** A new tensor with the same shape as the input, but with each element computed as `QuickGELU(input_element)`.
655    #[must_use]
656    pub fn quick_gelu(&self) -> Tensor {
657        self * (1.702f32 * self).sigmoid()
658    }
659
660    /// Computes the multiplicative inverse of each element in the input tensor using a faster implementation.
661    ///
662    /// This function returns a new tensor with the same shape as the input, where each element is the multiplicative inverse (i.e., reciprocal) of the corresponding element in the input tensor. This implementation uses `1.0 / self` which is generally faster than calling the `inv()` method directly.
663    ///
664    /// **Parameters:**
665    ///
666    /// * self: The input tensor.
667    ///
668    /// **Returns:** A new tensor with the same shape as the input, where each element is the multiplicative inverse (reciprocal) of the corresponding element in the input tensor using a faster implementation.
669    #[must_use]
670    pub fn reciprocal(&self) -> Tensor {
671        return Tensor {
672            id: RT.lock().reciprocal(self.id),
673        };
674    }
675
676    /// Applies the Rectified Linear Unit (ReLU) activation function to each element in the input tensor.
677    ///
678    /// The ReLU function returns `max(0, x)`, i.e., it replaces negative values with zero and leaves positive values unchanged. This makes it a popular choice for use in hidden layers of neural networks due to its simplicity and effectiveness.
679    ///
680    /// **Parameters:**
681    ///
682    /// * self: The input tensor.
683    ///
684    /// **Returns:** A new tensor with the same shape as the input, but with each element computed as `max(0, input_element)`.
685    #[must_use]
686    pub fn relu(&self) -> Tensor {
687        return Tensor {
688            id: RT.lock().relu(self.id),
689        };
690    }
691
692    /// Computes the reciprocal square root of each element in the input tensor.
693    ///
694    /// This function returns a new tensor with the same shape as the input, where each element is the reciprocal square root (i.e., `1 / sqrt(x)`) of the corresponding element in the input tensor. This operation can be useful for scaling and stabilizing certain types of computations.
695    ///
696    /// **Parameters:**
697    ///
698    /// * self: The input tensor.
699    ///
700    /// **Returns:** A new tensor with the same shape as the input, where each element is the reciprocal square root (i.e., `1 / sqrt(x)`) of the corresponding element in the input tensor.
701    #[must_use]
702    pub fn rsqrt(&self) -> Tensor {
703        self.reciprocal().sqrt()
704    }
705
706    /// Applies the Self-Normalized Linear Unit (Selu) activation function to each element in the input tensor.
707    ///
708    /// The Selu activation function is designed to maintain the mean and variance of the activations approximately constant when training deep neural networks with residual connections. It combines the benefits of both ReLU and sigmoid functions, making it a good choice for certain types of problems.
709    ///
710    /// **Parameters:**
711    ///
712    /// * self: The input tensor.
713    ///
714    /// **Returns:** A new tensor with the same shape as the input, but with each element computed as `Selu(input_element)`.
715    #[must_use]
716    pub fn selu(&self) -> Tensor {
717        1.0507009873554804934193349852946f32
718            * (self.relu()
719                - (1.6732632423543772848170429916717f32
720                    * (Tensor::ones(1, self.dtype()) - self.exp()))
721                .relu())
722    }
723
724    /// Applies the sigmoid activation function to each element in the input tensor.
725    ///
726    /// The sigmoid function returns `1 / (1 + exp(-x))`, i.e., it maps any real-valued input onto a value between 0 and 1. This function is commonly used for binary classification problems or as an activation function in neural networks.
727    ///
728    /// **Parameters:**
729    ///
730    /// * self: The input tensor.
731    ///
732    /// **Returns:** A new tensor with the same shape as the input, but with each element computed as `sigmoid(input_element)`.
733    #[must_use]
734    pub fn sigmoid(&self) -> Tensor {
735        let one = Tensor::ones(1, self.dtype());
736        let exp_x = self.exp();
737        return &exp_x / (&one + &exp_x);
738    }
739
740    /// Applies the sine function to each element in the input tensor.
741    ///
742    /// This function returns a new tensor with the same shape as the input, where each element is the sine of the corresponding element in the input tensor. The sine function is useful for various mathematical and scientific computations involving angles or periodic phenomena.
743    ///
744    /// **Parameters:**
745    ///
746    /// * self: The input tensor.
747    ///
748    /// **Returns:** A new tensor with the same shape as the input, where each element is the sine of the corresponding element in the input tensor.
749    #[must_use]
750    pub fn sin(&self) -> Tensor {
751        let x = self.float_cast();
752        let x = Tensor {
753            id: RT.lock().sin(x.id),
754        };
755        x
756    }
757
758    /// Applies the hyperbolic sine function to each element in the input tensor.
759    ///
760    /// The hyperbolic sine function returns `(e^x - e^-x) / 2`, i.e., it maps any real-valued input onto a value that grows exponentially. This function is useful for computations involving exponential growth or decay, such as in physics and engineering applications.
761    ///
762    /// **Parameters:**
763    ///
764    /// * self: The input tensor.
765    ///
766    /// **Returns:** A new tensor with the same shape as the input, but with each element computed as `sinh(input_element)`.
767    #[must_use]
768    pub fn sinh(&self) -> Tensor {
769        // (e^x - e^-x) / 2
770        let nx = self.neg();
771        let enx = nx.exp();
772        let ex = self.exp();
773        (ex - enx) / 2
774    }
775
776    /// Applies the softplus function to each element in the input tensor with a given beta and threshold.
777    ///
778    /// The softplus function returns `log(exp(x) + 1)` for inputs greater than the threshold, and x otherwise. This function is useful for bounding outputs between zero and infinity when applying the ReLU function.
779    ///
780    /// **Parameters:**
781    ///
782    /// * self: The input tensor.
783    /// * beta: A scalar multiplier applied to each element of the input tensor before comparison with the threshold.
784    /// * threshold: The threshold value below which the input is returned unchanged, and above which the softplus function is applied.
785    ///
786    /// **Returns:** A new tensor with the same shape as the input, where each element is computed according to the softplus function with the given beta and threshold.
787    #[must_use]
788    pub fn softplus(&self, beta: impl Scalar, threshold: impl Scalar) -> Tensor {
789        let x = self * beta;
790        x.cmplt(threshold).unwrap().where_(((x).exp() + 1).ln() * beta.reciprocal(), x).unwrap()
791    }
792
793    /// Applies the square root function to each element in the input tensor.
794    ///
795    /// This function returns a new tensor with the same shape as the input, where each element is the square root of the corresponding element in the input tensor. The square root function is useful for various mathematical computations involving squares or square roots.
796    ///
797    /// **Parameters:**
798    ///
799    /// * self: The input tensor.
800    ///
801    /// **Returns:** A new tensor with the same shape as the input, where each element is the square root of the corresponding element in the input tensor.
802    #[must_use]
803    pub fn sqrt(&self) -> Tensor {
804        let x = self.float_cast();
805        let x = Tensor {
806            id: RT.lock().sqrt(x.id),
807        };
808        x
809    }
810
811    /// Applies the Swish activation function to each element in the input tensor.
812    ///
813    /// The Swish function returns `x * sigmoid(x)`, where `sigmoid(x) = 1 / (1 + exp(-x))`. This function is useful for various deep learning applications, as it has been shown to improve convergence speed and generalization performance compared to other activation functions like ReLU.
814    ///
815    /// **Parameters:**
816    ///
817    /// * self: The input tensor.
818    ///
819    /// **Returns:** A new tensor with the same shape as the input, where each element is computed according to the Swish function.
820    #[must_use]
821    pub fn swish(&self) -> Tensor {
822        self * self.sigmoid()
823    }
824
825    /// Applies the tangent function to each element in the input tensor.
826    ///
827    /// The tangent function returns the sine of the input divided by the cosine of the input. This function is useful for various mathematical computations involving angles and trigonometry.
828    ///
829    /// **Parameters:**
830    ///
831    /// * self: The input tensor.
832    ///
833    /// **Returns:** A new tensor with the same shape as the input, where each element is computed according to the tangent function.
834    #[must_use]
835    pub fn tan(&self) -> Tensor {
836        self.sin() / self.cos()
837    }
838
839    /// Returns the hyperbolic tangent of each element in the tensor.
840    ///
841    /// The hyperbolic tangent is calculated as `(exp(2x) + 1) / (exp(2x) - 1)`, where `exp` is the exponential function and `x` is an element of the input tensor. This function applies the hyperbolic tangent element-wise to the input tensor.
842    ///
843    /// # Examples
844    ///
845    /// ```rust
846    /// use zyx::Tensor;
847    ///
848    /// let t = Tensor::from(vec![0.5, 1.0]);
849    /// assert_eq!(t.tanh(), [0.46211715738221946, 0.761594166564993]);
850    /// ```
851    ///
852    /// # Panics
853    ///
854    /// This function will panic if the input tensor is empty.
855    #[must_use]
856    pub fn tanh(&self) -> Tensor {
857        let x = (self + self).sigmoid();
858        (&x + &x) - Tensor::constant(1).cast(self.dtype())
859    }
860
861    // movement
862    /// Expands this tensor by adding singleton dimensions at the front until its rank matches that of the target shape.
863    ///
864    /// If the target shape has a higher rank than the current tensor, singleton dimensions are added to the front of the tensor's shape.
865    /// If any dimension in the target shape does not match the corresponding dimension in the expanded tensor's shape,
866    /// an assertion failure occurs unless the expanded dimension is 1 (in which case it is ignored).
867    ///
868    /// # Examples
869    ///
870    /// ```
871    /// let t = Tensor::zeros((2, 3));
872    /// assert_eq!(t.expand((4, 2, 3)).shape(), &[4, 2, 3]);
873    /// ```
874    #[must_use]
875    pub fn expand(&self, shape: impl IntoShape) -> Result<Tensor, ZyxError> {
876        let mut sh = self.shape();
877        let shape: Vec<usize> = shape.into_shape().collect();
878        //println!("Expand to {shape:?}");
879        if shape.rank() < sh.rank() {
880            return Err(ZyxError::ShapeError(format!("Cannot expand {:?} into {:?}", self.shape(), shape)));
881        }
882        if shape.rank() > sh.rank() {
883            let mut i = sh.len();
884            for d in shape.iter().copied().rev() {
885                if i == 0 {
886                    // Adding dimensions to the front of the shape
887                    sh.insert(i, 1);
888                } else {
889                    i -= 1;
890                }
891                if d != sh[i] {
892                    if sh[i] != 1 {
893                        return Err(ZyxError::ShapeError(format!("Cannot expand {:?} into {:?}", self.shape(), shape)));
894                    }
895                }
896            }
897            let x = self.reshape(sh).unwrap();
898            let id = RT.lock().expand(x.id, shape);
899            drop(x);
900            return Ok(Tensor { id })
901        };
902        Ok(Tensor { id: RT.lock().expand(self.id, shape) })
903    }
904
905    /// Permutes the axes of this tensor.
906    ///
907    /// This function rearranges the dimensions of the tensor according to the provided axes. The axes must be a permutation of the original axes, i.e., they must contain each index once and only once. If the axes have a different length than the rank of the tensor, a panic will occur with an appropriate error message.
908    ///
909    /// # Examples
910    ///
911    /// ```rust
912    /// use zyx::Tensor;
913    /// let t = Tensor::rand((3, 4)).unwrap();
914    /// let p = [1, 0];
915    /// let permuted_t = t.permute(p); // Results in a tensor with axes (4, 3)
916    /// ```
917    ///
918    /// # Panics
919    ///
920    /// This function panics if the length of `axes` is not equal to the rank of this tensor.
921    #[must_use]
922    pub fn permute(&self, axes: impl IntoAxes) -> Result<Tensor, ZyxError> {
923        let rank = self.rank();
924        let axes: Vec<usize> = axes.into_axes(rank).collect();
925        if rank != axes.len() {
926            return Err(ZyxError::ShapeError(format!("Axes has rank {}, but tensor has rank {}. It must be the same for permute.", axes.len(), rank)));
927        }
928        Ok(Tensor { id: RT.lock().permute(self.id, axes) })
929    }
930
931    /// Creates a new tensor by padding zeros around this tensor based on the specified padding configuration.
932    ///
933    /// # Examples
934    ///
935    /// ```
936    /// use zyx::Tensor;
937    ///
938    /// let t = Tensor::from([1, 2, 3]);
939    /// let padded = t.pad_zeros(1).into_shape((5,))?;
940    /// assert_eq!(padded, [0., 1., 2., 3., 0.]);
941    ///
942    /// let padded = t.pad_zeros([(1, 2)]);
943    /// assert_eq!(padded.shape(), &[5]);
944    /// ```
945    ///
946    /// # Panics
947    ///
948    /// This function will panic if the padding configuration is invalid.
949    #[must_use]
950    pub fn pad_zeros(&self, padding: impl IntoPadding) -> Result<Tensor, ZyxError> {
951        let padding = padding.into_padding();
952        for (i, &(l, r)) in padding.iter().enumerate() {
953            let shape = self.shape();
954            let rank = shape.len();
955            let mut total = 0;
956            if l < 0 {
957                total -= l;
958            }
959            if r < 0 {
960                total -= r;
961            }
962            if (total as usize) >= shape[rank-i-1] {
963                return Err(ZyxError::ShapeError(format!("Invalid padding {padding:?} on shape {shape:?}")));
964            }
965        }
966        Ok(Tensor { id: RT.lock().pad_zeros(self.id, padding) })
967    }
968
969    /// Constant padding
970    ///
971    /// This can both add and remove values from tensor. Negative padding removes values, positive padding
972    /// adds values.
973    ///
974    /// Pad last dimension by (1, 2)
975    /// ```rust
976    /// use zyx::Tensor;
977    /// let x = Tensor::from([[2, 3],
978    ///                       [4, 1]]);
979    /// let z = x.pad([(1, 2)], 0);
980    /// std::println!("{}", z);
981    /// assert_eq!(z, [[0, 2, 3, 0, 0],
982    ///                [0, 4, 1, 0, 0]]);
983    /// ```
984    /// Pad last dimension by (2, -1) and second last dimension by (1, 1)
985    /// ```rust
986    /// # use zyx::Tensor;
987    /// # let x = Tensor::from([[2, 3],
988    /// #                       [4, 1]]);
989    /// let z = x.pad([(2, -1), (1, 1)], 0);
990    /// println!("z: {z}");
991    /// assert_eq!(z, [[0, 0, 0],
992    ///                [0, 0, 2],
993    ///                [0, 0, 4],
994    ///                [0, 0, 0]]);
995    /// ```
996    ///
997    /// # Panics
998    /// T must be of the same dtype as Tensor's dtype, otherwise this function panics.
999    #[must_use]
1000    pub fn pad(
1001        &self,
1002        padding: impl IntoPadding,
1003        value: impl Into<Tensor>,
1004    ) -> Result<Tensor, ZyxError> {
1005        let dtype = self.dtype();
1006        let value: Tensor = value.into();
1007        let padding = padding.into_padding();
1008        let sh = self.shape();
1009        if value.dtype() != dtype {
1010            return Err(ZyxError::DTypeError(format!("Cannot pad tensor with dtype {} with value of dtype {}", dtype, value.dtype())));
1011        }
1012        if !padding.len() <= sh.rank() && padding.iter().zip(sh.iter().rev()).all(|((lp, rp), d)| if *lp < 0 { ((-*lp) as usize) <= *d } else { true } && if *rp < 0 { ((-*rp) as usize) <= *d } else { true }) {
1013            return Err(ZyxError::ShapeError(format!("Cannot pad tensor with shape {sh:?} with padding {padding:?}")));
1014        }
1015        let t0 = self.pad_zeros(padding.clone());
1016        if value.numel() == 1
1017            && match dtype {
1018                #[cfg(feature = "half")]
1019                DType::BF16 => {
1020                    let x: bf16 = value.clone().try_into()?;
1021                    x == bf16::ZERO
1022                }
1023                #[cfg(feature = "half")]
1024                DType::F16 => {
1025                    let x: f16 = value.clone().try_into()?;
1026                    x == f16::ZERO
1027                }
1028                DType::F32 => {
1029                    let x: f32 = value.clone().try_into()?;
1030                    x == 0.
1031                }
1032                DType::F64 => {
1033                    let x: f64 = value.clone().try_into()?;
1034                    x == 0.
1035                }
1036                #[cfg(feature = "complex")]
1037                DType::CF32 => {
1038                    let x: Complex<f32> = value.clone().try_into()?;
1039                    x == Complex::new(0., 0.)
1040                }
1041                #[cfg(feature = "complex")]
1042                DType::CF64 => {
1043                    let x: Complex<f64> = value.clone().try_into()?;
1044                    x == Complex::new(0., 0.)
1045                }
1046                DType::U8 => {
1047                    let x: u8 = value.clone().try_into()?;
1048                    x == 0
1049                }
1050                DType::I8 => {
1051                    let x: i8 = value.clone().try_into()?;
1052                    x == 0
1053                }
1054                DType::I16 => {
1055                    let x: i16 = value.clone().try_into()?;
1056                    x == 0
1057                }
1058                DType::I32 => {
1059                    let x: i32 = value.clone().try_into()?;
1060                    x == 0
1061                }
1062                DType::I64 => {
1063                    let x: i64 = value.clone().try_into()?;
1064                    x == 0
1065                }
1066                DType::Bool => {
1067                    let x: bool = value.clone().try_into()?;
1068                    x == false
1069                }
1070            }
1071        {
1072            t0
1073        } else {
1074            let ones = Tensor::ones(sh.clone(), dtype);
1075            let zeros = Tensor::zeros(sh, self.dtype());
1076            Ok(t0? + ones.pad_zeros(padding)?.where_(zeros, value)?)
1077        }
1078    }
1079
1080    /// Applies a new shape to this tensor while preserving its total number of elements.
1081    ///
1082    /// # Examples
1083    ///
1084    /// ```rust
1085    /// use zyx::Tensor;
1086    /// let t = Tensor::from([1, 2, 3, 4]);
1087    /// assert_eq!(t.reshape((2, 2)), [[1, 2], [3, 4]]);
1088    /// ```
1089    ///
1090    /// # Panics
1091    ///
1092    /// Panics if the product of the new shape is not equal to the number of elements in this tensor.
1093    #[must_use]
1094    pub fn reshape(&self, shape: impl IntoShape) -> Result<Tensor, ZyxError> {
1095        let shape: Vec<usize> = shape.into_shape().collect();
1096        if shape.iter().product::<usize>() != self.numel() {
1097            return Err(ZyxError::ShapeError(format!("Invalid reshape {:?} into {:?}", self.shape(), shape)));
1098        };
1099        Ok(Tensor { id: RT.lock().reshape(self.id, shape) })
1100    }
1101
1102    /// An alias to reshape
1103    #[must_use]
1104    pub fn view(&self, shape: impl IntoShape) -> Result<Tensor, ZyxError> {
1105        self.reshape(shape)
1106    }
1107
1108    /// Transpose last two dimensions of this tensor.
1109    /// If self.rank() == 1, returns tensor with shape `[self.shape()[0], 1]` (column tensor)
1110    #[must_use]
1111    pub fn t(&self) -> Tensor {
1112        let mut rank = self.rank();
1113        let x = if rank == 1 {
1114            let n = self.numel();
1115            rank = 2;
1116            self.reshape([1, n]).unwrap()
1117        } else {
1118            self.clone()
1119        };
1120        let mut axes: Vec<isize> = (0..rank as isize).collect();
1121        axes.swap(rank - 1, rank - 2);
1122        x.permute(axes).unwrap()
1123    }
1124
1125    /// Transpose two arbitrary dimensions
1126    #[must_use]
1127    pub fn transpose(&self, dim0: isize, dim1: isize) -> Result<Tensor, ZyxError> {
1128        let rank = self.rank();
1129        if dim0 < 0 {
1130            if (-dim0) as usize >= rank {
1131                return Err(ZyxError::ShapeError(format!("Cannot transpose dimensions {dim0} and {dim1}, {dim0} is greater than rank {rank}")));
1132            }
1133        } else {
1134            if dim0 as usize >= rank {
1135                return Err(ZyxError::ShapeError(format!("Cannot transpose dimensions {dim0} and {dim1}, {dim0} is greater than rank {rank}")));
1136            }
1137        }
1138        if dim1 < 0 {
1139            if (-dim1) as usize >= rank {
1140                return Err(ZyxError::ShapeError(format!("Cannot transpose dimensions {dim0} and {dim1}, {dim1} is greater than rank {rank}")));
1141            }
1142        } else {
1143            if dim1 as usize >= rank {
1144                return Err(ZyxError::ShapeError(format!("Cannot transpose dimensions {dim0} and {dim1}, {dim1} is greater than rank {rank}")));
1145            }
1146        }
1147        let mut axes: Vec<isize> = (0..rank as isize).collect();
1148        axes.swap(to_axis(dim0, rank), to_axis(dim1, rank));
1149        self.permute(axes)
1150    }
1151
1152    // reduce
1153    /// Computes the natural logarithm of the softmax of the input tensor along the specified axes.
1154    ///
1155    /// This function first subtracts the maximum value along the given axes from the input tensor,
1156    /// then computes the exponential of the result, sums over the specified axes using `sum_kd`,
1157    /// and finally takes the natural logarithm of the sum before returning it.
1158    ///
1159    /// # Arguments
1160    ///
1161    /// * `self` - The input tensor to compute the softmax and natural logarithm of.
1162    /// * `axes` - A trait implementing `IntoAxes`, specifying along which axes the softmax should be computed.
1163    ///
1164    /// # Examples
1165    ///
1166    /// ```
1167    /// use zyx::Tensor;
1168    /// let x = Tensor::from([2f32, 3., 4.]);
1169    /// let y = x.ln_softmax([]);
1170    /// println!("{y}");
1171    /// ```
1172    ///
1173    /// # Returns
1174    ///
1175    /// The resulting tensor after computing the natural logarithm of the softmax of `self`.
1176    ///
1177    /// # Panics
1178    ///
1179    /// This function will panic if any of the specified axes are out-of-bounds for the input tensor.
1180    pub fn ln_softmax(&self, axes: impl IntoAxes) -> Result<Tensor, ZyxError> {
1181        let m = self - self.max_kd(axes.clone())?;
1182        Ok(&m - m.exp().sum_kd(axes)?.ln())
1183    }
1184
1185    /// Returns a new tensor containing the maximum value along the specified axes.
1186    ///
1187    /// # Arguments
1188    ///
1189    /// * `axes` - The axes along which to compute the maximum. This can be any type that implements `IntoAxes`.
1190    ///
1191    /// # Examples
1192    ///
1193    /// ```
1194    /// use zyx::Tensor
1195    /// let arr = Tensor::from([1, 2, 3, 4]);
1196    /// assert_eq!(arr.max(0), [4]);
1197    /// assert_eq!(arr.max(1), [2, 4]);
1198    /// ```
1199    ///
1200    /// # Panics
1201    ///
1202    /// This function panics if the axes contain duplicates.
1203    #[must_use]
1204    pub fn max(&self, axes: impl IntoAxes) -> Result<Tensor, ZyxError> {
1205        let rank = self.rank();
1206        let axes: Vec<usize> = axes.into_axes(rank).collect();
1207        let mut unique = BTreeSet::new();
1208        for a in &axes {
1209            if !unique.insert(a) {
1210                return Err(ZyxError::ShapeError("Axes contain duplicates.".into()));
1211            }
1212        }
1213        Ok(Tensor { id: RT.lock().max_reduce(self.id, axes) })
1214    }
1215
1216    /// Returns the maximum value along the specified axes.
1217    ///
1218    /// This function computes the maximum value of each slice determined by the `axes`.
1219    /// It first calculates the maximum along the specified axes using the `max` method,
1220    /// and then reshapes the result to have the same number of dimensions as the input tensor.
1221    ///
1222    /// # Examples
1223    ///
1224    /// ```
1225    /// use zyx::Tensor;
1226    ///
1227    /// let a = Tensor::from([1, 2, 3, 4]);
1228    /// assert_eq!(a.max_kd(&[0]), &[[4]]);
1229    /// ```
1230    ///
1231    #[must_use]
1232    pub fn max_kd(&self, axes: impl IntoAxes) -> Result<Tensor, ZyxError> {
1233        self.max(axes.clone())?.reshape(self.reduce_kd_shape(axes))
1234    }
1235
1236    /// Calculates the mean of a tensor along specified axes.
1237    ///
1238    /// This function computes the sum of all elements in the tensor along the specified axes and then divides by the product of their sizes.
1239    ///
1240    /// # Examples
1241    ///
1242    /// ```
1243    /// use zyx::Tensor;
1244    ///
1245    /// let arr = Tensor::eye(3, DType::F32);
1246    /// assert_eq!(arr.mean(0, &[1.0, 1.0, 1.0]));
1247    /// ```
1248    ///
1249    /// # Panics
1250    ///
1251    /// This function panics if the tensor is empty.
1252    #[must_use]
1253    pub fn mean(&self, axes: impl IntoAxes) -> Result<Tensor, ZyxError> {
1254        let shape = self.shape();
1255        Ok(self.sum(axes.clone())?
1256            / axes
1257                .into_axes(shape.rank())
1258                .map(|a| shape[a])
1259                .product::<usize>() as i64)
1260    }
1261
1262    /// Calculates the mean of this tensor along the specified axes and reshapes it using `reduce_kd_shape`.
1263    ///
1264    /// This function first calculates the mean of the input tensor along the specified axes using the `mean`
1265    /// method. It then reshapes the resulting tensor using `reduce_kd_shape` to match the output shape expected
1266    /// by the caller.
1267    ///
1268    /// # Examples
1269    ///
1270    /// ```
1271    /// use zyx::Tensor;
1272    ///
1273    /// let a = Tensor::from([1, 2, 3, 4]);
1274    /// assert_eq!(a.mean_kd(0), [2.5]);
1275    /// ```
1276    ///
1277    /// # Panics
1278    ///
1279    /// This function panics if the input tensor is empty.
1280    #[must_use]
1281    pub fn mean_kd(&self, axes: impl IntoAxes) -> Result<Tensor, ZyxError> {
1282        self.mean(axes.clone())?.reshape(self.reduce_kd_shape(axes))
1283    }
1284
1285    /// Calculates the product of elements along specified axes.
1286    ///
1287    /// This function first applies the natural logarithm element-wise (`ln()`), then sums along the specified axes,
1288    /// and finally exponentiates the result element-wise (`exp()`).
1289    ///
1290    /// # Examples
1291    ///
1292    /// ```
1293    /// use zyx::Tensor;
1294    ///
1295    /// let arr = Tensor::from([[1.0, 2.0], [3.0, 4.0]]);
1296    /// assert_eq!(arr.product(1), [3., 8.]);
1297    /// ```
1298    #[must_use]
1299    pub fn product(&self, axes: impl IntoAxes) -> Result<Tensor, ZyxError> {
1300        Ok(self.ln().sum(axes)?.exp())
1301    }
1302
1303    /// Calculates the standard deviation of the input tensor along specified axes.
1304    ///
1305    /// This function calculates the standard deviation by first computing the mean along the specified axes,
1306    /// then subtracting that mean from each element, squaring the result, and finally taking the square root
1307    /// of the average of those squared differences. If no axes are provided, it computes the standard deviation
1308    /// over all elements in the tensor.
1309    ///
1310    /// # Examples
1311    ///
1312    /// ```
1313    /// use zyx::Tensor;
1314    ///
1315    /// let a = Tensor::from([[1., 2., 3.], [4., 5., 6.]]);
1316    /// assert_eq!(a.std(()), 1.5);
1317    /// ```
1318    ///
1319    /// # Panics
1320    ///
1321    /// This function will panic if the input tensor is empty.
1322    ///
1323    #[must_use]
1324    pub fn std(&self, axes: impl IntoAxes) -> Result<Tensor, ZyxError> {
1325        Ok(self.var(axes)?.sqrt())
1326    }
1327
1328    /// Creates a new tensor by applying standard deviation along specified axes.
1329    ///
1330    /// This function first computes the standard deviation of the input tensor along the specified axes,
1331    /// and then reshapes the result to match the shape of the original tensor after reduction along those axes.
1332    ///
1333    /// # Examples
1334    ///
1335    /// ```
1336    /// use zyx::{Tensor, DType};
1337    ///
1338    /// let t = Tensor::rand([3, 4], DType::F32).unwrap();
1339    /// let std_kd = t.std_kd([0, 1]);
1340    /// assert_eq!(std_kd.shape(), [1, 2]);
1341    /// ```
1342    ///
1343    /// # Panics
1344    ///
1345    /// This function panics if the input tensor has no elements.
1346    #[must_use]
1347    pub fn std_kd(&self, axes: impl IntoAxes) -> Result<Tensor, ZyxError> {
1348        self.std(axes.clone())?.reshape(self.reduce_kd_shape(axes))
1349    }
1350
1351    /// Sum reduce. Removes tensor dimensions.
1352    /// Equivalent to pytorch sum(axes, keepdim=False)
1353    /// If you want to keep reduce dimensions, see [sum_kd](Tensor::sum_kd)
1354    /// Passing empty axes executes reduce across all dimensions and result will have shape `[1]`
1355    #[must_use]
1356    pub fn sum(&self, axes: impl IntoAxes) -> Result<Tensor, ZyxError> {
1357        // TODO handle axes out of range error
1358        let rank = self.rank();
1359        let axes: Vec<usize> = axes.into_axes(rank).collect();
1360        {
1361            // We can add checks for axes being less than rank and axes not containing duplicates
1362            let mut unique = BTreeSet::new();
1363            for a in &axes {
1364                if !unique.insert(a) {
1365                    return Err(ZyxError::ShapeError("Axes contains duplicates.".into()));
1366                }
1367                // This is checked by into_axes function
1368                //assert!(a < rank, "Axes are too high");
1369            }
1370        }
1371        Ok(Tensor { id: RT.lock().sum_reduce(self.id, axes) })
1372    }
1373
1374    // Probably just have sum_kd, max_kd that keep tensor dimensions
1375    /// Like [sum](Tensor::sum) but keeps reduce dimensions, setting them to 1.
1376    /// Equivalent to pytorch sum(axes, keepdim=True)
1377    #[must_use]
1378    pub fn sum_kd(&self, axes: impl IntoAxes) -> Result<Tensor, ZyxError> {
1379        self.sum(axes.clone())?.reshape(self.reduce_kd_shape(axes))
1380    }
1381
1382    /// Comulative sum along axis.
1383    #[must_use]
1384    pub fn cumsum(&self, axis: isize) -> Result<Tensor, ZyxError> {
1385        let axis = to_axis(axis, self.rank());
1386        let pl_sz = (self.shape()[axis] - 1) as isize;
1387        let k = self.shape()[axis];
1388        let axis = axis as isize;
1389        let mut x = self.transpose(axis, -1)?;
1390        x = x.pad_zeros([(pl_sz, 0)])?;
1391        //println!("{x:?} padded");
1392        x = x.pool(k, 1, 1)?;
1393        //println!("{x:?} pooled");
1394        x = x.sum(-1)?;
1395        //println!("{x:?} summed");
1396        x = x.transpose(axis, -1)?;
1397        //println!("{x:?} transposed");
1398        Ok(x)
1399    }
1400
1401    /// Calculates the softmax of this tensor along the specified axes.
1402    ///
1403    /// # Arguments
1404    ///
1405    /// * `axes`: The axes along which to calculate the softmax.
1406    ///
1407    /// # Returns
1408    ///
1409    /// * A new tensor containing the result of the softmax operation.
1410    ///
1411    /// # Examples
1412    ///
1413    /// ```
1414    /// use zyx::Tensor;
1415    ///
1416    /// let t = Tensor::from(vec![1.0, 2.0, 3.0]);
1417    /// let sm = t.softmax(0);
1418    /// assert_eq!(sm, [0.0900305748, 0.2447281546, 0.6652412706]);
1419    /// ```
1420    ///
1421    /// # Panics
1422    ///
1423    /// This function will panic if the input tensor is empty.
1424    #[must_use]
1425    pub fn softmax(&self, axes: impl IntoAxes) -> Result<Tensor, ZyxError> {
1426        let e = (self - self.max_kd(axes.clone())?).exp();
1427        Ok(&e / e.sum_kd(axes)?)
1428    }
1429
1430    /// Calculates the variance of this tensor along the specified axes.
1431    ///
1432    /// This function first computes the mean of the tensor along the provided axes,
1433    /// then subtracts this mean from each element in the tensor, squares the result,
1434    /// and finally sums these squared differences along the same axes to obtain the variance.
1435    ///
1436    /// # Arguments
1437    ///
1438    /// * `axes` - The axes along which to compute the mean and variance. This can be a single axis or a tuple of axes.
1439    ///
1440    /// # Returns
1441    ///
1442    /// * A new tensor containing the variance values computed for each axis.
1443    ///
1444    /// # Examples
1445    ///
1446    /// ```
1447    /// use zyx::Tensor;
1448    ///
1449    /// let arr = Tensor::from([[1, 2], [3, 4]]);
1450    /// let var = arr.var(0); // Compute variance along rows (axis=0)
1451    /// assert_eq!(var, [[5.0, 2.5]]); // Expected output: [[5.0, 2.5]]
1452    ///
1453    /// let var = arr.var(1); // Compute variance along columns (axis=1)
1454    /// assert_eq!(var, [[2.5], [2.5]]); // Expected output: [[2.5], [2.5]]
1455    /// ```
1456    #[must_use]
1457    pub fn var(&self, axes: impl IntoAxes) -> Result<Tensor, ZyxError> {
1458        Ok((self - self.mean(axes.clone())?).pow(2)?.sum(axes)?)
1459    }
1460
1461    /// Calculates the variance along the specified axes.
1462    ///
1463    /// This function first calculates the mean along the specified axes using `var()`,
1464    /// then subtracts that mean from the original tensor, squares the result,
1465    /// and finally takes the mean of those squared values.
1466    ///
1467    /// # Arguments
1468    ///
1469    /// * `axes`: The axes to reduce over. If not provided, reduces over all axes.
1470    ///
1471    /// # Returns
1472    ///
1473    /// A new tensor containing the variance along the specified axes.
1474    ///
1475    /// # Examples
1476    ///
1477    /// ```
1478    /// use zyx::Tensor;
1479    ///
1480    /// let a = Tensor::from([[1., 2., 3.], [4., 5., 6.]]);
1481    /// assert_eq!(a.var_kd(0), 1.5);
1482    /// ```
1483    #[must_use]
1484    pub fn var_kd(&self, axes: impl IntoAxes) -> Result<Tensor, ZyxError> {
1485        self.var(axes.clone())?.reshape(self.reduce_kd_shape(axes))
1486    }
1487
1488    // index
1489    /// Get function
1490    #[must_use]
1491    pub fn get(&self, index: impl IntoIndex) -> Result<Tensor, ZyxError> {
1492        let shape = self.shape();
1493        let padding: Vec<(isize, isize)> = index
1494            .into_index()
1495            .into_iter()
1496            .zip(shape.iter())
1497            .map(|(r, d)| {
1498                (
1499                    if r.start >= 0 {
1500                        -r.start
1501                    } else {
1502                        -r.start - *d as isize
1503                    },
1504                    if r.end == isize::MAX {
1505                        0
1506                    } else if r.end > 0 {
1507                        -(*d as isize - r.end)
1508                    } else {
1509                        r.end
1510                    },
1511                )
1512            })
1513            .collect();
1514        let n = shape.rank() - padding.len();
1515        let padding: Vec<(isize, isize)> = padding
1516            .into_iter()
1517            .chain(core::iter::repeat((0, 0)).take(n))
1518            .collect::<Vec<(isize, isize)>>()
1519            .into_iter()
1520            .rev()
1521            .collect();
1522        //std::println!("Get padding: {padding:?}");
1523        self.pad_zeros(padding)
1524    }
1525
1526    /// Returns a tensor containing only the diagonal elements of this tensor.
1527    ///
1528    /// The diagonal is obtained by flattening the input tensor, padding it with zeros to make its last dimension size equal
1529    /// to the number of rows or columns in the original tensor, reshaping it into a 2D matrix, and then extracting the diagonal.
1530    ///
1531    /// # Returns
1532    ///
1533    /// * A new tensor containing only the diagonal elements of this tensor.
1534    ///
1535    /// # Examples
1536    ///
1537    /// ```
1538    /// use zyx::Tensor;
1539    ///
1540    /// let arr = Tensor::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9]).reshape([3, 3]);
1541    /// assert_eq!(arr.diagonal(), [[1, 0, 0], [0, 5, 0], [0, 0, 9]]); // diagonal elements are [1, 5, 9]
1542    /// ```
1543    ///
1544    /// # Panics
1545    ///
1546    /// This function panics if the input tensor has fewer than two dimensions.
1547    #[must_use]
1548    pub fn diagonal(&self) -> Tensor {
1549        let n = *self.shape().last().expect("Shape in invalid state. Internal bug.");
1550        self.flatten(..)
1551            .unwrap()
1552            .pad_zeros([(0, n as isize)])
1553            .unwrap()
1554            .reshape([n, n + 1])
1555            .unwrap()
1556            .get((.., 0))
1557            .unwrap()
1558    }
1559
1560    // binary
1561    /// Compares this tensor with another tensor element-wise.
1562    ///
1563    /// Returns a new tensor of boolean values indicating where `self` is less than `rhs`.
1564    ///
1565    /// # Examples
1566    ///
1567    /// ```
1568    /// use zyx::Tensor;
1569    ///
1570    /// let a = Tensor::from([1.0, 2.0, 3.0]);
1571    /// let b = Tensor::from([4.0, 5.0, 6.0]);
1572    /// assert_eq!(a.cmplt(b), [1., 1., 1.]);
1573    /// ```
1574    ///
1575    /// # Panics
1576    ///
1577    /// This function panics if the tensors have different shapes.
1578    #[must_use]
1579    pub fn cmplt(&self, rhs: impl Into<Tensor>) -> Result<Tensor, ZyxError> {
1580        let (x, y) = Tensor::broadcast(self, rhs)?;
1581        Ok(Tensor {
1582            id: RT.lock().cmplt(x.id, y.id),
1583        })
1584    }
1585
1586    /// Elementwise maximum between two tensors.
1587    #[must_use]
1588    pub fn maximum(&self, rhs: impl Into<Tensor>) -> Result<Tensor, ZyxError> {
1589        let (x, y) = Tensor::broadcast(self, rhs)?;
1590        Ok(Tensor {
1591            id: RT.lock().maximum(x.id, y.id),
1592        })
1593    }
1594
1595    /// Matmul and dot
1596    #[must_use]
1597    pub fn dot(&self, rhs: impl Into<Tensor>) -> Result<Tensor, ZyxError> {
1598        let rhs = rhs.into();
1599        let org_y_shape = rhs.shape();
1600        let y = rhs.t();
1601        let xshape = self.shape();
1602        let yshape = y.shape();
1603        let xrank = xshape.rank();
1604        let yrank = yshape.rank();
1605        if xshape[xrank - 1] != yshape[yrank - 1] {
1606            //yshape[-(yrank.min(2) as i64)],
1607            return Err(ZyxError::ShapeError(format!("Cannot dot tensors with shapes {xshape:?} and {org_y_shape:?}")));
1608        }
1609        let x_shape = xshape[..xrank - 1]
1610            .iter()
1611            .copied()
1612            .chain([1])
1613            .chain([xshape[xrank - 1]])
1614            .collect::<Vec<usize>>();
1615        let y_shape = yshape[0..yrank - 2]
1616            .iter()
1617            .copied()
1618            .chain([1])
1619            .chain(yshape[yrank - yrank.min(2)..yrank].iter().copied())
1620            .collect::<Vec<usize>>();
1621        //std::println!("{x_shape:?}");
1622        //std::println!("{y_shape:?}");
1623        (self.reshape(x_shape)? * y.reshape(y_shape)?)
1624            .sum(-1)?
1625            .reshape(
1626                xshape[0..xshape.len() - 1]
1627                    .iter()
1628                    .copied()
1629                    .chain([yshape[yshape.len() - 2]])
1630                    .collect::<Vec<usize>>(),
1631            )
1632    }
1633
1634    /// Matmul is just alias to dot
1635    #[must_use]
1636    pub fn matmul(&self, rhs: impl Into<Tensor>) -> Result<Tensor, ZyxError> {
1637        self.dot(rhs)
1638    }
1639
1640    /// Returns a new tensor where each element is the result of raising the corresponding element in `self` to the power of `exponent`.
1641    ///
1642    /// # Examples
1643    ///
1644    /// ```
1645    /// use zyx::Tensor;
1646    ///
1647    /// let arr = Tensor::from([1.0, 2.0]);
1648    /// assert_eq!(arr.pow(2.0), [1.0, 4.0]);
1649    /// ```
1650    ///
1651    /// # Panics
1652    ///
1653    /// This function will panic if the exponent tensor contains any invalid or non-finite values.
1654    ///
1655    /// # Returns
1656    ///
1657    /// A new tensor where each element is the result of raising the corresponding element in `self` to the power of `exponent`.
1658    #[must_use]
1659    pub fn pow(&self, exponent: impl Into<Tensor>) -> Result<Tensor, ZyxError> {
1660        let (x, y) = Tensor::broadcast(self, exponent)?;
1661        Ok(Tensor {
1662            id: RT.lock().pow(x.id, y.id),
1663        })
1664    }
1665
1666    /// Returns ones where self is true and zeros where it is false.
1667    #[must_use]
1668    pub fn nonzero(&self) -> Tensor {
1669        Tensor {
1670            id: RT.lock().nonzero(self.id),
1671        }
1672    }
1673
1674    // ternary
1675    /// Where operation. Replaces elementwise true values with if_true and false values with if_false.
1676    #[must_use]
1677    pub fn where_(&self, if_true: impl Into<Tensor>, if_false: impl Into<Tensor>) -> Result<Tensor, ZyxError> {
1678        let (x, y) = Tensor::broadcast(self, if_true)?;
1679        let (x, z) = Tensor::broadcast(x, if_false)?;
1680        let (y, z) = Tensor::broadcast(y, z)?;
1681        let x_nonzero = x.nonzero();
1682        Ok(&x_nonzero * y + !x_nonzero * z)
1683    }
1684
1685    // loss functions
1686    /// Calculates the cross-entropy loss for this tensor.
1687    ///
1688    /// This function takes a target tensor and axes as input. It first calculates the softmax of the input tensor along the specified axes,
1689    /// then multiplies the result by the logarithm of the target tensor.
1690    ///
1691    /// # Examples
1692    ///
1693    /// ```
1694    /// use zyx::Tensor;
1695    /// let input = Tensor::from([0.5, 0.2, 0.3]);
1696    /// let target = Tensor::from([1., 0., 0.]);
1697    /// assert_eq!(input.cross_entropy_loss(target, ()), -0.69314718);
1698    /// ```
1699    ///
1700    /// # Panics
1701    ///
1702    /// This function will panic if the input tensor and target tensor have different shapes.
1703    #[must_use]
1704    pub fn cross_entropy_loss(&self, target: impl Into<Tensor>, axes: impl IntoAxes) -> Result<Tensor, ZyxError> {
1705        Ok(self.ln_softmax(axes)? * target)
1706    }
1707
1708    /// Calculates the L1 loss between `self` and the target tensor.
1709    ///
1710    /// # Arguments
1711    ///
1712    /// * `target`: The target tensor to compare against. It will be converted into a `Tensor`.
1713    ///
1714    /// # Returns
1715    ///
1716    /// A new `Tensor` containing the absolute difference between `self` and the target tensor.
1717    ///
1718    /// # Examples
1719    ///
1720    /// ```
1721    /// use zyx::Tensor;
1722    ///
1723    /// let self_tensor = Tensor::from(&[1.0, 2.0, 3.0]);
1724    /// let target_tensor = Tensor::from(&[2.0, 3.0, 4.0]);
1725    ///
1726    /// assert_eq!(self_tensor.l1_loss(target_tensor), Tensor::from(&[1.0, 1.0, 1.0]));
1727    /// ```
1728    #[must_use]
1729    pub fn l1_loss(&self, target: impl Into<Tensor>) -> Tensor {
1730        (self - target).abs()
1731    }
1732
1733    /// Calculates the Mean Squared Error (MSE) loss.
1734    ///
1735    /// # Arguments
1736    ///
1737    /// * `target`: The target tensor to compare against the input tensor (`self`).
1738    ///
1739    /// # Returns
1740    ///
1741    /// * A new tensor containing the MSE loss values.
1742    ///
1743    /// # Example
1744    ///
1745    /// ```
1746    /// use zyx::Tensor;
1747    ///
1748    /// let input = Tensor::from([2.0, 3.0]);
1749    /// let target = Tensor::from([4.0, 5.0]);
1750    ///
1751    /// assert_eq!(input.mse_loss(target), Tensor::from([1.0, 1.0]));
1752    /// ```
1753    ///
1754    /// # Panics
1755    ///
1756    /// This function will panic if the input tensor and target tensor have different shapes.
1757    pub fn mse_loss(&self, target: impl Into<Tensor>) -> Result<Tensor, ZyxError> {
1758        (self - target).pow(2)
1759    }
1760
1761    /// Calculates the cosine similarity between this tensor and another.
1762    ///
1763    /// # Arguments
1764    ///
1765    /// * `rhs`: The other tensor to compare against. It will be converted into a `Tensor`.
1766    /// * `eps`: A tolerance value for numerical stability, which will also be converted into a `Tensor`.
1767    ///
1768    /// # Returns
1769    ///
1770    /// A new `Tensor` containing the cosine similarity values.
1771    ///
1772    /// # Example
1773    ///
1774    /// ```
1775    /// use zyx::Tensor;
1776    ///
1777    /// let tensor1 = Tensor::from([1.0, 2.0, 3.0]);
1778    /// let tensor2 = Tensor::from([4.0, 5.0, 6.0]);
1779    /// let eps = Tensor::from([1e-9]);
1780    ///
1781    /// let similarity = tensor1.cosine_similarity(tensor2, eps);
1782    /// ```
1783    ///
1784    /// # Panics
1785    ///
1786    /// This function panics if the input tensors have different shapes.
1787    #[must_use]
1788    pub fn cosine_similarity(&self, rhs: impl Into<Tensor>, eps: impl Into<Tensor>) -> Result<Tensor, ZyxError> {
1789        let rhs: Tensor = rhs.into();
1790        let eps: Tensor = eps.into();
1791        let x = self.pow(2)?.sqrt() * rhs.pow(2)?.sqrt();
1792        Ok(self * rhs / x.cmplt(&eps)?.where_(eps, x)?)
1793    }
1794
1795    // misc
1796    /// Flatten. Joins axes into one dimension,
1797    #[must_use]
1798    pub fn flatten(&self, axes: impl RangeBounds<isize>) -> Result<Tensor, ZyxError> {
1799        let shape = self.shape();
1800        let rank = shape.len();
1801        let start_dim = to_axis(
1802            match axes.start_bound() {
1803                Bound::Included(dim) => *dim,
1804                Bound::Excluded(dim) => *dim + 1,
1805                Bound::Unbounded => 0,
1806            },
1807            rank,
1808        );
1809        let end_dim = to_axis(
1810            match axes.end_bound() {
1811                Bound::Included(dim) => *dim,
1812                Bound::Excluded(dim) => *dim - 1,
1813                Bound::Unbounded => 0,
1814            },
1815            rank,
1816        );
1817        let dim = shape[start_dim..end_dim].iter().product();
1818        let new_shape: Vec<usize> = shape[..start_dim]
1819            .iter()
1820            .copied()
1821            .chain([dim])
1822            .chain(shape[end_dim..].iter().copied())
1823            .collect();
1824        self.reshape(new_shape)
1825    }
1826
1827    /// Concatenates a list of tensors along a specified dimension.
1828    ///
1829    /// # Arguments
1830    ///
1831    /// * `tensors`: An iterator of tensor references to concatenate.
1832    /// * `dim`: The dimension along which to concatenate. If negative, it is interpreted as counting from the end.
1833    ///
1834    /// # Returns
1835    ///
1836    /// A new tensor containing the concatenated input tensors.
1837    ///
1838    /// # Panics
1839    ///
1840    /// This function panics if any two tensors have different shapes except at the specified dimension.
1841    ///
1842    /// # Examples
1843    ///
1844    /// ```
1845    /// use zyx::Tensor;
1846    ///
1847    /// let a = Tensor::from([[1, 2], [3, 4]]);
1848    /// let b = Tensor::from([[5, 6], [7, 8]]);
1849    /// let c = Tensor::cat([&a, &b], 0);
1850    /// assert_eq!(c, [[1, 2], [3, 4], [5, 6], [7, 8]]);
1851    /// ```
1852    ///
1853    #[must_use]
1854    pub fn cat<'a>(tensors: impl IntoIterator<Item = &'a Tensor>, dim: isize) -> Result<Tensor, ZyxError> {
1855        let tensors: Vec<&Tensor> = tensors.into_iter().collect();
1856        if tensors.len() < 2 {
1857            return Err(ZyxError::ShapeError("Cat requires two or more tensors.".into()));
1858        }
1859        let shape = tensors[0].shape();
1860        let rank = shape.rank();
1861        let dim = if dim < 0 { dim + rank as isize } else { dim } as usize;
1862        // Dimension check
1863        for tensor in &tensors {
1864            for (i, (d1, d2)) in shape.iter().zip(tensor.shape().iter()).enumerate() {
1865                if i != dim {
1866                    if *d1 != *d2 {
1867                        return Err(ZyxError::ShapeError("Cannot concatenate these tensors.".into()));
1868                    }
1869                }
1870            }
1871        }
1872        let mut offset = 0isize;
1873        let mut offset2 = tensors.iter().fold(0, |acc, t| acc + t.shape()[dim] as isize);
1874        let mut shape = tensors[0].shape();
1875        shape[dim] = offset2 as usize;
1876        let mut res = None;
1877        for tensor in tensors {
1878            let d = tensor.shape()[dim] as isize;
1879            offset2 -= d;
1880            let padding: Vec<(isize, isize)> = core::iter::repeat((0isize, 0isize))
1881                    .take(rank - dim - 1)
1882                    .chain([(offset, offset2)]).collect();
1883            let t = tensor.pad_zeros(padding)?;
1884            if let Some(r) = res {
1885                res = Some(r + t);
1886            } else {
1887                res = Some(t);
1888            }
1889            offset += d;
1890        }
1891        Ok(res.unwrap())
1892    }
1893
1894    /// Expands the dimensionality of a tensor by inserting singleton dimensions.
1895    ///
1896    /// # Arguments
1897    ///
1898    /// * `dim`: The dimension to insert the singleton dimension at. If negative, it is counted from the end.
1899    ///
1900    /// # Returns
1901    ///
1902    /// A new tensor with expanded dimensionality.
1903    ///
1904    /// # Examples
1905    ///
1906    /// ```
1907    /// use zyx::{Tensor, DType};
1908    ///
1909    /// let t = Tensor::zeros([2, 3], DType::I8);
1910    /// assert_eq!(t.unsqueeze(1).shape(), &[2, 1, 3]);
1911    /// assert_eq!(t.unsqueeze(-1).shape(), &[2, 3, 1]);
1912    /// ```
1913    #[must_use]
1914    pub fn unsqueeze(&self, dim: isize) -> Result<Tensor, ZyxError> {
1915        let shape = self.shape();
1916        if dim < 0 {
1917            let rank = shape.len();
1918            let dim = (-dim) as usize;
1919            let dim = rank - dim + 1;
1920            self.reshape(
1921                shape[..dim]
1922                    .iter()
1923                    .copied()
1924                    .chain([1])
1925                    .chain(shape[dim..].iter().copied())
1926                    .collect::<Vec<usize>>(),
1927            )
1928        } else {
1929            let dim = dim as usize;
1930            self.reshape(
1931                shape[..dim]
1932                    .iter()
1933                    .copied()
1934                    .chain([1])
1935                    .chain(shape[dim..].iter().copied())
1936                    .collect::<Vec<usize>>(),
1937            )
1938        }
1939    }
1940
1941    /// Creates a new tensor by stacking the input tensors along the specified dimension.
1942    ///
1943    /// # Arguments
1944    ///
1945    /// * `tensors`: An iterator of tensor references to stack.
1946    /// * `dim`: The dimension along which to stack the tensors.
1947    ///
1948    /// # Returns
1949    ///
1950    /// A new tensor containing the stacked tensors.
1951    ///
1952    /// # Examples
1953    ///
1954    /// ```
1955    /// use zyx::Tensor;
1956    /// let a = Tensor::from([[1, 2], [3, 4]]);
1957    /// let b = Tensor::from([[5, 6], [7, 8]]);
1958    /// assert_eq!(Tensor::stack([&a, &b], 0), array![[[1, 2],
1959    ///                                                [3, 4]],
1960    ///                                               [[5, 6],
1961    ///                                                [7, 8]]]);
1962    /// ```
1963    ///
1964    /// # Panics
1965    ///
1966    /// This function will panic if the tensors have different shapes along the stacking dimension.
1967    ///
1968    /// # See also
1969    ///
1970    /// [`unsqueeze`](Tensor::unsqueeze), [`cat`](Tensor::cat)
1971    #[must_use]
1972    pub fn stack<'a>(tensors: impl IntoIterator<Item = &'a Tensor>, dim: isize) -> Result<Tensor, ZyxError> {
1973        // TODO handle dim corretly
1974        let tensors: Vec<Tensor> = tensors.into_iter().map(|t| t.unsqueeze(dim).unwrap()).collect();
1975        Tensor::cat(&tensors, dim)
1976    }
1977
1978    /// Split tensor into multiple tensors at given dim/axis
1979    #[must_use]
1980    pub fn split(&self, sizes: impl IntoShape, dim: isize) -> Result<Vec<Tensor>, ZyxError> {
1981        // assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
1982        // dim = self._resolve_dim(dim)
1983        // if isinstance(sizes, int): sizes = [min(sizes, self.shape[dim]-i) for i in range(0, max(1, self.shape[dim]), max(1, sizes))]
1984        // assert sum(sizes) == self.shape[dim], f"expect sizes to sum exactly to {self.shape[dim]}, but got {sum(sizes)}"
1985        // return tuple(self[sl] for sl in [tuple([slice(None)]*dim + [slice(sum(sizes[:i]), sum(sizes[:i + 1]))]) for i in range(len(sizes))])
1986        let sizes: Vec<usize> = sizes.into_shape().collect();
1987        let shape = self.shape();
1988        let rank = shape.rank();
1989        let dim: usize = if dim < 0 { dim + rank as isize } else { dim } as usize;
1990        if sizes.iter().sum::<usize>() != shape[dim] {
1991            return Err(ZyxError::ShapeError(format!("Sizes must sum exactly to {}, but got {:?}, which sums to {}", shape[dim], sizes, sizes.iter().sum::<usize>())));
1992        }
1993
1994        let mut res = Vec::new();
1995        let mut acc_size = 0;
1996        for size in sizes {
1997            let size = size as isize;
1998            let mut index = Vec::new();
1999            for i in 0..dim {
2000                index.push(0..shape[i] as isize);
2001            }
2002            index.push(acc_size..acc_size + size);
2003            //println!("Index {index:?}");
2004            res.push(self.get(index)?);
2005            acc_size += size;
2006        }
2007        Ok(res)
2008    }
2009
2010    /// Masked fill
2011    #[must_use]
2012    pub fn masked_fill(&self, mask: impl Into<Tensor>, value: impl Into<Tensor>) -> Result<Tensor, ZyxError> {
2013        mask.into().where_(value, self)
2014    }
2015
2016    /*#[must_use]
2017    fn tri(n: usize, dtype: DType) -> Tensor {
2018        // if r == 0 or c == 0 or diagonal >= c: return Tensor.zeros(r,c,**kwargs)
2019        // if r+diagonal <= 0: return Tensor.ones(r,c,**kwargs)
2020        // s = r+c-1
2021        // # build a (s, s) upper triangle
2022        // t = Tensor.ones(s,s,**kwargs).pad((None,(0,s))).flatten().shrink(((0,s*(2*s-1)),)).reshape(s,-1).shrink((None,(0,s)))
2023        // return t[:r,-diagonal:c-diagonal] if diagonal <= 0 else t[diagonal:r+diagonal,:c]
2024        Tensor::ones([n * n / 2], dtype).pad_zeros([(0, n * n / 2)])
2025    }*/
2026
2027    // Returns upper triangular part of the input tensor, other elements are set to zero
2028    /*#[must_use]
2029    pub fn triu(&self, diagonal: isize) -> Tensor {
2030        todo!()
2031    }*/
2032
2033    /// Pooling function with kernel size, stride and dilation
2034    #[must_use]
2035    pub fn pool(
2036        &self,
2037        kernel_size: impl IntoShape,
2038        stride: impl IntoShape,
2039        dilation: impl IntoShape,
2040    ) -> Result<Tensor, ZyxError> {
2041        // What a complex function ...
2042        let k_: Vec<usize> = kernel_size.into_shape().collect();
2043        let stride: Vec<usize> = stride.into_shape().collect();
2044        let dilation: Vec<usize> = dilation.into_shape().collect();
2045
2046        let shape = self.shape();
2047        let rank = shape.len();
2048
2049        let s_: Vec<usize> = if stride.len() == 1 {
2050            repeat(stride[0]).take(k_.len()).collect()
2051        } else {
2052            stride
2053        };
2054        let d_: Vec<usize> = if dilation.len() == 1 {
2055            repeat(dilation[0]).take(k_.len()).collect()
2056        } else {
2057            dilation
2058        };
2059        let i_ = &shape[rank - k_.len()..];
2060        let o_: Vec<usize> = i_
2061            .iter()
2062            .cloned()
2063            .zip(d_.iter().cloned())
2064            .zip(k_.iter().cloned())
2065            .zip(s_.iter().cloned())
2066            .map(|(((i, d), k), s)| (i - d * (k - 1)).div_ceil(s))
2067            .collect();
2068        //println!("s_ {s_:?}, d_ {d_:?}, i_ {i_:?} o_ {o_:?}");
2069        let repeats: Vec<usize> = repeat(1)
2070            .take(rank - k_.len())
2071            .chain(
2072                k_.iter()
2073                    .copied()
2074                    .zip(i_.iter().copied())
2075                    .zip(d_.iter().copied())
2076                    .map(|((k, i), d)| (k * (i + d)).div_ceil(i)),
2077            )
2078            .collect();
2079        //println!("repeats {repeats:?}");
2080        let pad_b: Vec<Range<isize>> = shape[..rank - k_.len()]
2081            .iter()
2082            .map(|&d| 0..d as isize)
2083            .collect();
2084        let sh_b: Vec<usize> = shape[..rank - k_.len()].into();
2085        let mut xup = self.repeat(repeats)?;
2086
2087        // dilation
2088        //println!("{xup:?} before padding");
2089        let padding: Vec<Range<isize>> = pad_b
2090            .iter()
2091            .cloned()
2092            .chain(
2093                k_.iter()
2094                    .copied()
2095                    .zip(i_.iter().copied())
2096                    .zip(d_.iter().copied())
2097                    .map(|((k, i), d)| (0..(k * (i + d)) as isize)),
2098            )
2099            .collect();
2100        //println!("Padding {padding:?}");
2101        xup = xup.get(padding)?;
2102        //println!("{xup} padded");
2103        let sh: Vec<usize> = sh_b
2104            .iter()
2105            .copied()
2106            .chain(
2107                k_.iter()
2108                    .copied()
2109                    .zip(i_.iter().copied())
2110                    .zip(d_.iter().copied())
2111                    .map(|((k, i), d)| [k, i + d])
2112                    .flatten(),
2113            )
2114            .collect();
2115        //println!("Reshape {sh:?}");
2116        xup = xup.reshape(sh)?;
2117
2118        // stride
2119        // padding = noop_ + flatten(((0,k), (0,o*s)) for k,o,s in zip(k_, o_, s_))
2120        // xup = xup.shrink(padding)
2121        let padding: Vec<Range<isize>> = pad_b
2122            .iter()
2123            .cloned()
2124            .chain(
2125                k_.iter()
2126                    .copied()
2127                    .zip(o_.iter().copied())
2128                    .zip(s_.iter().copied())
2129                    .map(|((k, o), s)| [(0..k as isize), (0..(o * s) as isize)])
2130                    .flatten(),
2131            )
2132            .collect();
2133        xup = xup.get(padding)?;
2134        // sh = noop_ + flatten((k,o,s) for k,o,s in zip(k_, o_, s_))
2135        // xup = xup.reshape(sh)
2136        let sh: Vec<usize> = sh_b
2137            .iter()
2138            .copied()
2139            .chain(
2140                k_.iter()
2141                    .copied()
2142                    .zip(o_.iter().copied())
2143                    .zip(s_.iter().copied())
2144                    .map(|((k, o), s)| [k, o, s])
2145                    .flatten(),
2146            )
2147            .collect();
2148        xup = xup.reshape(sh)?;
2149        // padding = noop_ + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_, o_))
2150        // xup = xup.shrink(padding)
2151        let padding: Vec<Range<isize>> = pad_b
2152            .iter()
2153            .cloned()
2154            .chain(
2155                k_.iter()
2156                    .copied()
2157                    .zip(o_.iter().copied())
2158                    .map(|(k, o)| [(0..k as isize), (0..o as isize), (0..1)])
2159                    .flatten(),
2160            )
2161            .collect();
2162        xup = xup.get(padding)?;
2163        // sh = noop_ + flatten((k,o) for k,o in zip(k_, o_))
2164        // xup = xup.reshape(sh)
2165        let sh: Vec<usize> = sh_b
2166            .iter()
2167            .copied()
2168            .chain(
2169                k_.iter()
2170                    .copied()
2171                    .zip(o_.iter().copied())
2172                    .map(|(k, o)| [k, o])
2173                    .flatten(),
2174            )
2175            .collect();
2176        xup = xup.reshape(sh)?;
2177
2178        // xup.permute(*range(len(noop_)), *[len(noop_)+i*2+1 for i in range(len(i_))], *[len(noop_)+i*2 for i in range(len(i_))])
2179        let axes: Vec<isize> = (0..rank - k_.len())
2180            .chain((0..i_.len()).map(|i| rank - k_.len() + i * 2 + 1))
2181            .chain((0..i_.len()).map(|i| rank - k_.len() + i * 2))
2182            .map(|i| i as isize)
2183            .collect();
2184        xup = xup.permute(axes)?;
2185
2186        Ok(xup)
2187    }
2188
2189    /// Creates a new tensor by repeating the input tensor along its dimensions.
2190    ///
2191    /// The `repeats` parameter specifies how many times to repeat each dimension of the tensor. If the length of `repeats`
2192    /// is less than the rank of the tensor, it will be padded with ones at the beginning.
2193    ///
2194    /// # Examples
2195    ///
2196    /// ```
2197    /// use zyx::Tensor;
2198    ///
2199    /// let arr = Tensor::from(vec![1, 2, 3]);
2200    /// assert_eq!(arr.repeat([2]), vec![1, 2, 3, 4, 5, 6]);
2201    /// ```
2202    ///
2203    /// # Panics
2204    ///
2205    /// This function will panic if the input tensor has zero dimensions.
2206    ///
2207    /// Returns a new tensor with the repeated values.
2208    #[must_use]
2209    pub fn repeat(&self, repeats: impl IntoShape) -> Result<Tensor, ZyxError> {
2210        let repeats: Vec<usize> = repeats.into_shape().collect();
2211        let shape = self.shape();
2212        let rank = shape.len();
2213        if repeats.len() < rank {
2214            return Err(ZyxError::ShapeError("Repeats must be greater or equal to rank of the tensor.".into()));
2215        }
2216
2217        let base_shape: Vec<usize> = repeat(1)
2218            .take(repeats.len() - rank)
2219            .chain(shape.iter().copied())
2220            .collect();
2221        let new_shape: Vec<usize> = repeat(1)
2222            .take(repeats.len() - rank)
2223            .chain(shape.into_iter())
2224            .flat_map(|d| [1, d])
2225            .collect();
2226        let expand_shape: Vec<usize> = repeats
2227            .iter()
2228            .copied()
2229            .zip(base_shape.iter().copied())
2230            .flat_map(|(r, d)| [r, d])
2231            .collect();
2232        let final_shape: Vec<usize> = repeats
2233            .iter()
2234            .copied()
2235            .zip(base_shape.iter().copied())
2236            .map(|(r, d)| r * d)
2237            .collect();
2238
2239        //println!("base_shape {base_shape:?} {new_shape:?} {expand_shape:?} {final_shape:?}");
2240
2241        let mut x = self.reshape(new_shape)?;
2242        x = x.expand(expand_shape)?;
2243        x = x.reshape(final_shape)?;
2244        Ok(x)
2245    }
2246
2247    /*#[must_use]
2248    pub fn conv(&self) -> Tensor {
2249        todo!()
2250    }*/
2251
2252    // io
2253    /// Load module from path
2254    pub fn load<Module: FromIterator<Tensor>>(path: impl AsRef<Path>) -> Result<Module, ZyxError> {
2255        let debug_print: bool = RT.lock().debug_dev();
2256        use std::io::Read;
2257        let mut f = std::fs::File::open(path)?;
2258        let mut header_len = [0u8; 8];
2259        f.read_exact(&mut header_len)?;
2260        let n = usize::try_from(u64::from_le_bytes(header_len)).map_err(|e| {
2261            ZyxError::ParseError(format!(
2262                "Failed to parse header len in safetensors file. {e}"
2263            ))
2264        })?;
2265        let mut header = vec![0u8; n];
2266        f.read_exact(&mut header)?;
2267        let header = core::str::from_utf8(&header)
2268            .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidData, err))?;
2269        let mut text = String::with_capacity(10);
2270        let mut begin_str = false;
2271        let mut i = 0;
2272        let mut tensors = Vec::new();
2273        let mut dtype = DType::F32;
2274        let mut shape = vec![1];
2275        for x in header.chars() {
2276            if ['"', '[', ']'].contains(&x) {
2277                if begin_str {
2278                    //std::println!("{text}");
2279                    if i % 7 == 0 {
2280                        //params[i / 7].set_label(&text);
2281                    } else if i % 7 == 2 {
2282                        dtype = DType::from_safetensors(&text)?;
2283                    } else if i % 7 == 4 {
2284                        shape = text
2285                            .split(',')
2286                            .map(|d| {
2287                                d.parse::<usize>().map_err(|err| {
2288                                    ZyxError::ParseError(format!(
2289                                        "Cannot parse safetensors shape: {err}"
2290                                    ))
2291                                })
2292                            })
2293                            .collect::<Result<_, ZyxError>>()?;
2294                    } else if i % 7 == 6 {
2295                        // TODO assert offsets
2296                        //println!("Offsets: {text}");
2297                        let offsets = text
2298                            .split(',')
2299                            .map(|offset| {
2300                                offset.parse::<usize>().map_err(|err| {
2301                                    ZyxError::ParseError(format!(
2302                                        "Could not parse safetensors offset: {err}"
2303                                    ))
2304                                })
2305                            })
2306                            .collect::<Result<Vec<usize>, ZyxError>>()?;
2307                        //println!("Offsets: {offsets:?}");
2308                        let bytes = shape.iter().product::<usize>() * dtype.byte_size();
2309                        if offsets[1] - offsets[0] != bytes {
2310                            return Err(ZyxError::ParseError(
2311                                "Safetensors shapes and offsets are incorrect.".into(),
2312                            ));
2313                        }
2314                        let mut buf = vec![0u8; bytes];
2315                        if debug_print {
2316                            print!("Loading tensor with shape {shape:?}, {dtype:?} ...");
2317                        }
2318                        f.read_exact(&mut buf)?;
2319                        if debug_print {
2320                            println!(" DONE");
2321                        }
2322                        tensors.push(match dtype {
2323                            DType::F32 => {
2324                                let vec: Vec<f32> = buf
2325                                    .chunks_exact(dtype.byte_size())
2326                                    .map(|x| f32::from_le_bytes([x[0], x[1], x[2], x[3]]))
2327                                    .collect();
2328                                Tensor::from(vec).reshape(&shape)?
2329                            }
2330                            DType::F64 => {
2331                                let vec: Vec<f64> = buf
2332                                    .chunks_exact(dtype.byte_size())
2333                                    .map(|x| {
2334                                        f64::from_le_bytes([
2335                                            x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7],
2336                                        ])
2337                                    })
2338                                    .collect();
2339                                Tensor::from(vec).reshape(&shape)?
2340                            }
2341                            DType::I32 => {
2342                                let vec: Vec<i32> = buf
2343                                    .chunks_exact(dtype.byte_size())
2344                                    .map(|x| i32::from_le_bytes([x[0], x[1], x[2], x[3]]))
2345                                    .collect();
2346                                Tensor::from(vec).reshape(&shape)?
2347                            }
2348                            _ => todo!(),
2349                        });
2350                    }
2351                    i += 1;
2352                    text.clear();
2353                    begin_str = false;
2354                } else {
2355                    text.clear();
2356                    begin_str = true;
2357                }
2358            } else {
2359                text.push(x);
2360            }
2361        }
2362        Ok(Module::from_iter(tensors))
2363    }
2364
2365    /// All tensor elements as contiguous le_bytes vector in row major order
2366    pub fn to_le_bytes(&self) -> Result<Vec<u8>, ZyxError> {
2367        Ok(match self.dtype() {
2368            DType::F32 => {
2369                let data: Vec<f32> = self.clone().try_into()?;
2370                data.into_iter().flat_map(|x| x.to_le_bytes()).collect()
2371            }
2372            DType::F64 => {
2373                let data: Vec<f64> = self.clone().try_into()?;
2374                data.into_iter().flat_map(|x| x.to_le_bytes()).collect()
2375            }
2376            DType::U8 => {
2377                let data: Vec<u8> = self.clone().try_into()?;
2378                data.into_iter().flat_map(|x| x.to_le_bytes()).collect()
2379            }
2380            DType::I8 => {
2381                let data: Vec<i8> = self.clone().try_into()?;
2382                data.into_iter().flat_map(|x| x.to_le_bytes()).collect()
2383            }
2384            DType::I16 => {
2385                let data: Vec<i16> = self.clone().try_into()?;
2386                data.into_iter().flat_map(|x| x.to_le_bytes()).collect()
2387            }
2388            DType::I32 => {
2389                let data: Vec<i32> = self.clone().try_into()?;
2390                data.into_iter().flat_map(|x| x.to_le_bytes()).collect()
2391            }
2392            DType::I64 => {
2393                let data: Vec<i64> = self.clone().try_into()?;
2394                data.into_iter().flat_map(|x| x.to_le_bytes()).collect()
2395            }
2396            DType::Bool => {
2397                let data: Vec<bool> = self.clone().try_into()?;
2398                unsafe { std::mem::transmute(data) }
2399            }
2400        })
2401    }
2402
2403    /// Load tensor from le_bytes in row major order
2404    pub fn from_le_bytes(&self, bytes: &[u8]) -> Result<(), ZyxError> {
2405        let _ = bytes;
2406        todo!()
2407    }
2408}
2409
2410pub struct DebugGuard {
2411    debug: u32,
2412}
2413
2414impl Drop for DebugGuard {
2415    fn drop(&mut self) {
2416        RT.lock().debug = self.debug;
2417    }
2418}
2419
2420impl Tensor {
2421    /// If self is not float, then cast it to float
2422    #[must_use]
2423    fn float_cast(&self) -> Tensor {
2424        let dtype = self.dtype();
2425        if !dtype.is_float() {
2426            return match dtype.byte_size() {
2427                #[cfg(feature = "half")]
2428                1 | 2 => self.cast(DType::F16),
2429                #[cfg(feature = "half")]
2430                4 => self.cast(DType::F32),
2431                #[cfg(not(feature = "half"))]
2432                1 | 2 | 4 => self.cast(DType::F32),
2433                8 => self.cast(DType::F64),
2434                _ => panic!(),
2435            };
2436        }
2437        self.clone()
2438    }
2439
2440    /// Braodcasts to synchronize shapes and casts to synchronize dtypss
2441    /// This does both automatic expand AND automatic casting between dtypes.
2442    // TODO Both of these can be disable by changing a setting in the backend.
2443    #[must_use]
2444    fn broadcast(x: impl Into<Tensor>, y: impl Into<Tensor>) -> Result<(Tensor, Tensor), ZyxError> {
2445        let mut x = x.into();
2446        let mut y = y.into();
2447        /*assert_eq!(
2448            graph.dtype(xid),
2449            graph.dtype(yid),
2450            "{op} parameters {xid} and {yid} have different dtypes: {} and {}",
2451            graph.dtype(xid),
2452            graph.dtype(yid)
2453        );*/
2454        // Now we just do implicit conversions. Not exactly rust style, but it's convenient.
2455        // We can later add option for backend to disable these implicit conversions.
2456        match (x.dtype(), y.dtype()) {
2457            (DType::F32, DType::I32) => y = y.cast(DType::F32),
2458            (DType::F32, DType::F64) => x = x.cast(DType::F64),
2459            (DType::I32, DType::F32) => x = x.cast(DType::F32),
2460            (DType::I32, DType::F64) => x = x.cast(DType::F64),
2461            (DType::F64, DType::F32) => y = y.cast(DType::F64),
2462            (DType::F64, DType::I32) => y = y.cast(DType::F64),
2463            _ => {}
2464        }
2465        let mut x_shape = x.shape();
2466        let mut y_shape = y.shape();
2467
2468        for (&x, &y) in x_shape.iter().rev().zip(y_shape.iter().rev()) {
2469            if x != y {
2470                if x != 1 && y != 1 {
2471                    return Err(ZyxError::ShapeError(format!("Left and right tensor shapes can not be broadcasted: {x_shape:?} and {y_shape:?}")));
2472                }
2473                //assert!( *x == 1 || *y == 1, "Left and right tensor shapes can not be broadcasted: {x_shape:?} and {y_shape:?}");
2474            }
2475        }
2476
2477        let rx = x_shape.rank();
2478        let ry = y_shape.rank();
2479        match rx.cmp(&ry) {
2480            Ordering::Less => {
2481                x_shape = core::iter::repeat(1)
2482                    .take(ry - rx)
2483                    .chain(x_shape.into_iter())
2484                    .collect();
2485            }
2486            Ordering::Greater => {
2487                y_shape = core::iter::repeat(1)
2488                    .take(rx - ry)
2489                    .chain(y_shape.into_iter())
2490                    .collect();
2491            }
2492            Ordering::Equal => {}
2493        }
2494        let mut eshape = Vec::new();
2495        for (x, y) in x_shape.iter().zip(y_shape.iter()) {
2496            eshape.push(*x.max(y));
2497        }
2498        x = x.reshape(&x_shape)?;
2499        if x_shape != eshape {
2500            x = x.expand(&eshape)?;
2501        }
2502        //println!("Second broadcast operand {y}");
2503        y = y.reshape(&y_shape)?;
2504        //println!("{x_shape:?}, {y_shape:?}, {eshape:?}");
2505        //println!("After reshape second broadcast operand {y}");
2506        //Tensor::plot_graph([], "graph");
2507        if y_shape != eshape {
2508            y = y.expand(&eshape)?;
2509        }
2510        //println!("Second broadcast operand {y}");
2511        //println!("Broadcasted to {eshape:?}");
2512        //println!("y shape {:?}", y.shape());
2513        return Ok((x, y));
2514    }
2515
2516    // Calculate shape for reduce which keeps reduced dims set to 1
2517    fn reduce_kd_shape(&self, axes: impl IntoAxes) -> Vec<usize> {
2518        let mut shape = self.shape();
2519        for a in axes.clone().into_axes(shape.len()) {
2520            shape[a] = 1;
2521        }
2522        shape
2523    }
2524
2525    pub(super) fn id(&self) -> TensorId {
2526        self.id
2527    }
2528}
2529
2530#[cfg(feature = "half")]
2531impl TryFrom<Tensor> for bf16 {
2532    type Error = ZyxError;
2533    fn try_from(value: Tensor) -> Result<Self, Self::Error> {
2534        RT.lock()
2535            .load(value.id)?
2536            .first()
2537            .copied()
2538            .ok_or(ZyxError::EmptyTensor)
2539    }
2540}
2541
2542#[cfg(feature = "half")]
2543impl TryFrom<Tensor> for f16 {
2544    type Error = ZyxError;
2545    fn try_from(value: Tensor) -> Result<Self, Self::Error> {
2546        RT.lock()
2547            .load(value.id)?
2548            .first()
2549            .copied()
2550            .ok_or(ZyxError::EmptyTensor)
2551    }
2552}
2553
2554impl TryFrom<Tensor> for f32 {
2555    type Error = ZyxError;
2556    fn try_from(value: Tensor) -> Result<Self, Self::Error> {
2557        let mut data = [0.];
2558        RT.lock().load(value.id, &mut data)?;
2559        Ok(data[0])
2560    }
2561}
2562
2563impl TryFrom<Tensor> for f64 {
2564    type Error = ZyxError;
2565    fn try_from(value: Tensor) -> Result<Self, Self::Error> {
2566        let mut data = [0.];
2567        RT.lock().load(value.id, &mut data)?;
2568        Ok(data[0])
2569    }
2570}
2571
2572#[cfg(feature = "complex")]
2573impl TryFrom<Tensor> for Complex<f32> {
2574    type Error = ZyxError;
2575    fn try_from(value: Tensor) -> Result<Self, Self::Error> {
2576        RT.lock()
2577            .load(value.id)?
2578            .first()
2579            .copied()
2580            .ok_or(ZyxError::EmptyTensor)
2581    }
2582}
2583
2584#[cfg(feature = "complex")]
2585impl TryFrom<Tensor> for Complex<f64> {
2586    type Error = ZyxError;
2587    fn try_from(value: Tensor) -> Result<Self, Self::Error> {
2588        RT.lock()
2589            .load(value.id)?
2590            .first()
2591            .copied()
2592            .ok_or(ZyxError::EmptyTensor)
2593    }
2594}
2595
2596impl TryFrom<Tensor> for u8 {
2597    type Error = ZyxError;
2598    fn try_from(value: Tensor) -> Result<Self, Self::Error> {
2599        let mut data = [0];
2600        RT.lock().load(value.id, &mut data)?;
2601        Ok(data[0])
2602    }
2603}
2604
2605impl TryFrom<Tensor> for i8 {
2606    type Error = ZyxError;
2607    fn try_from(value: Tensor) -> Result<Self, Self::Error> {
2608        let mut data = [0];
2609        RT.lock().load(value.id, &mut data)?;
2610        Ok(data[0])
2611    }
2612}
2613
2614impl TryFrom<Tensor> for i16 {
2615    type Error = ZyxError;
2616    fn try_from(value: Tensor) -> Result<Self, Self::Error> {
2617        let mut data = [0];
2618        RT.lock().load(value.id, &mut data)?;
2619        Ok(data[0])
2620    }
2621}
2622
2623impl TryFrom<Tensor> for i32 {
2624    type Error = ZyxError;
2625    fn try_from(value: Tensor) -> Result<Self, Self::Error> {
2626        let mut data = [0];
2627        RT.lock().load(value.id, &mut data)?;
2628        Ok(data[0])
2629    }
2630}
2631
2632impl TryFrom<Tensor> for i64 {
2633    type Error = ZyxError;
2634    fn try_from(value: Tensor) -> Result<Self, Self::Error> {
2635        let mut data = [0];
2636        RT.lock().load(value.id, &mut data)?;
2637        Ok(data[0])
2638    }
2639}
2640
2641impl TryFrom<Tensor> for bool {
2642    type Error = ZyxError;
2643    fn try_from(value: Tensor) -> Result<Self, Self::Error> {
2644        let mut data = [false];
2645        RT.lock().load(value.id, &mut data)?;
2646        Ok(data[0])
2647    }
2648}
2649
2650impl<T: Scalar> TryFrom<Tensor> for Vec<T> {
2651    type Error = ZyxError;
2652    fn try_from(value: Tensor) -> Result<Self, Self::Error> {
2653        let numel = value.numel();
2654        let mut data = Vec::with_capacity(numel);
2655        unsafe { data.set_len(numel) };
2656        RT.lock().load(value.id, &mut data)?;
2657        Ok(data)
2658    }
2659}
2660
2661impl<T: Scalar, const D0: usize> TryFrom<Tensor> for [T; D0] {
2662    type Error = ZyxError;
2663    fn try_from(value: Tensor) -> Result<Self, Self::Error> {
2664        let mut data = [T::zero(); D0];
2665        RT.lock().load(value.id, &mut data)?;
2666        Ok(data)
2667    }
2668}
2669
2670impl<T: Scalar, const D0: usize, const D1: usize> TryFrom<Tensor> for [[T; D1]; D0] {
2671    type Error = ZyxError;
2672    fn try_from(value: Tensor) -> Result<Self, Self::Error> {
2673        let mut data = [[T::zero(); D1]; D0];
2674        RT.lock().load(value.id, data.as_flattened_mut())?;
2675        Ok(data)
2676    }
2677}
2678
2679impl<T: Scalar, const D0: usize, const D1: usize, const D2: usize> TryFrom<Tensor>
2680    for [[[T; D2]; D1]; D0]
2681{
2682    type Error = ZyxError;
2683    fn try_from(value: Tensor) -> Result<Self, Self::Error> {
2684        let mut data = [[[T::zero(); D2]; D1]; D0];
2685        RT.lock()
2686            .load(value.id, data.as_flattened_mut().as_flattened_mut())?;
2687        Ok(data)
2688    }
2689}
2690
2691impl<T: Scalar, const D0: usize, const D1: usize, const D2: usize, const D3: usize> TryFrom<Tensor>
2692    for [[[[T; D3]; D2]; D1]; D0]
2693{
2694    type Error = ZyxError;
2695    fn try_from(value: Tensor) -> Result<Self, Self::Error> {
2696        let mut data = [[[[T::zero(); D3]; D2]; D1]; D0];
2697        RT.lock().load(
2698            value.id,
2699            data.as_flattened_mut()
2700                .as_flattened_mut()
2701                .as_flattened_mut(),
2702        )?;
2703        Ok(data)
2704    }
2705}
2706
2707impl Debug for Tensor {
2708    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
2709        f.write_fmt(format_args!("{self}"))
2710        //f.write_fmt(format_args!("Tensor {{ id = {:?} }}", self.id))
2711    }
2712}
2713
2714impl Display for Tensor {
2715    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
2716        // TODO don't print the whole tensor if it is too big
2717        let precision = if let Some(precision) = f.precision() {
2718            precision
2719        } else {
2720            3
2721        };
2722        let x = self.clone();
2723        let res = match self.dtype() {
2724            #[cfg(feature = "half")]
2725            DType::BF16 => {
2726                let data: Result<Vec<bf16>, _> = x.try_into();
2727                match data {
2728                    Ok(data) => tensor_to_string(&data, &self.shape(), precision, f.width()),
2729                    Err(e) => format!("f16 tensor failed to realize {e:?}"),
2730                }
2731            }
2732            #[cfg(feature = "half")]
2733            DType::F16 => {
2734                let data: Result<Vec<f16>, _> = x.try_into();
2735                match data {
2736                    Ok(data) => tensor_to_string(&data, &self.shape(), precision, f.width()),
2737                    Err(e) => format!("f16 tensor failed to realize {e:?}"),
2738                }
2739            }
2740            DType::F32 => {
2741                let data: Result<Vec<f32>, _> = x.try_into();
2742                match data {
2743                    Ok(data) => tensor_to_string(&data, &self.shape(), precision, f.width()),
2744                    Err(e) => format!("f32 tensor failed to realize {e:?}"),
2745                }
2746            }
2747            DType::F64 => {
2748                let data: Result<Vec<f64>, _> = x.try_into();
2749                match data {
2750                    Ok(data) => tensor_to_string(&data, &self.shape(), precision, f.width()),
2751                    Err(e) => format!("f64 tensor failed to realize {e:?}"),
2752                }
2753            }
2754            #[cfg(feature = "complex")]
2755            DType::CF32 => {
2756                let data: Result<Vec<Complex<f32>>, _> = x.try_into();
2757                match data {
2758                    Ok(data) => tensor_to_string(&data, &self.shape(), precision, f.width()),
2759                    Err(e) => format!("f32 tensor failed to realize {e:?}"),
2760                }
2761            }
2762            #[cfg(feature = "complex")]
2763            DType::CF64 => {
2764                let data: Result<Vec<Complex<f64>>, _> = x.try_into();
2765                match data {
2766                    Ok(data) => tensor_to_string(&data, &self.shape(), precision, f.width()),
2767                    Err(e) => format!("f64 tensor failed to realize {e:?}"),
2768                }
2769            }
2770            DType::U8 => {
2771                let data: Result<Vec<u8>, _> = x.try_into();
2772                match data {
2773                    Ok(data) => tensor_to_string(&data, &self.shape(), precision, f.width()),
2774                    Err(e) => format!("i32 tensor failed to realize {e:?}"),
2775                }
2776            }
2777            DType::I8 => {
2778                let data: Result<Vec<i8>, _> = x.try_into();
2779                match data {
2780                    Ok(data) => tensor_to_string(&data, &self.shape(), precision, f.width()),
2781                    Err(e) => format!("i32 tensor failed to realize {e:?}"),
2782                }
2783            }
2784            DType::I16 => {
2785                let data: Result<Vec<i16>, _> = x.try_into();
2786                match data {
2787                    Ok(data) => tensor_to_string(&data, &self.shape(), precision, f.width()),
2788                    Err(e) => format!("i32 tensor failed to realize {e:?}"),
2789                }
2790            }
2791            DType::I32 => {
2792                let data: Result<Vec<i32>, _> = x.try_into();
2793                match data {
2794                    Ok(data) => tensor_to_string(&data, &self.shape(), precision, f.width()),
2795                    Err(e) => format!("i32 tensor failed to realize {e:?}"),
2796                }
2797            }
2798            DType::I64 => {
2799                let data: Result<Vec<i64>, _> = x.try_into();
2800                match data {
2801                    Ok(data) => tensor_to_string(&data, &self.shape(), precision, f.width()),
2802                    Err(e) => format!("i32 tensor failed to realize {e:?}"),
2803                }
2804            }
2805            DType::Bool => {
2806                let data: Result<Vec<bool>, _> = x.try_into();
2807                match data {
2808                    Ok(data) => tensor_to_string(&data, &self.shape(), precision, f.width()),
2809                    Err(e) => format!("i32 tensor failed to realize {e:?}"),
2810                }
2811            }
2812        };
2813        f.write_fmt(format_args!(
2814            "Tensor {:?} {}\n{res}",
2815            self.shape(),
2816            self.dtype()
2817        ))
2818    }
2819}
2820
2821fn tensor_to_string<T: core::fmt::Display>(
2822    data: &[T],
2823    shape: &[usize],
2824    precision: usize,
2825    width: Option<usize>,
2826) -> String {
2827    use core::fmt::Write;
2828    let n: usize = shape.iter().product();
2829    let rank = shape.len();
2830    let mut res = String::new();
2831    if data.is_empty() {
2832        return "[]".into();
2833    }
2834    // get maximal width of single value
2835    let mut w = 0;
2836    if let Some(width) = width {
2837        w = width;
2838    } else {
2839        for x in data {
2840            let l = format!("{x:>.precision$}").len();
2841            if l > w {
2842                w = l;
2843            }
2844        }
2845    }
2846    let d0 = shape[rank - 1];
2847    for (i, x) in data.iter().enumerate() {
2848        {
2849            let mut var = 1;
2850            let mut r = rank;
2851            while r > 0 {
2852                if i % (n / var) == 0 {
2853                    res += &(" ".repeat(rank - r) + "[".repeat(r - 1).as_str());
2854                    break;
2855                }
2856                var *= shape[rank - r];
2857                r -= 1;
2858            }
2859        }
2860        let _ = write!(res, "{x:>w$.precision$}");
2861        if (i + 1) % d0 != 0usize {
2862            res += "  ";
2863        }
2864        {
2865            let mut var = 1;
2866            let mut r = rank;
2867            while r > 0 {
2868                if (i + 1) % (n / var) == 0 {
2869                    res += &"]".repeat(r - 1);
2870                    break;
2871                }
2872                var *= shape[rank - r];
2873                r -= 1;
2874            }
2875        }
2876        if (i + 1) % d0 == 0usize && i != n - 1 {
2877            res += "\n";
2878        }
2879    }
2880    res
2881}
2882
2883/// Into isize range, used for indexing
2884pub trait IntoRange: Clone {
2885    /// Convert self to range i64, if it is scalar, it gets converted to x..x+1
2886    fn into_range(self) -> Range<isize>;
2887}
2888
2889impl IntoRange for RangeFull {
2890    fn into_range(self) -> Range<isize> {
2891        0..isize::MAX
2892    }
2893}
2894
2895impl IntoRange for RangeFrom<isize> {
2896    fn into_range(self) -> Range<isize> {
2897        self.start..isize::MAX
2898    }
2899}
2900
2901impl IntoRange for RangeTo<isize> {
2902    fn into_range(self) -> Range<isize> {
2903        0..self.end
2904    }
2905}
2906
2907impl IntoRange for RangeInclusive<isize> {
2908    fn into_range(self) -> Range<isize> {
2909        *self.start()..*self.end() + 1
2910    }
2911}
2912
2913impl IntoRange for RangeToInclusive<isize> {
2914    fn into_range(self) -> Range<isize> {
2915        0..self.end + 1
2916    }
2917}
2918
2919impl IntoRange for Range<isize> {
2920    fn into_range(self) -> Range<isize> {
2921        self
2922    }
2923}
2924
2925impl IntoRange for isize {
2926    fn into_range(self) -> Range<isize> {
2927        self..self + 1
2928    }
2929}
2930
2931/// Implemented for objects that can be used to index tensors.
2932pub trait IntoIndex {
2933    /// Convert self to tensor index.
2934    fn into_index(self) -> impl IntoIterator<Item = Range<isize>>;
2935}
2936
2937impl IntoIndex for Vec<Range<isize>> {
2938    fn into_index(self) -> impl IntoIterator<Item = Range<isize>> {
2939        self.into_iter()
2940    }
2941}
2942
2943impl<I: IntoRange> IntoIndex for &[I] {
2944    fn into_index(self) -> impl IntoIterator<Item = Range<isize>> {
2945        self.iter().cloned().map(IntoRange::into_range)
2946    }
2947}
2948
2949impl<I0: IntoRange> IntoIndex for I0 {
2950    fn into_index(self) -> impl IntoIterator<Item = Range<isize>> {
2951        [self.into_range()].into_iter()
2952    }
2953}
2954
2955impl<I0: IntoRange, I1: IntoRange> IntoIndex for (I0, I1) {
2956    fn into_index(self) -> impl IntoIterator<Item = Range<isize>> {
2957        [self.0.into_range(), self.1.into_range()].into_iter()
2958    }
2959}
2960
2961impl<I0: IntoRange, I1: IntoRange, I2: IntoRange> IntoIndex for (I0, I1, I2) {
2962    fn into_index(self) -> impl IntoIterator<Item = Range<isize>> {
2963        [
2964            self.0.into_range(),
2965            self.1.into_range(),
2966            self.2.into_range(),
2967        ]
2968        .into_iter()
2969    }
2970}
2971
2972impl<I0: IntoRange, I1: IntoRange, I2: IntoRange, I3: IntoRange> IntoIndex for (I0, I1, I2, I3) {
2973    fn into_index(self) -> impl IntoIterator<Item = Range<isize>> {
2974        [
2975            self.0.into_range(),
2976            self.1.into_range(),
2977            self.2.into_range(),
2978            self.3.into_range(),
2979        ]
2980        .into_iter()
2981    }
2982}
2983
2984impl<I0: IntoRange, I1: IntoRange, I2: IntoRange, I3: IntoRange, I4: IntoRange> IntoIndex
2985    for (I0, I1, I2, I3, I4)
2986{
2987    fn into_index(self) -> impl IntoIterator<Item = Range<isize>> {
2988        [
2989            self.0.into_range(),
2990            self.1.into_range(),
2991            self.2.into_range(),
2992            self.3.into_range(),
2993            self.4.into_range(),
2994        ]
2995        .into_iter()
2996    }
2997}
2998
2999impl<I0: IntoRange, I1: IntoRange, I2: IntoRange, I3: IntoRange, I4: IntoRange, I5: IntoRange>
3000    IntoIndex for (I0, I1, I2, I3, I4, I5)
3001{
3002    fn into_index(self) -> impl IntoIterator<Item = Range<isize>> {
3003        [
3004            self.0.into_range(),
3005            self.1.into_range(),
3006            self.2.into_range(),
3007            self.3.into_range(),
3008            self.4.into_range(),
3009            self.5.into_range(),
3010        ]
3011        .into_iter()
3012    }
3013}
3014
3015impl<
3016        I0: IntoRange,
3017        I1: IntoRange,
3018        I2: IntoRange,
3019        I3: IntoRange,
3020        I4: IntoRange,
3021        I5: IntoRange,
3022        I6: IntoRange,
3023    > IntoIndex for (I0, I1, I2, I3, I4, I5, I6)
3024{
3025    fn into_index(self) -> impl IntoIterator<Item = Range<isize>> {
3026        [
3027            self.0.into_range(),
3028            self.1.into_range(),
3029            self.2.into_range(),
3030            self.3.into_range(),
3031            self.4.into_range(),
3032            self.5.into_range(),
3033            self.6.into_range(),
3034        ]
3035        .into_iter()
3036    }
3037}
3038
3039impl<
3040        I0: IntoRange,
3041        I1: IntoRange,
3042        I2: IntoRange,
3043        I3: IntoRange,
3044        I4: IntoRange,
3045        I5: IntoRange,
3046        I6: IntoRange,
3047        I7: IntoRange,
3048    > IntoIndex for (I0, I1, I2, I3, I4, I5, I6, I7)
3049{
3050    fn into_index(self) -> impl IntoIterator<Item = Range<isize>> {
3051        [
3052            self.0.into_range(),
3053            self.1.into_range(),
3054            self.2.into_range(),
3055            self.3.into_range(),
3056            self.4.into_range(),
3057            self.5.into_range(),
3058            self.6.into_range(),
3059            self.7.into_range(),
3060        ]
3061        .into_iter()
3062    }
3063}
3064
3065impl From<&Tensor> for Tensor {
3066    fn from(value: &Tensor) -> Self {
3067        value.clone()
3068    }
3069}
3070
3071impl<T: Scalar> From<T> for Tensor {
3072    fn from(value: T) -> Self {
3073        return Tensor {
3074            id: RT.lock().variable(vec![1], &[value]).unwrap(),
3075        };
3076    }
3077}
3078
3079impl<T: Scalar> From<Vec<T>> for Tensor {
3080    fn from(data: Vec<T>) -> Self {
3081        return Tensor {
3082            id: RT.lock().variable(vec![data.len()], &data).unwrap(),
3083        };
3084    }
3085}
3086
3087impl<T: Scalar> From<&Vec<T>> for Tensor {
3088    fn from(data: &Vec<T>) -> Self {
3089        return Tensor {
3090            id: RT.lock().variable(vec![data.len()], &data).unwrap(),
3091        };
3092    }
3093}
3094
3095impl<T: Scalar> From<&[T]> for Tensor {
3096    fn from(data: &[T]) -> Self {
3097        let n = data.len();
3098        return Tensor {
3099            id: RT.lock().variable(vec![n], data).unwrap(),
3100        };
3101    }
3102}
3103
3104impl<T: Scalar, const D0: usize> From<[T; D0]> for Tensor {
3105    fn from(data: [T; D0]) -> Self {
3106        return Tensor {
3107            id: RT.lock().variable(vec![D0], &data).unwrap(),
3108        };
3109    }
3110}
3111
3112impl<T: Scalar, const D0: usize, const D1: usize> From<[[T; D1]; D0]> for Tensor {
3113    fn from(data: [[T; D1]; D0]) -> Self {
3114        let data = unsafe { core::slice::from_raw_parts(data[0].as_ptr(), D0 * D1) };
3115        return Tensor {
3116            id: RT.lock().variable(vec![D0, D1], data).unwrap(),
3117        };
3118    }
3119}
3120
3121impl<T: Scalar, const D0: usize, const D1: usize, const D2: usize> From<[[[T; D2]; D1]; D0]>
3122    for Tensor
3123{
3124    fn from(data: [[[T; D2]; D1]; D0]) -> Self {
3125        let data = unsafe { core::slice::from_raw_parts(data[0][0].as_ptr(), D0 * D1 * D2) };
3126        return Tensor {
3127            id: RT.lock().variable(vec![D0, D1, D2], data).unwrap(),
3128        };
3129    }
3130}
3131
3132impl<T: Scalar, const D0: usize, const D1: usize, const D2: usize, const D3: usize>
3133    From<[[[[T; D3]; D2]; D1]; D0]> for Tensor
3134{
3135    fn from(data: [[[[T; D3]; D2]; D1]; D0]) -> Self {
3136        let data =
3137            unsafe { core::slice::from_raw_parts(data[0][0][0].as_ptr(), D0 * D1 * D2 * D3) };
3138        return Tensor {
3139            id: RT.lock().variable(vec![D0, D1, D2, D3], data).unwrap(),
3140        };
3141    }
3142}
3143
3144impl PartialEq<f32> for Tensor {
3145    fn eq(&self, other: &f32) -> bool {
3146        if let Ok(data) = self.clone().try_into() {
3147            let data: f32 = data;
3148            &data == other
3149        } else {
3150            false
3151        }
3152    }
3153}
3154
3155impl PartialEq<i32> for Tensor {
3156    fn eq(&self, other: &i32) -> bool {
3157        if let Ok(data) = self.clone().try_into() {
3158            let data: i32 = data;
3159            &data == other
3160        } else {
3161            false
3162        }
3163    }
3164}
3165
3166impl<T: Scalar, const D0: usize> PartialEq<[T; D0]> for Tensor {
3167    fn eq(&self, other: &[T; D0]) -> bool {
3168        if self.shape() != [D0] {
3169            return false
3170        }
3171        if let Ok(data) = self.clone().try_into() {
3172            let data: [T; D0] = data;
3173            &data == other
3174        } else {
3175            false
3176        }
3177    }
3178}
3179
3180impl<T: Scalar, const D0: usize, const D1: usize> PartialEq<[[T; D1]; D0]> for Tensor {
3181    fn eq(&self, other: &[[T; D1]; D0]) -> bool {
3182        if self.shape() != [D0, D1] {
3183            return false
3184        }
3185        if let Ok(data) = self.clone().try_into() {
3186            let data: [[T; D1]; D0] = data;
3187            &data == other
3188        } else {
3189            false
3190        }
3191    }
3192}
3193
3194impl<T: Scalar, const D0: usize, const D1: usize, const D2: usize> PartialEq<[[[T; D2]; D1]; D0]>
3195    for Tensor
3196{
3197    fn eq(&self, other: &[[[T; D2]; D1]; D0]) -> bool {
3198        if self.shape() != [D0, D1, D2] {
3199            return false
3200        }
3201        if let Ok(data) = self.clone().try_into() {
3202            let data: [[[T; D2]; D1]; D0] = data;
3203            &data == other
3204        } else {
3205            false
3206        }
3207    }
3208}
3209
3210impl<T: Scalar, const D0: usize, const D1: usize, const D2: usize, const D3: usize>
3211    PartialEq<[[[[T; D3]; D2]; D1]; D0]> for Tensor
3212{
3213    fn eq(&self, other: &[[[[T; D3]; D2]; D1]; D0]) -> bool {
3214        if self.shape() != [D0, D1, D2, D3] {
3215            return false
3216        }
3217        if let Ok(data) = self.clone().try_into() {
3218            let data: [[[[T; D3]; D2]; D1]; D0] = data;
3219            &data == other
3220        } else {
3221            false
3222        }
3223    }
3224}
3225
3226impl<IT: Into<Tensor>> Add<IT> for Tensor {
3227    type Output = Tensor;
3228    fn add(self, rhs: IT) -> Self::Output {
3229        let (x, y) = Tensor::broadcast(self, rhs).unwrap();
3230        // We have to do this using temporary variable,
3231        // otherwise rust drops tensor before dropping mutexguard,
3232        // causing deadlock. But with temporary variable
3233        // it works. Welcome to most beloved language of all time.
3234        let tensor = Tensor {
3235            id: RT.lock().add(x.id, y.id),
3236        };
3237        return tensor;
3238    }
3239}
3240
3241impl<IT: Into<Tensor>> Add<IT> for &Tensor {
3242    type Output = Tensor;
3243    fn add(self, rhs: IT) -> Self::Output {
3244        let (x, y) = Tensor::broadcast(self, rhs).unwrap();
3245        // We have to do this using temporary variable,
3246        // otherwise rust drops tensor before dropping mutexguard,
3247        // causing deadlock. But with temporary variable
3248        // it works. Welcome to most beloved language of all time.
3249        let tensor = Tensor {
3250            id: RT.lock().add(x.id, y.id),
3251        };
3252        return tensor;
3253    }
3254}
3255
3256impl<IT: Into<Tensor>> Sub<IT> for Tensor {
3257    type Output = Tensor;
3258    fn sub(self, rhs: IT) -> Self::Output {
3259        let (x, y) = Tensor::broadcast(self, rhs).unwrap();
3260        // We have to do this using temporary variable,
3261        // otherwise rust drops tensor before dropping mutexguard,
3262        // causing deadlock. But with temporary variable
3263        // it works. Welcome to most beloved language of all time.
3264        let tensor = Tensor {
3265            id: RT.lock().sub(x.id, y.id),
3266        };
3267        return tensor;
3268    }
3269}
3270
3271impl<IT: Into<Tensor>> Sub<IT> for &Tensor {
3272    type Output = Tensor;
3273    fn sub(self, rhs: IT) -> Self::Output {
3274        let (x, y) = Tensor::broadcast(self, rhs).unwrap();
3275        // We have to do this using temporary variable,
3276        // otherwise rust drops tensor before dropping mutexguard,
3277        // causing deadlock. But with temporary variable
3278        // it works. Welcome to most beloved language of all time.
3279        let tensor = Tensor {
3280            id: RT.lock().sub(x.id, y.id),
3281        };
3282        return tensor;
3283    }
3284}
3285
3286impl<IT: Into<Tensor>> Mul<IT> for Tensor {
3287    type Output = Tensor;
3288    fn mul(self, rhs: IT) -> Self::Output {
3289        let rhs = rhs.into();
3290        let (x, y) = Tensor::broadcast(self, rhs).unwrap();
3291        // We have to do this using temporary variable,
3292        // otherwise rust drops tensor before dropping mutexguard,
3293        // causing deadlock. But with temporary variable
3294        // it works. Welcome to most beloved language of all time.
3295        //println!("Multiply by {y}");
3296        let tensor = Tensor {
3297            id: RT.lock().mul(x.id, y.id),
3298        };
3299        return tensor;
3300    }
3301}
3302
3303impl<IT: Into<Tensor>> Mul<IT> for &Tensor {
3304    type Output = Tensor;
3305    fn mul(self, rhs: IT) -> Self::Output {
3306        let rhs = rhs.into();
3307        let (x, y) = Tensor::broadcast(self, rhs).unwrap();
3308        // We have to do this using temporary variable,
3309        // otherwise rust drops tensor before dropping mutexguard,
3310        // causing deadlock. But with temporary variable
3311        // it works. Welcome to most beloved language of all time.
3312        let tensor = Tensor {
3313            id: RT.lock().mul(x.id, y.id),
3314        };
3315        return tensor;
3316    }
3317}
3318
3319impl<IT: Into<Tensor>> Div<IT> for Tensor {
3320    type Output = Tensor;
3321    fn div(self, rhs: IT) -> Self::Output {
3322        let (x, y) = Tensor::broadcast(self, rhs).unwrap();
3323        let tensor = Tensor {
3324            id: RT.lock().div(x.id, y.id),
3325        };
3326        return tensor;
3327    }
3328}
3329
3330impl<IT: Into<Tensor>> Div<IT> for &Tensor {
3331    type Output = Tensor;
3332    fn div(self, rhs: IT) -> Self::Output {
3333        let (x, y) = Tensor::broadcast(self, rhs).unwrap();
3334        let tensor = Tensor {
3335            id: RT.lock().div(x.id, y.id),
3336        };
3337        return tensor;
3338    }
3339}
3340
3341impl<IT: Into<Tensor>> BitOr<IT> for Tensor {
3342    type Output = Tensor;
3343    fn bitor(self, rhs: IT) -> Self::Output {
3344        let (x, y) = Tensor::broadcast(self, rhs).unwrap();
3345        let tensor = Tensor {
3346            id: RT.lock().bitor(x.id, y.id),
3347        };
3348        return tensor;
3349    }
3350}
3351
3352impl<IT: Into<Tensor>> BitOr<IT> for &Tensor {
3353    type Output = Tensor;
3354    fn bitor(self, rhs: IT) -> Self::Output {
3355        let (x, y) = Tensor::broadcast(self, rhs).unwrap();
3356        let tensor = Tensor {
3357            id: RT.lock().bitor(x.id, y.id),
3358        };
3359        return tensor;
3360    }
3361}
3362
3363impl<IT: Into<Tensor>> BitXor<IT> for Tensor {
3364    type Output = Tensor;
3365    fn bitxor(self, rhs: IT) -> Self::Output {
3366        let (x, y) = Tensor::broadcast(self, rhs).unwrap();
3367        let tensor = Tensor {
3368            id: RT.lock().bitxor(x.id, y.id),
3369        };
3370        return tensor;
3371    }
3372}
3373
3374impl<IT: Into<Tensor>> BitXor<IT> for &Tensor {
3375    type Output = Tensor;
3376    fn bitxor(self, rhs: IT) -> Self::Output {
3377        let (x, y) = Tensor::broadcast(self, rhs).unwrap();
3378        let tensor = Tensor {
3379            id: RT.lock().bitxor(x.id, y.id),
3380        };
3381        return tensor;
3382    }
3383}
3384
3385impl<IT: Into<Tensor>> BitAnd<IT> for Tensor {
3386    type Output = Tensor;
3387    fn bitand(self, rhs: IT) -> Self::Output {
3388        let (x, y) = Tensor::broadcast(self, rhs).unwrap();
3389        let tensor = Tensor {
3390            id: RT.lock().bitand(x.id, y.id),
3391        };
3392        return tensor;
3393    }
3394}
3395
3396impl<IT: Into<Tensor>> BitAnd<IT> for &Tensor {
3397    type Output = Tensor;
3398    fn bitand(self, rhs: IT) -> Self::Output {
3399        let (x, y) = Tensor::broadcast(self, rhs).unwrap();
3400        let tensor = Tensor {
3401            id: RT.lock().bitand(x.id, y.id),
3402        };
3403        return tensor;
3404    }
3405}
3406
3407impl Neg for Tensor {
3408    type Output = Tensor;
3409    fn neg(self) -> Self::Output {
3410        Tensor {
3411            id: RT.lock().neg(self.id),
3412        }
3413    }
3414}
3415
3416impl Neg for &Tensor {
3417    type Output = Tensor;
3418    fn neg(self) -> Self::Output {
3419        Tensor {
3420            id: RT.lock().neg(self.id),
3421        }
3422    }
3423}
3424
3425impl Not for Tensor {
3426    type Output = Tensor;
3427    fn not(self) -> Self::Output {
3428        Tensor {
3429            id: RT.lock().not(self.id),
3430        }
3431    }
3432}
3433
3434impl Not for &Tensor {
3435    type Output = Tensor;
3436    fn not(self) -> Self::Output {
3437        Tensor {
3438            id: RT.lock().not(self.id),
3439        }
3440    }
3441}
3442
3443macro_rules! impl_trait {
3444    ($trait:ident for $type:ty, $fn_name:ident) => {
3445        impl $trait<Tensor> for $type {
3446            type Output = Tensor;
3447            fn $fn_name(self, rhs: Tensor) -> Self::Output {
3448                rhs.$fn_name(self)
3449            }
3450        }
3451
3452        impl $trait<&Tensor> for $type {
3453            type Output = Tensor;
3454            fn $fn_name(self, rhs: &Tensor) -> Self::Output {
3455                rhs.$fn_name(self)
3456            }
3457        }
3458    };
3459}
3460
3461#[cfg(feature = "half")]
3462impl_trait!(Add for bf16, add);
3463#[cfg(feature = "half")]
3464impl_trait!(Add for f16, add);
3465impl_trait!(Add for f32, add);
3466impl_trait!(Add for f64, add);
3467#[cfg(feature = "complex")]
3468impl_trait!(Add for Complex<f32>, add);
3469#[cfg(feature = "complex")]
3470impl_trait!(Add for Complex<f64>, add);
3471impl_trait!(Add for u8, add);
3472impl_trait!(Add for i8, add);
3473impl_trait!(Add for i16, add);
3474impl_trait!(Add for i32, add);
3475impl_trait!(Add for i64, add);
3476impl_trait!(Add for bool, add);
3477
3478#[cfg(feature = "half")]
3479impl_trait!(Sub for bf16, sub);
3480#[cfg(feature = "half")]
3481impl_trait!(Sub for f16, sub);
3482impl_trait!(Sub for f32, sub);
3483impl_trait!(Sub for f64, sub);
3484#[cfg(feature = "complex")]
3485impl_trait!(Sub for Complex<f32>, sub);
3486#[cfg(feature = "complex")]
3487impl_trait!(Sub for Complex<f64>, sub);
3488impl_trait!(Sub for u8, sub);
3489impl_trait!(Sub for i8, sub);
3490impl_trait!(Sub for i16, sub);
3491impl_trait!(Sub for i32, sub);
3492impl_trait!(Sub for i64, sub);
3493impl_trait!(Sub for bool, sub);
3494
3495#[cfg(feature = "half")]
3496impl_trait!(Mul for bf16, mul);
3497#[cfg(feature = "half")]
3498impl_trait!(Mul for f16, mul);
3499impl_trait!(Mul for f32, mul);
3500impl_trait!(Mul for f64, mul);
3501#[cfg(feature = "complex")]
3502impl_trait!(Mul for Complex<f32>, mul);
3503#[cfg(feature = "complex")]
3504impl_trait!(Mul for Complex<f64>, mul);
3505impl_trait!(Mul for u8, mul);
3506impl_trait!(Mul for i8, mul);
3507impl_trait!(Mul for i16, mul);
3508impl_trait!(Mul for i32, mul);
3509impl_trait!(Mul for i64, mul);
3510impl_trait!(Mul for bool, mul);
3511
3512#[cfg(feature = "half")]
3513impl_trait!(Div for bf16, div);
3514#[cfg(feature = "half")]
3515impl_trait!(Div for f16, div);
3516impl_trait!(Div for f32, div);
3517impl_trait!(Div for f64, div);
3518#[cfg(feature = "complex")]
3519impl_trait!(Div for Complex<f32>, div);
3520#[cfg(feature = "complex")]
3521impl_trait!(Div for Complex<f64>, div);
3522impl_trait!(Div for u8, div);
3523impl_trait!(Div for i8, div);
3524impl_trait!(Div for i16, div);
3525impl_trait!(Div for i32, div);
3526impl_trait!(Div for i64, div);
3527impl_trait!(Div for bool, div);
3528
3529#[cfg(feature = "half")]
3530impl_trait!(BitXor for bf16, bitxor);
3531#[cfg(feature = "half")]
3532impl_trait!(BitXor for f16, bitxor);
3533impl_trait!(BitXor for f32, bitxor);
3534impl_trait!(BitXor for f64, bitxor);
3535#[cfg(feature = "complex")]
3536impl_trait!(BitXor for Complex<f32>, bitxor);
3537#[cfg(feature = "complex")]
3538impl_trait!(BitXor for Complex<f64>, bitxor);
3539impl_trait!(BitXor for u8, bitxor);
3540impl_trait!(BitXor for i8, bitxor);
3541impl_trait!(BitXor for i16, bitxor);
3542impl_trait!(BitXor for i32, bitxor);
3543impl_trait!(BitXor for i64, bitxor);
3544impl_trait!(BitXor for bool, bitxor);
3545
3546#[cfg(feature = "half")]
3547impl_trait!(BitOr for bf16, bitor);
3548#[cfg(feature = "half")]
3549impl_trait!(BitOr for f16, bitor);
3550impl_trait!(BitOr for f32, bitor);
3551impl_trait!(BitOr for f64, bitor);
3552#[cfg(feature = "complex")]
3553impl_trait!(BitOr for Complex<f32>, bitor);
3554#[cfg(feature = "complex")]
3555impl_trait!(BitOr for Complex<f64>, bitor);
3556impl_trait!(BitOr for u8, bitor);
3557impl_trait!(BitOr for i8, bitor);
3558impl_trait!(BitOr for i16, bitor);
3559impl_trait!(BitOr for i32, bitor);
3560impl_trait!(BitOr for i64, bitor);
3561impl_trait!(BitOr for bool, bitor);
3562
3563#[cfg(feature = "half")]
3564impl_trait!(BitAnd for bf16, bitand);
3565#[cfg(feature = "half")]
3566impl_trait!(BitAnd for f16, bitand);
3567impl_trait!(BitAnd for f32, bitand);
3568impl_trait!(BitAnd for f64, bitand);
3569#[cfg(feature = "complex")]
3570impl_trait!(BitAnd for Complex<f32>, bitand);
3571#[cfg(feature = "complex")]
3572impl_trait!(BitAnd for Complex<f64>, bitand);
3573impl_trait!(BitAnd for u8, bitand);
3574impl_trait!(BitAnd for i8, bitand);
3575impl_trait!(BitAnd for i16, bitand);
3576impl_trait!(BitAnd for i32, bitand);
3577impl_trait!(BitAnd for i64, bitand);
3578impl_trait!(BitAnd for bool, bitand);