mod device;
mod element;
mod tensor;
mod backend;
mod ops;
pub use backend::{Mlx, MlxTensorPrimitive, MlxQuantizedTensorPrimitive};
pub use device::MlxDevice;
pub use element::MlxElement;
pub use tensor::MlxTensor;
pub mod mlx {
pub use mlx_rs::*;
}
#[cfg(test)]
mod tests {
use super::*;
use burn_tensor::{backend::Backend, Tensor, TensorData, Shape};
#[test]
fn test_device_creation() {
let _device = MlxDevice::Gpu;
let _cpu = MlxDevice::Cpu;
}
#[test]
fn test_tensor_creation_raw() {
let tensor = MlxTensor::<f32>::ones(&[2, 3], MlxDevice::Gpu);
assert_eq!(tensor.shape(), vec![2, 3]);
}
#[test]
fn test_tensor_operations_raw() {
let a = MlxTensor::<f32>::ones(&[2, 3], MlxDevice::Gpu);
let b = MlxTensor::<f32>::ones(&[2, 3], MlxDevice::Gpu);
let c = a.add(&b);
assert_eq!(c.shape(), vec![2, 3]);
}
#[test]
fn test_matmul_raw() {
let a = MlxTensor::<f32>::ones(&[2, 3], MlxDevice::Gpu);
let b = MlxTensor::<f32>::ones(&[3, 4], MlxDevice::Gpu);
let c = a.matmul(&b);
assert_eq!(c.shape(), vec![2, 4]);
}
#[test]
fn test_burn_backend_tensor_creation() {
let device = MlxDevice::default();
let data = TensorData::from([1.0f32, 2.0, 3.0, 4.0]);
let tensor: Tensor<Mlx, 1> = Tensor::from_data(data, &device);
assert_eq!(tensor.shape().dims, [4]);
}
#[test]
fn test_burn_backend_arithmetic() {
let device = MlxDevice::default();
let a: Tensor<Mlx, 2> = Tensor::from_data([[1.0f32, 2.0], [3.0, 4.0]], &device);
let b: Tensor<Mlx, 2> = Tensor::from_data([[5.0f32, 6.0], [7.0, 8.0]], &device);
let sum = a.clone() + b.clone();
let diff = a.clone() - b.clone();
let prod = a.clone() * b.clone();
let quot = a / b;
assert_eq!(sum.shape().dims, [2, 2]);
assert_eq!(diff.shape().dims, [2, 2]);
assert_eq!(prod.shape().dims, [2, 2]);
assert_eq!(quot.shape().dims, [2, 2]);
}
#[test]
fn test_burn_backend_matmul() {
let device = MlxDevice::default();
let a: Tensor<Mlx, 2> = Tensor::from_data([[1.0f32, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);
let b: Tensor<Mlx, 2> = Tensor::from_data([[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]], &device);
let c = a.matmul(b);
assert_eq!(c.shape().dims, [2, 2]);
}
#[test]
fn test_burn_backend_activations() {
let device = MlxDevice::default();
let x: Tensor<Mlx, 1> = Tensor::from_data([-1.0f32, 0.0, 1.0, 2.0], &device);
let relu = burn_tensor::activation::relu(x.clone());
let sigmoid = burn_tensor::activation::sigmoid(x.clone());
let softmax = burn_tensor::activation::softmax(x.clone(), 0);
assert_eq!(relu.shape().dims, [4]);
assert_eq!(sigmoid.shape().dims, [4]);
assert_eq!(softmax.shape().dims, [4]);
}
#[test]
fn test_avg_pool2d() {
use burn_tensor::ops::ModuleOps;
let device = MlxDevice::default();
let data: Vec<f32> = (0..16).map(|x| x as f32).collect();
let x: Tensor<Mlx, 4> = Tensor::from_data(
TensorData::new(data, Shape::new([1, 1, 4, 4])),
&device
);
let pooled = Mlx::avg_pool2d(
x.into_primitive().tensor(),
[2, 2],
[2, 2],
[0, 0],
true,
);
let shape = pooled.shape();
assert_eq!(shape, vec![1, 1, 2, 2]);
}
#[test]
fn test_max_pool2d() {
use burn_tensor::ops::ModuleOps;
let device = MlxDevice::default();
let data: Vec<f32> = (0..16).map(|x| x as f32).collect();
let x: Tensor<Mlx, 4> = Tensor::from_data(
TensorData::new(data, Shape::new([1, 1, 4, 4])),
&device
);
let pooled = Mlx::max_pool2d(
x.into_primitive().tensor(),
[2, 2],
[2, 2],
[0, 0],
[1, 1],
);
let shape = pooled.shape();
assert_eq!(shape, vec![1, 1, 2, 2]);
}
#[test]
fn test_max_pool2d_with_indices() {
use burn_tensor::ops::ModuleOps;
let device = MlxDevice::default();
let data: Vec<f32> = (0..16).map(|x| x as f32).collect();
let x: Tensor<Mlx, 4> = Tensor::from_data(
TensorData::new(data, Shape::new([1, 1, 4, 4])),
&device
);
let result = Mlx::max_pool2d_with_indices(
x.into_primitive().tensor(),
[2, 2],
[2, 2],
[0, 0],
[1, 1],
);
let output_shape = result.output.shape();
let indices_shape = result.indices.shape();
assert_eq!(output_shape, vec![1, 1, 2, 2]);
assert_eq!(indices_shape, vec![1, 1, 2, 2]);
}
#[test]
fn test_avg_pool1d() {
use burn_tensor::ops::ModuleOps;
let device = MlxDevice::default();
let data: Vec<f32> = (0..16).map(|x| x as f32).collect();
let x: Tensor<Mlx, 3> = Tensor::from_data(
TensorData::new(data, Shape::new([1, 2, 8])),
&device
);
let pooled = Mlx::avg_pool1d(
x.into_primitive().tensor(),
2,
2,
0,
true,
);
let shape = pooled.shape();
assert_eq!(shape, vec![1, 2, 4]);
}
#[test]
fn test_max_pool1d() {
use burn_tensor::ops::ModuleOps;
let device = MlxDevice::default();
let data: Vec<f32> = (0..16).map(|x| x as f32).collect();
let x: Tensor<Mlx, 3> = Tensor::from_data(
TensorData::new(data, Shape::new([1, 2, 8])),
&device
);
let pooled = Mlx::max_pool1d(
x.into_primitive().tensor(),
2,
2,
0,
1,
);
let shape = pooled.shape();
assert_eq!(shape, vec![1, 2, 4]);
}
}