use anyhow::{anyhow, Result};
use async_trait::async_trait;
use std::path::Path;
use std::process::Command;
use super::{GenOptions, InferenceEngine, LoadedModel, ModelSpec};
pub struct MLXEngine {
mlx_available: bool,
}
impl MLXEngine {
pub fn new() -> Self {
Self {
mlx_available: Self::check_mlx_availability(),
}
}
fn check_mlx_availability() -> bool {
#[cfg(target_os = "macos")]
{
if std::env::consts::ARCH == "aarch64" {
Self::check_mlx_python_available()
} else {
false
}
}
#[cfg(not(target_os = "macos"))]
{
false
}
}
pub fn is_available(&self) -> bool {
self.mlx_available
}
fn check_mlx_python_available() -> bool {
Command::new("python3")
.args(["-c", "import mlx.core; print('MLX available')"])
.output()
.map(|output| output.status.success())
.unwrap_or(false)
}
fn is_mlx_compatible(spec: &ModelSpec) -> bool {
let path_str = spec.base_path.to_string_lossy();
if let Some(ext) = spec.base_path.extension().and_then(|s| s.to_str()) {
if ext == "npz" || ext == "mlx" {
return true;
}
}
let model_name = spec.name.to_lowercase();
model_name.contains("llama")
|| model_name.contains("mistral")
|| model_name.contains("phi")
|| model_name.contains("qwen")
|| path_str.contains("huggingface")
}
}
impl Default for MLXEngine {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl InferenceEngine for MLXEngine {
async fn load(&self, spec: &ModelSpec) -> Result<Box<dyn LoadedModel>> {
if !self.mlx_available {
return Err(anyhow!(
"MLX not available on this system. MLX requires macOS with Apple Silicon."
));
}
if !Self::is_mlx_compatible(spec) {
return Err(anyhow!(
"Model {} is not compatible with MLX engine",
spec.name
));
}
tracing::info!("Loading model {} with MLX engine", spec.name);
let model = MLXModel::new(spec).await?;
Ok(Box::new(model))
}
}
struct MLXModel {
name: String,
model_path: std::path::PathBuf,
_ctx_len: usize,
}
impl MLXModel {
async fn new(spec: &ModelSpec) -> Result<Self> {
tracing::info!("Initializing MLX model at {:?}", spec.base_path);
if !spec.base_path.exists() {
return Err(anyhow!("Model file not found: {:?}", spec.base_path));
}
Ok(Self {
name: spec.name.clone(),
model_path: spec.base_path.clone(),
_ctx_len: spec.ctx_len,
})
}
async fn mlx_generate(&self, prompt: &str, options: &GenOptions) -> Result<String> {
tracing::debug!(
"MLX generation for model {}: prompt length = {}",
self.name,
prompt.len()
);
let response = format!(
"MLX generated response for prompt: '{}...' (max_tokens: {})",
&prompt.chars().take(50).collect::<String>(),
options.max_tokens
);
Ok(response)
}
}
#[async_trait]
impl LoadedModel for MLXModel {
async fn generate(
&self,
prompt: &str,
opts: GenOptions,
mut on_token: Option<Box<dyn FnMut(String) + Send>>,
) -> Result<String> {
tracing::info!("MLX generation request for model {}", self.name);
let response = self.mlx_generate(prompt, &opts).await?;
if let Some(ref mut callback) = on_token {
let words: Vec<&str> = response.split_whitespace().collect();
for word in words {
callback(format!("{} ", word));
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
}
}
Ok(response)
}
}
pub mod utils {
use super::*;
pub fn is_mlx_supported() -> bool {
MLXEngine::check_mlx_availability()
}
pub fn mlx_info() -> Result<String> {
if !is_mlx_supported() {
return Ok("MLX not supported on this system".to_string());
}
Ok("MLX available on Apple Silicon with Metal GPU".to_string())
}
pub async fn convert_to_mlx(model_path: &Path, output_path: &Path) -> Result<()> {
tracing::info!(
"Converting {:?} to MLX format at {:?}",
model_path,
output_path
);
Err(anyhow!(
"MLX conversion not yet implemented - placeholder for future development"
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn test_mlx_availability_check() {
let available = MLXEngine::check_mlx_availability();
#[cfg(not(target_os = "macos"))]
assert!(
!available,
"MLX should not be available on non-macOS systems"
);
#[cfg(target_os = "macos")]
{
let _ = available;
}
}
#[test]
fn test_mlx_compatibility_detection() {
let temp_dir = tempdir().unwrap();
let mlx_spec = ModelSpec {
name: "test-mlx".to_string(),
base_path: temp_dir.path().join("model.npz"),
lora_path: None,
template: None,
ctx_len: 2048,
n_threads: Some(4),
};
assert!(MLXEngine::is_mlx_compatible(&mlx_spec));
let llama_spec = ModelSpec {
name: "llama-7b".to_string(),
base_path: temp_dir.path().join("model.bin"),
lora_path: None,
template: None,
ctx_len: 2048,
n_threads: Some(4),
};
assert!(MLXEngine::is_mlx_compatible(&llama_spec));
}
#[tokio::test]
async fn test_mlx_model_creation_fails_gracefully() {
let temp_dir = tempdir().unwrap();
let spec = ModelSpec {
name: "nonexistent".to_string(),
base_path: temp_dir.path().join("nonexistent.npz"),
lora_path: None,
template: None,
ctx_len: 2048,
n_threads: Some(4),
};
let result = MLXModel::new(&spec).await;
assert!(result.is_err(), "Should fail when model file doesn't exist");
}
}