1#![warn(missing_docs)]
33
34#[cfg(all(target_os = "macos", feature = "metal"))]
35mod device;
36#[cfg(all(target_os = "macos", feature = "metal"))]
37mod kernel;
38#[cfg(all(target_os = "macos", feature = "metal"))]
39mod memory;
40#[cfg(all(target_os = "macos", feature = "metal"))]
41mod runtime;
42
43#[cfg(all(target_os = "macos", feature = "metal"))]
44pub use device::MetalDevice;
45#[cfg(all(target_os = "macos", feature = "metal"))]
46pub use kernel::{
47 HaloExchangeConfig, HaloExchangeStats, MetalHaloExchange, MetalHaloMessage,
48 MetalK2KInboxHeader, MetalK2KRouteEntry, MetalK2KRoutingTable, MetalKernel,
49};
50#[cfg(all(target_os = "macos", feature = "metal"))]
51pub use memory::MetalBuffer;
52#[cfg(all(target_os = "macos", feature = "metal"))]
53pub use runtime::MetalRuntime;
54
55#[cfg(not(all(target_os = "macos", feature = "metal")))]
57mod stub {
58 use async_trait::async_trait;
59 use ringkernel_core::error::{Result, RingKernelError};
60 use ringkernel_core::runtime::{
61 Backend, KernelHandle, KernelId, LaunchOptions, RingKernelRuntime, RuntimeMetrics,
62 };
63
64 pub struct MetalRuntime;
66
67 impl MetalRuntime {
68 pub async fn new() -> Result<Self> {
70 Err(RingKernelError::BackendUnavailable(
71 "Metal not available (requires macOS with metal feature)".to_string(),
72 ))
73 }
74 }
75
76 #[async_trait]
77 impl RingKernelRuntime for MetalRuntime {
78 fn backend(&self) -> Backend {
79 Backend::Metal
80 }
81
82 fn is_backend_available(&self, _backend: Backend) -> bool {
83 false
84 }
85
86 async fn launch(&self, _kernel_id: &str, _options: LaunchOptions) -> Result<KernelHandle> {
87 Err(RingKernelError::BackendUnavailable("Metal".to_string()))
88 }
89
90 fn get_kernel(&self, _kernel_id: &KernelId) -> Option<KernelHandle> {
91 None
92 }
93
94 fn list_kernels(&self) -> Vec<KernelId> {
95 vec![]
96 }
97
98 fn metrics(&self) -> RuntimeMetrics {
99 RuntimeMetrics::default()
100 }
101
102 async fn shutdown(&self) -> Result<()> {
103 Ok(())
104 }
105 }
106}
107
108#[cfg(not(all(target_os = "macos", feature = "metal")))]
109pub use stub::MetalRuntime;
110
111pub fn is_metal_available() -> bool {
113 #[cfg(all(target_os = "macos", feature = "metal"))]
114 {
115 metal::Device::system_default().is_some()
116 }
117 #[cfg(not(all(target_os = "macos", feature = "metal")))]
118 {
119 false
120 }
121}
122
123pub const RING_KERNEL_MSL_TEMPLATE: &str = r#"
125//
126// RingKernel Metal Shading Language Template
127// Generated by ringkernel-metal
128//
129
130#include <metal_stdlib>
131using namespace metal;
132
133// Control block structure (128 bytes)
134struct ControlBlock {
135 atomic_uint is_active;
136 atomic_uint should_terminate;
137 atomic_uint has_terminated;
138 uint _pad1;
139
140 atomic_ulong messages_processed;
141 atomic_ulong messages_in_flight;
142
143 atomic_ulong input_head;
144 atomic_ulong input_tail;
145 atomic_ulong output_head;
146 atomic_ulong output_tail;
147
148 uint input_capacity;
149 uint output_capacity;
150 uint input_mask;
151 uint output_mask;
152
153 // HLC state
154 atomic_ulong hlc_physical;
155 atomic_ulong hlc_logical;
156
157 atomic_uint last_error;
158 atomic_uint error_count;
159
160 uchar _reserved[16];
161};
162
163// Message header structure (256 bytes)
164struct MessageHeader {
165 ulong magic;
166 uint version;
167 uint flags;
168 ulong message_id;
169 ulong correlation_id;
170 ulong source_kernel;
171 ulong dest_kernel;
172 ulong message_type;
173 uchar priority;
174 uchar _reserved1[7];
175 ulong payload_size;
176 uint checksum;
177 uint _reserved2;
178 // HLC timestamp (24 bytes)
179 ulong ts_physical;
180 ulong ts_logical;
181 ulong ts_node_id;
182 // Deadline
183 ulong deadline_physical;
184 ulong deadline_logical;
185 ulong deadline_node_id;
186 uchar _reserved3[104];
187};
188
189// Kernel entry point
190kernel void ring_kernel_main(
191 device ControlBlock* control [[buffer(0)]],
192 device uchar* input_queue [[buffer(1)]],
193 device uchar* output_queue [[buffer(2)]],
194 device uchar* shared_state [[buffer(3)]],
195 uint thread_id [[thread_position_in_threadgroup]],
196 uint threadgroup_id [[threadgroup_position_in_grid]],
197 uint threads_per_group [[threads_per_threadgroup]]
198) {
199 // Check if kernel should process
200 uint is_active = atomic_load_explicit(&control->is_active, memory_order_acquire);
201 if (is_active == 0) {
202 return;
203 }
204
205 // Check termination
206 uint should_term = atomic_load_explicit(&control->should_terminate, memory_order_acquire);
207 if (should_term != 0) {
208 if (thread_id == 0 && threadgroup_id == 0) {
209 atomic_store_explicit(&control->has_terminated, 1, memory_order_release);
210 }
211 return;
212 }
213
214 // User kernel code will be inserted here
215 // USER_KERNEL_CODE
216
217 // Update message counter
218 if (thread_id == 0 && threadgroup_id == 0) {
219 atomic_fetch_add_explicit(&control->messages_processed, 1, memory_order_relaxed);
220 }
221}
222"#;
223
224pub const K2K_HALO_EXCHANGE_MSL_TEMPLATE: &str = r#"
229//
230// RingKernel Metal K2K Halo Exchange Template
231// Generated by ringkernel-metal
232//
233
234#include <metal_stdlib>
235using namespace metal;
236
237// K2K Inbox Header (64 bytes)
238struct K2KInboxHeader {
239 atomic_uint message_count;
240 uint max_messages;
241 atomic_uint head;
242 atomic_uint tail;
243 uint last_source;
244 atomic_uint lock;
245 atomic_uint sequence;
246 uint _reserved[9];
247};
248
249// K2K Route Entry (32 bytes)
250struct K2KRouteEntry {
251 uint dest_threadgroup;
252 uint inbox_offset;
253 uint is_active;
254 uint hops;
255 uint bandwidth_hint;
256 uint priority;
257 uint _reserved[2];
258};
259
260// K2K Routing Table
261struct K2KRoutingTable {
262 uint self_id;
263 uint route_count;
264 uint grid_dim_x;
265 uint grid_dim_y;
266 uint grid_dim_z;
267 uint _reserved[3];
268 K2KRouteEntry routes[26]; // Max neighbors for 3D Moore neighborhood
269};
270
271// Halo Message Header (32 bytes)
272struct HaloMessageHeader {
273 uint source;
274 uint direction;
275 uint width;
276 uint height;
277 uint depth;
278 uint element_size;
279 uint sequence;
280 uint flags;
281};
282
283// Try to acquire inbox lock
284bool k2k_try_lock(device K2KInboxHeader* inbox) {
285 uint expected = 0;
286 return atomic_compare_exchange_weak_explicit(
287 &inbox->lock, &expected, 1,
288 memory_order_acquire, memory_order_relaxed
289 );
290}
291
292// Release inbox lock
293void k2k_unlock(device K2KInboxHeader* inbox) {
294 atomic_store_explicit(&inbox->lock, 0, memory_order_release);
295}
296
297// Send halo data to neighbor
298bool k2k_send_halo(
299 device K2KRoutingTable* routing,
300 device uchar* inbox_buffer,
301 uint dest_id,
302 device float* halo_data,
303 uint width,
304 uint height,
305 uint depth,
306 uint direction,
307 uint thread_id
308) {
309 // Only thread 0 performs the send
310 if (thread_id != 0) return true;
311
312 // Find route to destination
313 for (uint i = 0; i < routing->route_count; i++) {
314 if (routing->routes[i].dest_threadgroup == dest_id &&
315 routing->routes[i].is_active != 0) {
316
317 uint offset = routing->routes[i].inbox_offset;
318 device K2KInboxHeader* inbox = (device K2KInboxHeader*)(inbox_buffer + offset);
319
320 // Try to acquire lock
321 if (!k2k_try_lock(inbox)) {
322 return false; // Inbox busy
323 }
324
325 // Check if inbox has space
326 uint count = atomic_load_explicit(&inbox->message_count, memory_order_acquire);
327 if (count >= inbox->max_messages) {
328 k2k_unlock(inbox);
329 return false; // Inbox full
330 }
331
332 // Write message header
333 uint msg_offset = offset + 64 + count * (32 + width * height * depth * 4);
334 device HaloMessageHeader* msg = (device HaloMessageHeader*)(inbox_buffer + msg_offset);
335 msg->source = routing->self_id;
336 msg->direction = direction;
337 msg->width = width;
338 msg->height = height;
339 msg->depth = depth;
340 msg->element_size = 4;
341 msg->sequence = atomic_fetch_add_explicit(&inbox->sequence, 1, memory_order_relaxed);
342 msg->flags = 0;
343
344 // Copy halo data
345 device float* payload = (device float*)(inbox_buffer + msg_offset + 32);
346 uint payload_size = width * height * depth;
347 for (uint j = 0; j < payload_size; j++) {
348 payload[j] = halo_data[j];
349 }
350
351 // Update message count
352 atomic_fetch_add_explicit(&inbox->message_count, 1, memory_order_release);
353 inbox->last_source = routing->self_id;
354
355 k2k_unlock(inbox);
356 return true;
357 }
358 }
359
360 return false; // No route found
361}
362
363// Receive halo data from neighbors
364bool k2k_recv_halo(
365 device K2KRoutingTable* routing,
366 device uchar* inbox_buffer,
367 device float* dest_buffer,
368 uint* source_out,
369 uint* direction_out,
370 uint thread_id
371) {
372 // Only thread 0 performs the receive
373 if (thread_id != 0) return false;
374
375 uint offset = routing->self_id * 4096; // Assume 4KB per inbox
376 device K2KInboxHeader* inbox = (device K2KInboxHeader*)(inbox_buffer + offset);
377
378 // Try to acquire lock
379 if (!k2k_try_lock(inbox)) {
380 return false;
381 }
382
383 // Check if inbox has messages
384 uint count = atomic_load_explicit(&inbox->message_count, memory_order_acquire);
385 if (count == 0) {
386 k2k_unlock(inbox);
387 return false;
388 }
389
390 // Read oldest message (FIFO)
391 uint head = atomic_load_explicit(&inbox->head, memory_order_acquire);
392 uint msg_offset = offset + 64 + head * 4064; // 32 header + max 4032 payload
393 device HaloMessageHeader* msg = (device HaloMessageHeader*)(inbox_buffer + msg_offset);
394
395 *source_out = msg->source;
396 *direction_out = msg->direction;
397
398 // Copy halo data
399 device float* payload = (device float*)(inbox_buffer + msg_offset + 32);
400 uint payload_size = msg->width * msg->height * msg->depth;
401 for (uint j = 0; j < payload_size; j++) {
402 dest_buffer[j] = payload[j];
403 }
404
405 // Update head and count
406 atomic_fetch_add_explicit(&inbox->head, 1, memory_order_relaxed);
407 atomic_fetch_sub_explicit(&inbox->message_count, 1, memory_order_release);
408
409 k2k_unlock(inbox);
410 return true;
411}
412
413// Halo exchange kernel - sends halo data to all neighbors
414kernel void k2k_halo_exchange(
415 device K2KRoutingTable* routing [[buffer(0)]],
416 device uchar* inbox_buffer [[buffer(1)]],
417 device float* local_data [[buffer(2)]],
418 constant uint& tile_width [[buffer(3)]],
419 constant uint& tile_height [[buffer(4)]],
420 constant uint& halo_size [[buffer(5)]],
421 uint thread_id [[thread_position_in_threadgroup]],
422 uint threadgroup_id [[threadgroup_position_in_grid]]
423) {
424 // Extract halos from local data and send to neighbors
425 // Direction: 0=North, 1=South, 2=West, 3=East, 4=Up, 5=Down
426 // Note: Only thread 0 performs sends; other threads synchronize
427
428 uint tw = tile_width;
429 uint th = tile_height;
430 uint self_id = routing->self_id;
431
432 // Calculate neighbor IDs based on grid position
433 uint gx = routing->grid_dim_x;
434 uint gy = routing->grid_dim_y;
435 uint x = self_id % gx;
436 uint y = self_id / gx;
437
438 // Threadgroup-local halo buffer for column gather
439 threadgroup float column_halo[256]; // Max halo height
440
441 // Send North halo (top row) to neighbor above (y-1)
442 if (y > 0) {
443 uint north_neighbor = self_id - gx;
444 device float* north_halo = local_data + halo_size * tw; // First interior row
445 k2k_send_halo(routing, inbox_buffer, north_neighbor,
446 north_halo, tw, halo_size, 1, 0, thread_id); // dir=0 (North)
447 }
448
449 // Send South halo (bottom row) to neighbor below (y+1)
450 if (y < gy - 1) {
451 uint south_neighbor = self_id + gx;
452 device float* south_halo = local_data + (th - 2 * halo_size) * tw; // Last interior row
453 k2k_send_halo(routing, inbox_buffer, south_neighbor,
454 south_halo, tw, halo_size, 1, 1, thread_id); // dir=1 (South)
455 }
456
457 // Send West halo (left column) to neighbor left (x-1)
458 // Gather column data to threadgroup memory first
459 if (x > 0) {
460 uint west_neighbor = self_id - 1;
461 // Gather left interior column
462 if (thread_id < th && thread_id < 256) {
463 column_halo[thread_id] = local_data[thread_id * tw + halo_size];
464 }
465 threadgroup_barrier(mem_flags::mem_threadgroup);
466
467 // Thread 0 sends the gathered column
468 k2k_send_halo(routing, inbox_buffer, west_neighbor,
469 column_halo, halo_size, th, 1, 2, thread_id); // dir=2 (West)
470 }
471
472 // Send East halo (right column) to neighbor right (x+1)
473 if (x < gx - 1) {
474 uint east_neighbor = self_id + 1;
475 // Gather right interior column
476 if (thread_id < th && thread_id < 256) {
477 column_halo[thread_id] = local_data[thread_id * tw + (tw - 2 * halo_size)];
478 }
479 threadgroup_barrier(mem_flags::mem_threadgroup);
480
481 // Thread 0 sends the gathered column
482 k2k_send_halo(routing, inbox_buffer, east_neighbor,
483 column_halo, halo_size, th, 1, 3, thread_id); // dir=3 (East)
484 }
485
486 threadgroup_barrier(mem_flags::mem_device);
487}
488
489// Halo apply kernel - receives halo data and applies to ghost cells
490kernel void k2k_halo_apply(
491 device K2KRoutingTable* routing [[buffer(0)]],
492 device uchar* inbox_buffer [[buffer(1)]],
493 device float* local_data [[buffer(2)]],
494 constant uint& tile_width [[buffer(3)]],
495 constant uint& tile_height [[buffer(4)]],
496 constant uint& halo_size [[buffer(5)]],
497 uint thread_id [[thread_position_in_threadgroup]],
498 uint threadgroup_id [[threadgroup_position_in_grid]]
499) {
500 // Receive halo data from neighbors and apply to local ghost cells
501 // Thread 0 receives messages, then all threads cooperate to apply them
502
503 uint tw = tile_width;
504 uint th = tile_height;
505
506 // Threadgroup-shared receive buffer
507 threadgroup float recv_buffer[256]; // Max halo size
508 threadgroup uint msg_source;
509 threadgroup uint msg_direction;
510 threadgroup bool has_message;
511
512 // Keep receiving until inbox is empty
513 while (true) {
514 // Thread 0 attempts to receive
515 if (thread_id == 0) {
516 has_message = k2k_recv_halo(routing, inbox_buffer, recv_buffer, &msg_source, &msg_direction, 0);
517 }
518 threadgroup_barrier(mem_flags::mem_threadgroup);
519
520 if (!has_message) break;
521
522 // All threads cooperate to apply the received halo
523 switch (msg_direction) {
524 case 0: {
525 // From North - apply to top ghost row (row 0)
526 // Received data is a row of width tw
527 if (thread_id < tw) {
528 for (uint h = 0; h < halo_size; h++) {
529 local_data[h * tw + thread_id] = recv_buffer[h * tw + thread_id];
530 }
531 }
532 break;
533 }
534 case 1: {
535 // From South - apply to bottom ghost row (row th-halo_size to th-1)
536 if (thread_id < tw) {
537 for (uint h = 0; h < halo_size; h++) {
538 uint row = th - halo_size + h;
539 local_data[row * tw + thread_id] = recv_buffer[h * tw + thread_id];
540 }
541 }
542 break;
543 }
544 case 2: {
545 // From West - apply to left ghost column (col 0)
546 if (thread_id < th) {
547 for (uint h = 0; h < halo_size; h++) {
548 local_data[thread_id * tw + h] = recv_buffer[thread_id];
549 }
550 }
551 break;
552 }
553 case 3: {
554 // From East - apply to right ghost column (col tw-halo_size to tw-1)
555 if (thread_id < th) {
556 for (uint h = 0; h < halo_size; h++) {
557 uint col = tw - halo_size + h;
558 local_data[thread_id * tw + col] = recv_buffer[thread_id];
559 }
560 }
561 break;
562 }
563 case 4: {
564 // From Up - apply to top ghost plane (3D)
565 // Would need depth dimension; placeholder for 3D support
566 break;
567 }
568 case 5: {
569 // From Down - apply to bottom ghost plane (3D)
570 // Would need depth dimension; placeholder for 3D support
571 break;
572 }
573 }
574
575 threadgroup_barrier(mem_flags::mem_threadgroup);
576 }
577
578 threadgroup_barrier(mem_flags::mem_device);
579}
580"#;