objc2_metal_performance_shaders_graph/generated/
MPSGraphGatherOps.rs

1//! This file has been automatically generated by `objc2`'s `header-translator`.
2//! DO NOT EDIT
3use core::ptr::NonNull;
4use objc2::__framework_prelude::*;
5use objc2_foundation::*;
6
7use crate::*;
8
9/// GatherNDOps.
10#[cfg(all(feature = "MPSGraph", feature = "MPSGraphCore"))]
11impl MPSGraph {
12    extern_methods!(
13        #[cfg(feature = "MPSGraphTensor")]
14        /// Creates a GatherND operation and returns the result tensor.
15        ///
16        /// Gathers the slices in updatesTensor to the result tensor along the indices in indicesTensor.
17        /// The gather is defined as
18        /// ```md
19        /// B = batchDims
20        /// U = updates.rank - B
21        /// P = res.rank - B
22        /// Q = inds.rank - B
23        /// K = inds.shape[-1]
24        /// index_slice = indices[i_{b0},...,i_{bB},i_{0},..,i_{Q-1}]
25        /// res[i_{b0},...,i_{bB},i_{0},...,i_{Q-1}] = updates[i_{b0},...,i_{bB},index_slice[0],...,index_slice[K-1]]
26        /// ```
27        /// The tensors have the following shape requirements
28        /// ```md
29        /// U > 0; P > 0; Q > 0
30        /// K
31        /// <
32        /// = U
33        /// P = (U-K) + Q-1
34        /// indices.shape[0:Q-1] = res.shape[0:Q-1]
35        /// res.shape[Q:P] = updates.shape[K:U]
36        /// ```
37        ///
38        /// - Parameters:
39        /// - updatesTensor: Tensor containing slices to be inserted into the result tensor.
40        /// - indicesTensor: Tensor containg the updates indices to read slices from
41        /// - batchDimensions: The number of batch dimensions
42        /// - name: The name for the operation.
43        /// - Returns: A valid MPSGraphTensor object
44        #[unsafe(method(gatherNDWithUpdatesTensor:indicesTensor:batchDimensions:name:))]
45        #[unsafe(method_family = none)]
46        pub unsafe fn gatherNDWithUpdatesTensor_indicesTensor_batchDimensions_name(
47            &self,
48            updates_tensor: &MPSGraphTensor,
49            indices_tensor: &MPSGraphTensor,
50            batch_dimensions: NSUInteger,
51            name: Option<&NSString>,
52        ) -> Retained<MPSGraphTensor>;
53    );
54}
55
56/// GatherOps.
57#[cfg(all(feature = "MPSGraph", feature = "MPSGraphCore"))]
58impl MPSGraph {
59    extern_methods!(
60        #[cfg(feature = "MPSGraphTensor")]
61        /// Creates a Gather operation and returns the result tensor.
62        ///
63        /// Gathers the values in updatesTensor to the result tensor along the indices in indicesTensor.
64        /// The gather is defined as
65        /// ```md
66        /// B = batchDims
67        /// U = updates.rank
68        /// P = res.rank
69        /// Q = inds.rank
70        /// res[p_{0},...p_{axis-1}, i_{B},...,i_{Q}, ...,p_{axis+1},...,p{U-1}] =
71        /// updates[p_{0},...p_{axis-1}, indices[p_{0},...,p_{B-1},i_{B},...,i_{Q}, ...,p_{axis+1},...,p{U-1}]
72        /// ```
73        /// The tensors have the following shape requirements
74        /// ```md
75        /// P = Q-B + U-1
76        /// indices.shape[0:B] = updates.shape[0:B] = res.shape[0:B]
77        /// res.shape[0:axis] = updates.shape[0:axis]
78        /// res.shape[axis:axis+Q-B] = indices.shape[B:]
79        /// res.shape[axis+1+Q-B:] = updates.shape[axis+1:]
80        /// ```
81        ///
82        /// - Parameters:
83        /// - updatesTensor: Tensor containing slices to be inserted into the result tensor.
84        /// - indicesTensor: Tensor containg the updates indices to read slices from
85        /// - axis: The dimension on which to perform the gather
86        /// - batchDimensions: The number of batch dimensions
87        /// - name: The name for the operation.
88        /// - Returns: A valid MPSGraphTensor object
89        #[unsafe(method(gatherWithUpdatesTensor:indicesTensor:axis:batchDimensions:name:))]
90        #[unsafe(method_family = none)]
91        pub unsafe fn gatherWithUpdatesTensor_indicesTensor_axis_batchDimensions_name(
92            &self,
93            updates_tensor: &MPSGraphTensor,
94            indices_tensor: &MPSGraphTensor,
95            axis: NSUInteger,
96            batch_dimensions: NSUInteger,
97            name: Option<&NSString>,
98        ) -> Retained<MPSGraphTensor>;
99    );
100}
101
102/// MPSGraphGatherAlongAxisOps.
103#[cfg(all(feature = "MPSGraph", feature = "MPSGraphCore"))]
104impl MPSGraph {
105    extern_methods!(
106        #[cfg(feature = "MPSGraphTensor")]
107        /// Creates a GatherAlongAxis operation and returns the result tensor.
108        ///
109        /// Gather values from `updatesTensor` along the specified `axis` at indices in `indicesTensor`.
110        /// The shape of `updatesTensor` and `indicesTensor` must match except at `axis`.
111        /// The shape of the result tensor is equal to the shape of `indicesTensor`.
112        /// If an index is out of bounds of the `updatesTensor` along `axis` a 0 is inserted.
113        ///
114        /// - Parameters:
115        /// - axis: The axis to gather from. Negative values wrap around
116        /// - updatesTensor: The input tensor to gather values from
117        /// - indicesTensor: Int32 or Int64 tensor used to index `updatesTensor`
118        /// - name: The name for the operation.
119        /// - Returns: A valid MPSGraphTensor object
120        #[unsafe(method(gatherAlongAxis:withUpdatesTensor:indicesTensor:name:))]
121        #[unsafe(method_family = none)]
122        pub unsafe fn gatherAlongAxis_withUpdatesTensor_indicesTensor_name(
123            &self,
124            axis: NSInteger,
125            updates_tensor: &MPSGraphTensor,
126            indices_tensor: &MPSGraphTensor,
127            name: Option<&NSString>,
128        ) -> Retained<MPSGraphTensor>;
129
130        #[cfg(feature = "MPSGraphTensor")]
131        /// Creates a GatherAlongAxis operation and returns the result tensor.
132        ///
133        /// Gather values from `updatesTensor` along the specified `axis` at indices in `indicesTensor`.
134        /// The shape of `updatesTensor` and `indicesTensor` must match except at `axis`.
135        /// The shape of the result tensor is equal to the shape of `indicesTensor`.
136        /// If an index is out of bounds of the `updatesTensor` along `axis` a 0 is inserted.
137        ///
138        /// - Parameters:
139        /// - axisTensor: Scalar Int32 tensor. The axis to gather from. Negative values wrap around
140        /// - updatesTensor: The input tensor to gather values from
141        /// - indicesTensor: Int32 or Int64 tensor used to index `updatesTensor`
142        /// - name: The name for the operation.
143        /// - Returns: A valid MPSGraphTensor object
144        #[unsafe(method(gatherAlongAxisTensor:withUpdatesTensor:indicesTensor:name:))]
145        #[unsafe(method_family = none)]
146        pub unsafe fn gatherAlongAxisTensor_withUpdatesTensor_indicesTensor_name(
147            &self,
148            axis_tensor: &MPSGraphTensor,
149            updates_tensor: &MPSGraphTensor,
150            indices_tensor: &MPSGraphTensor,
151            name: Option<&NSString>,
152        ) -> Retained<MPSGraphTensor>;
153    );
154}