numrs/backend/metal/
mod.rs

1use crate::array::Array;
2use anyhow::{Result, anyhow};
3use once_cell::sync::OnceCell;
4
5#[cfg(target_os = "macos")]
6use metal::*;
7#[cfg(target_os = "macos")]
8use objc::rc::autoreleasepool;
9#[cfg(target_os = "macos")]
10use std::sync::Mutex;
11#[cfg(target_os = "macos")]
12use std::collections::HashMap;
13
14#[derive(Debug, Clone)]
15pub struct MetalBackend {}
16
17impl MetalBackend {
18    pub fn new() -> Self { Self {} }
19
20    /// Check if Metal is available (macOS only)
21    pub fn is_available() -> bool {
22        #[cfg(target_os = "macos")]
23        {
24            Device::system_default().is_some()
25        }
26        #[cfg(not(target_os = "macos"))]
27        {
28            false
29        }
30    }
31}
32
33// ============================================================================
34// Optimized Metal Context with Caching
35// ============================================================================
36
37#[cfg(target_os = "macos")]
38struct MetalContext {
39    device: Device,
40    queue: CommandQueue,
41    max_threads_per_threadgroup: u64,
42    // Cached pipelines
43    _elementwise_vec4_pipeline: ComputePipelineState,
44    elementwise_scalar_pipeline: ComputePipelineState,
45    matmul_pipeline_cache: Mutex<HashMap<(u32, u32), ComputePipelineState>>,
46    reduction_pipeline: ComputePipelineState,
47    // Buffer pool
48    buffer_pool: Mutex<BufferPool>,
49}
50
51#[cfg(target_os = "macos")]
52struct BufferPool {
53    free_buffers: HashMap<usize, Vec<Buffer>>,
54    max_cached_size: usize,
55}
56
57#[cfg(target_os = "macos")]
58impl BufferPool {
59    fn new() -> Self {
60        Self {
61            free_buffers: HashMap::new(),
62            max_cached_size: 100 * 1024 * 1024, // 100 MB max per size bucket
63        }
64    }
65
66    fn get_or_create(&mut self, device: &Device, size: usize, mode: MTLResourceOptions) -> Buffer {
67        // Round up to nearest power of 2 for better reuse
68        let bucket_size = size.next_power_of_two();
69        
70        if bucket_size <= self.max_cached_size {
71            if let Some(buffers) = self.free_buffers.get_mut(&bucket_size) {
72                if let Some(buffer) = buffers.pop() {
73                    return buffer;
74                }
75            }
76        }
77        
78        device.new_buffer(bucket_size as u64, mode)
79    }
80
81    fn return_buffer(&mut self, buffer: Buffer, size: usize) {
82        let bucket_size = size.next_power_of_two();
83        
84        if bucket_size <= self.max_cached_size {
85            self.free_buffers.entry(bucket_size)
86                .or_insert_with(Vec::new)
87                .push(buffer);
88        }
89    }
90}
91
92#[cfg(target_os = "macos")]
93static METAL_DEVICE: OnceCell<Result<MetalContext, anyhow::Error>> = OnceCell::new();
94
95#[cfg(target_os = "macos")]
96fn get_metal_device() -> Result<&'static MetalContext> {
97    METAL_DEVICE.get_or_init(|| {
98        autoreleasepool(|| {
99            let device = Device::system_default()
100                .ok_or_else(|| anyhow!("No Metal device available"))?;
101            let queue = device.new_command_queue();
102
103            // Get device capabilities
104            let max_threads_per_threadgroup = device.max_threads_per_threadgroup().width;
105
106            // Compile all shaders ONCE
107            let elementwise_vec4_pipeline = compile_elementwise_vec4_pipeline(&device)?;
108            let elementwise_scalar_pipeline = compile_elementwise_scalar_pipeline(&device)?;
109            let reduction_pipeline = compile_reduction_pipeline(&device)?;
110
111            Ok(MetalContext {
112                device,
113                queue,
114                max_threads_per_threadgroup,
115                _elementwise_vec4_pipeline: elementwise_vec4_pipeline,
116                elementwise_scalar_pipeline,
117                matmul_pipeline_cache: Mutex::new(HashMap::new()),
118                reduction_pipeline,
119                buffer_pool: Mutex::new(BufferPool::new()),
120            })
121        })
122    });
123
124    match METAL_DEVICE.get().unwrap() {
125        Ok(ctx) => Ok(ctx),
126        Err(e) => Err(anyhow!("Metal init failed: {:?}", e)),
127    }
128}
129
130// ============================================================================
131// Shader Compilation (ONCE per shader type)
132// ============================================================================
133
134#[cfg(target_os = "macos")]
135fn compile_elementwise_vec4_pipeline(device: &Device) -> Result<ComputePipelineState> {
136    let shader_src = r#"
137#include <metal_stdlib>
138using namespace metal;
139
140constant uint OP_ADD = 0; constant uint OP_MUL = 1; constant uint OP_SUB = 2; constant uint OP_DIV = 3;
141constant uint OP_SQRT = 4; constant uint OP_SIN = 5; constant uint OP_COS = 6; constant uint OP_POW = 7;
142constant uint OP_ABS = 8; constant uint OP_EXP = 9; constant uint OP_LOG = 10; constant uint OP_TAN = 11;
143constant uint OP_ASIN = 12; constant uint OP_ACOS = 13; constant uint OP_ATAN = 14; constant uint OP_RELU = 15;
144constant uint OP_LEAKY_RELU = 16; constant uint OP_SIGMOID = 17; constant uint OP_TANH = 18; constant uint OP_SOFTPLUS = 19;
145
146struct Params { uint size; uint op_kind; };
147
148kernel void elementwise_vec4(
149    device const float4* a [[buffer(0)]],
150    device const float4* b [[buffer(1)]],
151    device float4* out [[buffer(2)]],
152    constant Params& params [[buffer(3)]],
153    uint idx [[thread_position_in_grid]]
154) {
155    if (idx >= params.size / 4) return;
156    float4 a_val = a[idx];
157    float4 b_val = b[idx];
158    float4 result;
159    
160    switch(params.op_kind) {
161        case OP_ADD: result = a_val + b_val; break;
162        case OP_MUL: result = a_val * b_val; break;
163        case OP_SUB: result = a_val - b_val; break;
164        case OP_DIV: result = a_val / b_val; break;
165        case OP_SQRT: result = fast::sqrt(a_val); break;
166        case OP_SIN: result = fast::sin(a_val); break;
167        case OP_COS: result = fast::cos(a_val); break;
168        case OP_POW: result = fast::pow(a_val, b_val); break;
169        case OP_ABS: result = fast::fabs(a_val); break;
170        case OP_EXP: result = fast::exp(a_val); break;
171        case OP_LOG: result = fast::log(a_val); break;
172        case OP_TAN: result = fast::tan(a_val); break;
173        case OP_ASIN: result = asin(a_val); break;
174        case OP_ACOS: result = acos(a_val); break;
175        case OP_ATAN: result = atan(a_val); break;
176        case OP_RELU: result = fast::max(a_val, float4(0.0f)); break;
177        case OP_LEAKY_RELU: result = select(float4(0.01f) * a_val, a_val, a_val > float4(0.0f)); break;
178        case OP_SIGMOID: result = 1.0f / (1.0f + fast::exp(-a_val)); break;
179        case OP_TANH: result = fast::tanh(a_val); break;
180        case OP_SOFTPLUS: result = fast::log(1.0f + fast::exp(a_val)); break;
181        default: result = a_val; break;
182    }
183    out[idx] = result;
184}
185"#;
186
187    let library = device.new_library_with_source(shader_src, &CompileOptions::new())
188        .map_err(|e| anyhow!("Failed to compile vec4 shader: {}", e))?;
189    let kernel = library.get_function("elementwise_vec4", None)
190        .map_err(|e| anyhow!("Failed to get vec4 kernel: {}", e))?;
191    device.new_compute_pipeline_state_with_function(&kernel)
192        .map_err(|e| anyhow!("Failed to create vec4 pipeline: {}", e))
193}
194
195#[cfg(target_os = "macos")]
196fn compile_elementwise_scalar_pipeline(device: &Device) -> Result<ComputePipelineState> {
197    let shader_src = r#"
198#include <metal_stdlib>
199using namespace metal;
200
201constant uint OP_ADD = 0; constant uint OP_MUL = 1; constant uint OP_SUB = 2; constant uint OP_DIV = 3;
202constant uint OP_SQRT = 4; constant uint OP_SIN = 5; constant uint OP_COS = 6; constant uint OP_POW = 7;
203constant uint OP_ABS = 8; constant uint OP_EXP = 9; constant uint OP_LOG = 10; constant uint OP_TAN = 11;
204constant uint OP_ASIN = 12; constant uint OP_ACOS = 13; constant uint OP_ATAN = 14; constant uint OP_RELU = 15;
205constant uint OP_LEAKY_RELU = 16; constant uint OP_SIGMOID = 17; constant uint OP_TANH = 18; constant uint OP_SOFTPLUS = 19;
206
207struct Params { uint size; uint op_kind; };
208
209kernel void elementwise_scalar(
210    device const float* a [[buffer(0)]],
211    device const float* b [[buffer(1)]],
212    device float* out [[buffer(2)]],
213    constant Params& params [[buffer(3)]],
214    uint idx [[thread_position_in_grid]]
215) {
216    if (idx >= params.size) return;
217    float a_val = a[idx];
218    float b_val = b[idx];
219    float result;
220    
221    switch(params.op_kind) {
222        case OP_ADD: result = a_val + b_val; break;
223        case OP_MUL: result = a_val * b_val; break;
224        case OP_SUB: result = a_val - b_val; break;
225        case OP_DIV: result = a_val / b_val; break;
226        case OP_SQRT: result = fast::sqrt(a_val); break;
227        case OP_SIN: result = fast::sin(a_val); break;
228        case OP_COS: result = fast::cos(a_val); break;
229        case OP_POW: result = fast::pow(a_val, b_val); break;
230        case OP_ABS: result = fast::fabs(a_val); break;
231        case OP_EXP: result = fast::exp(a_val); break;
232        case OP_LOG: result = fast::log(a_val); break;
233        case OP_TAN: result = fast::tan(a_val); break;
234        case OP_ASIN: result = asin(a_val); break;
235        case OP_ACOS: result = acos(a_val); break;
236        case OP_ATAN: result = atan(a_val); break;
237        case OP_RELU: result = fast::max(a_val, 0.0f); break;
238        case OP_LEAKY_RELU: result = (a_val > 0.0f) ? a_val : 0.01f * a_val; break;
239        case OP_SIGMOID: result = 1.0f / (1.0f + fast::exp(-a_val)); break;
240        case OP_TANH: result = fast::tanh(a_val); break;
241        case OP_SOFTPLUS: result = fast::log(1.0f + fast::exp(a_val)); break;
242        default: result = a_val; break;
243    }
244    out[idx] = result;
245}
246"#;
247
248    let library = device.new_library_with_source(shader_src, &CompileOptions::new())
249        .map_err(|e| anyhow!("Failed to compile scalar shader: {}", e))?;
250    let kernel = library.get_function("elementwise_scalar", None)
251        .map_err(|e| anyhow!("Failed to get scalar kernel: {}", e))?;
252    device.new_compute_pipeline_state_with_function(&kernel)
253        .map_err(|e| anyhow!("Failed to create scalar pipeline: {}", e))
254}
255
256#[cfg(target_os = "macos")]
257fn compile_reduction_pipeline(device: &Device) -> Result<ComputePipelineState> {
258    // Optimized reduction using manual threadgroup reduction
259    let shader_src = r#"
260#include <metal_stdlib>
261using namespace metal;
262
263constant uint WG_SIZE = 256;
264
265struct Params {
266    uint size;
267};
268
269kernel void reduction_sum(
270    device const float* data [[buffer(0)]],
271    device float* partials [[buffer(1)]],
272    constant Params& params [[buffer(2)]],
273    uint gid [[thread_position_in_grid]],
274    uint lid [[thread_position_in_threadgroup]],
275    uint group_id [[threadgroup_position_in_grid]]
276) {
277    threadgroup float shared[WG_SIZE];
278    
279    // Load data
280    float value = 0.0f;
281    if (gid < params.size) {
282        value = data[gid];
283    }
284    shared[lid] = value;
285    
286    threadgroup_barrier(mem_flags::mem_threadgroup);
287    
288    // Tree reduction in threadgroup memory
289    for (uint s = WG_SIZE / 2; s > 0; s >>= 1) {
290        if (lid < s) {
291            shared[lid] += shared[lid + s];
292        }
293        threadgroup_barrier(mem_flags::mem_threadgroup);
294    }
295    
296    // First thread writes result
297    if (lid == 0) {
298        partials[group_id] = shared[0];
299    }
300}
301"#;
302
303    let library = device.new_library_with_source(shader_src, &CompileOptions::new())
304        .map_err(|e| anyhow!("Failed to compile reduction shader: {}", e))?;
305    
306    let kernel = library.get_function("reduction_sum", None)
307        .map_err(|e| anyhow!("Failed to get reduction kernel: {}", e))?;
308    
309    device.new_compute_pipeline_state_with_function(&kernel)
310        .map_err(|e| anyhow!("Failed to create reduction pipeline: {}", e))
311}
312
313#[cfg(target_os = "macos")]
314fn get_or_compile_matmul_pipeline(device: &Device, tile_size: u32) -> Result<ComputePipelineState> {
315    let shader_src = format!(r#"
316#include <metal_stdlib>
317using namespace metal;
318
319constant uint TILE = {tile};
320
321struct Params {{
322    uint m;
323    uint n;
324    uint k;
325}};
326
327kernel void matmul_tiled(
328    device const float* a [[buffer(0)]],
329    device const float* b [[buffer(1)]],
330    device float* out [[buffer(2)]],
331    constant Params& params [[buffer(3)]],
332    uint2 gid [[thread_position_in_grid]]
333) {{
334    // Simple safe implementation: each thread computes one output element
335    uint row = gid.y;
336    uint col = gid.x;
337    
338    if (row >= params.m || col >= params.n) {{
339        return;
340    }}
341    
342    float sum = 0.0f;
343    for (uint k = 0; k < params.k; k++) {{
344        sum = fast::fma(a[row * params.k + k], b[k * params.n + col], sum);
345    }}
346    
347    out[row * params.n + col] = sum;
348}}
349"#, tile = tile_size);
350
351    let library = device.new_library_with_source(&shader_src, &CompileOptions::new())
352        .map_err(|e| anyhow!("Failed to compile matmul shader: {}", e))?;
353    
354    let kernel = library.get_function("matmul_tiled", None)
355        .map_err(|e| anyhow!("Failed to get matmul kernel: {}", e))?;
356    
357    device.new_compute_pipeline_state_with_function(&kernel)
358        .map_err(|e| anyhow!("Failed to create matmul pipeline: {}", e))
359}
360
361/// Cached probe helper
362pub fn is_available_cached() -> bool {
363    static PROBE: OnceCell<bool> = OnceCell::new();
364    *PROBE.get_or_init(|| MetalBackend::is_available())
365}
366
367// ============================================================================
368// Public API for Dispatch System
369// ============================================================================
370
371/// Elementwise operations on Metal (public API for dispatch)
372pub fn elementwise_metal(a: &Array, b: &Array, kind: crate::llo::ElementwiseKind) -> Result<Array> {
373    #[cfg(target_os = "macos")]
374    {
375        run_elementwise_metal_optimized(a, b, kind)
376    }
377    #[cfg(not(target_os = "macos"))]
378    {
379        let _ = (a, b, kind);
380        Err(anyhow!("Metal backend only available on macOS"))
381    }
382}
383
384/// Matrix multiplication on Metal (public API for dispatch)
385pub fn matmul_metal(a: &Array, b: &Array) -> Result<Array> {
386    #[cfg(target_os = "macos")]
387    {
388        run_matmul_metal_optimized(a, b)
389    }
390    #[cfg(not(target_os = "macos"))]
391    {
392        let _ = (a, b);
393        Err(anyhow!("Metal backend only available on macOS"))
394    }
395}
396
397/// Reduction operations on Metal (public API for dispatch)
398pub fn reduction_metal(a: &Array, axis: Option<usize>) -> Result<Array> {
399    #[cfg(target_os = "macos")]
400    {
401        run_reduction_metal_optimized(a, axis)
402    }
403    #[cfg(not(target_os = "macos"))]
404    {
405        let _ = (a, axis);
406        Err(anyhow!("Metal backend only available on macOS"))
407    }
408}
409
410// ============================================================================
411// Optimized Implementation (macOS only)
412// ============================================================================
413
414#[cfg(target_os = "macos")]
415fn kind_to_u32(kind: crate::llo::ElementwiseKind) -> u32 {
416    use crate::llo::ElementwiseKind::*;
417    match kind {
418        Add => 0, Mul => 1, Sub => 2, Div => 3,
419        Sqrt => 4, Sin => 5, Cos => 6, Pow => 7,
420        Abs => 8, Exp => 9, Log => 10, Tan => 11,
421        Asin => 12, Acos => 13, Atan => 14,
422        Relu => 15, LeakyRelu => 16, Sigmoid => 17,
423        Tanh => 18, Softplus => 19, Neg => 20,
424    }
425}
426
427#[cfg(target_os = "macos")]
428fn run_elementwise_metal_optimized(a: &Array, b: &Array, kind: crate::llo::ElementwiseKind) -> Result<Array> {
429    let ctx = get_metal_device()?;
430    let len = a.len();
431
432    let command_buffer = ctx.queue.new_command_buffer();
433
434    let a_bytes: &[u8] = unsafe {
435        std::slice::from_raw_parts(a.data.as_ptr() as *const u8, a.data.len() * std::mem::size_of::<f32>())
436    };
437    let b_bytes: &[u8] = unsafe {
438        std::slice::from_raw_parts(b.data.as_ptr() as *const u8, b.data.len() * std::mem::size_of::<f32>())
439    };
440
441    // Use buffer pool with Shared mode (simple and safe)
442    let mut pool = ctx.buffer_pool.lock().unwrap();
443    
444    let a_buf = {
445        let buf = pool.get_or_create(&ctx.device, a_bytes.len(), MTLResourceOptions::StorageModeShared);
446        unsafe {
447            std::ptr::copy_nonoverlapping(
448                a_bytes.as_ptr(),
449                buf.contents() as *mut u8,
450                a_bytes.len()
451            );
452        }
453        buf
454    };
455
456    let b_buf = {
457        let buf = pool.get_or_create(&ctx.device, b_bytes.len(), MTLResourceOptions::StorageModeShared);
458        unsafe {
459            std::ptr::copy_nonoverlapping(
460                b_bytes.as_ptr(),
461                buf.contents() as *mut u8,
462                b_bytes.len()
463            );
464        }
465        buf
466    };
467
468    let out_buf = pool.get_or_create(
469        &ctx.device,
470        len * std::mem::size_of::<f32>(),
471        MTLResourceOptions::StorageModeShared
472    );
473
474    drop(pool); // Release lock
475
476    // Compute encoder - use ONLY scalar pipeline to avoid alignment issues
477    let encoder = command_buffer.new_compute_command_encoder();
478    let op_kind = kind_to_u32(kind);
479
480    let params = [len as u32, op_kind];
481    let params_bytes: &[u8] = unsafe {
482        std::slice::from_raw_parts(params.as_ptr() as *const u8, params.len() * std::mem::size_of::<u32>())
483    };
484    let params_buf = ctx.device.new_buffer_with_data(
485        params_bytes.as_ptr() as *const _,
486        params_bytes.len() as u64,
487        MTLResourceOptions::StorageModeShared,
488    );
489
490    encoder.set_compute_pipeline_state(&ctx.elementwise_scalar_pipeline);
491    encoder.set_buffer(0, Some(&a_buf), 0);
492    encoder.set_buffer(1, Some(&b_buf), 0);
493    encoder.set_buffer(2, Some(&out_buf), 0);
494    encoder.set_buffer(3, Some(&params_buf), 0);
495
496    let thread_count = MTLSize::new(len as u64, 1, 1);
497    let thread_group_size = MTLSize::new(ctx.max_threads_per_threadgroup.min(256), 1, 1);
498    encoder.dispatch_threads(thread_count, thread_group_size);
499
500    encoder.end_encoding();
501
502    command_buffer.commit();
503    command_buffer.wait_until_completed();
504
505    // Read back results
506    let out_ptr = out_buf.contents() as *const f32;
507    let out_slice = unsafe { std::slice::from_raw_parts(out_ptr, len) };
508    let result = out_slice.to_vec();
509
510    // Return buffers to pool
511    let mut pool = ctx.buffer_pool.lock().unwrap();
512    pool.return_buffer(a_buf, a_bytes.len());
513    pool.return_buffer(b_buf, b_bytes.len());
514    pool.return_buffer(out_buf, len * std::mem::size_of::<f32>());
515
516    Ok(Array::new(a.shape.clone(), result))
517}
518
519#[cfg(target_os = "macos")]
520fn run_matmul_metal_optimized(a: &Array, b: &Array) -> Result<Array> {
521    let ctx = get_metal_device()?;
522    
523    let m = a.shape[0] as u32;
524    let k = a.shape[1] as u32;
525    let n = b.shape[1] as u32;
526    let len = (m * n) as usize;
527
528    // Get pipeline (tile_size is unused in current simple implementation)
529    let pipeline = {
530        let mut cache = ctx.matmul_pipeline_cache.lock().unwrap();
531        if let Some(p) = cache.get(&(16, 0)) {
532            p.clone()
533        } else {
534            let p = get_or_compile_matmul_pipeline(&ctx.device, 16)?;
535            cache.insert((16, 0), p.clone());
536            p
537        }
538    };
539
540    let a_bytes: &[u8] = unsafe {
541        std::slice::from_raw_parts(a.data.as_ptr() as *const u8, a.data.len() * std::mem::size_of::<f32>())
542    };
543    let b_bytes: &[u8] = unsafe {
544        std::slice::from_raw_parts(b.data.as_ptr() as *const u8, b.data.len() * std::mem::size_of::<f32>())
545    };
546
547    // Use buffer pool
548    let mut pool = ctx.buffer_pool.lock().unwrap();
549    
550    let a_buf = {
551        let buf = pool.get_or_create(&ctx.device, a_bytes.len(), MTLResourceOptions::StorageModeShared);
552        unsafe {
553            std::ptr::copy_nonoverlapping(
554                a_bytes.as_ptr(),
555                buf.contents() as *mut u8,
556                a_bytes.len()
557            );
558        }
559        buf
560    };
561
562    let b_buf = {
563        let buf = pool.get_or_create(&ctx.device, b_bytes.len(), MTLResourceOptions::StorageModeShared);
564        unsafe {
565            std::ptr::copy_nonoverlapping(
566                b_bytes.as_ptr(),
567                buf.contents() as *mut u8,
568                b_bytes.len()
569            );
570        }
571        buf
572    };
573
574    let out_buf = pool.get_or_create(
575        &ctx.device,
576        len * std::mem::size_of::<f32>(),
577        MTLResourceOptions::StorageModeShared,
578    );
579
580    drop(pool); // Release lock
581
582    let command_buffer = ctx.queue.new_command_buffer();
583    let encoder = command_buffer.new_compute_command_encoder();
584
585    let params = [m, n, k];
586    let params_bytes: &[u8] = unsafe {
587        std::slice::from_raw_parts(params.as_ptr() as *const u8, params.len() * std::mem::size_of::<u32>())
588    };
589    let params_buf = ctx.device.new_buffer_with_data(
590        params_bytes.as_ptr() as *const _,
591        params_bytes.len() as u64,
592        MTLResourceOptions::StorageModeShared,
593    );
594
595    encoder.set_compute_pipeline_state(&pipeline);
596    encoder.set_buffer(0, Some(&a_buf), 0);
597    encoder.set_buffer(1, Some(&b_buf), 0);
598    encoder.set_buffer(2, Some(&out_buf), 0);
599    encoder.set_buffer(3, Some(&params_buf), 0);
600
601    // Simple 2D dispatch: one thread per output element
602    let thread_group_size = MTLSize::new(16, 16, 1);
603    let grid_size = MTLSize::new(n as u64, m as u64, 1);
604    encoder.dispatch_threads(grid_size, thread_group_size);
605    
606    encoder.end_encoding();
607
608    command_buffer.commit();
609    command_buffer.wait_until_completed();
610
611    // Read back results
612    let out_ptr = out_buf.contents() as *const f32;
613    let out_slice = unsafe { std::slice::from_raw_parts(out_ptr, len) };
614    let result = out_slice.to_vec();
615
616    // Return buffers to pool
617    let mut pool = ctx.buffer_pool.lock().unwrap();
618    pool.return_buffer(a_buf, a_bytes.len());
619    pool.return_buffer(b_buf, b_bytes.len());
620    pool.return_buffer(out_buf, len * std::mem::size_of::<f32>());
621
622    Ok(Array::new(vec![m as usize, n as usize], result))
623}
624
625#[cfg(target_os = "macos")]
626fn run_reduction_metal_optimized(a: &Array, axis: Option<usize>) -> Result<Array> {
627    let ctx = get_metal_device()?;
628
629    if axis.is_some() {
630        return Err(anyhow!("axis-based reduction not implemented in Metal prototype"));
631    }
632
633    let size = a.len() as u32;
634    if size == 0 {
635        return Ok(Array::new(vec![1], vec![0.0]));
636    }
637
638    const WG_SIZE: u32 = 256;
639
640    let data_bytes: &[u8] = unsafe {
641        std::slice::from_raw_parts(a.data.as_ptr() as *const u8, a.data.len() * std::mem::size_of::<f32>())
642    };
643
644    // Use buffer pool - ALL Shared mode for simplicity
645    let mut pool = ctx.buffer_pool.lock().unwrap();
646    
647    // Input buffer with data
648    let in_buf = {
649        let buf = pool.get_or_create(&ctx.device, data_bytes.len(), MTLResourceOptions::StorageModeShared);
650        unsafe {
651            std::ptr::copy_nonoverlapping(
652                data_bytes.as_ptr(),
653                buf.contents() as *mut u8,
654                data_bytes.len()
655            );
656        }
657        buf
658    };
659
660    // Calculate max buffer size needed
661    let mut max_groups = size;
662    let mut temp_size = max_groups;
663    while temp_size > 1 {
664        temp_size = (temp_size + WG_SIZE - 1) / WG_SIZE;
665        max_groups = max_groups.max(temp_size);
666    }
667
668    // Preallocate temp buffers (Shared mode)
669    let temp_buf1 = pool.get_or_create(&ctx.device, max_groups as usize * std::mem::size_of::<f32>(), MTLResourceOptions::StorageModeShared);
670    let temp_buf2 = pool.get_or_create(&ctx.device, max_groups as usize * std::mem::size_of::<f32>(), MTLResourceOptions::StorageModeShared);
671
672    drop(pool); // Release lock
673
674    // Single command buffer for all reduction stages
675    let command_buffer = ctx.queue.new_command_buffer();
676    let encoder = command_buffer.new_compute_command_encoder();
677    
678    let mut current_size = size;
679    let mut iteration = 0;
680
681    loop {
682        let groups = ((current_size + WG_SIZE - 1) / WG_SIZE) as u32;
683        let is_final = groups == 1;
684
685        let params = [current_size];
686        let params_bytes: &[u8] = unsafe {
687            std::slice::from_raw_parts(params.as_ptr() as *const u8, params.len() * std::mem::size_of::<u32>())
688        };
689        let params_buf = ctx.device.new_buffer_with_data(
690            params_bytes.as_ptr() as *const _,
691            params_bytes.len() as u64,
692            MTLResourceOptions::StorageModeShared,
693        );
694
695        encoder.set_compute_pipeline_state(&ctx.reduction_pipeline);
696        
697        // Determine input/output buffers
698        let (input_buf, output_buf) = if iteration == 0 {
699            (&in_buf, &temp_buf1)
700        } else if iteration % 2 == 1 {
701            (&temp_buf1, &temp_buf2)
702        } else {
703            (&temp_buf2, &temp_buf1)
704        };
705        
706        encoder.set_buffer(0, Some(input_buf), 0);
707        encoder.set_buffer(1, Some(output_buf), 0);
708        encoder.set_buffer(2, Some(&params_buf), 0);
709
710        let thread_count = MTLSize::new((groups * WG_SIZE) as u64, 1, 1);
711        let thread_group_size = MTLSize::new(WG_SIZE as u64, 1, 1);
712        encoder.dispatch_threads(thread_count, thread_group_size);
713
714        if is_final {
715            encoder.end_encoding();
716            break;
717        }
718
719        iteration += 1;
720        current_size = groups;
721    }
722
723    command_buffer.commit();
724    command_buffer.wait_until_completed();
725
726    // Read final result directly from output buffer (it's Shared mode)
727    let final_buf = if iteration == 0 {
728        &temp_buf1
729    } else if iteration % 2 == 1 {
730        &temp_buf2
731    } else {
732        &temp_buf1
733    };
734    let out_ptr = final_buf.contents() as *const f32;
735    let final_value = unsafe { *out_ptr };
736
737    // Return buffers to pool
738    let mut pool = ctx.buffer_pool.lock().unwrap();
739    pool.return_buffer(in_buf, data_bytes.len());
740    pool.return_buffer(temp_buf1, max_groups as usize * std::mem::size_of::<f32>());
741    pool.return_buffer(temp_buf2, max_groups as usize * std::mem::size_of::<f32>());
742
743    Ok(Array::new(vec![1], vec![final_value]))
744}