burn_tensor/tensor/api/fmod.rs
1use crate::{Float, Tensor, backend::Backend};
2
3impl<B, const D: usize> Tensor<B, D, Float>
4where
5 B: Backend,
6{
7 /// Computes the floating-point remainder of dividing `self` by `other`.
8 ///
9 /// The result has the same sign as `self` and magnitude less than `other`.
10 /// This is equivalent to the IEEE 754 remainder operation.
11 ///
12 /// # Special Cases (IEEE 754 compliant)
13 ///
14 /// - If `self` is ±∞ and `other` is not NaN, NaN is returned
15 /// - If `other` is ±0 and `self` is not NaN, NaN is returned
16 /// - If `other` is ±∞ and `self` is finite, `self` is returned
17 /// - If either argument is NaN, NaN is returned
18 ///
19 /// # Arguments
20 ///
21 /// * `other` - The divisor tensor. Must have the same shape as `self`.
22 ///
23 /// # Returns
24 ///
25 /// A tensor with the same shape where each element is the floating-point remainder.
26 ///
27 /// # Example
28 ///
29 /// ```rust
30 /// use burn_tensor::backend::Backend;
31 /// use burn_tensor::Tensor;
32 ///
33 /// fn example<B: Backend>() {
34 /// let device = B::Device::default();
35 /// let dividend = Tensor::<B, 1>::from_data([5.3, -5.3, 5.3, -5.3], &device);
36 /// let divisor = Tensor::<B, 1>::from_data([2.0, 2.0, -2.0, -2.0], &device);
37 /// let result = dividend.fmod(divisor);
38 ///
39 /// // Result: [1.3, -1.3, 1.3, -1.3]
40 /// }
41 /// ```
42 pub fn fmod(self, other: Self) -> Self {
43 // Normal case: fmod(x, y) = x - y * trunc(x / y)
44 let quotient = self.clone().div(other.clone());
45 let truncated = quotient.trunc();
46 let product = other.clone() * truncated.clone();
47
48 // When divisor is infinity and dividend is finite:
49 // - quotient is 0, truncated is 0
50 // - but 0 * infinity = NaN, which is wrong
51 // We need to handle this case by replacing NaN with 0 when appropriate
52
53 // Check if the product is NaN due to 0 * inf
54 let is_zero_times_inf = truncated.equal_elem(0.0).bool_and(other.is_inf());
55 let zero_tensor = self.clone().mul_scalar(0.0);
56 let corrected_product = product.mask_where(is_zero_times_inf, zero_tensor);
57
58 self - corrected_product
59 }
60
61 /// Computes the floating-point remainder of dividing `self` by a scalar.
62 ///
63 /// The result has the same sign as `self` and magnitude less than the scalar.
64 ///
65 /// # Special Cases (IEEE 754 compliant)
66 ///
67 /// - If `self` is ±∞ and scalar is not NaN, NaN is returned
68 /// - If scalar is ±0 and `self` is not NaN, NaN is returned
69 /// - If scalar is ±∞ and `self` is finite, `self` is returned
70 /// - If either argument is NaN, NaN is returned
71 ///
72 /// # Arguments
73 ///
74 /// * `scalar` - The scalar divisor.
75 ///
76 /// # Returns
77 ///
78 /// A tensor with the same shape where each element is the floating-point remainder.
79 ///
80 /// # Example
81 ///
82 /// ```rust
83 /// use burn_tensor::backend::Backend;
84 /// use burn_tensor::Tensor;
85 ///
86 /// fn example<B: Backend>() {
87 /// let device = B::Device::default();
88 /// let tensor = Tensor::<B, 1>::from_data([5.3, -5.3, 7.5, -7.5], &device);
89 /// let result = tensor.fmod_scalar(2.0);
90 ///
91 /// // Result: [1.3, -1.3, 1.5, -1.5]
92 /// }
93 /// ```
94 pub fn fmod_scalar(self, scalar: f32) -> Self {
95 // Normal case: fmod(x, y) = x - y * trunc(x / y)
96 let quotient = self.clone().div_scalar(scalar);
97 let truncated = quotient.trunc();
98 let product = truncated.mul_scalar(scalar);
99
100 // Handle the special case where scalar is infinity
101 // When scalar is ±∞ and self is finite, quotient is 0, truncated is 0
102 // but 0 * infinity = NaN, which is wrong - it should be 0
103 if scalar.is_infinite() {
104 // For finite values, fmod(x, ±∞) = x
105 // For infinite values, fmod(±∞, ±∞) = NaN (which is handled by arithmetic)
106 return self;
107 }
108
109 self - product
110 }
111}