objc2_metal_performance_shaders_graph/generated/MPSGraph.rs
1//! This file has been automatically generated by `objc2`'s `header-translator`.
2//! DO NOT EDIT
3use core::ffi::*;
4use core::ptr::NonNull;
5#[cfg(feature = "dispatch2")]
6use dispatch2::*;
7use objc2::__framework_prelude::*;
8use objc2_foundation::*;
9use objc2_metal::*;
10#[cfg(feature = "objc2-metal-performance-shaders")]
11use objc2_metal_performance_shaders::*;
12
13use crate::*;
14
15/// The options available to a graph.
16///
17/// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraphoptions?language=objc)
18// NS_ENUM
19#[repr(transparent)]
20#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
21pub struct MPSGraphOptions(pub u64);
22impl MPSGraphOptions {
23 /// No Options.
24 #[doc(alias = "MPSGraphOptionsNone")]
25 pub const None: Self = Self(0);
26 /// The graph synchronizes results to the CPU using a blit encoder if on a discrete GPU at the end of execution.
27 #[doc(alias = "MPSGraphOptionsSynchronizeResults")]
28 pub const SynchronizeResults: Self = Self(1);
29 /// The framework prints more logging info.
30 #[doc(alias = "MPSGraphOptionsVerbose")]
31 pub const Verbose: Self = Self(2);
32 /// The framework uses these options as default if not overriden.
33 #[doc(alias = "MPSGraphOptionsDefault")]
34 pub const Default: Self = Self(MPSGraphOptions::SynchronizeResults.0);
35}
36
37unsafe impl Encode for MPSGraphOptions {
38 const ENCODING: Encoding = u64::ENCODING;
39}
40
41unsafe impl RefEncode for MPSGraphOptions {
42 const ENCODING_REF: Encoding = Encoding::Pointer(&Self::ENCODING);
43}
44
45/// The optimization levels to trade compilation time for even more runtime performance by running more passes.
46///
47/// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraphoptimization?language=objc)
48// NS_ENUM
49#[repr(transparent)]
50#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
51pub struct MPSGraphOptimization(pub u64);
52impl MPSGraphOptimization {
53 /// Graph performs core optimizations only.
54 #[doc(alias = "MPSGraphOptimizationLevel0")]
55 pub const Level0: Self = Self(0);
56 /// Graph performs additional Optimizations, like using the placement pass to dispatch across different HW blocks like the NeuralEngine and CPU along with the GPU.
57 #[doc(alias = "MPSGraphOptimizationLevel1")]
58 pub const Level1: Self = Self(1);
59}
60
61unsafe impl Encode for MPSGraphOptimization {
62 const ENCODING: Encoding = u64::ENCODING;
63}
64
65unsafe impl RefEncode for MPSGraphOptimization {
66 const ENCODING_REF: Encoding = Encoding::Pointer(&Self::ENCODING);
67}
68
69/// The optimization profile used as a heuristic as the graph compiler optimizes the network.
70///
71/// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraphoptimizationprofile?language=objc)
72// NS_ENUM
73#[repr(transparent)]
74#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
75pub struct MPSGraphOptimizationProfile(pub u64);
76impl MPSGraphOptimizationProfile {
77 /// Default, graph optimized for performance.
78 #[doc(alias = "MPSGraphOptimizationProfilePerformance")]
79 pub const Performance: Self = Self(0);
80 /// Graph optimized for power efficiency.
81 #[doc(alias = "MPSGraphOptimizationProfilePowerEfficiency")]
82 pub const PowerEfficiency: Self = Self(1);
83}
84
85unsafe impl Encode for MPSGraphOptimizationProfile {
86 const ENCODING: Encoding = u64::ENCODING;
87}
88
89unsafe impl RefEncode for MPSGraphOptimizationProfile {
90 const ENCODING_REF: Encoding = Encoding::Pointer(&Self::ENCODING);
91}
92
93/// Execution events that can be used with shared events.
94///
95/// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraphexecutionstage?language=objc)
96// NS_ENUM
97#[repr(transparent)]
98#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
99pub struct MPSGraphExecutionStage(pub u64);
100impl MPSGraphExecutionStage {
101 /// stage when execution of the graph completes.
102 #[doc(alias = "MPSGraphExecutionStageCompleted")]
103 pub const Completed: Self = Self(0);
104}
105
106unsafe impl Encode for MPSGraphExecutionStage {
107 const ENCODING: Encoding = u64::ENCODING;
108}
109
110unsafe impl RefEncode for MPSGraphExecutionStage {
111 const ENCODING_REF: Encoding = Encoding::Pointer(&Self::ENCODING);
112}
113
114/// MPSGraph could use these reduced precision paths to deliver faster math, but it is not guaranteed.
115///
116/// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraphreducedprecisionfastmath?language=objc)
117// NS_OPTIONS
118#[repr(transparent)]
119#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
120pub struct MPSGraphReducedPrecisionFastMath(pub NSUInteger);
121bitflags::bitflags! {
122 impl MPSGraphReducedPrecisionFastMath: NSUInteger {
123/// Full precision math with maximum accuracy.
124 #[doc(alias = "MPSGraphReducedPrecisionFastMathNone")]
125 const None = 0;
126/// Execute winograd transform intermediate as FP16.
127 #[doc(alias = "MPSGraphReducedPrecisionFastMathAllowFP16Conv2DWinogradTransformIntermediate")]
128 const AllowFP16Conv2DWinogradTransformIntermediate = 1<<1;
129/// Curated list allowing intermediates for multi-pass GPU kernels to be FP16.
130 #[doc(alias = "MPSGraphReducedPrecisionFastMathAllowFP16Intermediates")]
131 const AllowFP16Intermediates = MPSGraphReducedPrecisionFastMath::AllowFP16Conv2DWinogradTransformIntermediate.0;
132/// Default selection.
133 #[doc(alias = "MPSGraphReducedPrecisionFastMathDefault")]
134 const Default = MPSGraphReducedPrecisionFastMath::None.0;
135 }
136}
137
138unsafe impl Encode for MPSGraphReducedPrecisionFastMath {
139 const ENCODING: Encoding = NSUInteger::ENCODING;
140}
141
142unsafe impl RefEncode for MPSGraphReducedPrecisionFastMath {
143 const ENCODING_REF: Encoding = Encoding::Pointer(&Self::ENCODING);
144}
145
146/// A dictionary of tensors and corresponding tensor data.
147///
148/// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraphtensordatadictionary?language=objc)
149#[cfg(all(
150 feature = "MPSGraphCore",
151 feature = "MPSGraphTensor",
152 feature = "MPSGraphTensorData"
153))]
154pub type MPSGraphTensorDataDictionary = NSDictionary<MPSGraphTensor, MPSGraphTensorData>;
155
156/// A dictionary of tensors and corresponding shapes for them.
157///
158/// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraphtensorshapedtypedictionary?language=objc)
159#[cfg(all(feature = "MPSGraphCore", feature = "MPSGraphTensor"))]
160pub type MPSGraphTensorShapedTypeDictionary = NSDictionary<MPSGraphTensor, MPSGraphShapedType>;
161
162/// A notification that appears when graph execution finishes.
163///
164/// - Parameters:
165/// - resultsDictionary: If no error, the results dictionary produced by the graph operation.
166/// - error: If an error occurs, more information might be found here.
167///
168/// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraphcompletionhandler?language=objc)
169#[cfg(all(
170 feature = "MPSGraphCore",
171 feature = "MPSGraphTensor",
172 feature = "MPSGraphTensorData",
173 feature = "block2"
174))]
175pub type MPSGraphCompletionHandler =
176 *mut block2::DynBlock<dyn Fn(NonNull<MPSGraphTensorDataDictionary>, *mut NSError)>;
177
178/// A notification that appears when graph execution schedules.
179///
180/// - Parameters:
181/// - resultsDictionary: If no error, the results dictionary produced by the graph operation. If Graph has not yet allocated, the results will be `NSNull`.
182/// - error: If an error occurs, more information might be found here.
183///
184/// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraphscheduledhandler?language=objc)
185#[cfg(all(
186 feature = "MPSGraphCore",
187 feature = "MPSGraphTensor",
188 feature = "MPSGraphTensorData",
189 feature = "block2"
190))]
191pub type MPSGraphScheduledHandler =
192 *mut block2::DynBlock<dyn Fn(NonNull<MPSGraphTensorDataDictionary>, *mut NSError)>;
193
194/// A notification that appears when compilation finishes.
195///
196/// - Parameters:
197/// - executable: If no error, the executable produced by the compilation.
198/// - error: If an error occurs, more information might be found here.
199///
200/// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraphcompilationcompletionhandler?language=objc)
201#[cfg(all(
202 feature = "MPSGraphCore",
203 feature = "MPSGraphExecutable",
204 feature = "block2"
205))]
206pub type MPSGraphCompilationCompletionHandler =
207 *mut block2::DynBlock<dyn Fn(NonNull<MPSGraphExecutable>, *mut NSError)>;
208
209/// A dictionary of symbol names and the corresponding executables for them.
210///
211/// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraphcallablemap?language=objc)
212#[cfg(all(feature = "MPSGraphCore", feature = "MPSGraphExecutable"))]
213pub type MPSGraphCallableMap = NSDictionary<NSString, MPSGraphExecutable>;
214
215extern_class!(
216 /// A class that consists of all the levers for compiling graphs.
217 ///
218 /// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraphcompilationdescriptor?language=objc)
219 #[unsafe(super(MPSGraphObject, NSObject))]
220 #[derive(Debug, PartialEq, Eq, Hash)]
221 #[cfg(feature = "MPSGraphCore")]
222 pub struct MPSGraphCompilationDescriptor;
223);
224
225#[cfg(feature = "MPSGraphCore")]
226extern_conformance!(
227 unsafe impl NSCopying for MPSGraphCompilationDescriptor {}
228);
229
230#[cfg(feature = "MPSGraphCore")]
231unsafe impl CopyingHelper for MPSGraphCompilationDescriptor {
232 type Result = Self;
233}
234
235#[cfg(feature = "MPSGraphCore")]
236extern_conformance!(
237 unsafe impl NSObjectProtocol for MPSGraphCompilationDescriptor {}
238);
239
240#[cfg(feature = "MPSGraphCore")]
241impl MPSGraphCompilationDescriptor {
242 extern_methods!(
243 /// Turns off type inference and relies on type inference during runtime.
244 #[unsafe(method(disableTypeInference))]
245 #[unsafe(method_family = none)]
246 pub unsafe fn disableTypeInference(&self);
247
248 /// The optimization level for the graph execution, default is MPSGraphOptimizationLevel1.
249 #[unsafe(method(optimizationLevel))]
250 #[unsafe(method_family = none)]
251 pub unsafe fn optimizationLevel(&self) -> MPSGraphOptimization;
252
253 /// Setter for [`optimizationLevel`][Self::optimizationLevel].
254 #[unsafe(method(setOptimizationLevel:))]
255 #[unsafe(method_family = none)]
256 pub unsafe fn setOptimizationLevel(&self, optimization_level: MPSGraphOptimization);
257
258 /// Flag that makes the compile or specialize call blocking till the entire compilation is complete, defaults to NO.
259 #[unsafe(method(waitForCompilationCompletion))]
260 #[unsafe(method_family = none)]
261 pub unsafe fn waitForCompilationCompletion(&self) -> bool;
262
263 /// Setter for [`waitForCompilationCompletion`][Self::waitForCompilationCompletion].
264 #[unsafe(method(setWaitForCompilationCompletion:))]
265 #[unsafe(method_family = none)]
266 pub unsafe fn setWaitForCompilationCompletion(&self, wait_for_compilation_completion: bool);
267
268 #[cfg(all(feature = "MPSGraphExecutable", feature = "block2"))]
269 /// The handler that the graph calls when the compilation completes.
270 ///
271 /// Default value is nil.
272 ///
273 /// # Safety
274 ///
275 /// - The returned block's argument 1 must be a valid pointer.
276 /// - The returned block's argument 2 must be a valid pointer or null.
277 #[unsafe(method(compilationCompletionHandler))]
278 #[unsafe(method_family = none)]
279 pub unsafe fn compilationCompletionHandler(&self) -> MPSGraphCompilationCompletionHandler;
280
281 #[cfg(all(feature = "MPSGraphExecutable", feature = "block2"))]
282 /// Setter for [`compilationCompletionHandler`][Self::compilationCompletionHandler].
283 ///
284 /// # Safety
285 ///
286 /// `compilation_completion_handler` must be a valid pointer.
287 #[unsafe(method(setCompilationCompletionHandler:))]
288 #[unsafe(method_family = none)]
289 pub unsafe fn setCompilationCompletionHandler(
290 &self,
291 compilation_completion_handler: MPSGraphCompilationCompletionHandler,
292 );
293
294 #[cfg(feature = "dispatch2")]
295 /// The dispatch queue used for the compilation.
296 ///
297 /// Default value is nil.
298 #[unsafe(method(dispatchQueue))]
299 #[unsafe(method_family = none)]
300 pub unsafe fn dispatchQueue(&self) -> Retained<DispatchQueue>;
301
302 #[cfg(feature = "dispatch2")]
303 /// Setter for [`dispatchQueue`][Self::dispatchQueue].
304 ///
305 /// # Safety
306 ///
307 /// `dispatch_queue` possibly has additional threading requirements.
308 #[unsafe(method(setDispatchQueue:))]
309 #[unsafe(method_family = none)]
310 pub unsafe fn setDispatchQueue(&self, dispatch_queue: &DispatchQueue);
311
312 /// The optimization profile for the graph optimization.
313 ///
314 /// Default is MPSGraphOptimizationProfilePerformance.
315 #[deprecated]
316 #[unsafe(method(optimizationProfile))]
317 #[unsafe(method_family = none)]
318 pub unsafe fn optimizationProfile(&self) -> MPSGraphOptimizationProfile;
319
320 /// Setter for [`optimizationProfile`][Self::optimizationProfile].
321 #[deprecated]
322 #[unsafe(method(setOptimizationProfile:))]
323 #[unsafe(method_family = none)]
324 pub unsafe fn setOptimizationProfile(
325 &self,
326 optimization_profile: MPSGraphOptimizationProfile,
327 );
328
329 #[cfg(feature = "MPSGraphExecutable")]
330 /// The dictionary used during runtime to lookup the ``MPSGraphExecutable`` which correspond to the ``symbolName``.
331 #[unsafe(method(callables))]
332 #[unsafe(method_family = none)]
333 pub unsafe fn callables(&self) -> Option<Retained<MPSGraphCallableMap>>;
334
335 #[cfg(feature = "MPSGraphExecutable")]
336 /// Setter for [`callables`][Self::callables].
337 #[unsafe(method(setCallables:))]
338 #[unsafe(method_family = none)]
339 pub unsafe fn setCallables(&self, callables: Option<&MPSGraphCallableMap>);
340
341 /// Across the executable allow reduced precision fast math optimizations.
342 #[unsafe(method(reducedPrecisionFastMath))]
343 #[unsafe(method_family = none)]
344 pub unsafe fn reducedPrecisionFastMath(&self) -> MPSGraphReducedPrecisionFastMath;
345
346 /// Setter for [`reducedPrecisionFastMath`][Self::reducedPrecisionFastMath].
347 #[unsafe(method(setReducedPrecisionFastMath:))]
348 #[unsafe(method_family = none)]
349 pub unsafe fn setReducedPrecisionFastMath(
350 &self,
351 reduced_precision_fast_math: MPSGraphReducedPrecisionFastMath,
352 );
353 );
354}
355
356/// Methods declared on superclass `NSObject`.
357#[cfg(feature = "MPSGraphCore")]
358impl MPSGraphCompilationDescriptor {
359 extern_methods!(
360 #[unsafe(method(init))]
361 #[unsafe(method_family = init)]
362 pub unsafe fn init(this: Allocated<Self>) -> Retained<Self>;
363
364 #[unsafe(method(new))]
365 #[unsafe(method_family = new)]
366 pub unsafe fn new() -> Retained<Self>;
367 );
368}
369
370extern_class!(
371 /// A class that consists of all the levers to synchronize and schedule graph execution.
372 ///
373 /// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraphexecutiondescriptor?language=objc)
374 #[unsafe(super(MPSGraphObject, NSObject))]
375 #[derive(Debug, PartialEq, Eq, Hash)]
376 #[cfg(feature = "MPSGraphCore")]
377 pub struct MPSGraphExecutionDescriptor;
378);
379
380#[cfg(feature = "MPSGraphCore")]
381extern_conformance!(
382 unsafe impl NSObjectProtocol for MPSGraphExecutionDescriptor {}
383);
384
385#[cfg(feature = "MPSGraphCore")]
386impl MPSGraphExecutionDescriptor {
387 extern_methods!(
388 #[cfg(all(
389 feature = "MPSGraphTensor",
390 feature = "MPSGraphTensorData",
391 feature = "block2"
392 ))]
393 /// The handler that graph calls when it schedules the execution.
394 ///
395 /// Default value is nil.
396 ///
397 /// # Safety
398 ///
399 /// - The returned block's argument 1 must be a valid pointer.
400 /// - The returned block's argument 2 must be a valid pointer or null.
401 #[unsafe(method(scheduledHandler))]
402 #[unsafe(method_family = none)]
403 pub unsafe fn scheduledHandler(&self) -> MPSGraphScheduledHandler;
404
405 #[cfg(all(
406 feature = "MPSGraphTensor",
407 feature = "MPSGraphTensorData",
408 feature = "block2"
409 ))]
410 /// Setter for [`scheduledHandler`][Self::scheduledHandler].
411 ///
412 /// # Safety
413 ///
414 /// `scheduled_handler` must be a valid pointer.
415 #[unsafe(method(setScheduledHandler:))]
416 #[unsafe(method_family = none)]
417 pub unsafe fn setScheduledHandler(&self, scheduled_handler: MPSGraphScheduledHandler);
418
419 #[cfg(all(
420 feature = "MPSGraphTensor",
421 feature = "MPSGraphTensorData",
422 feature = "block2"
423 ))]
424 /// The handler that graph calls at the completion of the execution.
425 ///
426 /// Default value is nil.
427 ///
428 /// # Safety
429 ///
430 /// - The returned block's argument 1 must be a valid pointer.
431 /// - The returned block's argument 2 must be a valid pointer or null.
432 #[unsafe(method(completionHandler))]
433 #[unsafe(method_family = none)]
434 pub unsafe fn completionHandler(&self) -> MPSGraphCompletionHandler;
435
436 #[cfg(all(
437 feature = "MPSGraphTensor",
438 feature = "MPSGraphTensorData",
439 feature = "block2"
440 ))]
441 /// Setter for [`completionHandler`][Self::completionHandler].
442 ///
443 /// # Safety
444 ///
445 /// `completion_handler` must be a valid pointer.
446 #[unsafe(method(setCompletionHandler:))]
447 #[unsafe(method_family = none)]
448 pub unsafe fn setCompletionHandler(&self, completion_handler: MPSGraphCompletionHandler);
449
450 /// The flag that blocks the execution call until the entire execution is complete.
451 ///
452 /// Defaults to NO.
453 #[unsafe(method(waitUntilCompleted))]
454 #[unsafe(method_family = none)]
455 pub unsafe fn waitUntilCompleted(&self) -> bool;
456
457 /// Setter for [`waitUntilCompleted`][Self::waitUntilCompleted].
458 #[unsafe(method(setWaitUntilCompleted:))]
459 #[unsafe(method_family = none)]
460 pub unsafe fn setWaitUntilCompleted(&self, wait_until_completed: bool);
461
462 /// The compilation descriptor for the graph.
463 ///
464 /// Default value is nil.
465 #[unsafe(method(compilationDescriptor))]
466 #[unsafe(method_family = none)]
467 pub unsafe fn compilationDescriptor(
468 &self,
469 ) -> Option<Retained<MPSGraphCompilationDescriptor>>;
470
471 /// Setter for [`compilationDescriptor`][Self::compilationDescriptor].
472 ///
473 /// This is [copied][objc2_foundation::NSCopying::copy] when set.
474 #[unsafe(method(setCompilationDescriptor:))]
475 #[unsafe(method_family = none)]
476 pub unsafe fn setCompilationDescriptor(
477 &self,
478 compilation_descriptor: Option<&MPSGraphCompilationDescriptor>,
479 );
480
481 /// Executable waits on these shared events before scheduling execution on the HW, this does not include encoding which can still continue.
482 ///
483 /// - Parameters:
484 /// - event: shared event graph waits on.
485 /// - value: value of shared event graph waits on.
486 #[unsafe(method(waitForEvent:value:))]
487 #[unsafe(method_family = none)]
488 pub unsafe fn waitForEvent_value(
489 &self,
490 event: &ProtocolObject<dyn MTLSharedEvent>,
491 value: u64,
492 );
493
494 /// Executable signals these shared events at execution stage and immediately proceeds.
495 ///
496 /// - Parameters:
497 /// - event: shared event to signal.
498 /// - executionStage: execution stage to signal event at.
499 /// - value: value for shared event to wait on.
500 #[unsafe(method(signalEvent:atExecutionEvent:value:))]
501 #[unsafe(method_family = none)]
502 pub unsafe fn signalEvent_atExecutionEvent_value(
503 &self,
504 event: &ProtocolObject<dyn MTLSharedEvent>,
505 execution_stage: MPSGraphExecutionStage,
506 value: u64,
507 );
508 );
509}
510
511/// Methods declared on superclass `NSObject`.
512#[cfg(feature = "MPSGraphCore")]
513impl MPSGraphExecutionDescriptor {
514 extern_methods!(
515 #[unsafe(method(init))]
516 #[unsafe(method_family = init)]
517 pub unsafe fn init(this: Allocated<Self>) -> Retained<Self>;
518
519 #[unsafe(method(new))]
520 #[unsafe(method_family = new)]
521 pub unsafe fn new() -> Retained<Self>;
522 );
523}
524
525extern_class!(
526 /// The optimized representation of a compute graph of operations and tensors.
527 ///
528 /// An MPSGraph is a symbolic representation of operations to be utilized to execute compute graphs on a device.
529 ///
530 /// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraph?language=objc)
531 #[unsafe(super(MPSGraphObject, NSObject))]
532 #[derive(Debug, PartialEq, Eq, Hash)]
533 #[cfg(feature = "MPSGraphCore")]
534 pub struct MPSGraph;
535);
536
537#[cfg(feature = "MPSGraphCore")]
538extern_conformance!(
539 unsafe impl NSObjectProtocol for MPSGraph {}
540);
541
542#[cfg(feature = "MPSGraphCore")]
543impl MPSGraph {
544 extern_methods!(
545 /// Options for the graph.
546 ///
547 /// The default value is `MPSGraphOptionsDefault`.
548 #[unsafe(method(options))]
549 #[unsafe(method_family = none)]
550 pub unsafe fn options(&self) -> MPSGraphOptions;
551
552 /// Setter for [`options`][Self::options].
553 #[unsafe(method(setOptions:))]
554 #[unsafe(method_family = none)]
555 pub unsafe fn setOptions(&self, options: MPSGraphOptions);
556
557 /// Creates a new graph to insert nodes in.
558 #[unsafe(method(new))]
559 #[unsafe(method_family = new)]
560 pub unsafe fn new() -> Retained<Self>;
561
562 /// Initialize an MPSGraph to insert nodes in.
563 #[unsafe(method(init))]
564 #[unsafe(method_family = init)]
565 pub unsafe fn init(this: Allocated<Self>) -> Retained<Self>;
566
567 #[cfg(feature = "MPSGraphTensor")]
568 /// Array of all the placeholder tensors.
569 #[unsafe(method(placeholderTensors))]
570 #[unsafe(method_family = none)]
571 pub unsafe fn placeholderTensors(&self) -> Retained<NSArray<MPSGraphTensor>>;
572
573 #[cfg(all(
574 feature = "MPSGraphDevice",
575 feature = "MPSGraphExecutable",
576 feature = "MPSGraphOperation",
577 feature = "MPSGraphTensor"
578 ))]
579 /// Compiles the graph for the given feeds to returns the target tensor values, ensuring all target operations would be executed.
580 ///
581 /// This call blocks until execution has completed. The compilation descriptor helps specialize the executable returned.
582 ///
583 /// - Parameters:
584 /// - device: MPSGraph device to optimize for.
585 /// - feeds: Feeds dictionary for the placeholder tensors.
586 /// - targetTensors: Tensors for which the caller wishes MPSGraphTensorData to be returned.
587 /// - targetOperations: Operations to be completed at the end of the run.
588 /// - compilationDescriptor: compilation descriptor to set different compilation parameters.
589 /// - Returns: A valid MPSGraphExecutable object
590 #[unsafe(method(compileWithDevice:feeds:targetTensors:targetOperations:compilationDescriptor:))]
591 #[unsafe(method_family = none)]
592 pub unsafe fn compileWithDevice_feeds_targetTensors_targetOperations_compilationDescriptor(
593 &self,
594 device: Option<&MPSGraphDevice>,
595 feeds: &MPSGraphTensorShapedTypeDictionary,
596 target_tensors: &NSArray<MPSGraphTensor>,
597 target_operations: Option<&NSArray<MPSGraphOperation>>,
598 compilation_descriptor: Option<&MPSGraphCompilationDescriptor>,
599 ) -> Retained<MPSGraphExecutable>;
600
601 #[cfg(all(
602 feature = "MPSGraphOperation",
603 feature = "MPSGraphTensor",
604 feature = "MPSGraphTensorData"
605 ))]
606 /// Runs the graph for the given feeds and returns the target tensor values, ensuring all target operations also executed.
607 ///
608 /// This call blocks until execution has completed.
609 ///
610 /// - Parameters:
611 /// - feeds: Feeds dictionary for the placeholder tensors.
612 /// - targetTensors: Tensors for which the caller wishes MPSGraphTensorData to be returned.
613 /// - targetOperations: Operations to be completed at the end of the run.
614 /// - Returns: A valid MPSGraphTensor : MPSGraphTensorData dictionary with results synchronized to the CPU memory.
615 #[unsafe(method(runWithFeeds:targetTensors:targetOperations:))]
616 #[unsafe(method_family = none)]
617 pub unsafe fn runWithFeeds_targetTensors_targetOperations(
618 &self,
619 feeds: &MPSGraphTensorDataDictionary,
620 target_tensors: &NSArray<MPSGraphTensor>,
621 target_operations: Option<&NSArray<MPSGraphOperation>>,
622 ) -> Retained<MPSGraphTensorDataDictionary>;
623
624 #[cfg(all(
625 feature = "MPSGraphOperation",
626 feature = "MPSGraphTensor",
627 feature = "MPSGraphTensorData"
628 ))]
629 /// Runs the graph for the given feeds and returns the target tensor values, ensuring all target operations also executed.
630 ///
631 /// This call blocks until execution has completed.
632 ///
633 /// - Parameters:
634 /// - commandQueue: CommandQueue passed to exectute the graph on.
635 /// - feeds: Feeds dictionary for the placeholder tensors.
636 /// - targetTensors: Tensors for which the caller wishes MPSGraphTensorData to be returned.
637 /// - targetOperations: Operations to be completed at the end of the run.
638 /// - Returns: A valid MPSGraphTensor : MPSGraphTensorData dictionary with results synchronized to the CPU memory.
639 #[unsafe(method(runWithMTLCommandQueue:feeds:targetTensors:targetOperations:))]
640 #[unsafe(method_family = none)]
641 pub unsafe fn runWithMTLCommandQueue_feeds_targetTensors_targetOperations(
642 &self,
643 command_queue: &ProtocolObject<dyn MTLCommandQueue>,
644 feeds: &MPSGraphTensorDataDictionary,
645 target_tensors: &NSArray<MPSGraphTensor>,
646 target_operations: Option<&NSArray<MPSGraphOperation>>,
647 ) -> Retained<MPSGraphTensorDataDictionary>;
648
649 #[cfg(all(
650 feature = "MPSGraphOperation",
651 feature = "MPSGraphTensor",
652 feature = "MPSGraphTensorData"
653 ))]
654 /// Runs the graph for the given feeds and returns the target tensor values in the results dictionary provided by the user.
655 ///
656 /// It also ensures all target operations also executed. This call blocks until execution has completed.
657 ///
658 /// - Parameters:
659 /// - commandQueue: CommandQueue passed to exectute the graph on.
660 /// - feeds: Feeds dictionary for the placeholder tensors.
661 /// - targetOperations: Operations to be completed at the end of the run.
662 /// - resultsDictionary: MPSGraphTensors dictionary passed by user, these will be filled with graph output data.
663 #[unsafe(method(runWithMTLCommandQueue:feeds:targetOperations:resultsDictionary:))]
664 #[unsafe(method_family = none)]
665 pub unsafe fn runWithMTLCommandQueue_feeds_targetOperations_resultsDictionary(
666 &self,
667 command_queue: &ProtocolObject<dyn MTLCommandQueue>,
668 feeds: &MPSGraphTensorDataDictionary,
669 target_operations: Option<&NSArray<MPSGraphOperation>>,
670 results_dictionary: &MPSGraphTensorDataDictionary,
671 );
672
673 #[cfg(all(
674 feature = "MPSGraphOperation",
675 feature = "MPSGraphTensor",
676 feature = "MPSGraphTensorData"
677 ))]
678 /// Runs the graph for the given feeds and returns the target tensor values, ensuring all target operations also executed.
679 ///
680 /// This call is asynchronous and will return immediately if a completionHandler is set.
681 ///
682 /// - Parameters:
683 /// - feeds: Feeds dictionary for the placeholder tensors.
684 /// - targetTensors: Tensors for which the caller wishes MPSGraphTensorData to be returned.
685 /// - targetOperations: Operations to be completed at the end of the run.
686 /// - executionDescriptor: ExecutionDescriptor to be passed in and used.
687 /// - Returns: A valid MPSGraphTensor : MPSGraphTensorData dictionary with results synchronized to the CPU memory.
688 #[unsafe(method(runAsyncWithFeeds:targetTensors:targetOperations:executionDescriptor:))]
689 #[unsafe(method_family = none)]
690 pub unsafe fn runAsyncWithFeeds_targetTensors_targetOperations_executionDescriptor(
691 &self,
692 feeds: &MPSGraphTensorDataDictionary,
693 target_tensors: &NSArray<MPSGraphTensor>,
694 target_operations: Option<&NSArray<MPSGraphOperation>>,
695 execution_descriptor: Option<&MPSGraphExecutionDescriptor>,
696 ) -> Retained<MPSGraphTensorDataDictionary>;
697
698 #[cfg(all(
699 feature = "MPSGraphOperation",
700 feature = "MPSGraphTensor",
701 feature = "MPSGraphTensorData"
702 ))]
703 /// Runs the graph for the given feeds and returns the target tensor values, ensuring all target operations also executed.
704 ///
705 /// This call is asynchronous and will return immediately if a completionHandler is set.
706 ///
707 /// - Parameters:
708 /// - commandQueue: CommandQueue passed to exectute the graph on.
709 /// - feeds: Feeds dictionary for the placeholder tensors.
710 /// - targetTensors: Tensors for which the caller wishes MPSGraphTensorData to be returned.
711 /// - targetOperations: Operations to be completed at the end of the run.
712 /// - executionDescriptor: ExecutionDescriptor to be passed in and used.
713 /// - Returns: A valid MPSGraphTensor : MPSGraphTensorData dictionary with results synchronized to the CPU memory if MPSGraphOptionsSynchronizeResults set.
714 #[unsafe(method(runAsyncWithMTLCommandQueue:feeds:targetTensors:targetOperations:executionDescriptor:))]
715 #[unsafe(method_family = none)]
716 pub unsafe fn runAsyncWithMTLCommandQueue_feeds_targetTensors_targetOperations_executionDescriptor(
717 &self,
718 command_queue: &ProtocolObject<dyn MTLCommandQueue>,
719 feeds: &MPSGraphTensorDataDictionary,
720 target_tensors: &NSArray<MPSGraphTensor>,
721 target_operations: Option<&NSArray<MPSGraphOperation>>,
722 execution_descriptor: Option<&MPSGraphExecutionDescriptor>,
723 ) -> Retained<MPSGraphTensorDataDictionary>;
724
725 #[cfg(all(
726 feature = "MPSGraphOperation",
727 feature = "MPSGraphTensor",
728 feature = "MPSGraphTensorData"
729 ))]
730 /// Encodes the graph for the given feeds to returns the target tensor values in the results dictionary provided by the user.
731 ///
732 /// It ensures all target operations also executed. This call is asynchronous and will return immediately if a completionHandler is set.
733 ///
734 /// - Parameters:
735 /// - commandQueue: CommandQueue passed to exectute the graph on.
736 /// - feeds: Feeds dictionary for the placeholder tensors.
737 /// - targetOperations: Operations to be completed at the end of the run.
738 /// - resultsDictionary: MPSGraphTensors dictionary passed by user, these will be filled with graph output data.
739 /// - executionDescriptor: ExecutionDescriptor to be passed in and used.
740 #[unsafe(method(runAsyncWithMTLCommandQueue:feeds:targetOperations:resultsDictionary:executionDescriptor:))]
741 #[unsafe(method_family = none)]
742 pub unsafe fn runAsyncWithMTLCommandQueue_feeds_targetOperations_resultsDictionary_executionDescriptor(
743 &self,
744 command_queue: &ProtocolObject<dyn MTLCommandQueue>,
745 feeds: &MPSGraphTensorDataDictionary,
746 target_operations: Option<&NSArray<MPSGraphOperation>>,
747 results_dictionary: &MPSGraphTensorDataDictionary,
748 execution_descriptor: Option<&MPSGraphExecutionDescriptor>,
749 );
750
751 #[cfg(all(
752 feature = "MPSGraphOperation",
753 feature = "MPSGraphTensor",
754 feature = "MPSGraphTensorData",
755 feature = "objc2-metal-performance-shaders"
756 ))]
757 /// Encodes the graph for the given feeds to returns the target tensor values, ensuring all target operations also executed.
758 ///
759 /// This call is asynchronous and will return immediately if a completionHandler is set.
760 ///
761 /// - Parameters:
762 /// - commandBuffer: commandBuffer passed to exectute the graph on, it is an MPSCommandBuffer, commitAndContinue might be called, please don't rely on underlying MTLCommandBuffer to remain uncommitted.
763 /// - feeds: Feeds dictionary for the placeholder tensors.
764 /// - targetTensors: Tensors for which the caller wishes MPSGraphTensorData to be returned.
765 /// - targetOperations: Operations to be completed at the end of the run.
766 /// - executionDescriptor: ExecutionDescriptor to be passed in and used.
767 /// - Returns: A valid MPSGraphTensor : MPSGraphTensorData dictionary with results synchronized to the CPU memory if MPSGraphOptionsSynchronizeResults set.
768 #[unsafe(method(encodeToCommandBuffer:feeds:targetTensors:targetOperations:executionDescriptor:))]
769 #[unsafe(method_family = none)]
770 pub unsafe fn encodeToCommandBuffer_feeds_targetTensors_targetOperations_executionDescriptor(
771 &self,
772 command_buffer: &MPSCommandBuffer,
773 feeds: &MPSGraphTensorDataDictionary,
774 target_tensors: &NSArray<MPSGraphTensor>,
775 target_operations: Option<&NSArray<MPSGraphOperation>>,
776 execution_descriptor: Option<&MPSGraphExecutionDescriptor>,
777 ) -> Retained<MPSGraphTensorDataDictionary>;
778
779 #[cfg(all(
780 feature = "MPSGraphOperation",
781 feature = "MPSGraphTensor",
782 feature = "MPSGraphTensorData",
783 feature = "objc2-metal-performance-shaders"
784 ))]
785 /// Encodes the graph for the given feeds to returns the target tensor values in the results dictionary provided by the user.
786 ///
787 /// It ensures all target operations also executed. This call is asynchronous and will return immediately if a completionHandler is set.
788 ///
789 /// - Parameters:
790 /// - commandBuffer: commandBuffer passed to execute the graph on, commitAndContinue might be called, please don't rely on underlying MTLCommandBuffer to remain uncommitted.
791 /// - feeds: Feeds dictionary for the placeholder tensors.
792 /// - targetOperations: Operations to be completed at the end of the run.
793 /// - resultsDictionary: MPSGraphTensors dictionary passed by user, these will be filled with graph output data.
794 /// - executionDescriptor: ExecutionDescriptor to be passed in and used.
795 #[unsafe(method(encodeToCommandBuffer:feeds:targetOperations:resultsDictionary:executionDescriptor:))]
796 #[unsafe(method_family = none)]
797 pub unsafe fn encodeToCommandBuffer_feeds_targetOperations_resultsDictionary_executionDescriptor(
798 &self,
799 command_buffer: &MPSCommandBuffer,
800 feeds: &MPSGraphTensorDataDictionary,
801 target_operations: Option<&NSArray<MPSGraphOperation>>,
802 results_dictionary: &MPSGraphTensorDataDictionary,
803 execution_descriptor: Option<&MPSGraphExecutionDescriptor>,
804 );
805 );
806}