Skip to main content

cidre/mlc/
types.rs

1use crate::{blocks, define_opts, mlc, ns};
2
3pub type GraphCompletionHandler =
4    blocks::SyncBlock<fn(Option<&mlc::Tensor>, Option<&ns::Error>, ns::TimeInterval)>;
5
6#[derive(Debug, Copy, Clone, Eq, PartialEq)]
7#[repr(i32)]
8pub enum DType {
9    Invalid = 0,
10
11    /// The 32-bit floating-point data type.
12    F32 = 1,
13
14    /// The 16-bit floating-point data type.
15    F16 = 3,
16
17    /// Boolean data type.
18    Bool = 4,
19
20    /// The 64-bit integer data type
21    I64 = 5,
22
23    /// The 32-bit integer data type
24    I32 = 7,
25
26    /// The 8-bit integer data type
27    I8 = 8,
28
29    /// The 8-bit unsigned integer data type.
30    U8 = 9,
31}
32
33#[derive(Debug, Copy, Clone, Eq, PartialEq)]
34#[repr(i32)]
35pub enum RandomInitializerType {
36    Invalid = 0,
37
38    /// The uniform random initializer type.
39    Uniform = 1,
40
41    /// The glorot uniform random initializer type.
42    GlorotUniform = 2,
43
44    /// The Xavier random initializer type.
45    Xavier = 3,
46}
47
48#[derive(Debug, Copy, Clone, Eq, PartialEq)]
49#[repr(i32)]
50pub enum DeviceType {
51    /// The CPU device
52    Cpu = 0,
53
54    /// The GPU device
55    Gpu = 1,
56
57    /// The any device type.  When selected, the framework will automatically use the appropriate devices
58    /// to achieve the best performance.
59    Any = 2,
60
61    /// The  Apple Neural Engine device.  When selected, the framework will use the  Neural Engine to execute all layers that can be executed on it.
62    /// Layers that cannot be executed on the ANE will run on the CPU or GPU.   The Neural Engine device must be explicitly selected.  MLDeviceTypeAny
63    /// will not select the Neural Engine device.  In addition, this device can be used with inference graphs only.  This device cannot be used with a
64    /// training graph or an inference graph that shares layers with a training graph.
65    ///
66    Ane = 3,
67}
68
69#[doc(alias = "MLCArithmeticOperation")]
70#[derive(Debug, Copy, Clone, Eq, PartialEq)]
71#[repr(i32)]
72pub enum ArithmeticOp {
73    /// An operation that calculates the elementwise sum of its two inputs.
74    Add = 0,
75
76    /// An operation that calculates the elementwise difference of its two inputs.
77    Subtract = 1,
78
79    /// An operation that calculates the elementwise product of its two inputs.
80    Multiply = 2,
81
82    /// An operation that calculates the elementwise division of its two inputs.
83    Divide = 3,
84
85    /// An operation that calculates the elementwise floor of its two inputs.
86    Floor = 4,
87
88    /// An operation that calculates the elementwise round of its inputs.
89    Round = 5,
90
91    /// An operation that calculates the elementwise ceiling of its inputs.
92    Ceil = 6,
93
94    /// An operation that calculates the elementwise square root of its inputs.
95    Sqrt = 7,
96
97    /// An operation that calculates the elementwise reciprocal of the square root of its inputs.
98    RSqrt = 8,
99
100    /// An operation that calculates the elementwise sine of its inputs.
101    Sin = 9,
102
103    /// An operation that calculates the elementwise cosine of its inputs.
104    Cos = 10,
105
106    /// An operation that calculates the elementwise tangent of its inputs.
107    Tan = 11,
108
109    /// An operation that calculates the elementwise inverse sine of its inputs.
110    ASin = 12,
111
112    /// An operation that calculates the elementwise inverse cosine of its inputs.
113    ACos = 13,
114
115    /// An operation that calculates the elementwise inverse tangent of its inputs.
116    ATan = 14,
117
118    /// An operation that calculates the elementwise hyperbolic sine of its inputs.
119    SinH = 15,
120
121    /// An operation that calculates the elementwise hyperbolic cosine of its inputs.
122    CosH = 16,
123
124    /// An operation that calculates the elementwise hyperbolic tangent of its inputs.
125    TanH = 17,
126
127    /// An operation that calculates the elementwise inverse hyperbolic sine of its inputs.
128    ASinH = 18,
129
130    /// An operation that calculates the elementwise inverse hyperbolic cosine of its inputs.
131    ACosH = 19,
132
133    /// An operation that calculates the elementwise inverse hyperbolic tangent of its inputs.
134    ATanH = 20,
135
136    /// An operation that calculates the elementwise first input raised to the power of its second input.
137    Pow = 21,
138
139    /// An operation that calculates the elementwise result of e raised to the power of its input.
140    Exp = 22,
141
142    /// An operation that calculates the elementwise result of 2 raised to the power of its input.
143    Exp2 = 23,
144
145    /// An operation that calculates the elementwise natural logarithm of its input.
146    Log = 24,
147
148    /// An operation that calculates the elementwise base 2 logarithm of its input.
149    Log2 = 25,
150
151    /// An operation that calculates the elementwise product of its two inputs.  Returns 0 if y in x * y is zero, even if x is NaN or INF
152    MultiplyNoNaN = 26,
153
154    /// An operations that calculates the elementwise division of its two inputs.  Returns 0 if the denominator is 0.
155    DivideNoNaN = 27,
156
157    /// An operation that calculates the elementwise min of two inputs.
158    Min = 28,
159
160    /// An operations that calculates the elementwise max of two inputs.
161    Max = 29,
162}
163
164impl ArithmeticOp {
165    #[inline]
166    pub fn debug_desc(self) -> &'static ns::String {
167        unsafe { MLCArithmeticOperationDebugDescription(self) }
168    }
169}
170
171/// A loss function.
172#[doc(alias = "MLCLossType")]
173#[derive(Debug, Copy, Clone, Eq, PartialEq)]
174#[repr(i32)]
175pub enum LossType {
176    /// The mean absolute error loss.
177    MeanAbsoluteError = 0,
178
179    /// The mean squared error loss.
180    MeanSquaredError = 1,
181
182    /// The softmax cross entropy loss.
183    SoftmaxCrossEntropy = 2,
184
185    /// The sigmoid cross entropy loss.
186    SigmoidCrossEntropy = 3,
187
188    /// The categorical cross entropy loss.
189    CategoricalCrossEntropy = 4,
190
191    /// The hinge loss.
192    Hinge = 5,
193
194    /// The Huber loss.
195    Huber = 6,
196
197    /// The cosine distance loss.
198    CosineDistance = 7,
199
200    /// The log loss.
201    Log = 8,
202}
203
204impl LossType {
205    #[inline]
206    pub fn debug_desc(self) -> &'static ns::String {
207        unsafe { MLCLossTypeDebugDescription(self) }
208    }
209}
210
211/// An activation type that you specify for an activation descriptor.
212#[doc(alias = "MLCActivationType")]
213#[derive(Debug, Copy, Clone, Eq, PartialEq)]
214#[repr(i32)]
215pub enum ActivationType {
216    None = 0,
217
218    /// The ReLU activation type.
219    ///
220    /// This activation type implements the following function:
221    /// ```pseudo
222    /// f(x) = x >= 0 ? x : a * x`
223    /// ```
224    ReLU = 1,
225
226    /// The linear activation type.
227    ///
228    /// This activation type implements the following function:
229    /// ```pseudo
230    /// f(x) = a * x + b
231    /// ```
232    Linear = 2,
233
234    /// The sigmoid activation type.
235    /// This activation type implements the following function:
236    /// ```pseudo
237    /// f(x) = 1 / (1 + e⁻ˣ)
238    /// ```
239    Sigmoid = 3,
240
241    /// The hard sigmoid activation type.
242    ///
243    /// This activation type implements the following function:
244    /// ```pseudo
245    /// f(x) = clamp((x * a) + b, 0, 1)
246    /// ```
247    HardSigmoid = 4,
248
249    /// The hyperbolic tangent (TanH) activation type.
250    /// This activation type implements the following function:
251    /// ```pseudo
252    /// f(x) = a * tanh(b * x)
253    /// ```
254    TanH = 5,
255
256    /// The absolute activation type.
257    ///
258    /// This activation type implements the following function:
259    /// ```pseudo
260    /// f(x) = fabs(x)
261    /// ```
262    Absolute = 6,
263
264    /// The parametric soft plus activation type.
265    ///
266    /// This activation type implements the following function:
267    /// ```pseudo
268    /// f(x) = a * log(1 + e^(b * x))
269    /// ```
270    SoftPlus = 7,
271
272    /// The parametric soft sign activation type.
273    ///
274    /// This activation type implements the following function:
275    /// ```pseudo
276    /// f(x) = x / (1 + abs(x))
277    /// ```
278    SoftSign = 8,
279
280    /// The parametric ELU activation type.
281    /// This activation type implements the following function:
282    /// ```pseudo
283    /// f(x) = x >= 0 ? x : a * (exp(x) - 1)
284    /// ```
285    ELU = 9,
286
287    /// The ReLUN activation type.
288    ///
289    /// This activation type implements the following function:
290    /// ```pseudo
291    /// f(x) = min((x >= 0 ? x : a * x), b)
292    /// ```
293    ReLUN = 10,
294
295    /// The log sigmoid activation type.
296    ///
297    /// This activation type implements the following function:
298    /// ```pseudo
299    /// f(x) = log(1 / (1 + exp(-x)))
300    /// ```
301    LogSigmoid = 11,
302
303    /// The SELU activation type.
304    ///
305    /// This activation type implements the following function:
306    /// ```pseudo
307    /// f(x) = scale * (max(0, x) + min(0, α * (exp(x) − 1)))
308    /// ```
309    /// where:
310    /// ```pseudo
311    /// α = 1.6732632423543772848170429916717
312    /// scale = 1.0507009873554804934193349852946
313    /// ```
314    SELU = 12,
315
316    /// The CELU activation type.
317    ///
318    /// This activation type implements the following function:
319    /// ```pseudo
320    /// f(x) = max(0, x) + min(0, a * (exp(x / a) − 1))
321    /// ```
322    CELU = 13,
323
324    /// The hard shrink activation type.
325    ///
326    /// This activation type implements the following function:
327    /// ```pseudo
328    /// f(x) = x, if x > a or x < −a, else 0
329    /// ```
330    HardShrink = 14,
331
332    /// The soft shrink activation type.
333    ///
334    /// This activation type implements the following function:
335    /// ```pseudo
336    /// f(x) = x - a, if x > a, x + a, if x < −a, else 0
337    /// ```
338    SoftShrink = 15,
339
340    /// The hyperbolic tangent (TanH) shrink activation type.
341    /// This activation type implements the following function:
342    /// ```pseudo
343    /// f(x) = x - tanh(x)
344    /// ```
345    TanHShrink = 16,
346
347    /// The threshold activation type.
348    ///
349    /// This activation type implements the following function:
350    /// ```pseudo
351    ///  f(x) = x, if x > a, else b
352    /// ```
353    Threshold = 17,
354
355    /// The GELU activation type.
356    /// This activation type implements the following function:
357    /// ```pseudo
358    /// f(x) = x * CDF(x)
359    /// ```
360    GELU = 18,
361
362    /// The hardswish activation type.
363    ///
364    /// This activation type implements the following function:
365    /// ```pseudo
366    /// f(x) = 0, if x <= -3
367    /// f(x) = x, if x >= +3
368    /// f(x) = x * (x + 3)/6, otherwise
369    /// ```
370    HardSwish = 19,
371
372    /// The clamp activation type.
373    /// This activation type implements the following function:
374    /// ```pseudo
375    ///  f(x) = min(max(x, a), b)
376    /// ```
377    Clamp = 20,
378}
379
380impl ActivationType {
381    /// Returns a textual description of the arithmetic operation, suitable for debugging
382    pub fn debug_desc(self) -> &'static ns::String {
383        unsafe { MLCActivationTypeDebugDescription(self) }
384    }
385}
386
387#[derive(Debug, Copy, Clone, Eq, PartialEq)]
388#[repr(i32)]
389pub enum ConvolutionType {
390    /// The standard convolution type.
391    Standard = 0,
392
393    /// The transposed convolution type.
394    Transposed = 1,
395
396    /// The depthwise convolution type.
397    Depthwise = 2,
398}
399
400impl ConvolutionType {
401    pub fn debug_desc(self) -> &'static ns::String {
402        unsafe { MLCConvolutionTypeDebugDescription(self) }
403    }
404}
405
406#[derive(Debug, Copy, Clone, Eq, PartialEq)]
407#[repr(i32)]
408pub enum PaddingPolicy {
409    /// The "same" padding policy.
410    Same = 0,
411    /// The "valid" padding policy.
412    Valid = 1,
413    /// The choice to use explicitly specified padding sizes.
414    UsePaddingSize = 2,
415}
416
417impl PaddingPolicy {
418    #[inline]
419    pub fn debug_desc(self) -> &'static ns::String {
420        unsafe { MLCPaddingPolicyDebugDescription(self) }
421    }
422}
423
424#[derive(Debug, Copy, Clone, Eq, PartialEq)]
425#[repr(i32)]
426pub enum PaddingType {
427    /// The zero padding type.
428    Zero = 0,
429    /// The reflect padding type.
430    Reflect = 1,
431    /// The symmetric padding type.
432    Symmetric = 2,
433    /// The constant padding type.
434    Constant = 3,
435}
436
437impl PaddingType {
438    pub fn debug_desc(self) -> &'static ns::String {
439        unsafe { MLCPaddingTypeDebugDescription(self) }
440    }
441}
442
443#[derive(Debug, Copy, Clone, Eq, PartialEq)]
444#[repr(i32)]
445pub enum PoolingType {
446    /// The max pooling type.
447    Max = 1,
448    /// The average pooling type.
449    Average = 2,
450    /// The L2-norm pooling type.
451    L2Norm = 3,
452}
453
454impl PoolingType {
455    pub fn debug_desc(self) -> &'static ns::String {
456        unsafe { MLCPoolingTypeDebugDescription(self) }
457    }
458}
459
460#[derive(Debug, Copy, Clone, Eq, PartialEq)]
461#[repr(i32)]
462pub enum ReductionType {
463    /// No reduction.
464    None = 0,
465    /// The sum reduction.
466    Sum = 1,
467    /// The mean reduction.
468    Mean = 2,
469    /// The max reduction.
470    Max = 3,
471    /// The min reduction.
472    Min = 4,
473    /// The argmax reduction.
474    ArgMax = 5,
475    /// The argmin reduction.
476    ArgMin = 6,
477    /// The L1norm reduction.
478    L1Norm = 7,
479    /// Any(X) = X_0 || X_1 || ... X_n
480    Any = 8,
481    /// Alf(X) = X_0 && X_1 && ... X_n
482    All = 9,
483}
484
485impl ReductionType {
486    #[inline]
487    pub fn debug_desc(self) -> &'static ns::String {
488        unsafe { MLCReductionTypeDebugDescription(self) }
489    }
490}
491
492#[derive(Debug, Copy, Clone, Eq, PartialEq)]
493#[repr(i32)]
494pub enum RegularizationType {
495    /// No regularization.
496    None = 0,
497
498    /// The L1 regularization.
499    L1 = 1,
500
501    /// The L2 regularization.
502    L2 = 2,
503}
504
505#[derive(Debug, Copy, Clone, Eq, PartialEq)]
506#[repr(i32)]
507pub enum SampleMode {
508    /// The nearest sample mode.
509    Nearest = 0,
510    /// The linear sample mode.
511    Linear = 1,
512}
513
514impl SampleMode {
515    pub fn debug_desc(self) -> &'static ns::String {
516        unsafe { MLCSampleModeDebugDescription(self) }
517    }
518}
519
520#[derive(Debug, Copy, Clone, Eq, PartialEq)]
521#[repr(i32)]
522pub enum SoftmaxOp {
523    /// The standard softmax operation.
524    Softmax = 0,
525    /// The log softmax operation.
526    LogSoftmax = 1,
527}
528
529impl SoftmaxOp {
530    #[inline]
531    pub fn debug_desc(self) -> &'static ns::String {
532        unsafe { MLCSoftmaxOperationDebugDescription(self) }
533    }
534}
535
536#[derive(Debug, Copy, Clone, Eq, PartialEq)]
537#[repr(usize)]
538pub enum LSTMResultMode {
539    /// The output result mode. When selected for an LSTM layer, the layer will produce a single result tensor representing the final output of the LSTM.
540    Output = 0,
541    /// The output and states result mode. When selected for an LSTM layer, the layer will produce three result tensors representing the final output of
542    ///  the LSTM, the last hidden state, and the cell state, respectively.
543    OutputAndStates = 1,
544}
545
546impl LSTMResultMode {
547    pub fn debug_desc(self) -> &'static ns::String {
548        unsafe { MLCLSTMResultModeDebugDescription(self) }
549    }
550}
551
552#[derive(Debug, Copy, Clone, Eq, PartialEq)]
553#[repr(usize)]
554pub enum ComparisonOp {
555    Equal = 0,
556    NotEqual = 1,
557    Less = 2,
558    Greater = 3,
559    LessOrEqual = 4,
560    GreaterOrEqual = 5,
561    LogicalAND = 6,
562    LogicalOR = 7,
563    LogicalNOT = 8,
564    LogicalNAND = 9,
565    LogicalNOR = 10,
566    LogicalXOR = 11,
567}
568
569impl ComparisonOp {
570    pub fn debug_desc(self) -> &'static ns::String {
571        unsafe { MLCComparisonOperationDebugDescription(self) }
572    }
573}
574
575/// The type of clipping applied to gradient
576#[doc(alias = "MLCGradientClippingType")]
577#[derive(Debug, Copy, Clone, Eq, PartialEq)]
578#[repr(usize)]
579pub enum GradientClippingType {
580    ByValue = 0,
581    ByNorm = 1,
582    ByGlobalNorm = 2,
583}
584
585impl GradientClippingType {
586    pub fn debug_desc(self) -> &'static ns::String {
587        unsafe { MLCGradientClippingTypeDebugDescription(self) }
588    }
589}
590
591define_opts!(pub GraphCompilationOpts(u64));
592
593impl GraphCompilationOpts {
594    pub const DEBUG_LAYERS: Self = Self(0x01);
595    pub const DISABLE_LAYER_FUSION: Self = Self(0x02);
596    pub const LINK_GRAPHS: Self = Self(0x04);
597    pub const COMPUTE_ALL_GRADIENTS: Self = Self(0x08);
598}
599define_opts!(pub ExecutionOpts(u64));
600impl ExecutionOpts {
601    pub const SKIP_WRITING_INPUT_DATA_TO_DEVICE: Self = Self(0x01);
602    pub const SYNCHRONOUS: Self = Self(0x02);
603    pub const PROFILING: Self = Self(0x04);
604    pub const FORWARD_FOR_INFERENCE: Self = Self(0x08);
605    pub const PER_LAYER_PROFILING: Self = Self(0x10);
606}
607
608unsafe extern "C-unwind" {
609    fn MLCActivationTypeDebugDescription(activationType: ActivationType) -> &'static ns::String;
610    fn MLCArithmeticOperationDebugDescription(op: ArithmeticOp) -> &'static ns::String;
611    fn MLCPaddingPolicyDebugDescription(policy: PaddingPolicy) -> &'static ns::String;
612    fn MLCLossTypeDebugDescription(loss_type: LossType) -> &'static ns::String;
613    fn MLCReductionTypeDebugDescription(reduction_type: ReductionType) -> &'static ns::String;
614    fn MLCPaddingTypeDebugDescription(padding_type: PaddingType) -> &'static ns::String;
615    fn MLCConvolutionTypeDebugDescription(convolution_type: ConvolutionType)
616    -> &'static ns::String;
617    fn MLCPoolingTypeDebugDescription(pooling_type: PoolingType) -> &'static ns::String;
618    fn MLCSoftmaxOperationDebugDescription(operation: SoftmaxOp) -> &'static ns::String;
619    fn MLCSampleModeDebugDescription(mode: SampleMode) -> &'static ns::String;
620    fn MLCLSTMResultModeDebugDescription(mode: LSTMResultMode) -> &'static ns::String;
621    fn MLCComparisonOperationDebugDescription(operation: ComparisonOp) -> &'static ns::String;
622    fn MLCGradientClippingTypeDebugDescription(
623        gradient_clipping_type: GradientClippingType,
624    ) -> &'static ns::String;
625}
626
627#[cfg(test)]
628mod tests {
629    use crate::mlc;
630
631    #[test]
632    fn basics() {
633        let desc = mlc::ActivationType::ReLU.debug_desc();
634        assert_eq!(desc.to_string(), "ReLU")
635    }
636}