hpt_traits/ops/
cumulative.rs

1use hpt_common::error::base::TensorError;
2
3/// A trait for cumulative operations
4pub trait CumulativeOps: Sized {
5    /// Computes the cumulative sum of the elements along a specified axis.
6    ///
7    /// This method calculates the cumulative sum of the elements in the tensor along the given `axis`.
8    /// The cumulative sum of an element at position `i` is the sum of all elements from the start of the axis
9    /// up to and including position `i`. If no axis is specified, the cumulative sum is computed over a flattened
10    /// version of the tensor.
11    ///
12    /// # Arguments
13    ///
14    /// * `axis` - An optional axis along which to compute the cumulative sum. If `None`, the tensor is flattened,
15    ///   and the cumulative sum is computed over all elements.
16    ///
17    /// # Returns
18    ///
19    /// This function returns a `Result` containing a new tensor with the cumulative sum computed along the specified axis.
20    #[track_caller]
21    fn cumsum<A: Into<Option<i64>>>(&self, axis: A) -> Result<Self, TensorError>;
22    /// Computes the cumulative product of the elements along a specified axis.
23    ///
24    /// This method calculates the cumulative product of the elements in the tensor along the given `axis`.
25    /// The cumulative product of an element at position `i` is the product of all elements from the start of the axis
26    /// up to and including position `i`. If no axis is specified, the cumulative product is computed over a flattened
27    /// version of the tensor.
28    ///
29    /// # Arguments
30    ///
31    /// * `axis` - An optional axis along which to compute the cumulative product. If `None`, the tensor is flattened,
32    ///   and the cumulative product is computed over all elements.
33    ///
34    /// # Returns
35    ///
36    /// This function returns a `Result` containing a new tensor with the cumulative product computed along the specified axis.
37    #[track_caller]
38    fn cumprod<A: Into<Option<i64>>>(&self, axis: A) -> Result<Self, TensorError>;
39}