use super::{ExecutableCompletionHandler, ExecutableScheduledHandler, ExecutionStage};
use crate::TensorData;
use block2::DynBlock;
use metal::SharedEvent;
use metal::foreign_types::ForeignType;
use objc2::rc::{Allocated, Retained};
use objc2::runtime::NSObject;
use objc2::{extern_class, extern_conformance, extern_methods, msg_send};
use objc2_foundation::{CopyingHelper, NSArray, NSCopying, NSError, NSObjectProtocol};
use std::ptr::NonNull;
use crate::GraphObject;
extern_class!(
#[unsafe(super(GraphObject, NSObject))]
#[derive(Debug, PartialEq, Eq, Hash)]
#[name = "MPSGraphExecutableExecutionDescriptor"]
pub struct ExecutableExecutionDescriptor;
);
extern_conformance!(
unsafe impl NSCopying for ExecutableExecutionDescriptor {}
);
unsafe impl CopyingHelper for ExecutableExecutionDescriptor {
type Result = Self;
}
extern_conformance!(
unsafe impl NSObjectProtocol for ExecutableExecutionDescriptor {}
);
impl ExecutableExecutionDescriptor {
extern_methods!(
#[unsafe(method(init))]
#[unsafe(method_family = init)]
pub fn init(this: Allocated<Self>) -> Retained<Self>;
#[unsafe(method(new))]
#[unsafe(method_family = new)]
pub fn new() -> Retained<Self>;
#[unsafe(method(waitUntilCompleted))]
#[unsafe(method_family = none)]
pub fn wait_until_completed(&self) -> bool;
#[unsafe(method(setWaitUntilCompleted:))]
#[unsafe(method_family = none)]
pub fn set_wait_until_completed(&self, wait_until_completed: bool);
);
}
impl ExecutableExecutionDescriptor {
pub fn wait_for_event(&self, event: &SharedEvent, value: u64) {
unsafe {
let event_ptr = event.as_ptr() as *mut std::ffi::c_void;
let _: () = msg_send![self, waitForEvent: event_ptr, value: value];
}
}
pub fn signal_event(&self, event: &SharedEvent, execution_stage: ExecutionStage, value: u64) {
unsafe {
let event_ptr = event.as_ptr() as *mut std::ffi::c_void;
let _: () = msg_send![self, signalEvent: event_ptr, atExecutionEvent: execution_stage as u64, value: value];
}
}
pub fn completion_handler(&self) -> ExecutableCompletionHandler {
unsafe {
let block_ptr: *mut DynBlock<dyn Fn(NonNull<NSArray<TensorData>>, *mut NSError)> =
msg_send![self, completionHandler];
ExecutableCompletionHandler::copy(block_ptr)
}
}
pub fn set_completion_handler(&self, completion_handler: ExecutableCompletionHandler) {
unsafe {
let _: () = msg_send![self, setCompletionHandler: &*completion_handler];
}
}
pub fn scheduled_handler(&self) -> ExecutableScheduledHandler {
unsafe {
let block_ptr: *mut DynBlock<dyn Fn(NonNull<NSArray<TensorData>>, *mut NSError)> =
msg_send![self, scheduledHandler];
ExecutableScheduledHandler::copy(block_ptr)
}
}
pub fn set_scheduled_handler(&self, scheduled_handler: ExecutableScheduledHandler) {
unsafe {
let _: () = msg_send![self, setScheduledHandler: &*scheduled_handler];
}
}
}
impl ExecutableExecutionDescriptor {
pub fn set_enable_commit_and_continue(&self, enable: bool) {
unsafe {
let _: () = msg_send![self, setEnableCommitAndContinue: enable];
}
}
pub fn set_simulate_ane_compile_failure(&self, enable: bool) {
unsafe {
let _: () = msg_send![self, setSimulateANECompileFailure: enable];
}
}
pub fn set_simulate_ane_load_model_failure(&self, enable: bool) {
unsafe {
let _: () = msg_send![self, setSimulateANELoadModelFailure: enable];
}
}
pub fn set_disable_synchronize_results(&self, disable: bool) {
unsafe {
let _: () = msg_send![self, setDisableSynchronizeResults: disable];
}
}
pub fn set_disable_ane_caching(&self, disable: bool) {
unsafe {
let _: () = msg_send![self, setDisableANECaching: disable];
}
}
pub fn set_disable_ane_fallback(&self, disable: bool) {
unsafe {
let _: () = msg_send![self, setDisableANEFallback: disable];
}
}
pub fn set_enable_profiling_op_names(&self, enable: bool) {
unsafe {
let _: () = msg_send![self, setEnableProfilingOpNames: enable];
}
}
pub fn set_brief_profiling_op_names(&self, enable: bool) {
unsafe {
let _: () = msg_send![self, setBriefProfilingOpNames: enable];
}
}
pub fn set_break_up_metal_encoders(&self, enable: bool) {
unsafe {
let _: () = msg_send![self, setBreakUpMetalEncoders: enable];
}
}
pub fn set_generate_runtime_execution_report(&self, enable: bool) {
unsafe {
let _: () = msg_send![self, setGenerateRuntimeExecutionReport: enable];
}
}
pub fn set_maximum_number_of_encoding_threads(&self, value: u64) {
unsafe {
let _: () = msg_send![self, setMaximumNumberOfEncodingThreads: value];
}
}
pub fn number_of_commits_by_mps_graph(&self) -> u64 {
unsafe { msg_send![self, numberOfCommitsByMPSGraph] }
}
}