shimmy 1.8.2

Lightweight sub-5MB Ollama alternative with native SafeTensors support. No Python dependencies, 2x faster loading. Now with GitHub Spec-Kit integration for systematic development.
Documentation
// MLX Engine for Apple Silicon GPU acceleration
// Provides native Metal performance on Apple devices

use anyhow::{anyhow, Result};
use async_trait::async_trait;
use std::path::Path;
use std::process::Command;

use super::{GenOptions, InferenceEngine, LoadedModel, ModelSpec};

/// MLX-based inference engine for Apple Silicon
pub struct MLXEngine {
    /// Whether MLX is available on this system
    mlx_available: bool,
}

impl MLXEngine {
    pub fn new() -> Self {
        Self {
            mlx_available: Self::check_mlx_availability(),
        }
    }

    /// Check if MLX is available on the current system
    fn check_mlx_availability() -> bool {
        // Check if we're on macOS with Apple Silicon
        #[cfg(target_os = "macos")]
        {
            // Check if we're on Apple Silicon (ARM64)
            std::env::consts::ARCH == "aarch64"
        }
        #[cfg(not(target_os = "macos"))]
        {
            false
        }
    }

    /// Check if MLX hardware is supported (Apple Silicon + macOS)
    pub fn is_hardware_supported() -> bool {
        Self::check_mlx_availability()
    }

    /// Public method to check if MLX is available
    #[allow(dead_code)]
    pub fn is_available(&self) -> bool {
        self.mlx_available
    }

    /// Check if MLX Python packages are available
    pub fn check_mlx_python_available() -> bool {
        // Try to run a simple MLX command to verify installation
        Command::new("python3")
            .args(["-c", "import mlx.core; print('MLX available')"])
            .output()
            .map(|output| output.status.success())
            .unwrap_or(false)
    }

    /// Detect if a model is suitable for MLX
    fn is_mlx_compatible(spec: &ModelSpec) -> bool {
        let path_str = spec.base_path.to_string_lossy();

        // MLX typically works with:
        // - Converted MLX models (.npz files)
        // - HuggingFace models that can be converted
        // - Specific model architectures (Llama, Mistral, etc.)

        if let Some(ext) = spec.base_path.extension().and_then(|s| s.to_str()) {
            if ext == "npz" || ext == "mlx" {
                return true;
            }
        }

        // Check for known compatible model families
        let model_name = spec.name.to_lowercase();
        model_name.contains("llama")
            || model_name.contains("mistral")
            || model_name.contains("phi")
            || model_name.contains("qwen")
            || path_str.contains("huggingface")
    }
}

impl Default for MLXEngine {
    fn default() -> Self {
        Self::new()
    }
}

#[async_trait]
impl InferenceEngine for MLXEngine {
    async fn load(&self, spec: &ModelSpec) -> Result<Box<dyn LoadedModel>> {
        if !self.mlx_available {
            return Err(anyhow!(
                "MLX not available on this system. MLX requires macOS with Apple Silicon."
            ));
        }

        if !Self::is_mlx_compatible(spec) {
            return Err(anyhow!(
                "Model {} is not compatible with MLX engine",
                spec.name
            ));
        }

        tracing::info!("Loading model {} with MLX engine", spec.name);

        // Create MLX model instance
        let model = MLXModel::new(spec).await?;

        Ok(Box::new(model))
    }
}

/// MLX-loaded model instance
struct MLXModel {
    name: String,
    #[allow(dead_code)]
    model_path: std::path::PathBuf,
    _ctx_len: usize,
}

impl MLXModel {
    async fn new(spec: &ModelSpec) -> Result<Self> {
        // In a real implementation, this would:
        // 1. Load the MLX model using Python bindings or native MLX
        // 2. Set up the model for inference
        // 3. Configure memory and GPU settings

        tracing::info!("Initializing MLX model at {:?}", spec.base_path);

        // Validate model file exists
        if !spec.base_path.exists() {
            return Err(anyhow!("Model file not found: {:?}", spec.base_path));
        }

        Ok(Self {
            name: spec.name.clone(),
            model_path: spec.base_path.clone(),
            _ctx_len: spec.ctx_len,
        })
    }

