#[cfg(not(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx))))]
fn main() {
eprintln!("mlx_safetensors_repro requires macOS Apple Silicon");
std::process::exit(1);
}
#[cfg(all(target_os = "macos", target_arch = "aarch64", not(car_skip_mlx)))]
fn main() {
use std::path::PathBuf;
use mlx_rs::Array;
let mut args = std::env::args().skip(1);
let path = match args.next() {
Some(path) => PathBuf::from(path),
None => {
eprintln!("usage: mlx_safetensors_repro <path-to-safetensors> [cpu|gpu]");
std::process::exit(2);
}
};
let device = args.next().or_else(|| std::env::var("CAR_MLX_DEVICE").ok());
match device.as_deref() {
Some("cpu") => mlx_rs::Device::set_default(&mlx_rs::Device::cpu()),
#[cfg(feature = "mlx-metal")]
Some("gpu") => mlx_rs::Device::set_default(&mlx_rs::Device::gpu()),
_ => {}
}
println!("loading {}", path.display());
let tensors = match Array::load_safetensors(&path) {
Ok(tensors) => tensors,
Err(err) => {
eprintln!("load failed: {err}");
std::process::exit(3);
}
};
println!("loaded {} tensors", tensors.len());
let mut names: Vec<_> = tensors.keys().cloned().collect();
names.sort();
for name in names.into_iter().take(10) {
let shape = tensors.get(&name).map(|tensor| tensor.shape().to_vec());
println!("{name} {:?}", shape.unwrap_or_default());
}
}