#![cfg(all(feature = "gpu", target_os = "macos"))]
#![allow(
clippy::cast_possible_truncation,
clippy::cast_sign_loss,
clippy::doc_markdown,
clippy::borrow_as_ptr,
clippy::needless_borrow
)]
use std::ffi::c_void;
use std::os::raw::c_int;
use std::time::Instant;
use metal::foreign_types::ForeignTypeRef;
use metal::objc::runtime::Object;
use metal::objc::{msg_send, sel, sel_impl};
use metal::{
Buffer, CommandBufferRef, CommandQueue, ComputePipelineDescriptor, ComputePipelineState,
Device, Library, MTLResourceOptions, MTLSize, NSUInteger,
};
const QOS_CLASS_USER_INTERACTIVE: u32 = 0x21;
#[link(name = "c")]
extern "C" {
fn pthread_set_qos_class_self_np(qos_class: u32, relative_priority: c_int) -> c_int;
}
#[link(name = "IOKit", kind = "framework")]
extern "C" {
fn IOPMAssertionCreateWithName(
assertion_type: *const c_void,
assertion_level: u32,
assertion_name: *const c_void,
out_assertion_id: *mut u32,
) -> i32;
fn IOPMAssertionRelease(assertion_id: u32) -> i32;
}
#[link(name = "CoreFoundation", kind = "framework")]
extern "C" {
fn CFStringCreateWithCString(
allocator: *const c_void,
c_str: *const std::os::raw::c_char,
encoding: u32,
) -> *const c_void;
}
const K_CF_STRING_ENCODING_UTF8: u32 = 0x0800_0100;
pub fn raise_thread_qos_user_interactive() {
let rc = unsafe { pthread_set_qos_class_self_np(QOS_CLASS_USER_INTERACTIVE, 0) };
if rc != 0 {
eprintln!("difflib-fast: pthread_set_qos_class_self_np failed (rc={rc}); proceeding at default QoS");
}
}
#[must_use]
pub fn hold_high_perf_assertion() -> u32 {
unsafe {
let kind = b"PreventUserIdleSystemSleep\0";
let name = b"difflib-fast.gpu.boost\0";
let null_alloc: *const c_void = std::ptr::null();
let kind_cs = CFStringCreateWithCString(
null_alloc,
kind.as_ptr().cast(),
K_CF_STRING_ENCODING_UTF8,
);
let name_cs = CFStringCreateWithCString(
null_alloc,
name.as_ptr().cast(),
K_CF_STRING_ENCODING_UTF8,
);
if kind_cs.is_null() || name_cs.is_null() {
return 0;
}
let mut id: u32 = 0;
let rc = IOPMAssertionCreateWithName(kind_cs, 255, name_cs, &mut id);
if rc != 0 {
eprintln!("difflib-fast: IOPMAssertionCreate failed (rc={rc}); proceeding without boost");
return 0;
}
id
}
}
pub fn release_high_perf_assertion(id: u32) {
if id == 0 {
return;
}
let _ = unsafe { IOPMAssertionRelease(id) };
}
pub struct BoostGuard {
assertion_id: u32,
}
impl BoostGuard {
#[must_use]
pub fn acquire() -> Self {
raise_thread_qos_user_interactive();
let assertion_id = hold_high_perf_assertion();
Self { assertion_id }
}
}
impl Drop for BoostGuard {
fn drop(&mut self) {
release_high_perf_assertion(self.assertion_id);
}
}
fn set_queue_high_priority(queue: &CommandQueue) {
unsafe {
let q: *mut Object = queue.as_ptr().cast();
let sel_obj = sel!(setReducedCPUPriority:);
let responds: bool = msg_send![q, respondsToSelector: sel_obj];
if responds {
let _: () = msg_send![q, setReducedCPUPriority: false];
}
}
}
fn gpu_command_buffer_times(cmd: &CommandBufferRef) -> (f64, f64, f64, f64) {
unsafe {
let gpu_start: f64 = msg_send![cmd.as_ptr(), GPUStartTime];
let gpu_end: f64 = msg_send![cmd.as_ptr(), GPUEndTime];
let kernel_start: f64 = msg_send![cmd.as_ptr(), kernelStartTime];
let kernel_end: f64 = msg_send![cmd.as_ptr(), kernelEndTime];
(kernel_start, kernel_end, gpu_start, gpu_end)
}
}
const KERNELS: &str = "
#include <metal_stdlib>
using namespace metal;
// Stage-4a-1 smoke test: pure element-wise add over u32 arrays. Used to verify the Metal pipeline
// is wired correctly end-to-end (buffer upload -> dispatch -> readback) before we wire in the
// SAM matching-stats kernel. Writes are well-defined per-thread, so this is a sound correctness gate.
kernel void smoke_elementwise_add(
device const uint* a [[buffer(0)]],
device const uint* b [[buffer(1)]],
device uint* out [[buffer(2)]],
constant uint& n [[buffer(3)]],
uint id [[thread_position_in_grid]]
) {
if (id >= n) return;
out[id] = a[id] + b[id];
}
// Stage-4a-11: partial-cache variant. ONLY the first `K_HOT_NODES` states (low-len, near-root,
// most-visited per matching_stats traffic distribution) live in threadgroup memory; states with
// `state >= K_HOT_NODES` fall through to global `sam_nodes_g`. Drops the full-SAM-in-TG cap of
// `matching_stats_by_b` (which forced a CPU fallback for SAMs > 32 KB) while still giving the
// hot path TG-memory-speed (≈1 cycle vs 30 cycles for L1 vs 200 cycles for RAM).
//
// Theory of expected gain on canonical Python:
// - K=256 covers the low-`len` band; instrumented runs show ~60-90% of byte visits land here
// (after each suffix-link backtrack the walker resets to a shallow state).
// - Hot byte = ~1 cycle node read; cold byte = ~9 cycles. At 80% hot ratio average drops
// from ~9 cycles/byte (all-global baseline) to ~2.6 cycles/byte → ~3.5× kernel speedup,
// putting GPU compute in the ~27 ms range for 100 k mypy pairs and CPU wall at ~5-7× CPU.
//
// Edges stay in global memory (already largely L1-resident — each state's edge range is small
// and contiguous). Caching them too would push K down due to the 32 KB threadgroup cap.
kernel void matching_stats_by_b_partial(
device const uint* pair_a_idx_sorted [[buffer(0)]],
device const uint* pair_b_offsets [[buffer(1)]],
device const uint* active_b_idx [[buffer(2)]],
device const uchar* a_data [[buffer(3)]],
device const uint* a_offsets [[buffer(4)]],
device const uint4* sam_nodes_g [[buffer(5)]],
device const uint* sam_node_offs [[buffer(6)]],
device const uint* sam_edges_g [[buffer(7)]],
device const uint* sam_edge_offs [[buffer(8)]],
device const int* sam_root_g [[buffer(9)]],
device uint* fmatch_out [[buffer(10)]],
device uint* fstate_out [[buffer(11)]],
device const uint* out_offsets [[buffer(12)]],
constant uint& k_hot_nodes [[buffer(13)]],
threadgroup uchar* tg_mem [[threadgroup(0)]],
uint tg_pos [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint tg_size [[threads_per_threadgroup]]
) {
uint b_idx = active_b_idx[tg_pos];
uint node_lo = sam_node_offs[b_idx];
uint node_hi = sam_node_offs[b_idx + 1u];
uint n_nodes = node_hi - node_lo;
uint sam_node_base = node_lo;
uint sam_edge_base = sam_edge_offs[b_idx];
uint sam_root_base = b_idx * 128u;
// Cache first min(n_nodes, K_HOT) state nodes + root_next in TG memory. MEASURED: caching
// edges too gave NO additional win on canonical Python (HA, mypy) — edges within a state's
// contiguous range are already L1-resident, while caching them in TG memory cost arena
// bytes that reduced occupancy. Keep edges in global; cache only nodes.
uint k_hot = (n_nodes < k_hot_nodes) ? n_nodes : k_hot_nodes;
threadgroup uint4* nodes_tg = (threadgroup uint4*)(tg_mem);
threadgroup int* root_tg = (threadgroup int*) (tg_mem + k_hot_nodes * 16u);
for (uint i = lid; i < k_hot; i += tg_size) {
nodes_tg[i] = sam_nodes_g[node_lo + i];
}
for (uint i = lid; i < 128u; i += tg_size) {
root_tg[i] = sam_root_g[sam_root_base + i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
uint pair_lo = pair_b_offsets[tg_pos];
uint pair_hi = pair_b_offsets[tg_pos + 1u];
uint n_my = pair_hi - pair_lo;
for (uint pos = lid; pos < n_my; pos += tg_size) {
uint t = pair_lo + pos;
uint a_idx = pair_a_idx_sorted[t];
uint a_lo = a_offsets[a_idx];
uint a_len = a_offsets[a_idx + 1u] - a_lo;
uint out_base = out_offsets[t];
uint state = 0u;
uint matched = 0u;
for (uint i = 0u; i < a_len; i++) {
uint c = (uint)a_data[a_lo + i];
for (;;) {
int nx = -1;
uint4 cur_nd;
bool have_cur_nd = false;
if (state == 0u) {
if (c < 128u) {
nx = root_tg[c];
} else {
// Root state is always in TG (state 0 < k_hot trivially).
uint4 nd = nodes_tg[0];
uint elo = nd.z;
uint ehi = nd.w;
while (elo < ehi) {
uint mid = elo + (ehi - elo) / 2u;
uint e = sam_edges_g[sam_edge_base + mid];
uint mc = e >> 24;
if (mc == c) { nx = (int)(e & 0xFFFFFFu); break; }
if (mc < c) { elo = mid + 1u; } else { ehi = mid; }
}
}
} else {
// Hot path: state < k_hot → nodes_tg (≈1 cycle TG memory latency).
// Cold path: state >= k_hot → global memory (≈30+ cycles).
// Ternary on the SELECT side — both addresses are computed but only one
// load fires per warp lane (M3 select-merge keeps it from doubling traffic).
cur_nd = (state < k_hot) ? nodes_tg[state] : sam_nodes_g[sam_node_base + state];
have_cur_nd = true;
uint elo = cur_nd.z;
uint ehi = cur_nd.w;
while (elo < ehi) {
uint mid = elo + (ehi - elo) / 2u;
uint e = sam_edges_g[sam_edge_base + mid];
uint mc = e >> 24;
if (mc == c) { nx = (int)(e & 0xFFFFFFu); break; }
if (mc < c) { elo = mid + 1u; } else { ehi = mid; }
}
}
if (nx >= 0) {
state = (uint)nx;
matched += 1u;
break;
}
if (state == 0u) {
matched = 0u;
break;
}
// Link backtrack — `cur_nd` is loaded above for state>0.
uint4 nd = have_cur_nd
? cur_nd
: ((state < k_hot) ? nodes_tg[state] : sam_nodes_g[sam_node_base + state]);
state = nd.y;
matched = nd.x;
}
fmatch_out[out_base + i] = matched;
fstate_out[out_base + i] = state;
}
}
}
// Stage-4a-3: same matching-stats walk as matching_stats_one_pair, but BATCHED — one thread per
// pair, K pairs processed by a single dispatch_threads call. The pairs share the corpus buffers:
//
// pair_a_idx[t], pair_b_idx[t] — which a-string and which SAM thread t handles
// a_data[a_offsets[i]..a_offsets[i+1]] — string i's bytes
// sam_nodes[sam_node_offs[j]..] — SAM j's nodes (uint4 each, units of uint4)
// sam_edges[sam_edge_offs[j]..] — SAM j's edges (ulong each, units of ulong)
// sam_root_next[j*128..(j+1)*128] — SAM j's root direct ASCII table
// fmatch_out[out_offsets[t]..out_offsets[t+1]] — thread t's per-position fmatch
// fstate_out[out_offsets[t]..out_offsets[t+1]] — thread t's per-position fstate
//
// edge_lo/edge_hi in nodes are LOCAL indices into the SAM's edge range (the SAM never sees the
// global concatenated buffer); the kernel reads sam_edges[sam_edge_base + mid] where mid is the
// SAM-local index. Same applies to suffix-link targets (state indices) — they're local. This is
// why we don't need to rewrite any field during concatenation.
// One uint4 load per byte. MEASURED: hot/cold split (separate uint + uint2 buffers) is SLOWER —
// IR-level CSE keeps the uint4 layout to a single load instruction, while split forced two
// separate loads (~+60% i32-load count). uint4 wins on M3 due to wider memory ops + better
// instruction amortization. Layout: nd = (link_len_of_state, link, edge_lo, edge_hi).
kernel void matching_stats_batched(
device const uint* pair_a_idx [[buffer(0)]],
device const uint* pair_b_idx [[buffer(1)]],
device const uchar* a_data [[buffer(2)]],
device const uint* a_offsets [[buffer(3)]],
device const uint4* sam_nodes [[buffer(4)]],
device const uint* sam_node_offs [[buffer(5)]],
device const uint* sam_edges [[buffer(6)]],
device const uint* sam_edge_offs [[buffer(7)]],
device const int* sam_root_next [[buffer(8)]],
device uint* fmatch_out [[buffer(9)]],
device uint* fstate_out [[buffer(10)]],
device const uint* out_offsets [[buffer(11)]],
constant uint& n_pairs [[buffer(12)]],
uint tid [[thread_position_in_grid]]
) {
if (tid >= n_pairs) return;
uint a_idx = pair_a_idx[tid];
uint b_idx = pair_b_idx[tid];
uint a_lo = a_offsets[a_idx];
uint a_hi = a_offsets[a_idx + 1u];
uint a_len = a_hi - a_lo;
uint sam_node_base = sam_node_offs[b_idx];
uint sam_edge_base = sam_edge_offs[b_idx];
uint sam_root_base = b_idx * 128u;
uint out_base = out_offsets[tid];
uint state = 0u;
uint matched = 0u;
for (uint i = 0u; i < a_len; i++) {
uint c = (uint)a_data[a_lo + i];
for (;;) {
int nx = -1;
uint4 cur_nd;
bool have_cur_nd = false;
if (state == 0u) {
if (c < 128u) {
nx = sam_root_next[sam_root_base + c];
} else {
uint4 nd = sam_nodes[sam_node_base + 0u];
uint elo = nd.z;
uint ehi = nd.w;
while (elo < ehi) {
uint mid = elo + (ehi - elo) / 2u;
uint e = sam_edges[sam_edge_base + mid];
uint mc = e >> 24;
if (mc == c) { nx = (int)(e & 0xFFFFFFu); break; }
if (mc < c) { elo = mid + 1u; } else { ehi = mid; }
}
}
} else {
cur_nd = sam_nodes[sam_node_base + state];
have_cur_nd = true;
uint elo = cur_nd.z;
uint ehi = cur_nd.w;
while (elo < ehi) {
uint mid = elo + (ehi - elo) / 2u;
uint e = sam_edges[sam_edge_base + mid];
uint mc = e >> 24;
if (mc == c) { nx = (int)(e & 0xFFFFFFu); break; }
if (mc < c) { elo = mid + 1u; } else { ehi = mid; }
}
}
if (nx >= 0) {
state = (uint)nx;
matched += 1u;
break;
}
if (state == 0u) {
matched = 0u;
break;
}
// nd.x is precomputed by CorpusGpu::build to be len(link[state]) — read directly,
// skipping a second sam_nodes load.
uint4 nd = have_cur_nd ? cur_nd : sam_nodes[sam_node_base + state];
state = nd.y;
matched = nd.x;
}
fmatch_out[out_base + i] = matched;
fstate_out[out_base + i] = state;
}
}
";
pub struct Gpu {
device: Device,
queue: CommandQueue,
_library: Library,
smoke_pipeline: ComputePipelineState,
matching_stats_batched_pipeline: ComputePipelineState,
matching_stats_by_b_partial_pipeline: ComputePipelineState,
}
unsafe impl Send for Gpu {}
unsafe impl Sync for Gpu {}
impl Gpu {
#[must_use]
pub fn new() -> Option<Self> {
let device = Device::system_default()?;
let queue = device.new_command_queue();
set_queue_high_priority(&queue);
let options = metal::CompileOptions::new();
options.set_fast_math_enabled(true);
let library = match device.new_library_with_source(KERNELS, &options) {
Ok(lib) => lib,
Err(err) => {
eprintln!("difflib-fast: Metal kernel compile failed: {err}");
return None;
}
};
let smoke_pipeline = make_pipeline(&device, &library, "smoke_elementwise_add").ok()?;
let matching_stats_batched_pipeline =
make_pipeline(&device, &library, "matching_stats_batched").ok()?;
let matching_stats_by_b_partial_pipeline =
make_pipeline(&device, &library, "matching_stats_by_b_partial").ok()?;
let gpu = Gpu {
device,
queue,
_library: library,
smoke_pipeline,
matching_stats_batched_pipeline,
matching_stats_by_b_partial_pipeline,
};
gpu.warm_up();
Some(gpu)
}
fn warm_up(&self) {
let a: [u32; 1024] = [0; 1024];
let _ = self.smoke_elementwise_add(&a, &a);
}
#[must_use]
pub fn device_name(&self) -> String {
self.device.name().to_string()
}
#[must_use]
pub fn smoke_elementwise_add(&self, a: &[u32], b: &[u32]) -> Vec<u32> {
assert_eq!(a.len(), b.len(), "smoke_elementwise_add: inputs must match length");
let n = a.len();
if n == 0 {
return Vec::new();
}
let buf_a = self.upload_u32(a);
let buf_b = self.upload_u32(b);
let buf_out = self.empty_u32_buffer(n);
let n_u32 = n as u32;
let buf_n = self.device.new_buffer_with_data(
(&raw const n_u32).cast::<c_void>(),
std::mem::size_of::<u32>() as NSUInteger,
MTLResourceOptions::StorageModeShared,
);
let cmd = self.queue.new_command_buffer();
let encoder = cmd.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&self.smoke_pipeline);
encoder.set_buffer(0, Some(&buf_a), 0);
encoder.set_buffer(1, Some(&buf_b), 0);
encoder.set_buffer(2, Some(&buf_out), 0);
encoder.set_buffer(3, Some(&buf_n), 0);
let max_t = self.smoke_pipeline.max_total_threads_per_threadgroup() as usize;
let tg = max_t.min(n);
let grid_size = MTLSize::new(n as u64, 1, 1);
let tg_size = MTLSize::new(tg as u64, 1, 1);
encoder.dispatch_threads(grid_size, tg_size);
encoder.end_encoding();
cmd.commit();
cmd.wait_until_completed();
let out_ptr = buf_out.contents().cast::<u32>();
let slice = unsafe { std::slice::from_raw_parts(out_ptr, n) };
slice.to_vec()
}
fn upload_u32(&self, data: &[u32]) -> Buffer {
let bytes = std::mem::size_of_val(data) as NSUInteger;
self.device.new_buffer_with_data(
data.as_ptr().cast::<c_void>(),
bytes,
MTLResourceOptions::StorageModeShared,
)
}
fn empty_u32_buffer(&self, n: usize) -> Buffer {
let bytes = (n * std::mem::size_of::<u32>()) as NSUInteger;
self.device.new_buffer(bytes, MTLResourceOptions::StorageModeShared)
}
fn upload_buf<T: Copy>(&self, data: &[T]) -> Buffer {
let bytes = std::mem::size_of_val(data) as NSUInteger;
let bytes_safe = bytes.max(1);
self.device.new_buffer_with_data(
data.as_ptr().cast::<c_void>(),
bytes_safe,
MTLResourceOptions::StorageModeShared,
)
}
#[must_use]
pub fn matching_stats_batched_flat(
&self,
corpus: &CorpusGpu,
pairs: &[(u32, u32)],
) -> MatchingStatsFlat {
self.matching_stats_batched_flat_with_timings(corpus, pairs).0
}
#[must_use]
#[allow(clippy::similar_names, clippy::missing_panics_doc, clippy::too_many_lines)]
pub fn matching_stats_batched_flat_with_timings(
&self,
corpus: &CorpusGpu,
pairs: &[(u32, u32)],
) -> (MatchingStatsFlat, [u128; 5]) {
let n_pairs = pairs.len();
if n_pairs == 0 {
let empty = self.empty_u32_buffer(1);
return (
MatchingStatsFlat {
out_offsets: vec![0],
pair_orig_idx: Vec::new(),
fstate_buf: empty.clone(),
fmatch_buf: empty,
total_out: 0,
},
[0; 5],
);
}
let t1 = Instant::now();
let mut order: Vec<u32> = (0..n_pairs as u32).collect();
order.sort_by_key(|&t| {
let (a, b) = pairs[t as usize];
let a_lo = corpus.a_offsets_cpu[a as usize];
let a_hi = corpus.a_offsets_cpu[a as usize + 1];
(b, a_hi - a_lo, a)
});
let pair_a_idx: Vec<u32> = order.iter().map(|&t| pairs[t as usize].0).collect();
let pair_b_idx: Vec<u32> = order.iter().map(|&t| pairs[t as usize].1).collect();
let mut out_offsets: Vec<u32> = Vec::with_capacity(n_pairs + 1);
out_offsets.push(0);
let mut total_out: u32 = 0;
for &t_idx in &order {
let (a_idx, b_idx) = pairs[t_idx as usize];
assert!((a_idx as usize) < corpus.n_strings, "a_idx out of range");
assert!((b_idx as usize) < corpus.n_sams, "b_idx out of range");
let a_len = corpus.a_offsets_cpu[a_idx as usize + 1]
- corpus.a_offsets_cpu[a_idx as usize];
total_out = total_out.checked_add(a_len).expect("matching_stats_batched output too large");
out_offsets.push(total_out);
}
let stage_build_pairs = t1.elapsed().as_nanos();
let t2 = Instant::now();
let buf_pair_a = self.upload_buf(&pair_a_idx);
let buf_pair_b = self.upload_buf(&pair_b_idx);
let buf_out_offsets = self.upload_buf(&out_offsets);
let buf_fmatch = self.empty_u32_buffer(total_out as usize);
let buf_fstate = self.empty_u32_buffer(total_out as usize);
let n_pairs_u32 = n_pairs as u32;
let buf_n_pairs = self.upload_buf(std::slice::from_ref(&n_pairs_u32));
let stage_upload = t2.elapsed().as_nanos();
let t3 = Instant::now();
let cmd = self.queue.new_command_buffer();
let encoder = cmd.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&self.matching_stats_batched_pipeline);
encoder.set_buffer(0, Some(&buf_pair_a), 0);
encoder.set_buffer(1, Some(&buf_pair_b), 0);
encoder.set_buffer(2, Some(&corpus.a_data_buf), 0);
encoder.set_buffer(3, Some(&corpus.a_offsets_buf), 0);
encoder.set_buffer(4, Some(&corpus.sam_nodes_buf), 0);
encoder.set_buffer(5, Some(&corpus.sam_node_offsets_buf), 0);
encoder.set_buffer(6, Some(&corpus.sam_edges_buf), 0);
encoder.set_buffer(7, Some(&corpus.sam_edge_offsets_buf), 0);
encoder.set_buffer(8, Some(&corpus.sam_root_next_buf), 0);
encoder.set_buffer(9, Some(&buf_fmatch), 0);
encoder.set_buffer(10, Some(&buf_fstate), 0);
encoder.set_buffer(11, Some(&buf_out_offsets), 0);
encoder.set_buffer(12, Some(&buf_n_pairs), 0);
let max_t = self
.matching_stats_batched_pipeline
.max_total_threads_per_threadgroup() as usize;
let tg_env: usize = std::env::var("BENCH_TG")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(max_t);
let tg = tg_env.min(max_t).min(n_pairs).max(32);
encoder.dispatch_threads(
MTLSize::new(n_pairs as u64, 1, 1),
MTLSize::new(tg as u64, 1, 1),
);
encoder.end_encoding();
cmd.commit();
let stage_dispatch = t3.elapsed().as_nanos();
let t4 = Instant::now();
cmd.wait_until_completed();
let stage_wait = t4.elapsed().as_nanos();
let (kernel_start, kernel_end, gpu_start, gpu_end) = gpu_command_buffer_times(&cmd);
eprintln!(
" [gpu_times] kernel: {:.3} ms, gpu: {:.3} ms (gpu_start={:.6} end={:.6})",
(kernel_end - kernel_start) * 1000.0,
(gpu_end - gpu_start) * 1000.0,
gpu_start,
gpu_end,
);
let t5 = Instant::now();
let flat = MatchingStatsFlat {
out_offsets,
pair_orig_idx: order,
fstate_buf: buf_fstate,
fmatch_buf: buf_fmatch,
total_out: total_out as usize,
};
let stage_readback = t5.elapsed().as_nanos();
(flat, [stage_build_pairs, stage_upload, stage_dispatch, stage_wait, stage_readback])
}
}
impl Gpu {
#[must_use]
#[allow(clippy::similar_names, clippy::missing_panics_doc, clippy::too_many_lines)]
pub fn matching_stats_by_b_partial_flat_with_timings(
&self,
corpus: &CorpusGpu,
pairs: &[(u32, u32)],
) -> (MatchingStatsFlat, [u128; 5]) {
let n_pairs = pairs.len();
if n_pairs == 0 {
let empty = self.empty_u32_buffer(1);
return (
MatchingStatsFlat {
out_offsets: vec![0],
pair_orig_idx: Vec::new(),
fstate_buf: empty.clone(),
fmatch_buf: empty,
total_out: 0,
},
[0; 5],
);
}
let k_hot_nodes: u32 = corpus.k_hot_nodes_build;
let t1 = Instant::now();
let mut order: Vec<u32> = (0..n_pairs as u32).collect();
order.sort_by_key(|&t| {
let (a, b) = pairs[t as usize];
let a_lo = corpus.a_offsets_cpu[a as usize];
let a_hi = corpus.a_offsets_cpu[a as usize + 1];
(b, a_hi - a_lo, a)
});
let mut active_b_idx: Vec<u32> = Vec::new();
let mut pair_b_offsets: Vec<u32> = vec![0];
let mut pair_a_idx_sorted: Vec<u32> = Vec::with_capacity(n_pairs);
let mut out_offsets: Vec<u32> = Vec::with_capacity(n_pairs + 1);
out_offsets.push(0);
let mut total_out: u32 = 0;
let mut current_b: u32 = u32::MAX;
for (slot, &t_idx) in order.iter().enumerate() {
let (a_idx, b_idx) = pairs[t_idx as usize];
if b_idx != current_b {
if !active_b_idx.is_empty() {
pair_b_offsets.push(slot as u32);
}
active_b_idx.push(b_idx);
current_b = b_idx;
}
pair_a_idx_sorted.push(a_idx);
let a_len = corpus.a_offsets_cpu[a_idx as usize + 1]
- corpus.a_offsets_cpu[a_idx as usize];
total_out = total_out.checked_add(a_len).expect("output too large");
out_offsets.push(total_out);
}
pair_b_offsets.push(n_pairs as u32);
let n_active_b = active_b_idx.len();
let stage_build = t1.elapsed().as_nanos();
let t2 = Instant::now();
let buf_pair_a_sorted = self.upload_buf(&pair_a_idx_sorted);
let buf_pair_b_offsets = self.upload_buf(&pair_b_offsets);
let buf_active_b = self.upload_buf(&active_b_idx);
let buf_out_offsets = self.upload_buf(&out_offsets);
let buf_fmatch = self.empty_u32_buffer(total_out as usize);
let buf_fstate = self.empty_u32_buffer(total_out as usize);
let buf_k_hot_nodes = self.upload_buf(std::slice::from_ref(&k_hot_nodes));
let stage_upload = t2.elapsed().as_nanos();
let t3 = Instant::now();
let cmd = self.queue.new_command_buffer();
let encoder = cmd.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&self.matching_stats_by_b_partial_pipeline);
encoder.set_buffer(0, Some(&buf_pair_a_sorted), 0);
encoder.set_buffer(1, Some(&buf_pair_b_offsets), 0);
encoder.set_buffer(2, Some(&buf_active_b), 0);
encoder.set_buffer(3, Some(&corpus.a_data_buf), 0);
encoder.set_buffer(4, Some(&corpus.a_offsets_buf), 0);
encoder.set_buffer(5, Some(&corpus.sam_nodes_buf), 0);
encoder.set_buffer(6, Some(&corpus.sam_node_offsets_buf), 0);
encoder.set_buffer(7, Some(&corpus.sam_edges_buf), 0);
encoder.set_buffer(8, Some(&corpus.sam_edge_offsets_buf), 0);
encoder.set_buffer(9, Some(&corpus.sam_root_next_buf), 0);
encoder.set_buffer(10, Some(&buf_fmatch), 0);
encoder.set_buffer(11, Some(&buf_fstate), 0);
encoder.set_buffer(12, Some(&buf_out_offsets), 0);
encoder.set_buffer(13, Some(&buf_k_hot_nodes), 0);
let tg_mem_bytes = (k_hot_nodes as usize) * 16 + 128 * 4;
encoder.set_threadgroup_memory_length(0, tg_mem_bytes as NSUInteger);
let pipeline_max =
self.matching_stats_by_b_partial_pipeline.max_total_threads_per_threadgroup() as usize;
let max_pairs_in_a_group =
pair_b_offsets.windows(2).map(|w| (w[1] - w[0]) as usize).max().unwrap_or(1);
let tg = 128.min(pipeline_max).min(max_pairs_in_a_group.max(32).next_power_of_two());
encoder.dispatch_thread_groups(
MTLSize::new(n_active_b as u64, 1, 1),
MTLSize::new(tg as u64, 1, 1),
);
encoder.end_encoding();
cmd.commit();
let stage_dispatch = t3.elapsed().as_nanos();
let t4 = Instant::now();
cmd.wait_until_completed();
let stage_wait = t4.elapsed().as_nanos();
let (ks, ke, gs, ge) = gpu_command_buffer_times(&cmd);
eprintln!(
" [by_b_partial K={k_hot_nodes} tg={tg} mem={tg_mem_bytes}B] kernel: {:.3} ms, gpu: {:.3} ms (n_active_b={n_active_b})",
(ke - ks) * 1000.0,
(ge - gs) * 1000.0,
);
let t5 = Instant::now();
let flat = MatchingStatsFlat {
out_offsets,
pair_orig_idx: order,
fstate_buf: buf_fstate,
fmatch_buf: buf_fmatch,
total_out: total_out as usize,
};
let stage_readback = t5.elapsed().as_nanos();
(flat, [stage_build, stage_upload, stage_dispatch, stage_wait, stage_readback])
}
}
pub struct MatchingStatsFlat {
pub out_offsets: Vec<u32>,
pub pair_orig_idx: Vec<u32>,
fstate_buf: Buffer,
fmatch_buf: Buffer,
total_out: usize,
}
unsafe impl Send for MatchingStatsFlat {}
unsafe impl Sync for MatchingStatsFlat {}
impl MatchingStatsFlat {
#[must_use]
pub fn fstate_all(&self) -> &[u32] {
unsafe {
std::slice::from_raw_parts(self.fstate_buf.contents().cast::<u32>(), self.total_out)
}
}
#[must_use]
pub fn fmatch_all(&self) -> &[u32] {
unsafe {
std::slice::from_raw_parts(self.fmatch_buf.contents().cast::<u32>(), self.total_out)
}
}
#[must_use]
pub fn pair(&self, t: usize) -> (&[u32], &[u32]) {
let lo = self.out_offsets[t] as usize;
let hi = self.out_offsets[t + 1] as usize;
(&self.fstate_all()[lo..hi], &self.fmatch_all()[lo..hi])
}
#[must_use]
pub fn n_pairs(&self) -> usize {
self.out_offsets.len() - 1
}
}
pub struct CorpusGpu {
n_strings: usize,
n_sams: usize,
a_offsets_cpu: Vec<u32>,
a_data_buf: Buffer,
a_offsets_buf: Buffer,
sam_nodes_buf: Buffer,
sam_node_offsets_buf: Buffer,
sam_edges_buf: Buffer,
sam_edge_offsets_buf: Buffer,
sam_root_next_buf: Buffer,
k_hot_nodes_build: u32,
}
impl CorpusGpu {
#[must_use]
pub fn build(gpu: &Gpu, strings: &[&[u8]], sams: &[crate::gestalt::Sam]) -> Self {
assert_eq!(strings.len(), sams.len(), "CorpusGpu: must have one SAM per input string");
let total_str_bytes: usize = strings.iter().map(|s| s.len()).sum();
let mut a_data: Vec<u8> = Vec::with_capacity(total_str_bytes);
let mut a_offsets_cpu: Vec<u32> = Vec::with_capacity(strings.len() + 1);
a_offsets_cpu.push(0);
for s in strings {
a_data.extend_from_slice(s);
a_offsets_cpu.push(a_data.len() as u32);
}
let total_nodes: usize = sams.iter().map(|s| s.nodes().len()).sum();
let mut sam_nodes: Vec<[u32; 4]> = Vec::with_capacity(total_nodes);
let mut sam_node_offsets: Vec<u32> = Vec::with_capacity(sams.len() + 1);
sam_node_offsets.push(0);
for sam in sams {
let nodes = sam.nodes();
for (state, &node) in nodes.iter().enumerate() {
let link = node[1] as usize;
let link_len = if state == 0 { 0 } else { nodes[link][0] };
let edge_count = node[3] - node[2];
assert!(edge_count <= 255, "edge count {edge_count} exceeds u8 — bump packing");
sam_nodes.push([link_len, node[1], node[2], node[3]]);
}
sam_node_offsets.push(sam_nodes.len() as u32);
}
let total_edges: usize = sams.iter().map(|s| s.edges_packed().len()).sum();
let mut sam_edges: Vec<u32> = Vec::with_capacity(total_edges);
let mut sam_edge_offsets: Vec<u32> = Vec::with_capacity(sams.len() + 1);
sam_edge_offsets.push(0);
for sam in sams {
let edges = sam.edges_packed();
for &e in edges {
let c = (e >> 32) as u32;
let target = (e & 0xffff_ffff) as u32;
assert!(c < 128, "ASCII corpus only — non-ASCII edge char");
assert!(target < (1 << 24), "SAM exceeds 16M states — bump packing width");
sam_edges.push((c << 24) | target);
}
sam_edge_offsets.push(sam_edges.len() as u32);
}
let mut sam_root_next: Vec<i32> = Vec::with_capacity(sams.len() * 128);
for sam in sams {
let rn = sam.root_next_table();
assert_eq!(rn.len(), 128, "SAM root_next must be 128 entries");
sam_root_next.extend_from_slice(rn);
}
let k_hot_nodes_build: u32 = std::env::var("DFGPU_K_HOT_NODES_BUILD")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(128);
CorpusGpu {
n_strings: strings.len(),
n_sams: sams.len(),
a_data_buf: gpu.upload_buf(&a_data),
a_offsets_buf: gpu.upload_buf(&a_offsets_cpu),
sam_nodes_buf: gpu.upload_buf(&sam_nodes),
sam_node_offsets_buf: gpu.upload_buf(&sam_node_offsets),
sam_edges_buf: gpu.upload_buf(&sam_edges),
sam_edge_offsets_buf: gpu.upload_buf(&sam_edge_offsets),
sam_root_next_buf: gpu.upload_buf(&sam_root_next),
k_hot_nodes_build,
a_offsets_cpu,
}
}
#[must_use]
pub fn n_sams(&self) -> usize {
self.n_sams
}
}
unsafe impl Send for CorpusGpu {}
unsafe impl Sync for CorpusGpu {}
fn make_pipeline(
device: &Device,
library: &Library,
fn_name: &str,
) -> Result<ComputePipelineState, String> {
let func = library.get_function(fn_name, None).map_err(|e| format!("get_function({fn_name}): {e}"))?;
let desc = ComputePipelineDescriptor::new();
desc.set_compute_function(Some(&func));
device.new_compute_pipeline_state_with_function(&func).map_err(|e| format!("pipeline({fn_name}): {e}"))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn gpu_acquires_metal_device() {
let Some(gpu) = Gpu::new() else {
eprintln!("no Metal device on this machine — skipping GPU tests");
return;
};
let name = gpu.device_name();
eprintln!("Metal device: {name}");
assert!(!name.is_empty(), "device name must be non-empty");
}
#[test]
fn smoke_elementwise_add_correct() {
let Some(gpu) = Gpu::new() else { return };
let a: Vec<u32> = (0..1024).collect();
let b: Vec<u32> = (0..1024).map(|x| x * 2).collect();
let got = gpu.smoke_elementwise_add(&a, &b);
let want: Vec<u32> = a.iter().zip(b.iter()).map(|(x, y)| x + y).collect();
assert_eq!(got, want, "GPU elementwise add disagrees with CPU");
}
#[test]
fn smoke_handles_short_inputs() {
let Some(gpu) = Gpu::new() else { return };
assert_eq!(gpu.smoke_elementwise_add(&[], &[]), Vec::<u32>::new());
assert_eq!(gpu.smoke_elementwise_add(&[1], &[2]), vec![3]);
let big = vec![5u32; 100_000];
let got = gpu.smoke_elementwise_add(&big, &big);
assert!(got.iter().all(|&x| x == 10));
}
#[test]
#[allow(clippy::similar_names)] fn batched_flat_matches_cpu_on_real_corpus() {
let Some(gpu) = Gpu::new() else { return };
let Ok(data) = std::fs::read_to_string("benchmarks/corpora/mypy.canon.bin") else {
return; };
let strings_str: Vec<&str> = data
.split('\0')
.filter(|s| !s.is_empty() && s.is_ascii())
.take(8)
.collect();
if strings_str.len() < 2 {
return;
}
let strings_bytes: Vec<Vec<u8>> =
strings_str.iter().map(|s| s.as_bytes().to_vec()).collect();
let strings_chars: Vec<Vec<char>> = strings_str.iter().map(|s| s.chars().collect()).collect();
let sams: Vec<crate::gestalt::Sam> =
strings_chars.iter().map(|c| crate::gestalt::build_sam(c)).collect();
let byte_refs: Vec<&[u8]> = strings_bytes.iter().map(Vec::as_slice).collect();
let corpus = CorpusGpu::build(&gpu, &byte_refs, &sams);
let n = strings_str.len();
let mut pairs: Vec<(u32, u32)> = Vec::with_capacity(n * (n - 1));
for i in 0..n {
for j in 0..n {
if i != j {
pairs.push((i as u32, j as u32));
}
}
}
let flat = gpu.matching_stats_batched_flat(&corpus, &pairs);
let fstate_all = flat.fstate_all();
let fmatch_all = flat.fmatch_all();
for slot in 0..pairs.len() {
let orig = flat.pair_orig_idx[slot] as usize;
let (a_idx, b_idx) = pairs[orig];
let lo = flat.out_offsets[slot] as usize;
let hi = flat.out_offsets[slot + 1] as usize;
let fstate_gpu = &fstate_all[lo..hi];
let fmatch_gpu = &fmatch_all[lo..hi];
let mut fstate_cpu = Vec::new();
let mut fmatch_cpu = Vec::new();
crate::gestalt::matching_stats_for_test(
&strings_chars[a_idx as usize],
&sams[b_idx as usize],
&mut fstate_cpu,
&mut fmatch_cpu,
);
assert_eq!(
fstate_gpu, &fstate_cpu[..],
"fstate diverges on pair (a={a_idx}, b={b_idx})"
);
assert_eq!(
fmatch_gpu, &fmatch_cpu[..],
"fmatch diverges on pair (a={a_idx}, b={b_idx})"
);
}
}
}