1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
use objc2::{
extern_class, extern_conformance, extern_methods,
rc::{Allocated, Retained},
runtime::NSObject,
};
use objc2_foundation::{CopyingHelper, NSCopying, NSObjectProtocol};
use super::{MTLTensorDataType, MTLTensorExtents, MTLTensorUsage};
use crate::{MTLCPUCacheMode, MTLHazardTrackingMode, MTLResourceOptions, MTLStorageMode};
extern_class!(
/// A configuration type for creating new tensor instances.
///
/// [Apple's documentation](https://developer.apple.com/documentation/metal/mtltensordescriptor?language=objc)
#[unsafe(super(NSObject))]
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct MTLTensorDescriptor;
);
extern_conformance!(
unsafe impl NSCopying for MTLTensorDescriptor {}
);
unsafe impl CopyingHelper for MTLTensorDescriptor {
type Result = Self;
}
extern_conformance!(
unsafe impl NSObjectProtocol for MTLTensorDescriptor {}
);
impl MTLTensorDescriptor {
extern_methods!(
/// An array of sizes, in elements, one for each dimension of the tensors you create with this descriptor.
///
/// The default value of this property is a rank one extents with size one.
#[unsafe(method(dimensions))]
#[unsafe(method_family = none)]
pub fn dimensions(&self) -> Retained<MTLTensorExtents>;
/// Setter for [`dimensions`][Self::dimensions].
#[unsafe(method(setDimensions:))]
#[unsafe(method_family = none)]
pub fn set_dimensions(
&self,
dimensions: &MTLTensorExtents,
);
/// An array of strides, in elements, one for each dimension in the tensors you create with this descriptor, if applicable.
///
/// This property only applies to tensors you create from a buffer, otherwise it is nil. You are responsible for ensuring `strides` meets the following requirements:
/// - Elements of `strides`are in monotonically non-decreasing order.
/// - The first element of `strides` is one.
/// - For any `i` larger than zero, `strides[i]` is greater than or equal to `strides[i-1] * dimensions[i-1]`.
/// - If `usage` contains `TensorUsage::MACHINE_LEARNING`, the second element of `strides` is aligned to 64 bytes, and for any `i` larger than one, `strides[i]` is equal to `strides[i-1] * dimensions[i-1]`.
#[unsafe(method(strides))]
#[unsafe(method_family = none)]
pub fn strides(&self) -> Option<Retained<MTLTensorExtents>>;
/// Setter for [`strides`][Self::strides].
#[unsafe(method(setStrides:))]
#[unsafe(method_family = none)]
pub fn set_strides(
&self,
strides: Option<&MTLTensorExtents>,
);
/// A data format for the tensors you create with this descriptor.
///
/// The default value of this property is `TensorDataType::Float32`.
#[unsafe(method(dataType))]
#[unsafe(method_family = none)]
pub fn data_type(&self) -> MTLTensorDataType;
/// Setter for [`dataType`][Self::dataType].
#[unsafe(method(setDataType:))]
#[unsafe(method_family = none)]
pub fn set_data_type(
&self,
data_type: MTLTensorDataType,
);
/// A set of contexts in which you can use tensors you create with this descriptor.
///
/// The default value for this property is a bitwise OR of: `TensorUsage::RENDER | TensorUsage::COMPUTE`.
#[unsafe(method(usage))]
#[unsafe(method_family = none)]
pub fn usage(&self) -> MTLTensorUsage;
/// Setter for [`usage`][Self::usage].
#[unsafe(method(setUsage:))]
#[unsafe(method_family = none)]
pub fn set_usage(
&self,
usage: MTLTensorUsage,
);
/// A packed set of the `storageMode`, `cpuCacheMode` and `hazardTrackingMode` properties.
#[unsafe(method(resourceOptions))]
#[unsafe(method_family = none)]
pub fn resource_options(&self) -> MTLResourceOptions;
/// Setter for [`resourceOptions`][Self::resourceOptions].
#[unsafe(method(setResourceOptions:))]
#[unsafe(method_family = none)]
pub fn set_resource_options(
&self,
resource_options: MTLResourceOptions,
);
/// A value that configures the cache mode of CPU mapping of tensors you create with this descriptor.
///
/// The default value of this property is `CpuCacheMode::DefaultCache`.
#[unsafe(method(cpuCacheMode))]
#[unsafe(method_family = none)]
pub fn cpu_cache_mode(&self) -> MTLCPUCacheMode;
/// Setter for [`cpuCacheMode`][Self::cpuCacheMode].
#[unsafe(method(setCpuCacheMode:))]
#[unsafe(method_family = none)]
pub fn set_cpu_cache_mode(
&self,
cpu_cache_mode: MTLCPUCacheMode,
);
/// A value that configures the memory location and access permissions of tensors you create with this descriptor.
///
/// The default value of this property defaults to `StorageMode::Shared`.
#[unsafe(method(storageMode))]
#[unsafe(method_family = none)]
pub fn storage_mode(&self) -> MTLStorageMode;
/// Setter for [`storageMode`][Self::storageMode].
#[unsafe(method(setStorageMode:))]
#[unsafe(method_family = none)]
pub fn set_storage_mode(
&self,
storage_mode: MTLStorageMode,
);
/// A value that configures the hazard tracking of tensors you create with this descriptor.
///
/// The default value of this property is `HazardTrackingMode::Default`.
#[unsafe(method(hazardTrackingMode))]
#[unsafe(method_family = none)]
pub fn hazard_tracking_mode(&self) -> MTLHazardTrackingMode;
/// Setter for [`hazardTrackingMode`][Self::hazardTrackingMode].
#[unsafe(method(setHazardTrackingMode:))]
#[unsafe(method_family = none)]
pub fn set_hazard_tracking_mode(
&self,
hazard_tracking_mode: MTLHazardTrackingMode,
);
);
}
/// Methods declared on superclass `NSObject`.
impl MTLTensorDescriptor {
extern_methods!(
#[unsafe(method(init))]
#[unsafe(method_family = init)]
pub fn init(this: Allocated<Self>) -> Retained<Self>;
#[unsafe(method(new))]
#[unsafe(method_family = new)]
pub fn new() -> Retained<Self>;
);
}
#[allow(unused)]
mod tests {
use super::*;
use crate::tensor::MTLTensorExtents;
fn make_extents(vals: &[isize]) -> Retained<MTLTensorExtents> {
// Safety: We pass a correct pointer or null based on rank.
MTLTensorExtents::new_with_rank_values(vals.len(), Some(vals)).expect("init extents")
}
#[test]
fn tensor_descriptor_round_trip() {
let desc = unsafe { MTLTensorDescriptor::new() };
// dimensions
let dims_in = make_extents(&[2, 3, 4]);
desc.set_dimensions(&dims_in);
let dims_out = desc.dimensions();
assert_eq!(dims_out.rank(), 3);
assert_eq!(dims_out.extent_at_dimension_index(0), 2);
assert_eq!(dims_out.extent_at_dimension_index(1), 3);
assert_eq!(dims_out.extent_at_dimension_index(2), 4);
// strides
let strides_in = make_extents(&[1, 2, 6]);
desc.set_strides(Some(&strides_in));
let strides_out = desc.strides().expect("strides set");
assert_eq!(strides_out.rank(), 3);
assert_eq!(strides_out.extent_at_dimension_index(0), 1);
assert_eq!(strides_out.extent_at_dimension_index(1), 2);
assert_eq!(strides_out.extent_at_dimension_index(2), 6);
// data type
desc.set_data_type(MTLTensorDataType::Float16);
assert_eq!(desc.data_type(), MTLTensorDataType::Float16);
// usage
let usage = MTLTensorUsage::COMPUTE | MTLTensorUsage::RENDER;
desc.set_usage(usage);
let usage_out = desc.usage();
assert!(usage_out.contains(MTLTensorUsage::COMPUTE));
assert!(usage_out.contains(MTLTensorUsage::RENDER));
// cpu cache mode
desc.set_cpu_cache_mode(MTLCPUCacheMode::WriteCombined);
assert_eq!(desc.cpu_cache_mode(), MTLCPUCacheMode::WriteCombined);
// storage mode
desc.set_storage_mode(MTLStorageMode::Private);
assert_eq!(desc.storage_mode(), MTLStorageMode::Private);
// hazard tracking
desc.set_hazard_tracking_mode(MTLHazardTrackingMode::Untracked);
assert_eq!(desc.hazard_tracking_mode(), MTLHazardTrackingMode::Untracked);
// resource options should reflect the above modes
let ro = desc.resource_options();
let expected = MTLResourceOptions::CPU_CACHE_MODE_WRITE_COMBINED
| MTLResourceOptions::STORAGE_MODE_PRIVATE
| MTLResourceOptions::HAZARD_TRACKING_MODE_UNTRACKED;
assert!(ro.contains(expected));
}
}