Skip to main content

ringkernel_metal/
lib.rs

1//! Metal Backend for RingKernel
2//!
3//! This crate provides Apple Metal GPU support for RingKernel.
4//! Supports macOS, iOS, and Apple Silicon.
5//!
6//! # Features
7//!
8//! - Event-driven kernel execution (Metal compute shaders)
9//! - MSL (Metal Shading Language) support
10//! - Apple Silicon optimization
11//! - Unified memory architecture support
12//!
13//! # Limitations
14//!
15//! - No true persistent kernels (Metal doesn't support cooperative groups)
16//! - macOS/iOS only
17//!
18//! # Example
19//!
20//! ```ignore
21//! use ringkernel_metal::MetalRuntime;
22//!
23//! #[tokio::main]
24//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
25//!     let runtime = MetalRuntime::new().await?;
26//!     let kernel = runtime.launch("compute", Default::default()).await?;
27//!     kernel.activate().await?;
28//!     Ok(())
29//! }
30//! ```
31
32#![warn(missing_docs)]
33
34#[cfg(all(target_os = "macos", feature = "metal"))]
35mod device;
36#[cfg(all(target_os = "macos", feature = "metal"))]
37mod kernel;
38#[cfg(all(target_os = "macos", feature = "metal"))]
39mod memory;
40#[cfg(all(target_os = "macos", feature = "metal"))]
41mod runtime;
42
43#[cfg(all(target_os = "macos", feature = "metal"))]
44pub use device::MetalDevice;
45#[cfg(all(target_os = "macos", feature = "metal"))]
46pub use kernel::{
47    HaloExchangeConfig, HaloExchangeStats, MetalHaloExchange, MetalHaloMessage,
48    MetalK2KInboxHeader, MetalK2KRouteEntry, MetalK2KRoutingTable, MetalKernel,
49};
50#[cfg(all(target_os = "macos", feature = "metal"))]
51pub use memory::MetalBuffer;
52#[cfg(all(target_os = "macos", feature = "metal"))]
53pub use runtime::MetalRuntime;
54
55// Stub implementation when Metal is not available
56#[cfg(not(all(target_os = "macos", feature = "metal")))]
57mod stub {
58    use async_trait::async_trait;
59    use ringkernel_core::error::{Result, RingKernelError};
60    use ringkernel_core::runtime::{
61        Backend, KernelHandle, KernelId, LaunchOptions, RingKernelRuntime, RuntimeMetrics,
62    };
63
64    /// Stub Metal runtime when not on macOS or Metal feature disabled.
65    pub struct MetalRuntime;
66
67    impl MetalRuntime {
68        /// Create fails when Metal is not available.
69        pub async fn new() -> Result<Self> {
70            Err(RingKernelError::BackendUnavailable(
71                "Metal not available (requires macOS with metal feature)".to_string(),
72            ))
73        }
74    }
75
76    #[async_trait]
77    impl RingKernelRuntime for MetalRuntime {
78        fn backend(&self) -> Backend {
79            Backend::Metal
80        }
81
82        fn is_backend_available(&self, _backend: Backend) -> bool {
83            false
84        }
85
86        async fn launch(&self, _kernel_id: &str, _options: LaunchOptions) -> Result<KernelHandle> {
87            Err(RingKernelError::BackendUnavailable("Metal".to_string()))
88        }
89
90        fn get_kernel(&self, _kernel_id: &KernelId) -> Option<KernelHandle> {
91            None
92        }
93
94        fn list_kernels(&self) -> Vec<KernelId> {
95            vec![]
96        }
97
98        fn metrics(&self) -> RuntimeMetrics {
99            RuntimeMetrics::default()
100        }
101
102        async fn shutdown(&self) -> Result<()> {
103            Ok(())
104        }
105    }
106}
107
108#[cfg(not(all(target_os = "macos", feature = "metal")))]
109pub use stub::MetalRuntime;
110
111/// Check if Metal is available at runtime.
112pub fn is_metal_available() -> bool {
113    #[cfg(all(target_os = "macos", feature = "metal"))]
114    {
115        metal::Device::system_default().is_some()
116    }
117    #[cfg(not(all(target_os = "macos", feature = "metal")))]
118    {
119        false
120    }
121}
122
123/// MSL (Metal Shading Language) kernel template.
124pub const RING_KERNEL_MSL_TEMPLATE: &str = r#"
125//
126// RingKernel Metal Shading Language Template
127// Generated by ringkernel-metal
128//
129
130#include <metal_stdlib>
131using namespace metal;
132
133// Control block structure (128 bytes)
134struct ControlBlock {
135    atomic_uint is_active;
136    atomic_uint should_terminate;
137    atomic_uint has_terminated;
138    uint _pad1;
139
140    atomic_ulong messages_processed;
141    atomic_ulong messages_in_flight;
142
143    atomic_ulong input_head;
144    atomic_ulong input_tail;
145    atomic_ulong output_head;
146    atomic_ulong output_tail;
147
148    uint input_capacity;
149    uint output_capacity;
150    uint input_mask;
151    uint output_mask;
152
153    // HLC state
154    atomic_ulong hlc_physical;
155    atomic_ulong hlc_logical;
156
157    atomic_uint last_error;
158    atomic_uint error_count;
159
160    uchar _reserved[16];
161};
162
163// Message header structure (256 bytes)
164struct MessageHeader {
165    ulong magic;
166    uint version;
167    uint flags;
168    ulong message_id;
169    ulong correlation_id;
170    ulong source_kernel;
171    ulong dest_kernel;
172    ulong message_type;
173    uchar priority;
174    uchar _reserved1[7];
175    ulong payload_size;
176    uint checksum;
177    uint _reserved2;
178    // HLC timestamp (24 bytes)
179    ulong ts_physical;
180    ulong ts_logical;
181    ulong ts_node_id;
182    // Deadline
183    ulong deadline_physical;
184    ulong deadline_logical;
185    ulong deadline_node_id;
186    uchar _reserved3[104];
187};
188
189// Kernel entry point
190kernel void ring_kernel_main(
191    device ControlBlock* control [[buffer(0)]],
192    device uchar* input_queue [[buffer(1)]],
193    device uchar* output_queue [[buffer(2)]],
194    device uchar* shared_state [[buffer(3)]],
195    uint thread_id [[thread_position_in_threadgroup]],
196    uint threadgroup_id [[threadgroup_position_in_grid]],
197    uint threads_per_group [[threads_per_threadgroup]]
198) {
199    // Check if kernel should process
200    uint is_active = atomic_load_explicit(&control->is_active, memory_order_acquire);
201    if (is_active == 0) {
202        return;
203    }
204
205    // Check termination
206    uint should_term = atomic_load_explicit(&control->should_terminate, memory_order_acquire);
207    if (should_term != 0) {
208        if (thread_id == 0 && threadgroup_id == 0) {
209            atomic_store_explicit(&control->has_terminated, 1, memory_order_release);
210        }
211        return;
212    }
213
214    // User kernel code will be inserted here
215    // USER_KERNEL_CODE
216
217    // Update message counter
218    if (thread_id == 0 && threadgroup_id == 0) {
219        atomic_fetch_add_explicit(&control->messages_processed, 1, memory_order_relaxed);
220    }
221}
222"#;
223
224/// MSL (Metal Shading Language) K2K Halo Exchange template.
225///
226/// This template provides kernel-to-kernel communication for stencil computations.
227/// Each threadgroup can exchange halo data with its neighbors.
228pub const K2K_HALO_EXCHANGE_MSL_TEMPLATE: &str = r#"
229//
230// RingKernel Metal K2K Halo Exchange Template
231// Generated by ringkernel-metal
232//
233
234#include <metal_stdlib>
235using namespace metal;
236
237// K2K Inbox Header (64 bytes)
238struct K2KInboxHeader {
239    atomic_uint message_count;
240    uint max_messages;
241    atomic_uint head;
242    atomic_uint tail;
243    uint last_source;
244    atomic_uint lock;
245    atomic_uint sequence;
246    uint _reserved[9];
247};
248
249// K2K Route Entry (32 bytes)
250struct K2KRouteEntry {
251    uint dest_threadgroup;
252    uint inbox_offset;
253    uint is_active;
254    uint hops;
255    uint bandwidth_hint;
256    uint priority;
257    uint _reserved[2];
258};
259
260// K2K Routing Table
261struct K2KRoutingTable {
262    uint self_id;
263    uint route_count;
264    uint grid_dim_x;
265    uint grid_dim_y;
266    uint grid_dim_z;
267    uint _reserved[3];
268    K2KRouteEntry routes[26];  // Max neighbors for 3D Moore neighborhood
269};
270
271// Halo Message Header (32 bytes)
272struct HaloMessageHeader {
273    uint source;
274    uint direction;
275    uint width;
276    uint height;
277    uint depth;
278    uint element_size;
279    uint sequence;
280    uint flags;
281};
282
283// Try to acquire inbox lock
284bool k2k_try_lock(device K2KInboxHeader* inbox) {
285    uint expected = 0;
286    return atomic_compare_exchange_weak_explicit(
287        &inbox->lock, &expected, 1,
288        memory_order_acquire, memory_order_relaxed
289    );
290}
291
292// Release inbox lock
293void k2k_unlock(device K2KInboxHeader* inbox) {
294    atomic_store_explicit(&inbox->lock, 0, memory_order_release);
295}
296
297// Send halo data to neighbor
298bool k2k_send_halo(
299    device K2KRoutingTable* routing,
300    device uchar* inbox_buffer,
301    uint dest_id,
302    device float* halo_data,
303    uint width,
304    uint height,
305    uint depth,
306    uint direction,
307    uint thread_id
308) {
309    // Only thread 0 performs the send
310    if (thread_id != 0) return true;
311
312    // Find route to destination
313    for (uint i = 0; i < routing->route_count; i++) {
314        if (routing->routes[i].dest_threadgroup == dest_id &&
315            routing->routes[i].is_active != 0) {
316
317            uint offset = routing->routes[i].inbox_offset;
318            device K2KInboxHeader* inbox = (device K2KInboxHeader*)(inbox_buffer + offset);
319
320            // Try to acquire lock
321            if (!k2k_try_lock(inbox)) {
322                return false;  // Inbox busy
323            }
324
325            // Check if inbox has space
326            uint count = atomic_load_explicit(&inbox->message_count, memory_order_acquire);
327            if (count >= inbox->max_messages) {
328                k2k_unlock(inbox);
329                return false;  // Inbox full
330            }
331
332            // Write message header
333            uint msg_offset = offset + 64 + count * (32 + width * height * depth * 4);
334            device HaloMessageHeader* msg = (device HaloMessageHeader*)(inbox_buffer + msg_offset);
335            msg->source = routing->self_id;
336            msg->direction = direction;
337            msg->width = width;
338            msg->height = height;
339            msg->depth = depth;
340            msg->element_size = 4;
341            msg->sequence = atomic_fetch_add_explicit(&inbox->sequence, 1, memory_order_relaxed);
342            msg->flags = 0;
343
344            // Copy halo data
345            device float* payload = (device float*)(inbox_buffer + msg_offset + 32);
346            uint payload_size = width * height * depth;
347            for (uint j = 0; j < payload_size; j++) {
348                payload[j] = halo_data[j];
349            }
350
351            // Update message count
352            atomic_fetch_add_explicit(&inbox->message_count, 1, memory_order_release);
353            inbox->last_source = routing->self_id;
354
355            k2k_unlock(inbox);
356            return true;
357        }
358    }
359
360    return false;  // No route found
361}
362
363// Receive halo data from neighbors
364bool k2k_recv_halo(
365    device K2KRoutingTable* routing,
366    device uchar* inbox_buffer,
367    device float* dest_buffer,
368    uint* source_out,
369    uint* direction_out,
370    uint thread_id
371) {
372    // Only thread 0 performs the receive
373    if (thread_id != 0) return false;
374
375    uint offset = routing->self_id * 4096;  // Assume 4KB per inbox
376    device K2KInboxHeader* inbox = (device K2KInboxHeader*)(inbox_buffer + offset);
377
378    // Try to acquire lock
379    if (!k2k_try_lock(inbox)) {
380        return false;
381    }
382
383    // Check if inbox has messages
384    uint count = atomic_load_explicit(&inbox->message_count, memory_order_acquire);
385    if (count == 0) {
386        k2k_unlock(inbox);
387        return false;
388    }
389
390    // Read oldest message (FIFO)
391    uint head = atomic_load_explicit(&inbox->head, memory_order_acquire);
392    uint msg_offset = offset + 64 + head * 4064;  // 32 header + max 4032 payload
393    device HaloMessageHeader* msg = (device HaloMessageHeader*)(inbox_buffer + msg_offset);
394
395    *source_out = msg->source;
396    *direction_out = msg->direction;
397
398    // Copy halo data
399    device float* payload = (device float*)(inbox_buffer + msg_offset + 32);
400    uint payload_size = msg->width * msg->height * msg->depth;
401    for (uint j = 0; j < payload_size; j++) {
402        dest_buffer[j] = payload[j];
403    }
404
405    // Update head and count
406    atomic_fetch_add_explicit(&inbox->head, 1, memory_order_relaxed);
407    atomic_fetch_sub_explicit(&inbox->message_count, 1, memory_order_release);
408
409    k2k_unlock(inbox);
410    return true;
411}
412
413// Halo exchange kernel - sends halo data to all neighbors
414kernel void k2k_halo_exchange(
415    device K2KRoutingTable* routing [[buffer(0)]],
416    device uchar* inbox_buffer [[buffer(1)]],
417    device float* local_data [[buffer(2)]],
418    constant uint& tile_width [[buffer(3)]],
419    constant uint& tile_height [[buffer(4)]],
420    constant uint& halo_size [[buffer(5)]],
421    uint thread_id [[thread_position_in_threadgroup]],
422    uint threadgroup_id [[threadgroup_position_in_grid]]
423) {
424    // Extract halos from local data and send to neighbors
425    // Direction: 0=North, 1=South, 2=West, 3=East, 4=Up, 5=Down
426    // Note: Only thread 0 performs sends; other threads synchronize
427
428    uint tw = tile_width;
429    uint th = tile_height;
430    uint self_id = routing->self_id;
431
432    // Calculate neighbor IDs based on grid position
433    uint gx = routing->grid_dim_x;
434    uint gy = routing->grid_dim_y;
435    uint x = self_id % gx;
436    uint y = self_id / gx;
437
438    // Threadgroup-local halo buffer for column gather
439    threadgroup float column_halo[256];  // Max halo height
440
441    // Send North halo (top row) to neighbor above (y-1)
442    if (y > 0) {
443        uint north_neighbor = self_id - gx;
444        device float* north_halo = local_data + halo_size * tw;  // First interior row
445        k2k_send_halo(routing, inbox_buffer, north_neighbor,
446                      north_halo, tw, halo_size, 1, 0, thread_id);  // dir=0 (North)
447    }
448
449    // Send South halo (bottom row) to neighbor below (y+1)
450    if (y < gy - 1) {
451        uint south_neighbor = self_id + gx;
452        device float* south_halo = local_data + (th - 2 * halo_size) * tw;  // Last interior row
453        k2k_send_halo(routing, inbox_buffer, south_neighbor,
454                      south_halo, tw, halo_size, 1, 1, thread_id);  // dir=1 (South)
455    }
456
457    // Send West halo (left column) to neighbor left (x-1)
458    // Gather column data to threadgroup memory first
459    if (x > 0) {
460        uint west_neighbor = self_id - 1;
461        // Gather left interior column
462        if (thread_id < th && thread_id < 256) {
463            column_halo[thread_id] = local_data[thread_id * tw + halo_size];
464        }
465        threadgroup_barrier(mem_flags::mem_threadgroup);
466
467        // Thread 0 sends the gathered column
468        k2k_send_halo(routing, inbox_buffer, west_neighbor,
469                      column_halo, halo_size, th, 1, 2, thread_id);  // dir=2 (West)
470    }
471
472    // Send East halo (right column) to neighbor right (x+1)
473    if (x < gx - 1) {
474        uint east_neighbor = self_id + 1;
475        // Gather right interior column
476        if (thread_id < th && thread_id < 256) {
477            column_halo[thread_id] = local_data[thread_id * tw + (tw - 2 * halo_size)];
478        }
479        threadgroup_barrier(mem_flags::mem_threadgroup);
480
481        // Thread 0 sends the gathered column
482        k2k_send_halo(routing, inbox_buffer, east_neighbor,
483                      column_halo, halo_size, th, 1, 3, thread_id);  // dir=3 (East)
484    }
485
486    threadgroup_barrier(mem_flags::mem_device);
487}
488
489// Halo apply kernel - receives halo data and applies to ghost cells
490kernel void k2k_halo_apply(
491    device K2KRoutingTable* routing [[buffer(0)]],
492    device uchar* inbox_buffer [[buffer(1)]],
493    device float* local_data [[buffer(2)]],
494    constant uint& tile_width [[buffer(3)]],
495    constant uint& tile_height [[buffer(4)]],
496    constant uint& halo_size [[buffer(5)]],
497    uint thread_id [[thread_position_in_threadgroup]],
498    uint threadgroup_id [[threadgroup_position_in_grid]]
499) {
500    // Receive halo data from neighbors and apply to local ghost cells
501    // Thread 0 receives messages, then all threads cooperate to apply them
502
503    uint tw = tile_width;
504    uint th = tile_height;
505
506    // Threadgroup-shared receive buffer
507    threadgroup float recv_buffer[256];  // Max halo size
508    threadgroup uint msg_source;
509    threadgroup uint msg_direction;
510    threadgroup bool has_message;
511
512    // Keep receiving until inbox is empty
513    while (true) {
514        // Thread 0 attempts to receive
515        if (thread_id == 0) {
516            has_message = k2k_recv_halo(routing, inbox_buffer, recv_buffer, &msg_source, &msg_direction, 0);
517        }
518        threadgroup_barrier(mem_flags::mem_threadgroup);
519
520        if (!has_message) break;
521
522        // All threads cooperate to apply the received halo
523        switch (msg_direction) {
524            case 0: {
525                // From North - apply to top ghost row (row 0)
526                // Received data is a row of width tw
527                if (thread_id < tw) {
528                    for (uint h = 0; h < halo_size; h++) {
529                        local_data[h * tw + thread_id] = recv_buffer[h * tw + thread_id];
530                    }
531                }
532                break;
533            }
534            case 1: {
535                // From South - apply to bottom ghost row (row th-halo_size to th-1)
536                if (thread_id < tw) {
537                    for (uint h = 0; h < halo_size; h++) {
538                        uint row = th - halo_size + h;
539                        local_data[row * tw + thread_id] = recv_buffer[h * tw + thread_id];
540                    }
541                }
542                break;
543            }
544            case 2: {
545                // From West - apply to left ghost column (col 0)
546                if (thread_id < th) {
547                    for (uint h = 0; h < halo_size; h++) {
548                        local_data[thread_id * tw + h] = recv_buffer[thread_id];
549                    }
550                }
551                break;
552            }
553            case 3: {
554                // From East - apply to right ghost column (col tw-halo_size to tw-1)
555                if (thread_id < th) {
556                    for (uint h = 0; h < halo_size; h++) {
557                        uint col = tw - halo_size + h;
558                        local_data[thread_id * tw + col] = recv_buffer[thread_id];
559                    }
560                }
561                break;
562            }
563            case 4: {
564                // From Up - apply to top ghost plane (3D)
565                // Would need depth dimension; placeholder for 3D support
566                break;
567            }
568            case 5: {
569                // From Down - apply to bottom ghost plane (3D)
570                // Would need depth dimension; placeholder for 3D support
571                break;
572            }
573        }
574
575        threadgroup_barrier(mem_flags::mem_threadgroup);
576    }
577
578    threadgroup_barrier(mem_flags::mem_device);
579}
580"#;