use core::{ffi::c_void, ops::Range, ptr::NonNull};
use objc2::{Message, extern_protocol, msg_send, runtime::ProtocolObject};
use objc2_foundation::NSRange;
use crate::util::{opt_ref_slice_as_ptr, ref_slice_as_ptr};
use crate::{
MTLAccelerationStructure, MTLBarrierScope, MTLBuffer, MTLCommandEncoder, MTLCounterSampleBuffer, MTLFence, MTLHeap,
MTLIndirectCommandBuffer, MTLResource, MTLResourceUsage, MTLSamplerState, MTLTexture,
compute_pipeline::MTLComputePipelineState,
intersection_function_table::MTLIntersectionFunctionTable,
types::{MTLRegion, MTLSize},
visible_function_table::MTLVisibleFunctionTable,
};
extern_protocol!(
pub unsafe trait MTLComputeCommandEncoder: MTLCommandEncoder {
#[unsafe(method(dispatchType))]
#[unsafe(method_family = none)]
fn dispatch_type(&self) -> super::MTLDispatchType;
#[unsafe(method(setComputePipelineState:))]
#[unsafe(method_family = none)]
fn set_compute_pipeline_state(
&self,
state: &ProtocolObject<dyn MTLComputePipelineState>,
);
#[unsafe(method(setBytes:length:atIndex:))]
#[unsafe(method_family = none)]
fn set_bytes(
&self,
bytes: NonNull<c_void>,
length: usize,
index: usize,
);
#[unsafe(method(setBuffer:offset:atIndex:))]
#[unsafe(method_family = none)]
fn set_buffer(
&self,
buffer: Option<&ProtocolObject<dyn MTLBuffer>>,
offset: usize,
index: usize,
);
#[unsafe(method(setBufferOffset:atIndex:))]
#[unsafe(method_family = none)]
fn set_buffer_offset(
&self,
offset: usize,
index: usize,
);
#[unsafe(method(setBuffer:offset:attributeStride:atIndex:))]
#[unsafe(method_family = none)]
fn set_buffer_with_attribute_stride(
&self,
buffer: &ProtocolObject<dyn MTLBuffer>,
offset: usize,
stride: usize,
index: usize,
);
#[unsafe(method(setBufferOffset:attributeStride:atIndex:))]
#[unsafe(method_family = none)]
fn set_buffer_offset_with_attribute_stride(
&self,
offset: usize,
stride: usize,
index: usize,
);
#[unsafe(method(setBytes:length:attributeStride:atIndex:))]
#[unsafe(method_family = none)]
fn set_bytes_with_attribute_stride(
&self,
bytes: NonNull<c_void>,
length: usize,
stride: usize,
index: usize,
);
#[unsafe(method(setVisibleFunctionTable:atBufferIndex:))]
#[unsafe(method_family = none)]
fn set_visible_function_table(
&self,
table: Option<&ProtocolObject<dyn MTLVisibleFunctionTable>>,
buffer_index: usize,
);
#[unsafe(method(setIntersectionFunctionTable:atBufferIndex:))]
#[unsafe(method_family = none)]
fn set_intersection_function_table(
&self,
table: Option<&ProtocolObject<dyn MTLIntersectionFunctionTable>>,
buffer_index: usize,
);
#[unsafe(method(setAccelerationStructure:atBufferIndex:))]
#[unsafe(method_family = none)]
fn set_acceleration_structure(
&self,
structure: Option<&ProtocolObject<dyn MTLAccelerationStructure>>,
buffer_index: usize,
);
#[unsafe(method(setTexture:atIndex:))]
#[unsafe(method_family = none)]
fn set_texture(
&self,
texture: Option<&ProtocolObject<dyn MTLTexture>>,
index: usize,
);
#[unsafe(method(setSamplerState:atIndex:))]
#[unsafe(method_family = none)]
fn set_sampler_state(
&self,
sampler: Option<&ProtocolObject<dyn MTLSamplerState>>,
index: usize,
);
#[unsafe(method(setSamplerState:lodMinClamp:lodMaxClamp:atIndex:))]
#[unsafe(method_family = none)]
fn set_sampler_state_with_lod(
&self,
sampler: Option<&ProtocolObject<dyn MTLSamplerState>>,
lod_min_clamp: f32,
lod_max_clamp: f32,
index: usize,
);
#[unsafe(method(setThreadgroupMemoryLength:atIndex:))]
#[unsafe(method_family = none)]
fn set_threadgroup_memory_length(
&self,
length: usize,
index: usize,
);
#[unsafe(method(setImageblockWidth:height:))]
#[unsafe(method_family = none)]
fn set_imageblock_width_height(
&self,
width: usize,
height: usize,
);
#[unsafe(method(setStageInRegion:))]
#[unsafe(method_family = none)]
fn set_stage_in_region(
&self,
region: MTLRegion,
);
#[unsafe(method(setStageInRegionWithIndirectBuffer:indirectBufferOffset:))]
#[unsafe(method_family = none)]
fn set_stage_in_region_indirect(
&self,
indirect_buffer: &ProtocolObject<dyn MTLBuffer>,
indirect_buffer_offset: usize,
);
#[unsafe(method(dispatchThreadgroups:threadsPerThreadgroup:))]
#[unsafe(method_family = none)]
fn dispatch_threadgroups(
&self,
threadgroups_per_grid: MTLSize,
threads_per_threadgroup: MTLSize,
);
#[unsafe(method(dispatchThreadgroupsWithIndirectBuffer:indirectBufferOffset:threadsPerThreadgroup:))]
#[unsafe(method_family = none)]
fn dispatch_threadgroups_indirect(
&self,
indirect_buffer: &ProtocolObject<dyn MTLBuffer>,
indirect_buffer_offset: usize,
threads_per_threadgroup: MTLSize,
);
#[unsafe(method(dispatchThreads:threadsPerThreadgroup:))]
#[unsafe(method_family = none)]
fn dispatch_threads(
&self,
threads_per_grid: MTLSize,
threads_per_threadgroup: MTLSize,
);
#[unsafe(method(updateFence:))]
#[unsafe(method_family = none)]
fn update_fence(
&self,
fence: &ProtocolObject<dyn MTLFence>,
);
#[unsafe(method(waitForFence:))]
#[unsafe(method_family = none)]
fn wait_for_fence(
&self,
fence: &ProtocolObject<dyn MTLFence>,
);
#[unsafe(method(useResource:usage:))]
#[unsafe(method_family = none)]
fn use_resource(
&self,
resource: &ProtocolObject<dyn MTLResource>,
usage: MTLResourceUsage,
);
#[unsafe(method(useHeap:))]
#[unsafe(method_family = none)]
fn use_heap(
&self,
heap: &ProtocolObject<dyn MTLHeap>,
);
#[unsafe(method(executeCommandsInBuffer:indirectBuffer:indirectBufferOffset:))]
#[unsafe(method_family = none)]
fn execute_commands_in_buffer_indirect(
&self,
icb: &ProtocolObject<dyn MTLIndirectCommandBuffer>,
indirect_range_buffer: &ProtocolObject<dyn MTLBuffer>,
indirect_buffer_offset: usize,
);
#[unsafe(method(memoryBarrierWithScope:))]
#[unsafe(method_family = none)]
fn memory_barrier_with_scope(
&self,
scope: MTLBarrierScope,
);
#[unsafe(method(sampleCountersInBuffer:atSampleIndex:withBarrier:))]
#[unsafe(method_family = none)]
fn sample_counters_in_buffer(
&self,
sample_buffer: &ProtocolObject<dyn MTLCounterSampleBuffer>,
sample_index: usize,
barrier: bool,
);
}
);
pub trait MTLComputeCommandEncoderExt: MTLComputeCommandEncoder + Message {
fn set_buffers(
&self,
buffers: &[Option<&ProtocolObject<dyn MTLBuffer>>],
offsets: &[usize],
range: Range<usize>,
) where
Self: Sized,
{
assert_eq!(buffers.len(), offsets.len());
let ptr = opt_ref_slice_as_ptr(buffers);
unsafe {
msg_send![self, setBuffers: ptr, offsets: offsets.as_ptr(), withRange: NSRange::from(range)]
}
}
fn set_buffers_with_attribute_strides(
&self,
buffers: &[Option<&ProtocolObject<dyn MTLBuffer>>],
offsets: &[usize],
strides: &[usize],
range: Range<usize>,
) where
Self: Sized,
{
assert_eq!(buffers.len(), offsets.len());
assert_eq!(buffers.len(), strides.len());
let ptr = opt_ref_slice_as_ptr(buffers);
unsafe {
msg_send![
self,
setBuffers: ptr,
offsets: offsets.as_ptr(),
attributeStrides: strides.as_ptr(),
withRange: NSRange::from(range)
]
}
}
fn set_visible_function_tables(
&self,
tables: &[Option<&ProtocolObject<dyn MTLVisibleFunctionTable>>],
range: Range<usize>,
) where
Self: Sized,
{
let ptr = opt_ref_slice_as_ptr(tables);
unsafe { msg_send![self, setVisibleFunctionTables: ptr, withBufferRange: NSRange::from(range)] }
}
fn set_intersection_function_tables(
&self,
tables: &[Option<&ProtocolObject<dyn MTLIntersectionFunctionTable>>],
range: Range<usize>,
) where
Self: Sized,
{
let ptr = opt_ref_slice_as_ptr(tables);
unsafe { msg_send![self, setIntersectionFunctionTables: ptr, withBufferRange: NSRange::from(range)] }
}
fn set_textures(
&self,
textures: &[Option<&ProtocolObject<dyn MTLTexture>>],
range: Range<usize>,
) where
Self: Sized,
{
let ptr = opt_ref_slice_as_ptr(textures);
unsafe { msg_send![self, setTextures: ptr, withRange: NSRange::from(range)] }
}
fn set_sampler_states(
&self,
samplers: &[Option<&ProtocolObject<dyn MTLSamplerState>>],
range: Range<usize>,
) where
Self: Sized,
{
let ptr = opt_ref_slice_as_ptr(samplers);
unsafe { msg_send![self, setSamplerStates: ptr, withRange: NSRange::from(range)] }
}
fn set_sampler_states_with_lods(
&self,
samplers: &[Option<&ProtocolObject<dyn MTLSamplerState>>],
lod_min_clamps: &[f32],
lod_max_clamps: &[f32],
range: Range<usize>,
) where
Self: Sized,
{
assert_eq!(samplers.len(), lod_min_clamps.len());
assert_eq!(samplers.len(), lod_max_clamps.len());
let ptr = opt_ref_slice_as_ptr(samplers);
unsafe {
msg_send![
self,
setSamplerStates: ptr,
lodMinClamps: lod_min_clamps.as_ptr(),
lodMaxClamps: lod_max_clamps.as_ptr(),
withRange: NSRange::from(range)
]
}
}
fn execute_commands_in_buffer_with_range(
&self,
icb: &ProtocolObject<dyn MTLIndirectCommandBuffer>,
execution_range: Range<usize>,
) where
Self: Sized,
{
unsafe {
let _: () = msg_send![
self,
executeCommandsInBuffer: icb,
withRange: NSRange::from(execution_range)
];
}
}
fn use_resources(
&self,
resources: &[&ProtocolObject<dyn MTLResource>],
usage: MTLResourceUsage,
) where
Self: Sized,
{
let ptr = ref_slice_as_ptr(resources);
unsafe { msg_send![self, useResources: ptr, count: resources.len(), usage: usage] }
}
fn use_heaps(&self, heaps: &[&ProtocolObject<dyn MTLHeap>])
where
Self: Sized,
{
let ptr = ref_slice_as_ptr(heaps);
unsafe { msg_send![self, useHeaps: ptr, count: heaps.len()] }
}
fn memory_barrier_with_resources(&self, resources: &[&ProtocolObject<dyn MTLResource>])
where
Self: Sized,
{
let ptr = ref_slice_as_ptr(resources);
unsafe { msg_send![self, memoryBarrierWithResources: ptr, count: resources.len()] }
}
}
impl<T: MTLComputeCommandEncoder + Message> MTLComputeCommandEncoderExt for T {}