burn_tensor/tensor/api/
fmod.rs

1use crate::{Float, Tensor, backend::Backend};
2
3impl<B, const D: usize> Tensor<B, D, Float>
4where
5    B: Backend,
6{
7    /// Computes the floating-point remainder of dividing `self` by `other`.
8    ///
9    /// The result has the same sign as `self` and magnitude less than `other`.
10    /// This is equivalent to the IEEE 754 remainder operation.
11    ///
12    /// # Special Cases (IEEE 754 compliant)
13    ///
14    /// - If `self` is ±∞ and `other` is not NaN, NaN is returned
15    /// - If `other` is ±0 and `self` is not NaN, NaN is returned
16    /// - If `other` is ±∞ and `self` is finite, `self` is returned
17    /// - If either argument is NaN, NaN is returned
18    ///
19    /// # Arguments
20    ///
21    /// * `other` - The divisor tensor. Must have the same shape as `self`.
22    ///
23    /// # Returns
24    ///
25    /// A tensor with the same shape where each element is the floating-point remainder.
26    ///
27    /// # Example
28    ///
29    /// ```rust
30    /// use burn_tensor::backend::Backend;
31    /// use burn_tensor::Tensor;
32    ///
33    /// fn example<B: Backend>() {
34    ///     let device = B::Device::default();
35    ///     let dividend = Tensor::<B, 1>::from_data([5.3, -5.3, 5.3, -5.3], &device);
36    ///     let divisor = Tensor::<B, 1>::from_data([2.0, 2.0, -2.0, -2.0], &device);
37    ///     let result = dividend.fmod(divisor);
38    ///
39    ///     // Result: [1.3, -1.3, 1.3, -1.3]
40    /// }
41    /// ```
42    pub fn fmod(self, other: Self) -> Self {
43        // Normal case: fmod(x, y) = x - y * trunc(x / y)
44        let quotient = self.clone().div(other.clone());
45        let truncated = quotient.trunc();
46        let product = other.clone() * truncated.clone();
47
48        // When divisor is infinity and dividend is finite:
49        // - quotient is 0, truncated is 0
50        // - but 0 * infinity = NaN, which is wrong
51        // We need to handle this case by replacing NaN with 0 when appropriate
52
53        // Check if the product is NaN due to 0 * inf
54        let is_zero_times_inf = truncated.equal_elem(0.0).bool_and(other.is_inf());
55        let zero_tensor = self.clone().mul_scalar(0.0);
56        let corrected_product = product.mask_where(is_zero_times_inf, zero_tensor);
57
58        self - corrected_product
59    }
60
61    /// Computes the floating-point remainder of dividing `self` by a scalar.
62    ///
63    /// The result has the same sign as `self` and magnitude less than the scalar.
64    ///
65    /// # Special Cases (IEEE 754 compliant)
66    ///
67    /// - If `self` is ±∞ and scalar is not NaN, NaN is returned
68    /// - If scalar is ±0 and `self` is not NaN, NaN is returned
69    /// - If scalar is ±∞ and `self` is finite, `self` is returned
70    /// - If either argument is NaN, NaN is returned
71    ///
72    /// # Arguments
73    ///
74    /// * `scalar` - The scalar divisor.
75    ///
76    /// # Returns
77    ///
78    /// A tensor with the same shape where each element is the floating-point remainder.
79    ///
80    /// # Example
81    ///
82    /// ```rust
83    /// use burn_tensor::backend::Backend;
84    /// use burn_tensor::Tensor;
85    ///
86    /// fn example<B: Backend>() {
87    ///     let device = B::Device::default();
88    ///     let tensor = Tensor::<B, 1>::from_data([5.3, -5.3, 7.5, -7.5], &device);
89    ///     let result = tensor.fmod_scalar(2.0);
90    ///
91    ///     // Result: [1.3, -1.3, 1.5, -1.5]
92    /// }
93    /// ```
94    pub fn fmod_scalar(self, scalar: f32) -> Self {
95        // Normal case: fmod(x, y) = x - y * trunc(x / y)
96        let quotient = self.clone().div_scalar(scalar);
97        let truncated = quotient.trunc();
98        let product = truncated.mul_scalar(scalar);
99
100        // Handle the special case where scalar is infinity
101        // When scalar is ±∞ and self is finite, quotient is 0, truncated is 0
102        // but 0 * infinity = NaN, which is wrong - it should be 0
103        if scalar.is_infinite() {
104            // For finite values, fmod(x, ±∞) = x
105            // For infinite values, fmod(±∞, ±∞) = NaN (which is handled by arithmetic)
106            return self;
107        }
108
109        self - product
110    }
111}