objc2_metal_performance_shaders_graph/generated/
MPSGraph.rs

1//! This file has been automatically generated by `objc2`'s `header-translator`.
2//! DO NOT EDIT
3use core::ffi::*;
4use core::ptr::NonNull;
5#[cfg(feature = "dispatch2")]
6use dispatch2::*;
7use objc2::__framework_prelude::*;
8use objc2_foundation::*;
9use objc2_metal::*;
10#[cfg(feature = "objc2-metal-performance-shaders")]
11use objc2_metal_performance_shaders::*;
12
13use crate::*;
14
15/// The options available to a graph.
16///
17/// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraphoptions?language=objc)
18// NS_ENUM
19#[repr(transparent)]
20#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
21pub struct MPSGraphOptions(pub u64);
22impl MPSGraphOptions {
23    /// No Options.
24    #[doc(alias = "MPSGraphOptionsNone")]
25    pub const None: Self = Self(0);
26    /// The graph synchronizes results to the CPU using a blit encoder if on a discrete GPU at the end of execution.
27    #[doc(alias = "MPSGraphOptionsSynchronizeResults")]
28    pub const SynchronizeResults: Self = Self(1);
29    /// The framework prints more logging info.
30    #[doc(alias = "MPSGraphOptionsVerbose")]
31    pub const Verbose: Self = Self(2);
32    /// The framework uses these options as default if not overriden.
33    #[doc(alias = "MPSGraphOptionsDefault")]
34    pub const Default: Self = Self(MPSGraphOptions::SynchronizeResults.0);
35}
36
37unsafe impl Encode for MPSGraphOptions {
38    const ENCODING: Encoding = u64::ENCODING;
39}
40
41unsafe impl RefEncode for MPSGraphOptions {
42    const ENCODING_REF: Encoding = Encoding::Pointer(&Self::ENCODING);
43}
44
45/// The optimization levels to trade compilation time for even more runtime performance by running more passes.
46///
47/// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraphoptimization?language=objc)
48// NS_ENUM
49#[repr(transparent)]
50#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
51pub struct MPSGraphOptimization(pub u64);
52impl MPSGraphOptimization {
53    /// Graph performs core optimizations only.
54    #[doc(alias = "MPSGraphOptimizationLevel0")]
55    pub const Level0: Self = Self(0);
56    /// Graph performs additional Optimizations, like using the placement pass to dispatch across different HW blocks like the NeuralEngine and CPU along with the GPU.
57    #[doc(alias = "MPSGraphOptimizationLevel1")]
58    pub const Level1: Self = Self(1);
59}
60
61unsafe impl Encode for MPSGraphOptimization {
62    const ENCODING: Encoding = u64::ENCODING;
63}
64
65unsafe impl RefEncode for MPSGraphOptimization {
66    const ENCODING_REF: Encoding = Encoding::Pointer(&Self::ENCODING);
67}
68
69/// The optimization profile used as a heuristic as the graph compiler optimizes the network.
70///
71/// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraphoptimizationprofile?language=objc)
72// NS_ENUM
73#[repr(transparent)]
74#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
75pub struct MPSGraphOptimizationProfile(pub u64);
76impl MPSGraphOptimizationProfile {
77    /// Default, graph optimized for performance.
78    #[doc(alias = "MPSGraphOptimizationProfilePerformance")]
79    pub const Performance: Self = Self(0);
80    /// Graph optimized for power efficiency.
81    #[doc(alias = "MPSGraphOptimizationProfilePowerEfficiency")]
82    pub const PowerEfficiency: Self = Self(1);
83}
84
85unsafe impl Encode for MPSGraphOptimizationProfile {
86    const ENCODING: Encoding = u64::ENCODING;
87}
88
89unsafe impl RefEncode for MPSGraphOptimizationProfile {
90    const ENCODING_REF: Encoding = Encoding::Pointer(&Self::ENCODING);
91}
92
93/// Execution events that can be used with shared events.
94///
95/// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraphexecutionstage?language=objc)
96// NS_ENUM
97#[repr(transparent)]
98#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
99pub struct MPSGraphExecutionStage(pub u64);
100impl MPSGraphExecutionStage {
101    /// stage when execution of the graph completes.
102    #[doc(alias = "MPSGraphExecutionStageCompleted")]
103    pub const Completed: Self = Self(0);
104}
105
106unsafe impl Encode for MPSGraphExecutionStage {
107    const ENCODING: Encoding = u64::ENCODING;
108}
109
110unsafe impl RefEncode for MPSGraphExecutionStage {
111    const ENCODING_REF: Encoding = Encoding::Pointer(&Self::ENCODING);
112}
113
114/// MPSGraph could use these reduced precision paths to deliver faster math, but it is not guaranteed.
115///
116/// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraphreducedprecisionfastmath?language=objc)
117// NS_OPTIONS
118#[repr(transparent)]
119#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
120pub struct MPSGraphReducedPrecisionFastMath(pub NSUInteger);
121bitflags::bitflags! {
122    impl MPSGraphReducedPrecisionFastMath: NSUInteger {
123/// Full precision math with maximum accuracy.
124        #[doc(alias = "MPSGraphReducedPrecisionFastMathNone")]
125        const None = 0;
126/// Execute winograd transform intermediate as FP16.
127        #[doc(alias = "MPSGraphReducedPrecisionFastMathAllowFP16Conv2DWinogradTransformIntermediate")]
128        const AllowFP16Conv2DWinogradTransformIntermediate = 1<<1;
129/// Curated list allowing intermediates for multi-pass GPU kernels to be FP16.
130        #[doc(alias = "MPSGraphReducedPrecisionFastMathAllowFP16Intermediates")]
131        const AllowFP16Intermediates = MPSGraphReducedPrecisionFastMath::AllowFP16Conv2DWinogradTransformIntermediate.0;
132/// Default selection.
133        #[doc(alias = "MPSGraphReducedPrecisionFastMathDefault")]
134        const Default = MPSGraphReducedPrecisionFastMath::None.0;
135    }
136}
137
138unsafe impl Encode for MPSGraphReducedPrecisionFastMath {
139    const ENCODING: Encoding = NSUInteger::ENCODING;
140}
141
142unsafe impl RefEncode for MPSGraphReducedPrecisionFastMath {
143    const ENCODING_REF: Encoding = Encoding::Pointer(&Self::ENCODING);
144}
145
146/// A dictionary of tensors and corresponding tensor data.
147///
148/// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraphtensordatadictionary?language=objc)
149#[cfg(all(
150    feature = "MPSGraphCore",
151    feature = "MPSGraphTensor",
152    feature = "MPSGraphTensorData"
153))]
154pub type MPSGraphTensorDataDictionary = NSDictionary<MPSGraphTensor, MPSGraphTensorData>;
155
156/// A dictionary of tensors and corresponding shapes for them.
157///
158/// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraphtensorshapedtypedictionary?language=objc)
159#[cfg(all(feature = "MPSGraphCore", feature = "MPSGraphTensor"))]
160pub type MPSGraphTensorShapedTypeDictionary = NSDictionary<MPSGraphTensor, MPSGraphShapedType>;
161
162/// A notification that appears when graph execution finishes.
163///
164/// - Parameters:
165/// - resultsDictionary: If no error, the results dictionary produced by the graph operation.
166/// - error: If an error occurs, more information might be found here.
167///
168/// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraphcompletionhandler?language=objc)
169#[cfg(all(
170    feature = "MPSGraphCore",
171    feature = "MPSGraphTensor",
172    feature = "MPSGraphTensorData",
173    feature = "block2"
174))]
175pub type MPSGraphCompletionHandler =
176    *mut block2::DynBlock<dyn Fn(NonNull<MPSGraphTensorDataDictionary>, *mut NSError)>;
177
178/// A notification that appears when graph execution schedules.
179///
180/// - Parameters:
181/// - resultsDictionary: If no error, the results dictionary produced by the graph operation. If Graph has not yet allocated, the results will be `NSNull`.
182/// - error: If an error occurs, more information might be found here.
183///
184/// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraphscheduledhandler?language=objc)
185#[cfg(all(
186    feature = "MPSGraphCore",
187    feature = "MPSGraphTensor",
188    feature = "MPSGraphTensorData",
189    feature = "block2"
190))]
191pub type MPSGraphScheduledHandler =
192    *mut block2::DynBlock<dyn Fn(NonNull<MPSGraphTensorDataDictionary>, *mut NSError)>;
193
194/// A notification that appears when compilation finishes.
195///
196/// - Parameters:
197/// - executable: If no error, the executable produced by the compilation.
198/// - error: If an error occurs, more information might be found here.
199///
200/// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraphcompilationcompletionhandler?language=objc)
201#[cfg(all(
202    feature = "MPSGraphCore",
203    feature = "MPSGraphExecutable",
204    feature = "block2"
205))]
206pub type MPSGraphCompilationCompletionHandler =
207    *mut block2::DynBlock<dyn Fn(NonNull<MPSGraphExecutable>, *mut NSError)>;
208
209/// A dictionary of symbol names and the corresponding executables for them.
210///
211/// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraphcallablemap?language=objc)
212#[cfg(all(feature = "MPSGraphCore", feature = "MPSGraphExecutable"))]
213pub type MPSGraphCallableMap = NSDictionary<NSString, MPSGraphExecutable>;
214
215extern_class!(
216    /// A class that consists of all the levers for compiling graphs.
217    ///
218    /// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraphcompilationdescriptor?language=objc)
219    #[unsafe(super(MPSGraphObject, NSObject))]
220    #[derive(Debug, PartialEq, Eq, Hash)]
221    #[cfg(feature = "MPSGraphCore")]
222    pub struct MPSGraphCompilationDescriptor;
223);
224
225#[cfg(feature = "MPSGraphCore")]
226extern_conformance!(
227    unsafe impl NSCopying for MPSGraphCompilationDescriptor {}
228);
229
230#[cfg(feature = "MPSGraphCore")]
231unsafe impl CopyingHelper for MPSGraphCompilationDescriptor {
232    type Result = Self;
233}
234
235#[cfg(feature = "MPSGraphCore")]
236extern_conformance!(
237    unsafe impl NSObjectProtocol for MPSGraphCompilationDescriptor {}
238);
239
240#[cfg(feature = "MPSGraphCore")]
241impl MPSGraphCompilationDescriptor {
242    extern_methods!(
243        /// Turns off type inference and relies on type inference during runtime.
244        #[unsafe(method(disableTypeInference))]
245        #[unsafe(method_family = none)]
246        pub unsafe fn disableTypeInference(&self);
247
248        /// The optimization level for the graph execution, default is MPSGraphOptimizationLevel1.
249        #[unsafe(method(optimizationLevel))]
250        #[unsafe(method_family = none)]
251        pub unsafe fn optimizationLevel(&self) -> MPSGraphOptimization;
252
253        /// Setter for [`optimizationLevel`][Self::optimizationLevel].
254        #[unsafe(method(setOptimizationLevel:))]
255        #[unsafe(method_family = none)]
256        pub unsafe fn setOptimizationLevel(&self, optimization_level: MPSGraphOptimization);
257
258        /// Flag that makes the compile or specialize call blocking till the entire compilation is complete, defaults to NO.
259        #[unsafe(method(waitForCompilationCompletion))]
260        #[unsafe(method_family = none)]
261        pub unsafe fn waitForCompilationCompletion(&self) -> bool;
262
263        /// Setter for [`waitForCompilationCompletion`][Self::waitForCompilationCompletion].
264        #[unsafe(method(setWaitForCompilationCompletion:))]
265        #[unsafe(method_family = none)]
266        pub unsafe fn setWaitForCompilationCompletion(&self, wait_for_compilation_completion: bool);
267
268        #[cfg(all(feature = "MPSGraphExecutable", feature = "block2"))]
269        /// The handler that the graph calls when the compilation completes.
270        ///
271        /// Default value is nil.
272        ///
273        /// # Safety
274        ///
275        /// - The returned block's argument 1 must be a valid pointer.
276        /// - The returned block's argument 2 must be a valid pointer or null.
277        #[unsafe(method(compilationCompletionHandler))]
278        #[unsafe(method_family = none)]
279        pub unsafe fn compilationCompletionHandler(&self) -> MPSGraphCompilationCompletionHandler;
280
281        #[cfg(all(feature = "MPSGraphExecutable", feature = "block2"))]
282        /// Setter for [`compilationCompletionHandler`][Self::compilationCompletionHandler].
283        ///
284        /// # Safety
285        ///
286        /// `compilation_completion_handler` must be a valid pointer.
287        #[unsafe(method(setCompilationCompletionHandler:))]
288        #[unsafe(method_family = none)]
289        pub unsafe fn setCompilationCompletionHandler(
290            &self,
291            compilation_completion_handler: MPSGraphCompilationCompletionHandler,
292        );
293
294        #[cfg(feature = "dispatch2")]
295        /// The dispatch queue used for the compilation.
296        ///
297        /// Default value is nil.
298        #[unsafe(method(dispatchQueue))]
299        #[unsafe(method_family = none)]
300        pub unsafe fn dispatchQueue(&self) -> Retained<DispatchQueue>;
301
302        #[cfg(feature = "dispatch2")]
303        /// Setter for [`dispatchQueue`][Self::dispatchQueue].
304        ///
305        /// # Safety
306        ///
307        /// `dispatch_queue` possibly has additional threading requirements.
308        #[unsafe(method(setDispatchQueue:))]
309        #[unsafe(method_family = none)]
310        pub unsafe fn setDispatchQueue(&self, dispatch_queue: &DispatchQueue);
311
312        /// The optimization profile for the graph optimization.
313        ///
314        /// Default is MPSGraphOptimizationProfilePerformance.
315        #[deprecated]
316        #[unsafe(method(optimizationProfile))]
317        #[unsafe(method_family = none)]
318        pub unsafe fn optimizationProfile(&self) -> MPSGraphOptimizationProfile;
319
320        /// Setter for [`optimizationProfile`][Self::optimizationProfile].
321        #[deprecated]
322        #[unsafe(method(setOptimizationProfile:))]
323        #[unsafe(method_family = none)]
324        pub unsafe fn setOptimizationProfile(
325            &self,
326            optimization_profile: MPSGraphOptimizationProfile,
327        );
328
329        #[cfg(feature = "MPSGraphExecutable")]
330        /// The dictionary used during runtime to lookup the ``MPSGraphExecutable`` which correspond to the ``symbolName``.
331        #[unsafe(method(callables))]
332        #[unsafe(method_family = none)]
333        pub unsafe fn callables(&self) -> Option<Retained<MPSGraphCallableMap>>;
334
335        #[cfg(feature = "MPSGraphExecutable")]
336        /// Setter for [`callables`][Self::callables].
337        #[unsafe(method(setCallables:))]
338        #[unsafe(method_family = none)]
339        pub unsafe fn setCallables(&self, callables: Option<&MPSGraphCallableMap>);
340
341        /// Across the executable allow reduced precision fast math optimizations.
342        #[unsafe(method(reducedPrecisionFastMath))]
343        #[unsafe(method_family = none)]
344        pub unsafe fn reducedPrecisionFastMath(&self) -> MPSGraphReducedPrecisionFastMath;
345
346        /// Setter for [`reducedPrecisionFastMath`][Self::reducedPrecisionFastMath].
347        #[unsafe(method(setReducedPrecisionFastMath:))]
348        #[unsafe(method_family = none)]
349        pub unsafe fn setReducedPrecisionFastMath(
350            &self,
351            reduced_precision_fast_math: MPSGraphReducedPrecisionFastMath,
352        );
353    );
354}
355
356/// Methods declared on superclass `NSObject`.
357#[cfg(feature = "MPSGraphCore")]
358impl MPSGraphCompilationDescriptor {
359    extern_methods!(
360        #[unsafe(method(init))]
361        #[unsafe(method_family = init)]
362        pub unsafe fn init(this: Allocated<Self>) -> Retained<Self>;
363
364        #[unsafe(method(new))]
365        #[unsafe(method_family = new)]
366        pub unsafe fn new() -> Retained<Self>;
367    );
368}
369
370extern_class!(
371    /// A class that consists of all the levers  to synchronize and schedule graph execution.
372    ///
373    /// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraphexecutiondescriptor?language=objc)
374    #[unsafe(super(MPSGraphObject, NSObject))]
375    #[derive(Debug, PartialEq, Eq, Hash)]
376    #[cfg(feature = "MPSGraphCore")]
377    pub struct MPSGraphExecutionDescriptor;
378);
379
380#[cfg(feature = "MPSGraphCore")]
381extern_conformance!(
382    unsafe impl NSObjectProtocol for MPSGraphExecutionDescriptor {}
383);
384
385#[cfg(feature = "MPSGraphCore")]
386impl MPSGraphExecutionDescriptor {
387    extern_methods!(
388        #[cfg(all(
389            feature = "MPSGraphTensor",
390            feature = "MPSGraphTensorData",
391            feature = "block2"
392        ))]
393        /// The handler that graph calls when it schedules the execution.
394        ///
395        /// Default value is nil.
396        ///
397        /// # Safety
398        ///
399        /// - The returned block's argument 1 must be a valid pointer.
400        /// - The returned block's argument 2 must be a valid pointer or null.
401        #[unsafe(method(scheduledHandler))]
402        #[unsafe(method_family = none)]
403        pub unsafe fn scheduledHandler(&self) -> MPSGraphScheduledHandler;
404
405        #[cfg(all(
406            feature = "MPSGraphTensor",
407            feature = "MPSGraphTensorData",
408            feature = "block2"
409        ))]
410        /// Setter for [`scheduledHandler`][Self::scheduledHandler].
411        ///
412        /// # Safety
413        ///
414        /// `scheduled_handler` must be a valid pointer.
415        #[unsafe(method(setScheduledHandler:))]
416        #[unsafe(method_family = none)]
417        pub unsafe fn setScheduledHandler(&self, scheduled_handler: MPSGraphScheduledHandler);
418
419        #[cfg(all(
420            feature = "MPSGraphTensor",
421            feature = "MPSGraphTensorData",
422            feature = "block2"
423        ))]
424        /// The handler that graph calls at the completion of the execution.
425        ///
426        /// Default value is nil.
427        ///
428        /// # Safety
429        ///
430        /// - The returned block's argument 1 must be a valid pointer.
431        /// - The returned block's argument 2 must be a valid pointer or null.
432        #[unsafe(method(completionHandler))]
433        #[unsafe(method_family = none)]
434        pub unsafe fn completionHandler(&self) -> MPSGraphCompletionHandler;
435
436        #[cfg(all(
437            feature = "MPSGraphTensor",
438            feature = "MPSGraphTensorData",
439            feature = "block2"
440        ))]
441        /// Setter for [`completionHandler`][Self::completionHandler].
442        ///
443        /// # Safety
444        ///
445        /// `completion_handler` must be a valid pointer.
446        #[unsafe(method(setCompletionHandler:))]
447        #[unsafe(method_family = none)]
448        pub unsafe fn setCompletionHandler(&self, completion_handler: MPSGraphCompletionHandler);
449
450        /// The flag that blocks the execution call until the entire execution is complete.
451        ///
452        /// Defaults to NO.
453        #[unsafe(method(waitUntilCompleted))]
454        #[unsafe(method_family = none)]
455        pub unsafe fn waitUntilCompleted(&self) -> bool;
456
457        /// Setter for [`waitUntilCompleted`][Self::waitUntilCompleted].
458        #[unsafe(method(setWaitUntilCompleted:))]
459        #[unsafe(method_family = none)]
460        pub unsafe fn setWaitUntilCompleted(&self, wait_until_completed: bool);
461
462        /// The compilation descriptor for the graph.
463        ///
464        /// Default value is nil.
465        #[unsafe(method(compilationDescriptor))]
466        #[unsafe(method_family = none)]
467        pub unsafe fn compilationDescriptor(
468            &self,
469        ) -> Option<Retained<MPSGraphCompilationDescriptor>>;
470
471        /// Setter for [`compilationDescriptor`][Self::compilationDescriptor].
472        ///
473        /// This is [copied][objc2_foundation::NSCopying::copy] when set.
474        #[unsafe(method(setCompilationDescriptor:))]
475        #[unsafe(method_family = none)]
476        pub unsafe fn setCompilationDescriptor(
477            &self,
478            compilation_descriptor: Option<&MPSGraphCompilationDescriptor>,
479        );
480
481        /// Executable waits on these shared events before scheduling execution on the HW, this does not include encoding which can still continue.
482        ///
483        /// - Parameters:
484        /// - event: shared event graph waits on.
485        /// - value: value of shared event graph waits on.
486        #[unsafe(method(waitForEvent:value:))]
487        #[unsafe(method_family = none)]
488        pub unsafe fn waitForEvent_value(
489            &self,
490            event: &ProtocolObject<dyn MTLSharedEvent>,
491            value: u64,
492        );
493
494        /// Executable signals these shared events at execution stage and immediately proceeds.
495        ///
496        /// - Parameters:
497        /// - event: shared event to signal.
498        /// - executionStage: execution stage to signal event at.
499        /// - value: value for shared event to wait on.
500        #[unsafe(method(signalEvent:atExecutionEvent:value:))]
501        #[unsafe(method_family = none)]
502        pub unsafe fn signalEvent_atExecutionEvent_value(
503            &self,
504            event: &ProtocolObject<dyn MTLSharedEvent>,
505            execution_stage: MPSGraphExecutionStage,
506            value: u64,
507        );
508    );
509}
510
511/// Methods declared on superclass `NSObject`.
512#[cfg(feature = "MPSGraphCore")]
513impl MPSGraphExecutionDescriptor {
514    extern_methods!(
515        #[unsafe(method(init))]
516        #[unsafe(method_family = init)]
517        pub unsafe fn init(this: Allocated<Self>) -> Retained<Self>;
518
519        #[unsafe(method(new))]
520        #[unsafe(method_family = new)]
521        pub unsafe fn new() -> Retained<Self>;
522    );
523}
524
525extern_class!(
526    /// The optimized representation of a compute graph of operations and tensors.
527    ///
528    /// An MPSGraph is a symbolic representation of operations to be utilized to execute compute graphs on a device.
529    ///
530    /// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraph?language=objc)
531    #[unsafe(super(MPSGraphObject, NSObject))]
532    #[derive(Debug, PartialEq, Eq, Hash)]
533    #[cfg(feature = "MPSGraphCore")]
534    pub struct MPSGraph;
535);
536
537#[cfg(feature = "MPSGraphCore")]
538extern_conformance!(
539    unsafe impl NSObjectProtocol for MPSGraph {}
540);
541
542#[cfg(feature = "MPSGraphCore")]
543impl MPSGraph {
544    extern_methods!(
545        /// Options for the graph.
546        ///
547        /// The default value is `MPSGraphOptionsDefault`.
548        #[unsafe(method(options))]
549        #[unsafe(method_family = none)]
550        pub unsafe fn options(&self) -> MPSGraphOptions;
551
552        /// Setter for [`options`][Self::options].
553        #[unsafe(method(setOptions:))]
554        #[unsafe(method_family = none)]
555        pub unsafe fn setOptions(&self, options: MPSGraphOptions);
556
557        /// Creates a new graph to insert nodes in.
558        #[unsafe(method(new))]
559        #[unsafe(method_family = new)]
560        pub unsafe fn new() -> Retained<Self>;
561
562        /// Initialize an MPSGraph to insert nodes in.
563        #[unsafe(method(init))]
564        #[unsafe(method_family = init)]
565        pub unsafe fn init(this: Allocated<Self>) -> Retained<Self>;
566
567        #[cfg(feature = "MPSGraphTensor")]
568        /// Array of all the placeholder tensors.
569        #[unsafe(method(placeholderTensors))]
570        #[unsafe(method_family = none)]
571        pub unsafe fn placeholderTensors(&self) -> Retained<NSArray<MPSGraphTensor>>;
572
573        #[cfg(all(
574            feature = "MPSGraphDevice",
575            feature = "MPSGraphExecutable",
576            feature = "MPSGraphOperation",
577            feature = "MPSGraphTensor"
578        ))]
579        /// Compiles the graph for the given feeds to returns the target tensor values, ensuring all target operations would be executed.
580        ///
581        /// This call blocks until execution has completed. The compilation descriptor helps specialize the executable returned.
582        ///
583        /// - Parameters:
584        /// - device: MPSGraph device to optimize for.
585        /// - feeds: Feeds dictionary for the placeholder tensors.
586        /// - targetTensors: Tensors for which the caller wishes MPSGraphTensorData to be returned.
587        /// - targetOperations: Operations to be completed at the end of the run.
588        /// - compilationDescriptor: compilation descriptor to set different compilation parameters.
589        /// - Returns: A valid MPSGraphExecutable object
590        #[unsafe(method(compileWithDevice:feeds:targetTensors:targetOperations:compilationDescriptor:))]
591        #[unsafe(method_family = none)]
592        pub unsafe fn compileWithDevice_feeds_targetTensors_targetOperations_compilationDescriptor(
593            &self,
594            device: Option<&MPSGraphDevice>,
595            feeds: &MPSGraphTensorShapedTypeDictionary,
596            target_tensors: &NSArray<MPSGraphTensor>,
597            target_operations: Option<&NSArray<MPSGraphOperation>>,
598            compilation_descriptor: Option<&MPSGraphCompilationDescriptor>,
599        ) -> Retained<MPSGraphExecutable>;
600
601        #[cfg(all(
602            feature = "MPSGraphOperation",
603            feature = "MPSGraphTensor",
604            feature = "MPSGraphTensorData"
605        ))]
606        /// Runs the graph for the given feeds and returns the target tensor values, ensuring all target operations also executed.
607        ///
608        /// This call blocks until execution has completed.
609        ///
610        /// - Parameters:
611        /// - feeds: Feeds dictionary for the placeholder tensors.
612        /// - targetTensors: Tensors for which the caller wishes MPSGraphTensorData to be returned.
613        /// - targetOperations: Operations to be completed at the end of the run.
614        /// - Returns: A valid MPSGraphTensor : MPSGraphTensorData dictionary with results synchronized to the CPU memory.
615        #[unsafe(method(runWithFeeds:targetTensors:targetOperations:))]
616        #[unsafe(method_family = none)]
617        pub unsafe fn runWithFeeds_targetTensors_targetOperations(
618            &self,
619            feeds: &MPSGraphTensorDataDictionary,
620            target_tensors: &NSArray<MPSGraphTensor>,
621            target_operations: Option<&NSArray<MPSGraphOperation>>,
622        ) -> Retained<MPSGraphTensorDataDictionary>;
623
624        #[cfg(all(
625            feature = "MPSGraphOperation",
626            feature = "MPSGraphTensor",
627            feature = "MPSGraphTensorData"
628        ))]
629        /// Runs the graph for the given feeds and returns the target tensor values, ensuring all target operations also executed.
630        ///
631        /// This call blocks until execution has completed.
632        ///
633        /// - Parameters:
634        /// - commandQueue: CommandQueue passed to exectute the graph on.
635        /// - feeds: Feeds dictionary for the placeholder tensors.
636        /// - targetTensors: Tensors for which the caller wishes MPSGraphTensorData to be returned.
637        /// - targetOperations: Operations to be completed at the end of the run.
638        /// - Returns: A valid MPSGraphTensor : MPSGraphTensorData dictionary with results synchronized to the CPU memory.
639        #[unsafe(method(runWithMTLCommandQueue:feeds:targetTensors:targetOperations:))]
640        #[unsafe(method_family = none)]
641        pub unsafe fn runWithMTLCommandQueue_feeds_targetTensors_targetOperations(
642            &self,
643            command_queue: &ProtocolObject<dyn MTLCommandQueue>,
644            feeds: &MPSGraphTensorDataDictionary,
645            target_tensors: &NSArray<MPSGraphTensor>,
646            target_operations: Option<&NSArray<MPSGraphOperation>>,
647        ) -> Retained<MPSGraphTensorDataDictionary>;
648
649        #[cfg(all(
650            feature = "MPSGraphOperation",
651            feature = "MPSGraphTensor",
652            feature = "MPSGraphTensorData"
653        ))]
654        /// Runs the graph for the given feeds and returns the target tensor values in the results dictionary provided by the user.
655        ///
656        /// It also ensures all target operations also executed. This call blocks until execution has completed.
657        ///
658        /// - Parameters:
659        /// - commandQueue: CommandQueue passed to exectute the graph on.
660        /// - feeds: Feeds dictionary for the placeholder tensors.
661        /// - targetOperations: Operations to be completed at the end of the run.
662        /// - resultsDictionary: MPSGraphTensors dictionary passed by user, these will be filled with graph output data.
663        #[unsafe(method(runWithMTLCommandQueue:feeds:targetOperations:resultsDictionary:))]
664        #[unsafe(method_family = none)]
665        pub unsafe fn runWithMTLCommandQueue_feeds_targetOperations_resultsDictionary(
666            &self,
667            command_queue: &ProtocolObject<dyn MTLCommandQueue>,
668            feeds: &MPSGraphTensorDataDictionary,
669            target_operations: Option<&NSArray<MPSGraphOperation>>,
670            results_dictionary: &MPSGraphTensorDataDictionary,
671        );
672
673        #[cfg(all(
674            feature = "MPSGraphOperation",
675            feature = "MPSGraphTensor",
676            feature = "MPSGraphTensorData"
677        ))]
678        /// Runs the graph for the given feeds and returns the target tensor values, ensuring all target operations also executed.
679        ///
680        /// This call is asynchronous and will return immediately if a completionHandler is set.
681        ///
682        /// - Parameters:
683        /// - feeds: Feeds dictionary for the placeholder tensors.
684        /// - targetTensors: Tensors for which the caller wishes MPSGraphTensorData to be returned.
685        /// - targetOperations: Operations to be completed at the end of the run.
686        /// - executionDescriptor: ExecutionDescriptor to be passed in and used.
687        /// - Returns: A valid MPSGraphTensor : MPSGraphTensorData dictionary with results synchronized to the CPU memory.
688        #[unsafe(method(runAsyncWithFeeds:targetTensors:targetOperations:executionDescriptor:))]
689        #[unsafe(method_family = none)]
690        pub unsafe fn runAsyncWithFeeds_targetTensors_targetOperations_executionDescriptor(
691            &self,
692            feeds: &MPSGraphTensorDataDictionary,
693            target_tensors: &NSArray<MPSGraphTensor>,
694            target_operations: Option<&NSArray<MPSGraphOperation>>,
695            execution_descriptor: Option<&MPSGraphExecutionDescriptor>,
696        ) -> Retained<MPSGraphTensorDataDictionary>;
697
698        #[cfg(all(
699            feature = "MPSGraphOperation",
700            feature = "MPSGraphTensor",
701            feature = "MPSGraphTensorData"
702        ))]
703        /// Runs the graph for the given feeds and returns the target tensor values, ensuring all target operations also executed.
704        ///
705        /// This call is asynchronous and will return immediately if a completionHandler is set.
706        ///
707        /// - Parameters:
708        /// - commandQueue: CommandQueue passed to exectute the graph on.
709        /// - feeds: Feeds dictionary for the placeholder tensors.
710        /// - targetTensors: Tensors for which the caller wishes MPSGraphTensorData to be returned.
711        /// - targetOperations: Operations to be completed at the end of the run.
712        /// - executionDescriptor: ExecutionDescriptor to be passed in and used.
713        /// - Returns: A valid MPSGraphTensor : MPSGraphTensorData dictionary with results synchronized to the CPU memory if MPSGraphOptionsSynchronizeResults set.
714        #[unsafe(method(runAsyncWithMTLCommandQueue:feeds:targetTensors:targetOperations:executionDescriptor:))]
715        #[unsafe(method_family = none)]
716        pub unsafe fn runAsyncWithMTLCommandQueue_feeds_targetTensors_targetOperations_executionDescriptor(
717            &self,
718            command_queue: &ProtocolObject<dyn MTLCommandQueue>,
719            feeds: &MPSGraphTensorDataDictionary,
720            target_tensors: &NSArray<MPSGraphTensor>,
721            target_operations: Option<&NSArray<MPSGraphOperation>>,
722            execution_descriptor: Option<&MPSGraphExecutionDescriptor>,
723        ) -> Retained<MPSGraphTensorDataDictionary>;
724
725        #[cfg(all(
726            feature = "MPSGraphOperation",
727            feature = "MPSGraphTensor",
728            feature = "MPSGraphTensorData"
729        ))]
730        /// Encodes the graph for the given feeds to returns the target tensor values in the results dictionary provided by the user.
731        ///
732        /// It ensures all target operations also executed. This call is asynchronous and will return immediately if a completionHandler is set.
733        ///
734        /// - Parameters:
735        /// - commandQueue: CommandQueue passed to exectute the graph on.
736        /// - feeds: Feeds dictionary for the placeholder tensors.
737        /// - targetOperations: Operations to be completed at the end of the run.
738        /// - resultsDictionary: MPSGraphTensors dictionary passed by user, these will be filled with graph output data.
739        /// - executionDescriptor: ExecutionDescriptor to be passed in and used.
740        #[unsafe(method(runAsyncWithMTLCommandQueue:feeds:targetOperations:resultsDictionary:executionDescriptor:))]
741        #[unsafe(method_family = none)]
742        pub unsafe fn runAsyncWithMTLCommandQueue_feeds_targetOperations_resultsDictionary_executionDescriptor(
743            &self,
744            command_queue: &ProtocolObject<dyn MTLCommandQueue>,
745            feeds: &MPSGraphTensorDataDictionary,
746            target_operations: Option<&NSArray<MPSGraphOperation>>,
747            results_dictionary: &MPSGraphTensorDataDictionary,
748            execution_descriptor: Option<&MPSGraphExecutionDescriptor>,
749        );
750
751        #[cfg(all(
752            feature = "MPSGraphOperation",
753            feature = "MPSGraphTensor",
754            feature = "MPSGraphTensorData",
755            feature = "objc2-metal-performance-shaders"
756        ))]
757        /// Encodes the graph for the given feeds to returns the target tensor values, ensuring all target operations also executed.
758        ///
759        /// This call is asynchronous and will return immediately if a completionHandler is set.
760        ///
761        /// - Parameters:
762        /// - commandBuffer: commandBuffer passed to exectute the graph on, it is an MPSCommandBuffer, commitAndContinue might be called, please don't rely on underlying MTLCommandBuffer to remain uncommitted.
763        /// - feeds: Feeds dictionary for the placeholder tensors.
764        /// - targetTensors: Tensors for which the caller wishes MPSGraphTensorData to be returned.
765        /// - targetOperations: Operations to be completed at the end of the run.
766        /// - executionDescriptor: ExecutionDescriptor to be passed in and used.
767        /// - Returns: A valid MPSGraphTensor : MPSGraphTensorData dictionary with results synchronized to the CPU memory if MPSGraphOptionsSynchronizeResults set.
768        #[unsafe(method(encodeToCommandBuffer:feeds:targetTensors:targetOperations:executionDescriptor:))]
769        #[unsafe(method_family = none)]
770        pub unsafe fn encodeToCommandBuffer_feeds_targetTensors_targetOperations_executionDescriptor(
771            &self,
772            command_buffer: &MPSCommandBuffer,
773            feeds: &MPSGraphTensorDataDictionary,
774            target_tensors: &NSArray<MPSGraphTensor>,
775            target_operations: Option<&NSArray<MPSGraphOperation>>,
776            execution_descriptor: Option<&MPSGraphExecutionDescriptor>,
777        ) -> Retained<MPSGraphTensorDataDictionary>;
778
779        #[cfg(all(
780            feature = "MPSGraphOperation",
781            feature = "MPSGraphTensor",
782            feature = "MPSGraphTensorData",
783            feature = "objc2-metal-performance-shaders"
784        ))]
785        /// Encodes the graph for the given feeds to returns the target tensor values in the results dictionary provided by the user.
786        ///
787        /// It ensures all target operations also executed. This call is asynchronous and will return immediately if a completionHandler is set.
788        ///
789        /// - Parameters:
790        /// - commandBuffer: commandBuffer passed to execute the graph on, commitAndContinue might be called, please don't rely on underlying MTLCommandBuffer to remain uncommitted.
791        /// - feeds: Feeds dictionary for the placeholder tensors.
792        /// - targetOperations: Operations to be completed at the end of the run.
793        /// - resultsDictionary: MPSGraphTensors dictionary passed by user, these will be filled with graph output data.
794        /// - executionDescriptor: ExecutionDescriptor to be passed in and used.
795        #[unsafe(method(encodeToCommandBuffer:feeds:targetOperations:resultsDictionary:executionDescriptor:))]
796        #[unsafe(method_family = none)]
797        pub unsafe fn encodeToCommandBuffer_feeds_targetOperations_resultsDictionary_executionDescriptor(
798            &self,
799            command_buffer: &MPSCommandBuffer,
800            feeds: &MPSGraphTensorDataDictionary,
801            target_operations: Option<&NSArray<MPSGraphOperation>>,
802            results_dictionary: &MPSGraphTensorDataDictionary,
803            execution_descriptor: Option<&MPSGraphExecutionDescriptor>,
804        );
805    );
806}