use metal::{
BufferRef, CommandBufferRef, ComputeCommandEncoderRef, ComputePipelineStateRef, MTLSize,
};
macro_rules! pipeline_bundle {
(
$(#[$struct_meta:meta])*
$vis:vis struct $name:ident {
$(
$(#[$field_meta:meta])*
$field:ident => $kernel:literal
),+ $(,)?
}
) => {
$(#[$struct_meta])*
$vis struct $name {
$(
$(#[$field_meta])*
pub $field: ::metal::ComputePipelineState,
)+
}
impl $name {
pub fn fetch(
metal: &mut $crate::riir::backend::gpu::metal::MetalContext,
) -> ::core::result::Result<
Self,
$crate::riir::backend::gpu::metal::MetalError,
> {
::core::result::Result::Ok(Self {
$( $field: metal.pipeline($kernel)?.clone(), )+
})
}
}
};
}
pub(crate) use pipeline_bundle;
pub(crate) struct ComputeEncoder<'e> {
enc: &'e ComputeCommandEncoderRef,
ended: bool,
}
impl<'e> ComputeEncoder<'e> {
pub fn wrap(enc: &'e ComputeCommandEncoderRef) -> Self {
Self { enc, ended: false }
}
pub fn begin(cmdbuf: &'e CommandBufferRef) -> Self {
Self::wrap(cmdbuf.new_compute_command_encoder())
}
pub fn pipeline(&mut self, state: &ComputePipelineStateRef) -> &mut Self {
self.enc.set_compute_pipeline_state(state);
self
}
pub fn buffer(&mut self, index: u64, buffer: &BufferRef, offset: u64) -> &mut Self {
self.enc.set_buffer(index, Some(buffer), offset);
self
}
pub fn bytes<T: Copy + 'static>(&mut self, index: u64, value: &T) -> &mut Self {
debug_assert!(
std::mem::size_of::<T>() <= 4096,
"set_bytes is for small inline constants, not buffers",
);
self.enc.set_bytes(
index,
std::mem::size_of::<T>() as u64,
(value as *const T).cast::<std::ffi::c_void>(),
);
self
}
pub fn dispatch(&mut self, grid: MTLSize, threadgroup: MTLSize) -> &mut Self {
self.enc.dispatch_thread_groups(grid, threadgroup);
self
}
pub fn end(&mut self) {
if !self.ended {
self.enc.end_encoding();
self.ended = true;
}
}
}
impl Drop for ComputeEncoder<'_> {
fn drop(&mut self) {
self.end();
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::riir::backend::gpu::metal::MetalContext;
#[test]
fn end_then_drop_is_single_end() {
let mut metal = match MetalContext::new() {
Ok(m) => m,
Err(e) => {
eprintln!("[encoder] skipping: Metal init failed: {e:?}");
return;
}
};
let cmdbuf = metal.queue().new_command_buffer();
let mut ce = ComputeEncoder::begin(cmdbuf);
assert!(!ce.ended, "fresh encoder must not be ended");
ce.end();
assert!(ce.ended, "end() must mark the encoder ended");
ce.end();
assert!(ce.ended, "redundant end() must stay a no-op");
drop(ce); }
#[test]
fn drop_without_explicit_end_still_ends() {
let mut metal = match MetalContext::new() {
Ok(m) => m,
Err(e) => {
eprintln!("[encoder] skipping: Metal init failed: {e:?}");
return;
}
};
let cmdbuf = metal.queue().new_command_buffer();
{
let _ce = ComputeEncoder::begin(cmdbuf);
}
}
}