use crate::device::MlxDevice;
use crate::element::MlxElement;
use mlx_rs::Array;
use std::marker::PhantomData;
#[derive(Debug, Clone)]
pub struct MlxTensor<E: MlxElement> {
pub(crate) array: Array,
pub(crate) device: MlxDevice,
pub(crate) _element: PhantomData<E>,
}
impl<E: MlxElement> MlxTensor<E> {
pub fn new(array: Array, device: MlxDevice) -> Self {
Self {
array,
device,
_element: PhantomData,
}
}
pub fn array(&self) -> &Array {
&self.array
}
pub fn device(&self) -> &MlxDevice {
&self.device
}
pub fn shape(&self) -> Vec<usize> {
self.array.shape().iter().map(|&s| s as usize).collect()
}
pub fn ndim(&self) -> usize {
self.array.ndim()
}
pub fn numel(&self) -> usize {
self.array.size()
}
pub fn eval(&self) -> Result<(), mlx_rs::error::Exception> {
self.array.eval()
}
}
impl MlxTensor<f32> {
pub fn zeros(shape: &[i32], device: MlxDevice) -> Self {
let mlx_device = device.to_mlx_device();
mlx_rs::Device::set_default(&mlx_device);
let array = Array::zeros::<f32>(shape).expect("Failed to create zeros array");
Self::new(array, device)
}
pub fn ones(shape: &[i32], device: MlxDevice) -> Self {
let mlx_device = device.to_mlx_device();
mlx_rs::Device::set_default(&mlx_device);
let array = Array::ones::<f32>(shape).expect("Failed to create ones array");
Self::new(array, device)
}
pub fn from_slice(data: &[f32], shape: &[i32], device: MlxDevice) -> Self {
let mlx_device = device.to_mlx_device();
mlx_rs::Device::set_default(&mlx_device);
let array = Array::from_slice(data, shape);
Self::new(array, device)
}
pub fn add(&self, other: &Self) -> Self {
let array = mlx_rs::ops::add(&self.array, &other.array).expect("Failed to add");
Self::new(array, self.device.clone())
}
pub fn sub(&self, other: &Self) -> Self {
let array = mlx_rs::ops::subtract(&self.array, &other.array).expect("Failed to subtract");
Self::new(array, self.device.clone())
}
pub fn mul(&self, other: &Self) -> Self {
let array = mlx_rs::ops::multiply(&self.array, &other.array).expect("Failed to multiply");
Self::new(array, self.device.clone())
}
pub fn div(&self, other: &Self) -> Self {
let array = mlx_rs::ops::divide(&self.array, &other.array).expect("Failed to divide");
Self::new(array, self.device.clone())
}
pub fn matmul(&self, other: &Self) -> Self {
let array = mlx_rs::ops::matmul(&self.array, &other.array).expect("Failed to matmul");
Self::new(array, self.device.clone())
}
pub fn relu(&self) -> Self {
let zero = Array::from_f32(0.0);
let array = mlx_rs::ops::maximum(&self.array, &zero).expect("Failed to relu");
Self::new(array, self.device.clone())
}
pub fn sigmoid(&self) -> Self {
let array = mlx_rs::ops::sigmoid(&self.array).expect("Failed to sigmoid");
Self::new(array, self.device.clone())
}
pub fn tanh_act(&self) -> Self {
let array = mlx_rs::ops::tanh(&self.array).expect("Failed to tanh");
Self::new(array, self.device.clone())
}
pub fn softmax(&self) -> Self {
let array = mlx_rs::ops::softmax(&self.array, None).expect("Failed to softmax");
Self::new(array, self.device.clone())
}
pub fn sum_dim(&self, dim: i32, keepdim: bool) -> Self {
let array = mlx_rs::ops::sum_axis(&self.array, dim, keepdim).expect("Failed to sum");
Self::new(array, self.device.clone())
}
pub fn mean_dim(&self, dim: i32) -> Self {
let array = mlx_rs::ops::mean_axis(&self.array, dim, true).expect("Failed to mean");
Self::new(array, self.device.clone())
}
pub fn exp(&self) -> Self {
let array = mlx_rs::ops::exp(&self.array).expect("Failed to exp");
Self::new(array, self.device.clone())
}
pub fn log(&self) -> Self {
let array = mlx_rs::ops::log(&self.array).expect("Failed to log");
Self::new(array, self.device.clone())
}
pub fn sqrt(&self) -> Self {
let array = mlx_rs::ops::sqrt(&self.array).expect("Failed to sqrt");
Self::new(array, self.device.clone())
}
pub fn abs(&self) -> Self {
let array = mlx_rs::ops::abs(&self.array).expect("Failed to abs");
Self::new(array, self.device.clone())
}
pub fn neg(&self) -> Self {
let array = mlx_rs::ops::negative(&self.array).expect("Failed to neg");
Self::new(array, self.device.clone())
}
}
unsafe impl<E: MlxElement> Send for MlxTensor<E> {}
unsafe impl<E: MlxElement> Sync for MlxTensor<E> {}