use crate::core::traits::Tokenizer;
use crate::error::{Result, TrustformersError};
use crate::pipeline::{ClassificationOutput, GenerationOutput, Pipeline, PipelineOutput};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use trustformers_core::tensor::Tensor;
#[derive(Debug, Clone)]
pub struct MetalBackend {
device: MetalDevice,
command_queue: MetalCommandQueue,
compute_pipeline: Option<MetalComputePipelineState>,
library: Option<MetalLibrary>,
}
#[derive(Debug, Clone)]
pub struct MetalDevice;
#[derive(Debug, Clone)]
pub struct MetalCommandQueue;
#[derive(Debug, Clone)]
pub struct MetalComputePipelineState;
#[derive(Debug, Clone)]
pub struct MetalLibrary;
#[derive(Debug, Clone)]
pub struct MetalBuffer;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum MetalPrecisionMode {
FP32,
FP16,
INT8,
Auto,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum MetalDeviceType {
IntegratedGPU,
DiscreteGPU,
NeuralEngine,
CPU,
Auto,
}
#[derive(Debug, Clone, Copy)]
pub enum MetalOptimizationLevel {
O1,
O2,
O3,
Ofast,
}
#[derive(Debug, Clone, Copy)]
pub enum MetalMemoryStrategy {
Shared,
Private,
Managed,
Auto,
}
#[derive(Debug, Clone)]
pub struct MetalBackendConfig {
pub model_path: PathBuf,
pub metal_library_path: Option<PathBuf>,
pub device_type: MetalDeviceType,
pub precision_mode: MetalPrecisionMode,
pub max_batch_size: usize,
pub memory_strategy: MetalMemoryStrategy,
pub optimization_level: MetalOptimizationLevel,
pub enable_neural_engine: bool,
pub enable_mps: bool,
pub buffer_allocation_size: usize,
pub enable_profiling: bool,
pub profile_output_path: Option<PathBuf>,
pub enable_fast_math: bool,
pub enable_threadgroup_optimization: bool,
}
impl Default for MetalBackendConfig {
fn default() -> Self {
Self {
model_path: PathBuf::new(),
metal_library_path: None,
device_type: MetalDeviceType::Auto,
precision_mode: MetalPrecisionMode::Auto,
max_batch_size: 1,
memory_strategy: MetalMemoryStrategy::Auto,
optimization_level: MetalOptimizationLevel::O2,
enable_neural_engine: true,
enable_mps: true,
buffer_allocation_size: 256 * 1024 * 1024, enable_profiling: false,
profile_output_path: None,
enable_fast_math: true,
enable_threadgroup_optimization: true,
}
}
}
impl MetalBackend {
pub fn new(config: MetalBackendConfig) -> Result<Self> {
let device = Self::create_device(config.device_type)?;
let command_queue = Self::create_command_queue(&device)?;
Ok(Self {
device,
command_queue,
compute_pipeline: None,
library: None,
})
}
fn create_device(device_type: MetalDeviceType) -> Result<MetalDevice> {
Ok(MetalDevice)
}
fn create_command_queue(device: &MetalDevice) -> Result<MetalCommandQueue> {
Ok(MetalCommandQueue)
}
pub fn load_shaders(&mut self, shader_path: &Path) -> Result<()> {
self.library = Some(MetalLibrary);
self.compute_pipeline = Some(MetalComputePipelineState);
Ok(())
}
pub fn get_device_capabilities(&self) -> MetalDeviceCapabilities {
MetalDeviceCapabilities {
supports_neural_engine: true,
supports_mps: true,
supports_fp16: true,
supports_int8: true,
max_threads_per_threadgroup: 1024,
max_buffer_size: 4 * 1024 * 1024 * 1024, unified_memory: true,
}
}
pub fn run_inference(
&self,
inputs: HashMap<String, Tensor>,
) -> Result<HashMap<String, Tensor>> {
let mut outputs = HashMap::new();
let output_shape = vec![1, 512]; let output_data: Vec<f32> = vec![0.5; output_shape.iter().product()];
let output_tensor = Tensor::from_vec(output_data, &output_shape)?;
outputs.insert("logits".to_string(), output_tensor);
Ok(outputs)
}
fn create_buffers(&self, tensor: &Tensor) -> Result<MetalBuffer> {
Ok(MetalBuffer)
}
pub fn compile_model(&mut self, model_path: &Path) -> Result<()> {
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct MetalDeviceCapabilities {
pub supports_neural_engine: bool,
pub supports_mps: bool,
pub supports_fp16: bool,
pub supports_int8: bool,
pub max_threads_per_threadgroup: usize,
pub max_buffer_size: usize,
pub unified_memory: bool,
}
pub struct MetalTextClassificationPipeline<T: Tokenizer> {
tokenizer: T,
backend: MetalBackend,
config: MetalBackendConfig,
}
impl<T: Tokenizer + Clone> MetalTextClassificationPipeline<T> {
pub fn new(tokenizer: T, config: MetalBackendConfig) -> Result<Self> {
let mut backend = MetalBackend::new(config.clone())?;
backend.compile_model(&config.model_path)?;
Ok(Self {
tokenizer,
backend,
config,
})
}
pub fn device_capabilities(&self) -> MetalDeviceCapabilities {
self.backend.get_device_capabilities()
}
}
impl<T: Tokenizer + Clone> Pipeline for MetalTextClassificationPipeline<T> {
type Input = String;
type Output = PipelineOutput;
fn __call__(&self, input: Self::Input) -> Result<Self::Output> {
let tokenized = self.tokenizer.encode(&input)?;
let input_ids = tokenized.input_ids;
let attention_mask = tokenized.attention_mask;
let mut inputs = HashMap::new();
inputs.insert(
"input_ids".to_string(),
Tensor::from_vec(
input_ids.iter().map(|&x| x as f32).collect(),
&[1, input_ids.len()],
)?,
);
inputs.insert(
"attention_mask".to_string(),
Tensor::from_vec(
attention_mask.iter().map(|&x| x as f32).collect(),
&[1, attention_mask.len()],
)?,
);
let outputs = self.backend.run_inference(inputs)?;
if let Some(logits_tensor) = outputs.get("logits") {
let logits = logits_tensor.data()?;
let exp_logits: Vec<f32> = logits.iter().map(|x| x.exp()).collect();
let sum_exp: f32 = exp_logits.iter().sum();
let probabilities: Vec<f32> = exp_logits.iter().map(|x| x / sum_exp).collect();
let mut results = Vec::new();
for (i, &prob) in probabilities.iter().enumerate() {
results.push(ClassificationOutput {
label: format!("LABEL_{}", i),
score: prob,
});
}
results
.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
Ok(PipelineOutput::Classification(results))
} else {
Err(TrustformersError::invalid_input_simple(
"No logits output from Metal backend".to_string(),
))
}
}
}
pub struct MetalTextGenerationPipeline<T: Tokenizer> {
tokenizer: T,
backend: MetalBackend,
config: MetalBackendConfig,
}
impl<T: Tokenizer + Clone> MetalTextGenerationPipeline<T> {
pub fn new(tokenizer: T, config: MetalBackendConfig) -> Result<Self> {
let mut backend = MetalBackend::new(config.clone())?;
backend.compile_model(&config.model_path)?;
Ok(Self {
tokenizer,
backend,
config,
})
}
}
impl<T: Tokenizer + Clone> Pipeline for MetalTextGenerationPipeline<T> {
type Input = String;
type Output = PipelineOutput;
fn __call__(&self, input: Self::Input) -> Result<Self::Output> {
let tokenized = self.tokenizer.encode(&input)?;
let input_ids = tokenized.input_ids;
let mut inputs = HashMap::new();
inputs.insert(
"input_ids".to_string(),
Tensor::from_vec(
input_ids.iter().map(|&x| x as f32).collect(),
&[1, input_ids.len()],
)?,
);
let outputs = self.backend.run_inference(inputs)?;
if let Some(logits_tensor) = outputs.get("logits") {
let logits = logits_tensor.data()?;
let next_token_id = logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(index, _)| index as u32)
.unwrap_or(0);
let generated_text = self.tokenizer.decode(&[next_token_id])?;
Ok(PipelineOutput::Generation(GenerationOutput {
generated_text: input + &generated_text,
sequences: Some(vec![vec![next_token_id]]),
scores: Some(logits.clone()),
}))
} else {
Err(TrustformersError::invalid_input_simple(
"No logits output from Metal backend".to_string(),
))
}
}
}
impl MetalBackendConfig {
pub fn for_apple_silicon() -> Self {
Self {
device_type: MetalDeviceType::IntegratedGPU,
precision_mode: MetalPrecisionMode::FP16,
memory_strategy: MetalMemoryStrategy::Shared,
enable_neural_engine: true,
enable_mps: true,
optimization_level: MetalOptimizationLevel::O3,
enable_fast_math: true,
enable_threadgroup_optimization: true,
..Default::default()
}
}
pub fn for_ios() -> Self {
Self {
device_type: MetalDeviceType::IntegratedGPU,
precision_mode: MetalPrecisionMode::FP16,
memory_strategy: MetalMemoryStrategy::Shared,
enable_neural_engine: true,
enable_mps: true,
optimization_level: MetalOptimizationLevel::O2,
max_batch_size: 1,
buffer_allocation_size: 128 * 1024 * 1024, enable_fast_math: true,
..Default::default()
}
}
pub fn for_intel_mac() -> Self {
Self {
device_type: MetalDeviceType::DiscreteGPU,
precision_mode: MetalPrecisionMode::FP32,
memory_strategy: MetalMemoryStrategy::Private,
enable_neural_engine: false,
enable_mps: false,
optimization_level: MetalOptimizationLevel::O2,
enable_fast_math: false,
..Default::default()
}
}
pub fn for_maximum_performance() -> Self {
Self {
device_type: MetalDeviceType::IntegratedGPU,
precision_mode: MetalPrecisionMode::INT8,
memory_strategy: MetalMemoryStrategy::Shared,
enable_neural_engine: true,
enable_mps: true,
optimization_level: MetalOptimizationLevel::Ofast,
enable_fast_math: true,
enable_threadgroup_optimization: true,
..Default::default()
}
}
}
pub fn create_metal_text_classification_pipeline<T: Tokenizer + Clone>(
tokenizer: T,
config: Option<MetalBackendConfig>,
) -> Result<MetalTextClassificationPipeline<T>> {
let config = config.unwrap_or_else(MetalBackendConfig::for_apple_silicon);
MetalTextClassificationPipeline::new(tokenizer, config)
}
pub fn create_metal_text_generation_pipeline<T: Tokenizer + Clone>(
tokenizer: T,
config: Option<MetalBackendConfig>,
) -> Result<MetalTextGenerationPipeline<T>> {
let config = config.unwrap_or_else(MetalBackendConfig::for_apple_silicon);
MetalTextGenerationPipeline::new(tokenizer, config)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_metal_backend_creation() {
let config = MetalBackendConfig::for_apple_silicon();
let backend = MetalBackend::new(config);
assert!(backend.is_ok());
}
#[test]
fn test_device_capabilities() {
let config = MetalBackendConfig::for_apple_silicon();
let backend = MetalBackend::new(config).expect("operation failed in test");
let capabilities = backend.get_device_capabilities();
assert!(capabilities.supports_neural_engine);
assert!(capabilities.supports_mps);
assert!(capabilities.supports_fp16);
assert!(capabilities.unified_memory);
}
#[test]
fn test_configuration_presets() {
let apple_silicon_config = MetalBackendConfig::for_apple_silicon();
assert_eq!(
apple_silicon_config.device_type,
MetalDeviceType::IntegratedGPU
);
assert_eq!(
apple_silicon_config.precision_mode,
MetalPrecisionMode::FP16
);
assert!(apple_silicon_config.enable_neural_engine);
let ios_config = MetalBackendConfig::for_ios();
assert_eq!(ios_config.max_batch_size, 1);
assert_eq!(ios_config.buffer_allocation_size, 128 * 1024 * 1024);
let intel_config = MetalBackendConfig::for_intel_mac();
assert_eq!(intel_config.device_type, MetalDeviceType::DiscreteGPU);
assert_eq!(intel_config.precision_mode, MetalPrecisionMode::FP32);
assert!(!intel_config.enable_neural_engine);
}
#[test]
fn test_default_config_values() {
let cfg = MetalBackendConfig::default();
assert_eq!(cfg.max_batch_size, 1);
assert!(cfg.enable_neural_engine);
assert!(cfg.enable_mps);
assert!(cfg.enable_fast_math);
assert!(cfg.enable_threadgroup_optimization);
assert!(!cfg.enable_profiling);
assert!(cfg.profile_output_path.is_none());
assert!(cfg.metal_library_path.is_none());
}
#[test]
fn test_default_config_buffer_allocation_256mb() {
let cfg = MetalBackendConfig::default();
assert_eq!(cfg.buffer_allocation_size, 256 * 1024 * 1024);
}
#[test]
fn test_device_type_variants_are_distinct() {
assert_ne!(
MetalDeviceType::IntegratedGPU as i32,
MetalDeviceType::DiscreteGPU as i32
);
assert_ne!(
MetalDeviceType::NeuralEngine as i32,
MetalDeviceType::CPU as i32
);
assert_ne!(MetalDeviceType::Auto as i32, MetalDeviceType::CPU as i32);
}
#[test]
fn test_precision_mode_variants() {
let fp32 = MetalPrecisionMode::FP32;
let fp16 = MetalPrecisionMode::FP16;
let int8 = MetalPrecisionMode::INT8;
let auto = MetalPrecisionMode::Auto;
assert_ne!(fp32, fp16);
assert_ne!(fp16, int8);
assert_ne!(int8, auto);
}
#[test]
fn test_memory_strategy_variants() {
let strategies = [
MetalMemoryStrategy::Shared,
MetalMemoryStrategy::Private,
MetalMemoryStrategy::Managed,
MetalMemoryStrategy::Auto,
];
assert_eq!(strategies.len(), 4);
}
#[test]
fn test_optimization_level_variants() {
let levels = [
MetalOptimizationLevel::O1,
MetalOptimizationLevel::O2,
MetalOptimizationLevel::O3,
MetalOptimizationLevel::Ofast,
];
assert_eq!(levels.len(), 4);
}
#[test]
fn test_run_inference_returns_logits() {
let config = MetalBackendConfig::for_apple_silicon();
let backend = MetalBackend::new(config).expect("backend creation failed");
let inputs = HashMap::new();
let outputs = backend.run_inference(inputs).expect("inference failed");
assert!(
outputs.contains_key("logits"),
"output must contain 'logits' key"
);
}
#[test]
fn test_run_inference_output_shape_product() {
let config = MetalBackendConfig::for_apple_silicon();
let backend = MetalBackend::new(config).expect("backend creation failed");
let inputs = HashMap::new();
let outputs = backend.run_inference(inputs).expect("inference failed");
let logits = outputs.get("logits").expect("logits key must be present");
let data = logits.data().expect("tensor data must be accessible");
assert_eq!(
data.len(),
512,
"output shape should be [1, 512] → 512 elements"
);
}
#[test]
fn test_load_shaders_sets_pipeline_state() {
let config = MetalBackendConfig::for_apple_silicon();
let mut backend = MetalBackend::new(config).expect("backend creation failed");
let dummy_path = std::path::Path::new("/tmp/dummy.metallib");
let result = backend.load_shaders(dummy_path);
assert!(result.is_ok(), "load_shaders should succeed (mock impl)");
assert!(
backend.compute_pipeline.is_some(),
"compute_pipeline must be set after load_shaders"
);
assert!(
backend.library.is_some(),
"library must be set after load_shaders"
);
}
#[test]
fn test_compile_model_ok() {
let config = MetalBackendConfig::for_apple_silicon();
let mut backend = MetalBackend::new(config).expect("backend creation failed");
let dummy_path = std::path::Path::new("/tmp/model.onnx");
let result = backend.compile_model(dummy_path);
assert!(result.is_ok(), "compile_model should succeed (mock impl)");
}
#[test]
fn test_maximum_performance_preset() {
let cfg = MetalBackendConfig::for_maximum_performance();
assert_eq!(cfg.precision_mode, MetalPrecisionMode::INT8);
assert!(cfg.enable_neural_engine);
assert!(cfg.enable_mps);
assert!(cfg.enable_fast_math);
assert!(cfg.enable_threadgroup_optimization);
}
#[test]
fn test_device_capabilities_max_threads_positive() {
let config = MetalBackendConfig::default();
let backend = MetalBackend::new(config).expect("backend creation failed");
let cap = backend.get_device_capabilities();
assert!(cap.max_threads_per_threadgroup > 0);
}
#[test]
fn test_device_capabilities_max_buffer_size_positive() {
let config = MetalBackendConfig::default();
let backend = MetalBackend::new(config).expect("backend creation failed");
let cap = backend.get_device_capabilities();
assert!(cap.max_buffer_size > 0);
}
#[test]
fn test_device_capabilities_int8_support() {
let config = MetalBackendConfig::for_apple_silicon();
let backend = MetalBackend::new(config).expect("backend creation failed");
let cap = backend.get_device_capabilities();
assert!(cap.supports_int8);
}
#[test]
fn test_ios_config_128mb_buffer() {
let cfg = MetalBackendConfig::for_ios();
assert_eq!(cfg.buffer_allocation_size, 128 * 1024 * 1024);
}
#[test]
fn test_ios_config_o2_optimization() {
let cfg = MetalBackendConfig::for_ios();
assert!(matches!(cfg.optimization_level, MetalOptimizationLevel::O2));
}
#[test]
fn test_intel_mac_no_mps() {
let cfg = MetalBackendConfig::for_intel_mac();
assert!(!cfg.enable_mps);
assert!(!cfg.enable_fast_math);
}
#[test]
fn test_intel_mac_private_memory() {
let cfg = MetalBackendConfig::for_intel_mac();
assert!(matches!(cfg.memory_strategy, MetalMemoryStrategy::Private));
}
#[test]
fn test_profiling_disabled_by_default_in_presets() {
for cfg in [
MetalBackendConfig::for_apple_silicon(),
MetalBackendConfig::for_ios(),
MetalBackendConfig::for_intel_mac(),
MetalBackendConfig::for_maximum_performance(),
] {
assert!(
!cfg.enable_profiling,
"profiling should be disabled by default"
);
}
}
#[test]
fn test_inference_output_values_are_finite() {
let config = MetalBackendConfig::default();
let backend = MetalBackend::new(config).expect("backend ok");
let outputs = backend.run_inference(HashMap::new()).expect("inference ok");
let logits = outputs.get("logits").expect("logits");
let data = logits.data().expect("data");
for v in &data {
assert!(v.is_finite(), "every logit output must be finite");
}
}
}