cidre/mlc/types.rs
1use crate::{blocks, define_opts, mlc, ns};
2
3pub type GraphCompletionHandler =
4 blocks::SyncBlock<fn(Option<&mlc::Tensor>, Option<&ns::Error>, ns::TimeInterval)>;
5
6#[derive(Debug, Copy, Clone, Eq, PartialEq)]
7#[repr(i32)]
8pub enum DType {
9 Invalid = 0,
10
11 /// The 32-bit floating-point data type.
12 F32 = 1,
13
14 /// The 16-bit floating-point data type.
15 F16 = 3,
16
17 /// Boolean data type.
18 Bool = 4,
19
20 /// The 64-bit integer data type
21 I64 = 5,
22
23 /// The 32-bit integer data type
24 I32 = 7,
25
26 /// The 8-bit integer data type
27 I8 = 8,
28
29 /// The 8-bit unsigned integer data type.
30 U8 = 9,
31}
32
33#[derive(Debug, Copy, Clone, Eq, PartialEq)]
34#[repr(i32)]
35pub enum RandomInitializerType {
36 Invalid = 0,
37
38 /// The uniform random initializer type.
39 Uniform = 1,
40
41 /// The glorot uniform random initializer type.
42 GlorotUniform = 2,
43
44 /// The Xavier random initializer type.
45 Xavier = 3,
46}
47
48#[derive(Debug, Copy, Clone, Eq, PartialEq)]
49#[repr(i32)]
50pub enum DeviceType {
51 /// The CPU device
52 Cpu = 0,
53
54 /// The GPU device
55 Gpu = 1,
56
57 /// The any device type. When selected, the framework will automatically use the appropriate devices
58 /// to achieve the best performance.
59 Any = 2,
60
61 /// The Apple Neural Engine device. When selected, the framework will use the Neural Engine to execute all layers that can be executed on it.
62 /// Layers that cannot be executed on the ANE will run on the CPU or GPU. The Neural Engine device must be explicitly selected. MLDeviceTypeAny
63 /// will not select the Neural Engine device. In addition, this device can be used with inference graphs only. This device cannot be used with a
64 /// training graph or an inference graph that shares layers with a training graph.
65 ///
66 Ane = 3,
67}
68
69#[doc(alias = "MLCArithmeticOperation")]
70#[derive(Debug, Copy, Clone, Eq, PartialEq)]
71#[repr(i32)]
72pub enum ArithmeticOp {
73 /// An operation that calculates the elementwise sum of its two inputs.
74 Add = 0,
75
76 /// An operation that calculates the elementwise difference of its two inputs.
77 Subtract = 1,
78
79 /// An operation that calculates the elementwise product of its two inputs.
80 Multiply = 2,
81
82 /// An operation that calculates the elementwise division of its two inputs.
83 Divide = 3,
84
85 /// An operation that calculates the elementwise floor of its two inputs.
86 Floor = 4,
87
88 /// An operation that calculates the elementwise round of its inputs.
89 Round = 5,
90
91 /// An operation that calculates the elementwise ceiling of its inputs.
92 Ceil = 6,
93
94 /// An operation that calculates the elementwise square root of its inputs.
95 Sqrt = 7,
96
97 /// An operation that calculates the elementwise reciprocal of the square root of its inputs.
98 RSqrt = 8,
99
100 /// An operation that calculates the elementwise sine of its inputs.
101 Sin = 9,
102
103 /// An operation that calculates the elementwise cosine of its inputs.
104 Cos = 10,
105
106 /// An operation that calculates the elementwise tangent of its inputs.
107 Tan = 11,
108
109 /// An operation that calculates the elementwise inverse sine of its inputs.
110 ASin = 12,
111
112 /// An operation that calculates the elementwise inverse cosine of its inputs.
113 ACos = 13,
114
115 /// An operation that calculates the elementwise inverse tangent of its inputs.
116 ATan = 14,
117
118 /// An operation that calculates the elementwise hyperbolic sine of its inputs.
119 SinH = 15,
120
121 /// An operation that calculates the elementwise hyperbolic cosine of its inputs.
122 CosH = 16,
123
124 /// An operation that calculates the elementwise hyperbolic tangent of its inputs.
125 TanH = 17,
126
127 /// An operation that calculates the elementwise inverse hyperbolic sine of its inputs.
128 ASinH = 18,
129
130 /// An operation that calculates the elementwise inverse hyperbolic cosine of its inputs.
131 ACosH = 19,
132
133 /// An operation that calculates the elementwise inverse hyperbolic tangent of its inputs.
134 ATanH = 20,
135
136 /// An operation that calculates the elementwise first input raised to the power of its second input.
137 Pow = 21,
138
139 /// An operation that calculates the elementwise result of e raised to the power of its input.
140 Exp = 22,
141
142 /// An operation that calculates the elementwise result of 2 raised to the power of its input.
143 Exp2 = 23,
144
145 /// An operation that calculates the elementwise natural logarithm of its input.
146 Log = 24,
147
148 /// An operation that calculates the elementwise base 2 logarithm of its input.
149 Log2 = 25,
150
151 /// An operation that calculates the elementwise product of its two inputs. Returns 0 if y in x * y is zero, even if x is NaN or INF
152 MultiplyNoNaN = 26,
153
154 /// An operations that calculates the elementwise division of its two inputs. Returns 0 if the denominator is 0.
155 DivideNoNaN = 27,
156
157 /// An operation that calculates the elementwise min of two inputs.
158 Min = 28,
159
160 /// An operations that calculates the elementwise max of two inputs.
161 Max = 29,
162}
163
164impl ArithmeticOp {
165 #[inline]
166 pub fn debug_desc(self) -> &'static ns::String {
167 unsafe { MLCArithmeticOperationDebugDescription(self) }
168 }
169}
170
171/// A loss function.
172#[doc(alias = "MLCLossType")]
173#[derive(Debug, Copy, Clone, Eq, PartialEq)]
174#[repr(i32)]
175pub enum LossType {
176 /// The mean absolute error loss.
177 MeanAbsoluteError = 0,
178
179 /// The mean squared error loss.
180 MeanSquaredError = 1,
181
182 /// The softmax cross entropy loss.
183 SoftmaxCrossEntropy = 2,
184
185 /// The sigmoid cross entropy loss.
186 SigmoidCrossEntropy = 3,
187
188 /// The categorical cross entropy loss.
189 CategoricalCrossEntropy = 4,
190
191 /// The hinge loss.
192 Hinge = 5,
193
194 /// The Huber loss.
195 Huber = 6,
196
197 /// The cosine distance loss.
198 CosineDistance = 7,
199
200 /// The log loss.
201 Log = 8,
202}
203
204impl LossType {
205 #[inline]
206 pub fn debug_desc(self) -> &'static ns::String {
207 unsafe { MLCLossTypeDebugDescription(self) }
208 }
209}
210
211/// An activation type that you specify for an activation descriptor.
212#[doc(alias = "MLCActivationType")]
213#[derive(Debug, Copy, Clone, Eq, PartialEq)]
214#[repr(i32)]
215pub enum ActivationType {
216 None = 0,
217
218 /// The ReLU activation type.
219 ///
220 /// This activation type implements the following function:
221 /// ```pseudo
222 /// f(x) = x >= 0 ? x : a * x`
223 /// ```
224 ReLU = 1,
225
226 /// The linear activation type.
227 ///
228 /// This activation type implements the following function:
229 /// ```pseudo
230 /// f(x) = a * x + b
231 /// ```
232 Linear = 2,
233
234 /// The sigmoid activation type.
235 /// This activation type implements the following function:
236 /// ```pseudo
237 /// f(x) = 1 / (1 + e⁻ˣ)
238 /// ```
239 Sigmoid = 3,
240
241 /// The hard sigmoid activation type.
242 ///
243 /// This activation type implements the following function:
244 /// ```pseudo
245 /// f(x) = clamp((x * a) + b, 0, 1)
246 /// ```
247 HardSigmoid = 4,
248
249 /// The hyperbolic tangent (TanH) activation type.
250 /// This activation type implements the following function:
251 /// ```pseudo
252 /// f(x) = a * tanh(b * x)
253 /// ```
254 TanH = 5,
255
256 /// The absolute activation type.
257 ///
258 /// This activation type implements the following function:
259 /// ```pseudo
260 /// f(x) = fabs(x)
261 /// ```
262 Absolute = 6,
263
264 /// The parametric soft plus activation type.
265 ///
266 /// This activation type implements the following function:
267 /// ```pseudo
268 /// f(x) = a * log(1 + e^(b * x))
269 /// ```
270 SoftPlus = 7,
271
272 /// The parametric soft sign activation type.
273 ///
274 /// This activation type implements the following function:
275 /// ```pseudo
276 /// f(x) = x / (1 + abs(x))
277 /// ```
278 SoftSign = 8,
279
280 /// The parametric ELU activation type.
281 /// This activation type implements the following function:
282 /// ```pseudo
283 /// f(x) = x >= 0 ? x : a * (exp(x) - 1)
284 /// ```
285 ELU = 9,
286
287 /// The ReLUN activation type.
288 ///
289 /// This activation type implements the following function:
290 /// ```pseudo
291 /// f(x) = min((x >= 0 ? x : a * x), b)
292 /// ```
293 ReLUN = 10,
294
295 /// The log sigmoid activation type.
296 ///
297 /// This activation type implements the following function:
298 /// ```pseudo
299 /// f(x) = log(1 / (1 + exp(-x)))
300 /// ```
301 LogSigmoid = 11,
302
303 /// The SELU activation type.
304 ///
305 /// This activation type implements the following function:
306 /// ```pseudo
307 /// f(x) = scale * (max(0, x) + min(0, α * (exp(x) − 1)))
308 /// ```
309 /// where:
310 /// ```pseudo
311 /// α = 1.6732632423543772848170429916717
312 /// scale = 1.0507009873554804934193349852946
313 /// ```
314 SELU = 12,
315
316 /// The CELU activation type.
317 ///
318 /// This activation type implements the following function:
319 /// ```pseudo
320 /// f(x) = max(0, x) + min(0, a * (exp(x / a) − 1))
321 /// ```
322 CELU = 13,
323
324 /// The hard shrink activation type.
325 ///
326 /// This activation type implements the following function:
327 /// ```pseudo
328 /// f(x) = x, if x > a or x < −a, else 0
329 /// ```
330 HardShrink = 14,
331
332 /// The soft shrink activation type.
333 ///
334 /// This activation type implements the following function:
335 /// ```pseudo
336 /// f(x) = x - a, if x > a, x + a, if x < −a, else 0
337 /// ```
338 SoftShrink = 15,
339
340 /// The hyperbolic tangent (TanH) shrink activation type.
341 /// This activation type implements the following function:
342 /// ```pseudo
343 /// f(x) = x - tanh(x)
344 /// ```
345 TanHShrink = 16,
346
347 /// The threshold activation type.
348 ///
349 /// This activation type implements the following function:
350 /// ```pseudo
351 /// f(x) = x, if x > a, else b
352 /// ```
353 Threshold = 17,
354
355 /// The GELU activation type.
356 /// This activation type implements the following function:
357 /// ```pseudo
358 /// f(x) = x * CDF(x)
359 /// ```
360 GELU = 18,
361
362 /// The hardswish activation type.
363 ///
364 /// This activation type implements the following function:
365 /// ```pseudo
366 /// f(x) = 0, if x <= -3
367 /// f(x) = x, if x >= +3
368 /// f(x) = x * (x + 3)/6, otherwise
369 /// ```
370 HardSwish = 19,
371
372 /// The clamp activation type.
373 /// This activation type implements the following function:
374 /// ```pseudo
375 /// f(x) = min(max(x, a), b)
376 /// ```
377 Clamp = 20,
378}
379
380impl ActivationType {
381 /// Returns a textual description of the arithmetic operation, suitable for debugging
382 pub fn debug_desc(self) -> &'static ns::String {
383 unsafe { MLCActivationTypeDebugDescription(self) }
384 }
385}
386
387#[derive(Debug, Copy, Clone, Eq, PartialEq)]
388#[repr(i32)]
389pub enum ConvolutionType {
390 /// The standard convolution type.
391 Standard = 0,
392
393 /// The transposed convolution type.
394 Transposed = 1,
395
396 /// The depthwise convolution type.
397 Depthwise = 2,
398}
399
400impl ConvolutionType {
401 pub fn debug_desc(self) -> &'static ns::String {
402 unsafe { MLCConvolutionTypeDebugDescription(self) }
403 }
404}
405
406#[derive(Debug, Copy, Clone, Eq, PartialEq)]
407#[repr(i32)]
408pub enum PaddingPolicy {
409 /// The "same" padding policy.
410 Same = 0,
411 /// The "valid" padding policy.
412 Valid = 1,
413 /// The choice to use explicitly specified padding sizes.
414 UsePaddingSize = 2,
415}
416
417impl PaddingPolicy {
418 #[inline]
419 pub fn debug_desc(self) -> &'static ns::String {
420 unsafe { MLCPaddingPolicyDebugDescription(self) }
421 }
422}
423
424#[derive(Debug, Copy, Clone, Eq, PartialEq)]
425#[repr(i32)]
426pub enum PaddingType {
427 /// The zero padding type.
428 Zero = 0,
429 /// The reflect padding type.
430 Reflect = 1,
431 /// The symmetric padding type.
432 Symmetric = 2,
433 /// The constant padding type.
434 Constant = 3,
435}
436
437impl PaddingType {
438 pub fn debug_desc(self) -> &'static ns::String {
439 unsafe { MLCPaddingTypeDebugDescription(self) }
440 }
441}
442
443#[derive(Debug, Copy, Clone, Eq, PartialEq)]
444#[repr(i32)]
445pub enum PoolingType {
446 /// The max pooling type.
447 Max = 1,
448 /// The average pooling type.
449 Average = 2,
450 /// The L2-norm pooling type.
451 L2Norm = 3,
452}
453
454impl PoolingType {
455 pub fn debug_desc(self) -> &'static ns::String {
456 unsafe { MLCPoolingTypeDebugDescription(self) }
457 }
458}
459
460#[derive(Debug, Copy, Clone, Eq, PartialEq)]
461#[repr(i32)]
462pub enum ReductionType {
463 /// No reduction.
464 None = 0,
465 /// The sum reduction.
466 Sum = 1,
467 /// The mean reduction.
468 Mean = 2,
469 /// The max reduction.
470 Max = 3,
471 /// The min reduction.
472 Min = 4,
473 /// The argmax reduction.
474 ArgMax = 5,
475 /// The argmin reduction.
476 ArgMin = 6,
477 /// The L1norm reduction.
478 L1Norm = 7,
479 /// Any(X) = X_0 || X_1 || ... X_n
480 Any = 8,
481 /// Alf(X) = X_0 && X_1 && ... X_n
482 All = 9,
483}
484
485impl ReductionType {
486 #[inline]
487 pub fn debug_desc(self) -> &'static ns::String {
488 unsafe { MLCReductionTypeDebugDescription(self) }
489 }
490}
491
492#[derive(Debug, Copy, Clone, Eq, PartialEq)]
493#[repr(i32)]
494pub enum RegularizationType {
495 /// No regularization.
496 None = 0,
497
498 /// The L1 regularization.
499 L1 = 1,
500
501 /// The L2 regularization.
502 L2 = 2,
503}
504
505#[derive(Debug, Copy, Clone, Eq, PartialEq)]
506#[repr(i32)]
507pub enum SampleMode {
508 /// The nearest sample mode.
509 Nearest = 0,
510 /// The linear sample mode.
511 Linear = 1,
512}
513
514impl SampleMode {
515 pub fn debug_desc(self) -> &'static ns::String {
516 unsafe { MLCSampleModeDebugDescription(self) }
517 }
518}
519
520#[derive(Debug, Copy, Clone, Eq, PartialEq)]
521#[repr(i32)]
522pub enum SoftmaxOp {
523 /// The standard softmax operation.
524 Softmax = 0,
525 /// The log softmax operation.
526 LogSoftmax = 1,
527}
528
529impl SoftmaxOp {
530 #[inline]
531 pub fn debug_desc(self) -> &'static ns::String {
532 unsafe { MLCSoftmaxOperationDebugDescription(self) }
533 }
534}
535
536#[derive(Debug, Copy, Clone, Eq, PartialEq)]
537#[repr(usize)]
538pub enum LSTMResultMode {
539 /// The output result mode. When selected for an LSTM layer, the layer will produce a single result tensor representing the final output of the LSTM.
540 Output = 0,
541 /// The output and states result mode. When selected for an LSTM layer, the layer will produce three result tensors representing the final output of
542 /// the LSTM, the last hidden state, and the cell state, respectively.
543 OutputAndStates = 1,
544}
545
546impl LSTMResultMode {
547 pub fn debug_desc(self) -> &'static ns::String {
548 unsafe { MLCLSTMResultModeDebugDescription(self) }
549 }
550}
551
552#[derive(Debug, Copy, Clone, Eq, PartialEq)]
553#[repr(usize)]
554pub enum ComparisonOp {
555 Equal = 0,
556 NotEqual = 1,
557 Less = 2,
558 Greater = 3,
559 LessOrEqual = 4,
560 GreaterOrEqual = 5,
561 LogicalAND = 6,
562 LogicalOR = 7,
563 LogicalNOT = 8,
564 LogicalNAND = 9,
565 LogicalNOR = 10,
566 LogicalXOR = 11,
567}
568
569impl ComparisonOp {
570 pub fn debug_desc(self) -> &'static ns::String {
571 unsafe { MLCComparisonOperationDebugDescription(self) }
572 }
573}
574
575/// The type of clipping applied to gradient
576#[doc(alias = "MLCGradientClippingType")]
577#[derive(Debug, Copy, Clone, Eq, PartialEq)]
578#[repr(usize)]
579pub enum GradientClippingType {
580 ByValue = 0,
581 ByNorm = 1,
582 ByGlobalNorm = 2,
583}
584
585impl GradientClippingType {
586 pub fn debug_desc(self) -> &'static ns::String {
587 unsafe { MLCGradientClippingTypeDebugDescription(self) }
588 }
589}
590
591define_opts!(pub GraphCompilationOpts(u64));
592
593impl GraphCompilationOpts {
594 pub const DEBUG_LAYERS: Self = Self(0x01);
595 pub const DISABLE_LAYER_FUSION: Self = Self(0x02);
596 pub const LINK_GRAPHS: Self = Self(0x04);
597 pub const COMPUTE_ALL_GRADIENTS: Self = Self(0x08);
598}
599define_opts!(pub ExecutionOpts(u64));
600impl ExecutionOpts {
601 pub const SKIP_WRITING_INPUT_DATA_TO_DEVICE: Self = Self(0x01);
602 pub const SYNCHRONOUS: Self = Self(0x02);
603 pub const PROFILING: Self = Self(0x04);
604 pub const FORWARD_FOR_INFERENCE: Self = Self(0x08);
605 pub const PER_LAYER_PROFILING: Self = Self(0x10);
606}
607
608unsafe extern "C-unwind" {
609 fn MLCActivationTypeDebugDescription(activationType: ActivationType) -> &'static ns::String;
610 fn MLCArithmeticOperationDebugDescription(op: ArithmeticOp) -> &'static ns::String;
611 fn MLCPaddingPolicyDebugDescription(policy: PaddingPolicy) -> &'static ns::String;
612 fn MLCLossTypeDebugDescription(loss_type: LossType) -> &'static ns::String;
613 fn MLCReductionTypeDebugDescription(reduction_type: ReductionType) -> &'static ns::String;
614 fn MLCPaddingTypeDebugDescription(padding_type: PaddingType) -> &'static ns::String;
615 fn MLCConvolutionTypeDebugDescription(convolution_type: ConvolutionType)
616 -> &'static ns::String;
617 fn MLCPoolingTypeDebugDescription(pooling_type: PoolingType) -> &'static ns::String;
618 fn MLCSoftmaxOperationDebugDescription(operation: SoftmaxOp) -> &'static ns::String;
619 fn MLCSampleModeDebugDescription(mode: SampleMode) -> &'static ns::String;
620 fn MLCLSTMResultModeDebugDescription(mode: LSTMResultMode) -> &'static ns::String;
621 fn MLCComparisonOperationDebugDescription(operation: ComparisonOp) -> &'static ns::String;
622 fn MLCGradientClippingTypeDebugDescription(
623 gradient_clipping_type: GradientClippingType,
624 ) -> &'static ns::String;
625}
626
627#[cfg(test)]
628mod tests {
629 use crate::mlc;
630
631 #[test]
632 fn basics() {
633 let desc = mlc::ActivationType::ReLU.debug_desc();
634 assert_eq!(desc.to_string(), "ReLU")
635 }
636}