1#![warn(missing_docs)]
33
34#[cfg(all(target_os = "macos", feature = "metal"))]
35mod device;
36#[cfg(all(target_os = "macos", feature = "metal"))]
37mod kernel;
38#[cfg(all(target_os = "macos", feature = "metal"))]
39mod memory;
40#[cfg(all(target_os = "macos", feature = "metal"))]
41mod runtime;
42
43#[cfg(all(target_os = "macos", feature = "metal"))]
44pub use device::MetalDevice;
45#[cfg(all(target_os = "macos", feature = "metal"))]
46pub use kernel::{
47 HaloExchangeConfig, HaloExchangeStats, MetalHaloExchange, MetalHaloMessage,
48 MetalK2KInboxHeader, MetalK2KRouteEntry, MetalK2KRoutingTable, MetalKernel,
49};
50#[cfg(all(target_os = "macos", feature = "metal"))]
51pub use memory::MetalBuffer;
52#[cfg(all(target_os = "macos", feature = "metal"))]
53pub use runtime::MetalRuntime;
54
55#[cfg(not(all(target_os = "macos", feature = "metal")))]
57mod stub {
58 ringkernel_core::unavailable_backend!(
59 MetalRuntime,
60 ringkernel_core::runtime::Backend::Metal,
61 "Metal"
62 );
63}
64
65#[cfg(not(all(target_os = "macos", feature = "metal")))]
66pub use stub::MetalRuntime;
67
68pub fn is_metal_available() -> bool {
70 #[cfg(all(target_os = "macos", feature = "metal"))]
71 {
72 metal::Device::system_default().is_some()
73 }
74 #[cfg(not(all(target_os = "macos", feature = "metal")))]
75 {
76 false
77 }
78}
79
80pub const RING_KERNEL_MSL_TEMPLATE: &str = r#"
82//
83// RingKernel Metal Shading Language Template
84// Generated by ringkernel-metal
85//
86
87#include <metal_stdlib>
88using namespace metal;
89
90// Control block structure (128 bytes)
91struct ControlBlock {
92 atomic_uint is_active;
93 atomic_uint should_terminate;
94 atomic_uint has_terminated;
95 uint _pad1;
96
97 atomic_ulong messages_processed;
98 atomic_ulong messages_in_flight;
99
100 atomic_ulong input_head;
101 atomic_ulong input_tail;
102 atomic_ulong output_head;
103 atomic_ulong output_tail;
104
105 uint input_capacity;
106 uint output_capacity;
107 uint input_mask;
108 uint output_mask;
109
110 // HLC state
111 atomic_ulong hlc_physical;
112 atomic_ulong hlc_logical;
113
114 atomic_uint last_error;
115 atomic_uint error_count;
116
117 uchar _reserved[16];
118};
119
120// Message header structure (256 bytes)
121struct MessageHeader {
122 ulong magic;
123 uint version;
124 uint flags;
125 ulong message_id;
126 ulong correlation_id;
127 ulong source_kernel;
128 ulong dest_kernel;
129 ulong message_type;
130 uchar priority;
131 uchar _reserved1[7];
132 ulong payload_size;
133 uint checksum;
134 uint _reserved2;
135 // HLC timestamp (24 bytes)
136 ulong ts_physical;
137 ulong ts_logical;
138 ulong ts_node_id;
139 // Deadline
140 ulong deadline_physical;
141 ulong deadline_logical;
142 ulong deadline_node_id;
143 uchar _reserved3[104];
144};
145
146// Kernel entry point
147kernel void ring_kernel_main(
148 device ControlBlock* control [[buffer(0)]],
149 device uchar* input_queue [[buffer(1)]],
150 device uchar* output_queue [[buffer(2)]],
151 device uchar* shared_state [[buffer(3)]],
152 uint thread_id [[thread_position_in_threadgroup]],
153 uint threadgroup_id [[threadgroup_position_in_grid]],
154 uint threads_per_group [[threads_per_threadgroup]]
155) {
156 // Check if kernel should process
157 uint is_active = atomic_load_explicit(&control->is_active, memory_order_acquire);
158 if (is_active == 0) {
159 return;
160 }
161
162 // Check termination
163 uint should_term = atomic_load_explicit(&control->should_terminate, memory_order_acquire);
164 if (should_term != 0) {
165 if (thread_id == 0 && threadgroup_id == 0) {
166 atomic_store_explicit(&control->has_terminated, 1, memory_order_release);
167 }
168 return;
169 }
170
171 // User kernel code will be inserted here
172 // USER_KERNEL_CODE
173
174 // Update message counter
175 if (thread_id == 0 && threadgroup_id == 0) {
176 atomic_fetch_add_explicit(&control->messages_processed, 1, memory_order_relaxed);
177 }
178}
179"#;
180
181pub const K2K_HALO_EXCHANGE_MSL_TEMPLATE: &str = r#"
186//
187// RingKernel Metal K2K Halo Exchange Template
188// Generated by ringkernel-metal
189//
190
191#include <metal_stdlib>
192using namespace metal;
193
194// K2K Inbox Header (64 bytes)
195struct K2KInboxHeader {
196 atomic_uint message_count;
197 uint max_messages;
198 atomic_uint head;
199 atomic_uint tail;
200 uint last_source;
201 atomic_uint lock;
202 atomic_uint sequence;
203 uint _reserved[9];
204};
205
206// K2K Route Entry (32 bytes)
207struct K2KRouteEntry {
208 uint dest_threadgroup;
209 uint inbox_offset;
210 uint is_active;
211 uint hops;
212 uint bandwidth_hint;
213 uint priority;
214 uint _reserved[2];
215};
216
217// K2K Routing Table
218struct K2KRoutingTable {
219 uint self_id;
220 uint route_count;
221 uint grid_dim_x;
222 uint grid_dim_y;
223 uint grid_dim_z;
224 uint _reserved[3];
225 K2KRouteEntry routes[26]; // Max neighbors for 3D Moore neighborhood
226};
227
228// Halo Message Header (32 bytes)
229struct HaloMessageHeader {
230 uint source;
231 uint direction;
232 uint width;
233 uint height;
234 uint depth;
235 uint element_size;
236 uint sequence;
237 uint flags;
238};
239
240// Try to acquire inbox lock
241bool k2k_try_lock(device K2KInboxHeader* inbox) {
242 uint expected = 0;
243 return atomic_compare_exchange_weak_explicit(
244 &inbox->lock, &expected, 1,
245 memory_order_acquire, memory_order_relaxed
246 );
247}
248
249// Release inbox lock
250void k2k_unlock(device K2KInboxHeader* inbox) {
251 atomic_store_explicit(&inbox->lock, 0, memory_order_release);
252}
253
254// Send halo data to neighbor
255bool k2k_send_halo(
256 device K2KRoutingTable* routing,
257 device uchar* inbox_buffer,
258 uint dest_id,
259 device float* halo_data,
260 uint width,
261 uint height,
262 uint depth,
263 uint direction,
264 uint thread_id
265) {
266 // Only thread 0 performs the send
267 if (thread_id != 0) return true;
268
269 // Find route to destination
270 for (uint i = 0; i < routing->route_count; i++) {
271 if (routing->routes[i].dest_threadgroup == dest_id &&
272 routing->routes[i].is_active != 0) {
273
274 uint offset = routing->routes[i].inbox_offset;
275 device K2KInboxHeader* inbox = (device K2KInboxHeader*)(inbox_buffer + offset);
276
277 // Try to acquire lock
278 if (!k2k_try_lock(inbox)) {
279 return false; // Inbox busy
280 }
281
282 // Check if inbox has space
283 uint count = atomic_load_explicit(&inbox->message_count, memory_order_acquire);
284 if (count >= inbox->max_messages) {
285 k2k_unlock(inbox);
286 return false; // Inbox full
287 }
288
289 // Write message header
290 uint msg_offset = offset + 64 + count * (32 + width * height * depth * 4);
291 device HaloMessageHeader* msg = (device HaloMessageHeader*)(inbox_buffer + msg_offset);
292 msg->source = routing->self_id;
293 msg->direction = direction;
294 msg->width = width;
295 msg->height = height;
296 msg->depth = depth;
297 msg->element_size = 4;
298 msg->sequence = atomic_fetch_add_explicit(&inbox->sequence, 1, memory_order_relaxed);
299 msg->flags = 0;
300
301 // Copy halo data
302 device float* payload = (device float*)(inbox_buffer + msg_offset + 32);
303 uint payload_size = width * height * depth;
304 for (uint j = 0; j < payload_size; j++) {
305 payload[j] = halo_data[j];
306 }
307
308 // Update message count
309 atomic_fetch_add_explicit(&inbox->message_count, 1, memory_order_release);
310 inbox->last_source = routing->self_id;
311
312 k2k_unlock(inbox);
313 return true;
314 }
315 }
316
317 return false; // No route found
318}
319
320// Receive halo data from neighbors
321bool k2k_recv_halo(
322 device K2KRoutingTable* routing,
323 device uchar* inbox_buffer,
324 device float* dest_buffer,
325 uint* source_out,
326 uint* direction_out,
327 uint thread_id
328) {
329 // Only thread 0 performs the receive
330 if (thread_id != 0) return false;
331
332 uint offset = routing->self_id * 4096; // Assume 4KB per inbox
333 device K2KInboxHeader* inbox = (device K2KInboxHeader*)(inbox_buffer + offset);
334
335 // Try to acquire lock
336 if (!k2k_try_lock(inbox)) {
337 return false;
338 }
339
340 // Check if inbox has messages
341 uint count = atomic_load_explicit(&inbox->message_count, memory_order_acquire);
342 if (count == 0) {
343 k2k_unlock(inbox);
344 return false;
345 }
346
347 // Read oldest message (FIFO)
348 uint head = atomic_load_explicit(&inbox->head, memory_order_acquire);
349 uint msg_offset = offset + 64 + head * 4064; // 32 header + max 4032 payload
350 device HaloMessageHeader* msg = (device HaloMessageHeader*)(inbox_buffer + msg_offset);
351
352 *source_out = msg->source;
353 *direction_out = msg->direction;
354
355 // Copy halo data
356 device float* payload = (device float*)(inbox_buffer + msg_offset + 32);
357 uint payload_size = msg->width * msg->height * msg->depth;
358 for (uint j = 0; j < payload_size; j++) {
359 dest_buffer[j] = payload[j];
360 }
361
362 // Update head and count
363 atomic_fetch_add_explicit(&inbox->head, 1, memory_order_relaxed);
364 atomic_fetch_sub_explicit(&inbox->message_count, 1, memory_order_release);
365
366 k2k_unlock(inbox);
367 return true;
368}
369
370// Halo exchange kernel - sends halo data to all neighbors
371kernel void k2k_halo_exchange(
372 device K2KRoutingTable* routing [[buffer(0)]],
373 device uchar* inbox_buffer [[buffer(1)]],
374 device float* local_data [[buffer(2)]],
375 constant uint& tile_width [[buffer(3)]],
376 constant uint& tile_height [[buffer(4)]],
377 constant uint& halo_size [[buffer(5)]],
378 uint thread_id [[thread_position_in_threadgroup]],
379 uint threadgroup_id [[threadgroup_position_in_grid]]
380) {
381 // Extract halos from local data and send to neighbors
382 // Direction: 0=North, 1=South, 2=West, 3=East, 4=Up, 5=Down
383 // Note: Only thread 0 performs sends; other threads synchronize
384
385 uint tw = tile_width;
386 uint th = tile_height;
387 uint self_id = routing->self_id;
388
389 // Calculate neighbor IDs based on grid position
390 uint gx = routing->grid_dim_x;
391 uint gy = routing->grid_dim_y;
392 uint x = self_id % gx;
393 uint y = self_id / gx;
394
395 // Threadgroup-local halo buffer for column gather
396 threadgroup float column_halo[256]; // Max halo height
397
398 // Send North halo (top row) to neighbor above (y-1)
399 if (y > 0) {
400 uint north_neighbor = self_id - gx;
401 device float* north_halo = local_data + halo_size * tw; // First interior row
402 k2k_send_halo(routing, inbox_buffer, north_neighbor,
403 north_halo, tw, halo_size, 1, 0, thread_id); // dir=0 (North)
404 }
405
406 // Send South halo (bottom row) to neighbor below (y+1)
407 if (y < gy - 1) {
408 uint south_neighbor = self_id + gx;
409 device float* south_halo = local_data + (th - 2 * halo_size) * tw; // Last interior row
410 k2k_send_halo(routing, inbox_buffer, south_neighbor,
411 south_halo, tw, halo_size, 1, 1, thread_id); // dir=1 (South)
412 }
413
414 // Send West halo (left column) to neighbor left (x-1)
415 // Gather column data to threadgroup memory first
416 if (x > 0) {
417 uint west_neighbor = self_id - 1;
418 // Gather left interior column
419 if (thread_id < th && thread_id < 256) {
420 column_halo[thread_id] = local_data[thread_id * tw + halo_size];
421 }
422 threadgroup_barrier(mem_flags::mem_threadgroup);
423
424 // Thread 0 sends the gathered column
425 k2k_send_halo(routing, inbox_buffer, west_neighbor,
426 column_halo, halo_size, th, 1, 2, thread_id); // dir=2 (West)
427 }
428
429 // Send East halo (right column) to neighbor right (x+1)
430 if (x < gx - 1) {
431 uint east_neighbor = self_id + 1;
432 // Gather right interior column
433 if (thread_id < th && thread_id < 256) {
434 column_halo[thread_id] = local_data[thread_id * tw + (tw - 2 * halo_size)];
435 }
436 threadgroup_barrier(mem_flags::mem_threadgroup);
437
438 // Thread 0 sends the gathered column
439 k2k_send_halo(routing, inbox_buffer, east_neighbor,
440 column_halo, halo_size, th, 1, 3, thread_id); // dir=3 (East)
441 }
442
443 threadgroup_barrier(mem_flags::mem_device);
444}
445
446// Halo apply kernel - receives halo data and applies to ghost cells
447kernel void k2k_halo_apply(
448 device K2KRoutingTable* routing [[buffer(0)]],
449 device uchar* inbox_buffer [[buffer(1)]],
450 device float* local_data [[buffer(2)]],
451 constant uint& tile_width [[buffer(3)]],
452 constant uint& tile_height [[buffer(4)]],
453 constant uint& halo_size [[buffer(5)]],
454 uint thread_id [[thread_position_in_threadgroup]],
455 uint threadgroup_id [[threadgroup_position_in_grid]]
456) {
457 // Receive halo data from neighbors and apply to local ghost cells
458 // Thread 0 receives messages, then all threads cooperate to apply them
459
460 uint tw = tile_width;
461 uint th = tile_height;
462
463 // Threadgroup-shared receive buffer
464 threadgroup float recv_buffer[256]; // Max halo size
465 threadgroup uint msg_source;
466 threadgroup uint msg_direction;
467 threadgroup bool has_message;
468
469 // Keep receiving until inbox is empty
470 while (true) {
471 // Thread 0 attempts to receive
472 if (thread_id == 0) {
473 has_message = k2k_recv_halo(routing, inbox_buffer, recv_buffer, &msg_source, &msg_direction, 0);
474 }
475 threadgroup_barrier(mem_flags::mem_threadgroup);
476
477 if (!has_message) break;
478
479 // All threads cooperate to apply the received halo
480 switch (msg_direction) {
481 case 0: {
482 // From North - apply to top ghost row (row 0)
483 // Received data is a row of width tw
484 if (thread_id < tw) {
485 for (uint h = 0; h < halo_size; h++) {
486 local_data[h * tw + thread_id] = recv_buffer[h * tw + thread_id];
487 }
488 }
489 break;
490 }
491 case 1: {
492 // From South - apply to bottom ghost row (row th-halo_size to th-1)
493 if (thread_id < tw) {
494 for (uint h = 0; h < halo_size; h++) {
495 uint row = th - halo_size + h;
496 local_data[row * tw + thread_id] = recv_buffer[h * tw + thread_id];
497 }
498 }
499 break;
500 }
501 case 2: {
502 // From West - apply to left ghost column (col 0)
503 if (thread_id < th) {
504 for (uint h = 0; h < halo_size; h++) {
505 local_data[thread_id * tw + h] = recv_buffer[thread_id];
506 }
507 }
508 break;
509 }
510 case 3: {
511 // From East - apply to right ghost column (col tw-halo_size to tw-1)
512 if (thread_id < th) {
513 for (uint h = 0; h < halo_size; h++) {
514 uint col = tw - halo_size + h;
515 local_data[thread_id * tw + col] = recv_buffer[thread_id];
516 }
517 }
518 break;
519 }
520 case 4: {
521 // From Up - apply to top ghost plane (3D)
522 // Would need depth dimension; placeholder for 3D support
523 break;
524 }
525 case 5: {
526 // From Down - apply to bottom ghost plane (3D)
527 // Would need depth dimension; placeholder for 3D support
528 break;
529 }
530 }
531
532 threadgroup_barrier(mem_flags::mem_threadgroup);
533 }
534
535 threadgroup_barrier(mem_flags::mem_device);
536}
537"#;