#[cfg(target_os = "macos")]
pub mod metal_acceleration {
use metal::{Device, MTLResourceOptions};
use ndarray::Array2;
pub struct MetalKNN {
device: Device,
command_queue: CommandQueue,
distance_kernel: ComputePipelineState,
}
impl MetalKNN {
pub fn compute_distances(&self, X: &Array2<f32>, query: &[f32]) -> Vec<f32> {
let buffer = self.device.new_buffer_with_data(
X.as_slice().unwrap().as_ptr() as *const _,
(X.len() * std::mem::size_of::<f32>()) as u64,
MTLResourceOptions::StorageModeShared,
);
let command_buffer = self.command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&self.distance_kernel);
encoder.set_buffer(0, Some(&buffer), 0);
encoder.dispatch_threads(
MTLSize::new(X.nrows() as u64, 1, 1),
MTLSize::new(64, 1, 1),
);
encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
self.read_distances_zero_copy(&buffer)
}
}
}