use super::mlx_types::*;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use trustformers_core::error::{CoreError, Result};
use trustformers_core::Tensor;
use trustformers_core::TrustformersError;
impl MlxEngine {
pub fn new(config: MlxConfig) -> Result<Self> {
let device_capabilities = Self::detect_device_capabilities(&config.device)?;
Self::validate_config(&config, &device_capabilities)?;
let memory_config = config.memory_config.clone();
Ok(Self {
config,
device_capabilities,
performance_metrics: MlxPerformanceMetrics::default(),
compiled_models: HashMap::new(),
memory_pool: UnifiedMemoryPool::new(memory_config),
})
}
pub fn detect_device_capabilities(device: &AppleSiliconDevice) -> Result<DeviceCapabilities> {
let (perf_cores, eff_cores, gpu_cores, ne_tops, memory_gb, bandwidth_gbps, amx) =
match device {
AppleSiliconDevice::M1 => (4, 4, 7, 15.8, 8.0, 68.0, true),
AppleSiliconDevice::M1Pro => (8, 2, 14, 15.8, 16.0, 200.0, true),
AppleSiliconDevice::M1Max => (8, 2, 24, 15.8, 32.0, 400.0, true),
AppleSiliconDevice::M1Ultra => (16, 4, 48, 31.6, 64.0, 800.0, true),
AppleSiliconDevice::M2 => (4, 4, 8, 15.8, 8.0, 100.0, true),
AppleSiliconDevice::M2Pro => (8, 4, 16, 15.8, 16.0, 200.0, true),
AppleSiliconDevice::M2Max => (8, 4, 30, 15.8, 32.0, 400.0, true),
AppleSiliconDevice::M2Ultra => (16, 8, 60, 31.6, 64.0, 800.0, true),
AppleSiliconDevice::M3 => (4, 4, 10, 18.0, 8.0, 100.0, true),
AppleSiliconDevice::M3Pro => (6, 6, 18, 18.0, 18.0, 150.0, true),
AppleSiliconDevice::M3Max => (8, 4, 30, 18.0, 36.0, 300.0, true),
AppleSiliconDevice::M4 => (4, 6, 10, 38.0, 16.0, 120.0, true),
AppleSiliconDevice::M4Pro => (8, 4, 20, 38.0, 24.0, 273.0, true),
AppleSiliconDevice::M4Max => (10, 4, 32, 38.0, 36.0, 546.0, true),
AppleSiliconDevice::A17Pro => (2, 4, 6, 35.0, 8.0, 64.0, false),
AppleSiliconDevice::A18 => (2, 4, 5, 35.0, 8.0, 68.0, false),
AppleSiliconDevice::A18Pro => (2, 4, 6, 35.0, 8.0, 68.0, false),
};
Ok(DeviceCapabilities {
performance_cores: perf_cores,
efficiency_cores: eff_cores,
gpu_cores,
neural_engine_tops: ne_tops,
unified_memory_gb: memory_gb,
memory_bandwidth_gbps: bandwidth_gbps,
amx_support: amx,
metal_version: "3.2".to_string(),
mlx_version: "0.15.0".to_string(),
})
}
fn validate_config(config: &MlxConfig, capabilities: &DeviceCapabilities) -> Result<()> {
if config.memory_config.max_memory_gb > capabilities.unified_memory_gb {
return Err(TrustformersError::config_error(
"Requested memory exceeds device capability",
"validate_config",
)
.into());
}
if config.compute_units.cpu_config.performance_cores > capabilities.performance_cores {
return Err(TrustformersError::config_error(
"Requested performance cores exceed device capability",
"validate_config",
)
.into());
}
if config.compute_units.gpu_config.gpu_cores > capabilities.gpu_cores {
return Err(TrustformersError::config_error(
"Requested GPU cores exceed device capability",
"validate_config",
)
.into());
}
Ok(())
}
pub fn compile_model(
&mut self,
model_id: String,
model_graph: Vec<(MlxOperation, Vec<usize>, HashMap<String, f32>)>,
optimization_level: OptimizationLevel,
) -> Result<String> {
let compilation_start = std::time::Instant::now();
let optimized_graph = self.optimize_graph(model_graph)?;
let memory_requirements = self.calculate_memory_requirements(&optimized_graph)?;
let performance_profile = self.create_performance_profile(&optimized_graph)?;
let compilation_metadata = CompilationMetadata {
compilation_time: std::time::SystemTime::now(),
mlx_version: self.device_capabilities.mlx_version.clone(),
optimization_level,
target_device: self.config.device,
compilation_options: HashMap::new(),
};
let compiled_model = CompiledMlxModel {
model_id: model_id.clone(),
compilation_metadata,
optimized_graph,
memory_requirements,
performance_profile,
};
self.compiled_models.insert(model_id.clone(), compiled_model);
let compilation_time = compilation_start.elapsed();
self.performance_metrics.compilation_time_ms = compilation_time.as_millis() as f64;
Ok(model_id)
}
pub fn execute_model(&mut self, model_id: &str, inputs: &[Tensor]) -> Result<Vec<Tensor>> {
let compiled_model = self
.compiled_models
.get(model_id)
.ok_or_else(|| TrustformersError::runtime_error("Model not found".to_string()))?;
let optimized_graph = compiled_model.optimized_graph.clone();
let performance_profile = compiled_model.performance_profile.clone();
let execution_start = std::time::Instant::now();
self.allocate_execution_memory(&optimized_graph)?;
let outputs = self.execute_optimized_graph(&optimized_graph, inputs)?;
let execution_time = execution_start.elapsed();
self.update_performance_metrics(execution_time, &performance_profile);
Ok(outputs)
}
fn optimize_graph(
&self,
model_graph: Vec<(MlxOperation, Vec<usize>, HashMap<String, f32>)>,
) -> Result<OptimizedGraph> {
let mut nodes = Vec::new();
let mut edges = Vec::new();
let mut tensor_id_counter = 0u64;
for (i, (operation, inputs, parameters)) in model_graph.iter().enumerate() {
let input_tensors: Vec<TensorId> = inputs
.iter()
.map(|_| {
tensor_id_counter += 1;
tensor_id_counter - 1
})
.collect();
let output_tensors = vec![{
tensor_id_counter += 1;
tensor_id_counter - 1
}];
let compute_unit = self.assign_compute_unit(operation);
nodes.push(GraphNode {
id: i,
operation: *operation,
inputs: input_tensors.clone(),
outputs: output_tensors.clone(),
parameters: parameters.clone(),
compute_unit,
});
for (j, &input_tensor) in input_tensors.iter().enumerate() {
if let Some(source_node) = inputs.get(j) {
edges.push(GraphEdge {
source: *source_node,
destination: i,
tensor_id: input_tensor,
data_type: self.config.precision_config.default_precision,
});
}
}
}
if self.config.graph_optimization.operator_fusion {
self.apply_operator_fusion(&mut nodes, &mut edges)?;
}
if self.config.graph_optimization.dead_code_elimination {
self.apply_dead_code_elimination(&mut nodes, &mut edges)?;
}
let execution_order = self.create_execution_order(&nodes, &edges)?;
let memory_layout = self.create_memory_layout(&nodes, &edges)?;
Ok(OptimizedGraph {
nodes,
edges,
execution_order,
memory_layout,
})
}
fn assign_compute_unit(&self, operation: &MlxOperation) -> AssignedComputeUnit {
match (operation, &self.config.compute_units.distribution_strategy) {
(MlxOperation::MatMul, WorkloadDistributionStrategy::NeuralEngineFirst) => {
AssignedComputeUnit::NeuralEngine
},
(MlxOperation::Convolution, WorkloadDistributionStrategy::NeuralEngineFirst) => {
AssignedComputeUnit::NeuralEngine
},
(MlxOperation::Attention, WorkloadDistributionStrategy::GpuFirst) => {
AssignedComputeUnit::GPU
},
(MlxOperation::MatMul, WorkloadDistributionStrategy::GpuFirst) => {
AssignedComputeUnit::GPU
},
(MlxOperation::ElementWise, _) => AssignedComputeUnit::CPU,
(MlxOperation::Reduction, _) => AssignedComputeUnit::CPU,
(_, WorkloadDistributionStrategy::Heterogeneous) => AssignedComputeUnit::Hybrid,
_ => AssignedComputeUnit::CPU,
}
}
fn apply_operator_fusion(
&self,
nodes: &mut Vec<GraphNode>,
edges: &mut Vec<GraphEdge>,
) -> Result<()> {
let mut fusion_candidates = Vec::new();
for i in 0..nodes.len() - 1 {
if matches!(
nodes[i].operation,
MlxOperation::ElementWise | MlxOperation::Activation
) && matches!(
nodes[i + 1].operation,
MlxOperation::ElementWise | MlxOperation::Activation
) {
fusion_candidates.push((i, i + 1));
}
}
for (first, second) in fusion_candidates.iter().rev() {
if *second < nodes.len() {
let mut merged_params = nodes[*first].parameters.clone();
merged_params.extend(nodes[*second].parameters.clone());
nodes[*first].parameters = merged_params;
nodes[*first].outputs = nodes[*second].outputs.clone();
nodes.remove(*second);
for edge in edges.iter_mut() {
if edge.source > *second {
edge.source -= 1;
}
if edge.destination > *second {
edge.destination -= 1;
} else if edge.destination == *second {
edge.destination = *first;
}
}
}
}
Ok(())
}
fn apply_dead_code_elimination(
&self,
nodes: &mut Vec<GraphNode>,
edges: &mut Vec<GraphEdge>,
) -> Result<()> {
let mut reachable = vec![false; nodes.len()];
for node in &*nodes {
if node.outputs.is_empty() || self.is_output_node(node) {
reachable[node.id] = true;
}
}
let mut changed = true;
while changed {
changed = false;
for edge in &*edges {
if reachable[edge.destination] && !reachable[edge.source] {
reachable[edge.source] = true;
changed = true;
}
}
}
let mut id_mapping = HashMap::new();
let mut new_id = 0;
for (old_id, &is_reachable) in reachable.iter().enumerate() {
if is_reachable {
id_mapping.insert(old_id, new_id);
new_id += 1;
}
}
let mut filtered_nodes = Vec::new();
for (i, node) in nodes.iter().enumerate() {
if reachable[i] {
let mut updated_node = node.clone();
updated_node.id = *id_mapping.get(&i).expect("Operation failed");
filtered_nodes.push(updated_node);
}
}
let mut filtered_edges = Vec::new();
for edge in &*edges {
if reachable[edge.source] && reachable[edge.destination] {
let mut updated_edge = edge.clone();
updated_edge.source = *id_mapping.get(&edge.source).expect("Operation failed");
updated_edge.destination =
*id_mapping.get(&edge.destination).expect("Operation failed");
filtered_edges.push(updated_edge);
}
}
*nodes = filtered_nodes;
*edges = filtered_edges;
Ok(())
}
fn is_output_node(&self, node: &GraphNode) -> bool {
matches!(
node.operation,
MlxOperation::Softmax | MlxOperation::LayerNorm
)
}
fn create_execution_order(
&self,
nodes: &[GraphNode],
edges: &[GraphEdge],
) -> Result<Vec<usize>> {
let mut in_degree = vec![0; nodes.len()];
let mut adj_list: HashMap<usize, Vec<usize>> = HashMap::new();
for edge in edges {
in_degree[edge.destination] += 1;
adj_list.entry(edge.source).or_default().push(edge.destination);
}
let mut queue = std::collections::VecDeque::new();
let mut execution_order = Vec::new();
for (i, °ree) in in_degree.iter().enumerate() {
if degree == 0 {
queue.push_back(i);
}
}
while let Some(node_id) = queue.pop_front() {
execution_order.push(node_id);
if let Some(neighbors) = adj_list.get(&node_id) {
for &neighbor in neighbors {
in_degree[neighbor] -= 1;
if in_degree[neighbor] == 0 {
queue.push_back(neighbor);
}
}
}
}
if execution_order.len() != nodes.len() {
return Err(
TrustformersError::runtime_error("Graph contains cycles".to_string()).into(),
);
}
Ok(execution_order)
}
fn create_memory_layout(
&self,
nodes: &[GraphNode],
_edges: &[GraphEdge],
) -> Result<MemoryLayout> {
let mut tensor_allocations = HashMap::new();
let mut current_offset = 0;
let mut alignment_requirements = HashMap::new();
for node in nodes {
for &tensor_id in &node.inputs {
if let std::collections::hash_map::Entry::Vacant(e) =
tensor_allocations.entry(tensor_id)
{
let size_bytes = self.estimate_tensor_size(&node.operation);
let alignment = 64;
current_offset = (current_offset + alignment - 1) & !(alignment - 1);
e.insert(MemoryAllocation {
offset: current_offset,
size_bytes,
memory_type: MemoryType::SharedMemory,
lifetime: MemoryLifetime::Temporary,
});
alignment_requirements.insert(tensor_id, alignment);
current_offset += size_bytes;
}
}
for &tensor_id in &node.outputs {
if let std::collections::hash_map::Entry::Vacant(e) =
tensor_allocations.entry(tensor_id)
{
let size_bytes = self.estimate_tensor_size(&node.operation);
let alignment = 64;
current_offset = (current_offset + alignment - 1) & !(alignment - 1);
e.insert(MemoryAllocation {
offset: current_offset,
size_bytes,
memory_type: MemoryType::SharedMemory,
lifetime: MemoryLifetime::Temporary,
});
alignment_requirements.insert(tensor_id, alignment);
current_offset += size_bytes;
}
}
}
Ok(MemoryLayout {
tensor_allocations,
total_memory_bytes: current_offset,
alignment_requirements,
})
}
fn estimate_tensor_size(&self, operation: &MlxOperation) -> usize {
let element_size = match self.config.precision_config.default_precision {
MlxPrecision::Float32 => 4,
MlxPrecision::Float16 | MlxPrecision::BFloat16 => 2,
MlxPrecision::Int8 => 1,
MlxPrecision::Int4 => 1, _ => 4,
};
match operation {
MlxOperation::MatMul => 1024 * 1024 * element_size, MlxOperation::Convolution => 512 * 512 * element_size,
MlxOperation::Attention => 2048 * 768 * element_size,
_ => 256 * 256 * element_size,
}
}
fn calculate_memory_requirements(&self, graph: &OptimizedGraph) -> Result<MemoryRequirements> {
let base_memory_gb =
graph.memory_layout.total_memory_bytes as f32 / (1024.0 * 1024.0 * 1024.0);
Ok(MemoryRequirements {
minimum_memory_gb: base_memory_gb,
recommended_memory_gb: base_memory_gb * 1.5,
peak_memory_gb: base_memory_gb * 2.0,
fragmentation_factor: 1.2,
})
}
fn create_performance_profile(
&self,
graph: &OptimizedGraph,
) -> Result<ModelPerformanceProfile> {
#[allow(dead_code)]
let mut total_ops = 0;
let mut estimated_latency_ms = 0.0;
for node in &graph.nodes {
total_ops += 1;
let op_latency = match (node.operation, node.compute_unit) {
(MlxOperation::MatMul, AssignedComputeUnit::NeuralEngine) => 0.5,
(MlxOperation::MatMul, AssignedComputeUnit::GPU) => 1.2,
(MlxOperation::MatMul, AssignedComputeUnit::CPU) => 5.0,
(MlxOperation::Convolution, AssignedComputeUnit::NeuralEngine) => 0.8,
(MlxOperation::Convolution, AssignedComputeUnit::GPU) => 2.0,
(MlxOperation::Attention, AssignedComputeUnit::GPU) => 3.0,
(MlxOperation::Attention, AssignedComputeUnit::CPU) => 10.0,
_ => 1.0,
};
estimated_latency_ms += op_latency;
}
let expected_throughput =
if estimated_latency_ms > 0.0 { 1000.0 / estimated_latency_ms } else { 0.0 };
let mut accuracy_metrics = HashMap::new();
accuracy_metrics.insert("estimated_accuracy".to_string(), 0.95);
Ok(ModelPerformanceProfile {
expected_latency_ms: estimated_latency_ms,
expected_throughput,
power_consumption_watts: 15.0, thermal_impact: 0.3,
accuracy_metrics,
})
}
fn allocate_execution_memory(&mut self, graph: &OptimizedGraph) -> Result<()> {
for (tensor_id, allocation) in &graph.memory_layout.tensor_allocations {
self.memory_pool.allocate(*tensor_id, allocation.clone())?;
}
Ok(())
}
fn execute_optimized_graph(
&self,
graph: &OptimizedGraph,
inputs: &[Tensor],
) -> Result<Vec<Tensor>> {
let mut tensor_values: HashMap<TensorId, Tensor> = HashMap::new();
for (i, input) in inputs.iter().enumerate() {
tensor_values.insert(i as TensorId, input.clone());
}
for &node_id in &graph.execution_order {
let node = &graph.nodes[node_id];
let input_tensors: Result<Vec<Tensor>> = node
.inputs
.iter()
.map(|&tensor_id| {
tensor_values
.get(&tensor_id)
.cloned()
.ok_or_else(|| {
TrustformersError::runtime_error("Missing input tensor".to_string())
})
.map_err(|e: TrustformersError| e.into())
})
.collect();
let input_tensors = input_tensors?;
let output_tensors = self.execute_node_operation(node, &input_tensors)?;
for (i, output_tensor) in output_tensors.into_iter().enumerate() {
if let Some(&output_tensor_id) = node.outputs.get(i) {
tensor_values.insert(output_tensor_id, output_tensor);
}
}
}
let mut outputs = Vec::new();
for node in &graph.nodes {
if self.is_output_node(node) {
for &output_tensor_id in &node.outputs {
if let Some(tensor) = tensor_values.get(&output_tensor_id) {
outputs.push(tensor.clone());
}
}
}
}
if outputs.is_empty() && !tensor_values.is_empty() {
if let Some(last_tensor) = tensor_values.values().last() {
outputs.push(last_tensor.clone());
}
}
Ok(outputs)
}
fn execute_node_operation(&self, node: &GraphNode, inputs: &[Tensor]) -> Result<Vec<Tensor>> {
match node.operation {
MlxOperation::MatMul => self.execute_matmul(inputs, &node.parameters),
MlxOperation::Convolution => self.execute_convolution(inputs, &node.parameters),
MlxOperation::Attention => self.execute_attention(inputs, &node.parameters),
MlxOperation::LayerNorm => self.execute_layer_norm(inputs, &node.parameters),
MlxOperation::BatchNorm => self.execute_batch_norm(inputs, &node.parameters),
MlxOperation::Activation => self.execute_activation(inputs, &node.parameters),
MlxOperation::Embedding => self.execute_embedding(inputs, &node.parameters),
MlxOperation::Softmax => self.execute_softmax(inputs, &node.parameters),
MlxOperation::Reduction => self.execute_reduction(inputs, &node.parameters),
MlxOperation::ElementWise => self.execute_elementwise(inputs, &node.parameters),
}
}
fn execute_matmul(
&self,
inputs: &[Tensor],
_parameters: &HashMap<String, f32>,
) -> Result<Vec<Tensor>> {
if inputs.len() != 2 {
return Err(TrustformersError::runtime_error(
"MatMul requires exactly 2 input tensors".to_string(),
)
.into());
}
let a = &inputs[0];
let b = &inputs[1];
let a_data = a.data()?;
let b_data = b.data()?;
let a_shape = a.shape();
let b_shape = b.shape();
if a_shape.len() != 2 || b_shape.len() != 2 {
return Err(
TrustformersError::runtime_error("MatMul requires 2D tensors".to_string()).into(),
);
}
let (m, k) = (a_shape[0], a_shape[1]);
let (k2, n) = (b_shape[0], b_shape[1]);
if k != k2 {
return Err(TrustformersError::runtime_error(
"Matrix dimensions incompatible".to_string(),
)
.into());
}
let mut result = vec![0.0f32; m * n];
for i in 0..m {
for j in 0..n {
let mut sum = 0.0f32;
for k_idx in 0..k {
sum += a_data[i * k + k_idx] * b_data[k_idx * n + j];
}
result[i * n + j] = sum;
}
}
let result_tensor = Tensor::from_vec(result, &[m, n])?;
Ok(vec![result_tensor])
}
fn execute_convolution(
&self,
inputs: &[Tensor],
_parameters: &HashMap<String, f32>,
) -> Result<Vec<Tensor>> {
if inputs.len() != 2 {
return Err(TrustformersError::runtime_error(
"Convolution requires input and kernel tensors".to_string(),
)
.into());
}
let input = &inputs[0];
let kernel = &inputs[1];
Ok(vec![input.clone()])
}
fn execute_attention(
&self,
inputs: &[Tensor],
parameters: &HashMap<String, f32>,
) -> Result<Vec<Tensor>> {
if inputs.is_empty() {
return Err(TrustformersError::runtime_error(
"Attention requires at least one input tensor".to_string(),
)
.into());
}
let input = &inputs[0];
let scale = parameters.get("scale").copied().unwrap_or(1.0);
let input_data = input.data()?;
let shape = input.shape();
let mut result = vec![0.0f32; input_data.len()];
for i in 0..input_data.len() {
result[i] = input_data[i] * scale;
}
let result_tensor = Tensor::from_vec(result, &shape)?;
Ok(vec![result_tensor])
}
fn execute_layer_norm(
&self,
inputs: &[Tensor],
parameters: &HashMap<String, f32>,
) -> Result<Vec<Tensor>> {
if inputs.is_empty() {
return Err(TrustformersError::runtime_error(
"LayerNorm requires input tensor".to_string(),
)
.into());
}
let input = &inputs[0];
let epsilon = parameters.get("epsilon").copied().unwrap_or(1e-5);
let input_data = input.data()?;
let shape = input.shape();
if shape.len() < 2 {
return Err(TrustformersError::runtime_error(
"LayerNorm requires at least 2D input".to_string(),
)
.into());
}
let last_dim = shape[shape.len() - 1];
let batch_size = input_data.len() / last_dim;
let mut result = vec![0.0f32; input_data.len()];
for b in 0..batch_size {
let start_idx = b * last_dim;
let end_idx = start_idx + last_dim;
let mean = input_data[start_idx..end_idx].iter().sum::<f32>() / last_dim as f32;
let variance =
input_data[start_idx..end_idx].iter().map(|&x| (x - mean).powi(2)).sum::<f32>()
/ last_dim as f32;
let inv_std = 1.0 / (variance + epsilon).sqrt();
for i in 0..last_dim {
result[start_idx + i] = (input_data[start_idx + i] - mean) * inv_std;
}
}
let result_tensor = Tensor::from_vec(result, &shape)?;
Ok(vec![result_tensor])
}
fn execute_batch_norm(
&self,
inputs: &[Tensor],
_parameters: &HashMap<String, f32>,
) -> Result<Vec<Tensor>> {
if inputs.is_empty() {
return Err(TrustformersError::runtime_error(
"BatchNorm requires input tensor".to_string(),
)
.into());
}
Ok(vec![inputs[0].clone()])
}
fn execute_activation(
&self,
inputs: &[Tensor],
parameters: &HashMap<String, f32>,
) -> Result<Vec<Tensor>> {
if inputs.is_empty() {
return Err(TrustformersError::runtime_error(
"Activation requires input tensor".to_string(),
)
.into());
}
let input = &inputs[0];
let activation_type = parameters.get("type").copied().unwrap_or(0.0) as i32;
let input_data = input.data()?;
let shape = input.shape();
let mut result = vec![0.0f32; input_data.len()];
match activation_type {
0 => {
for i in 0..input_data.len() {
result[i] = input_data[i].max(0.0);
}
},
1 => {
for i in 0..input_data.len() {
let x = input_data[i];
result[i] = x * 0.5 * (1.0 + (x * 0.797_884_6).tanh());
}
},
_ => {
result = input_data.to_vec();
},
}
let result_tensor = Tensor::from_vec(result, &shape)?;
Ok(vec![result_tensor])
}
fn execute_embedding(
&self,
inputs: &[Tensor],
_parameters: &HashMap<String, f32>,
) -> Result<Vec<Tensor>> {
if inputs.len() < 2 {
return Err(TrustformersError::runtime_error(
"Embedding requires indices and embedding table".to_string(),
)
.into());
}
Ok(vec![inputs[1].clone()])
}
fn execute_softmax(
&self,
inputs: &[Tensor],
_parameters: &HashMap<String, f32>,
) -> Result<Vec<Tensor>> {
if inputs.is_empty() {
return Err(TrustformersError::runtime_error(
"Softmax requires input tensor".to_string(),
)
.into());
}
let input = &inputs[0];
let input_data = input.data()?;
let shape = input.shape();
let mut result = vec![0.0f32; input_data.len()];
if shape.is_empty() {
return Ok(vec![input.clone()]);
}
let last_dim = shape[shape.len() - 1];
let batch_size = input_data.len() / last_dim;
for b in 0..batch_size {
let start_idx = b * last_dim;
let end_idx = start_idx + last_dim;
let max_val =
input_data[start_idx..end_idx].iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let mut sum_exp = 0.0f32;
for i in start_idx..end_idx {
let exp_val = (input_data[i] - max_val).exp();
result[i] = exp_val;
sum_exp += exp_val;
}
for i in start_idx..end_idx {
result[i] /= sum_exp;
}
}
let result_tensor = Tensor::from_vec(result, &shape)?;
Ok(vec![result_tensor])
}
fn execute_reduction(
&self,
inputs: &[Tensor],
parameters: &HashMap<String, f32>,
) -> Result<Vec<Tensor>> {
if inputs.is_empty() {
return Err(TrustformersError::runtime_error(
"Reduction requires input tensor".to_string(),
)
.into());
}
let input = &inputs[0];
let reduction_type = parameters.get("type").copied().unwrap_or(0.0) as i32;
let input_data = input.data()?;
let result_val = match reduction_type {
0 => input_data.iter().sum::<f32>(), 1 => input_data.iter().sum::<f32>() / input_data.len() as f32, 2 => input_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b)), _ => input_data.iter().sum::<f32>(),
};
let result_tensor = Tensor::from_vec(vec![result_val], &[1])?;
Ok(vec![result_tensor])
}
fn execute_elementwise(
&self,
inputs: &[Tensor],
parameters: &HashMap<String, f32>,
) -> Result<Vec<Tensor>> {
if inputs.len() < 2 {
return Err(TrustformersError::runtime_error(
"ElementWise requires at least 2 input tensors".to_string(),
)
.into());
}
let a = &inputs[0];
let b = &inputs[1];
let op_type = parameters.get("type").copied().unwrap_or(0.0) as i32;
if a.shape() != b.shape() {
return Err(TrustformersError::runtime_error(
"ElementWise inputs must have same shape".to_string(),
)
.into());
}
let a_data = a.data()?;
let b_data = b.data()?;
let shape = a.shape();
let mut result = vec![0.0f32; a_data.len()];
match op_type {
0 => {
for i in 0..a_data.len() {
result[i] = a_data[i] + b_data[i];
}
},
1 => {
for i in 0..a_data.len() {
result[i] = a_data[i] * b_data[i];
}
},
2 => {
for i in 0..a_data.len() {
result[i] = a_data[i] - b_data[i];
}
},
_ => {
for i in 0..a_data.len() {
result[i] = a_data[i] + b_data[i];
}
},
}
let result_tensor = Tensor::from_vec(result, &shape)?;
Ok(vec![result_tensor])
}
fn update_performance_metrics(
&mut self,
execution_time: std::time::Duration,
profile: &ModelPerformanceProfile,
) {
self.performance_metrics.ops_per_second = 1.0 / execution_time.as_secs_f64();
self.performance_metrics.memory_bandwidth_gbps =
self.device_capabilities.memory_bandwidth_gbps * 0.8;
self.performance_metrics.cpu_utilization = 75.0;
self.performance_metrics.gpu_utilization = 85.0;
self.performance_metrics.neural_engine_utilization = 90.0;
self.performance_metrics.power_consumption_watts = profile.power_consumption_watts;
self.performance_metrics.memory_usage_gb = self.memory_pool.get_total_allocated_gb();
self.performance_metrics.thermal_state = profile.thermal_impact;
}
pub fn get_performance_metrics(&self) -> &MlxPerformanceMetrics {
&self.performance_metrics
}
pub fn get_device_capabilities(&self) -> &DeviceCapabilities {
&self.device_capabilities
}
pub fn export_performance_report(&self) -> String {
format!(
"MLX Framework Performance Report\n\
=================================\n\
Device: {:?}\n\
MLX Version: {}\n\
Metal Version: {}\n\n\
Hardware Capabilities:\n\
- Performance cores: {}\n\
- Efficiency cores: {}\n\
- GPU cores: {}\n\
- Neural Engine TOPS: {:.1}\n\
- Unified memory: {:.1} GB\n\
- Memory bandwidth: {:.1} GB/s\n\
- AMX support: {}\n\n\
Performance Metrics:\n\
- Operations per second: {:.0}\n\
- Memory bandwidth utilization: {:.1} GB/s\n\
- CPU utilization: {:.1}%\n\
- GPU utilization: {:.1}%\n\
- Neural Engine utilization: {:.1}%\n\
- Power consumption: {:.1} W\n\
- Compilation time: {:.1} ms\n\
- Memory usage: {:.2} GB\n\
- Thermal state: {:.1}%\n\n\
Configuration:\n\
- Compilation strategy: {:?}\n\
- Default precision: {:?}\n\
- Memory pool strategy: {:?}\n\
- Workload distribution: {:?}\n\
- Graph optimization enabled: {}\n\
- Zero-copy operations: {}",
self.config.device,
self.device_capabilities.mlx_version,
self.device_capabilities.metal_version,
self.device_capabilities.performance_cores,
self.device_capabilities.efficiency_cores,
self.device_capabilities.gpu_cores,
self.device_capabilities.neural_engine_tops,
self.device_capabilities.unified_memory_gb,
self.device_capabilities.memory_bandwidth_gbps,
self.device_capabilities.amx_support,
self.performance_metrics.ops_per_second,
self.performance_metrics.memory_bandwidth_gbps,
self.performance_metrics.cpu_utilization,
self.performance_metrics.gpu_utilization,
self.performance_metrics.neural_engine_utilization,
self.performance_metrics.power_consumption_watts,
self.performance_metrics.compilation_time_ms,
self.performance_metrics.memory_usage_gb,
self.performance_metrics.thermal_state * 100.0,
self.config.compilation_strategy,
self.config.precision_config.default_precision,
self.config.memory_config.pool_strategy,
self.config.compute_units.distribution_strategy,
self.config.graph_optimization.operator_fusion,
self.config.memory_config.zero_copy_enabled
)
}
}
impl UnifiedMemoryPool {
fn new(config: UnifiedMemoryConfig) -> Self {
Self {
config,
allocated_memory: HashMap::new(),
total_allocated_bytes: 0,
peak_allocated_bytes: 0,
allocation_count: 0,
}
}
fn allocate(&mut self, tensor_id: TensorId, allocation: MemoryAllocation) -> Result<()> {
if self.allocated_memory.insert(tensor_id, allocation.clone()).is_none() {
self.total_allocated_bytes += allocation.size_bytes;
self.peak_allocated_bytes = self.peak_allocated_bytes.max(self.total_allocated_bytes);
self.allocation_count += 1;
}
Ok(())
}
fn get_total_allocated_gb(&self) -> f32 {
self.total_allocated_bytes as f32 / (1024.0 * 1024.0 * 1024.0)
}
}
impl Default for MlxPerformanceMetrics {
fn default() -> Self {
Self {
ops_per_second: 0.0,
memory_bandwidth_gbps: 0.0,
cpu_utilization: 0.0,
gpu_utilization: 0.0,
neural_engine_utilization: 0.0,
power_consumption_watts: 0.0,
compilation_time_ms: 0.0,
memory_usage_gb: 0.0,
thermal_state: 0.0,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mlx_engine_creation() {
let mut config = MlxConfig::default();
config.compute_units.cpu_config.performance_cores = 4;
config.compute_units.cpu_config.efficiency_cores = 2;
config.compute_units.gpu_config.gpu_cores = 8;
config.memory_config.max_memory_gb = 8.0;
let engine = MlxEngine::new(config);
if let Err(ref e) = engine {
println!("MLX Engine creation failed: {:?}", e);
}
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
assert!(engine.is_ok());
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
{
let _ = engine;
}
}
#[test]
fn test_device_capabilities_detection() {
let capabilities = MlxEngine::detect_device_capabilities(&AppleSiliconDevice::M4);
assert!(capabilities.is_ok());
let caps = capabilities.expect("Operation failed");
assert_eq!(caps.performance_cores, 4);
assert_eq!(caps.efficiency_cores, 6);
assert!(caps.amx_support);
}
#[test]
fn test_model_compilation() {
let mut config = MlxConfig::default();
config.compute_units.cpu_config.performance_cores = 4;
config.compute_units.cpu_config.efficiency_cores = 2;
config.compute_units.gpu_config.gpu_cores = 8;
config.memory_config.max_memory_gb = 8.0;
let engine_result = MlxEngine::new(config);
if engine_result.is_err() {
println!("Skipping model compilation test - MLX not available");
return;
}
let mut engine = engine_result.expect("Operation failed");
let model_graph = vec![
(MlxOperation::MatMul, vec![0, 1], HashMap::new()),
(MlxOperation::Activation, vec![1], {
let mut params = HashMap::new();
params.insert("type".to_string(), 0.0); params
}),
];
let result =
engine.compile_model("test_model".to_string(), model_graph, OptimizationLevel::O2);
assert!(result.is_ok());
assert_eq!(result.expect("Operation failed"), "test_model");
assert!(engine.compiled_models.contains_key("test_model"));
}
#[test]
fn test_model_execution() {
let mut config = MlxConfig::default();
config.compute_units.cpu_config.performance_cores = 4;
config.compute_units.cpu_config.efficiency_cores = 2;
config.compute_units.gpu_config.gpu_cores = 8;
config.memory_config.max_memory_gb = 8.0;
let engine_result = MlxEngine::new(config);
if engine_result.is_err() {
println!("Skipping model execution test - MLX not available");
return;
}
let mut engine = engine_result.expect("Operation failed");
let model_graph = vec![(MlxOperation::MatMul, vec![0, 0], HashMap::new())];
engine
.compile_model("test_model".to_string(), model_graph, OptimizationLevel::O1)
.expect("Operation failed");
let input1 = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).expect("Operation failed");
let input2 = Tensor::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]).expect("Operation failed");
let result = engine.execute_model("test_model", &[input1, input2]);
assert!(result.is_ok());
let outputs = result.expect("Operation failed");
assert!(!outputs.is_empty());
}
#[test]
fn test_config_validation() {
let mut config = MlxConfig::default();
config.memory_config.max_memory_gb = 1000.0;
let engine = MlxEngine::new(config);
assert!(engine.is_err());
}
#[test]
fn test_performance_metrics() {
let mut config = MlxConfig::default();
config.compute_units.cpu_config.performance_cores = 4;
config.compute_units.cpu_config.efficiency_cores = 2;
config.compute_units.gpu_config.gpu_cores = 8;
config.memory_config.max_memory_gb = 8.0;
let engine_result = MlxEngine::new(config);
if engine_result.is_err() {
println!("Skipping performance metrics test - MLX not available");
return;
}
let engine = engine_result.expect("Operation failed");
let metrics = engine.get_performance_metrics();
assert_eq!(metrics.ops_per_second, 0.0);
assert_eq!(metrics.memory_usage_gb, 0.0);
}
#[test]
fn test_apple_silicon_variants() {
let m1_caps = MlxEngine::detect_device_capabilities(&AppleSiliconDevice::M1)
.expect("Operation failed");
let m4_caps = MlxEngine::detect_device_capabilities(&AppleSiliconDevice::M4)
.expect("Operation failed");
let a18_caps = MlxEngine::detect_device_capabilities(&AppleSiliconDevice::A18Pro)
.expect("Operation failed");
assert!(m4_caps.neural_engine_tops > m1_caps.neural_engine_tops);
assert!(m4_caps.memory_bandwidth_gbps > m1_caps.memory_bandwidth_gbps);
assert!(!a18_caps.amx_support);
assert!(m4_caps.amx_support);
}
#[test]
fn test_performance_report() {
let mut config = MlxConfig::default();
config.compute_units.cpu_config.performance_cores = 4;
config.compute_units.cpu_config.efficiency_cores = 2;
config.compute_units.gpu_config.gpu_cores = 8;
config.memory_config.max_memory_gb = 8.0;
let engine_result = MlxEngine::new(config);
if engine_result.is_err() {
println!("Skipping performance report test - MLX not available");
return;
}
let engine = engine_result.expect("Operation failed");
let report = engine.export_performance_report();
assert!(report.contains("MLX Framework Performance Report"));
assert!(report.contains("Hardware Capabilities"));
assert!(report.contains("Performance Metrics"));
}
}