#[cfg(feature = "gpu")]
use metal::*;
#[cfg(feature = "gpu")]
const KMEANS_ASSIGN_SHADER: &str = r#"
#include <metal_stdlib>
using namespace metal;
kernel void kmeans_assign(
device const float* data [[buffer(0)]], // N x D, row-major
device const float* centroids [[buffer(1)]], // K x D, row-major
device uint* labels [[buffer(2)]], // N output labels
constant uint& n [[buffer(3)]],
constant uint& k [[buffer(4)]],
constant uint& d [[buffer(5)]],
uint tid [[thread_position_in_grid]]
) {
if (tid >= n) return;
device const float* point = data + tid * d;
float best_dist = INFINITY;
uint best_k = 0;
for (uint c = 0; c < k; c++) {
device const float* centroid = centroids + c * d;
float dist = 0.0;
for (uint j = 0; j < d; j++) {
float diff = point[j] - centroid[j];
dist += diff * diff;
}
if (dist < best_dist) {
best_dist = dist;
best_k = c;
}
}
labels[tid] = best_k;
}
"#;
#[cfg(feature = "gpu")]
pub(crate) struct GpuAssigner {
device: Device,
queue: CommandQueue,
pipeline: ComputePipelineState,
data_buf: Buffer,
label_buf: Buffer,
param_buf: Buffer, n: usize,
_k: usize,
_d: usize,
thread_group_size: u64,
}
#[cfg(feature = "gpu")]
impl GpuAssigner {
pub(crate) fn new(data_flat: &[f32], n: usize, k: usize, d: usize) -> Option<Self> {
let device = Device::system_default()?;
let queue = device.new_command_queue();
let options = CompileOptions::new();
let library = device
.new_library_with_source(KMEANS_ASSIGN_SHADER, &options)
.ok()?;
let function = library.get_function("kmeans_assign", None).ok()?;
let pipeline = device
.new_compute_pipeline_state_with_function(&function)
.ok()?;
let data_buf = device.new_buffer_with_data(
data_flat.as_ptr() as *const _,
(data_flat.len() * 4) as u64,
MTLResourceOptions::StorageModeShared,
);
let label_buf = device.new_buffer((n * 4) as u64, MTLResourceOptions::StorageModeShared);
let params: [u32; 3] = [n as u32, k as u32, d as u32];
let param_buf = device.new_buffer_with_data(
params.as_ptr() as *const _,
12,
MTLResourceOptions::StorageModeShared,
);
let thread_group_size = pipeline.max_total_threads_per_threadgroup().min(256);
Some(Self {
device,
queue,
pipeline,
data_buf,
label_buf,
param_buf,
n,
_k: k,
_d: d,
thread_group_size,
})
}
#[allow(unsafe_code)]
pub(crate) fn assign(&self, centroids_flat: &[f32]) -> Vec<usize> {
let cent_buf = self.device.new_buffer_with_data(
centroids_flat.as_ptr() as *const _,
(centroids_flat.len() * 4) as u64,
MTLResourceOptions::StorageModeShared,
);
let cmd = self.queue.new_command_buffer();
let encoder = cmd.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&self.pipeline);
encoder.set_buffer(0, Some(&self.data_buf), 0);
encoder.set_buffer(1, Some(¢_buf), 0);
encoder.set_buffer(2, Some(&self.label_buf), 0);
encoder.set_buffer(3, Some(&self.param_buf), 0); encoder.set_buffer(4, Some(&self.param_buf), 4); encoder.set_buffer(5, Some(&self.param_buf), 8);
let grid_size = MTLSize::new(self.n as u64, 1, 1);
let group_size = MTLSize::new(self.thread_group_size, 1, 1);
encoder.dispatch_threads(grid_size, group_size);
encoder.end_encoding();
cmd.commit();
cmd.wait_until_completed();
let ptr = self.label_buf.contents() as *const u32;
let labels_u32 = unsafe { std::slice::from_raw_parts(ptr, self.n) };
labels_u32.iter().map(|&l| l as usize).collect()
}
}
#[cfg(feature = "gpu")]
pub(crate) fn flatten(data: &(impl super::flat::DataRef + ?Sized)) -> Vec<f32> {
let n = data.n();
let d = data.d();
let mut flat = Vec::with_capacity(n * d);
for i in 0..n {
flat.extend_from_slice(data.row(i));
}
flat
}