Skip to main content

lumen_core/tensor/
arith.rs

1use crate::{TensorOrScalar, grad::BinaryOp, AutogradMetaT, CmpOp, Error, FloatDType, NumDType, Shape, Storage, UnaryOp, WithDType};
2use super::Tensor;
3use paste::paste;
4
5//////////////////////////////////////////////////////////////////////////////
6///        Binary(Assign) Op with Tensor and Tensor / scalar
7//////////////////////////////////////////////////////////////////////////////
8
9impl<T: WithDType> Tensor<T> {
10    fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> crate::Result<&Shape> {
11        let lhs = self.shape();
12        let rhs = rhs.shape();
13        if lhs != rhs {
14            Err(Error::ShapeMismatchBinaryOp {
15                lhs: lhs.clone(),
16                rhs: rhs.clone(),
17                op,
18            })?
19        } else {
20            Ok(lhs)
21        }
22    }
23}
24
25impl<T: WithDType> Tensor<T> {
26    fn compute_binary_scalar_rhs_op<U, F>(lhs: &Tensor<T>, rhs: T, mut f: F, _op_name: &'static str) -> crate::Result<(Storage<U>, Shape)>
27        where 
28            U: WithDType, 
29            F: FnMut(T, T) -> U 
30    {
31        let shape = lhs.shape();
32        let lhs_storage = lhs.storage_read()?;
33        let lhs_layout = lhs.layout();
34
35        let lhs = lhs_storage.data();
36        
37        let output: Vec<_> = lhs_layout.storage_indices()
38            .map(|lhs_index| f(lhs[lhs_index], rhs))
39            .collect();
40        
41        let storage = Storage::<U>::new(output);
42        Ok((storage, shape.clone()))
43    }
44
45    fn compute_binary_scalar_lhs_op<U, F>(lhs: T, rhs: &Tensor<T>, mut f: F, _op_name: &'static str) -> crate::Result<(Storage<U>, Shape)>
46        where 
47            U: WithDType, 
48            F: FnMut(T, T) -> U 
49    {
50        let shape = rhs.shape();
51        let rhs_storage = rhs.storage_read()?;
52        let rhs_layout = rhs.layout();
53
54        let rhs = rhs_storage.data();
55        
56        let output: Vec<_> = rhs_layout.storage_indices()
57            .map(|index| f(lhs, rhs[index]))
58            .collect();
59        
60        let storage = Storage::<U>::new(output);
61        Ok((storage, shape.clone()))
62    }
63
64    fn compute_binary_op<U, F>(lhs: &Tensor<T>, rhs: &Tensor<T>, mut f: F, op_name: &'static str) -> crate::Result<(Storage<U>, Shape)>
65        where 
66            U: WithDType, 
67            F: FnMut(T, T) -> U 
68    {
69        let shape = Tensor::<T>::same_shape_binary_op(lhs, rhs, op_name)?;
70        let lhs_storage = lhs.storage_read()?;
71        let rhs_storage = rhs.storage_read()?;
72        let lhs_layout = lhs.layout();
73        let rhs_layout = rhs.layout();
74
75        assert_eq!(lhs_layout.dims(), rhs_layout.dims(), "lhs dims != rhs dim2");
76
77        let lhs = lhs_storage.data();
78        let rhs = rhs_storage.data();
79        
80        let output: Vec<_> = lhs_layout.storage_indices().zip(rhs_layout.storage_indices())
81            .map(|(lhs_index, rhs_index)| f(lhs[lhs_index], rhs[rhs_index]))
82            .collect();
83        
84        let storage = Storage::<U>::new(output);
85        Ok((storage, shape.clone()))
86    }
87
88    fn binary_op<U, F>(lhs: &Tensor<T>, rhs: &Tensor<T>, f: F, meta: U::AutogradMeta, op_name: &'static str) -> crate::Result<Tensor<U>>
89    where 
90        U: WithDType, 
91        F: FnMut(T, T) -> U 
92    {
93        let (storage, shape) = Self::compute_binary_op(lhs, rhs, f, op_name)?;
94        Ok(Tensor::<U>::from_storage(storage, shape, meta))
95    }
96
97    fn binary_scalar_rhs_op<U, F>(lhs: &Tensor<T>, rhs: T, f: F, meta: U::AutogradMeta, op_name: &'static str) -> crate::Result<Tensor<U>>
98    where 
99        U: WithDType, 
100        F: FnMut(T, T) -> U 
101    {
102        let (storage, shape) = Self::compute_binary_scalar_rhs_op(lhs, rhs, f, op_name)?;
103        Ok(Tensor::<U>::from_storage(storage, shape, meta))
104    }
105
106    fn binary_scalar_lhs_op<U, F>(lhs: T, rhs: &Tensor<T>, f: F, meta: U::AutogradMeta, op_name: &'static str) -> crate::Result<Tensor<U>>
107    where 
108        U: WithDType, 
109        F: FnMut(T, T) -> U 
110    {
111        let (storage, shape) = Self::compute_binary_scalar_lhs_op(lhs, rhs, f, op_name)?;
112        Ok(Tensor::<U>::from_storage(storage, shape, meta))
113    }
114}
115
116macro_rules! binary_op_impl {
117    ($fn_name:ident) => {
118        paste! {
119            pub fn [< $fn_name _tensor >](&self, rhs: &Self) -> crate::Result<Self> {
120                let meta = T::AutogradMeta::on_binary_op(self, rhs, BinaryOp::  [< $fn_name:camel >]);
121                Self::binary_op(self, rhs, T::$fn_name, meta, stringify!([< $fn_name _tensor >]))
122            }
123        
124            pub fn [< $fn_name _scalar >](&self, rhs: T) -> crate::Result<Self> {
125                let meta = T::AutogradMeta::on_binary_scalar_rhs_op(self, rhs, BinaryOp::  [< $fn_name:camel >]);
126                Self::binary_scalar_rhs_op(self, rhs, T::$fn_name, meta, stringify!([< $fn_name _scalar >]))
127            } 
128        
129            pub fn [< scalar_ $fn_name >](lhs: T, rhs: &Tensor<T>) -> crate::Result<Tensor<T>> {
130                let meta = T::AutogradMeta::on_binary_scalar_lhs_op(lhs, rhs, BinaryOp::  [< $fn_name:camel >]);
131                Self::binary_scalar_lhs_op(lhs, rhs, T::$fn_name, meta, stringify!([< scalar_ $fn_name >]))
132            }
133
134            pub fn $fn_name(&self, rhs: impl Into<TensorOrScalar<T>>) -> crate::Result<Self> {
135                match rhs.into() {
136                    TensorOrScalar::Tensor(t) => self.[< $fn_name _tensor >](&t),
137                    TensorOrScalar::Scalar(s) => self.[< $fn_name _scalar >](s)
138                }
139            } 
140        }
141    };
142}
143
144impl<T: NumDType> Tensor<T> {
145    binary_op_impl!(add);
146    binary_op_impl!(mul);
147    binary_op_impl!(sub);
148    binary_op_impl!(div);
149    binary_op_impl!(minimum);
150    binary_op_impl!(maximum);
151
152    pub fn clamp(&self, min: T, max: T) -> crate::Result<Self> {
153        self.maximum(min)?.minimum(max)
154    }
155}
156
157impl<T: NumDType> Tensor<T> {
158    fn binary_op_inplace<F>(lhs: &Tensor<T>, rhs: &Tensor<T>, mut f: F, op_name: &'static str) -> crate::Result<()> 
159    where 
160        F: FnMut(T, T) -> T
161    {
162        let _ = Tensor::<T>::same_shape_binary_op(lhs, rhs, op_name)?;
163
164        let mut lhs_storage = lhs.storage_write()?;
165        let rhs_storage = rhs.storage_read()?;
166        let lhs_layout = lhs.layout();
167        let rhs_layout = rhs.layout();
168
169        assert_eq!(lhs_layout.dims(), rhs_layout.dims(), "lhs dims != rhs dim2");
170
171        let lhs = lhs_storage.data_mut();
172        let rhs = rhs_storage.data();
173        
174        lhs_layout.storage_indices().zip(rhs_layout.storage_indices())
175            .for_each(|(lhs_index, rhs_index)| lhs[lhs_index] = f(lhs[lhs_index], rhs[rhs_index]));
176        
177        Ok(())
178    }
179
180    fn binary_op_scalar_inplace<F>(lhs: &Tensor<T>, rhs: T, mut f: F, _op_name: &'static str) -> crate::Result<()> 
181    where 
182        F: FnMut(T, T) -> T
183    {
184        let mut lhs_storage = lhs.storage_write()?;
185        let lhs_layout = lhs.layout();
186
187
188        let lhs = lhs_storage.data_mut();
189        
190        lhs_layout.storage_indices()
191            .for_each(|lhs_index| lhs[lhs_index] = f(lhs[lhs_index], rhs));
192        
193        Ok(())
194    }
195}
196
197macro_rules! binary_inplace_op_impl {
198    ($fn_name:ident) => {
199        paste! {
200            pub fn [< $fn_name _ >](&self, rhs: impl Into<TensorOrScalar<T>>) -> crate::Result<Self> {
201                let rhs = rhs.into();
202                match rhs {
203                    TensorOrScalar::Scalar(rhs) => Self::binary_op_scalar_inplace(self, rhs, T::$fn_name, stringify!([< $fn_name _scalar_ >]))?,
204                    TensorOrScalar::Tensor(rhs) => Self::binary_op_inplace(self, &rhs, T::$fn_name, stringify!([< $fn_name _scalar >]))?,
205                }
206                Ok(self.clone())
207            }
208        }
209    };
210}
211
212#[allow(unused)]
213impl<T: NumDType> Tensor<T> {
214    binary_inplace_op_impl!(add);
215    binary_inplace_op_impl!(sub);
216    binary_inplace_op_impl!(mul);
217    binary_inplace_op_impl!(div);
218}
219
220impl<T: NumDType> Tensor<T> {
221    pub fn eq(&self, rhs: impl Into<TensorOrScalar<T>>) -> crate::Result<Tensor<bool>> {
222        self.cmp(rhs, CmpOp::Eq)
223    }
224
225    pub fn ne(&self, rhs: impl Into<TensorOrScalar<T>>) -> crate::Result<Tensor<bool>> {
226        self.cmp(rhs, CmpOp::Ne)
227    }
228
229    pub fn le(&self, rhs: impl Into<TensorOrScalar<T>>) -> crate::Result<Tensor<bool>> {
230        self.cmp(rhs, CmpOp::Le)
231    }
232
233    pub fn ge(&self, rhs: impl Into<TensorOrScalar<T>>) -> crate::Result<Tensor<bool>> {
234        self.cmp(rhs, CmpOp::Ge)
235    }
236
237    pub fn lt(&self, rhs: impl Into<TensorOrScalar<T>>) -> crate::Result<Tensor<bool>> {
238        self.cmp(rhs, CmpOp::Lt)
239    }
240
241    pub fn gt(&self, rhs: impl Into<TensorOrScalar<T>>) -> crate::Result<Tensor<bool>> {
242        self.cmp(rhs, CmpOp::Gt)
243    }
244
245    pub fn cmp(&self, rhs: impl Into<TensorOrScalar<T>>, op: CmpOp) -> crate::Result<Tensor<bool>> {
246        match rhs.into() {
247            TensorOrScalar::Tensor(rhs) => {
248                match op {
249                    CmpOp::Eq => Self::binary_op(self, &rhs, |a, b| a == b, Default::default(), "eq"),
250                    CmpOp::Ne => Self::binary_op(self, &rhs, |a, b| a != b, Default::default(), "nq"),
251                    CmpOp::Le => Self::binary_op(self, &rhs, |a, b| a <= b, Default::default(), "le"),
252                    CmpOp::Ge => Self::binary_op(self, &rhs, |a, b| a >= b, Default::default(), "ge"),
253                    CmpOp::Lt => Self::binary_op(self, &rhs, |a, b| a <  b, Default::default(), "lt"),
254                    CmpOp::Gt => Self::binary_op(self, &rhs, |a, b| a >  b, Default::default(), "gt"),
255                }
256            }
257            TensorOrScalar::Scalar(rhs) => {
258                match op {
259                    CmpOp::Eq => Self::binary_scalar_rhs_op(self, rhs, |a, b| a == b, Default::default(), "eq"),
260                    CmpOp::Ne => Self::binary_scalar_rhs_op(self, rhs, |a, b| a != b, Default::default(), "nq"),
261                    CmpOp::Le => Self::binary_scalar_rhs_op(self, rhs, |a, b| a <= b, Default::default(), "le"),
262                    CmpOp::Ge => Self::binary_scalar_rhs_op(self, rhs, |a, b| a >= b, Default::default(), "ge"),
263                    CmpOp::Lt => Self::binary_scalar_rhs_op(self, rhs, |a, b| a <  b, Default::default(), "lt"),
264                    CmpOp::Gt => Self::binary_scalar_rhs_op(self, rhs, |a, b| a >  b, Default::default(), "gt"),
265                }
266            }
267        }
268    } 
269}
270
271impl Tensor<bool> {
272    pub fn and(&self, rhs: impl Into<TensorOrScalar<bool>>) -> crate::Result<Tensor<bool>> {
273        match rhs.into() {
274            TensorOrScalar::Tensor(rhs) => Self::binary_op(self, &rhs, |a, b| a & b, Default::default(), "and"),
275            TensorOrScalar::Scalar(rhs) => Self::binary_scalar_rhs_op(self, rhs, |a, b| a & b, Default::default(), "and"),
276        }
277    }
278
279    pub fn or(&self, rhs: impl Into<TensorOrScalar<bool>>) -> crate::Result<Tensor<bool>> {
280        match rhs.into() {
281            TensorOrScalar::Tensor(rhs) => Self::binary_op(self, &rhs, |a, b| a | b, Default::default(), "or"),
282            TensorOrScalar::Scalar(rhs) => Self::binary_scalar_rhs_op(self, rhs, |a, b| a | b, Default::default(), "or"),
283        }
284    }
285
286    pub fn xor(&self, rhs: impl Into<TensorOrScalar<bool>>) -> crate::Result<Tensor<bool>> {
287        match rhs.into() {
288            TensorOrScalar::Tensor(rhs) => Self::binary_op(self, &rhs, |a, b| a ^ b, Default::default(), "xor"),
289            TensorOrScalar::Scalar(rhs) => Self::binary_scalar_rhs_op(self, rhs, |a, b| a ^ b, Default::default(), "xor"),
290        }
291    }
292
293    pub fn not(&self) -> crate::Result<Tensor<bool>> {
294        if self.element_count() == 0 {
295            return Ok(self.clone());
296        }
297        let storage = self.compute_unary_op(|v| !v)?;
298        Ok(Self::from_storage(storage, self.shape(), Default::default()))
299    }
300}
301
302//////////////////////////////////////////////////////////////////////////////
303///        Unary Op / Unary Assign Op  for Tensor
304//////////////////////////////////////////////////////////////////////////////
305
306impl<T: WithDType> Tensor<T> {
307    fn compute_unary_op<U, F>(&self, mut f: F) -> crate::Result<Storage<U>> 
308    where
309        U: WithDType,
310        F: FnMut(T) -> U
311    {
312        let storage = self.storage_read()?;
313        let vec = storage.data();
314        let mut output = vec![];
315        for index in self.layout().storage_indices() {
316            output.push( f(vec[index]) );
317        }
318        
319        Ok(Storage::new(output))
320    }
321
322    fn unary_assign_op<F>(&self, mut f: F) -> crate::Result<()>
323    where
324        F: FnMut(T) -> T
325    {
326        let mut storage = self.storage_write()?;
327        let vec = storage.data_mut();
328        for index in self.layout().storage_indices() {
329            vec[index] = f(vec[index]);
330        }
331        Ok(())
332    }
333}
334
335impl<T: NumDType> Tensor<T> {
336    pub fn affine(&self, mul: T, add: T) -> crate::Result<Self> {
337        if self.element_count() == 0 {
338            return Ok(self.clone());
339        }
340        let storage = self.compute_unary_op(|v| v * mul + add)?;
341        Ok(Self::from_storage(storage, self.shape(), Default::default()))
342    }
343
344    pub fn affine_assign(&self, mul: T, add: T) -> crate::Result<()> {
345        if self.element_count() == 0 {
346            return Ok(());
347        }
348        self.unary_assign_op(|v| v * mul + add)
349    }
350}
351
352macro_rules! float_unary_op_impl {
353    ($fn_name:ident) => {
354        paste! {
355            pub fn $fn_name(&self) -> crate::Result<Self> {
356                if self.element_count() == 0 {
357                    return Ok(self.clone());
358                }
359                let storage = self.compute_unary_op(F::$fn_name)?;
360                let meta = F::AutogradMeta::on_unray_op(self, UnaryOp:: [< $fn_name:camel >]);
361                Ok(Self::from_storage(storage, self.shape(), meta))
362            }
363        }
364    };
365}
366
367impl<T: WithDType> Tensor<T> {
368    pub fn map<F, O>(&self, f: F) -> crate::Result<Tensor<O>>
369    where 
370        O: WithDType,
371        F: Fn(T) -> O,
372    {
373        let storage = self.compute_unary_op(f)?;
374        Ok(Tensor::from_storage(storage, self.shape(), Default::default()))
375    }
376
377    pub fn map_assign<F>(&self, f: F) -> crate::Result<()>
378    where 
379        F: Fn(T) -> T,
380    {
381        if self.element_count() == 0 {
382            return Ok(());
383        }
384        self.unary_assign_op(f)
385    }
386}
387
388impl<T: NumDType + Neg<Output = T>> Tensor<T> {
389    pub fn neg(&self) -> crate::Result<Self> {
390        if self.element_count() == 0 {
391            return Ok(self.clone());
392        }
393        let storage = self.compute_unary_op(Neg::neg)?;
394        let meta = T::AutogradMeta::on_unray_op(self, UnaryOp::Neg);
395        Ok(Self::from_storage(storage, self.shape(), meta))
396    }
397}
398
399impl<F: FloatDType> Tensor<F> {
400    float_unary_op_impl!(floor);
401    float_unary_op_impl!(ceil);
402    float_unary_op_impl!(round);
403
404    float_unary_op_impl!(exp);
405    float_unary_op_impl!(ln);
406
407    float_unary_op_impl!(sin);
408    float_unary_op_impl!(cos);
409    float_unary_op_impl!(tanh);
410
411    float_unary_op_impl!(sqrt);
412    float_unary_op_impl!(sqr);
413    float_unary_op_impl!(abs);
414    // float_unary_op_impl!(neg);
415
416    float_unary_op_impl!(recip);
417    float_unary_op_impl!(gelu);
418    float_unary_op_impl!(gelu_erf);
419    float_unary_op_impl!(erf);
420    float_unary_op_impl!(relu);
421    float_unary_op_impl!(silu);
422    float_unary_op_impl!(sigmoid);
423
424    pub fn leaky_relu(&self, negative_slope: F) -> crate::Result<Self> {
425        if self.element_count() == 0 {
426            return Ok(self.clone());
427        }
428        let f = |v: F| F::leaky_relu(v, negative_slope);
429        let storage = self.compute_unary_op(f)?;
430        let meta = F::AutogradMeta::on_unray_op(self, UnaryOp::LeakyRelu(negative_slope));
431        Ok(Self::from_storage(storage, self.shape(), meta))
432    }
433}
434
435impl<F: FloatDType> Tensor<F> {
436    pub fn pow(&self, e: F) -> crate::Result<Self> {
437        if self.element_count() == 0 {
438            return Ok(self.clone());
439        }
440        let f = |v: F| v.powf(e); 
441        let storage = self.compute_unary_op(f)?;
442        let meta = F::AutogradMeta::on_pow_op(self, e);
443        Ok(Self::from_storage(storage, self.shape(), meta))
444    }
445}
446
447use std::ops::{Add, BitAnd, BitOr, BitXor, Div, Mul, Neg, Sub};
448
449//////////////////////////////////////////////////////////////////////////////
450///        Add
451//////////////////////////////////////////////////////////////////////////////
452
453impl<'a, T: NumDType, R: Into<TensorOrScalar<T>>> Add<R> for &Tensor<T> {
454    type Output = Tensor<T>;
455    fn add(self, rhs: R) -> Self::Output {
456        Tensor::add(self, rhs).expect("Tensor::add failed")
457    }
458}
459
460impl<'a, T: NumDType, R> Add<R> for Tensor<T> 
461where R: Into<TensorOrScalar<T>> 
462{
463    type Output = Tensor<T>;
464    fn add(self, rhs: R) -> Self::Output {
465        Tensor::add(&self, rhs).expect("Tensor::add failed")
466    }
467}
468
469//////////////////////////////////////////////////////////////////////////////
470///        Sub
471//////////////////////////////////////////////////////////////////////////////
472
473impl<'a, T: NumDType, R: Into<TensorOrScalar<T>>> Sub<R> for &Tensor<T> {
474    type Output = Tensor<T>;
475    fn sub(self, rhs: R) -> Self::Output {
476        Tensor::sub(self, rhs).expect("Tensor::sub failed")
477    }
478}
479
480impl<'a, T: NumDType, R: Into<TensorOrScalar<T>>> Sub<R> for Tensor<T> {
481    type Output = Tensor<T>;
482    fn sub(self, rhs: R) -> Self::Output {
483        Tensor::sub(&self, rhs).expect("Tensor::sub failed")
484    }
485}
486
487//////////////////////////////////////////////////////////////////////////////
488///        Mul
489//////////////////////////////////////////////////////////////////////////////
490
491impl<'a, T: NumDType, R: Into<TensorOrScalar<T>>> Mul<R> for &Tensor<T> {
492    type Output = Tensor<T>;
493    fn mul(self, rhs: R) -> Self::Output {
494        Tensor::mul(self, rhs).expect("Tensor::mul failed")
495    }
496}
497
498impl<'a, T: NumDType, R: Into<TensorOrScalar<T>>> Mul<R> for Tensor<T> {
499    type Output = Tensor<T>;
500    fn mul(self, rhs: R) -> Self::Output {
501        Tensor::mul(&self, rhs).expect("Tensor::mul failed")
502    }
503}
504
505//////////////////////////////////////////////////////////////////////////////
506///        Div
507//////////////////////////////////////////////////////////////////////////////
508
509impl<'a, T: NumDType, R: Into<TensorOrScalar<T>>> Div<R> for &Tensor<T> {
510    type Output = Tensor<T>;
511    fn div(self, rhs: R) -> Self::Output {
512        Tensor::div(self, rhs).expect("Tensor::div failed")
513    }
514}
515
516impl<'a, T: NumDType, R: Into<TensorOrScalar<T>>> Div<R> for Tensor<T> {
517    type Output = Tensor<T>;
518    fn div(self, rhs: R) -> Self::Output {
519        Tensor::div(&self, rhs).expect("Tensor::div failed")
520    }
521}
522
523//////////////////////////////////////////////////////////////////////////////
524///        Bool
525//////////////////////////////////////////////////////////////////////////////
526
527impl<'a, R: Into<TensorOrScalar<bool>>> BitAnd<R> for &Tensor<bool> {
528    type Output = Tensor<bool>;
529    fn bitand(self, rhs: R) -> Self::Output {
530        self.and(rhs).expect("Tensor::and failed")
531    }
532}
533
534impl<'a, R: Into<TensorOrScalar<bool>>> BitAnd<R> for Tensor<bool> {
535    type Output = Tensor<bool>;
536    fn bitand(self, rhs: R) -> Self::Output {
537        self.and(rhs).expect("Tensor::and failed")
538    }
539}
540
541impl<'a, R: Into<TensorOrScalar<bool>>> BitOr<R> for &Tensor<bool> {
542    type Output = Tensor<bool>;
543    fn bitor(self, rhs: R) -> Self::Output {
544        self.or(rhs).expect("Tensor::or failed")
545    }
546}
547
548impl<'a, R: Into<TensorOrScalar<bool>>> BitOr<R> for Tensor<bool> {
549    type Output = Tensor<bool>;
550    fn bitor(self, rhs: R) -> Self::Output {
551        self.or(rhs).expect("Tensor::or failed")
552    }
553}
554
555impl<'a, R: Into<TensorOrScalar<bool>>> BitXor<R> for &Tensor<bool> {
556    type Output = Tensor<bool>;
557    fn bitxor(self, rhs: R) -> Self::Output {
558        self.xor(rhs).expect("Tensor::xor failed")
559    }
560}
561
562impl<'a, R: Into<TensorOrScalar<bool>>> BitXor<R> for Tensor<bool> {
563    type Output = Tensor<bool>;
564    fn bitxor(self, rhs: R) -> Self::Output {
565        self.xor(rhs).expect("Tensor::xor failed")
566    }
567}
568
569//////////////////////////////////////////////////////////////////////////////
570///        Scalar & op
571//////////////////////////////////////////////////////////////////////////////
572
573macro_rules! impl_scalar_tensor_binary {
574    ($($t:ty),*) => {
575        $(
576            impl Add<Tensor<$t>> for $t {
577                type Output = Tensor<$t>;
578
579                fn add(self, rhs: Tensor<$t>) -> Self::Output {
580                    Tensor::add(&rhs, self).expect("Tensor::add failed")
581                }
582            }
583
584            impl Add<&Tensor<$t>> for $t {
585                type Output = Tensor<$t>;
586
587                fn add(self, rhs: &Tensor<$t>) -> Self::Output {
588                    Tensor::add(rhs, self).expect("Tensor::add failed")
589                }
590            }
591
592            impl Mul<Tensor<$t>> for $t {
593                type Output = Tensor<$t>;
594
595                fn mul(self, rhs: Tensor<$t>) -> Self::Output {
596                    Tensor::mul(&rhs, self).expect("Tensor::mul failed")
597                }
598            }
599
600            impl Mul<&Tensor<$t>> for $t {
601                type Output = Tensor<$t>;
602
603                fn mul(self, rhs: &Tensor<$t>) -> Self::Output {
604                    Tensor::mul(rhs, self).expect("Tensor::mul failed")
605                }
606            }
607
608            impl Sub<&Tensor<$t>> for $t {
609                type Output = Tensor<$t>;
610
611                fn sub(self, rhs: &Tensor<$t>) -> Self::Output {
612                    Tensor::scalar_sub(self, rhs).expect("Tensor::scalar_sub failed")
613                }
614            }
615
616            impl Sub<Tensor<$t>> for $t {
617                type Output = Tensor<$t>;
618
619                fn sub(self, rhs: Tensor<$t>) -> Self::Output {
620                    Tensor::scalar_sub(self, &rhs).expect("Tensor::scalar_sub failed")
621                }
622            }
623
624            impl Div<&Tensor<$t>> for $t {
625                type Output = Tensor<$t>;
626
627                fn div(self, rhs: &Tensor<$t>) -> Self::Output {
628                    Tensor::scalar_div(self, rhs).expect("Tensor::scalar_div failed")
629                }
630            }
631
632            impl Div<Tensor<$t>> for $t {
633                type Output = Tensor<$t>;
634
635                fn div(self, rhs: Tensor<$t>) -> Self::Output {
636                    Tensor::scalar_div(self, &rhs).expect("Tensor::scalar_div failed")
637                }
638            }
639        )*
640    };
641}
642
643impl_scalar_tensor_binary!(f32, f64, u8, i32, u32);
644
645#[cfg(test)]
646mod tests {
647    use super::*;
648
649    #[test]
650    fn test_exp_log() -> crate::Result<()> {
651        let a = Tensor::new(&[0.0f32, 1.0, 2.0])?;
652        let exp_a = a.exp()?;
653        let log_a = exp_a.ln()?;
654        assert!(a.allclose(&log_a, 1e-5, 1e-8)?);
655        Ok(())
656    }
657
658    #[test]
659    fn test_trig() -> crate::Result<()> {
660        let a = Tensor::new(&[0.0f32, std::f32::consts::FRAC_PI_2])?;
661        let sin_a = a.sin()?;
662        let cos_a = a.cos()?;
663
664        let expected_sin = Tensor::new(&[0.0f32, 1.0])?;
665        let expected_cos = Tensor::new(&[1.0f32, 0.0])?;
666
667        println!("{:?}", cos_a.iter()?.collect::<Vec<_>>());
668
669        assert!(sin_a.allclose(&expected_sin, 1e-5, 1e-8)?);
670        assert!(cos_a.allclose(&expected_cos, 1e-5, 8e-8)?);
671
672        Ok(())
673    }
674
675    #[test]
676    fn test_abs_neg() -> crate::Result<()> {
677        let a = Tensor::new(&[-1.0f32, 0.0, 2.0])?;
678        let abs_a = a.abs()?;
679        let neg_a = a.neg()?;
680
681        let expected_abs = Tensor::new(&[1.0f32, 0.0, 2.0])?;
682        let expected_neg = Tensor::new(&[1.0f32, 0.0, -2.0])?;
683
684        assert!(abs_a.allclose(&expected_abs, 1e-6, 1e-6)?);
685        assert!(neg_a.allclose(&expected_neg, 1e-6, 1e-6)?);
686
687        Ok(())
688    }
689
690    #[test]
691    fn test_floor_ceil_round() -> crate::Result<()> {
692        let a = Tensor::new(&[1.2f32, 2.7, -1.3])?;
693        let floor_a = a.floor()?;
694        let ceil_a = a.ceil()?;
695        let round_a = a.round()?;
696
697        let expected_floor = Tensor::new(&[1.0f32, 2.0, -2.0])?;
698        let expected_ceil = Tensor::new(&[2.0f32, 3.0, -1.0])?;
699        let expected_round = Tensor::new(&[1.0f32, 3.0, -1.0])?;
700
701        assert!(floor_a.allclose(&expected_floor, 1e-6, 1e-6)?);
702        assert!(ceil_a.allclose(&expected_ceil, 1e-6, 1e-6)?);
703        assert!(round_a.allclose(&expected_round, 1e-6, 1e-6)?);
704
705        Ok(())
706    }
707
708    #[test]
709    fn test_floor_recip() -> crate::Result<()> {
710        let a = Tensor::new(&[1.2f32, 2.7, -1.3])?;
711        let recip_a = a.recip()?;
712        let expected = Tensor::new(&[1.2f32.recip(), 2.7f32.recip(), -1.3f32.recip(),])?;
713
714        assert!(recip_a.allclose(&expected, 1e-6, 1e-6)?);
715
716        Ok(())
717    }
718
719    #[test]
720    fn test_add_basic() -> crate::Result<()> {
721        let a = Tensor::new(&[1.0f32, 2.0, 3.0])?;
722        let b = Tensor::new(&[4.0f32, 5.0, 6.0])?;
723        let c = Tensor::add(&a, &b)?;
724        let expected = Tensor::new(&[5.0f32, 7.0, 9.0])?;
725        assert!(c.allclose(&expected, 1e-6, 1e-6)?);
726
727        Ok(())
728    }
729
730    #[test]
731    fn test_add_basic_variable() -> crate::Result<()> {
732        let a = Tensor::new_var(&[1.0f32, 2.0, 3.0])?;
733        let b = Tensor::new_var(&[4.0f32, 5.0, 6.0])?;
734        let c = Tensor::add(&a, &b)?;
735        let expected = Tensor::new(&[5.0f32, 7.0, 9.0])?;
736        assert!(c.allclose(&expected, 1e-6, 1e-6)?);
737
738        Ok(())
739    }
740
741    #[test]
742    fn test_sub_basic() -> crate::Result<()> {
743        let a = Tensor::new(&[10.0f32, 20.0, 30.0])?;
744        let b = Tensor::new(&[1.0f32, 2.0, 3.0])?;
745        let c = Tensor::sub(&a, &b)?;
746        let expected = Tensor::new(&[9.0f32, 18.0, 27.0])?;
747        assert!(c.allclose(&expected, 1e-6, 1e-6)?);
748
749        Ok(())
750    }
751
752    #[test]
753    fn test_mul_basic() -> crate::Result<()> {
754        let a = Tensor::new(&[1.0f32, 2.0, 3.0])?;
755        let b = Tensor::new(&[2.0f32, 3.0, 4.0])?;
756        let c = Tensor::mul(&a, &b)?;
757        let expected = Tensor::new(&[2.0f32, 6.0, 12.0])?;
758        assert!(c.allclose(&expected, 1e-6, 1e-6)?);
759
760        Ok(())
761    }
762
763    #[test]
764    fn test_div_basic() -> crate::Result<()> {
765        let a = Tensor::new(&[4.0f32, 9.0, 16.0])?;
766        let b = Tensor::new(&[2.0f32, 3.0, 4.0])?;
767        let c = Tensor::div(&a, &b)?;
768        let expected = Tensor::new(&[2.0f32, 3.0, 4.0])?;
769        assert!(c.allclose(&expected, 1e-6, 1e-6)?);
770
771        Ok(())
772    }
773
774    #[test]
775    fn test_min_max_basic() -> crate::Result<()> {
776        let a = Tensor::new(&[1.0f32, 5.0, 3.0])?;
777        let b = Tensor::new(&[2.0f32, 4.0, 6.0])?;
778        let min_res = Tensor::minimum(&a, &b)?;
779        let max_res = Tensor::maximum(&a, &b)?;
780        let expected_min = Tensor::new(&[1.0f32, 4.0, 3.0])?;
781        let expected_max = Tensor::new(&[2.0f32, 5.0, 6.0])?;
782        assert!(min_res.allclose(&expected_min, 1e-6, 1e-6)?);
783        assert!(max_res.allclose(&expected_max, 1e-6, 1e-6)?);
784
785        Ok(())
786    }
787
788    #[test]
789    fn test_comparisons() -> crate::Result<()> {
790        let a = Tensor::new(&[1, 2, 3])?;
791        let b = Tensor::new(&[1, 0, 3])?;
792
793        assert_eq!(a.eq(&b).unwrap().to_vec()?, [true, false, true]);
794        assert_eq!(a.ne(&b).unwrap().to_vec()?, [false, true, false]);
795        assert_eq!(a.lt(&b).unwrap().to_vec()?, [false, false, false]);
796        assert_eq!(a.le(&b).unwrap().to_vec()?, [true, false, true]);
797        assert_eq!(a.gt(&b).unwrap().to_vec()?, [false, true, false]);
798        assert_eq!(a.ge(&b).unwrap().to_vec()?, [true, true, true]);
799
800        Ok(())
801    }
802
803    #[test]
804    fn test_add_mul_2d_3d() -> crate::Result<()> {
805        let a = Tensor::new(&[[1.0f32, 2.0], [3.0, 4.0]])?;
806        let b = Tensor::new(&[[5.0f32, 6.0], [7.0, 8.0]])?;
807        let c = Tensor::add(&a, &b)?;
808        let expected = Tensor::new(&[[6., 8.], [10., 12.]])?;
809        assert!(c.allclose(&expected, 1e-6, 1e-6)?);
810
811        let a3 = Tensor::new(&[
812            [[1., 2.], [3., 4.]],
813            [[5., 6.], [7., 8.]],
814        ])?;
815        let b3 = Tensor::new(&[
816            [[2., 0.5], [1., 2.]],
817            [[0.5, 2.], [1.5, 1.]],
818        ])?;
819        let c3 = Tensor::mul(&a3, &b3)?;
820        let expected3 = Tensor::new(&[
821            [[2., 1.], [3., 8.]],
822            [[2.5, 12.], [10.5, 8.]],
823        ])?;
824        assert!(c3.allclose(&expected3, 1e-6, 1e-6)?);
825
826        Ok(())
827    }
828
829    #[test]
830    fn test_div_high_dim() -> crate::Result<()> {
831        let a = Tensor::full((2, 2, 2, 2), 8.0f32)?;
832        let b = Tensor::full((2, 2, 2, 2), 2.0f32)?;
833        let c = Tensor::div(&a, &b)?;
834        let expected = Tensor::full((2, 2, 2, 2), 4.0f32)?;
835        assert!(c.allclose(&expected, 1e-6, 1e-6)?);
836    
837        Ok(())
838    }
839
840    #[test]
841    fn test_affine_and_affine_assign() -> crate::Result<()> {
842        let a = Tensor::<f64>::ones((3, 3))?;
843        let b = a.affine(3., 2.)?;
844        let expected = Tensor::new(&[[5., 5., 5.],[5.,5.,5.],[5.,5.,5.]])?;
845        assert!(b.allclose(&expected, 1e-6, 1e-6)?);
846
847        let a2 = Tensor::<f64>::ones((3, 3))?;
848        a2.affine_assign(3., 2.)?;
849        assert!(a2.allclose(&expected, 1e-6, 1e-6)?);
850        Ok(())
851    }
852
853    #[test]
854    fn test_add_scalar() -> crate::Result<()> {
855        let a = Tensor::new(&[1.0f32, 2.0, 3.0])?;
856        let b = 10.0f32;
857        let c = Tensor::add(&a, b)?;
858        let expected = Tensor::new(&[11.0f32, 12.0, 13.0])?;
859        assert!(c.allclose(&expected, 1e-6, 1e-6)?);
860        Ok(())
861    }
862
863    #[test]
864    fn test_sub_scalar() -> crate::Result<()> {
865        let a = Tensor::new(&[10.0f32, 20.0, 30.0])?;
866        let b = 5.0f32;
867        let c = Tensor::sub(&a, b)?;
868        let expected = Tensor::new(&[5.0f32, 15.0, 25.0])?;
869        assert!(c.allclose(&expected, 1e-6, 1e-6)?);
870        Ok(())
871    }
872
873    #[test]
874    fn test_mul_scalar() -> crate::Result<()> {
875        let a = Tensor::new(&[1.0f32, 2.0, 3.0])?;
876        let b = 2.0f32;
877        let c = Tensor::mul(&a, b)?;
878        let expected = Tensor::new(&[2.0f32, 4.0, 6.0])?;
879        assert!(c.allclose(&expected, 1e-6, 1e-6)?);
880        Ok(())
881    }
882
883    #[test]
884    fn test_div_scalar() -> crate::Result<()> {
885        let a = Tensor::new(&[4.0f32, 9.0, 16.0])?;
886        let b = 2.0f32;
887        let c = Tensor::div(&a, b)?;
888        let expected = Tensor::new(&[2.0f32, 4.5, 8.0])?;
889        assert!(c.allclose(&expected, 1e-6, 1e-6)?);
890        Ok(())
891    }
892
893    #[test]
894    fn test_minimum_scalar() -> crate::Result<()> {
895        let a = Tensor::new(&[1.0f32, 5.0, 3.0])?;
896        let b = 4.0f32;
897        let c = Tensor::minimum(&a, b)?;
898        let expected = Tensor::new(&[1.0f32, 4.0, 3.0])?;
899        assert!(c.allclose(&expected, 1e-6, 1e-6)?);
900        Ok(())
901    }
902
903    #[test]
904    fn test_maximum_scalar() -> crate::Result<()> {
905        let a = Tensor::new(&[1.0f32, 5.0, 3.0])?;
906        let b = 4.0f32;
907        let c = Tensor::maximum(&a, b)?;
908        let expected = Tensor::new(&[4.0f32, 5.0, 4.0])?;
909        assert!(c.allclose(&expected, 1e-6, 1e-6)?);
910        Ok(())
911    }
912
913    #[test]
914    fn test_eq_ne_scalar() -> crate::Result<()> {
915        let a = Tensor::new(&[1, 2, 3])?;
916        let b = 2;
917
918        // Tensor vs scalar
919        let eq_res = a.eq(b)?;
920        let expected_eq = Tensor::new(&[false, true, false])?;
921        assert_eq!(eq_res.to_vec()?, expected_eq.to_vec()?);
922
923        let ne_res = a.ne(b)?;
924        let expected_ne = Tensor::new(&[true, false, true])?;
925        assert_eq!(ne_res.to_vec()?, expected_ne.to_vec()?);
926        Ok(())
927    }
928
929    #[test]
930    fn test_lt_le_gt_ge_scalar() -> crate::Result<()> {
931        let a = Tensor::new(&[1, 2, 3])?;
932        let b = 2;
933
934        let lt_res = a.lt(b)?;
935        assert_eq!(lt_res.to_vec()?, [true, false, false]);
936
937        let le_res = a.le(b)?;
938        assert_eq!(le_res.to_vec()?, [true, true, false]);
939
940        let gt_res = a.gt(b)?;
941        assert_eq!(gt_res.to_vec()?, [false, false, true]);
942
943        let ge_res = a.ge(b)?;
944        assert_eq!(ge_res.to_vec()?, [false, true, true]);
945
946        Ok(())
947    }
948
949    #[test]
950    fn test_eq_ne_tensor() -> crate::Result<()> {
951        let a = Tensor::new(&[1, 2, 3])?;
952        let b = Tensor::new(&[1, 0, 3])?;
953
954        let eq_res = a.eq(&b)?;
955        assert_eq!(eq_res.to_vec()?, [true, false, true]);
956
957        let ne_res = a.ne(&b)?;
958        assert_eq!(ne_res.to_vec()?, [false, true, false]);
959
960        Ok(())
961    }
962
963    #[test]
964    fn test_lt_le_gt_ge_tensor() -> crate::Result<()> {
965        let a = Tensor::new(&[1, 2, 3])?;
966        let b = Tensor::new(&[2, 2, 1])?;
967
968        let lt_res = a.lt(&b)?;
969        assert_eq!(lt_res.to_vec()?, [true, false, false]);
970
971        let le_res = a.le(&b)?;
972        assert_eq!(le_res.to_vec()?, [true, true, false]);
973
974        let gt_res = a.gt(&b)?;
975        assert_eq!(gt_res.to_vec()?, [false, false, true]);
976
977        let ge_res = a.ge(&b)?;
978        assert_eq!(ge_res.to_vec()?, [false, true, true]);
979
980        Ok(())
981    }
982
983    #[test]
984    fn test_comparison_2d() -> crate::Result<()> {
985        let a = Tensor::new(&[[1, 2], [3, 4]])?;
986        let b = Tensor::new(&[[2, 2], [1, 5]])?;
987
988        let eq_res = a.eq(&b)?;
989        let expected_eq = Tensor::new(&[[false, true], [false, false]])?;
990        assert_eq!(eq_res.to_vec()?, expected_eq.to_vec()?);
991
992        let gt_res = a.gt(&b)?;
993        let expected_gt = Tensor::new(&[[false, false], [true, false]])?;
994        assert_eq!(gt_res.to_vec()?, expected_gt.to_vec()?);
995
996        // Tensor vs scalar
997        let le_res = a.le(3)?;
998        let expected_le = Tensor::new(&[[true, true], [true, false]])?;
999        assert_eq!(le_res.to_vec()?, expected_le.to_vec()?);
1000
1001        Ok(())
1002    }
1003
1004    #[test]
1005    fn test_std_ops() -> crate::Result<()> {
1006        let a = Tensor::new(&[[1., 2.], [3., 4.]])?;
1007        let b = Tensor::new(&[[2., 2.], [1., 5.]])?;
1008        let _ = a + b;
1009
1010        let a = Tensor::new(&[[1., 2.], [3., 4.]])?;
1011        let b = Tensor::new(&[[2., 2.], [1., 5.]])?;
1012        let _ = &a + &b;
1013
1014        let a = Tensor::new(&[[1., 2.], [3., 4.]])?;
1015        let b = Tensor::new(&[[2., 2.], [1., 5.]])?;
1016        let _ = a + b;
1017
1018        let a = Tensor::new(&[[1., 2.], [3., 4.]])?;
1019        let b = Tensor::new(&[[2., 2.], [1., 5.]])?;
1020        let _ = a + b;
1021
1022        Ok(())
1023    }
1024}