use tch::{Kind, TchError, Tensor};
pub trait SumDim {
fn f_sum_dim(&self, dim: i64, keep_dim: bool, kind: Kind) -> Result<Tensor, TchError>;
fn sum_dim(&self, dim: i64, keep_dim: bool, kind: Kind) -> Tensor;
}
impl SumDim for Tensor {
fn f_sum_dim(&self, dim: i64, keep_dim: bool, kind: Kind) -> Result<Tensor, TchError> {
self.f_sum_dim_intlist(Some([dim].as_slice()), keep_dim, kind)
}
fn sum_dim(&self, dim: i64, keep_dim: bool, kind: Kind) -> Tensor {
self.f_sum_dim(dim, keep_dim, kind).unwrap()
}
}