use std::marker::PhantomData;
use std::ptr::NonNull;
mod metal_helpers;
use metal_helpers::{alloc_buffer, init_device_and_queue, FnConstant, PsoCache};
use objc2::rc::Retained;
use objc2::runtime::ProtocolObject;
use objc2_foundation::NSString;
use objc2_metal::{
MTLBuffer, MTLCommandBuffer, MTLCommandBufferStatus, MTLCommandEncoder, MTLCommandQueue,
MTLComputeCommandEncoder, MTLDevice, MTLLibrary, MTLSize,
};
mod private {
pub trait Sealed {}
impl Sealed for u32 {}
impl Sealed for i32 {}
impl Sealed for f32 {}
impl Sealed for u64 {}
impl Sealed for i64 {}
impl Sealed for f64 {}
}
pub trait SortKey: private::Sealed + Copy + 'static {
const KEY_SIZE: usize;
const NEEDS_TRANSFORM: bool;
const IS_64BIT: bool;
const TRANSFORM_MODE_FORWARD: u32;
const TRANSFORM_MODE_INVERSE: u32;
}
impl SortKey for u32 {
const KEY_SIZE: usize = 4;
const NEEDS_TRANSFORM: bool = false;
const IS_64BIT: bool = false;
const TRANSFORM_MODE_FORWARD: u32 = 0;
const TRANSFORM_MODE_INVERSE: u32 = 0;
}
impl SortKey for i32 {
const KEY_SIZE: usize = 4;
const NEEDS_TRANSFORM: bool = true;
const IS_64BIT: bool = false;
const TRANSFORM_MODE_FORWARD: u32 = 0;
const TRANSFORM_MODE_INVERSE: u32 = 0;
}
impl SortKey for f32 {
const KEY_SIZE: usize = 4;
const NEEDS_TRANSFORM: bool = true;
const IS_64BIT: bool = false;
const TRANSFORM_MODE_FORWARD: u32 = 1;
const TRANSFORM_MODE_INVERSE: u32 = 2;
}
impl SortKey for u64 {
const KEY_SIZE: usize = 8;
const NEEDS_TRANSFORM: bool = false;
const IS_64BIT: bool = true;
const TRANSFORM_MODE_FORWARD: u32 = 0;
const TRANSFORM_MODE_INVERSE: u32 = 0;
}
impl SortKey for i64 {
const KEY_SIZE: usize = 8;
const NEEDS_TRANSFORM: bool = true;
const IS_64BIT: bool = true;
const TRANSFORM_MODE_FORWARD: u32 = 0;
const TRANSFORM_MODE_INVERSE: u32 = 0;
}
impl SortKey for f64 {
const KEY_SIZE: usize = 8;
const NEEDS_TRANSFORM: bool = true;
const IS_64BIT: bool = true;
const TRANSFORM_MODE_FORWARD: u32 = 1;
const TRANSFORM_MODE_INVERSE: u32 = 2;
}
const TILE_SIZE: usize = 4096;
const TILE_SIZE_64: usize = 2048;
const THREADS_PER_TG: usize = 256;
#[derive(Debug, thiserror::Error)]
pub enum SortError {
#[error("no Metal GPU device found")]
DeviceNotFound,
#[error("shader compilation failed: {0}")]
ShaderCompilation(String),
#[error("GPU execution failed: {0}")]
GpuExecution(String),
#[error("length mismatch: keys={keys}, values={values}")]
LengthMismatch { keys: usize, values: usize },
}
#[repr(C)]
#[derive(Clone, Copy)]
struct SortParams {
element_count: u32,
num_tiles: u32,
shift: u32,
pass: u32,
}
#[repr(C)]
#[derive(Clone, Copy)]
struct InnerParams {
start_shift: u32,
pass_count: u32,
batch_start: u32,
}
#[repr(C)]
#[derive(Clone, Copy)]
struct BucketDesc {
offset: u32,
count: u32,
tile_count: u32,
tile_base: u32,
}
pub struct GpuSorter {
device: Retained<ProtocolObject<dyn MTLDevice>>,
queue: Retained<ProtocolObject<dyn MTLCommandQueue>>,
library: Retained<ProtocolObject<dyn MTLLibrary>>,
pso_cache: PsoCache,
buf_a: Option<Retained<ProtocolObject<dyn MTLBuffer>>>,
buf_b: Option<Retained<ProtocolObject<dyn MTLBuffer>>>,
buf_msd_hist: Option<Retained<ProtocolObject<dyn MTLBuffer>>>,
buf_counters: Option<Retained<ProtocolObject<dyn MTLBuffer>>>,
buf_bucket_descs: Option<Retained<ProtocolObject<dyn MTLBuffer>>>,
data_buf_capacity: usize,
buf_vals_a: Option<Retained<ProtocolObject<dyn MTLBuffer>>>,
buf_vals_b: Option<Retained<ProtocolObject<dyn MTLBuffer>>>,
vals_buf_capacity: usize,
buf_orig_vals: Option<Retained<ProtocolObject<dyn MTLBuffer>>>,
orig_vals_capacity: usize,
data_buf_capacity_64: usize,
psos_64bit_compiled: bool,
}
pub struct SortBuffer<T: SortKey> {
buffer: Retained<ProtocolObject<dyn MTLBuffer>>,
len: usize,
capacity: usize,
_marker: PhantomData<T>,
}
impl<T: SortKey> SortBuffer<T> {
pub fn len(&self) -> usize {
self.len
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn capacity(&self) -> usize {
self.capacity
}
pub fn as_mut_slice(&mut self) -> &mut [T] {
unsafe {
std::slice::from_raw_parts_mut(
self.buffer.contents().as_ptr() as *mut T,
self.capacity,
)
}
}
pub fn as_slice(&self) -> &[T] {
unsafe {
std::slice::from_raw_parts(self.buffer.contents().as_ptr() as *const T, self.len)
}
}
pub fn set_len(&mut self, len: usize) {
assert!(
len <= self.capacity,
"len {} exceeds capacity {}",
len,
self.capacity
);
self.len = len;
}
pub fn copy_from_slice(&mut self, data: &[T]) {
assert!(
data.len() <= self.capacity,
"data len {} exceeds capacity {}",
data.len(),
self.capacity
);
unsafe {
std::ptr::copy_nonoverlapping(
data.as_ptr(),
self.buffer.contents().as_ptr() as *mut T,
data.len(),
);
}
self.len = data.len();
}
pub fn copy_to_slice(&self, dest: &mut [T]) {
let n = self.len.min(dest.len());
unsafe {
std::ptr::copy_nonoverlapping(
self.buffer.contents().as_ptr() as *const T,
dest.as_mut_ptr(),
n,
);
}
}
pub fn metal_buffer(&self) -> &ProtocolObject<dyn MTLBuffer> {
&self.buffer
}
}
#[allow(clippy::too_many_arguments)]
fn dispatch_sort(
queue: &ProtocolObject<dyn MTLCommandQueue>,
library: &ProtocolObject<dyn MTLLibrary>,
pso_cache: &mut PsoCache,
buf_a: &ProtocolObject<dyn MTLBuffer>,
buf_b: &ProtocolObject<dyn MTLBuffer>,
buf_msd_hist: &ProtocolObject<dyn MTLBuffer>,
buf_counters: &ProtocolObject<dyn MTLBuffer>,
buf_bucket_descs: &ProtocolObject<dyn MTLBuffer>,
n: usize,
num_tiles: usize,
) -> Result<(), SortError> {
let params = SortParams {
element_count: n as u32,
num_tiles: num_tiles as u32,
shift: 24,
pass: 0,
};
let tile_size_u32 = TILE_SIZE as u32;
let tg_size = MTLSize {
width: THREADS_PER_TG,
height: 1,
depth: 1,
};
let hist_grid = MTLSize {
width: num_tiles,
height: 1,
depth: 1,
};
let one_tg_grid = MTLSize {
width: 1,
height: 1,
depth: 1,
};
let fused_grid = MTLSize {
width: 256,
height: 1,
depth: 1,
};
let cmd = queue.commandBuffer().ok_or_else(|| {
SortError::GpuExecution("failed to create command buffer".to_string())
})?;
let enc = cmd.computeCommandEncoder().ok_or_else(|| {
SortError::GpuExecution("failed to create compute encoder".to_string())
})?;
let pso = pso_cache.get_or_create_specialized(library, "sort_msd_histogram", &[(1, FnConstant::Bool(false))]);
enc.setComputePipelineState(pso);
unsafe {
enc.setBuffer_offset_atIndex(Some(buf_a), 0, 0);
enc.setBuffer_offset_atIndex(Some(buf_msd_hist), 0, 1);
enc.setBytes_length_atIndex(
NonNull::new(¶ms as *const SortParams as *mut _).unwrap(),
std::mem::size_of::<SortParams>(),
2,
);
}
enc.dispatchThreadgroups_threadsPerThreadgroup(hist_grid, tg_size);
let pso = pso_cache.get_or_create(library, "sort_msd_prep");
enc.setComputePipelineState(pso);
unsafe {
enc.setBuffer_offset_atIndex(Some(buf_msd_hist), 0, 0);
enc.setBuffer_offset_atIndex(Some(buf_counters), 0, 1);
enc.setBuffer_offset_atIndex(Some(buf_bucket_descs), 0, 2);
enc.setBytes_length_atIndex(
NonNull::new(&tile_size_u32 as *const u32 as *mut _).unwrap(),
4,
3,
);
}
enc.dispatchThreadgroups_threadsPerThreadgroup(one_tg_grid, tg_size);
let pso = pso_cache.get_or_create_specialized(library, "sort_msd_atomic_scatter", &[(0, FnConstant::Bool(false))]);
enc.setComputePipelineState(pso);
unsafe {
enc.setBuffer_offset_atIndex(Some(buf_a), 0, 0);
enc.setBuffer_offset_atIndex(Some(buf_b), 0, 1);
enc.setBuffer_offset_atIndex(Some(buf_counters), 0, 2);
enc.setBytes_length_atIndex(
NonNull::new(¶ms as *const SortParams as *mut _).unwrap(),
std::mem::size_of::<SortParams>(),
3,
);
}
enc.dispatchThreadgroups_threadsPerThreadgroup(hist_grid, tg_size);
let inner_params = InnerParams {
start_shift: 0,
pass_count: 3,
batch_start: 0,
};
let pso = pso_cache.get_or_create_specialized(library, "sort_inner_fused", &[(0, FnConstant::Bool(false))]);
enc.setComputePipelineState(pso);
unsafe {
enc.setBuffer_offset_atIndex(Some(buf_a), 0, 0);
enc.setBuffer_offset_atIndex(Some(buf_b), 0, 1);
enc.setBuffer_offset_atIndex(Some(buf_bucket_descs), 0, 2);
enc.setBytes_length_atIndex(
NonNull::new(&inner_params as *const InnerParams as *mut _).unwrap(),
std::mem::size_of::<InnerParams>(),
3,
);
}
enc.dispatchThreadgroups_threadsPerThreadgroup(fused_grid, tg_size);
enc.endEncoding();
cmd.commit();
cmd.waitUntilCompleted();
if cmd.status() == MTLCommandBufferStatus::Error {
return Err(SortError::GpuExecution(format!(
"command buffer error: {:?}",
cmd.error()
)));
}
Ok(())
}
fn encode_transform_32(
encoder: &ProtocolObject<dyn MTLComputeCommandEncoder>,
library: &ProtocolObject<dyn MTLLibrary>,
pso_cache: &mut PsoCache,
buf: &ProtocolObject<dyn MTLBuffer>,
n: usize,
mode: u32,
) {
let pso = pso_cache.get_or_create_specialized(
library,
"sort_transform_32",
&[(2, FnConstant::U32(mode))],
);
encoder.setComputePipelineState(pso);
let count = n as u32;
unsafe {
encoder.setBuffer_offset_atIndex(Some(buf), 0, 0);
encoder.setBytes_length_atIndex(
NonNull::new(&count as *const u32 as *mut _).unwrap(),
4,
1,
);
}
let grid = MTLSize {
width: n.div_ceil(THREADS_PER_TG) * THREADS_PER_TG,
height: 1,
depth: 1,
};
let tg = MTLSize {
width: THREADS_PER_TG,
height: 1,
depth: 1,
};
encoder.dispatchThreads_threadsPerThreadgroup(grid, tg);
}
fn encode_transform_64(
encoder: &ProtocolObject<dyn MTLComputeCommandEncoder>,
library: &ProtocolObject<dyn MTLLibrary>,
pso_cache: &mut PsoCache,
buf: &ProtocolObject<dyn MTLBuffer>,
n: usize,
mode: u32,
) {
let pso = pso_cache.get_or_create_specialized(
library,
"sort_transform_64",
&[(2, FnConstant::U32(mode))],
);
encoder.setComputePipelineState(pso);
let count = n as u32;
unsafe {
encoder.setBuffer_offset_atIndex(Some(buf), 0, 0);
encoder.setBytes_length_atIndex(
NonNull::new(&count as *const u32 as *mut _).unwrap(),
4,
1,
);
}
let grid = MTLSize {
width: n.div_ceil(THREADS_PER_TG) * THREADS_PER_TG,
height: 1,
depth: 1,
};
let tg = MTLSize {
width: THREADS_PER_TG,
height: 1,
depth: 1,
};
encoder.dispatchThreads_threadsPerThreadgroup(grid, tg);
}
#[allow(clippy::too_many_arguments)]
fn encode_sort_pipeline(
encoder: &ProtocolObject<dyn MTLComputeCommandEncoder>,
library: &ProtocolObject<dyn MTLLibrary>,
pso_cache: &mut PsoCache,
buf_a: &ProtocolObject<dyn MTLBuffer>,
buf_b: &ProtocolObject<dyn MTLBuffer>,
buf_msd_hist: &ProtocolObject<dyn MTLBuffer>,
buf_counters: &ProtocolObject<dyn MTLBuffer>,
buf_bucket_descs: &ProtocolObject<dyn MTLBuffer>,
n: usize,
num_tiles: usize,
) {
let params = SortParams {
element_count: n as u32,
num_tiles: num_tiles as u32,
shift: 24,
pass: 0,
};
let tile_size_u32 = TILE_SIZE as u32;
let tg_size = MTLSize {
width: THREADS_PER_TG,
height: 1,
depth: 1,
};
let hist_grid = MTLSize {
width: num_tiles,
height: 1,
depth: 1,
};
let one_tg_grid = MTLSize {
width: 1,
height: 1,
depth: 1,
};
let fused_grid = MTLSize {
width: 256,
height: 1,
depth: 1,
};
let pso = pso_cache.get_or_create_specialized(library, "sort_msd_histogram", &[(1, FnConstant::Bool(false))]);
encoder.setComputePipelineState(pso);
unsafe {
encoder.setBuffer_offset_atIndex(Some(buf_a), 0, 0);
encoder.setBuffer_offset_atIndex(Some(buf_msd_hist), 0, 1);
encoder.setBytes_length_atIndex(
NonNull::new(¶ms as *const SortParams as *mut _).unwrap(),
std::mem::size_of::<SortParams>(),
2,
);
}
encoder.dispatchThreadgroups_threadsPerThreadgroup(hist_grid, tg_size);
let pso = pso_cache.get_or_create(library, "sort_msd_prep");
encoder.setComputePipelineState(pso);
unsafe {
encoder.setBuffer_offset_atIndex(Some(buf_msd_hist), 0, 0);
encoder.setBuffer_offset_atIndex(Some(buf_counters), 0, 1);
encoder.setBuffer_offset_atIndex(Some(buf_bucket_descs), 0, 2);
encoder.setBytes_length_atIndex(
NonNull::new(&tile_size_u32 as *const u32 as *mut _).unwrap(),
4,
3,
);
}
encoder.dispatchThreadgroups_threadsPerThreadgroup(one_tg_grid, tg_size);
let pso = pso_cache.get_or_create_specialized(library, "sort_msd_atomic_scatter", &[(0, FnConstant::Bool(false))]);
encoder.setComputePipelineState(pso);
unsafe {
encoder.setBuffer_offset_atIndex(Some(buf_a), 0, 0);
encoder.setBuffer_offset_atIndex(Some(buf_b), 0, 1);
encoder.setBuffer_offset_atIndex(Some(buf_counters), 0, 2);
encoder.setBytes_length_atIndex(
NonNull::new(¶ms as *const SortParams as *mut _).unwrap(),
std::mem::size_of::<SortParams>(),
3,
);
}
encoder.dispatchThreadgroups_threadsPerThreadgroup(hist_grid, tg_size);
let inner_params = InnerParams {
start_shift: 0,
pass_count: 3,
batch_start: 0,
};
let pso = pso_cache.get_or_create_specialized(library, "sort_inner_fused", &[(0, FnConstant::Bool(false))]);
encoder.setComputePipelineState(pso);
unsafe {
encoder.setBuffer_offset_atIndex(Some(buf_a), 0, 0);
encoder.setBuffer_offset_atIndex(Some(buf_b), 0, 1);
encoder.setBuffer_offset_atIndex(Some(buf_bucket_descs), 0, 2);
encoder.setBytes_length_atIndex(
NonNull::new(&inner_params as *const InnerParams as *mut _).unwrap(),
std::mem::size_of::<InnerParams>(),
3,
);
}
encoder.dispatchThreadgroups_threadsPerThreadgroup(fused_grid, tg_size);
}
impl GpuSorter {
pub fn new() -> Result<Self, SortError> {
let (device, queue) = init_device_and_queue();
let metallib_path = env!("SORT_METALLIB_PATH");
let path_ns = NSString::from_str(metallib_path);
#[allow(deprecated)]
let library = device
.newLibraryWithFile_error(&path_ns)
.map_err(|e| SortError::ShaderCompilation(format!("{:?}", e)))?;
let mut pso_cache = PsoCache::new();
pso_cache.get_or_create_specialized(
&library,
"sort_msd_histogram",
&[(1, FnConstant::Bool(false))],
);
pso_cache.get_or_create(&library, "sort_msd_prep");
pso_cache.get_or_create_specialized(
&library,
"sort_msd_atomic_scatter",
&[(0, FnConstant::Bool(false))],
);
pso_cache.get_or_create_specialized(
&library,
"sort_inner_fused",
&[(0, FnConstant::Bool(false))],
);
for mode in 0u32..=2 {
pso_cache.get_or_create_specialized(
&library,
"sort_transform_32",
&[(2, FnConstant::U32(mode))],
);
}
for name in &["sort_init_indices", "sort_gather_values"] {
pso_cache.get_or_create(&library, name);
}
pso_cache.get_or_create_specialized(
&library,
"sort_msd_atomic_scatter",
&[(0, FnConstant::Bool(true))],
);
pso_cache.get_or_create_specialized(
&library,
"sort_inner_fused",
&[(0, FnConstant::Bool(true))],
);
Ok(Self {
device,
queue,
library,
pso_cache,
buf_a: None,
buf_b: None,
buf_msd_hist: None,
buf_counters: None,
buf_bucket_descs: None,
data_buf_capacity: 0,
buf_vals_a: None,
buf_vals_b: None,
vals_buf_capacity: 0,
buf_orig_vals: None,
orig_vals_capacity: 0,
data_buf_capacity_64: 0,
psos_64bit_compiled: false,
})
}
pub fn alloc_sort_buffer<T: SortKey>(&self, capacity: usize) -> SortBuffer<T> {
let buffer = alloc_buffer(&self.device, capacity * T::KEY_SIZE);
SortBuffer {
buffer,
len: 0,
capacity,
_marker: PhantomData,
}
}
pub fn sort_buffer(&mut self, buf: &SortBuffer<u32>) -> Result<(), SortError> {
let n = buf.len();
if n <= 1 {
return Ok(());
}
self.ensure_scratch_buffers(n);
let num_tiles = n.div_ceil(TILE_SIZE);
unsafe {
std::ptr::write_bytes(
self.buf_msd_hist.as_ref().unwrap().contents().as_ptr() as *mut u8,
0,
256 * 4,
);
}
dispatch_sort(
&self.queue,
&self.library,
&mut self.pso_cache,
&buf.buffer,
self.buf_b.as_ref().unwrap(),
self.buf_msd_hist.as_ref().unwrap(),
self.buf_counters.as_ref().unwrap(),
self.buf_bucket_descs.as_ref().unwrap(),
n,
num_tiles,
)
}
pub fn sort_u32(&mut self, data: &mut [u32]) -> Result<(), SortError> {
let n = data.len();
if n <= 1 {
return Ok(());
}
self.ensure_buffers(n);
let num_tiles = n.div_ceil(TILE_SIZE);
unsafe {
std::ptr::copy_nonoverlapping(
data.as_ptr() as *const u8,
self.buf_a.as_ref().unwrap().contents().as_ptr() as *mut u8,
n * 4,
);
std::ptr::write_bytes(
self.buf_msd_hist.as_ref().unwrap().contents().as_ptr() as *mut u8,
0,
256 * 4,
);
}
dispatch_sort(
&self.queue,
&self.library,
&mut self.pso_cache,
self.buf_a.as_ref().unwrap(),
self.buf_b.as_ref().unwrap(),
self.buf_msd_hist.as_ref().unwrap(),
self.buf_counters.as_ref().unwrap(),
self.buf_bucket_descs.as_ref().unwrap(),
n,
num_tiles,
)?;
unsafe {
std::ptr::copy_nonoverlapping(
self.buf_a.as_ref().unwrap().contents().as_ptr() as *const u32,
data.as_mut_ptr(),
n,
);
}
Ok(())
}
pub fn sort_i32(&mut self, data: &mut [i32]) -> Result<(), SortError> {
let n = data.len();
if n <= 1 {
return Ok(());
}
self.ensure_buffers(n);
let num_tiles = n.div_ceil(TILE_SIZE);
unsafe {
std::ptr::copy_nonoverlapping(
data.as_ptr() as *const u8,
self.buf_a.as_ref().unwrap().contents().as_ptr() as *mut u8,
n * 4,
);
std::ptr::write_bytes(
self.buf_msd_hist.as_ref().unwrap().contents().as_ptr() as *mut u8,
0,
256 * 4,
);
}
let cmd = self.queue.commandBuffer().ok_or_else(|| {
SortError::GpuExecution("failed to create command buffer".to_string())
})?;
let enc = cmd.computeCommandEncoder().ok_or_else(|| {
SortError::GpuExecution("failed to create compute encoder".to_string())
})?;
let buf_a = self.buf_a.as_ref().unwrap();
let buf_b = self.buf_b.as_ref().unwrap();
encode_transform_32(&enc, &self.library, &mut self.pso_cache, buf_a, n, 0);
encode_sort_pipeline(
&enc,
&self.library,
&mut self.pso_cache,
buf_a,
buf_b,
self.buf_msd_hist.as_ref().unwrap(),
self.buf_counters.as_ref().unwrap(),
self.buf_bucket_descs.as_ref().unwrap(),
n,
num_tiles,
);
encode_transform_32(&enc, &self.library, &mut self.pso_cache, buf_a, n, 0);
enc.endEncoding();
cmd.commit();
cmd.waitUntilCompleted();
if cmd.status() == MTLCommandBufferStatus::Error {
return Err(SortError::GpuExecution(format!(
"command buffer error: {:?}",
cmd.error()
)));
}
unsafe {
std::ptr::copy_nonoverlapping(
self.buf_a.as_ref().unwrap().contents().as_ptr() as *const i32,
data.as_mut_ptr(),
n,
);
}
Ok(())
}
pub fn sort_f32(&mut self, data: &mut [f32]) -> Result<(), SortError> {
let n = data.len();
if n <= 1 {
return Ok(());
}
self.ensure_buffers(n);
let num_tiles = n.div_ceil(TILE_SIZE);
unsafe {
std::ptr::copy_nonoverlapping(
data.as_ptr() as *const u8,
self.buf_a.as_ref().unwrap().contents().as_ptr() as *mut u8,
n * 4,
);
std::ptr::write_bytes(
self.buf_msd_hist.as_ref().unwrap().contents().as_ptr() as *mut u8,
0,
256 * 4,
);
}
let cmd = self.queue.commandBuffer().ok_or_else(|| {
SortError::GpuExecution("failed to create command buffer".to_string())
})?;
let enc = cmd.computeCommandEncoder().ok_or_else(|| {
SortError::GpuExecution("failed to create compute encoder".to_string())
})?;
let buf_a = self.buf_a.as_ref().unwrap();
let buf_b = self.buf_b.as_ref().unwrap();
encode_transform_32(&enc, &self.library, &mut self.pso_cache, buf_a, n, 1);
encode_sort_pipeline(
&enc,
&self.library,
&mut self.pso_cache,
buf_a,
buf_b,
self.buf_msd_hist.as_ref().unwrap(),
self.buf_counters.as_ref().unwrap(),
self.buf_bucket_descs.as_ref().unwrap(),
n,
num_tiles,
);
encode_transform_32(&enc, &self.library, &mut self.pso_cache, buf_a, n, 2);
enc.endEncoding();
cmd.commit();
cmd.waitUntilCompleted();
if cmd.status() == MTLCommandBufferStatus::Error {
return Err(SortError::GpuExecution(format!(
"command buffer error: {:?}",
cmd.error()
)));
}
unsafe {
std::ptr::copy_nonoverlapping(
self.buf_a.as_ref().unwrap().contents().as_ptr() as *const f32,
data.as_mut_ptr(),
n,
);
}
Ok(())
}
pub fn sort_f32_buffer(&mut self, buf: &SortBuffer<f32>) -> Result<(), SortError> {
let n = buf.len();
if n <= 1 {
return Ok(());
}
self.ensure_scratch_buffers(n);
let num_tiles = n.div_ceil(TILE_SIZE);
unsafe {
std::ptr::write_bytes(
self.buf_msd_hist.as_ref().unwrap().contents().as_ptr() as *mut u8,
0,
256 * 4,
);
}
let cmd = self.queue.commandBuffer().ok_or_else(|| {
SortError::GpuExecution("failed to create command buffer".to_string())
})?;
let enc = cmd.computeCommandEncoder().ok_or_else(|| {
SortError::GpuExecution("failed to create compute encoder".to_string())
})?;
encode_transform_32(&enc, &self.library, &mut self.pso_cache, &buf.buffer, n, 1);
encode_sort_pipeline(
&enc,
&self.library,
&mut self.pso_cache,
&buf.buffer,
self.buf_b.as_ref().unwrap(),
self.buf_msd_hist.as_ref().unwrap(),
self.buf_counters.as_ref().unwrap(),
self.buf_bucket_descs.as_ref().unwrap(),
n,
num_tiles,
);
encode_transform_32(&enc, &self.library, &mut self.pso_cache, &buf.buffer, n, 2);
enc.endEncoding();
cmd.commit();
cmd.waitUntilCompleted();
if cmd.status() == MTLCommandBufferStatus::Error {
return Err(SortError::GpuExecution(format!(
"command buffer error: {:?}",
cmd.error()
)));
}
Ok(())
}
pub fn sort_i32_buffer(&mut self, buf: &SortBuffer<i32>) -> Result<(), SortError> {
let n = buf.len();
if n <= 1 {
return Ok(());
}
self.ensure_scratch_buffers(n);
let num_tiles = n.div_ceil(TILE_SIZE);
unsafe {
std::ptr::write_bytes(
self.buf_msd_hist.as_ref().unwrap().contents().as_ptr() as *mut u8,
0,
256 * 4,
);
}
let cmd = self.queue.commandBuffer().ok_or_else(|| {
SortError::GpuExecution("failed to create command buffer".to_string())
})?;
let enc = cmd.computeCommandEncoder().ok_or_else(|| {
SortError::GpuExecution("failed to create compute encoder".to_string())
})?;
encode_transform_32(&enc, &self.library, &mut self.pso_cache, &buf.buffer, n, 0);
encode_sort_pipeline(
&enc,
&self.library,
&mut self.pso_cache,
&buf.buffer,
self.buf_b.as_ref().unwrap(),
self.buf_msd_hist.as_ref().unwrap(),
self.buf_counters.as_ref().unwrap(),
self.buf_bucket_descs.as_ref().unwrap(),
n,
num_tiles,
);
encode_transform_32(&enc, &self.library, &mut self.pso_cache, &buf.buffer, n, 0);
enc.endEncoding();
cmd.commit();
cmd.waitUntilCompleted();
if cmd.status() == MTLCommandBufferStatus::Error {
return Err(SortError::GpuExecution(format!(
"command buffer error: {:?}",
cmd.error()
)));
}
Ok(())
}
fn ensure_buffers(&mut self, n: usize) {
let data_bytes = n * 4;
if self.buf_a.is_none() || data_bytes > self.data_buf_capacity {
self.buf_a = Some(alloc_buffer(&self.device, data_bytes));
self.buf_b = Some(alloc_buffer(&self.device, data_bytes));
self.data_buf_capacity = data_bytes;
}
if self.buf_msd_hist.is_none() {
self.buf_msd_hist = Some(alloc_buffer(&self.device, 256 * 4));
self.buf_counters = Some(alloc_buffer(&self.device, 256 * 4));
self.buf_bucket_descs =
Some(alloc_buffer(&self.device, 256 * std::mem::size_of::<BucketDesc>()));
}
}
fn ensure_scratch_buffers(&mut self, n: usize) {
let data_bytes = n * 4;
if self.buf_b.is_none() || data_bytes > self.data_buf_capacity {
self.buf_b = Some(alloc_buffer(&self.device, data_bytes));
if self.buf_a.is_some() {
self.buf_a = Some(alloc_buffer(&self.device, data_bytes));
}
self.data_buf_capacity = data_bytes;
}
if self.buf_msd_hist.is_none() {
self.buf_msd_hist = Some(alloc_buffer(&self.device, 256 * 4));
self.buf_counters = Some(alloc_buffer(&self.device, 256 * 4));
self.buf_bucket_descs =
Some(alloc_buffer(&self.device, 256 * std::mem::size_of::<BucketDesc>()));
}
}
fn ensure_buffers_with_values(&mut self, n: usize) {
self.ensure_buffers(n);
let val_bytes = n * 4; if self.buf_vals_a.is_none() || val_bytes > self.vals_buf_capacity {
self.buf_vals_a = Some(alloc_buffer(&self.device, val_bytes));
self.buf_vals_b = Some(alloc_buffer(&self.device, val_bytes));
self.vals_buf_capacity = val_bytes;
}
}
fn ensure_buffers_with_values_and_orig(&mut self, n: usize) {
self.ensure_buffers_with_values(n);
let orig_bytes = n * 4; if self.buf_orig_vals.is_none() || orig_bytes > self.orig_vals_capacity {
self.buf_orig_vals = Some(alloc_buffer(&self.device, orig_bytes));
self.orig_vals_capacity = orig_bytes;
}
}
fn ensure_64bit_psos(&mut self) {
if self.psos_64bit_compiled {
return;
}
self.pso_cache.get_or_create_specialized(
&self.library,
"sort_msd_histogram",
&[(1, FnConstant::Bool(true))],
);
self.pso_cache.get_or_create_specialized(
&self.library,
"sort_msd_atomic_scatter",
&[(0, FnConstant::Bool(false)), (1, FnConstant::Bool(true))],
);
self.pso_cache.get_or_create_specialized(
&self.library,
"sort_msd_atomic_scatter",
&[(0, FnConstant::Bool(true)), (1, FnConstant::Bool(true))],
);
self.pso_cache.get_or_create_specialized(
&self.library,
"sort_inner_fused",
&[(0, FnConstant::Bool(false)), (1, FnConstant::Bool(true))],
);
self.pso_cache.get_or_create_specialized(
&self.library,
"sort_inner_fused",
&[(0, FnConstant::Bool(true)), (1, FnConstant::Bool(true))],
);
for mode in 0u32..=2 {
self.pso_cache.get_or_create_specialized(
&self.library,
"sort_transform_64",
&[(2, FnConstant::U32(mode))],
);
}
self.psos_64bit_compiled = true;
}
fn ensure_buffers_64(&mut self, n: usize) {
let data_bytes = n * 8;
if self.buf_a.is_none() || data_bytes > self.data_buf_capacity_64 {
self.buf_a = Some(alloc_buffer(&self.device, data_bytes));
self.buf_b = Some(alloc_buffer(&self.device, data_bytes));
self.data_buf_capacity_64 = data_bytes;
self.data_buf_capacity = data_bytes;
}
if self.buf_msd_hist.is_none() {
self.buf_msd_hist = Some(alloc_buffer(&self.device, 256 * 4));
self.buf_counters = Some(alloc_buffer(&self.device, 256 * 4));
self.buf_bucket_descs =
Some(alloc_buffer(&self.device, 256 * std::mem::size_of::<BucketDesc>()));
}
}
fn ensure_scratch_buffers_64(&mut self, n: usize) {
let data_bytes = n * 8;
if self.buf_b.is_none() || data_bytes > self.data_buf_capacity_64 {
self.buf_b = Some(alloc_buffer(&self.device, data_bytes));
if self.buf_a.is_some() {
self.buf_a = Some(alloc_buffer(&self.device, data_bytes));
}
self.data_buf_capacity_64 = data_bytes;
self.data_buf_capacity = data_bytes;
}
if self.buf_msd_hist.is_none() {
self.buf_msd_hist = Some(alloc_buffer(&self.device, 256 * 4));
self.buf_counters = Some(alloc_buffer(&self.device, 256 * 4));
self.buf_bucket_descs =
Some(alloc_buffer(&self.device, 256 * std::mem::size_of::<BucketDesc>()));
}
}
fn ensure_buffers_64_with_values(&mut self, n: usize) {
self.ensure_buffers_64(n);
let val_bytes = n * 4;
if self.buf_vals_a.is_none() || val_bytes > self.vals_buf_capacity {
self.buf_vals_a = Some(alloc_buffer(&self.device, val_bytes));
self.buf_vals_b = Some(alloc_buffer(&self.device, val_bytes));
self.vals_buf_capacity = val_bytes;
}
}
fn ensure_buffers_64_with_values_and_orig(&mut self, n: usize) {
self.ensure_buffers_64_with_values(n);
let orig_bytes = n * 4;
if self.buf_orig_vals.is_none() || orig_bytes > self.orig_vals_capacity {
self.buf_orig_vals = Some(alloc_buffer(&self.device, orig_bytes));
self.orig_vals_capacity = orig_bytes;
}
}
#[allow(clippy::too_many_arguments)]
fn encode_sort_pipeline_64(
&mut self,
encoder: &ProtocolObject<dyn MTLComputeCommandEncoder>,
buf_a: &ProtocolObject<dyn MTLBuffer>,
buf_b: &ProtocolObject<dyn MTLBuffer>,
vals_a: Option<&ProtocolObject<dyn MTLBuffer>>,
vals_b: Option<&ProtocolObject<dyn MTLBuffer>>,
n: usize,
num_tiles: usize,
with_values: bool,
) {
let params = SortParams {
element_count: n as u32,
num_tiles: num_tiles as u32,
shift: 56, pass: 0,
};
let tile_size_u32 = TILE_SIZE_64 as u32;
let tg_size = MTLSize {
width: THREADS_PER_TG,
height: 1,
depth: 1,
};
let hist_grid = MTLSize {
width: num_tiles,
height: 1,
depth: 1,
};
let one_tg_grid = MTLSize {
width: 1,
height: 1,
depth: 1,
};
let fused_grid = MTLSize {
width: 256,
height: 1,
depth: 1,
};
let pso = self.pso_cache.get_or_create_specialized(
&self.library,
"sort_msd_histogram",
&[(1, FnConstant::Bool(true))],
);
encoder.setComputePipelineState(pso);
unsafe {
encoder.setBuffer_offset_atIndex(Some(buf_a), 0, 0);
encoder.setBuffer_offset_atIndex(Some(self.buf_msd_hist.as_ref().unwrap()), 0, 1);
encoder.setBytes_length_atIndex(
NonNull::new(¶ms as *const SortParams as *mut _).unwrap(),
std::mem::size_of::<SortParams>(),
2,
);
}
encoder.dispatchThreadgroups_threadsPerThreadgroup(hist_grid, tg_size);
let pso = self.pso_cache.get_or_create(&self.library, "sort_msd_prep");
encoder.setComputePipelineState(pso);
unsafe {
encoder.setBuffer_offset_atIndex(Some(self.buf_msd_hist.as_ref().unwrap()), 0, 0);
encoder.setBuffer_offset_atIndex(Some(self.buf_counters.as_ref().unwrap()), 0, 1);
encoder.setBuffer_offset_atIndex(Some(self.buf_bucket_descs.as_ref().unwrap()), 0, 2);
encoder.setBytes_length_atIndex(
NonNull::new(&tile_size_u32 as *const u32 as *mut _).unwrap(),
4,
3,
);
}
encoder.dispatchThreadgroups_threadsPerThreadgroup(one_tg_grid, tg_size);
let pso = self.pso_cache.get_or_create_specialized(
&self.library,
"sort_msd_atomic_scatter",
&[(0, FnConstant::Bool(with_values)), (1, FnConstant::Bool(true))],
);
encoder.setComputePipelineState(pso);
unsafe {
encoder.setBuffer_offset_atIndex(Some(buf_a), 0, 0);
encoder.setBuffer_offset_atIndex(Some(buf_b), 0, 1);
encoder.setBuffer_offset_atIndex(Some(self.buf_counters.as_ref().unwrap()), 0, 2);
encoder.setBytes_length_atIndex(
NonNull::new(¶ms as *const SortParams as *mut _).unwrap(),
std::mem::size_of::<SortParams>(),
3,
);
if with_values {
encoder.setBuffer_offset_atIndex(vals_a, 0, 4);
encoder.setBuffer_offset_atIndex(vals_b, 0, 5);
}
}
encoder.dispatchThreadgroups_threadsPerThreadgroup(hist_grid, tg_size);
let pso_inner = self.pso_cache.get_or_create_specialized(
&self.library,
"sort_inner_fused",
&[(0, FnConstant::Bool(with_values)), (1, FnConstant::Bool(true))],
);
struct InnerConfig<'a> {
start_shift: u32,
pass_count: u32,
buf_0: &'a ProtocolObject<dyn MTLBuffer>,
buf_1: &'a ProtocolObject<dyn MTLBuffer>,
val_4: Option<&'a ProtocolObject<dyn MTLBuffer>>,
val_5: Option<&'a ProtocolObject<dyn MTLBuffer>>,
}
let inner_configs = [
InnerConfig { start_shift: 0, pass_count: 1, buf_0: buf_a, buf_1: buf_b, val_4: vals_a, val_5: vals_b },
InnerConfig { start_shift: 1, pass_count: 3, buf_0: buf_b, buf_1: buf_a, val_4: vals_b, val_5: vals_a },
InnerConfig { start_shift: 4, pass_count: 3, buf_0: buf_a, buf_1: buf_b, val_4: vals_a, val_5: vals_b },
];
for cfg in &inner_configs {
let inner_params = InnerParams {
start_shift: cfg.start_shift,
pass_count: cfg.pass_count,
batch_start: 0,
};
encoder.setComputePipelineState(pso_inner);
unsafe {
encoder.setBuffer_offset_atIndex(Some(cfg.buf_0), 0, 0);
encoder.setBuffer_offset_atIndex(Some(cfg.buf_1), 0, 1);
encoder.setBuffer_offset_atIndex(
Some(self.buf_bucket_descs.as_ref().unwrap()),
0,
2,
);
encoder.setBytes_length_atIndex(
NonNull::new(&inner_params as *const InnerParams as *mut _).unwrap(),
std::mem::size_of::<InnerParams>(),
3,
);
if with_values {
encoder.setBuffer_offset_atIndex(cfg.val_4, 0, 4);
encoder.setBuffer_offset_atIndex(cfg.val_5, 0, 5);
}
}
encoder.dispatchThreadgroups_threadsPerThreadgroup(fused_grid, tg_size);
}
}
fn encode_gather_values(
&mut self,
encoder: &ProtocolObject<dyn MTLComputeCommandEncoder>,
sorted_indices: &ProtocolObject<dyn MTLBuffer>,
original_vals: &ProtocolObject<dyn MTLBuffer>,
gathered_vals: &ProtocolObject<dyn MTLBuffer>,
n: usize,
) {
let pso = self.pso_cache.get_or_create(&self.library, "sort_gather_values");
encoder.setComputePipelineState(pso);
let count = n as u32;
unsafe {
encoder.setBuffer_offset_atIndex(Some(sorted_indices), 0, 0);
encoder.setBuffer_offset_atIndex(Some(original_vals), 0, 1);
encoder.setBuffer_offset_atIndex(Some(gathered_vals), 0, 2);
encoder.setBytes_length_atIndex(
NonNull::new(&count as *const u32 as *mut _).unwrap(),
4,
3,
);
}
let grid = MTLSize {
width: n.div_ceil(THREADS_PER_TG) * THREADS_PER_TG,
height: 1,
depth: 1,
};
let tg = MTLSize {
width: THREADS_PER_TG,
height: 1,
depth: 1,
};
encoder.dispatchThreads_threadsPerThreadgroup(grid, tg);
}
#[allow(clippy::too_many_arguments)]
fn encode_sort_pipeline_full(
&mut self,
encoder: &ProtocolObject<dyn MTLComputeCommandEncoder>,
buf_a: &ProtocolObject<dyn MTLBuffer>,
buf_b: &ProtocolObject<dyn MTLBuffer>,
vals_a: Option<&ProtocolObject<dyn MTLBuffer>>,
vals_b: Option<&ProtocolObject<dyn MTLBuffer>>,
n: usize,
num_tiles: usize,
with_values: bool,
) {
let params = SortParams {
element_count: n as u32,
num_tiles: num_tiles as u32,
shift: 24,
pass: 0,
};
let tile_size_u32 = TILE_SIZE as u32;
let tg_size = MTLSize {
width: THREADS_PER_TG,
height: 1,
depth: 1,
};
let hist_grid = MTLSize {
width: num_tiles,
height: 1,
depth: 1,
};
let one_tg_grid = MTLSize {
width: 1,
height: 1,
depth: 1,
};
let fused_grid = MTLSize {
width: 256,
height: 1,
depth: 1,
};
let pso = self.pso_cache.get_or_create_specialized(&self.library, "sort_msd_histogram", &[(1, FnConstant::Bool(false))]);
encoder.setComputePipelineState(pso);
unsafe {
encoder.setBuffer_offset_atIndex(Some(buf_a), 0, 0);
encoder.setBuffer_offset_atIndex(Some(self.buf_msd_hist.as_ref().unwrap()), 0, 1);
encoder.setBytes_length_atIndex(
NonNull::new(¶ms as *const SortParams as *mut _).unwrap(),
std::mem::size_of::<SortParams>(),
2,
);
}
encoder.dispatchThreadgroups_threadsPerThreadgroup(hist_grid, tg_size);
let pso = self.pso_cache.get_or_create(&self.library, "sort_msd_prep");
encoder.setComputePipelineState(pso);
unsafe {
encoder.setBuffer_offset_atIndex(Some(self.buf_msd_hist.as_ref().unwrap()), 0, 0);
encoder.setBuffer_offset_atIndex(Some(self.buf_counters.as_ref().unwrap()), 0, 1);
encoder.setBuffer_offset_atIndex(Some(self.buf_bucket_descs.as_ref().unwrap()), 0, 2);
encoder.setBytes_length_atIndex(
NonNull::new(&tile_size_u32 as *const u32 as *mut _).unwrap(),
4,
3,
);
}
encoder.dispatchThreadgroups_threadsPerThreadgroup(one_tg_grid, tg_size);
let pso = self.pso_cache.get_or_create_specialized(
&self.library,
"sort_msd_atomic_scatter",
&[(0, FnConstant::Bool(with_values))],
);
encoder.setComputePipelineState(pso);
unsafe {
encoder.setBuffer_offset_atIndex(Some(buf_a), 0, 0);
encoder.setBuffer_offset_atIndex(Some(buf_b), 0, 1);
encoder.setBuffer_offset_atIndex(Some(self.buf_counters.as_ref().unwrap()), 0, 2);
encoder.setBytes_length_atIndex(
NonNull::new(¶ms as *const SortParams as *mut _).unwrap(),
std::mem::size_of::<SortParams>(),
3,
);
if with_values {
encoder.setBuffer_offset_atIndex(vals_a, 0, 4);
encoder.setBuffer_offset_atIndex(vals_b, 0, 5);
}
}
encoder.dispatchThreadgroups_threadsPerThreadgroup(hist_grid, tg_size);
let inner_params = InnerParams {
start_shift: 0,
pass_count: 3,
batch_start: 0,
};
let pso = self.pso_cache.get_or_create_specialized(
&self.library,
"sort_inner_fused",
&[(0, FnConstant::Bool(with_values))],
);
encoder.setComputePipelineState(pso);
unsafe {
encoder.setBuffer_offset_atIndex(Some(buf_a), 0, 0);
encoder.setBuffer_offset_atIndex(Some(buf_b), 0, 1);
encoder.setBuffer_offset_atIndex(Some(self.buf_bucket_descs.as_ref().unwrap()), 0, 2);
encoder.setBytes_length_atIndex(
NonNull::new(&inner_params as *const InnerParams as *mut _).unwrap(),
std::mem::size_of::<InnerParams>(),
3,
);
if with_values {
encoder.setBuffer_offset_atIndex(vals_a, 0, 4);
encoder.setBuffer_offset_atIndex(vals_b, 0, 5);
}
}
encoder.dispatchThreadgroups_threadsPerThreadgroup(fused_grid, tg_size);
}
fn encode_init_indices(
&mut self,
encoder: &ProtocolObject<dyn MTLComputeCommandEncoder>,
indices_buf: &ProtocolObject<dyn MTLBuffer>,
n: usize,
) {
let pso = self.pso_cache.get_or_create(&self.library, "sort_init_indices");
encoder.setComputePipelineState(pso);
let count = n as u32;
unsafe {
encoder.setBuffer_offset_atIndex(Some(indices_buf), 0, 0);
encoder.setBytes_length_atIndex(
NonNull::new(&count as *const u32 as *mut _).unwrap(),
4,
1,
);
}
let grid = MTLSize {
width: n.div_ceil(THREADS_PER_TG) * THREADS_PER_TG,
height: 1,
depth: 1,
};
let tg = MTLSize {
width: THREADS_PER_TG,
height: 1,
depth: 1,
};
encoder.dispatchThreads_threadsPerThreadgroup(grid, tg);
}
pub fn argsort_u32(&mut self, data: &[u32]) -> Result<Vec<u32>, SortError> {
let n = data.len();
if n == 0 {
return Ok(vec![]);
}
if n == 1 {
return Ok(vec![0]);
}
self.ensure_buffers_with_values(n);
let num_tiles = n.div_ceil(TILE_SIZE);
unsafe {
std::ptr::copy_nonoverlapping(
data.as_ptr() as *const u8,
self.buf_a.as_ref().unwrap().contents().as_ptr() as *mut u8,
n * 4,
);
std::ptr::write_bytes(
self.buf_msd_hist.as_ref().unwrap().contents().as_ptr() as *mut u8,
0,
256 * 4,
);
}
let cmd = self.queue.commandBuffer().ok_or_else(|| {
SortError::GpuExecution("failed to create command buffer".to_string())
})?;
let enc = cmd.computeCommandEncoder().ok_or_else(|| {
SortError::GpuExecution("failed to create compute encoder".to_string())
})?;
let vals_a = self.buf_vals_a.as_ref().unwrap().clone();
let vals_b = self.buf_vals_b.as_ref().unwrap().clone();
self.encode_init_indices(&enc, &vals_a, n);
let buf_a = self.buf_a.as_ref().unwrap().clone();
let buf_b = self.buf_b.as_ref().unwrap().clone();
self.encode_sort_pipeline_full(
&enc,
&buf_a,
&buf_b,
Some(&vals_a),
Some(&vals_b),
n,
num_tiles,
true,
);
enc.endEncoding();
cmd.commit();
cmd.waitUntilCompleted();
if cmd.status() == MTLCommandBufferStatus::Error {
return Err(SortError::GpuExecution(format!(
"command buffer error: {:?}",
cmd.error()
)));
}
let mut result = vec![0u32; n];
unsafe {
std::ptr::copy_nonoverlapping(
self.buf_vals_a.as_ref().unwrap().contents().as_ptr() as *const u32,
result.as_mut_ptr(),
n,
);
}
Ok(result)
}
pub fn argsort_i32(&mut self, data: &[i32]) -> Result<Vec<u32>, SortError> {
let n = data.len();
if n == 0 {
return Ok(vec![]);
}
if n == 1 {
return Ok(vec![0]);
}
self.ensure_buffers_with_values(n);
let num_tiles = n.div_ceil(TILE_SIZE);
unsafe {
std::ptr::copy_nonoverlapping(
data.as_ptr() as *const u8,
self.buf_a.as_ref().unwrap().contents().as_ptr() as *mut u8,
n * 4,
);
std::ptr::write_bytes(
self.buf_msd_hist.as_ref().unwrap().contents().as_ptr() as *mut u8,
0,
256 * 4,
);
}
let cmd = self.queue.commandBuffer().ok_or_else(|| {
SortError::GpuExecution("failed to create command buffer".to_string())
})?;
let enc = cmd.computeCommandEncoder().ok_or_else(|| {
SortError::GpuExecution("failed to create compute encoder".to_string())
})?;
let vals_a = self.buf_vals_a.as_ref().unwrap().clone();
let vals_b = self.buf_vals_b.as_ref().unwrap().clone();
let buf_a = self.buf_a.as_ref().unwrap().clone();
let buf_b = self.buf_b.as_ref().unwrap().clone();
self.encode_init_indices(&enc, &vals_a, n);
encode_transform_32(&enc, &self.library, &mut self.pso_cache, &buf_a, n, 0);
self.encode_sort_pipeline_full(
&enc,
&buf_a,
&buf_b,
Some(&vals_a),
Some(&vals_b),
n,
num_tiles,
true,
);
enc.endEncoding();
cmd.commit();
cmd.waitUntilCompleted();
if cmd.status() == MTLCommandBufferStatus::Error {
return Err(SortError::GpuExecution(format!(
"command buffer error: {:?}",
cmd.error()
)));
}
let mut result = vec![0u32; n];
unsafe {
std::ptr::copy_nonoverlapping(
self.buf_vals_a.as_ref().unwrap().contents().as_ptr() as *const u32,
result.as_mut_ptr(),
n,
);
}
Ok(result)
}
pub fn argsort_f32(&mut self, data: &[f32]) -> Result<Vec<u32>, SortError> {
let n = data.len();
if n == 0 {
return Ok(vec![]);
}
if n == 1 {
return Ok(vec![0]);
}
self.ensure_buffers_with_values(n);
let num_tiles = n.div_ceil(TILE_SIZE);
unsafe {
std::ptr::copy_nonoverlapping(
data.as_ptr() as *const u8,
self.buf_a.as_ref().unwrap().contents().as_ptr() as *mut u8,
n * 4,
);
std::ptr::write_bytes(
self.buf_msd_hist.as_ref().unwrap().contents().as_ptr() as *mut u8,
0,
256 * 4,
);
}
let cmd = self.queue.commandBuffer().ok_or_else(|| {
SortError::GpuExecution("failed to create command buffer".to_string())
})?;
let enc = cmd.computeCommandEncoder().ok_or_else(|| {
SortError::GpuExecution("failed to create compute encoder".to_string())
})?;
let vals_a = self.buf_vals_a.as_ref().unwrap().clone();
let vals_b = self.buf_vals_b.as_ref().unwrap().clone();
let buf_a = self.buf_a.as_ref().unwrap().clone();
let buf_b = self.buf_b.as_ref().unwrap().clone();
self.encode_init_indices(&enc, &vals_a, n);
encode_transform_32(&enc, &self.library, &mut self.pso_cache, &buf_a, n, 1);
self.encode_sort_pipeline_full(
&enc,
&buf_a,
&buf_b,
Some(&vals_a),
Some(&vals_b),
n,
num_tiles,
true,
);
enc.endEncoding();
cmd.commit();
cmd.waitUntilCompleted();
if cmd.status() == MTLCommandBufferStatus::Error {
return Err(SortError::GpuExecution(format!(
"command buffer error: {:?}",
cmd.error()
)));
}
let mut result = vec![0u32; n];
unsafe {
std::ptr::copy_nonoverlapping(
self.buf_vals_a.as_ref().unwrap().contents().as_ptr() as *const u32,
result.as_mut_ptr(),
n,
);
}
Ok(result)
}
pub fn sort_pairs_u32(
&mut self,
keys: &mut [u32],
values: &mut [u32],
) -> Result<(), SortError> {
let n = keys.len();
if n != values.len() {
return Err(SortError::LengthMismatch {
keys: n,
values: values.len(),
});
}
if n <= 1 {
return Ok(());
}
self.ensure_buffers_with_values_and_orig(n);
let num_tiles = n.div_ceil(TILE_SIZE);
unsafe {
std::ptr::copy_nonoverlapping(
keys.as_ptr() as *const u8,
self.buf_a.as_ref().unwrap().contents().as_ptr() as *mut u8,
n * 4,
);
std::ptr::copy_nonoverlapping(
values.as_ptr() as *const u8,
self.buf_orig_vals.as_ref().unwrap().contents().as_ptr() as *mut u8,
n * 4,
);
std::ptr::write_bytes(
self.buf_msd_hist.as_ref().unwrap().contents().as_ptr() as *mut u8,
0,
256 * 4,
);
}
let cmd = self.queue.commandBuffer().ok_or_else(|| {
SortError::GpuExecution("failed to create command buffer".to_string())
})?;
let enc = cmd.computeCommandEncoder().ok_or_else(|| {
SortError::GpuExecution("failed to create compute encoder".to_string())
})?;
let vals_a = self.buf_vals_a.as_ref().unwrap().clone();
let vals_b = self.buf_vals_b.as_ref().unwrap().clone();
let buf_a = self.buf_a.as_ref().unwrap().clone();
let buf_b = self.buf_b.as_ref().unwrap().clone();
let orig_vals = self.buf_orig_vals.as_ref().unwrap().clone();
self.encode_init_indices(&enc, &vals_a, n);
self.encode_sort_pipeline_full(
&enc,
&buf_a,
&buf_b,
Some(&vals_a),
Some(&vals_b),
n,
num_tiles,
true,
);
self.encode_gather_values(&enc, &vals_a, &orig_vals, &vals_b, n);
enc.endEncoding();
cmd.commit();
cmd.waitUntilCompleted();
if cmd.status() == MTLCommandBufferStatus::Error {
return Err(SortError::GpuExecution(format!(
"command buffer error: {:?}",
cmd.error()
)));
}
unsafe {
std::ptr::copy_nonoverlapping(
self.buf_a.as_ref().unwrap().contents().as_ptr() as *const u32,
keys.as_mut_ptr(),
n,
);
std::ptr::copy_nonoverlapping(
self.buf_vals_b.as_ref().unwrap().contents().as_ptr() as *const u32,
values.as_mut_ptr(),
n,
);
}
Ok(())
}
pub fn sort_pairs_i32(
&mut self,
keys: &mut [i32],
values: &mut [u32],
) -> Result<(), SortError> {
let n = keys.len();
if n != values.len() {
return Err(SortError::LengthMismatch {
keys: n,
values: values.len(),
});
}
if n <= 1 {
return Ok(());
}
self.ensure_buffers_with_values_and_orig(n);
let num_tiles = n.div_ceil(TILE_SIZE);
unsafe {
std::ptr::copy_nonoverlapping(
keys.as_ptr() as *const u8,
self.buf_a.as_ref().unwrap().contents().as_ptr() as *mut u8,
n * 4,
);
std::ptr::copy_nonoverlapping(
values.as_ptr() as *const u8,
self.buf_orig_vals.as_ref().unwrap().contents().as_ptr() as *mut u8,
n * 4,
);
std::ptr::write_bytes(
self.buf_msd_hist.as_ref().unwrap().contents().as_ptr() as *mut u8,
0,
256 * 4,
);
}
let cmd = self.queue.commandBuffer().ok_or_else(|| {
SortError::GpuExecution("failed to create command buffer".to_string())
})?;
let enc = cmd.computeCommandEncoder().ok_or_else(|| {
SortError::GpuExecution("failed to create compute encoder".to_string())
})?;
let vals_a = self.buf_vals_a.as_ref().unwrap().clone();
let vals_b = self.buf_vals_b.as_ref().unwrap().clone();
let buf_a = self.buf_a.as_ref().unwrap().clone();
let buf_b = self.buf_b.as_ref().unwrap().clone();
let orig_vals = self.buf_orig_vals.as_ref().unwrap().clone();
self.encode_init_indices(&enc, &vals_a, n);
encode_transform_32(&enc, &self.library, &mut self.pso_cache, &buf_a, n, 0);
self.encode_sort_pipeline_full(
&enc,
&buf_a,
&buf_b,
Some(&vals_a),
Some(&vals_b),
n,
num_tiles,
true,
);
encode_transform_32(&enc, &self.library, &mut self.pso_cache, &buf_a, n, 0);
self.encode_gather_values(&enc, &vals_a, &orig_vals, &vals_b, n);
enc.endEncoding();
cmd.commit();
cmd.waitUntilCompleted();
if cmd.status() == MTLCommandBufferStatus::Error {
return Err(SortError::GpuExecution(format!(
"command buffer error: {:?}",
cmd.error()
)));
}
unsafe {
std::ptr::copy_nonoverlapping(
self.buf_a.as_ref().unwrap().contents().as_ptr() as *const i32,
keys.as_mut_ptr(),
n,
);
std::ptr::copy_nonoverlapping(
self.buf_vals_b.as_ref().unwrap().contents().as_ptr() as *const u32,
values.as_mut_ptr(),
n,
);
}
Ok(())
}
pub fn sort_pairs_f32(
&mut self,
keys: &mut [f32],
values: &mut [u32],
) -> Result<(), SortError> {
let n = keys.len();
if n != values.len() {
return Err(SortError::LengthMismatch {
keys: n,
values: values.len(),
});
}
if n <= 1 {
return Ok(());
}
self.ensure_buffers_with_values_and_orig(n);
let num_tiles = n.div_ceil(TILE_SIZE);
unsafe {
std::ptr::copy_nonoverlapping(
keys.as_ptr() as *const u8,
self.buf_a.as_ref().unwrap().contents().as_ptr() as *mut u8,
n * 4,
);
std::ptr::copy_nonoverlapping(
values.as_ptr() as *const u8,
self.buf_orig_vals.as_ref().unwrap().contents().as_ptr() as *mut u8,
n * 4,
);
std::ptr::write_bytes(
self.buf_msd_hist.as_ref().unwrap().contents().as_ptr() as *mut u8,
0,
256 * 4,
);
}
let cmd = self.queue.commandBuffer().ok_or_else(|| {
SortError::GpuExecution("failed to create command buffer".to_string())
})?;
let enc = cmd.computeCommandEncoder().ok_or_else(|| {
SortError::GpuExecution("failed to create compute encoder".to_string())
})?;
let vals_a = self.buf_vals_a.as_ref().unwrap().clone();
let vals_b = self.buf_vals_b.as_ref().unwrap().clone();
let buf_a = self.buf_a.as_ref().unwrap().clone();
let buf_b = self.buf_b.as_ref().unwrap().clone();
let orig_vals = self.buf_orig_vals.as_ref().unwrap().clone();
self.encode_init_indices(&enc, &vals_a, n);
encode_transform_32(&enc, &self.library, &mut self.pso_cache, &buf_a, n, 1);
self.encode_sort_pipeline_full(
&enc,
&buf_a,
&buf_b,
Some(&vals_a),
Some(&vals_b),
n,
num_tiles,
true,
);
encode_transform_32(&enc, &self.library, &mut self.pso_cache, &buf_a, n, 2);
self.encode_gather_values(&enc, &vals_a, &orig_vals, &vals_b, n);
enc.endEncoding();
cmd.commit();
cmd.waitUntilCompleted();
if cmd.status() == MTLCommandBufferStatus::Error {
return Err(SortError::GpuExecution(format!(
"command buffer error: {:?}",
cmd.error()
)));
}
unsafe {
std::ptr::copy_nonoverlapping(
self.buf_a.as_ref().unwrap().contents().as_ptr() as *const f32,
keys.as_mut_ptr(),
n,
);
std::ptr::copy_nonoverlapping(
self.buf_vals_b.as_ref().unwrap().contents().as_ptr() as *const u32,
values.as_mut_ptr(),
n,
);
}
Ok(())
}
pub fn sort_u64(&mut self, data: &mut [u64]) -> Result<(), SortError> {
let n = data.len();
if n <= 1 {
return Ok(());
}
self.ensure_64bit_psos();
self.ensure_buffers_64(n);
let num_tiles = n.div_ceil(TILE_SIZE_64);
unsafe {
std::ptr::copy_nonoverlapping(
data.as_ptr() as *const u8,
self.buf_a.as_ref().unwrap().contents().as_ptr() as *mut u8,
n * 8,
);
std::ptr::write_bytes(
self.buf_msd_hist.as_ref().unwrap().contents().as_ptr() as *mut u8,
0,
256 * 4,
);
}
let cmd = self.queue.commandBuffer().ok_or_else(|| {
SortError::GpuExecution("failed to create command buffer".to_string())
})?;
let enc = cmd.computeCommandEncoder().ok_or_else(|| {
SortError::GpuExecution("failed to create compute encoder".to_string())
})?;
let buf_a = self.buf_a.as_ref().unwrap().clone();
let buf_b = self.buf_b.as_ref().unwrap().clone();
self.encode_sort_pipeline_64(&enc, &buf_a, &buf_b, None, None, n, num_tiles, false);
enc.endEncoding();
cmd.commit();
cmd.waitUntilCompleted();
if cmd.status() == MTLCommandBufferStatus::Error {
return Err(SortError::GpuExecution(format!(
"command buffer error: {:?}",
cmd.error()
)));
}
unsafe {
std::ptr::copy_nonoverlapping(
self.buf_a.as_ref().unwrap().contents().as_ptr() as *const u64,
data.as_mut_ptr(),
n,
);
}
Ok(())
}
pub fn sort_u64_buffer(&mut self, buf: &SortBuffer<u64>) -> Result<(), SortError> {
let n = buf.len();
if n <= 1 {
return Ok(());
}
self.ensure_64bit_psos();
self.ensure_scratch_buffers_64(n);
let num_tiles = n.div_ceil(TILE_SIZE_64);
unsafe {
std::ptr::write_bytes(
self.buf_msd_hist.as_ref().unwrap().contents().as_ptr() as *mut u8,
0,
256 * 4,
);
}
let cmd = self.queue.commandBuffer().ok_or_else(|| {
SortError::GpuExecution("failed to create command buffer".to_string())
})?;
let enc = cmd.computeCommandEncoder().ok_or_else(|| {
SortError::GpuExecution("failed to create compute encoder".to_string())
})?;
let buf_b = self.buf_b.as_ref().unwrap().clone();
self.encode_sort_pipeline_64(&enc, &buf.buffer, &buf_b, None, None, n, num_tiles, false);
enc.endEncoding();
cmd.commit();
cmd.waitUntilCompleted();
if cmd.status() == MTLCommandBufferStatus::Error {
return Err(SortError::GpuExecution(format!(
"command buffer error: {:?}",
cmd.error()
)));
}
Ok(())
}
pub fn sort_i64(&mut self, data: &mut [i64]) -> Result<(), SortError> {
let n = data.len();
if n <= 1 {
return Ok(());
}
self.ensure_64bit_psos();
self.ensure_buffers_64(n);
let num_tiles = n.div_ceil(TILE_SIZE_64);
unsafe {
std::ptr::copy_nonoverlapping(
data.as_ptr() as *const u8,
self.buf_a.as_ref().unwrap().contents().as_ptr() as *mut u8,
n * 8,
);
std::ptr::write_bytes(
self.buf_msd_hist.as_ref().unwrap().contents().as_ptr() as *mut u8,
0,
256 * 4,
);
}
let cmd = self.queue.commandBuffer().ok_or_else(|| {
SortError::GpuExecution("failed to create command buffer".to_string())
})?;
let enc = cmd.computeCommandEncoder().ok_or_else(|| {
SortError::GpuExecution("failed to create compute encoder".to_string())
})?;
let buf_a = self.buf_a.as_ref().unwrap().clone();
let buf_b = self.buf_b.as_ref().unwrap().clone();
encode_transform_64(&enc, &self.library, &mut self.pso_cache, &buf_a, n, 0);
self.encode_sort_pipeline_64(&enc, &buf_a, &buf_b, None, None, n, num_tiles, false);
encode_transform_64(&enc, &self.library, &mut self.pso_cache, &buf_a, n, 0);
enc.endEncoding();
cmd.commit();
cmd.waitUntilCompleted();
if cmd.status() == MTLCommandBufferStatus::Error {
return Err(SortError::GpuExecution(format!(
"command buffer error: {:?}",
cmd.error()
)));
}
unsafe {
std::ptr::copy_nonoverlapping(
self.buf_a.as_ref().unwrap().contents().as_ptr() as *const i64,
data.as_mut_ptr(),
n,
);
}
Ok(())
}
pub fn sort_f64(&mut self, data: &mut [f64]) -> Result<(), SortError> {
let n = data.len();
if n <= 1 {
return Ok(());
}
self.ensure_64bit_psos();
self.ensure_buffers_64(n);
let num_tiles = n.div_ceil(TILE_SIZE_64);
unsafe {
std::ptr::copy_nonoverlapping(
data.as_ptr() as *const u8,
self.buf_a.as_ref().unwrap().contents().as_ptr() as *mut u8,
n * 8,
);
std::ptr::write_bytes(
self.buf_msd_hist.as_ref().unwrap().contents().as_ptr() as *mut u8,
0,
256 * 4,
);
}
let cmd = self.queue.commandBuffer().ok_or_else(|| {
SortError::GpuExecution("failed to create command buffer".to_string())
})?;
let enc = cmd.computeCommandEncoder().ok_or_else(|| {
SortError::GpuExecution("failed to create compute encoder".to_string())
})?;
let buf_a = self.buf_a.as_ref().unwrap().clone();
let buf_b = self.buf_b.as_ref().unwrap().clone();
encode_transform_64(&enc, &self.library, &mut self.pso_cache, &buf_a, n, 1);
self.encode_sort_pipeline_64(&enc, &buf_a, &buf_b, None, None, n, num_tiles, false);
encode_transform_64(&enc, &self.library, &mut self.pso_cache, &buf_a, n, 2);
enc.endEncoding();
cmd.commit();
cmd.waitUntilCompleted();
if cmd.status() == MTLCommandBufferStatus::Error {
return Err(SortError::GpuExecution(format!(
"command buffer error: {:?}",
cmd.error()
)));
}
unsafe {
std::ptr::copy_nonoverlapping(
self.buf_a.as_ref().unwrap().contents().as_ptr() as *const f64,
data.as_mut_ptr(),
n,
);
}
Ok(())
}
pub fn sort_i64_buffer(&mut self, buf: &SortBuffer<i64>) -> Result<(), SortError> {
let n = buf.len();
if n <= 1 {
return Ok(());
}
self.ensure_64bit_psos();
self.ensure_scratch_buffers_64(n);
let num_tiles = n.div_ceil(TILE_SIZE_64);
unsafe {
std::ptr::write_bytes(
self.buf_msd_hist.as_ref().unwrap().contents().as_ptr() as *mut u8,
0,
256 * 4,
);
}
let cmd = self.queue.commandBuffer().ok_or_else(|| {
SortError::GpuExecution("failed to create command buffer".to_string())
})?;
let enc = cmd.computeCommandEncoder().ok_or_else(|| {
SortError::GpuExecution("failed to create compute encoder".to_string())
})?;
let buf_b = self.buf_b.as_ref().unwrap().clone();
encode_transform_64(&enc, &self.library, &mut self.pso_cache, &buf.buffer, n, 0);
self.encode_sort_pipeline_64(&enc, &buf.buffer, &buf_b, None, None, n, num_tiles, false);
encode_transform_64(&enc, &self.library, &mut self.pso_cache, &buf.buffer, n, 0);
enc.endEncoding();
cmd.commit();
cmd.waitUntilCompleted();
if cmd.status() == MTLCommandBufferStatus::Error {
return Err(SortError::GpuExecution(format!(
"command buffer error: {:?}",
cmd.error()
)));
}
Ok(())
}
pub fn sort_f64_buffer(&mut self, buf: &SortBuffer<f64>) -> Result<(), SortError> {
let n = buf.len();
if n <= 1 {
return Ok(());
}
self.ensure_64bit_psos();
self.ensure_scratch_buffers_64(n);
let num_tiles = n.div_ceil(TILE_SIZE_64);
unsafe {
std::ptr::write_bytes(
self.buf_msd_hist.as_ref().unwrap().contents().as_ptr() as *mut u8,
0,
256 * 4,
);
}
let cmd = self.queue.commandBuffer().ok_or_else(|| {
SortError::GpuExecution("failed to create command buffer".to_string())
})?;
let enc = cmd.computeCommandEncoder().ok_or_else(|| {
SortError::GpuExecution("failed to create compute encoder".to_string())
})?;
let buf_b = self.buf_b.as_ref().unwrap().clone();
encode_transform_64(&enc, &self.library, &mut self.pso_cache, &buf.buffer, n, 1);
self.encode_sort_pipeline_64(&enc, &buf.buffer, &buf_b, None, None, n, num_tiles, false);
encode_transform_64(&enc, &self.library, &mut self.pso_cache, &buf.buffer, n, 2);
enc.endEncoding();
cmd.commit();
cmd.waitUntilCompleted();
if cmd.status() == MTLCommandBufferStatus::Error {
return Err(SortError::GpuExecution(format!(
"command buffer error: {:?}",
cmd.error()
)));
}
Ok(())
}
pub fn argsort_u64(&mut self, data: &[u64]) -> Result<Vec<u32>, SortError> {
let n = data.len();
if n == 0 {
return Ok(vec![]);
}
if n == 1 {
return Ok(vec![0]);
}
self.ensure_64bit_psos();
self.ensure_buffers_64_with_values(n);
let num_tiles = n.div_ceil(TILE_SIZE_64);
unsafe {
std::ptr::copy_nonoverlapping(
data.as_ptr() as *const u8,
self.buf_a.as_ref().unwrap().contents().as_ptr() as *mut u8,
n * 8,
);
std::ptr::write_bytes(
self.buf_msd_hist.as_ref().unwrap().contents().as_ptr() as *mut u8,
0,
256 * 4,
);
}
let cmd = self.queue.commandBuffer().ok_or_else(|| {
SortError::GpuExecution("failed to create command buffer".to_string())
})?;
let enc = cmd.computeCommandEncoder().ok_or_else(|| {
SortError::GpuExecution("failed to create compute encoder".to_string())
})?;
let vals_a = self.buf_vals_a.as_ref().unwrap().clone();
let vals_b = self.buf_vals_b.as_ref().unwrap().clone();
let buf_a = self.buf_a.as_ref().unwrap().clone();
let buf_b = self.buf_b.as_ref().unwrap().clone();
self.encode_init_indices(&enc, &vals_a, n);
self.encode_sort_pipeline_64(
&enc, &buf_a, &buf_b, Some(&vals_a), Some(&vals_b), n, num_tiles, true,
);
enc.endEncoding();
cmd.commit();
cmd.waitUntilCompleted();
if cmd.status() == MTLCommandBufferStatus::Error {
return Err(SortError::GpuExecution(format!(
"command buffer error: {:?}",
cmd.error()
)));
}
let mut result = vec![0u32; n];
unsafe {
std::ptr::copy_nonoverlapping(
self.buf_vals_a.as_ref().unwrap().contents().as_ptr() as *const u32,
result.as_mut_ptr(),
n,
);
}
Ok(result)
}
pub fn argsort_i64(&mut self, data: &[i64]) -> Result<Vec<u32>, SortError> {
let n = data.len();
if n == 0 {
return Ok(vec![]);
}
if n == 1 {
return Ok(vec![0]);
}
self.ensure_64bit_psos();
self.ensure_buffers_64_with_values(n);
let num_tiles = n.div_ceil(TILE_SIZE_64);
unsafe {
std::ptr::copy_nonoverlapping(
data.as_ptr() as *const u8,
self.buf_a.as_ref().unwrap().contents().as_ptr() as *mut u8,
n * 8,
);
std::ptr::write_bytes(
self.buf_msd_hist.as_ref().unwrap().contents().as_ptr() as *mut u8,
0,
256 * 4,
);
}
let cmd = self.queue.commandBuffer().ok_or_else(|| {
SortError::GpuExecution("failed to create command buffer".to_string())
})?;
let enc = cmd.computeCommandEncoder().ok_or_else(|| {
SortError::GpuExecution("failed to create compute encoder".to_string())
})?;
let vals_a = self.buf_vals_a.as_ref().unwrap().clone();
let vals_b = self.buf_vals_b.as_ref().unwrap().clone();
let buf_a = self.buf_a.as_ref().unwrap().clone();
let buf_b = self.buf_b.as_ref().unwrap().clone();
self.encode_init_indices(&enc, &vals_a, n);
encode_transform_64(&enc, &self.library, &mut self.pso_cache, &buf_a, n, 0);
self.encode_sort_pipeline_64(
&enc, &buf_a, &buf_b, Some(&vals_a), Some(&vals_b), n, num_tiles, true,
);
enc.endEncoding();
cmd.commit();
cmd.waitUntilCompleted();
if cmd.status() == MTLCommandBufferStatus::Error {
return Err(SortError::GpuExecution(format!(
"command buffer error: {:?}",
cmd.error()
)));
}
let mut result = vec![0u32; n];
unsafe {
std::ptr::copy_nonoverlapping(
self.buf_vals_a.as_ref().unwrap().contents().as_ptr() as *const u32,
result.as_mut_ptr(),
n,
);
}
Ok(result)
}
pub fn argsort_f64(&mut self, data: &[f64]) -> Result<Vec<u32>, SortError> {
let n = data.len();
if n == 0 {
return Ok(vec![]);
}
if n == 1 {
return Ok(vec![0]);
}
self.ensure_64bit_psos();
self.ensure_buffers_64_with_values(n);
let num_tiles = n.div_ceil(TILE_SIZE_64);
unsafe {
std::ptr::copy_nonoverlapping(
data.as_ptr() as *const u8,
self.buf_a.as_ref().unwrap().contents().as_ptr() as *mut u8,
n * 8,
);
std::ptr::write_bytes(
self.buf_msd_hist.as_ref().unwrap().contents().as_ptr() as *mut u8,
0,
256 * 4,
);
}
let cmd = self.queue.commandBuffer().ok_or_else(|| {
SortError::GpuExecution("failed to create command buffer".to_string())
})?;
let enc = cmd.computeCommandEncoder().ok_or_else(|| {
SortError::GpuExecution("failed to create compute encoder".to_string())
})?;
let vals_a = self.buf_vals_a.as_ref().unwrap().clone();
let vals_b = self.buf_vals_b.as_ref().unwrap().clone();
let buf_a = self.buf_a.as_ref().unwrap().clone();
let buf_b = self.buf_b.as_ref().unwrap().clone();
self.encode_init_indices(&enc, &vals_a, n);
encode_transform_64(&enc, &self.library, &mut self.pso_cache, &buf_a, n, 1);
self.encode_sort_pipeline_64(
&enc, &buf_a, &buf_b, Some(&vals_a), Some(&vals_b), n, num_tiles, true,
);
enc.endEncoding();
cmd.commit();
cmd.waitUntilCompleted();
if cmd.status() == MTLCommandBufferStatus::Error {
return Err(SortError::GpuExecution(format!(
"command buffer error: {:?}",
cmd.error()
)));
}
let mut result = vec![0u32; n];
unsafe {
std::ptr::copy_nonoverlapping(
self.buf_vals_a.as_ref().unwrap().contents().as_ptr() as *const u32,
result.as_mut_ptr(),
n,
);
}
Ok(result)
}
pub fn sort_pairs_u64(
&mut self,
keys: &mut [u64],
values: &mut [u32],
) -> Result<(), SortError> {
let n = keys.len();
if n != values.len() {
return Err(SortError::LengthMismatch {
keys: n,
values: values.len(),
});
}
if n <= 1 {
return Ok(());
}
self.ensure_64bit_psos();
self.ensure_buffers_64_with_values_and_orig(n);
let num_tiles = n.div_ceil(TILE_SIZE_64);
unsafe {
std::ptr::copy_nonoverlapping(
keys.as_ptr() as *const u8,
self.buf_a.as_ref().unwrap().contents().as_ptr() as *mut u8,
n * 8,
);
std::ptr::copy_nonoverlapping(
values.as_ptr() as *const u8,
self.buf_orig_vals.as_ref().unwrap().contents().as_ptr() as *mut u8,
n * 4,
);
std::ptr::write_bytes(
self.buf_msd_hist.as_ref().unwrap().contents().as_ptr() as *mut u8,
0,
256 * 4,
);
}
let cmd = self.queue.commandBuffer().ok_or_else(|| {
SortError::GpuExecution("failed to create command buffer".to_string())
})?;
let enc = cmd.computeCommandEncoder().ok_or_else(|| {
SortError::GpuExecution("failed to create compute encoder".to_string())
})?;
let vals_a = self.buf_vals_a.as_ref().unwrap().clone();
let vals_b = self.buf_vals_b.as_ref().unwrap().clone();
let buf_a = self.buf_a.as_ref().unwrap().clone();
let buf_b = self.buf_b.as_ref().unwrap().clone();
let orig_vals = self.buf_orig_vals.as_ref().unwrap().clone();
self.encode_init_indices(&enc, &vals_a, n);
self.encode_sort_pipeline_64(
&enc, &buf_a, &buf_b, Some(&vals_a), Some(&vals_b), n, num_tiles, true,
);
self.encode_gather_values(&enc, &vals_a, &orig_vals, &vals_b, n);
enc.endEncoding();
cmd.commit();
cmd.waitUntilCompleted();
if cmd.status() == MTLCommandBufferStatus::Error {
return Err(SortError::GpuExecution(format!(
"command buffer error: {:?}",
cmd.error()
)));
}
unsafe {
std::ptr::copy_nonoverlapping(
self.buf_a.as_ref().unwrap().contents().as_ptr() as *const u64,
keys.as_mut_ptr(),
n,
);
std::ptr::copy_nonoverlapping(
self.buf_vals_b.as_ref().unwrap().contents().as_ptr() as *const u32,
values.as_mut_ptr(),
n,
);
}
Ok(())
}
pub fn sort_pairs_i64(
&mut self,
keys: &mut [i64],
values: &mut [u32],
) -> Result<(), SortError> {
let n = keys.len();
if n != values.len() {
return Err(SortError::LengthMismatch {
keys: n,
values: values.len(),
});
}
if n <= 1 {
return Ok(());
}
self.ensure_64bit_psos();
self.ensure_buffers_64_with_values_and_orig(n);
let num_tiles = n.div_ceil(TILE_SIZE_64);
unsafe {
std::ptr::copy_nonoverlapping(
keys.as_ptr() as *const u8,
self.buf_a.as_ref().unwrap().contents().as_ptr() as *mut u8,
n * 8,
);
std::ptr::copy_nonoverlapping(
values.as_ptr() as *const u8,
self.buf_orig_vals.as_ref().unwrap().contents().as_ptr() as *mut u8,
n * 4,
);
std::ptr::write_bytes(
self.buf_msd_hist.as_ref().unwrap().contents().as_ptr() as *mut u8,
0,
256 * 4,
);
}
let cmd = self.queue.commandBuffer().ok_or_else(|| {
SortError::GpuExecution("failed to create command buffer".to_string())
})?;
let enc = cmd.computeCommandEncoder().ok_or_else(|| {
SortError::GpuExecution("failed to create compute encoder".to_string())
})?;
let vals_a = self.buf_vals_a.as_ref().unwrap().clone();
let vals_b = self.buf_vals_b.as_ref().unwrap().clone();
let buf_a = self.buf_a.as_ref().unwrap().clone();
let buf_b = self.buf_b.as_ref().unwrap().clone();
let orig_vals = self.buf_orig_vals.as_ref().unwrap().clone();
self.encode_init_indices(&enc, &vals_a, n);
encode_transform_64(&enc, &self.library, &mut self.pso_cache, &buf_a, n, 0);
self.encode_sort_pipeline_64(
&enc, &buf_a, &buf_b, Some(&vals_a), Some(&vals_b), n, num_tiles, true,
);
encode_transform_64(&enc, &self.library, &mut self.pso_cache, &buf_a, n, 0);
self.encode_gather_values(&enc, &vals_a, &orig_vals, &vals_b, n);
enc.endEncoding();
cmd.commit();
cmd.waitUntilCompleted();
if cmd.status() == MTLCommandBufferStatus::Error {
return Err(SortError::GpuExecution(format!(
"command buffer error: {:?}",
cmd.error()
)));
}
unsafe {
std::ptr::copy_nonoverlapping(
self.buf_a.as_ref().unwrap().contents().as_ptr() as *const i64,
keys.as_mut_ptr(),
n,
);
std::ptr::copy_nonoverlapping(
self.buf_vals_b.as_ref().unwrap().contents().as_ptr() as *const u32,
values.as_mut_ptr(),
n,
);
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sort_params_size() {
assert_eq!(std::mem::size_of::<SortParams>(), 16);
}
#[test]
fn test_bucket_desc_size() {
assert_eq!(std::mem::size_of::<BucketDesc>(), 16);
}
#[test]
fn test_sort_error_display() {
let e = SortError::DeviceNotFound;
assert_eq!(e.to_string(), "no Metal GPU device found");
let e = SortError::ShaderCompilation("test".to_string());
assert_eq!(e.to_string(), "shader compilation failed: test");
let e = SortError::GpuExecution("timeout".to_string());
assert_eq!(e.to_string(), "GPU execution failed: timeout");
let e = SortError::LengthMismatch {
keys: 10,
values: 5,
};
assert_eq!(e.to_string(), "length mismatch: keys=10, values=5");
}
#[test]
fn test_sort_key_u32_consts() {
assert_eq!(u32::KEY_SIZE, 4);
assert!(!u32::NEEDS_TRANSFORM);
assert!(!u32::IS_64BIT);
assert_eq!(u32::TRANSFORM_MODE_FORWARD, 0);
assert_eq!(u32::TRANSFORM_MODE_INVERSE, 0);
}
#[test]
fn test_sort_key_i32_consts() {
assert_eq!(i32::KEY_SIZE, 4);
assert!(i32::NEEDS_TRANSFORM);
assert!(!i32::IS_64BIT);
assert_eq!(i32::TRANSFORM_MODE_FORWARD, 0);
assert_eq!(i32::TRANSFORM_MODE_INVERSE, 0);
}
#[test]
fn test_sort_key_f32_consts() {
assert_eq!(f32::KEY_SIZE, 4);
assert!(f32::NEEDS_TRANSFORM);
assert!(!f32::IS_64BIT);
assert_eq!(f32::TRANSFORM_MODE_FORWARD, 1);
assert_eq!(f32::TRANSFORM_MODE_INVERSE, 2);
}
#[test]
fn test_sort_key_u64_consts() {
assert_eq!(u64::KEY_SIZE, 8);
assert!(!u64::NEEDS_TRANSFORM);
assert!(u64::IS_64BIT);
}
#[test]
fn test_sort_key_i64_consts() {
assert_eq!(i64::KEY_SIZE, 8);
assert!(i64::NEEDS_TRANSFORM);
assert!(i64::IS_64BIT);
assert_eq!(i64::TRANSFORM_MODE_FORWARD, 0);
assert_eq!(i64::TRANSFORM_MODE_INVERSE, 0);
}
#[test]
fn test_sort_key_f64_consts() {
assert_eq!(f64::KEY_SIZE, 8);
assert!(f64::NEEDS_TRANSFORM);
assert!(f64::IS_64BIT);
assert_eq!(f64::TRANSFORM_MODE_FORWARD, 1);
assert_eq!(f64::TRANSFORM_MODE_INVERSE, 2);
}
#[test]
fn test_sort_i32_basic() {
let mut sorter = GpuSorter::new().unwrap();
let mut data = vec![5i32, -3, 0, i32::MIN, i32::MAX, -1, 1];
sorter.sort_i32(&mut data).unwrap();
let mut expected = vec![5i32, -3, 0, i32::MIN, i32::MAX, -1, 1];
expected.sort();
assert_eq!(data, expected);
}
#[test]
fn test_sort_f32_basic() {
let mut sorter = GpuSorter::new().unwrap();
let mut data = vec![
1.0f32,
f32::NAN,
-0.0,
0.0,
f32::NEG_INFINITY,
-1.0,
f32::INFINITY,
-f32::NAN,
];
sorter.sort_f32(&mut data).unwrap();
let mut expected = vec![
1.0f32,
f32::NAN,
-0.0,
0.0,
f32::NEG_INFINITY,
-1.0,
f32::INFINITY,
-f32::NAN,
];
expected.sort_by(f32::total_cmp);
assert_eq!(
data.iter().map(|x| x.to_bits()).collect::<Vec<_>>(),
expected.iter().map(|x| x.to_bits()).collect::<Vec<_>>()
);
}
#[test]
fn test_argsort_u32_basic() {
let mut sorter = GpuSorter::new().unwrap();
let data = vec![30u32, 10, 20];
let indices = sorter.argsort_u32(&data).unwrap();
assert_eq!(indices, vec![1, 2, 0]);
assert_eq!(data, vec![30, 10, 20]); }
#[test]
fn test_argsort_u32_empty() {
let mut sorter = GpuSorter::new().unwrap();
let data: Vec<u32> = vec![];
let indices = sorter.argsort_u32(&data).unwrap();
assert_eq!(indices, vec![]);
}
#[test]
fn test_argsort_u32_single() {
let mut sorter = GpuSorter::new().unwrap();
let data = vec![42u32];
let indices = sorter.argsort_u32(&data).unwrap();
assert_eq!(indices, vec![0]);
}
#[test]
fn test_argsort_i32_basic() {
let mut sorter = GpuSorter::new().unwrap();
let data = vec![5i32, -3, 0, -1, 1];
let indices = sorter.argsort_i32(&data).unwrap();
assert_eq!(indices, vec![1, 3, 2, 4, 0]);
}
#[test]
fn test_argsort_f32_basic() {
let mut sorter = GpuSorter::new().unwrap();
let data = vec![3.0f32, 1.0, 2.0];
let indices = sorter.argsort_f32(&data).unwrap();
assert_eq!(indices, vec![1, 2, 0]);
assert_eq!(data, vec![3.0, 1.0, 2.0]); }
#[test]
fn test_argsort_i32_empty() {
let mut sorter = GpuSorter::new().unwrap();
let data: Vec<i32> = vec![];
let indices = sorter.argsort_i32(&data).unwrap();
assert_eq!(indices, vec![]);
}
#[test]
fn test_argsort_f32_empty() {
let mut sorter = GpuSorter::new().unwrap();
let data: Vec<f32> = vec![];
let indices = sorter.argsort_f32(&data).unwrap();
assert_eq!(indices, vec![]);
}
#[test]
fn test_sort_pairs_u32_basic() {
let mut sorter = GpuSorter::new().unwrap();
let mut keys = vec![3u32, 1, 2];
let mut vals = vec![30u32, 10, 20];
sorter.sort_pairs_u32(&mut keys, &mut vals).unwrap();
assert_eq!(keys, vec![1, 2, 3]);
assert_eq!(vals, vec![10, 20, 30]);
}
#[test]
fn test_sort_pairs_length_mismatch() {
let mut sorter = GpuSorter::new().unwrap();
let mut keys = vec![1u32, 2, 3];
let mut vals = vec![10u32, 20];
assert!(sorter.sort_pairs_u32(&mut keys, &mut vals).is_err());
}
#[test]
fn test_sort_pairs_i32_basic() {
let mut sorter = GpuSorter::new().unwrap();
let mut keys = vec![5i32, -3, 0, -1, 1];
let mut vals = vec![50u32, 30, 0, 10, 10];
sorter.sort_pairs_i32(&mut keys, &mut vals).unwrap();
assert_eq!(keys, vec![-3, -1, 0, 1, 5]);
assert_eq!(vals, vec![30, 10, 0, 10, 50]);
}
#[test]
fn test_sort_pairs_f32_basic() {
let mut sorter = GpuSorter::new().unwrap();
let mut keys = vec![3.0f32, 1.0, 2.0];
let mut vals = vec![30u32, 10, 20];
sorter.sort_pairs_f32(&mut keys, &mut vals).unwrap();
assert_eq!(
keys.iter().map(|x| x.to_bits()).collect::<Vec<_>>(),
vec![1.0f32, 2.0, 3.0]
.iter()
.map(|x| x.to_bits())
.collect::<Vec<_>>()
);
assert_eq!(vals, vec![10, 20, 30]);
}
#[test]
fn test_sort_pairs_u32_empty() {
let mut sorter = GpuSorter::new().unwrap();
let mut keys: Vec<u32> = vec![];
let mut vals: Vec<u32> = vec![];
sorter.sort_pairs_u32(&mut keys, &mut vals).unwrap();
assert_eq!(keys, vec![]);
assert_eq!(vals, vec![]);
}
#[test]
fn test_sort_pairs_u32_single() {
let mut sorter = GpuSorter::new().unwrap();
let mut keys = vec![42u32];
let mut vals = vec![100u32];
sorter.sort_pairs_u32(&mut keys, &mut vals).unwrap();
assert_eq!(keys, vec![42]);
assert_eq!(vals, vec![100]);
}
#[test]
fn test_sort_u64_basic() {
let mut sorter = GpuSorter::new().unwrap();
let mut data = vec![u64::MAX, 0u64, 1, u64::MAX - 1, 42];
sorter.sort_u64(&mut data).unwrap();
assert_eq!(data, vec![0, 1, 42, u64::MAX - 1, u64::MAX]);
}
#[test]
fn test_sort_u64_empty() {
let mut sorter = GpuSorter::new().unwrap();
let mut data: Vec<u64> = vec![];
sorter.sort_u64(&mut data).unwrap();
assert_eq!(data, vec![]);
}
#[test]
fn test_sort_u64_single() {
let mut sorter = GpuSorter::new().unwrap();
let mut data = vec![42u64];
sorter.sort_u64(&mut data).unwrap();
assert_eq!(data, vec![42]);
}
#[test]
fn test_sort_i64_basic() {
let mut sorter = GpuSorter::new().unwrap();
let mut data = vec![5i64, -3, 0, i64::MIN, i64::MAX, -1, 1];
sorter.sort_i64(&mut data).unwrap();
let mut expected = vec![5i64, -3, 0, i64::MIN, i64::MAX, -1, 1];
expected.sort();
assert_eq!(data, expected);
}
#[test]
fn test_sort_i64_empty() {
let mut sorter = GpuSorter::new().unwrap();
let mut data: Vec<i64> = vec![];
sorter.sort_i64(&mut data).unwrap();
assert_eq!(data, vec![]);
}
#[test]
fn test_sort_i64_single() {
let mut sorter = GpuSorter::new().unwrap();
let mut data = vec![42i64];
sorter.sort_i64(&mut data).unwrap();
assert_eq!(data, vec![42]);
}
#[test]
fn test_sort_f64_basic() {
let mut sorter = GpuSorter::new().unwrap();
let mut data = vec![1.0f64, f64::NAN, -0.0, 0.0, f64::NEG_INFINITY, -1.0, f64::INFINITY];
sorter.sort_f64(&mut data).unwrap();
let mut expected = vec![1.0f64, f64::NAN, -0.0, 0.0, f64::NEG_INFINITY, -1.0, f64::INFINITY];
expected.sort_by(f64::total_cmp);
assert_eq!(
data.iter().map(|x| x.to_bits()).collect::<Vec<_>>(),
expected.iter().map(|x| x.to_bits()).collect::<Vec<_>>()
);
}
#[test]
fn test_sort_f64_empty() {
let mut sorter = GpuSorter::new().unwrap();
let mut data: Vec<f64> = vec![];
sorter.sort_f64(&mut data).unwrap();
assert_eq!(data, vec![]);
}
#[test]
fn test_sort_f64_single() {
let mut sorter = GpuSorter::new().unwrap();
let mut data = vec![42.0f64];
sorter.sort_f64(&mut data).unwrap();
assert_eq!(data, vec![42.0]);
}
#[test]
fn test_argsort_u64_basic() {
let mut sorter = GpuSorter::new().unwrap();
let data = vec![30u64, 10, 20];
let indices = sorter.argsort_u64(&data).unwrap();
assert_eq!(indices, vec![1, 2, 0]);
assert_eq!(data, vec![30, 10, 20]); }
#[test]
fn test_argsort_u64_empty() {
let mut sorter = GpuSorter::new().unwrap();
let data: Vec<u64> = vec![];
let indices = sorter.argsort_u64(&data).unwrap();
assert_eq!(indices, vec![]);
}
#[test]
fn test_argsort_u64_single() {
let mut sorter = GpuSorter::new().unwrap();
let data = vec![42u64];
let indices = sorter.argsort_u64(&data).unwrap();
assert_eq!(indices, vec![0]);
}
#[test]
fn test_argsort_i64_basic() {
let mut sorter = GpuSorter::new().unwrap();
let data = vec![5i64, -3, 0, -1, 1];
let indices = sorter.argsort_i64(&data).unwrap();
assert_eq!(indices, vec![1, 3, 2, 4, 0]);
}
#[test]
fn test_argsort_f64_basic() {
let mut sorter = GpuSorter::new().unwrap();
let data = vec![3.0f64, 1.0, 2.0];
let indices = sorter.argsort_f64(&data).unwrap();
assert_eq!(indices, vec![1, 2, 0]);
assert_eq!(data, vec![3.0, 1.0, 2.0]); }
#[test]
fn test_sort_pairs_u64_basic() {
let mut sorter = GpuSorter::new().unwrap();
let mut keys = vec![30u64, 10, 20];
let mut vals = vec![300u32, 100, 200];
sorter.sort_pairs_u64(&mut keys, &mut vals).unwrap();
assert_eq!(keys, vec![10, 20, 30]);
assert_eq!(vals, vec![100, 200, 300]);
}
#[test]
fn test_sort_pairs_i64_basic() {
let mut sorter = GpuSorter::new().unwrap();
let mut keys = vec![5i64, -3, 0];
let mut vals = vec![50u32, 30, 40];
sorter.sort_pairs_i64(&mut keys, &mut vals).unwrap();
assert_eq!(keys, vec![-3, 0, 5]);
assert_eq!(vals, vec![30, 40, 50]);
}
#[test]
fn test_sort_pairs_u64_empty() {
let mut sorter = GpuSorter::new().unwrap();
let mut keys: Vec<u64> = vec![];
let mut vals: Vec<u32> = vec![];
sorter.sort_pairs_u64(&mut keys, &mut vals).unwrap();
assert_eq!(keys, vec![]);
assert_eq!(vals, vec![]);
}
#[test]
fn test_sort_pairs_i64_length_mismatch() {
let mut sorter = GpuSorter::new().unwrap();
let mut keys = vec![1i64, 2, 3];
let mut vals = vec![10u32, 20];
assert!(sorter.sort_pairs_i64(&mut keys, &mut vals).is_err());
}
}