    /// Generate text using MLX
    async fn mlx_generate(&self, prompt: &str, options: &GenOptions) -> Result<String> {
        // In a real implementation, this would call MLX Python bindings
        // or use a native MLX Rust interface

        tracing::debug!(
            "MLX generation for model {}: prompt length = {}",
            self.name,
            prompt.len()
        );

        // Simulate MLX generation with a placeholder
        // Real implementation would:
        // 1. Tokenize the prompt
        // 2. Run inference on Metal GPU
        // 3. Decode tokens back to text
        // 4. Handle streaming if requested

        let response = format!(
            "MLX generated response for prompt: '{}...' (max_tokens: {})",
            &prompt.chars().take(50).collect::<String>(),
            options.max_tokens
        );

        Ok(response)
    }
}

#[async_trait]
impl LoadedModel for MLXModel {
    async fn generate(
        &self,
        prompt: &str,
        opts: GenOptions,
        mut on_token: Option<Box<dyn FnMut(String) + Send>>,
    ) -> Result<String> {
        tracing::info!("MLX generation request for model {}", self.name);

        // Generate response using MLX
        let response = self.mlx_generate(prompt, &opts).await?;

        // If streaming callback provided, simulate token-by-token streaming
        if let Some(ref mut callback) = on_token {
            let words: Vec<&str> = response.split_whitespace().collect();
            for word in words {
                callback(format!("{} ", word));
                // Small delay to simulate realistic streaming
                tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
            }
        }

        Ok(response)
    }
}

/// Utility functions for MLX integration
pub mod utils {
    use super::*;

    /// Check if current system supports MLX
    #[allow(dead_code)]
    pub fn is_mlx_supported() -> bool {
        MLXEngine::check_mlx_availability()
    }

    /// Get MLX system information
    #[allow(dead_code)]
    pub fn mlx_info() -> Result<String> {
        if !is_mlx_supported() {
            return Ok("MLX not supported on this system".to_string());
        }

        // In real implementation, would query MLX for:
        // - Metal GPU information
        // - Available memory
        // - MLX version
        // - Supported operations

        Ok("MLX available on Apple Silicon with Metal GPU".to_string())
    }

    /// Convert HuggingFace model to MLX format (placeholder)
    #[allow(dead_code)]
    pub async fn convert_to_mlx(model_path: &Path, output_path: &Path) -> Result<()> {
        // Real implementation would use MLX conversion tools
        tracing::info!(
            "Converting {:?} to MLX format at {:?}",
            model_path,
            output_path
        );

        // Placeholder - real conversion would:
        // 1. Load HuggingFace model
        // 2. Convert weights to MLX format
        // 3. Save as .npz or MLX native format
        // 4. Optimize for Metal GPU

        Err(anyhow!(
            "MLX conversion not yet implemented - placeholder for future development"
        ))
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use tempfile::tempdir;

    #[test]
    fn test_mlx_availability_check() {
        // Test should work on any platform
        let available = MLXEngine::check_mlx_availability();

        // On non-macOS or non-Apple Silicon, should be false
        #[cfg(not(target_os = "macos"))]
        assert!(
            !available,
            "MLX should not be available on non-macOS systems"
        );

        // On macOS, depends on actual MLX installation
        #[cfg(target_os = "macos")]
        {
            // Just verify the check doesn't panic
            let _ = available;
        }
    }

    #[test]
    fn test_mlx_compatibility_detection() {
        let temp_dir = tempdir().unwrap();

        // Test MLX-specific file extensions
        let mlx_spec = ModelSpec {
            name: "test-mlx".to_string(),
            base_path: temp_dir.path().join("model.npz"),
            lora_path: None,
            template: None,
            ctx_len: 2048,
            n_threads: Some(4),
        };

        assert!(MLXEngine::is_mlx_compatible(&mlx_spec));

        // Test known compatible model names
        let llama_spec = ModelSpec {
            name: "llama-7b".to_string(),
            base_path: temp_dir.path().join("model.bin"),
            lora_path: None,
            template: None,
            ctx_len: 2048,
            n_threads: Some(4),
        };

        assert!(MLXEngine::is_mlx_compatible(&llama_spec));
    }

    #[tokio::test]
    async fn test_mlx_model_creation_fails_gracefully() {
        let temp_dir = tempdir().unwrap();

        let spec = ModelSpec {
            name: "nonexistent".to_string(),
            base_path: temp_dir.path().join("nonexistent.npz"),
            lora_path: None,
            template: None,
            ctx_len: 2048,
            n_threads: Some(4),
        };

        let result = MLXModel::new(&spec).await;
        assert!(result.is_err(), "Should fail when model file doesn't exist");
    }
}