objc2_metal_performance_shaders_graph/generated/MPSGraphMatrixMultiplicationOps.rs
1//! This file has been automatically generated by `objc2`'s `header-translator`.
2//! DO NOT EDIT
3use core::ffi::*;
4use core::ptr::NonNull;
5use objc2::__framework_prelude::*;
6use objc2_foundation::*;
7#[cfg(feature = "objc2-metal-performance-shaders")]
8use objc2_metal_performance_shaders::*;
9
10use crate::*;
11
12/// MPSGraphMatrixMultiplicationOps.
13#[cfg(all(feature = "MPSGraph", feature = "MPSGraphCore"))]
14impl MPSGraph {
15 extern_methods!(
16 #[cfg(feature = "MPSGraphTensor")]
17 /// Computes the matrix multiplication of 2 input tensors with support for broadcasting.
18 ///
19 /// - Parameters:
20 /// - primaryTensor: The left-hand side tensor.
21 /// - secondaryTensor: The right-hand side tensor.
22 /// - name: The name for the operation.
23 /// - Returns: A valid tensor containing the product of the input matrices.
24 #[unsafe(method(matrixMultiplicationWithPrimaryTensor:secondaryTensor:name:))]
25 #[unsafe(method_family = none)]
26 pub unsafe fn matrixMultiplicationWithPrimaryTensor_secondaryTensor_name(
27 &self,
28 primary_tensor: &MPSGraphTensor,
29 secondary_tensor: &MPSGraphTensor,
30 name: Option<&NSString>,
31 ) -> Retained<MPSGraphTensor>;
32
33 #[cfg(all(
34 feature = "MPSGraphTensor",
35 feature = "objc2-metal-performance-shaders"
36 ))]
37 /// Computes the hamming distance of two input tensors with support for broadcasting.
38 ///
39 /// The hamming distance is computed between 2 sets of vectors and the last dimension(s) of each
40 /// input tensor is considered a vector.
41 ///
42 /// - Parameters:
43 /// - primaryTensor: The first input tensor.
44 /// - secondaryTensor: The second input tensor.
45 /// - resultDataType: The datatype of the return MPSGraphTensor. Must be either ``MPSDataTypeUInt32`` or ``MPSDataTypeUInt16``.
46 /// - name: The name for the operation.
47 /// - Returns: A valid tensor containing the hamming distance between the input tensors.
48 #[unsafe(method(HammingDistanceWithPrimaryTensor:secondaryTensor:resultDataType:name:))]
49 #[unsafe(method_family = none)]
50 pub unsafe fn HammingDistanceWithPrimaryTensor_secondaryTensor_resultDataType_name(
51 &self,
52 primary_tensor: &MPSGraphTensor,
53 secondary_tensor: &MPSGraphTensor,
54 result_data_type: MPSDataType,
55 name: Option<&NSString>,
56 ) -> Retained<MPSGraphTensor>;
57
58 #[cfg(feature = "MPSGraphTensor")]
59 /// Creates a scaled dot product attention (SDPA) operation and returns the result tensor.
60 ///
61 /// SDPA Op computes attention by computing softmax(scale * QK^T + M)V.
62 /// queryTensor Q with shape [B, Hq, Nq, F] and keyTensor K with shape [B, Hq, Nkv, F],
63 /// with Q's H dimension expandable to satisfy matmul QK^T. maskTensor M's shape
64 /// should be broadcast compatible to satisfy (QK^T + M). valueTensor V with shape
65 /// [B, Hv, Nkv, F] should satisfy the matmul (QK^T + M)V.
66 ///
67 /// - Parameters:
68 /// - queryTensor: A tensor that represents the query projection.
69 /// - keyTensor: A tensor that represents the key projection.
70 /// - valueTensor: A tensor that represents the value projection.
71 /// - maskTensor: An optional tensor that contains a mask that is applied to the scaled, matrix
72 /// multiplied query and value matrices. If mask tensor is nil, the QK^T is not element-wise masked.
73 /// - scale: A scale that is applied to the result of query and value matrix multiply.
74 /// - name: The name for the operation.
75 /// - Returns: A valid MPSGraphTensor object.
76 #[unsafe(method(scaledDotProductAttentionWithQueryTensor:keyTensor:valueTensor:maskTensor:scale:name:))]
77 #[unsafe(method_family = none)]
78 pub unsafe fn scaledDotProductAttentionWithQueryTensor_keyTensor_valueTensor_maskTensor_scale_name(
79 &self,
80 query_tensor: &MPSGraphTensor,
81 key_tensor: &MPSGraphTensor,
82 value_tensor: &MPSGraphTensor,
83 mask_tensor: Option<&MPSGraphTensor>,
84 scale: c_float,
85 name: Option<&NSString>,
86 ) -> Retained<MPSGraphTensor>;
87
88 #[cfg(feature = "MPSGraphTensor")]
89 /// Creates a scaled dot product attention (SDPA) operation (without a mask) and returns the result tensor.
90 ///
91 /// - Parameters:
92 /// - queryTensor: A tensor that represents the query projection.
93 /// - keyTensor: A tensor that represents the key projection.
94 /// - valueTensor: A tensor that represents the value projection.
95 /// - scale: A scale that is applied on the result of query and value matrix multiply.
96 /// - name: The name for the operation.
97 /// - Returns: A valid MPSGraphTensor object.
98 #[unsafe(method(scaledDotProductAttentionWithQueryTensor:keyTensor:valueTensor:scale:name:))]
99 #[unsafe(method_family = none)]
100 pub unsafe fn scaledDotProductAttentionWithQueryTensor_keyTensor_valueTensor_scale_name(
101 &self,
102 query_tensor: &MPSGraphTensor,
103 key_tensor: &MPSGraphTensor,
104 value_tensor: &MPSGraphTensor,
105 scale: c_float,
106 name: Option<&NSString>,
107 ) -> Retained<MPSGraphTensor>;
108 );
109}