objc2_metal_performance_shaders_graph/generated/MPSGraphGatherOps.rs
1//! This file has been automatically generated by `objc2`'s `header-translator`.
2//! DO NOT EDIT
3use core::ptr::NonNull;
4use objc2::__framework_prelude::*;
5use objc2_foundation::*;
6
7use crate::*;
8
9/// GatherNDOps.
10#[cfg(all(feature = "MPSGraph", feature = "MPSGraphCore"))]
11impl MPSGraph {
12 extern_methods!(
13 #[cfg(feature = "MPSGraphTensor")]
14 /// Creates a GatherND operation and returns the result tensor.
15 ///
16 /// Gathers the slices in updatesTensor to the result tensor along the indices in indicesTensor.
17 /// The gather is defined as
18 /// ```md
19 /// B = batchDims
20 /// U = updates.rank - B
21 /// P = res.rank - B
22 /// Q = inds.rank - B
23 /// K = inds.shape[-1]
24 /// index_slice = indices[i_{b0},...,i_{bB},i_{0},..,i_{Q-1}]
25 /// res[i_{b0},...,i_{bB},i_{0},...,i_{Q-1}] = updates[i_{b0},...,i_{bB},index_slice[0],...,index_slice[K-1]]
26 /// ```
27 /// The tensors have the following shape requirements
28 /// ```md
29 /// U > 0; P > 0; Q > 0
30 /// K
31 /// <
32 /// = U
33 /// P = (U-K) + Q-1
34 /// indices.shape[0:Q-1] = res.shape[0:Q-1]
35 /// res.shape[Q:P] = updates.shape[K:U]
36 /// ```
37 ///
38 /// - Parameters:
39 /// - updatesTensor: Tensor containing slices to be inserted into the result tensor.
40 /// - indicesTensor: Tensor containg the updates indices to read slices from
41 /// - batchDimensions: The number of batch dimensions
42 /// - name: The name for the operation.
43 /// - Returns: A valid MPSGraphTensor object
44 #[unsafe(method(gatherNDWithUpdatesTensor:indicesTensor:batchDimensions:name:))]
45 #[unsafe(method_family = none)]
46 pub unsafe fn gatherNDWithUpdatesTensor_indicesTensor_batchDimensions_name(
47 &self,
48 updates_tensor: &MPSGraphTensor,
49 indices_tensor: &MPSGraphTensor,
50 batch_dimensions: NSUInteger,
51 name: Option<&NSString>,
52 ) -> Retained<MPSGraphTensor>;
53 );
54}
55
56/// GatherOps.
57#[cfg(all(feature = "MPSGraph", feature = "MPSGraphCore"))]
58impl MPSGraph {
59 extern_methods!(
60 #[cfg(feature = "MPSGraphTensor")]
61 /// Creates a Gather operation and returns the result tensor.
62 ///
63 /// Gathers the values in updatesTensor to the result tensor along the indices in indicesTensor.
64 /// The gather is defined as
65 /// ```md
66 /// B = batchDims
67 /// U = updates.rank
68 /// P = res.rank
69 /// Q = inds.rank
70 /// res[p_{0},...p_{axis-1}, i_{B},...,i_{Q}, ...,p_{axis+1},...,p{U-1}] =
71 /// updates[p_{0},...p_{axis-1}, indices[p_{0},...,p_{B-1},i_{B},...,i_{Q}, ...,p_{axis+1},...,p{U-1}]
72 /// ```
73 /// The tensors have the following shape requirements
74 /// ```md
75 /// P = Q-B + U-1
76 /// indices.shape[0:B] = updates.shape[0:B] = res.shape[0:B]
77 /// res.shape[0:axis] = updates.shape[0:axis]
78 /// res.shape[axis:axis+Q-B] = indices.shape[B:]
79 /// res.shape[axis+1+Q-B:] = updates.shape[axis+1:]
80 /// ```
81 ///
82 /// - Parameters:
83 /// - updatesTensor: Tensor containing slices to be inserted into the result tensor.
84 /// - indicesTensor: Tensor containg the updates indices to read slices from
85 /// - axis: The dimension on which to perform the gather
86 /// - batchDimensions: The number of batch dimensions
87 /// - name: The name for the operation.
88 /// - Returns: A valid MPSGraphTensor object
89 #[unsafe(method(gatherWithUpdatesTensor:indicesTensor:axis:batchDimensions:name:))]
90 #[unsafe(method_family = none)]
91 pub unsafe fn gatherWithUpdatesTensor_indicesTensor_axis_batchDimensions_name(
92 &self,
93 updates_tensor: &MPSGraphTensor,
94 indices_tensor: &MPSGraphTensor,
95 axis: NSUInteger,
96 batch_dimensions: NSUInteger,
97 name: Option<&NSString>,
98 ) -> Retained<MPSGraphTensor>;
99 );
100}
101
102/// MPSGraphGatherAlongAxisOps.
103#[cfg(all(feature = "MPSGraph", feature = "MPSGraphCore"))]
104impl MPSGraph {
105 extern_methods!(
106 #[cfg(feature = "MPSGraphTensor")]
107 /// Creates a GatherAlongAxis operation and returns the result tensor.
108 ///
109 /// Gather values from `updatesTensor` along the specified `axis` at indices in `indicesTensor`.
110 /// The shape of `updatesTensor` and `indicesTensor` must match except at `axis`.
111 /// The shape of the result tensor is equal to the shape of `indicesTensor`.
112 /// If an index is out of bounds of the `updatesTensor` along `axis` a 0 is inserted.
113 ///
114 /// - Parameters:
115 /// - axis: The axis to gather from. Negative values wrap around
116 /// - updatesTensor: The input tensor to gather values from
117 /// - indicesTensor: Int32 or Int64 tensor used to index `updatesTensor`
118 /// - name: The name for the operation.
119 /// - Returns: A valid MPSGraphTensor object
120 #[unsafe(method(gatherAlongAxis:withUpdatesTensor:indicesTensor:name:))]
121 #[unsafe(method_family = none)]
122 pub unsafe fn gatherAlongAxis_withUpdatesTensor_indicesTensor_name(
123 &self,
124 axis: NSInteger,
125 updates_tensor: &MPSGraphTensor,
126 indices_tensor: &MPSGraphTensor,
127 name: Option<&NSString>,
128 ) -> Retained<MPSGraphTensor>;
129
130 #[cfg(feature = "MPSGraphTensor")]
131 /// Creates a GatherAlongAxis operation and returns the result tensor.
132 ///
133 /// Gather values from `updatesTensor` along the specified `axis` at indices in `indicesTensor`.
134 /// The shape of `updatesTensor` and `indicesTensor` must match except at `axis`.
135 /// The shape of the result tensor is equal to the shape of `indicesTensor`.
136 /// If an index is out of bounds of the `updatesTensor` along `axis` a 0 is inserted.
137 ///
138 /// - Parameters:
139 /// - axisTensor: Scalar Int32 tensor. The axis to gather from. Negative values wrap around
140 /// - updatesTensor: The input tensor to gather values from
141 /// - indicesTensor: Int32 or Int64 tensor used to index `updatesTensor`
142 /// - name: The name for the operation.
143 /// - Returns: A valid MPSGraphTensor object
144 #[unsafe(method(gatherAlongAxisTensor:withUpdatesTensor:indicesTensor:name:))]
145 #[unsafe(method_family = none)]
146 pub unsafe fn gatherAlongAxisTensor_withUpdatesTensor_indicesTensor_name(
147 &self,
148 axis_tensor: &MPSGraphTensor,
149 updates_tensor: &MPSGraphTensor,
150 indices_tensor: &MPSGraphTensor,
151 name: Option<&NSString>,
152 ) -> Retained<MPSGraphTensor>;
153 );
154}