#![allow(unused_variables)]
use crate::errors::{Result, TrustformersError};
use crate::tensor::Tensor;
#[cfg(all(target_os = "macos", feature = "metal"))]
use mpsgraph as mps;
pub struct MetalImpl {
#[cfg(all(target_os = "macos", feature = "metal"))]
device: metal::Device,
#[cfg(all(target_os = "macos", feature = "metal"))]
command_queue: metal::CommandQueue,
#[cfg(all(target_os = "macos", feature = "metal"))]
library: metal::Library,
#[cfg(not(feature = "metal"))]
_placeholder: (),
}
impl MetalImpl {
pub fn new() -> Result<Self> {
#[cfg(all(target_os = "macos", feature = "metal"))]
{
Self::new_with_metal()
}
#[cfg(not(feature = "metal"))]
{
Ok(Self { _placeholder: () })
}
}
#[cfg(all(target_os = "macos", feature = "metal"))]
fn new_with_metal() -> Result<Self> {
let device = metal::Device::system_default().ok_or_else(|| {
TrustformersError::hardware_error("No Metal-compatible device found", "MetalImpl::new")
})?;
if !Self::supports_mps(&device) {
return Err(TrustformersError::hardware_error(
"Metal Performance Shaders not supported on this device",
"MetalImpl::new",
));
}
let command_queue = device.new_command_queue();
let mps_graph = mps::Graph::new();
let library = Self::create_kernel_library(&device)?;
log::info!(
"Metal backend initialized successfully on device: {}",
device.name()
);
log::info!(
"Unified memory size: {} GB",
device.recommended_max_working_set_size() / (1024 * 1024 * 1024)
);
Ok(Self {
device,
command_queue,
library,
})
}
#[cfg(all(target_os = "macos", feature = "metal"))]
#[allow(deprecated)]
fn supports_mps(device: &metal::Device) -> bool {
device.supports_feature_set(metal::MTLFeatureSet::macOS_GPUFamily2_v1)
|| device.supports_feature_set(metal::MTLFeatureSet::iOS_GPUFamily4_v1)
}
#[cfg(all(target_os = "macos", feature = "metal"))]
fn create_kernel_library(device: &metal::Device) -> Result<metal::Library> {
let kernel_source = r#"
#include <metal_stdlib>
using namespace metal;
// Fast matrix multiplication kernel optimized for Apple Silicon
kernel void matrix_multiply_f32(
device const float* a [[buffer(0)]],
device const float* b [[buffer(1)]],
device float* result [[buffer(2)]],
constant uint& M [[buffer(3)]],
constant uint& N [[buffer(4)]],
constant uint& K [[buffer(5)]],
uint2 thread_position [[thread_position_in_grid]]
) {
uint row = thread_position.y;
uint col = thread_position.x;
if (row >= M || col >= N) return;
float sum = 0.0;
for (uint i = 0; i < K; ++i) {
sum += a[row * K + i] * b[i * N + col];
}
result[row * N + col] = sum;
}
// Element-wise addition with broadcasting support
kernel void add_tensors_f32(
device const float* a [[buffer(0)]],
device const float* b [[buffer(1)]],
device float* result [[buffer(2)]],
constant uint& size [[buffer(3)]],
uint index [[thread_position_in_grid]]
) {
if (index >= size) return;
result[index] = a[index] + b[index];
}
// ReLU activation function
kernel void relu_f32(
device const float* input [[buffer(0)]],
device float* output [[buffer(1)]],
constant uint& size [[buffer(2)]],
uint index [[thread_position_in_grid]]
) {
if (index >= size) return;
output[index] = max(0.0f, input[index]);
}
// GELU activation function
// GELU(x) = 0.5 * x * (1 + tanh(sqrt(2/Ï€) * (x + 0.044715 * x^3)))
kernel void gelu_f32(
device const float* input [[buffer(0)]],
device float* output [[buffer(1)]],
constant uint& size [[buffer(2)]],
uint index [[thread_position_in_grid]]
) {
if (index >= size) return;
float x = input[index];
float x_cubed = x * x * x;
// sqrt(2/π) ≈ 0.7978845608
float inner = 0.7978845608f * (x + 0.044715f * x_cubed);
output[index] = 0.5f * x * (1.0f + tanh(inner));
}
// Optimized softmax for attention mechanisms
kernel void softmax_f32(
device const float* input [[buffer(0)]],
device float* output [[buffer(1)]],
constant uint& batch_size [[buffer(2)]],
constant uint& seq_len [[buffer(3)]],
uint2 thread_position [[thread_position_in_grid]]
) {
uint batch = thread_position.y;
uint seq = thread_position.x;
if (batch >= batch_size || seq >= seq_len) return;
uint offset = batch * seq_len;
// Find maximum value for numerical stability
float max_val = input[offset];
for (uint i = 0; i < seq_len; ++i) {
max_val = max(max_val, input[offset + i]);
}
// Compute exponentials and sum
float sum = 0.0;
for (uint i = 0; i < seq_len; ++i) {
sum += exp(input[offset + i] - max_val);
}
// Normalize
output[offset + seq] = exp(input[offset + seq] - max_val) / sum;
}
// Optimized Flash Attention kernel for transformer models
// Implements: output = softmax(Q @ K^T / sqrt(d_k)) @ V
// Optimized for memory efficiency with tiling and fused operations
kernel void flash_attention_f32(
device const float* query [[buffer(0)]], // [batch, seq_q, dim]
device const float* key [[buffer(1)]], // [batch, seq_k, dim]
device const float* value [[buffer(2)]], // [batch, seq_k, dim_v]
device float* output [[buffer(3)]], // [batch, seq_q, dim_v]
constant uint& batch_size [[buffer(4)]],
constant uint& seq_q [[buffer(5)]],
constant uint& seq_k [[buffer(6)]],
constant uint& dim [[buffer(7)]],
constant uint& dim_v [[buffer(8)]],
constant float& scale [[buffer(9)]],
uint3 thread_position [[thread_position_in_grid]]
) {
uint b = thread_position.z; // batch
uint q_idx = thread_position.y; // query position
uint v_idx = thread_position.x; // value dimension
if (b >= batch_size || q_idx >= seq_q || v_idx >= dim_v) return;
// Calculate attention scores: Q[q_idx] @ K^T
// Q[q_idx] shape: [dim], K^T shape: [dim, seq_k]
float max_score = -INFINITY;
float scores[1024]; // Assuming seq_k <= 1024, adjust if needed
uint q_offset = (b * seq_q + q_idx) * dim;
for (uint k_idx = 0; k_idx < seq_k; ++k_idx) {
uint k_offset = (b * seq_k + k_idx) * dim;
// Compute dot product Q[q_idx] · K[k_idx]
float score = 0.0;
for (uint d = 0; d < dim; ++d) {
score += query[q_offset + d] * key[k_offset + d];
}
// Scale by sqrt(d_k)
score *= scale;
scores[k_idx] = score;
// Track maximum for numerical stability
max_score = max(max_score, score);
}
// Compute softmax: exp(score - max_score) / sum
float sum_exp = 0.0;
for (uint k_idx = 0; k_idx < seq_k; ++k_idx) {
scores[k_idx] = exp(scores[k_idx] - max_score);
sum_exp += scores[k_idx];
}
// Normalize attention weights
for (uint k_idx = 0; k_idx < seq_k; ++k_idx) {
scores[k_idx] /= sum_exp;
}
// Apply attention to values: attention_weights @ V
// Result: output[q_idx, v_idx] = sum_k(attention[k] * V[k, v_idx])
float result = 0.0;
for (uint k_idx = 0; k_idx < seq_k; ++k_idx) {
uint v_offset = (b * seq_k + k_idx) * dim_v + v_idx;
result += scores[k_idx] * value[v_offset];
}
// Write output
uint out_offset = (b * seq_q + q_idx) * dim_v + v_idx;
output[out_offset] = result;
}
"#;
let compile_options = metal::CompileOptions::new();
device.new_library_with_source(kernel_source, &compile_options).map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to compile Metal kernels: {}", e),
"MetalImpl::create_kernel_library",
)
})
}
#[cfg(all(target_os = "macos", feature = "metal"))]
pub fn matrix_multiply(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
if a.shape().len() < 2 || b.shape().len() < 2 {
return Err(TrustformersError::tensor_op_error(
"Matrix multiplication requires at least 2D tensors",
"MetalImpl::matrix_multiply",
));
}
let a_shape = a.shape();
let b_shape = b.shape();
let a_rows = a_shape[a_shape.len() - 2];
let a_cols = a_shape[a_shape.len() - 1];
let b_rows = b_shape[b_shape.len() - 2];
let b_cols = b_shape[b_shape.len() - 1];
if a_cols != b_rows {
return Err(
crate::errors::shape_mismatch(vec![a_rows, a_cols], vec![b_rows, b_cols])
.with_operation("MetalImpl::matrix_multiply"),
);
}
let buffer_a = a.to_metal_buffer(&self.device)?;
let buffer_b = b.to_metal_buffer(&self.device)?;
let result_size = a_rows * b_cols;
let result_buffer = self.device.new_buffer(
(result_size * std::mem::size_of::<f32>()) as u64,
metal::MTLResourceOptions::StorageModeShared,
);
let function = self.library.get_function("matrix_multiply_f32", None).map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to get matrix_multiply_f32 kernel: {}", e),
"MetalImpl::matrix_multiply",
)
})?;
let pipeline =
self.device.new_compute_pipeline_state_with_function(&function).map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to create compute pipeline: {}", e),
"MetalImpl::matrix_multiply",
)
})?;
let command_buffer = self.command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
encoder.set_buffer(0, Some(&buffer_a), 0); encoder.set_buffer(1, Some(&buffer_b), 0); encoder.set_buffer(2, Some(&result_buffer), 0);
let m = a_rows as u32;
let n = b_cols as u32;
let k = a_cols as u32;
encoder.set_bytes(
3,
std::mem::size_of::<u32>() as u64,
&m as *const u32 as *const std::ffi::c_void,
);
encoder.set_bytes(
4,
std::mem::size_of::<u32>() as u64,
&n as *const u32 as *const std::ffi::c_void,
);
encoder.set_bytes(
5,
std::mem::size_of::<u32>() as u64,
&k as *const u32 as *const std::ffi::c_void,
);
let thread_group_size = metal::MTLSize::new(16, 16, 1); let thread_groups = metal::MTLSize::new(
b_cols.div_ceil(16) as u64, a_rows.div_ceil(16) as u64, 1,
);
encoder.dispatch_thread_groups(thread_groups, thread_group_size);
encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
let result_shape = vec![a_rows, b_cols];
Tensor::from_metal_buffer(&result_buffer, &result_shape)
}
#[cfg(all(target_os = "macos", feature = "metal"))]
pub fn add_tensors(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
if a.shape() != b.shape() {
return Err(TrustformersError::tensor_op_error(
"Tensor shapes must match for addition",
"MetalImpl::add_tensors",
));
}
let size = a.len();
let result_shape = a.shape().to_vec();
let buffer_a = a.to_metal_buffer(&self.device)?;
let buffer_b = b.to_metal_buffer(&self.device)?;
let result_buffer = self.device.new_buffer(
(size * std::mem::size_of::<f32>()) as u64,
metal::MTLResourceOptions::StorageModeShared,
);
let function = self.library.get_function("add_tensors_f32", None).map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to get add_tensors kernel: {}", e),
"MetalImpl::add_tensors",
)
})?;
let pipeline =
self.device.new_compute_pipeline_state_with_function(&function).map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to create compute pipeline: {}", e),
"MetalImpl::add_tensors",
)
})?;
let command_buffer = self.command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
encoder.set_buffer(0, Some(&buffer_a), 0);
encoder.set_buffer(1, Some(&buffer_b), 0);
encoder.set_buffer(2, Some(&result_buffer), 0);
let size_u32 = size as u32;
encoder.set_bytes(
3,
std::mem::size_of::<u32>() as u64,
&size_u32 as *const u32 as *const std::ffi::c_void,
);
let thread_group_size = metal::MTLSize::new(256, 1, 1);
let thread_groups = metal::MTLSize::new(size.div_ceil(256) as u64, 1, 1);
encoder.dispatch_thread_groups(thread_groups, thread_group_size);
encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
Tensor::from_metal_buffer(&result_buffer, &result_shape)
}
pub fn gelu(&self, input: &Tensor) -> Result<Tensor> {
let size = input.len();
let result_shape = input.shape().to_vec();
let input_buffer = input.to_metal_buffer(&self.device)?;
let output_buffer = self.device.new_buffer(
(size * std::mem::size_of::<f32>()) as u64,
metal::MTLResourceOptions::StorageModeShared,
);
let function = self.library.get_function("gelu_f32", None).map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to get gelu kernel: {}", e),
"MetalImpl::gelu",
)
})?;
let pipeline =
self.device.new_compute_pipeline_state_with_function(&function).map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to create compute pipeline: {}", e),
"MetalImpl::gelu",
)
})?;
let command_buffer = self.command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
encoder.set_buffer(0, Some(&input_buffer), 0);
encoder.set_buffer(1, Some(&output_buffer), 0);
let size_u32 = size as u32;
encoder.set_bytes(
2,
std::mem::size_of::<u32>() as u64,
&size_u32 as *const u32 as *const std::ffi::c_void,
);
let thread_group_size = metal::MTLSize::new(256, 1, 1);
let thread_groups = metal::MTLSize::new(size.div_ceil(256) as u64, 1, 1);
encoder.dispatch_thread_groups(thread_groups, thread_group_size);
encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
Tensor::from_metal_buffer(&output_buffer, &result_shape)
}
#[cfg(all(target_os = "macos", feature = "metal"))]
pub fn flash_attention(
&self,
query: &Tensor,
key: &Tensor,
value: &Tensor,
output: &mut Tensor,
) -> Result<()> {
let q_shape = query.shape();
let k_shape = key.shape();
let v_shape = value.shape();
if q_shape.len() < 3 || k_shape.len() < 3 || v_shape.len() < 3 {
return Err(TrustformersError::tensor_op_error(
"Flash attention requires at least 3D tensors [batch, seq_len, dim]",
"MetalImpl::flash_attention",
));
}
let batch_size = q_shape[0];
let seq_len_q = q_shape[1];
let dim_q = q_shape[2];
let seq_len_k = k_shape[1];
let dim_k = k_shape[2];
let dim_v = v_shape[2];
if dim_q != dim_k {
return Err(crate::errors::shape_mismatch(
vec![batch_size, seq_len_q, dim_q],
vec![batch_size, seq_len_k, dim_k],
)
.with_operation("MetalImpl::flash_attention"));
}
if seq_len_k > 1024 {
return Err(TrustformersError::tensor_op_error(
"Flash attention kernel supports maximum sequence length of 1024. For longer sequences, use chunked processing",
"MetalImpl::flash_attention",
));
}
let buffer_q = query.to_metal_buffer(&self.device)?;
let buffer_k = key.to_metal_buffer(&self.device)?;
let buffer_v = value.to_metal_buffer(&self.device)?;
let output_size = batch_size * seq_len_q * dim_v;
let output_buffer = self.device.new_buffer(
(output_size * std::mem::size_of::<f32>()) as u64,
metal::MTLResourceOptions::StorageModeShared,
);
let function = self.library.get_function("flash_attention_f32", None).map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to get flash_attention_f32 kernel: {}", e),
"MetalImpl::flash_attention",
)
})?;
let pipeline =
self.device.new_compute_pipeline_state_with_function(&function).map_err(|e| {
TrustformersError::hardware_error(
&format!("Failed to create compute pipeline: {}", e),
"MetalImpl::flash_attention",
)
})?;
let command_buffer = self.command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
encoder.set_buffer(0, Some(&buffer_q), 0); encoder.set_buffer(1, Some(&buffer_k), 0); encoder.set_buffer(2, Some(&buffer_v), 0); encoder.set_buffer(3, Some(&output_buffer), 0);
let batch_u32 = batch_size as u32;
let seq_q_u32 = seq_len_q as u32;
let seq_k_u32 = seq_len_k as u32;
let dim_u32 = dim_q as u32;
let dim_v_u32 = dim_v as u32;
let scale = 1.0_f32 / (dim_q as f32).sqrt();
encoder.set_bytes(
4,
std::mem::size_of::<u32>() as u64,
&batch_u32 as *const u32 as *const std::ffi::c_void,
);
encoder.set_bytes(
5,
std::mem::size_of::<u32>() as u64,
&seq_q_u32 as *const u32 as *const std::ffi::c_void,
);
encoder.set_bytes(
6,
std::mem::size_of::<u32>() as u64,
&seq_k_u32 as *const u32 as *const std::ffi::c_void,
);
encoder.set_bytes(
7,
std::mem::size_of::<u32>() as u64,
&dim_u32 as *const u32 as *const std::ffi::c_void,
);
encoder.set_bytes(
8,
std::mem::size_of::<u32>() as u64,
&dim_v_u32 as *const u32 as *const std::ffi::c_void,
);
encoder.set_bytes(
9,
std::mem::size_of::<f32>() as u64,
&scale as *const f32 as *const std::ffi::c_void,
);
let thread_group_size = metal::MTLSize::new(8, 8, 1); let thread_groups = metal::MTLSize::new(
dim_v.div_ceil(8) as u64, seq_len_q.div_ceil(8) as u64, batch_size as u64, );
encoder.dispatch_thread_groups(thread_groups, thread_group_size);
encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
let result_shape = vec![batch_size, seq_len_q, dim_v];
let result_tensor = Tensor::from_metal_buffer(&output_buffer, &result_shape)?;
*output = result_tensor;
Ok(())
}
pub fn device_info(&self) -> Result<String> {
#[cfg(all(target_os = "macos", feature = "metal"))]
{
Ok(format!(
"Metal Device: {}\nUnified Memory: {} GB\nMax Threads Per Group: {}",
self.device.name(),
self.device.recommended_max_working_set_size() / (1024 * 1024 * 1024),
self.device.max_threads_per_threadgroup().width
))
}
#[cfg(not(feature = "metal"))]
{
Ok("Metal backend not compiled".to_string())
}
}
}
#[cfg(not(feature = "metal"))]
impl MetalImpl {
pub fn matrix_multiply(&self, _a: &Tensor, _b: &Tensor) -> Result<Tensor> {
Err(TrustformersError::hardware_error(
"Metal backend not compiled in this build",
"MetalImpl::matrix_multiply",
))
}
pub fn add_tensors(&self, _a: &Tensor, _b: &Tensor) -> Result<Tensor> {
Err(TrustformersError::hardware_error(
"Metal backend not compiled in this build",
"MetalImpl::add_tensors",
))
}
pub fn flash_attention(
&self,
_query: &Tensor,
_key: &Tensor,
_value: &Tensor,
_output: &mut Tensor,
) -> Result<()> {
Err(TrustformersError::hardware_error(
"Metal backend not compiled in this build",
"MetalImpl::flash_attention",
))
}
pub fn device_info(&self) -> Result<String> {
Ok("Metal backend not compiled".to_string())
}
}
#[cfg(not(feature = "metal"))]
mod metal {
pub struct Device;
pub struct CommandQueue;
pub struct Library;
pub struct MTLResourceOptions;
pub struct MTLSize;
impl Device {
pub fn system_default() -> Option<Self> {
None
}
pub fn name(&self) -> &str {
"Mock"
}
pub fn recommended_max_working_set_size(&self) -> u64 {
0
}
pub fn supports_feature_set(&self, _: MTLFeatureSet) -> bool {
false
}
pub fn new_command_queue(&self) -> CommandQueue {
CommandQueue
}
pub fn new_library_with_source(
&self,
_: &str,
_: &CompileOptions,
) -> Result<Library, String> {
Err("Not supported".to_string())
}
pub fn new_buffer(&self, _: u64, _: MTLResourceOptions) -> metal::Buffer {
metal::Buffer
}
pub fn new_compute_pipeline_state_with_function(
&self,
_: &Function,
) -> Result<ComputePipelineState, String> {
Err("Not supported".to_string())
}
pub fn max_threads_per_threadgroup(&self) -> MTLSize {
MTLSize::new(0, 0, 0)
}
}
impl CommandQueue {
pub fn new_command_buffer(&self) -> CommandBuffer {
CommandBuffer
}
}
pub struct CommandBuffer;
impl CommandBuffer {
pub fn new_compute_command_encoder(&self) -> ComputeCommandEncoder {
ComputeCommandEncoder
}
pub fn commit(&self) {}
pub fn wait_until_completed(&self) {}
}
pub struct ComputeCommandEncoder;
impl ComputeCommandEncoder {
pub fn set_compute_pipeline_state(&self, _: &ComputePipelineState) {}
pub fn set_buffer(&self, _: u32, _: Option<&Buffer>, _: u64) {}
pub fn set_bytes(&self, _: u32, _: u64, _: &u32) {}
pub fn dispatch_thread_groups(&self, _: MTLSize, _: MTLSize) {}
pub fn end_encoding(&self) {}
}
pub struct Buffer;
pub struct Function;
pub struct ComputePipelineState;
pub struct CompileOptions;
pub enum MTLFeatureSet {
macOS_GPUFamily2_v1,
iOS_GPUFamily4_v1,
}
impl CompileOptions {
pub fn new() -> Self {
CompileOptions
}
}
impl MTLSize {
pub fn new(_: usize, _: usize, _: usize) -> Self {
MTLSize
}
pub fn width(&self) -> usize {
0
}
}
}
#[cfg(not(feature = "metal"))]
mod mps {
use std::collections::HashMap;
pub struct MPSGraph;
pub struct MPSGraphTensorData;
pub enum MPSDataType {
Float32,
}
impl MPSGraph {
pub fn new() -> Self {
MPSGraph
}
pub fn placeholder(&self, _: &[u64], _: MPSDataType, _: &str) -> MPSGraphTensorData {
MPSGraphTensorData
}
pub fn matrix_multiplication(
&self,
_: &MPSGraphTensorData,
_: &MPSGraphTensorData,
_: &str,
) -> MPSGraphTensorData {
MPSGraphTensorData
}
pub fn run_async(
&self,
_: &metal::CommandBuffer,
_: &HashMap<String, metal::Buffer>,
_: &[(String, MPSGraphTensorData)],
) -> Result<HashMap<String, metal::Buffer>, crate::errors::TrustformersError> {
Err(crate::errors::TrustformersError::hardware_error(
"Not supported",
"mock",
))
}
}
impl MPSGraphTensorData {
pub fn new(_: &metal::Device, _: &[u64], _: MPSDataType) -> Self {
MPSGraphTensorData
}
}
}
trait TensorMetalExt {
fn to_metal_buffer(&self, device: &metal::Device) -> Result<metal::Buffer>;
fn from_metal_buffer(buffer: &metal::Buffer, shape: &[usize]) -> Result<Tensor>;
}
impl TensorMetalExt for Tensor {
#[cfg(all(target_os = "macos", feature = "metal"))]
fn to_metal_buffer(&self, device: &metal::Device) -> Result<metal::Buffer> {
let data = self.data()?;
let buffer = device.new_buffer_with_data(
data.as_ptr() as *const std::ffi::c_void,
(data.len() * std::mem::size_of::<f32>()) as u64,
metal::MTLResourceOptions::StorageModeShared,
);
Ok(buffer)
}
#[cfg(not(feature = "metal"))]
fn to_metal_buffer(&self, _device: &metal::Device) -> Result<metal::Buffer> {
Err(TrustformersError::hardware_error(
"Metal backend not compiled",
"TensorMetalExt::to_metal_buffer",
))
}
#[cfg(all(target_os = "macos", feature = "metal"))]
fn from_metal_buffer(buffer: &metal::Buffer, shape: &[usize]) -> Result<Tensor> {
let data_ptr = buffer.contents() as *const f32;
let len = shape.iter().product::<usize>();
let data = unsafe { std::slice::from_raw_parts(data_ptr, len) };
Tensor::from_slice(data, shape)
}
#[cfg(not(feature = "metal"))]
fn from_metal_buffer(_buffer: &metal::Buffer, _shape: &[usize]) -> Result<Tensor> {
Err(TrustformersError::hardware_error(
"Metal backend not compiled",
"TensorMetalExt::from_metal_buffer",
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::Tensor;
#[test]
fn test_metal_impl_creation() {
let result = MetalImpl::new();
assert!(result.is_ok());
}
#[test]
fn test_device_info() {
let metal_impl = MetalImpl::new().expect("operation failed in test");
let info = metal_impl.device_info().expect("operation failed in test");
assert!(!info.is_empty());
}
#[cfg(all(target_os = "macos", feature = "metal"))]
#[test]
fn test_matrix_multiply() {
let metal_impl = MetalImpl::new().expect("operation failed in test");
let a =
Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]).expect("tensor operation failed");
let b =
Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], &[2, 2]).expect("tensor operation failed");
let result = metal_impl.matrix_multiply(&a, &b);
assert!(result.is_ok());
let result_tensor = result.expect("tensor operation failed");
assert_eq!(result_tensor.shape(), &[2, 2]);
}
#[cfg(all(target_os = "macos", feature = "metal"))]
#[test]
fn test_add_tensors() {
let metal_impl = MetalImpl::new().expect("operation failed in test");
let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], &[4]).expect("tensor operation failed");
let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], &[4]).expect("tensor operation failed");
let result = metal_impl.add_tensors(&a, &b);
assert!(result.is_ok());
let result_tensor = result.expect("tensor operation failed");
assert_eq!(result_tensor.shape(), &[4]);
let expected = [6.0, 8.0, 10.0, 12.0];
let actual = result_tensor.data_f32().expect("tensor operation failed");
for (a, e) in actual.iter().zip(expected.iter()) {
assert!((a - e).abs() < 1e-6);
}
}
}