syntaxdot-tch-ext 0.5.0

tch path extension for partitioning parameters in groups
Documentation
//! Convenience functions for `Tensor`.
//!
//! The `Tensor` API can be a bit unwieldy since it is partly
//! autogenerated. This module prodides some additional methods
//! that are more convenient to use.

use tch::{Kind, TchError, Tensor};

pub trait SumDim {
    /// Sum over a dimension (fallible).
    fn f_sum_dim(&self, dim: i64, keep_dim: bool, kind: Kind) -> Result<Tensor, TchError>;

    /// Sum over a dimension.
    fn sum_dim(&self, dim: i64, keep_dim: bool, kind: Kind) -> Tensor;
}

impl SumDim for Tensor {
    fn f_sum_dim(&self, dim: i64, keep_dim: bool, kind: Kind) -> Result<Tensor, TchError> {
        self.f_sum_dim_intlist(Some([dim].as_slice()), keep_dim, kind)
    }

    fn sum_dim(&self, dim: i64, keep_dim: bool, kind: Kind) -> Tensor {
        self.f_sum_dim(dim, keep_dim, kind).unwrap()
    }
}