use std::ffi::{CStr, CString};
use after_effects::log;
use objc::{class, msg_send, runtime::Object, sel, sel_impl};
use std::os::raw::c_void;
use std::time::{Duration, Instant};
pub unsafe fn nsstring_utf8(s: &str) -> *mut Object {
let c = CString::new(s).unwrap();
let ns: *mut Object = msg_send![class!(NSString), stringWithUTF8String: c.as_ptr()];
ns
}
pub unsafe fn log_buffer_info(tag: &str, raw: *mut core::ffi::c_void) {
if raw.is_null() {
log::error!("[metal] {tag}: null");
return;
}
let obj = raw as *mut Object;
let length: u64 = msg_send![obj, length];
let storage_mode: u64 = msg_send![obj, storageMode];
let contents: *mut core::ffi::c_void = msg_send![obj, contents];
log::info!("[metal] {tag}: MTLBuffer={raw:?}, length={length}, storageMode={storage_mode}, contents={contents:?}");
}
pub unsafe fn ns_error(err: *mut Object) -> Option<String> {
if err.is_null() {
return None;
}
let domain: *mut Object = msg_send![err, domain];
let domain_c: *const std::os::raw::c_char = msg_send![domain, UTF8String];
let domain_str = if !domain_c.is_null() {
unsafe { CStr::from_ptr(domain_c).to_string_lossy().into_owned() }
} else {
"<unknown-domain>".into()
};
let code: i64 = msg_send![err, code];
let desc: *mut Object = msg_send![err, localizedDescription];
let desc_c: *const std::os::raw::c_char = msg_send![desc, UTF8String];
let desc_str = if !desc_c.is_null() {
unsafe { CStr::from_ptr(desc_c).to_string_lossy().into_owned() }
} else {
"<no-description>".into()
};
let fail: *mut Object = msg_send![err, localizedFailureReason];
let fail_c: *const std::os::raw::c_char = if fail.is_null() { std::ptr::null() } else { msg_send![fail, UTF8String] };
let fail_str = if !fail_c.is_null() {
unsafe { CStr::from_ptr(fail_c).to_string_lossy().into_owned() }
} else {
String::new()
};
let sugg: *mut Object = msg_send![err, localizedRecoverySuggestion];
let sugg_c: *const std::os::raw::c_char = if sugg.is_null() { std::ptr::null() } else { msg_send![sugg, UTF8String] };
let sugg_str = if !sugg_c.is_null() {
unsafe { CStr::from_ptr(sugg_c).to_string_lossy().into_owned() }
} else {
String::new()
};
let mut msg = format!("{domain_str} ({code}): {desc_str}");
if !fail_str.is_empty() {
msg.push_str(&format!("\nFailureReason: {fail_str}"));
}
if !sugg_str.is_empty() {
msg.push_str(&format!("\nSuggestion: {sugg_str}"));
}
Some(msg)
}
pub mod buffer;
pub mod fence;
pub mod frame_scope;
pub mod pipeline;
use crate::types::{Configuration, FrameParams};
const SET_BYTES_LIMIT: usize = 4096;
pub fn run<UP>(config: &Configuration, user_params: UP, shader_src: &[u8], entry: &'static str) -> Result<(), &'static str> {
use objc::rc::autoreleasepool;
autoreleasepool(|| {
if config.device_handle.is_null() || config.command_queue_handle.is_null() {
log::error!("[Metal] device or command queue handle is null");
return Err("Invalid device or command queue handle");
}
if config.dest_data.is_null() {
log::error!("[Metal] dest_data is null");
return Err("null dest buffer");
}
let has_outgoing = config.outgoing_data.map_or(false, |p| !p.is_null());
let has_incoming = config.incoming_data.map_or(false, |p| !p.is_null());
if !has_outgoing && !has_incoming {
log::error!("[Metal] both outgoing and incoming are null/missing");
return Err("no input buffers");
}
let device = config.device_handle as *mut Object;
let queue = config.command_queue_handle as *mut Object;
let pipeline = unsafe { crate::gpu::pipeline::load_kernel(device, shader_src, entry) }?;
if pipeline.is_null() {
log::error!("[Metal] pipeline state is null");
return Err("null pipeline state");
}
let frame_params = FrameParams::from_config(config);
let outgoing_ptr = config.outgoing_data.unwrap_or(std::ptr::null_mut());
let incoming_ptr = config.incoming_data.unwrap_or(std::ptr::null_mut());
let frame_params_size = std::mem::size_of::<FrameParams>();
let user_param_size = std::mem::size_of::<UP>();
debug_assert!(frame_params_size <= SET_BYTES_LIMIT && user_param_size <= SET_BYTES_LIMIT);
#[cfg(debug_assertions)]
log::debug!(
"[Metal] '{entry}' bufs: dispatch={}x{} dst_pitch_px={} | outgoing={}x{} out_pitch_px={} mip_levels={} outDesc.mipCount={} | dstDesc={}x{} dstDesc.pitch={} | outgoing_ptr={:?} incoming_ptr={:?} dst_ptr={:?}",
config.width,
config.height,
config.dest_pitch_px,
config.outgoing_width,
config.outgoing_height,
config.outgoing_pitch_px,
config.outgoing_mip_levels,
frame_params.out_desc.mip_level_count,
frame_params.dst_desc.width,
frame_params.dst_desc.height,
frame_params.dst_desc.pitch_bytes,
outgoing_ptr,
incoming_ptr,
config.dest_data,
);
let tew: usize = unsafe { msg_send![pipeline, threadExecutionWidth] };
let max_threads: usize = unsafe { msg_send![pipeline, maxTotalThreadsPerThreadgroup] };
let tg_w = tew.max(1);
let tg_h = (max_threads / tg_w).clamp(1, 16);
let groups_x = (config.width as usize).div_ceil(tg_w);
let groups_y = (config.height as usize).div_ceil(tg_h);
let tg = crate::types::MTLSize {
width: groups_x,
height: groups_y,
depth: 1,
};
let tp = crate::types::MTLSize {
width: tg_w,
height: tg_h,
depth: 1,
};
if frame_scope::is_active() {
let cmd = frame_scope::command_buffer();
let enc: *mut Object = unsafe { msg_send![cmd, computeCommandEncoder] };
if enc.is_null() {
log::error!("[Metal] failed to create compute encoder");
return Err("compute encoder creation failed");
}
unsafe {
encode_pass(enc, pipeline, outgoing_ptr, incoming_ptr, config.dest_data, &frame_params, &user_params, tg, tp);
}
frame_scope::note_pass();
return Ok(());
}
const MAX_ATTEMPTS: u32 = 2;
let mut attempt: u32 = 0;
let gpu_ms = loop {
attempt += 1;
let cmd: *mut Object = unsafe { msg_send![queue, commandBuffer] };
if cmd.is_null() {
log::error!("[Metal] failed to create command buffer");
return Err("command buffer creation failed");
}
let enc: *mut Object = unsafe { msg_send![cmd, computeCommandEncoder] };
if enc.is_null() {
log::error!("[Metal] failed to create compute encoder");
return Err("compute encoder creation failed");
}
unsafe {
encode_pass(enc, pipeline, outgoing_ptr, incoming_ptr, config.dest_data, &frame_params, &user_params, tg, tp);
}
#[cfg(debug_assertions)]
let cpu_start = Instant::now();
unsafe {
let _: () = msg_send![cmd, commit];
let _: () = msg_send![cmd, waitUntilCompleted];
}
let status: u64 = unsafe { msg_send![cmd, status] };
if status == 5 {
let error: *mut Object = unsafe { msg_send![cmd, error] };
let msg = unsafe { ns_error(error) };
let is_watchdog = msg
.as_ref()
.is_some_and(|m| m.contains("Impacting Interactivity") || m.contains("kIOGPUCommandBufferCallbackError"));
if is_watchdog && attempt < MAX_ATTEMPTS {
log::warn!(
"[Metal] '{entry}' hit GPU watchdog (attempt {attempt}/{MAX_ATTEMPTS}) — cooling down 50ms and retrying"
);
std::thread::sleep(Duration::from_millis(50));
continue;
}
if let Some(m) = msg {
log::error!("[Metal] command buffer error: {m}");
}
return Err("GPU execution error");
}
if attempt > 1 {
log::info!("[Metal] '{entry}' recovered after watchdog retry (attempt {attempt})");
}
let gpu_start: f64 = unsafe { msg_send![cmd, GPUStartTime] };
let gpu_end: f64 = unsafe { msg_send![cmd, GPUEndTime] };
let gpu_ms = (gpu_end - gpu_start) * 1000.0;
#[cfg(debug_assertions)]
{
let cpu_elapsed = cpu_start.elapsed();
let generation = config.render_generation;
log::info!("[Metal] `{entry}` gen={generation}: gpu={gpu_ms:.3}ms, cpu={cpu_elapsed:?}");
}
break gpu_ms;
};
crate::timing::record(entry, crate::types::Backend::Metal, (gpu_ms * 1_000_000.0) as u64);
Ok(())
})
}
#[allow(clippy::too_many_arguments)]
unsafe fn encode_pass<UP>(
enc: *mut Object,
pipeline: *mut Object,
outgoing: *mut c_void,
incoming: *mut c_void,
dest: *mut c_void,
frame_params: &FrameParams,
user_params: &UP,
tg: crate::types::MTLSize,
tp: crate::types::MTLSize,
) {
unsafe {
let _: () = msg_send![enc, setComputePipelineState: pipeline];
let _: () = msg_send![enc, setBuffer: outgoing as *mut Object offset: 0usize atIndex: 0usize];
let _: () = msg_send![enc, setBuffer: incoming as *mut Object offset: 0usize atIndex: 1usize];
let _: () = msg_send![enc, setBuffer: dest as *mut Object offset: 0usize atIndex: 2usize];
let _: () = msg_send![enc, setBytes: frame_params as *const _ as *const c_void length: std::mem::size_of::<FrameParams>() atIndex: 3usize];
let _: () = msg_send![enc, setBytes: user_params as *const _ as *const c_void length: std::mem::size_of::<UP>() atIndex: 4usize];
let _: () = msg_send![enc, dispatchThreadgroups: tg threadsPerThreadgroup: tp];
let _: () = msg_send![enc, endEncoding];
}
}