1use crate::array::Array;
2use anyhow::{Result, anyhow};
3use once_cell::sync::OnceCell;
4
5#[cfg(target_os = "macos")]
6use metal::*;
7#[cfg(target_os = "macos")]
8use objc::rc::autoreleasepool;
9#[cfg(target_os = "macos")]
10use std::sync::Mutex;
11#[cfg(target_os = "macos")]
12use std::collections::HashMap;
13
14#[derive(Debug, Clone)]
15pub struct MetalBackend {}
16
17impl MetalBackend {
18 pub fn new() -> Self { Self {} }
19
20 pub fn is_available() -> bool {
22 #[cfg(target_os = "macos")]
23 {
24 Device::system_default().is_some()
25 }
26 #[cfg(not(target_os = "macos"))]
27 {
28 false
29 }
30 }
31}
32
33#[cfg(target_os = "macos")]
38struct MetalContext {
39 device: Device,
40 queue: CommandQueue,
41 max_threads_per_threadgroup: u64,
42 _elementwise_vec4_pipeline: ComputePipelineState,
44 elementwise_scalar_pipeline: ComputePipelineState,
45 matmul_pipeline_cache: Mutex<HashMap<(u32, u32), ComputePipelineState>>,
46 reduction_pipeline: ComputePipelineState,
47 buffer_pool: Mutex<BufferPool>,
49}
50
51#[cfg(target_os = "macos")]
52struct BufferPool {
53 free_buffers: HashMap<usize, Vec<Buffer>>,
54 max_cached_size: usize,
55}
56
57#[cfg(target_os = "macos")]
58impl BufferPool {
59 fn new() -> Self {
60 Self {
61 free_buffers: HashMap::new(),
62 max_cached_size: 100 * 1024 * 1024, }
64 }
65
66 fn get_or_create(&mut self, device: &Device, size: usize, mode: MTLResourceOptions) -> Buffer {
67 let bucket_size = size.next_power_of_two();
69
70 if bucket_size <= self.max_cached_size {
71 if let Some(buffers) = self.free_buffers.get_mut(&bucket_size) {
72 if let Some(buffer) = buffers.pop() {
73 return buffer;
74 }
75 }
76 }
77
78 device.new_buffer(bucket_size as u64, mode)
79 }
80
81 fn return_buffer(&mut self, buffer: Buffer, size: usize) {
82 let bucket_size = size.next_power_of_two();
83
84 if bucket_size <= self.max_cached_size {
85 self.free_buffers.entry(bucket_size)
86 .or_insert_with(Vec::new)
87 .push(buffer);
88 }
89 }
90}
91
92#[cfg(target_os = "macos")]
93static METAL_DEVICE: OnceCell<Result<MetalContext, anyhow::Error>> = OnceCell::new();
94
95#[cfg(target_os = "macos")]
96fn get_metal_device() -> Result<&'static MetalContext> {
97 METAL_DEVICE.get_or_init(|| {
98 autoreleasepool(|| {
99 let device = Device::system_default()
100 .ok_or_else(|| anyhow!("No Metal device available"))?;
101 let queue = device.new_command_queue();
102
103 let max_threads_per_threadgroup = device.max_threads_per_threadgroup().width;
105
106 let elementwise_vec4_pipeline = compile_elementwise_vec4_pipeline(&device)?;
108 let elementwise_scalar_pipeline = compile_elementwise_scalar_pipeline(&device)?;
109 let reduction_pipeline = compile_reduction_pipeline(&device)?;
110
111 Ok(MetalContext {
112 device,
113 queue,
114 max_threads_per_threadgroup,
115 _elementwise_vec4_pipeline: elementwise_vec4_pipeline,
116 elementwise_scalar_pipeline,
117 matmul_pipeline_cache: Mutex::new(HashMap::new()),
118 reduction_pipeline,
119 buffer_pool: Mutex::new(BufferPool::new()),
120 })
121 })
122 });
123
124 match METAL_DEVICE.get().unwrap() {
125 Ok(ctx) => Ok(ctx),
126 Err(e) => Err(anyhow!("Metal init failed: {:?}", e)),
127 }
128}
129
130#[cfg(target_os = "macos")]
135fn compile_elementwise_vec4_pipeline(device: &Device) -> Result<ComputePipelineState> {
136 let shader_src = r#"
137#include <metal_stdlib>
138using namespace metal;
139
140constant uint OP_ADD = 0; constant uint OP_MUL = 1; constant uint OP_SUB = 2; constant uint OP_DIV = 3;
141constant uint OP_SQRT = 4; constant uint OP_SIN = 5; constant uint OP_COS = 6; constant uint OP_POW = 7;
142constant uint OP_ABS = 8; constant uint OP_EXP = 9; constant uint OP_LOG = 10; constant uint OP_TAN = 11;
143constant uint OP_ASIN = 12; constant uint OP_ACOS = 13; constant uint OP_ATAN = 14; constant uint OP_RELU = 15;
144constant uint OP_LEAKY_RELU = 16; constant uint OP_SIGMOID = 17; constant uint OP_TANH = 18; constant uint OP_SOFTPLUS = 19;
145
146struct Params { uint size; uint op_kind; };
147
148kernel void elementwise_vec4(
149 device const float4* a [[buffer(0)]],
150 device const float4* b [[buffer(1)]],
151 device float4* out [[buffer(2)]],
152 constant Params& params [[buffer(3)]],
153 uint idx [[thread_position_in_grid]]
154) {
155 if (idx >= params.size / 4) return;
156 float4 a_val = a[idx];
157 float4 b_val = b[idx];
158 float4 result;
159
160 switch(params.op_kind) {
161 case OP_ADD: result = a_val + b_val; break;
162 case OP_MUL: result = a_val * b_val; break;
163 case OP_SUB: result = a_val - b_val; break;
164 case OP_DIV: result = a_val / b_val; break;
165 case OP_SQRT: result = fast::sqrt(a_val); break;
166 case OP_SIN: result = fast::sin(a_val); break;
167 case OP_COS: result = fast::cos(a_val); break;
168 case OP_POW: result = fast::pow(a_val, b_val); break;
169 case OP_ABS: result = fast::fabs(a_val); break;
170 case OP_EXP: result = fast::exp(a_val); break;
171 case OP_LOG: result = fast::log(a_val); break;
172 case OP_TAN: result = fast::tan(a_val); break;
173 case OP_ASIN: result = asin(a_val); break;
174 case OP_ACOS: result = acos(a_val); break;
175 case OP_ATAN: result = atan(a_val); break;
176 case OP_RELU: result = fast::max(a_val, float4(0.0f)); break;
177 case OP_LEAKY_RELU: result = select(float4(0.01f) * a_val, a_val, a_val > float4(0.0f)); break;
178 case OP_SIGMOID: result = 1.0f / (1.0f + fast::exp(-a_val)); break;
179 case OP_TANH: result = fast::tanh(a_val); break;
180 case OP_SOFTPLUS: result = fast::log(1.0f + fast::exp(a_val)); break;
181 default: result = a_val; break;
182 }
183 out[idx] = result;
184}
185"#;
186
187 let library = device.new_library_with_source(shader_src, &CompileOptions::new())
188 .map_err(|e| anyhow!("Failed to compile vec4 shader: {}", e))?;
189 let kernel = library.get_function("elementwise_vec4", None)
190 .map_err(|e| anyhow!("Failed to get vec4 kernel: {}", e))?;
191 device.new_compute_pipeline_state_with_function(&kernel)
192 .map_err(|e| anyhow!("Failed to create vec4 pipeline: {}", e))
193}
194
195#[cfg(target_os = "macos")]
196fn compile_elementwise_scalar_pipeline(device: &Device) -> Result<ComputePipelineState> {
197 let shader_src = r#"
198#include <metal_stdlib>
199using namespace metal;
200
201constant uint OP_ADD = 0; constant uint OP_MUL = 1; constant uint OP_SUB = 2; constant uint OP_DIV = 3;
202constant uint OP_SQRT = 4; constant uint OP_SIN = 5; constant uint OP_COS = 6; constant uint OP_POW = 7;
203constant uint OP_ABS = 8; constant uint OP_EXP = 9; constant uint OP_LOG = 10; constant uint OP_TAN = 11;
204constant uint OP_ASIN = 12; constant uint OP_ACOS = 13; constant uint OP_ATAN = 14; constant uint OP_RELU = 15;
205constant uint OP_LEAKY_RELU = 16; constant uint OP_SIGMOID = 17; constant uint OP_TANH = 18; constant uint OP_SOFTPLUS = 19;
206
207struct Params { uint size; uint op_kind; };
208
209kernel void elementwise_scalar(
210 device const float* a [[buffer(0)]],
211 device const float* b [[buffer(1)]],
212 device float* out [[buffer(2)]],
213 constant Params& params [[buffer(3)]],
214 uint idx [[thread_position_in_grid]]
215) {
216 if (idx >= params.size) return;
217 float a_val = a[idx];
218 float b_val = b[idx];
219 float result;
220
221 switch(params.op_kind) {
222 case OP_ADD: result = a_val + b_val; break;
223 case OP_MUL: result = a_val * b_val; break;
224 case OP_SUB: result = a_val - b_val; break;
225 case OP_DIV: result = a_val / b_val; break;
226 case OP_SQRT: result = fast::sqrt(a_val); break;
227 case OP_SIN: result = fast::sin(a_val); break;
228 case OP_COS: result = fast::cos(a_val); break;
229 case OP_POW: result = fast::pow(a_val, b_val); break;
230 case OP_ABS: result = fast::fabs(a_val); break;
231 case OP_EXP: result = fast::exp(a_val); break;
232 case OP_LOG: result = fast::log(a_val); break;
233 case OP_TAN: result = fast::tan(a_val); break;
234 case OP_ASIN: result = asin(a_val); break;
235 case OP_ACOS: result = acos(a_val); break;
236 case OP_ATAN: result = atan(a_val); break;
237 case OP_RELU: result = fast::max(a_val, 0.0f); break;
238 case OP_LEAKY_RELU: result = (a_val > 0.0f) ? a_val : 0.01f * a_val; break;
239 case OP_SIGMOID: result = 1.0f / (1.0f + fast::exp(-a_val)); break;
240 case OP_TANH: result = fast::tanh(a_val); break;
241 case OP_SOFTPLUS: result = fast::log(1.0f + fast::exp(a_val)); break;
242 default: result = a_val; break;
243 }
244 out[idx] = result;
245}
246"#;
247
248 let library = device.new_library_with_source(shader_src, &CompileOptions::new())
249 .map_err(|e| anyhow!("Failed to compile scalar shader: {}", e))?;
250 let kernel = library.get_function("elementwise_scalar", None)
251 .map_err(|e| anyhow!("Failed to get scalar kernel: {}", e))?;
252 device.new_compute_pipeline_state_with_function(&kernel)
253 .map_err(|e| anyhow!("Failed to create scalar pipeline: {}", e))
254}
255
256#[cfg(target_os = "macos")]
257fn compile_reduction_pipeline(device: &Device) -> Result<ComputePipelineState> {
258 let shader_src = r#"
260#include <metal_stdlib>
261using namespace metal;
262
263constant uint WG_SIZE = 256;
264
265struct Params {
266 uint size;
267};
268
269kernel void reduction_sum(
270 device const float* data [[buffer(0)]],
271 device float* partials [[buffer(1)]],
272 constant Params& params [[buffer(2)]],
273 uint gid [[thread_position_in_grid]],
274 uint lid [[thread_position_in_threadgroup]],
275 uint group_id [[threadgroup_position_in_grid]]
276) {
277 threadgroup float shared[WG_SIZE];
278
279 // Load data
280 float value = 0.0f;
281 if (gid < params.size) {
282 value = data[gid];
283 }
284 shared[lid] = value;
285
286 threadgroup_barrier(mem_flags::mem_threadgroup);
287
288 // Tree reduction in threadgroup memory
289 for (uint s = WG_SIZE / 2; s > 0; s >>= 1) {
290 if (lid < s) {
291 shared[lid] += shared[lid + s];
292 }
293 threadgroup_barrier(mem_flags::mem_threadgroup);
294 }
295
296 // First thread writes result
297 if (lid == 0) {
298 partials[group_id] = shared[0];
299 }
300}
301"#;
302
303 let library = device.new_library_with_source(shader_src, &CompileOptions::new())
304 .map_err(|e| anyhow!("Failed to compile reduction shader: {}", e))?;
305
306 let kernel = library.get_function("reduction_sum", None)
307 .map_err(|e| anyhow!("Failed to get reduction kernel: {}", e))?;
308
309 device.new_compute_pipeline_state_with_function(&kernel)
310 .map_err(|e| anyhow!("Failed to create reduction pipeline: {}", e))
311}
312
313#[cfg(target_os = "macos")]
314fn get_or_compile_matmul_pipeline(device: &Device, tile_size: u32) -> Result<ComputePipelineState> {
315 let shader_src = format!(r#"
316#include <metal_stdlib>
317using namespace metal;
318
319constant uint TILE = {tile};
320
321struct Params {{
322 uint m;
323 uint n;
324 uint k;
325}};
326
327kernel void matmul_tiled(
328 device const float* a [[buffer(0)]],
329 device const float* b [[buffer(1)]],
330 device float* out [[buffer(2)]],
331 constant Params& params [[buffer(3)]],
332 uint2 gid [[thread_position_in_grid]]
333) {{
334 // Simple safe implementation: each thread computes one output element
335 uint row = gid.y;
336 uint col = gid.x;
337
338 if (row >= params.m || col >= params.n) {{
339 return;
340 }}
341
342 float sum = 0.0f;
343 for (uint k = 0; k < params.k; k++) {{
344 sum = fast::fma(a[row * params.k + k], b[k * params.n + col], sum);
345 }}
346
347 out[row * params.n + col] = sum;
348}}
349"#, tile = tile_size);
350
351 let library = device.new_library_with_source(&shader_src, &CompileOptions::new())
352 .map_err(|e| anyhow!("Failed to compile matmul shader: {}", e))?;
353
354 let kernel = library.get_function("matmul_tiled", None)
355 .map_err(|e| anyhow!("Failed to get matmul kernel: {}", e))?;
356
357 device.new_compute_pipeline_state_with_function(&kernel)
358 .map_err(|e| anyhow!("Failed to create matmul pipeline: {}", e))
359}
360
361pub fn is_available_cached() -> bool {
363 static PROBE: OnceCell<bool> = OnceCell::new();
364 *PROBE.get_or_init(|| MetalBackend::is_available())
365}
366
367pub fn elementwise_metal(a: &Array, b: &Array, kind: crate::llo::ElementwiseKind) -> Result<Array> {
373 #[cfg(target_os = "macos")]
374 {
375 run_elementwise_metal_optimized(a, b, kind)
376 }
377 #[cfg(not(target_os = "macos"))]
378 {
379 let _ = (a, b, kind);
380 Err(anyhow!("Metal backend only available on macOS"))
381 }
382}
383
384pub fn matmul_metal(a: &Array, b: &Array) -> Result<Array> {
386 #[cfg(target_os = "macos")]
387 {
388 run_matmul_metal_optimized(a, b)
389 }
390 #[cfg(not(target_os = "macos"))]
391 {
392 let _ = (a, b);
393 Err(anyhow!("Metal backend only available on macOS"))
394 }
395}
396
397pub fn reduction_metal(a: &Array, axis: Option<usize>) -> Result<Array> {
399 #[cfg(target_os = "macos")]
400 {
401 run_reduction_metal_optimized(a, axis)
402 }
403 #[cfg(not(target_os = "macos"))]
404 {
405 let _ = (a, axis);
406 Err(anyhow!("Metal backend only available on macOS"))
407 }
408}
409
410#[cfg(target_os = "macos")]
415fn kind_to_u32(kind: crate::llo::ElementwiseKind) -> u32 {
416 use crate::llo::ElementwiseKind::*;
417 match kind {
418 Add => 0, Mul => 1, Sub => 2, Div => 3,
419 Sqrt => 4, Sin => 5, Cos => 6, Pow => 7,
420 Abs => 8, Exp => 9, Log => 10, Tan => 11,
421 Asin => 12, Acos => 13, Atan => 14,
422 Relu => 15, LeakyRelu => 16, Sigmoid => 17,
423 Tanh => 18, Softplus => 19, Neg => 20,
424 }
425}
426
427#[cfg(target_os = "macos")]
428fn run_elementwise_metal_optimized(a: &Array, b: &Array, kind: crate::llo::ElementwiseKind) -> Result<Array> {
429 let ctx = get_metal_device()?;
430 let len = a.len();
431
432 let command_buffer = ctx.queue.new_command_buffer();
433
434 let a_bytes: &[u8] = unsafe {
435 std::slice::from_raw_parts(a.data.as_ptr() as *const u8, a.data.len() * std::mem::size_of::<f32>())
436 };
437 let b_bytes: &[u8] = unsafe {
438 std::slice::from_raw_parts(b.data.as_ptr() as *const u8, b.data.len() * std::mem::size_of::<f32>())
439 };
440
441 let mut pool = ctx.buffer_pool.lock().unwrap();
443
444 let a_buf = {
445 let buf = pool.get_or_create(&ctx.device, a_bytes.len(), MTLResourceOptions::StorageModeShared);
446 unsafe {
447 std::ptr::copy_nonoverlapping(
448 a_bytes.as_ptr(),
449 buf.contents() as *mut u8,
450 a_bytes.len()
451 );
452 }
453 buf
454 };
455
456 let b_buf = {
457 let buf = pool.get_or_create(&ctx.device, b_bytes.len(), MTLResourceOptions::StorageModeShared);
458 unsafe {
459 std::ptr::copy_nonoverlapping(
460 b_bytes.as_ptr(),
461 buf.contents() as *mut u8,
462 b_bytes.len()
463 );
464 }
465 buf
466 };
467
468 let out_buf = pool.get_or_create(
469 &ctx.device,
470 len * std::mem::size_of::<f32>(),
471 MTLResourceOptions::StorageModeShared
472 );
473
474 drop(pool); let encoder = command_buffer.new_compute_command_encoder();
478 let op_kind = kind_to_u32(kind);
479
480 let params = [len as u32, op_kind];
481 let params_bytes: &[u8] = unsafe {
482 std::slice::from_raw_parts(params.as_ptr() as *const u8, params.len() * std::mem::size_of::<u32>())
483 };
484 let params_buf = ctx.device.new_buffer_with_data(
485 params_bytes.as_ptr() as *const _,
486 params_bytes.len() as u64,
487 MTLResourceOptions::StorageModeShared,
488 );
489
490 encoder.set_compute_pipeline_state(&ctx.elementwise_scalar_pipeline);
491 encoder.set_buffer(0, Some(&a_buf), 0);
492 encoder.set_buffer(1, Some(&b_buf), 0);
493 encoder.set_buffer(2, Some(&out_buf), 0);
494 encoder.set_buffer(3, Some(¶ms_buf), 0);
495
496 let thread_count = MTLSize::new(len as u64, 1, 1);
497 let thread_group_size = MTLSize::new(ctx.max_threads_per_threadgroup.min(256), 1, 1);
498 encoder.dispatch_threads(thread_count, thread_group_size);
499
500 encoder.end_encoding();
501
502 command_buffer.commit();
503 command_buffer.wait_until_completed();
504
505 let out_ptr = out_buf.contents() as *const f32;
507 let out_slice = unsafe { std::slice::from_raw_parts(out_ptr, len) };
508 let result = out_slice.to_vec();
509
510 let mut pool = ctx.buffer_pool.lock().unwrap();
512 pool.return_buffer(a_buf, a_bytes.len());
513 pool.return_buffer(b_buf, b_bytes.len());
514 pool.return_buffer(out_buf, len * std::mem::size_of::<f32>());
515
516 Ok(Array::new(a.shape.clone(), result))
517}
518
519#[cfg(target_os = "macos")]
520fn run_matmul_metal_optimized(a: &Array, b: &Array) -> Result<Array> {
521 let ctx = get_metal_device()?;
522
523 let m = a.shape[0] as u32;
524 let k = a.shape[1] as u32;
525 let n = b.shape[1] as u32;
526 let len = (m * n) as usize;
527
528 let pipeline = {
530 let mut cache = ctx.matmul_pipeline_cache.lock().unwrap();
531 if let Some(p) = cache.get(&(16, 0)) {
532 p.clone()
533 } else {
534 let p = get_or_compile_matmul_pipeline(&ctx.device, 16)?;
535 cache.insert((16, 0), p.clone());
536 p
537 }
538 };
539
540 let a_bytes: &[u8] = unsafe {
541 std::slice::from_raw_parts(a.data.as_ptr() as *const u8, a.data.len() * std::mem::size_of::<f32>())
542 };
543 let b_bytes: &[u8] = unsafe {
544 std::slice::from_raw_parts(b.data.as_ptr() as *const u8, b.data.len() * std::mem::size_of::<f32>())
545 };
546
547 let mut pool = ctx.buffer_pool.lock().unwrap();
549
550 let a_buf = {
551 let buf = pool.get_or_create(&ctx.device, a_bytes.len(), MTLResourceOptions::StorageModeShared);
552 unsafe {
553 std::ptr::copy_nonoverlapping(
554 a_bytes.as_ptr(),
555 buf.contents() as *mut u8,
556 a_bytes.len()
557 );
558 }
559 buf
560 };
561
562 let b_buf = {
563 let buf = pool.get_or_create(&ctx.device, b_bytes.len(), MTLResourceOptions::StorageModeShared);
564 unsafe {
565 std::ptr::copy_nonoverlapping(
566 b_bytes.as_ptr(),
567 buf.contents() as *mut u8,
568 b_bytes.len()
569 );
570 }
571 buf
572 };
573
574 let out_buf = pool.get_or_create(
575 &ctx.device,
576 len * std::mem::size_of::<f32>(),
577 MTLResourceOptions::StorageModeShared,
578 );
579
580 drop(pool); let command_buffer = ctx.queue.new_command_buffer();
583 let encoder = command_buffer.new_compute_command_encoder();
584
585 let params = [m, n, k];
586 let params_bytes: &[u8] = unsafe {
587 std::slice::from_raw_parts(params.as_ptr() as *const u8, params.len() * std::mem::size_of::<u32>())
588 };
589 let params_buf = ctx.device.new_buffer_with_data(
590 params_bytes.as_ptr() as *const _,
591 params_bytes.len() as u64,
592 MTLResourceOptions::StorageModeShared,
593 );
594
595 encoder.set_compute_pipeline_state(&pipeline);
596 encoder.set_buffer(0, Some(&a_buf), 0);
597 encoder.set_buffer(1, Some(&b_buf), 0);
598 encoder.set_buffer(2, Some(&out_buf), 0);
599 encoder.set_buffer(3, Some(¶ms_buf), 0);
600
601 let thread_group_size = MTLSize::new(16, 16, 1);
603 let grid_size = MTLSize::new(n as u64, m as u64, 1);
604 encoder.dispatch_threads(grid_size, thread_group_size);
605
606 encoder.end_encoding();
607
608 command_buffer.commit();
609 command_buffer.wait_until_completed();
610
611 let out_ptr = out_buf.contents() as *const f32;
613 let out_slice = unsafe { std::slice::from_raw_parts(out_ptr, len) };
614 let result = out_slice.to_vec();
615
616 let mut pool = ctx.buffer_pool.lock().unwrap();
618 pool.return_buffer(a_buf, a_bytes.len());
619 pool.return_buffer(b_buf, b_bytes.len());
620 pool.return_buffer(out_buf, len * std::mem::size_of::<f32>());
621
622 Ok(Array::new(vec![m as usize, n as usize], result))
623}
624
625#[cfg(target_os = "macos")]
626fn run_reduction_metal_optimized(a: &Array, axis: Option<usize>) -> Result<Array> {
627 let ctx = get_metal_device()?;
628
629 if axis.is_some() {
630 return Err(anyhow!("axis-based reduction not implemented in Metal prototype"));
631 }
632
633 let size = a.len() as u32;
634 if size == 0 {
635 return Ok(Array::new(vec![1], vec![0.0]));
636 }
637
638 const WG_SIZE: u32 = 256;
639
640 let data_bytes: &[u8] = unsafe {
641 std::slice::from_raw_parts(a.data.as_ptr() as *const u8, a.data.len() * std::mem::size_of::<f32>())
642 };
643
644 let mut pool = ctx.buffer_pool.lock().unwrap();
646
647 let in_buf = {
649 let buf = pool.get_or_create(&ctx.device, data_bytes.len(), MTLResourceOptions::StorageModeShared);
650 unsafe {
651 std::ptr::copy_nonoverlapping(
652 data_bytes.as_ptr(),
653 buf.contents() as *mut u8,
654 data_bytes.len()
655 );
656 }
657 buf
658 };
659
660 let mut max_groups = size;
662 let mut temp_size = max_groups;
663 while temp_size > 1 {
664 temp_size = (temp_size + WG_SIZE - 1) / WG_SIZE;
665 max_groups = max_groups.max(temp_size);
666 }
667
668 let temp_buf1 = pool.get_or_create(&ctx.device, max_groups as usize * std::mem::size_of::<f32>(), MTLResourceOptions::StorageModeShared);
670 let temp_buf2 = pool.get_or_create(&ctx.device, max_groups as usize * std::mem::size_of::<f32>(), MTLResourceOptions::StorageModeShared);
671
672 drop(pool); let command_buffer = ctx.queue.new_command_buffer();
676 let encoder = command_buffer.new_compute_command_encoder();
677
678 let mut current_size = size;
679 let mut iteration = 0;
680
681 loop {
682 let groups = ((current_size + WG_SIZE - 1) / WG_SIZE) as u32;
683 let is_final = groups == 1;
684
685 let params = [current_size];
686 let params_bytes: &[u8] = unsafe {
687 std::slice::from_raw_parts(params.as_ptr() as *const u8, params.len() * std::mem::size_of::<u32>())
688 };
689 let params_buf = ctx.device.new_buffer_with_data(
690 params_bytes.as_ptr() as *const _,
691 params_bytes.len() as u64,
692 MTLResourceOptions::StorageModeShared,
693 );
694
695 encoder.set_compute_pipeline_state(&ctx.reduction_pipeline);
696
697 let (input_buf, output_buf) = if iteration == 0 {
699 (&in_buf, &temp_buf1)
700 } else if iteration % 2 == 1 {
701 (&temp_buf1, &temp_buf2)
702 } else {
703 (&temp_buf2, &temp_buf1)
704 };
705
706 encoder.set_buffer(0, Some(input_buf), 0);
707 encoder.set_buffer(1, Some(output_buf), 0);
708 encoder.set_buffer(2, Some(¶ms_buf), 0);
709
710 let thread_count = MTLSize::new((groups * WG_SIZE) as u64, 1, 1);
711 let thread_group_size = MTLSize::new(WG_SIZE as u64, 1, 1);
712 encoder.dispatch_threads(thread_count, thread_group_size);
713
714 if is_final {
715 encoder.end_encoding();
716 break;
717 }
718
719 iteration += 1;
720 current_size = groups;
721 }
722
723 command_buffer.commit();
724 command_buffer.wait_until_completed();
725
726 let final_buf = if iteration == 0 {
728 &temp_buf1
729 } else if iteration % 2 == 1 {
730 &temp_buf2
731 } else {
732 &temp_buf1
733 };
734 let out_ptr = final_buf.contents() as *const f32;
735 let final_value = unsafe { *out_ptr };
736
737 let mut pool = ctx.buffer_pool.lock().unwrap();
739 pool.return_buffer(in_buf, data_bytes.len());
740 pool.return_buffer(temp_buf1, max_groups as usize * std::mem::size_of::<f32>());
741 pool.return_buffer(temp_buf2, max_groups as usize * std::mem::size_of::<f32>());
742
743 Ok(Array::new(vec![1], vec![final_value]))
744}