1use crate::tensor::Tensor;
6use crate::error::{GhostError, Result};
7
8pub struct MetalDevice {
10 pub device_id: usize,
11 pub name: String,
12 pub is_low_power: bool,
13 pub supports_family_apple7: bool,
14}
15
16impl MetalDevice {
17 pub fn new(device_id: usize) -> Result<Self> {
19 #[cfg(feature = "metal")]
20 {
21 Ok(MetalDevice {
25 device_id,
26 name: "Apple GPU".to_string(),
27 is_low_power: false,
28 supports_family_apple7: true,
29 })
30 }
31 #[cfg(not(feature = "metal"))]
32 {
33 Err(GhostError::DeviceError(
34 "Metal support not compiled. Enable 'metal' feature.".to_string()
35 ))
36 }
37 }
38
39 pub fn device_count() -> Result<usize> {
41 #[cfg(all(feature = "metal", target_os = "macos"))]
42 {
43 Ok(1) }
46 #[cfg(not(all(feature = "metal", target_os = "macos")))]
47 {
48 Ok(0)
49 }
50 }
51
52 pub fn supports_neural_engine(&self) -> bool {
54 #[cfg(feature = "metal")]
55 {
56 true }
59 #[cfg(not(feature = "metal"))]
60 {
61 false
62 }
63 }
64}
65
66pub struct MetalBuffer {
68 size: usize,
69 device_id: usize,
70}
71
72impl MetalBuffer {
73 pub fn allocate(size: usize, device_id: usize) -> Result<Self> {
75 #[cfg(feature = "metal")]
76 {
77 Ok(MetalBuffer { size, device_id })
79 }
80 #[cfg(not(feature = "metal"))]
81 {
82 let _ = (size, device_id);
83 Err(GhostError::DeviceError("Metal not available".to_string()))
84 }
85 }
86
87 pub fn copy_from_host(&mut self, data: &[f32]) -> Result<()> {
89 #[cfg(feature = "metal")]
90 {
91 if data.len() * std::mem::size_of::<f32>() > self.size {
93 return Err(GhostError::DeviceError("Buffer too small".to_string()));
94 }
95 Ok(())
96 }
97 #[cfg(not(feature = "metal"))]
98 {
99 let _ = data;
100 Err(GhostError::DeviceError("Metal not available".to_string()))
101 }
102 }
103
104 pub fn copy_to_host(&self, data: &mut [f32]) -> Result<()> {
106 #[cfg(feature = "metal")]
107 {
108 if data.len() * std::mem::size_of::<f32>() > self.size {
110 return Err(GhostError::DeviceError("Buffer too small".to_string()));
111 }
112 Ok(())
113 }
114 #[cfg(not(feature = "metal"))]
115 {
116 let _ = data;
117 Err(GhostError::DeviceError("Metal not available".to_string()))
118 }
119 }
120}
121
122pub struct MetalPipeline {
124 name: String,
125 thread_group_size: (usize, usize, usize),
126}
127
128impl MetalPipeline {
129 pub fn new(name: &str) -> Self {
131 MetalPipeline {
132 name: name.to_string(),
133 thread_group_size: (256, 1, 1),
134 }
135 }
136
137 pub fn thread_group_size(mut self, x: usize, y: usize, z: usize) -> Self {
139 self.thread_group_size = (x, y, z);
140 self
141 }
142
143 pub fn dispatch(&self, grid_size: (usize, usize, usize)) -> Result<()> {
145 #[cfg(feature = "metal")]
146 {
147 let _ = grid_size;
156 Ok(())
157 }
158 #[cfg(not(feature = "metal"))]
159 {
160 let _ = grid_size;
161 Err(GhostError::DeviceError("Metal not available".to_string()))
162 }
163 }
164}
165
166pub mod mps {
168 use super::*;
169
170 pub fn matmul_mps(a: &Tensor, b: &Tensor, device_id: usize) -> Result<Tensor> {
172 let dims_a = a.dims();
173 let dims_b = b.dims();
174
175 if dims_a.len() != 2 || dims_b.len() != 2 {
176 return Err(GhostError::InvalidShape("matmul requires 2D tensors".to_string()));
177 }
178
179 let (m, k) = (dims_a[0], dims_a[1]);
180 let (k2, n) = (dims_b[0], dims_b[1]);
181
182 if k != k2 {
183 return Err(GhostError::ShapeMismatch {
184 expected: vec![k],
185 got: vec![k2],
186 });
187 }
188
189 #[cfg(feature = "metal")]
190 {
191 let size_a = m * k * std::mem::size_of::<f32>();
193 let size_b = k * n * std::mem::size_of::<f32>();
194 let size_c = m * n * std::mem::size_of::<f32>();
195
196 let mut buf_a = MetalBuffer::allocate(size_a, device_id)?;
197 let mut buf_b = MetalBuffer::allocate(size_b, device_id)?;
198 let buf_c = MetalBuffer::allocate(size_c, device_id)?;
199
200 buf_a.copy_from_host(&a.data_f32())?;
202 buf_b.copy_from_host(&b.data_f32())?;
203
204 let pipeline = MetalPipeline::new("matmul_kernel")
206 .thread_group_size(16, 16, 1);
207
208 let grid_x = (n + 15) / 16;
209 let grid_y = (m + 15) / 16;
210 pipeline.dispatch((grid_x, grid_y, 1))?;
211
212 let mut result_data = vec![0.0f32; m * n];
214 buf_c.copy_to_host(&mut result_data)?;
215
216 Tensor::from_slice(&result_data, &[m, n])
217 }
218 #[cfg(not(feature = "metal"))]
219 {
220 let _ = device_id;
221 #[cfg(target_arch = "aarch64")]
223 {
224 let a_data = a.data_f32();
225 let b_data = b.data_f32();
226 let mut result = vec![0.0f32; m * n];
227 crate::neon::matmul_neon(&a_data, &b_data, &mut result, m, n, k);
228 Tensor::from_slice(&result, &[m, n])
229 }
230 #[cfg(not(target_arch = "aarch64"))]
231 {
232 a.matmul(b)
233 }
234 }
235 }
236
237 pub fn conv2d_mps(
239 input: &Tensor,
240 kernel: &Tensor,
241 stride: (usize, usize),
242 padding: (usize, usize),
243 device_id: usize,
244 ) -> Result<Tensor> {
245 let input_dims = input.dims();
246 let kernel_dims = kernel.dims();
247
248 if input_dims.len() != 4 || kernel_dims.len() != 4 {
249 return Err(GhostError::InvalidShape("conv2d requires 4D tensors [N,C,H,W]".to_string()));
250 }
251
252 #[cfg(feature = "metal")]
253 {
254 let _ = (input, kernel, stride, padding, device_id);
255 Err(GhostError::NotImplemented("Metal conv2d - use CPU fallback".to_string()))
257 }
258 #[cfg(not(feature = "metal"))]
259 {
260 let _ = (input, kernel, stride, padding, device_id);
261 Err(GhostError::DeviceError("Metal not available".to_string()))
262 }
263 }
264
265 pub fn relu_mps(input: &Tensor, device_id: usize) -> Result<Tensor> {
267 let data = input.data_f32();
268 let size = data.len();
269
270 #[cfg(feature = "metal")]
271 {
272 let buf_size = size * std::mem::size_of::<f32>();
273 let mut buf = MetalBuffer::allocate(buf_size, device_id)?;
274 buf.copy_from_host(&data)?;
275
276 let pipeline = MetalPipeline::new("relu_kernel")
278 .thread_group_size(256, 1, 1);
279
280 let grid_size = (size + 255) / 256;
281 pipeline.dispatch((grid_size, 1, 1))?;
282
283 let mut result = vec![0.0f32; size];
284 buf.copy_to_host(&mut result)?;
285
286 Tensor::from_slice(&result, input.dims())
287 }
288 #[cfg(not(feature = "metal"))]
289 {
290 let _ = device_id;
291 #[cfg(target_arch = "aarch64")]
293 {
294 Ok(input.relu_neon())
295 }
296 #[cfg(not(target_arch = "aarch64"))]
297 {
298 Ok(input.relu())
299 }
300 }
301 }
302
303 pub fn batch_norm_mps(input: &Tensor, device_id: usize) -> Result<Tensor> {
305 #[cfg(feature = "metal")]
306 {
307 let _ = (input, device_id);
308 Err(GhostError::NotImplemented("Metal batch norm".to_string()))
309 }
310 #[cfg(not(feature = "metal"))]
311 {
312 let _ = (input, device_id);
313 Err(GhostError::DeviceError("Metal not available".to_string()))
314 }
315 }
316}
317
318#[cfg(feature = "metal")]
320pub const METAL_KERNEL_SOURCE: &str = r#"
321#include <metal_stdlib>
322using namespace metal;
323
324// Vector addition kernel
325kernel void vector_add(
326 device const float* a [[buffer(0)]],
327 device const float* b [[buffer(1)]],
328 device float* c [[buffer(2)]],
329 uint id [[thread_position_in_grid]]
330) {
331 c[id] = a[id] + b[id];
332}
333
334// ReLU activation kernel
335kernel void relu_kernel(
336 device float* data [[buffer(0)]],
337 uint id [[thread_position_in_grid]]
338) {
339 data[id] = max(0.0f, data[id]);
340}
341
342// Matrix multiplication kernel
343kernel void matmul_kernel(
344 device const float* A [[buffer(0)]],
345 device const float* B [[buffer(1)]],
346 device float* C [[buffer(2)]],
347 constant uint& M [[buffer(3)]],
348 constant uint& N [[buffer(4)]],
349 constant uint& K [[buffer(5)]],
350 uint2 gid [[thread_position_in_grid]]
351) {
352 uint row = gid.y;
353 uint col = gid.x;
354
355 if (row < M && col < N) {
356 float sum = 0.0f;
357 for (uint k = 0; k < K; k++) {
358 sum += A[row * K + k] * B[k * N + col];
359 }
360 C[row * N + col] = sum;
361 }
362}
363
364// Softmax kernel
365kernel void softmax_kernel(
366 device const float* input [[buffer(0)]],
367 device float* output [[buffer(1)]],
368 constant uint& size [[buffer(2)]],
369 uint id [[thread_position_in_grid]]
370) {
371 // Find max for numerical stability
372 float max_val = input[0];
373 for (uint i = 1; i < size; i++) {
374 max_val = max(max_val, input[i]);
375 }
376
377 // Compute exp and sum
378 float sum = 0.0f;
379 for (uint i = 0; i < size; i++) {
380 sum += exp(input[i] - max_val);
381 }
382
383 // Normalize
384 output[id] = exp(input[id] - max_val) / sum;
385}
386"#;
387
388#[cfg(feature = "metal")]
390pub mod neural_engine {
391 use super::*;
392
393 pub fn is_available() -> bool {
395 #[cfg(target_arch = "aarch64")]
397 {
398 true
399 }
400 #[cfg(not(target_arch = "aarch64"))]
401 {
402 false
403 }
404 }
405
406 pub fn run_inference(model: &str, input: &Tensor) -> Result<Tensor> {
408 if !is_available() {
409 return Err(GhostError::DeviceError("Neural Engine not available".to_string()));
410 }
411
412 let _ = (model, input);
417 Err(GhostError::NotImplemented("Neural Engine inference".to_string()))
418 }
419}
420
421#[cfg(test)]
422mod tests {
423 use super::*;
424
425 #[test]
426 fn test_metal_device_count() {
427 let count = MetalDevice::device_count().unwrap_or(0);
428 assert!(count >= 0);
430 }
431
432 #[test]
433 #[cfg(all(feature = "metal", target_os = "macos"))]
434 fn test_metal_device_creation() {
435 if let Ok(device) = MetalDevice::new(0) {
436 assert_eq!(device.device_id, 0);
437 assert!(!device.name.is_empty());
438 }
439 }
440
441 #[test]
442 fn test_metal_pipeline() {
443 let pipeline = MetalPipeline::new("test_kernel")
444 .thread_group_size(256, 1, 1);
445
446 assert_eq!(pipeline.thread_group_size, (256, 1, 1));
447 }
448
449 #[test]
450 #[cfg(feature = "metal")]
451 fn test_neural_engine_availability() {
452 let available = neural_engine::is_available();
453 #[cfg(target_arch = "aarch64")]
455 assert!(available);
456 }
457}