1use crate::{TensorOrScalar, grad::BinaryOp, AutogradMetaT, CmpOp, Error, FloatDType, NumDType, Shape, Storage, UnaryOp, WithDType};
2use super::Tensor;
3use paste::paste;
4
5impl<T: WithDType> Tensor<T> {
10 fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> crate::Result<&Shape> {
11 let lhs = self.shape();
12 let rhs = rhs.shape();
13 if lhs != rhs {
14 Err(Error::ShapeMismatchBinaryOp {
15 lhs: lhs.clone(),
16 rhs: rhs.clone(),
17 op,
18 })?
19 } else {
20 Ok(lhs)
21 }
22 }
23}
24
25impl<T: WithDType> Tensor<T> {
26 fn compute_binary_scalar_rhs_op<U, F>(lhs: &Tensor<T>, rhs: T, mut f: F, _op_name: &'static str) -> crate::Result<(Storage<U>, Shape)>
27 where
28 U: WithDType,
29 F: FnMut(T, T) -> U
30 {
31 let shape = lhs.shape();
32 let lhs_storage = lhs.storage_read()?;
33 let lhs_layout = lhs.layout();
34
35 let lhs = lhs_storage.data();
36
37 let output: Vec<_> = lhs_layout.storage_indices()
38 .map(|lhs_index| f(lhs[lhs_index], rhs))
39 .collect();
40
41 let storage = Storage::<U>::new(output);
42 Ok((storage, shape.clone()))
43 }
44
45 fn compute_binary_scalar_lhs_op<U, F>(lhs: T, rhs: &Tensor<T>, mut f: F, _op_name: &'static str) -> crate::Result<(Storage<U>, Shape)>
46 where
47 U: WithDType,
48 F: FnMut(T, T) -> U
49 {
50 let shape = rhs.shape();
51 let rhs_storage = rhs.storage_read()?;
52 let rhs_layout = rhs.layout();
53
54 let rhs = rhs_storage.data();
55
56 let output: Vec<_> = rhs_layout.storage_indices()
57 .map(|index| f(lhs, rhs[index]))
58 .collect();
59
60 let storage = Storage::<U>::new(output);
61 Ok((storage, shape.clone()))
62 }
63
64 fn compute_binary_op<U, F>(lhs: &Tensor<T>, rhs: &Tensor<T>, mut f: F, op_name: &'static str) -> crate::Result<(Storage<U>, Shape)>
65 where
66 U: WithDType,
67 F: FnMut(T, T) -> U
68 {
69 let shape = Tensor::<T>::same_shape_binary_op(lhs, rhs, op_name)?;
70 let lhs_storage = lhs.storage_read()?;
71 let rhs_storage = rhs.storage_read()?;
72 let lhs_layout = lhs.layout();
73 let rhs_layout = rhs.layout();
74
75 assert_eq!(lhs_layout.dims(), rhs_layout.dims(), "lhs dims != rhs dim2");
76
77 let lhs = lhs_storage.data();
78 let rhs = rhs_storage.data();
79
80 let output: Vec<_> = lhs_layout.storage_indices().zip(rhs_layout.storage_indices())
81 .map(|(lhs_index, rhs_index)| f(lhs[lhs_index], rhs[rhs_index]))
82 .collect();
83
84 let storage = Storage::<U>::new(output);
85 Ok((storage, shape.clone()))
86 }
87
88 fn binary_op<U, F>(lhs: &Tensor<T>, rhs: &Tensor<T>, f: F, meta: U::AutogradMeta, op_name: &'static str) -> crate::Result<Tensor<U>>
89 where
90 U: WithDType,
91 F: FnMut(T, T) -> U
92 {
93 let (storage, shape) = Self::compute_binary_op(lhs, rhs, f, op_name)?;
94 Ok(Tensor::<U>::from_storage(storage, shape, meta))
95 }
96
97 fn binary_scalar_rhs_op<U, F>(lhs: &Tensor<T>, rhs: T, f: F, meta: U::AutogradMeta, op_name: &'static str) -> crate::Result<Tensor<U>>
98 where
99 U: WithDType,
100 F: FnMut(T, T) -> U
101 {
102 let (storage, shape) = Self::compute_binary_scalar_rhs_op(lhs, rhs, f, op_name)?;
103 Ok(Tensor::<U>::from_storage(storage, shape, meta))
104 }
105
106 fn binary_scalar_lhs_op<U, F>(lhs: T, rhs: &Tensor<T>, f: F, meta: U::AutogradMeta, op_name: &'static str) -> crate::Result<Tensor<U>>
107 where
108 U: WithDType,
109 F: FnMut(T, T) -> U
110 {
111 let (storage, shape) = Self::compute_binary_scalar_lhs_op(lhs, rhs, f, op_name)?;
112 Ok(Tensor::<U>::from_storage(storage, shape, meta))
113 }
114}
115
116macro_rules! binary_op_impl {
117 ($fn_name:ident) => {
118 paste! {
119 pub fn [< $fn_name _tensor >](&self, rhs: &Self) -> crate::Result<Self> {
120 let meta = T::AutogradMeta::on_binary_op(self, rhs, BinaryOp:: [< $fn_name:camel >]);
121 Self::binary_op(self, rhs, T::$fn_name, meta, stringify!([< $fn_name _tensor >]))
122 }
123
124 pub fn [< $fn_name _scalar >](&self, rhs: T) -> crate::Result<Self> {
125 let meta = T::AutogradMeta::on_binary_scalar_rhs_op(self, rhs, BinaryOp:: [< $fn_name:camel >]);
126 Self::binary_scalar_rhs_op(self, rhs, T::$fn_name, meta, stringify!([< $fn_name _scalar >]))
127 }
128
129 pub fn [< scalar_ $fn_name >](lhs: T, rhs: &Tensor<T>) -> crate::Result<Tensor<T>> {
130 let meta = T::AutogradMeta::on_binary_scalar_lhs_op(lhs, rhs, BinaryOp:: [< $fn_name:camel >]);
131 Self::binary_scalar_lhs_op(lhs, rhs, T::$fn_name, meta, stringify!([< scalar_ $fn_name >]))
132 }
133
134 pub fn $fn_name(&self, rhs: impl Into<TensorOrScalar<T>>) -> crate::Result<Self> {
135 match rhs.into() {
136 TensorOrScalar::Tensor(t) => self.[< $fn_name _tensor >](&t),
137 TensorOrScalar::Scalar(s) => self.[< $fn_name _scalar >](s)
138 }
139 }
140 }
141 };
142}
143
144impl<T: NumDType> Tensor<T> {
145 binary_op_impl!(add);
146 binary_op_impl!(mul);
147 binary_op_impl!(sub);
148 binary_op_impl!(div);
149 binary_op_impl!(minimum);
150 binary_op_impl!(maximum);
151
152 pub fn clamp(&self, min: T, max: T) -> crate::Result<Self> {
153 self.maximum(min)?.minimum(max)
154 }
155}
156
157impl<T: NumDType> Tensor<T> {
158 fn binary_op_inplace<F>(lhs: &Tensor<T>, rhs: &Tensor<T>, mut f: F, op_name: &'static str) -> crate::Result<()>
159 where
160 F: FnMut(T, T) -> T
161 {
162 let _ = Tensor::<T>::same_shape_binary_op(lhs, rhs, op_name)?;
163
164 let mut lhs_storage = lhs.storage_write()?;
165 let rhs_storage = rhs.storage_read()?;
166 let lhs_layout = lhs.layout();
167 let rhs_layout = rhs.layout();
168
169 assert_eq!(lhs_layout.dims(), rhs_layout.dims(), "lhs dims != rhs dim2");
170
171 let lhs = lhs_storage.data_mut();
172 let rhs = rhs_storage.data();
173
174 lhs_layout.storage_indices().zip(rhs_layout.storage_indices())
175 .for_each(|(lhs_index, rhs_index)| lhs[lhs_index] = f(lhs[lhs_index], rhs[rhs_index]));
176
177 Ok(())
178 }
179
180 fn binary_op_scalar_inplace<F>(lhs: &Tensor<T>, rhs: T, mut f: F, _op_name: &'static str) -> crate::Result<()>
181 where
182 F: FnMut(T, T) -> T
183 {
184 let mut lhs_storage = lhs.storage_write()?;
185 let lhs_layout = lhs.layout();
186
187
188 let lhs = lhs_storage.data_mut();
189
190 lhs_layout.storage_indices()
191 .for_each(|lhs_index| lhs[lhs_index] = f(lhs[lhs_index], rhs));
192
193 Ok(())
194 }
195}
196
197macro_rules! binary_inplace_op_impl {
198 ($fn_name:ident) => {
199 paste! {
200 pub fn [< $fn_name _ >](&self, rhs: impl Into<TensorOrScalar<T>>) -> crate::Result<Self> {
201 let rhs = rhs.into();
202 match rhs {
203 TensorOrScalar::Scalar(rhs) => Self::binary_op_scalar_inplace(self, rhs, T::$fn_name, stringify!([< $fn_name _scalar_ >]))?,
204 TensorOrScalar::Tensor(rhs) => Self::binary_op_inplace(self, &rhs, T::$fn_name, stringify!([< $fn_name _scalar >]))?,
205 }
206 Ok(self.clone())
207 }
208 }
209 };
210}
211
212#[allow(unused)]
213impl<T: NumDType> Tensor<T> {
214 binary_inplace_op_impl!(add);
215 binary_inplace_op_impl!(sub);
216 binary_inplace_op_impl!(mul);
217 binary_inplace_op_impl!(div);
218}
219
220impl<T: NumDType> Tensor<T> {
221 pub fn eq(&self, rhs: impl Into<TensorOrScalar<T>>) -> crate::Result<Tensor<bool>> {
222 self.cmp(rhs, CmpOp::Eq)
223 }
224
225 pub fn ne(&self, rhs: impl Into<TensorOrScalar<T>>) -> crate::Result<Tensor<bool>> {
226 self.cmp(rhs, CmpOp::Ne)
227 }
228
229 pub fn le(&self, rhs: impl Into<TensorOrScalar<T>>) -> crate::Result<Tensor<bool>> {
230 self.cmp(rhs, CmpOp::Le)
231 }
232
233 pub fn ge(&self, rhs: impl Into<TensorOrScalar<T>>) -> crate::Result<Tensor<bool>> {
234 self.cmp(rhs, CmpOp::Ge)
235 }
236
237 pub fn lt(&self, rhs: impl Into<TensorOrScalar<T>>) -> crate::Result<Tensor<bool>> {
238 self.cmp(rhs, CmpOp::Lt)
239 }
240
241 pub fn gt(&self, rhs: impl Into<TensorOrScalar<T>>) -> crate::Result<Tensor<bool>> {
242 self.cmp(rhs, CmpOp::Gt)
243 }
244
245 pub fn cmp(&self, rhs: impl Into<TensorOrScalar<T>>, op: CmpOp) -> crate::Result<Tensor<bool>> {
246 match rhs.into() {
247 TensorOrScalar::Tensor(rhs) => {
248 match op {
249 CmpOp::Eq => Self::binary_op(self, &rhs, |a, b| a == b, Default::default(), "eq"),
250 CmpOp::Ne => Self::binary_op(self, &rhs, |a, b| a != b, Default::default(), "nq"),
251 CmpOp::Le => Self::binary_op(self, &rhs, |a, b| a <= b, Default::default(), "le"),
252 CmpOp::Ge => Self::binary_op(self, &rhs, |a, b| a >= b, Default::default(), "ge"),
253 CmpOp::Lt => Self::binary_op(self, &rhs, |a, b| a < b, Default::default(), "lt"),
254 CmpOp::Gt => Self::binary_op(self, &rhs, |a, b| a > b, Default::default(), "gt"),
255 }
256 }
257 TensorOrScalar::Scalar(rhs) => {
258 match op {
259 CmpOp::Eq => Self::binary_scalar_rhs_op(self, rhs, |a, b| a == b, Default::default(), "eq"),
260 CmpOp::Ne => Self::binary_scalar_rhs_op(self, rhs, |a, b| a != b, Default::default(), "nq"),
261 CmpOp::Le => Self::binary_scalar_rhs_op(self, rhs, |a, b| a <= b, Default::default(), "le"),
262 CmpOp::Ge => Self::binary_scalar_rhs_op(self, rhs, |a, b| a >= b, Default::default(), "ge"),
263 CmpOp::Lt => Self::binary_scalar_rhs_op(self, rhs, |a, b| a < b, Default::default(), "lt"),
264 CmpOp::Gt => Self::binary_scalar_rhs_op(self, rhs, |a, b| a > b, Default::default(), "gt"),
265 }
266 }
267 }
268 }
269}
270
271impl Tensor<bool> {
272 pub fn and(&self, rhs: impl Into<TensorOrScalar<bool>>) -> crate::Result<Tensor<bool>> {
273 match rhs.into() {
274 TensorOrScalar::Tensor(rhs) => Self::binary_op(self, &rhs, |a, b| a & b, Default::default(), "and"),
275 TensorOrScalar::Scalar(rhs) => Self::binary_scalar_rhs_op(self, rhs, |a, b| a & b, Default::default(), "and"),
276 }
277 }
278
279 pub fn or(&self, rhs: impl Into<TensorOrScalar<bool>>) -> crate::Result<Tensor<bool>> {
280 match rhs.into() {
281 TensorOrScalar::Tensor(rhs) => Self::binary_op(self, &rhs, |a, b| a | b, Default::default(), "or"),
282 TensorOrScalar::Scalar(rhs) => Self::binary_scalar_rhs_op(self, rhs, |a, b| a | b, Default::default(), "or"),
283 }
284 }
285
286 pub fn xor(&self, rhs: impl Into<TensorOrScalar<bool>>) -> crate::Result<Tensor<bool>> {
287 match rhs.into() {
288 TensorOrScalar::Tensor(rhs) => Self::binary_op(self, &rhs, |a, b| a ^ b, Default::default(), "xor"),
289 TensorOrScalar::Scalar(rhs) => Self::binary_scalar_rhs_op(self, rhs, |a, b| a ^ b, Default::default(), "xor"),
290 }
291 }
292
293 pub fn not(&self) -> crate::Result<Tensor<bool>> {
294 if self.element_count() == 0 {
295 return Ok(self.clone());
296 }
297 let storage = self.compute_unary_op(|v| !v)?;
298 Ok(Self::from_storage(storage, self.shape(), Default::default()))
299 }
300}
301
302impl<T: WithDType> Tensor<T> {
307 fn compute_unary_op<U, F>(&self, mut f: F) -> crate::Result<Storage<U>>
308 where
309 U: WithDType,
310 F: FnMut(T) -> U
311 {
312 let storage = self.storage_read()?;
313 let vec = storage.data();
314 let mut output = vec![];
315 for index in self.layout().storage_indices() {
316 output.push( f(vec[index]) );
317 }
318
319 Ok(Storage::new(output))
320 }
321
322 fn unary_assign_op<F>(&self, mut f: F) -> crate::Result<()>
323 where
324 F: FnMut(T) -> T
325 {
326 let mut storage = self.storage_write()?;
327 let vec = storage.data_mut();
328 for index in self.layout().storage_indices() {
329 vec[index] = f(vec[index]);
330 }
331 Ok(())
332 }
333}
334
335impl<T: NumDType> Tensor<T> {
336 pub fn affine(&self, mul: T, add: T) -> crate::Result<Self> {
337 if self.element_count() == 0 {
338 return Ok(self.clone());
339 }
340 let storage = self.compute_unary_op(|v| v * mul + add)?;
341 Ok(Self::from_storage(storage, self.shape(), Default::default()))
342 }
343
344 pub fn affine_assign(&self, mul: T, add: T) -> crate::Result<()> {
345 if self.element_count() == 0 {
346 return Ok(());
347 }
348 self.unary_assign_op(|v| v * mul + add)
349 }
350}
351
352macro_rules! float_unary_op_impl {
353 ($fn_name:ident) => {
354 paste! {
355 pub fn $fn_name(&self) -> crate::Result<Self> {
356 if self.element_count() == 0 {
357 return Ok(self.clone());
358 }
359 let storage = self.compute_unary_op(F::$fn_name)?;
360 let meta = F::AutogradMeta::on_unray_op(self, UnaryOp:: [< $fn_name:camel >]);
361 Ok(Self::from_storage(storage, self.shape(), meta))
362 }
363 }
364 };
365}
366
367impl<T: WithDType> Tensor<T> {
368 pub fn map<F, O>(&self, f: F) -> crate::Result<Tensor<O>>
369 where
370 O: WithDType,
371 F: Fn(T) -> O,
372 {
373 let storage = self.compute_unary_op(f)?;
374 Ok(Tensor::from_storage(storage, self.shape(), Default::default()))
375 }
376
377 pub fn map_assign<F>(&self, f: F) -> crate::Result<()>
378 where
379 F: Fn(T) -> T,
380 {
381 if self.element_count() == 0 {
382 return Ok(());
383 }
384 self.unary_assign_op(f)
385 }
386}
387
388impl<T: NumDType + Neg<Output = T>> Tensor<T> {
389 pub fn neg(&self) -> crate::Result<Self> {
390 if self.element_count() == 0 {
391 return Ok(self.clone());
392 }
393 let storage = self.compute_unary_op(Neg::neg)?;
394 let meta = T::AutogradMeta::on_unray_op(self, UnaryOp::Neg);
395 Ok(Self::from_storage(storage, self.shape(), meta))
396 }
397}
398
399impl<F: FloatDType> Tensor<F> {
400 float_unary_op_impl!(floor);
401 float_unary_op_impl!(ceil);
402 float_unary_op_impl!(round);
403
404 float_unary_op_impl!(exp);
405 float_unary_op_impl!(ln);
406
407 float_unary_op_impl!(sin);
408 float_unary_op_impl!(cos);
409 float_unary_op_impl!(tanh);
410
411 float_unary_op_impl!(sqrt);
412 float_unary_op_impl!(sqr);
413 float_unary_op_impl!(abs);
414 float_unary_op_impl!(recip);
417 float_unary_op_impl!(gelu);
418 float_unary_op_impl!(gelu_erf);
419 float_unary_op_impl!(erf);
420 float_unary_op_impl!(relu);
421 float_unary_op_impl!(silu);
422 float_unary_op_impl!(sigmoid);
423
424 pub fn leaky_relu(&self, negative_slope: F) -> crate::Result<Self> {
425 if self.element_count() == 0 {
426 return Ok(self.clone());
427 }
428 let f = |v: F| F::leaky_relu(v, negative_slope);
429 let storage = self.compute_unary_op(f)?;
430 let meta = F::AutogradMeta::on_unray_op(self, UnaryOp::LeakyRelu(negative_slope));
431 Ok(Self::from_storage(storage, self.shape(), meta))
432 }
433}
434
435impl<F: FloatDType> Tensor<F> {
436 pub fn pow(&self, e: F) -> crate::Result<Self> {
437 if self.element_count() == 0 {
438 return Ok(self.clone());
439 }
440 let f = |v: F| v.powf(e);
441 let storage = self.compute_unary_op(f)?;
442 let meta = F::AutogradMeta::on_pow_op(self, e);
443 Ok(Self::from_storage(storage, self.shape(), meta))
444 }
445}
446
447use std::ops::{Add, BitAnd, BitOr, BitXor, Div, Mul, Neg, Sub};
448
449impl<'a, T: NumDType, R: Into<TensorOrScalar<T>>> Add<R> for &Tensor<T> {
454 type Output = Tensor<T>;
455 fn add(self, rhs: R) -> Self::Output {
456 Tensor::add(self, rhs).expect("Tensor::add failed")
457 }
458}
459
460impl<'a, T: NumDType, R> Add<R> for Tensor<T>
461where R: Into<TensorOrScalar<T>>
462{
463 type Output = Tensor<T>;
464 fn add(self, rhs: R) -> Self::Output {
465 Tensor::add(&self, rhs).expect("Tensor::add failed")
466 }
467}
468
469impl<'a, T: NumDType, R: Into<TensorOrScalar<T>>> Sub<R> for &Tensor<T> {
474 type Output = Tensor<T>;
475 fn sub(self, rhs: R) -> Self::Output {
476 Tensor::sub(self, rhs).expect("Tensor::sub failed")
477 }
478}
479
480impl<'a, T: NumDType, R: Into<TensorOrScalar<T>>> Sub<R> for Tensor<T> {
481 type Output = Tensor<T>;
482 fn sub(self, rhs: R) -> Self::Output {
483 Tensor::sub(&self, rhs).expect("Tensor::sub failed")
484 }
485}
486
487impl<'a, T: NumDType, R: Into<TensorOrScalar<T>>> Mul<R> for &Tensor<T> {
492 type Output = Tensor<T>;
493 fn mul(self, rhs: R) -> Self::Output {
494 Tensor::mul(self, rhs).expect("Tensor::mul failed")
495 }
496}
497
498impl<'a, T: NumDType, R: Into<TensorOrScalar<T>>> Mul<R> for Tensor<T> {
499 type Output = Tensor<T>;
500 fn mul(self, rhs: R) -> Self::Output {
501 Tensor::mul(&self, rhs).expect("Tensor::mul failed")
502 }
503}
504
505impl<'a, T: NumDType, R: Into<TensorOrScalar<T>>> Div<R> for &Tensor<T> {
510 type Output = Tensor<T>;
511 fn div(self, rhs: R) -> Self::Output {
512 Tensor::div(self, rhs).expect("Tensor::div failed")
513 }
514}
515
516impl<'a, T: NumDType, R: Into<TensorOrScalar<T>>> Div<R> for Tensor<T> {
517 type Output = Tensor<T>;
518 fn div(self, rhs: R) -> Self::Output {
519 Tensor::div(&self, rhs).expect("Tensor::div failed")
520 }
521}
522
523impl<'a, R: Into<TensorOrScalar<bool>>> BitAnd<R> for &Tensor<bool> {
528 type Output = Tensor<bool>;
529 fn bitand(self, rhs: R) -> Self::Output {
530 self.and(rhs).expect("Tensor::and failed")
531 }
532}
533
534impl<'a, R: Into<TensorOrScalar<bool>>> BitAnd<R> for Tensor<bool> {
535 type Output = Tensor<bool>;
536 fn bitand(self, rhs: R) -> Self::Output {
537 self.and(rhs).expect("Tensor::and failed")
538 }
539}
540
541impl<'a, R: Into<TensorOrScalar<bool>>> BitOr<R> for &Tensor<bool> {
542 type Output = Tensor<bool>;
543 fn bitor(self, rhs: R) -> Self::Output {
544 self.or(rhs).expect("Tensor::or failed")
545 }
546}
547
548impl<'a, R: Into<TensorOrScalar<bool>>> BitOr<R> for Tensor<bool> {
549 type Output = Tensor<bool>;
550 fn bitor(self, rhs: R) -> Self::Output {
551 self.or(rhs).expect("Tensor::or failed")
552 }
553}
554
555impl<'a, R: Into<TensorOrScalar<bool>>> BitXor<R> for &Tensor<bool> {
556 type Output = Tensor<bool>;
557 fn bitxor(self, rhs: R) -> Self::Output {
558 self.xor(rhs).expect("Tensor::xor failed")
559 }
560}
561
562impl<'a, R: Into<TensorOrScalar<bool>>> BitXor<R> for Tensor<bool> {
563 type Output = Tensor<bool>;
564 fn bitxor(self, rhs: R) -> Self::Output {
565 self.xor(rhs).expect("Tensor::xor failed")
566 }
567}
568
569macro_rules! impl_scalar_tensor_binary {
574 ($($t:ty),*) => {
575 $(
576 impl Add<Tensor<$t>> for $t {
577 type Output = Tensor<$t>;
578
579 fn add(self, rhs: Tensor<$t>) -> Self::Output {
580 Tensor::add(&rhs, self).expect("Tensor::add failed")
581 }
582 }
583
584 impl Add<&Tensor<$t>> for $t {
585 type Output = Tensor<$t>;
586
587 fn add(self, rhs: &Tensor<$t>) -> Self::Output {
588 Tensor::add(rhs, self).expect("Tensor::add failed")
589 }
590 }
591
592 impl Mul<Tensor<$t>> for $t {
593 type Output = Tensor<$t>;
594
595 fn mul(self, rhs: Tensor<$t>) -> Self::Output {
596 Tensor::mul(&rhs, self).expect("Tensor::mul failed")
597 }
598 }
599
600 impl Mul<&Tensor<$t>> for $t {
601 type Output = Tensor<$t>;
602
603 fn mul(self, rhs: &Tensor<$t>) -> Self::Output {
604 Tensor::mul(rhs, self).expect("Tensor::mul failed")
605 }
606 }
607
608 impl Sub<&Tensor<$t>> for $t {
609 type Output = Tensor<$t>;
610
611 fn sub(self, rhs: &Tensor<$t>) -> Self::Output {
612 Tensor::scalar_sub(self, rhs).expect("Tensor::scalar_sub failed")
613 }
614 }
615
616 impl Sub<Tensor<$t>> for $t {
617 type Output = Tensor<$t>;
618
619 fn sub(self, rhs: Tensor<$t>) -> Self::Output {
620 Tensor::scalar_sub(self, &rhs).expect("Tensor::scalar_sub failed")
621 }
622 }
623
624 impl Div<&Tensor<$t>> for $t {
625 type Output = Tensor<$t>;
626
627 fn div(self, rhs: &Tensor<$t>) -> Self::Output {
628 Tensor::scalar_div(self, rhs).expect("Tensor::scalar_div failed")
629 }
630 }
631
632 impl Div<Tensor<$t>> for $t {
633 type Output = Tensor<$t>;
634
635 fn div(self, rhs: Tensor<$t>) -> Self::Output {
636 Tensor::scalar_div(self, &rhs).expect("Tensor::scalar_div failed")
637 }
638 }
639 )*
640 };
641}
642
643impl_scalar_tensor_binary!(f32, f64, u8, i32, u32);
644
645#[cfg(test)]
646mod tests {
647 use super::*;
648
649 #[test]
650 fn test_exp_log() -> crate::Result<()> {
651 let a = Tensor::new(&[0.0f32, 1.0, 2.0])?;
652 let exp_a = a.exp()?;
653 let log_a = exp_a.ln()?;
654 assert!(a.allclose(&log_a, 1e-5, 1e-8)?);
655 Ok(())
656 }
657
658 #[test]
659 fn test_trig() -> crate::Result<()> {
660 let a = Tensor::new(&[0.0f32, std::f32::consts::FRAC_PI_2])?;
661 let sin_a = a.sin()?;
662 let cos_a = a.cos()?;
663
664 let expected_sin = Tensor::new(&[0.0f32, 1.0])?;
665 let expected_cos = Tensor::new(&[1.0f32, 0.0])?;
666
667 println!("{:?}", cos_a.iter()?.collect::<Vec<_>>());
668
669 assert!(sin_a.allclose(&expected_sin, 1e-5, 1e-8)?);
670 assert!(cos_a.allclose(&expected_cos, 1e-5, 8e-8)?);
671
672 Ok(())
673 }
674
675 #[test]
676 fn test_abs_neg() -> crate::Result<()> {
677 let a = Tensor::new(&[-1.0f32, 0.0, 2.0])?;
678 let abs_a = a.abs()?;
679 let neg_a = a.neg()?;
680
681 let expected_abs = Tensor::new(&[1.0f32, 0.0, 2.0])?;
682 let expected_neg = Tensor::new(&[1.0f32, 0.0, -2.0])?;
683
684 assert!(abs_a.allclose(&expected_abs, 1e-6, 1e-6)?);
685 assert!(neg_a.allclose(&expected_neg, 1e-6, 1e-6)?);
686
687 Ok(())
688 }
689
690 #[test]
691 fn test_floor_ceil_round() -> crate::Result<()> {
692 let a = Tensor::new(&[1.2f32, 2.7, -1.3])?;
693 let floor_a = a.floor()?;
694 let ceil_a = a.ceil()?;
695 let round_a = a.round()?;
696
697 let expected_floor = Tensor::new(&[1.0f32, 2.0, -2.0])?;
698 let expected_ceil = Tensor::new(&[2.0f32, 3.0, -1.0])?;
699 let expected_round = Tensor::new(&[1.0f32, 3.0, -1.0])?;
700
701 assert!(floor_a.allclose(&expected_floor, 1e-6, 1e-6)?);
702 assert!(ceil_a.allclose(&expected_ceil, 1e-6, 1e-6)?);
703 assert!(round_a.allclose(&expected_round, 1e-6, 1e-6)?);
704
705 Ok(())
706 }
707
708 #[test]
709 fn test_floor_recip() -> crate::Result<()> {
710 let a = Tensor::new(&[1.2f32, 2.7, -1.3])?;
711 let recip_a = a.recip()?;
712 let expected = Tensor::new(&[1.2f32.recip(), 2.7f32.recip(), -1.3f32.recip(),])?;
713
714 assert!(recip_a.allclose(&expected, 1e-6, 1e-6)?);
715
716 Ok(())
717 }
718
719 #[test]
720 fn test_add_basic() -> crate::Result<()> {
721 let a = Tensor::new(&[1.0f32, 2.0, 3.0])?;
722 let b = Tensor::new(&[4.0f32, 5.0, 6.0])?;
723 let c = Tensor::add(&a, &b)?;
724 let expected = Tensor::new(&[5.0f32, 7.0, 9.0])?;
725 assert!(c.allclose(&expected, 1e-6, 1e-6)?);
726
727 Ok(())
728 }
729
730 #[test]
731 fn test_add_basic_variable() -> crate::Result<()> {
732 let a = Tensor::new_var(&[1.0f32, 2.0, 3.0])?;
733 let b = Tensor::new_var(&[4.0f32, 5.0, 6.0])?;
734 let c = Tensor::add(&a, &b)?;
735 let expected = Tensor::new(&[5.0f32, 7.0, 9.0])?;
736 assert!(c.allclose(&expected, 1e-6, 1e-6)?);
737
738 Ok(())
739 }
740
741 #[test]
742 fn test_sub_basic() -> crate::Result<()> {
743 let a = Tensor::new(&[10.0f32, 20.0, 30.0])?;
744 let b = Tensor::new(&[1.0f32, 2.0, 3.0])?;
745 let c = Tensor::sub(&a, &b)?;
746 let expected = Tensor::new(&[9.0f32, 18.0, 27.0])?;
747 assert!(c.allclose(&expected, 1e-6, 1e-6)?);
748
749 Ok(())
750 }
751
752 #[test]
753 fn test_mul_basic() -> crate::Result<()> {
754 let a = Tensor::new(&[1.0f32, 2.0, 3.0])?;
755 let b = Tensor::new(&[2.0f32, 3.0, 4.0])?;
756 let c = Tensor::mul(&a, &b)?;
757 let expected = Tensor::new(&[2.0f32, 6.0, 12.0])?;
758 assert!(c.allclose(&expected, 1e-6, 1e-6)?);
759
760 Ok(())
761 }
762
763 #[test]
764 fn test_div_basic() -> crate::Result<()> {
765 let a = Tensor::new(&[4.0f32, 9.0, 16.0])?;
766 let b = Tensor::new(&[2.0f32, 3.0, 4.0])?;
767 let c = Tensor::div(&a, &b)?;
768 let expected = Tensor::new(&[2.0f32, 3.0, 4.0])?;
769 assert!(c.allclose(&expected, 1e-6, 1e-6)?);
770
771 Ok(())
772 }
773
774 #[test]
775 fn test_min_max_basic() -> crate::Result<()> {
776 let a = Tensor::new(&[1.0f32, 5.0, 3.0])?;
777 let b = Tensor::new(&[2.0f32, 4.0, 6.0])?;
778 let min_res = Tensor::minimum(&a, &b)?;
779 let max_res = Tensor::maximum(&a, &b)?;
780 let expected_min = Tensor::new(&[1.0f32, 4.0, 3.0])?;
781 let expected_max = Tensor::new(&[2.0f32, 5.0, 6.0])?;
782 assert!(min_res.allclose(&expected_min, 1e-6, 1e-6)?);
783 assert!(max_res.allclose(&expected_max, 1e-6, 1e-6)?);
784
785 Ok(())
786 }
787
788 #[test]
789 fn test_comparisons() -> crate::Result<()> {
790 let a = Tensor::new(&[1, 2, 3])?;
791 let b = Tensor::new(&[1, 0, 3])?;
792
793 assert_eq!(a.eq(&b).unwrap().to_vec()?, [true, false, true]);
794 assert_eq!(a.ne(&b).unwrap().to_vec()?, [false, true, false]);
795 assert_eq!(a.lt(&b).unwrap().to_vec()?, [false, false, false]);
796 assert_eq!(a.le(&b).unwrap().to_vec()?, [true, false, true]);
797 assert_eq!(a.gt(&b).unwrap().to_vec()?, [false, true, false]);
798 assert_eq!(a.ge(&b).unwrap().to_vec()?, [true, true, true]);
799
800 Ok(())
801 }
802
803 #[test]
804 fn test_add_mul_2d_3d() -> crate::Result<()> {
805 let a = Tensor::new(&[[1.0f32, 2.0], [3.0, 4.0]])?;
806 let b = Tensor::new(&[[5.0f32, 6.0], [7.0, 8.0]])?;
807 let c = Tensor::add(&a, &b)?;
808 let expected = Tensor::new(&[[6., 8.], [10., 12.]])?;
809 assert!(c.allclose(&expected, 1e-6, 1e-6)?);
810
811 let a3 = Tensor::new(&[
812 [[1., 2.], [3., 4.]],
813 [[5., 6.], [7., 8.]],
814 ])?;
815 let b3 = Tensor::new(&[
816 [[2., 0.5], [1., 2.]],
817 [[0.5, 2.], [1.5, 1.]],
818 ])?;
819 let c3 = Tensor::mul(&a3, &b3)?;
820 let expected3 = Tensor::new(&[
821 [[2., 1.], [3., 8.]],
822 [[2.5, 12.], [10.5, 8.]],
823 ])?;
824 assert!(c3.allclose(&expected3, 1e-6, 1e-6)?);
825
826 Ok(())
827 }
828
829 #[test]
830 fn test_div_high_dim() -> crate::Result<()> {
831 let a = Tensor::full((2, 2, 2, 2), 8.0f32)?;
832 let b = Tensor::full((2, 2, 2, 2), 2.0f32)?;
833 let c = Tensor::div(&a, &b)?;
834 let expected = Tensor::full((2, 2, 2, 2), 4.0f32)?;
835 assert!(c.allclose(&expected, 1e-6, 1e-6)?);
836
837 Ok(())
838 }
839
840 #[test]
841 fn test_affine_and_affine_assign() -> crate::Result<()> {
842 let a = Tensor::<f64>::ones((3, 3))?;
843 let b = a.affine(3., 2.)?;
844 let expected = Tensor::new(&[[5., 5., 5.],[5.,5.,5.],[5.,5.,5.]])?;
845 assert!(b.allclose(&expected, 1e-6, 1e-6)?);
846
847 let a2 = Tensor::<f64>::ones((3, 3))?;
848 a2.affine_assign(3., 2.)?;
849 assert!(a2.allclose(&expected, 1e-6, 1e-6)?);
850 Ok(())
851 }
852
853 #[test]
854 fn test_add_scalar() -> crate::Result<()> {
855 let a = Tensor::new(&[1.0f32, 2.0, 3.0])?;
856 let b = 10.0f32;
857 let c = Tensor::add(&a, b)?;
858 let expected = Tensor::new(&[11.0f32, 12.0, 13.0])?;
859 assert!(c.allclose(&expected, 1e-6, 1e-6)?);
860 Ok(())
861 }
862
863 #[test]
864 fn test_sub_scalar() -> crate::Result<()> {
865 let a = Tensor::new(&[10.0f32, 20.0, 30.0])?;
866 let b = 5.0f32;
867 let c = Tensor::sub(&a, b)?;
868 let expected = Tensor::new(&[5.0f32, 15.0, 25.0])?;
869 assert!(c.allclose(&expected, 1e-6, 1e-6)?);
870 Ok(())
871 }
872
873 #[test]
874 fn test_mul_scalar() -> crate::Result<()> {
875 let a = Tensor::new(&[1.0f32, 2.0, 3.0])?;
876 let b = 2.0f32;
877 let c = Tensor::mul(&a, b)?;
878 let expected = Tensor::new(&[2.0f32, 4.0, 6.0])?;
879 assert!(c.allclose(&expected, 1e-6, 1e-6)?);
880 Ok(())
881 }
882
883 #[test]
884 fn test_div_scalar() -> crate::Result<()> {
885 let a = Tensor::new(&[4.0f32, 9.0, 16.0])?;
886 let b = 2.0f32;
887 let c = Tensor::div(&a, b)?;
888 let expected = Tensor::new(&[2.0f32, 4.5, 8.0])?;
889 assert!(c.allclose(&expected, 1e-6, 1e-6)?);
890 Ok(())
891 }
892
893 #[test]
894 fn test_minimum_scalar() -> crate::Result<()> {
895 let a = Tensor::new(&[1.0f32, 5.0, 3.0])?;
896 let b = 4.0f32;
897 let c = Tensor::minimum(&a, b)?;
898 let expected = Tensor::new(&[1.0f32, 4.0, 3.0])?;
899 assert!(c.allclose(&expected, 1e-6, 1e-6)?);
900 Ok(())
901 }
902
903 #[test]
904 fn test_maximum_scalar() -> crate::Result<()> {
905 let a = Tensor::new(&[1.0f32, 5.0, 3.0])?;
906 let b = 4.0f32;
907 let c = Tensor::maximum(&a, b)?;
908 let expected = Tensor::new(&[4.0f32, 5.0, 4.0])?;
909 assert!(c.allclose(&expected, 1e-6, 1e-6)?);
910 Ok(())
911 }
912
913 #[test]
914 fn test_eq_ne_scalar() -> crate::Result<()> {
915 let a = Tensor::new(&[1, 2, 3])?;
916 let b = 2;
917
918 let eq_res = a.eq(b)?;
920 let expected_eq = Tensor::new(&[false, true, false])?;
921 assert_eq!(eq_res.to_vec()?, expected_eq.to_vec()?);
922
923 let ne_res = a.ne(b)?;
924 let expected_ne = Tensor::new(&[true, false, true])?;
925 assert_eq!(ne_res.to_vec()?, expected_ne.to_vec()?);
926 Ok(())
927 }
928
929 #[test]
930 fn test_lt_le_gt_ge_scalar() -> crate::Result<()> {
931 let a = Tensor::new(&[1, 2, 3])?;
932 let b = 2;
933
934 let lt_res = a.lt(b)?;
935 assert_eq!(lt_res.to_vec()?, [true, false, false]);
936
937 let le_res = a.le(b)?;
938 assert_eq!(le_res.to_vec()?, [true, true, false]);
939
940 let gt_res = a.gt(b)?;
941 assert_eq!(gt_res.to_vec()?, [false, false, true]);
942
943 let ge_res = a.ge(b)?;
944 assert_eq!(ge_res.to_vec()?, [false, true, true]);
945
946 Ok(())
947 }
948
949 #[test]
950 fn test_eq_ne_tensor() -> crate::Result<()> {
951 let a = Tensor::new(&[1, 2, 3])?;
952 let b = Tensor::new(&[1, 0, 3])?;
953
954 let eq_res = a.eq(&b)?;
955 assert_eq!(eq_res.to_vec()?, [true, false, true]);
956
957 let ne_res = a.ne(&b)?;
958 assert_eq!(ne_res.to_vec()?, [false, true, false]);
959
960 Ok(())
961 }
962
963 #[test]
964 fn test_lt_le_gt_ge_tensor() -> crate::Result<()> {
965 let a = Tensor::new(&[1, 2, 3])?;
966 let b = Tensor::new(&[2, 2, 1])?;
967
968 let lt_res = a.lt(&b)?;
969 assert_eq!(lt_res.to_vec()?, [true, false, false]);
970
971 let le_res = a.le(&b)?;
972 assert_eq!(le_res.to_vec()?, [true, true, false]);
973
974 let gt_res = a.gt(&b)?;
975 assert_eq!(gt_res.to_vec()?, [false, false, true]);
976
977 let ge_res = a.ge(&b)?;
978 assert_eq!(ge_res.to_vec()?, [false, true, true]);
979
980 Ok(())
981 }
982
983 #[test]
984 fn test_comparison_2d() -> crate::Result<()> {
985 let a = Tensor::new(&[[1, 2], [3, 4]])?;
986 let b = Tensor::new(&[[2, 2], [1, 5]])?;
987
988 let eq_res = a.eq(&b)?;
989 let expected_eq = Tensor::new(&[[false, true], [false, false]])?;
990 assert_eq!(eq_res.to_vec()?, expected_eq.to_vec()?);
991
992 let gt_res = a.gt(&b)?;
993 let expected_gt = Tensor::new(&[[false, false], [true, false]])?;
994 assert_eq!(gt_res.to_vec()?, expected_gt.to_vec()?);
995
996 let le_res = a.le(3)?;
998 let expected_le = Tensor::new(&[[true, true], [true, false]])?;
999 assert_eq!(le_res.to_vec()?, expected_le.to_vec()?);
1000
1001 Ok(())
1002 }
1003
1004 #[test]
1005 fn test_std_ops() -> crate::Result<()> {
1006 let a = Tensor::new(&[[1., 2.], [3., 4.]])?;
1007 let b = Tensor::new(&[[2., 2.], [1., 5.]])?;
1008 let _ = a + b;
1009
1010 let a = Tensor::new(&[[1., 2.], [3., 4.]])?;
1011 let b = Tensor::new(&[[2., 2.], [1., 5.]])?;
1012 let _ = &a + &b;
1013
1014 let a = Tensor::new(&[[1., 2.], [3., 4.]])?;
1015 let b = Tensor::new(&[[2., 2.], [1., 5.]])?;
1016 let _ = a + b;
1017
1018 let a = Tensor::new(&[[1., 2.], [3., 4.]])?;
1019 let b = Tensor::new(&[[2., 2.], [1., 5.]])?;
1020 let _ = a + b;
1021
1022 Ok(())
1023 }
1024}