Skip to main content

ringkernel_metal/
lib.rs

1//! Metal Backend for RingKernel
2//!
3//! This crate provides Apple Metal GPU support for RingKernel.
4//! Supports macOS, iOS, and Apple Silicon.
5//!
6//! # Features
7//!
8//! - Event-driven kernel execution (Metal compute shaders)
9//! - MSL (Metal Shading Language) support
10//! - Apple Silicon optimization
11//! - Unified memory architecture support
12//!
13//! # Limitations
14//!
15//! - No true persistent kernels (Metal doesn't support cooperative groups)
16//! - macOS/iOS only
17//!
18//! # Example
19//!
20//! ```ignore
21//! use ringkernel_metal::MetalRuntime;
22//!
23//! #[tokio::main]
24//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
25//!     let runtime = MetalRuntime::new().await?;
26//!     let kernel = runtime.launch("compute", Default::default()).await?;
27//!     kernel.activate().await?;
28//!     Ok(())
29//! }
30//! ```
31
32#![warn(missing_docs)]
33
34#[cfg(all(target_os = "macos", feature = "metal"))]
35mod device;
36#[cfg(all(target_os = "macos", feature = "metal"))]
37mod kernel;
38#[cfg(all(target_os = "macos", feature = "metal"))]
39mod memory;
40#[cfg(all(target_os = "macos", feature = "metal"))]
41mod runtime;
42
43#[cfg(all(target_os = "macos", feature = "metal"))]
44pub use device::MetalDevice;
45#[cfg(all(target_os = "macos", feature = "metal"))]
46pub use kernel::{
47    HaloExchangeConfig, HaloExchangeStats, MetalHaloExchange, MetalHaloMessage,
48    MetalK2KInboxHeader, MetalK2KRouteEntry, MetalK2KRoutingTable, MetalKernel,
49};
50#[cfg(all(target_os = "macos", feature = "metal"))]
51pub use memory::MetalBuffer;
52#[cfg(all(target_os = "macos", feature = "metal"))]
53pub use runtime::MetalRuntime;
54
55// Stub implementation when Metal is not available
56#[cfg(not(all(target_os = "macos", feature = "metal")))]
57mod stub {
58    ringkernel_core::unavailable_backend!(
59        MetalRuntime,
60        ringkernel_core::runtime::Backend::Metal,
61        "Metal"
62    );
63}
64
65#[cfg(not(all(target_os = "macos", feature = "metal")))]
66pub use stub::MetalRuntime;
67
68/// Check if Metal is available at runtime.
69pub fn is_metal_available() -> bool {
70    #[cfg(all(target_os = "macos", feature = "metal"))]
71    {
72        metal::Device::system_default().is_some()
73    }
74    #[cfg(not(all(target_os = "macos", feature = "metal")))]
75    {
76        false
77    }
78}
79
80/// MSL (Metal Shading Language) kernel template.
81pub const RING_KERNEL_MSL_TEMPLATE: &str = r#"
82//
83// RingKernel Metal Shading Language Template
84// Generated by ringkernel-metal
85//
86
87#include <metal_stdlib>
88using namespace metal;
89
90// Control block structure (128 bytes)
91struct ControlBlock {
92    atomic_uint is_active;
93    atomic_uint should_terminate;
94    atomic_uint has_terminated;
95    uint _pad1;
96
97    atomic_ulong messages_processed;
98    atomic_ulong messages_in_flight;
99
100    atomic_ulong input_head;
101    atomic_ulong input_tail;
102    atomic_ulong output_head;
103    atomic_ulong output_tail;
104
105    uint input_capacity;
106    uint output_capacity;
107    uint input_mask;
108    uint output_mask;
109
110    // HLC state
111    atomic_ulong hlc_physical;
112    atomic_ulong hlc_logical;
113
114    atomic_uint last_error;
115    atomic_uint error_count;
116
117    uchar _reserved[16];
118};
119
120// Message header structure (256 bytes)
121struct MessageHeader {
122    ulong magic;
123    uint version;
124    uint flags;
125    ulong message_id;
126    ulong correlation_id;
127    ulong source_kernel;
128    ulong dest_kernel;
129    ulong message_type;
130    uchar priority;
131    uchar _reserved1[7];
132    ulong payload_size;
133    uint checksum;
134    uint _reserved2;
135    // HLC timestamp (24 bytes)
136    ulong ts_physical;
137    ulong ts_logical;
138    ulong ts_node_id;
139    // Deadline
140    ulong deadline_physical;
141    ulong deadline_logical;
142    ulong deadline_node_id;
143    uchar _reserved3[104];
144};
145
146// Kernel entry point
147kernel void ring_kernel_main(
148    device ControlBlock* control [[buffer(0)]],
149    device uchar* input_queue [[buffer(1)]],
150    device uchar* output_queue [[buffer(2)]],
151    device uchar* shared_state [[buffer(3)]],
152    uint thread_id [[thread_position_in_threadgroup]],
153    uint threadgroup_id [[threadgroup_position_in_grid]],
154    uint threads_per_group [[threads_per_threadgroup]]
155) {
156    // Check if kernel should process
157    uint is_active = atomic_load_explicit(&control->is_active, memory_order_acquire);
158    if (is_active == 0) {
159        return;
160    }
161
162    // Check termination
163    uint should_term = atomic_load_explicit(&control->should_terminate, memory_order_acquire);
164    if (should_term != 0) {
165        if (thread_id == 0 && threadgroup_id == 0) {
166            atomic_store_explicit(&control->has_terminated, 1, memory_order_release);
167        }
168        return;
169    }
170
171    // User kernel code will be inserted here
172    // USER_KERNEL_CODE
173
174    // Update message counter
175    if (thread_id == 0 && threadgroup_id == 0) {
176        atomic_fetch_add_explicit(&control->messages_processed, 1, memory_order_relaxed);
177    }
178}
179"#;
180
181/// MSL (Metal Shading Language) K2K Halo Exchange template.
182///
183/// This template provides kernel-to-kernel communication for stencil computations.
184/// Each threadgroup can exchange halo data with its neighbors.
185pub const K2K_HALO_EXCHANGE_MSL_TEMPLATE: &str = r#"
186//
187// RingKernel Metal K2K Halo Exchange Template
188// Generated by ringkernel-metal
189//
190
191#include <metal_stdlib>
192using namespace metal;
193
194// K2K Inbox Header (64 bytes)
195struct K2KInboxHeader {
196    atomic_uint message_count;
197    uint max_messages;
198    atomic_uint head;
199    atomic_uint tail;
200    uint last_source;
201    atomic_uint lock;
202    atomic_uint sequence;
203    uint _reserved[9];
204};
205
206// K2K Route Entry (32 bytes)
207struct K2KRouteEntry {
208    uint dest_threadgroup;
209    uint inbox_offset;
210    uint is_active;
211    uint hops;
212    uint bandwidth_hint;
213    uint priority;
214    uint _reserved[2];
215};
216
217// K2K Routing Table
218struct K2KRoutingTable {
219    uint self_id;
220    uint route_count;
221    uint grid_dim_x;
222    uint grid_dim_y;
223    uint grid_dim_z;
224    uint _reserved[3];
225    K2KRouteEntry routes[26];  // Max neighbors for 3D Moore neighborhood
226};
227
228// Halo Message Header (32 bytes)
229struct HaloMessageHeader {
230    uint source;
231    uint direction;
232    uint width;
233    uint height;
234    uint depth;
235    uint element_size;
236    uint sequence;
237    uint flags;
238};
239
240// Try to acquire inbox lock
241bool k2k_try_lock(device K2KInboxHeader* inbox) {
242    uint expected = 0;
243    return atomic_compare_exchange_weak_explicit(
244        &inbox->lock, &expected, 1,
245        memory_order_acquire, memory_order_relaxed
246    );
247}
248
249// Release inbox lock
250void k2k_unlock(device K2KInboxHeader* inbox) {
251    atomic_store_explicit(&inbox->lock, 0, memory_order_release);
252}
253
254// Send halo data to neighbor
255bool k2k_send_halo(
256    device K2KRoutingTable* routing,
257    device uchar* inbox_buffer,
258    uint dest_id,
259    device float* halo_data,
260    uint width,
261    uint height,
262    uint depth,
263    uint direction,
264    uint thread_id
265) {
266    // Only thread 0 performs the send
267    if (thread_id != 0) return true;
268
269    // Find route to destination
270    for (uint i = 0; i < routing->route_count; i++) {
271        if (routing->routes[i].dest_threadgroup == dest_id &&
272            routing->routes[i].is_active != 0) {
273
274            uint offset = routing->routes[i].inbox_offset;
275            device K2KInboxHeader* inbox = (device K2KInboxHeader*)(inbox_buffer + offset);
276
277            // Try to acquire lock
278            if (!k2k_try_lock(inbox)) {
279                return false;  // Inbox busy
280            }
281
282            // Check if inbox has space
283            uint count = atomic_load_explicit(&inbox->message_count, memory_order_acquire);
284            if (count >= inbox->max_messages) {
285                k2k_unlock(inbox);
286                return false;  // Inbox full
287            }
288
289            // Write message header
290            uint msg_offset = offset + 64 + count * (32 + width * height * depth * 4);
291            device HaloMessageHeader* msg = (device HaloMessageHeader*)(inbox_buffer + msg_offset);
292            msg->source = routing->self_id;
293            msg->direction = direction;
294            msg->width = width;
295            msg->height = height;
296            msg->depth = depth;
297            msg->element_size = 4;
298            msg->sequence = atomic_fetch_add_explicit(&inbox->sequence, 1, memory_order_relaxed);
299            msg->flags = 0;
300
301            // Copy halo data
302            device float* payload = (device float*)(inbox_buffer + msg_offset + 32);
303            uint payload_size = width * height * depth;
304            for (uint j = 0; j < payload_size; j++) {
305                payload[j] = halo_data[j];
306            }
307
308            // Update message count
309            atomic_fetch_add_explicit(&inbox->message_count, 1, memory_order_release);
310            inbox->last_source = routing->self_id;
311
312            k2k_unlock(inbox);
313            return true;
314        }
315    }
316
317    return false;  // No route found
318}
319
320// Receive halo data from neighbors
321bool k2k_recv_halo(
322    device K2KRoutingTable* routing,
323    device uchar* inbox_buffer,
324    device float* dest_buffer,
325    uint* source_out,
326    uint* direction_out,
327    uint thread_id
328) {
329    // Only thread 0 performs the receive
330    if (thread_id != 0) return false;
331
332    uint offset = routing->self_id * 4096;  // Assume 4KB per inbox
333    device K2KInboxHeader* inbox = (device K2KInboxHeader*)(inbox_buffer + offset);
334
335    // Try to acquire lock
336    if (!k2k_try_lock(inbox)) {
337        return false;
338    }
339
340    // Check if inbox has messages
341    uint count = atomic_load_explicit(&inbox->message_count, memory_order_acquire);
342    if (count == 0) {
343        k2k_unlock(inbox);
344        return false;
345    }
346
347    // Read oldest message (FIFO)
348    uint head = atomic_load_explicit(&inbox->head, memory_order_acquire);
349    uint msg_offset = offset + 64 + head * 4064;  // 32 header + max 4032 payload
350    device HaloMessageHeader* msg = (device HaloMessageHeader*)(inbox_buffer + msg_offset);
351
352    *source_out = msg->source;
353    *direction_out = msg->direction;
354
355    // Copy halo data
356    device float* payload = (device float*)(inbox_buffer + msg_offset + 32);
357    uint payload_size = msg->width * msg->height * msg->depth;
358    for (uint j = 0; j < payload_size; j++) {
359        dest_buffer[j] = payload[j];
360    }
361
362    // Update head and count
363    atomic_fetch_add_explicit(&inbox->head, 1, memory_order_relaxed);
364    atomic_fetch_sub_explicit(&inbox->message_count, 1, memory_order_release);
365
366    k2k_unlock(inbox);
367    return true;
368}
369
370// Halo exchange kernel - sends halo data to all neighbors
371kernel void k2k_halo_exchange(
372    device K2KRoutingTable* routing [[buffer(0)]],
373    device uchar* inbox_buffer [[buffer(1)]],
374    device float* local_data [[buffer(2)]],
375    constant uint& tile_width [[buffer(3)]],
376    constant uint& tile_height [[buffer(4)]],
377    constant uint& halo_size [[buffer(5)]],
378    uint thread_id [[thread_position_in_threadgroup]],
379    uint threadgroup_id [[threadgroup_position_in_grid]]
380) {
381    // Extract halos from local data and send to neighbors
382    // Direction: 0=North, 1=South, 2=West, 3=East, 4=Up, 5=Down
383    // Note: Only thread 0 performs sends; other threads synchronize
384
385    uint tw = tile_width;
386    uint th = tile_height;
387    uint self_id = routing->self_id;
388
389    // Calculate neighbor IDs based on grid position
390    uint gx = routing->grid_dim_x;
391    uint gy = routing->grid_dim_y;
392    uint x = self_id % gx;
393    uint y = self_id / gx;
394
395    // Threadgroup-local halo buffer for column gather
396    threadgroup float column_halo[256];  // Max halo height
397
398    // Send North halo (top row) to neighbor above (y-1)
399    if (y > 0) {
400        uint north_neighbor = self_id - gx;
401        device float* north_halo = local_data + halo_size * tw;  // First interior row
402        k2k_send_halo(routing, inbox_buffer, north_neighbor,
403                      north_halo, tw, halo_size, 1, 0, thread_id);  // dir=0 (North)
404    }
405
406    // Send South halo (bottom row) to neighbor below (y+1)
407    if (y < gy - 1) {
408        uint south_neighbor = self_id + gx;
409        device float* south_halo = local_data + (th - 2 * halo_size) * tw;  // Last interior row
410        k2k_send_halo(routing, inbox_buffer, south_neighbor,
411                      south_halo, tw, halo_size, 1, 1, thread_id);  // dir=1 (South)
412    }
413
414    // Send West halo (left column) to neighbor left (x-1)
415    // Gather column data to threadgroup memory first
416    if (x > 0) {
417        uint west_neighbor = self_id - 1;
418        // Gather left interior column
419        if (thread_id < th && thread_id < 256) {
420            column_halo[thread_id] = local_data[thread_id * tw + halo_size];
421        }
422        threadgroup_barrier(mem_flags::mem_threadgroup);
423
424        // Thread 0 sends the gathered column
425        k2k_send_halo(routing, inbox_buffer, west_neighbor,
426                      column_halo, halo_size, th, 1, 2, thread_id);  // dir=2 (West)
427    }
428
429    // Send East halo (right column) to neighbor right (x+1)
430    if (x < gx - 1) {
431        uint east_neighbor = self_id + 1;
432        // Gather right interior column
433        if (thread_id < th && thread_id < 256) {
434            column_halo[thread_id] = local_data[thread_id * tw + (tw - 2 * halo_size)];
435        }
436        threadgroup_barrier(mem_flags::mem_threadgroup);
437
438        // Thread 0 sends the gathered column
439        k2k_send_halo(routing, inbox_buffer, east_neighbor,
440                      column_halo, halo_size, th, 1, 3, thread_id);  // dir=3 (East)
441    }
442
443    threadgroup_barrier(mem_flags::mem_device);
444}
445
446// Halo apply kernel - receives halo data and applies to ghost cells
447kernel void k2k_halo_apply(
448    device K2KRoutingTable* routing [[buffer(0)]],
449    device uchar* inbox_buffer [[buffer(1)]],
450    device float* local_data [[buffer(2)]],
451    constant uint& tile_width [[buffer(3)]],
452    constant uint& tile_height [[buffer(4)]],
453    constant uint& halo_size [[buffer(5)]],
454    uint thread_id [[thread_position_in_threadgroup]],
455    uint threadgroup_id [[threadgroup_position_in_grid]]
456) {
457    // Receive halo data from neighbors and apply to local ghost cells
458    // Thread 0 receives messages, then all threads cooperate to apply them
459
460    uint tw = tile_width;
461    uint th = tile_height;
462
463    // Threadgroup-shared receive buffer
464    threadgroup float recv_buffer[256];  // Max halo size
465    threadgroup uint msg_source;
466    threadgroup uint msg_direction;
467    threadgroup bool has_message;
468
469    // Keep receiving until inbox is empty
470    while (true) {
471        // Thread 0 attempts to receive
472        if (thread_id == 0) {
473            has_message = k2k_recv_halo(routing, inbox_buffer, recv_buffer, &msg_source, &msg_direction, 0);
474        }
475        threadgroup_barrier(mem_flags::mem_threadgroup);
476
477        if (!has_message) break;
478
479        // All threads cooperate to apply the received halo
480        switch (msg_direction) {
481            case 0: {
482                // From North - apply to top ghost row (row 0)
483                // Received data is a row of width tw
484                if (thread_id < tw) {
485                    for (uint h = 0; h < halo_size; h++) {
486                        local_data[h * tw + thread_id] = recv_buffer[h * tw + thread_id];
487                    }
488                }
489                break;
490            }
491            case 1: {
492                // From South - apply to bottom ghost row (row th-halo_size to th-1)
493                if (thread_id < tw) {
494                    for (uint h = 0; h < halo_size; h++) {
495                        uint row = th - halo_size + h;
496                        local_data[row * tw + thread_id] = recv_buffer[h * tw + thread_id];
497                    }
498                }
499                break;
500            }
501            case 2: {
502                // From West - apply to left ghost column (col 0)
503                if (thread_id < th) {
504                    for (uint h = 0; h < halo_size; h++) {
505                        local_data[thread_id * tw + h] = recv_buffer[thread_id];
506                    }
507                }
508                break;
509            }
510            case 3: {
511                // From East - apply to right ghost column (col tw-halo_size to tw-1)
512                if (thread_id < th) {
513                    for (uint h = 0; h < halo_size; h++) {
514                        uint col = tw - halo_size + h;
515                        local_data[thread_id * tw + col] = recv_buffer[thread_id];
516                    }
517                }
518                break;
519            }
520            case 4: {
521                // From Up - apply to top ghost plane (3D)
522                // Would need depth dimension; placeholder for 3D support
523                break;
524            }
525            case 5: {
526                // From Down - apply to bottom ghost plane (3D)
527                // Would need depth dimension; placeholder for 3D support
528                break;
529            }
530        }
531
532        threadgroup_barrier(mem_flags::mem_threadgroup);
533    }
534
535    threadgroup_barrier(mem_flags::mem_device);
536}
537"#;