use crate::tensor::MlxTensor;
impl MlxTensor<f32> {
pub fn add(&self, rhs: &MlxTensor<f32>) -> MlxTensor<f32> {
let array = mlx_rs::ops::add(&self.array, &rhs.array)
.expect("Failed to add arrays");
MlxTensor::new(array, self.device)
}
pub fn sub(&self, rhs: &MlxTensor<f32>) -> MlxTensor<f32> {
let array = mlx_rs::ops::subtract(&self.array, &rhs.array)
.expect("Failed to subtract arrays");
MlxTensor::new(array, self.device)
}
pub fn mul(&self, rhs: &MlxTensor<f32>) -> MlxTensor<f32> {
let array = mlx_rs::ops::multiply(&self.array, &rhs.array)
.expect("Failed to multiply arrays");
MlxTensor::new(array, self.device)
}
pub fn div(&self, rhs: &MlxTensor<f32>) -> MlxTensor<f32> {
let array = mlx_rs::ops::divide(&self.array, &rhs.array)
.expect("Failed to divide arrays");
MlxTensor::new(array, self.device)
}
pub fn matmul(&self, rhs: &MlxTensor<f32>) -> MlxTensor<f32> {
let array = self.array.matmul(&rhs.array)
.expect("Failed to matmul arrays");
MlxTensor::new(array, self.device)
}
pub fn sum_dim(&self, dim: usize, keepdims: bool) -> MlxTensor<f32> {
let array = mlx_rs::ops::sum_axis(&self.array, dim as i32, keepdims)
.expect("Failed to sum array");
MlxTensor::new(array, self.device)
}
pub fn mean_dim(&self, dim: usize) -> MlxTensor<f32> {
let array = mlx_rs::ops::mean_axis(&self.array, dim as i32, true)
.expect("Failed to mean array");
MlxTensor::new(array, self.device)
}
pub fn exp(&self) -> MlxTensor<f32> {
let array = mlx_rs::ops::exp(&self.array)
.expect("Failed to exp array");
MlxTensor::new(array, self.device)
}
pub fn log(&self) -> MlxTensor<f32> {
let array = mlx_rs::ops::log(&self.array)
.expect("Failed to log array");
MlxTensor::new(array, self.device)
}
pub fn sqrt(&self) -> MlxTensor<f32> {
let array = mlx_rs::ops::sqrt(&self.array)
.expect("Failed to sqrt array");
MlxTensor::new(array, self.device)
}
pub fn abs(&self) -> MlxTensor<f32> {
let array = mlx_rs::ops::abs(&self.array)
.expect("Failed to abs array");
MlxTensor::new(array, self.device)
}
pub fn neg(&self) -> MlxTensor<f32> {
let array = mlx_rs::ops::negative(&self.array)
.expect("Failed to neg array");
MlxTensor::new(array, self.device)
}
}