Skip to main content

apple_mpsgraph/
ops.rs

1use crate::ffi;
2use crate::graph::Tensor;
3use crate::types::collect_owned_tensors;
4use core::ffi::{c_char, c_void};
5use std::ffi::CString;
6
7fn optional_cstring(name: Option<&str>) -> Option<CString> {
8    name.and_then(|value| CString::new(value).ok())
9}
10
11#[allow(clippy::ref_option)]
12fn cstring_ptr(value: &Option<CString>) -> *const c_char {
13    value
14        .as_ref()
15        .map_or(core::ptr::null(), |value| value.as_ptr())
16}
17
18fn wrap_tensor(ptr: *mut c_void) -> Option<Tensor> {
19    if ptr.is_null() {
20        None
21    } else {
22        Some(Tensor::from_raw(ptr))
23    }
24}
25
26fn wrap_tensor_pair(box_handle: *mut c_void) -> Option<(Tensor, Tensor)> {
27    let mut values = collect_owned_tensors(box_handle);
28    if values.len() != 2 {
29        return None;
30    }
31    let second = values.pop()?;
32    let first = values.pop()?;
33    Some((first, second))
34}
35
36/// Mirrors the `MPSGraph` framework counterpart for `UnaryArithmeticOp`.
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38#[repr(u32)]
39pub enum UnaryArithmeticOp {
40/// Mirrors the `MPSGraph` framework case `Identity`.
41    Identity = 0,
42/// Mirrors the `MPSGraph` framework case `Exponent`.
43    Exponent = 1,
44/// Mirrors the `MPSGraph` framework case `ExponentBase2`.
45    ExponentBase2 = 2,
46/// Mirrors the `MPSGraph` framework case `ExponentBase10`.
47    ExponentBase10 = 3,
48/// Mirrors the `MPSGraph` framework case `Logarithm`.
49    Logarithm = 4,
50/// Mirrors the `MPSGraph` framework case `LogarithmBase2`.
51    LogarithmBase2 = 5,
52/// Mirrors the `MPSGraph` framework case `LogarithmBase10`.
53    LogarithmBase10 = 6,
54/// Mirrors the `MPSGraph` framework case `Square`.
55    Square = 7,
56/// Mirrors the `MPSGraph` framework case `SquareRoot`.
57    SquareRoot = 8,
58/// Mirrors the `MPSGraph` framework case `Reciprocal`.
59    Reciprocal = 9,
60/// Mirrors the `MPSGraph` framework case `Absolute`.
61    Absolute = 10,
62/// Mirrors the `MPSGraph` framework case `Negative`.
63    Negative = 11,
64/// Mirrors the `MPSGraph` framework case `Sign`.
65    Sign = 12,
66/// Mirrors the `MPSGraph` framework case `SignBit`.
67    SignBit = 13,
68/// Mirrors the `MPSGraph` framework case `Ceil`.
69    Ceil = 14,
70/// Mirrors the `MPSGraph` framework case `Floor`.
71    Floor = 15,
72/// Mirrors the `MPSGraph` framework case `Round`.
73    Round = 16,
74/// Mirrors the `MPSGraph` framework case `Rint`.
75    Rint = 17,
76/// Mirrors the `MPSGraph` framework case `Sin`.
77    Sin = 18,
78/// Mirrors the `MPSGraph` framework case `Cos`.
79    Cos = 19,
80/// Mirrors the `MPSGraph` framework case `Tan`.
81    Tan = 20,
82/// Mirrors the `MPSGraph` framework case `Sinh`.
83    Sinh = 21,
84/// Mirrors the `MPSGraph` framework case `Cosh`.
85    Cosh = 22,
86/// Mirrors the `MPSGraph` framework case `Tanh`.
87    Tanh = 23,
88/// Mirrors the `MPSGraph` framework case `Asin`.
89    Asin = 24,
90/// Mirrors the `MPSGraph` framework case `Acos`.
91    Acos = 25,
92/// Mirrors the `MPSGraph` framework case `Atan`.
93    Atan = 26,
94/// Mirrors the `MPSGraph` framework case `Asinh`.
95    Asinh = 27,
96/// Mirrors the `MPSGraph` framework case `Acosh`.
97    Acosh = 28,
98/// Mirrors the `MPSGraph` framework case `Atanh`.
99    Atanh = 29,
100/// Mirrors the `MPSGraph` framework case `IsNaN`.
101    IsNaN = 30,
102/// Mirrors the `MPSGraph` framework case `IsInfinite`.
103    IsInfinite = 31,
104}
105
106/// Mirrors the `MPSGraph` framework counterpart for `BinaryArithmeticOp`.
107#[derive(Debug, Clone, Copy, PartialEq, Eq)]
108#[repr(u32)]
109pub enum BinaryArithmeticOp {
110/// Mirrors the `MPSGraph` framework case `Addition`.
111    Addition = 0,
112/// Mirrors the `MPSGraph` framework case `Subtraction`.
113    Subtraction = 1,
114/// Mirrors the `MPSGraph` framework case `Multiplication`.
115    Multiplication = 2,
116/// Mirrors the `MPSGraph` framework case `Division`.
117    Division = 3,
118/// Mirrors the `MPSGraph` framework case `DivisionNoNaN`.
119    DivisionNoNaN = 4,
120/// Mirrors the `MPSGraph` framework case `Power`.
121    Power = 5,
122/// Mirrors the `MPSGraph` framework case `Minimum`.
123    Minimum = 6,
124/// Mirrors the `MPSGraph` framework case `Maximum`.
125    Maximum = 7,
126/// Mirrors the `MPSGraph` framework case `Equal`.
127    Equal = 8,
128/// Mirrors the `MPSGraph` framework case `NotEqual`.
129    NotEqual = 9,
130/// Mirrors the `MPSGraph` framework case `GreaterThan`.
131    GreaterThan = 10,
132/// Mirrors the `MPSGraph` framework case `GreaterThanOrEqualTo`.
133    GreaterThanOrEqualTo = 11,
134/// Mirrors the `MPSGraph` framework case `LessThan`.
135    LessThan = 12,
136/// Mirrors the `MPSGraph` framework case `LessThanOrEqualTo`.
137    LessThanOrEqualTo = 13,
138/// Mirrors the `MPSGraph` framework case `LogicalAnd`.
139    LogicalAnd = 14,
140/// Mirrors the `MPSGraph` framework case `LogicalOr`.
141    LogicalOr = 15,
142/// Mirrors the `MPSGraph` framework case `Atan2`.
143    Atan2 = 16,
144/// Mirrors the `MPSGraph` framework case `FloorModulo`.
145    FloorModulo = 17,
146}
147
148/// Mirrors the `MPSGraph` framework counterpart for `ReductionAxisOp`.
149#[derive(Debug, Clone, Copy, PartialEq, Eq)]
150#[repr(u32)]
151pub enum ReductionAxisOp {
152/// Mirrors the `MPSGraph` framework case `Sum`.
153    Sum = 0,
154/// Mirrors the `MPSGraph` framework case `Maximum`.
155    Maximum = 1,
156/// Mirrors the `MPSGraph` framework case `Minimum`.
157    Minimum = 2,
158/// Mirrors the `MPSGraph` framework case `Product`.
159    Product = 3,
160}
161
162/// Mirrors the `MPSGraph` framework counterpart for `ReductionAxesOp`.
163#[derive(Debug, Clone, Copy, PartialEq, Eq)]
164#[repr(u32)]
165pub enum ReductionAxesOp {
166/// Mirrors the `MPSGraph` framework case `Sum`.
167    Sum = 0,
168/// Mirrors the `MPSGraph` framework case `Maximum`.
169    Maximum = 1,
170/// Mirrors the `MPSGraph` framework case `Minimum`.
171    Minimum = 2,
172/// Mirrors the `MPSGraph` framework case `Product`.
173    Product = 3,
174}
175
176impl crate::graph::Graph {
177/// Calls the `MPSGraph` framework counterpart for `unary_arithmetic`.
178    #[must_use]
179    pub fn unary_arithmetic(
180        &self,
181        op: UnaryArithmeticOp,
182        tensor: &Tensor,
183        name: Option<&str>,
184    ) -> Option<Tensor> {
185        let name = optional_cstring(name);
186        // SAFETY: all handles remain valid for the duration of the call.
187        let ptr = unsafe {
188            ffi::mpsgraph_graph_arithmetic_unary(
189                self.as_ptr(),
190                op as u32,
191                tensor.as_ptr(),
192                cstring_ptr(&name),
193            )
194        };
195        wrap_tensor(ptr)
196    }
197
198/// Calls the `MPSGraph` framework counterpart for `binary_arithmetic`.
199    #[must_use]
200    pub fn binary_arithmetic(
201        &self,
202        op: BinaryArithmeticOp,
203        primary: &Tensor,
204        secondary: &Tensor,
205        name: Option<&str>,
206    ) -> Option<Tensor> {
207        let name = optional_cstring(name);
208        // SAFETY: all handles remain valid for the duration of the call.
209        let ptr = unsafe {
210            ffi::mpsgraph_graph_arithmetic_binary(
211                self.as_ptr(),
212                op as u32,
213                primary.as_ptr(),
214                secondary.as_ptr(),
215                cstring_ptr(&name),
216            )
217        };
218        wrap_tensor(ptr)
219    }
220
221/// Calls the `MPSGraph` framework counterpart for `select`.
222    #[must_use]
223    pub fn select(
224        &self,
225        predicate: &Tensor,
226        true_tensor: &Tensor,
227        false_tensor: &Tensor,
228        name: Option<&str>,
229    ) -> Option<Tensor> {
230        let name = optional_cstring(name);
231        // SAFETY: all handles remain valid for the duration of the call.
232        let ptr = unsafe {
233            ffi::mpsgraph_graph_select(
234                self.as_ptr(),
235                predicate.as_ptr(),
236                true_tensor.as_ptr(),
237                false_tensor.as_ptr(),
238                cstring_ptr(&name),
239            )
240        };
241        wrap_tensor(ptr)
242    }
243
244/// Calls the `MPSGraph` framework counterpart for `relu_gradient`.
245    #[must_use]
246    pub fn relu_gradient(
247        &self,
248        gradient: &Tensor,
249        source: &Tensor,
250        name: Option<&str>,
251    ) -> Option<Tensor> {
252        let name = optional_cstring(name);
253        // SAFETY: all handles remain valid for the duration of the call.
254        let ptr = unsafe {
255            ffi::mpsgraph_graph_relu_gradient(
256                self.as_ptr(),
257                gradient.as_ptr(),
258                source.as_ptr(),
259                cstring_ptr(&name),
260            )
261        };
262        wrap_tensor(ptr)
263    }
264
265/// Calls the `MPSGraph` framework counterpart for `sigmoid_gradient`.
266    #[must_use]
267    pub fn sigmoid_gradient(
268        &self,
269        gradient: &Tensor,
270        source: &Tensor,
271        name: Option<&str>,
272    ) -> Option<Tensor> {
273        let name = optional_cstring(name);
274        // SAFETY: all handles remain valid for the duration of the call.
275        let ptr = unsafe {
276            ffi::mpsgraph_graph_sigmoid_gradient(
277                self.as_ptr(),
278                gradient.as_ptr(),
279                source.as_ptr(),
280                cstring_ptr(&name),
281            )
282        };
283        wrap_tensor(ptr)
284    }
285
286/// Calls the `MPSGraph` framework counterpart for `softmax_gradient`.
287    #[must_use]
288    pub fn softmax_gradient(
289        &self,
290        gradient: &Tensor,
291        source: &Tensor,
292        axis: isize,
293        name: Option<&str>,
294    ) -> Option<Tensor> {
295        let name = optional_cstring(name);
296        // SAFETY: all handles remain valid for the duration of the call.
297        let ptr = unsafe {
298            ffi::mpsgraph_graph_softmax_gradient(
299                self.as_ptr(),
300                gradient.as_ptr(),
301                source.as_ptr(),
302                axis,
303                cstring_ptr(&name),
304            )
305        };
306        wrap_tensor(ptr)
307    }
308
309/// Calls the `MPSGraph` framework counterpart for `leaky_relu`.
310    #[must_use]
311    pub fn leaky_relu(&self, tensor: &Tensor, alpha: f64, name: Option<&str>) -> Option<Tensor> {
312        let name = optional_cstring(name);
313        // SAFETY: all handles remain valid for the duration of the call.
314        let ptr = unsafe {
315            ffi::mpsgraph_graph_leaky_relu_scalar(
316                self.as_ptr(),
317                tensor.as_ptr(),
318                alpha,
319                cstring_ptr(&name),
320            )
321        };
322        wrap_tensor(ptr)
323    }
324
325/// Calls the `MPSGraph` framework counterpart for `leaky_relu_tensor`.
326    #[must_use]
327    pub fn leaky_relu_tensor(
328        &self,
329        tensor: &Tensor,
330        alpha_tensor: &Tensor,
331        name: Option<&str>,
332    ) -> Option<Tensor> {
333        let name = optional_cstring(name);
334        // SAFETY: all handles remain valid for the duration of the call.
335        let ptr = unsafe {
336            ffi::mpsgraph_graph_leaky_relu_tensor(
337                self.as_ptr(),
338                tensor.as_ptr(),
339                alpha_tensor.as_ptr(),
340                cstring_ptr(&name),
341            )
342        };
343        wrap_tensor(ptr)
344    }
345
346/// Calls the `MPSGraph` framework counterpart for `leaky_relu_gradient`.
347    #[must_use]
348    pub fn leaky_relu_gradient(
349        &self,
350        gradient: &Tensor,
351        source: &Tensor,
352        alpha_tensor: &Tensor,
353        name: Option<&str>,
354    ) -> Option<Tensor> {
355        let name = optional_cstring(name);
356        // SAFETY: all handles remain valid for the duration of the call.
357        let ptr = unsafe {
358            ffi::mpsgraph_graph_leaky_relu_gradient(
359                self.as_ptr(),
360                gradient.as_ptr(),
361                source.as_ptr(),
362                alpha_tensor.as_ptr(),
363                cstring_ptr(&name),
364            )
365        };
366        wrap_tensor(ptr)
367    }
368
369/// Calls the `MPSGraph` framework counterpart for `reduce_axis`.
370    #[must_use]
371    pub fn reduce_axis(
372        &self,
373        op: ReductionAxisOp,
374        tensor: &Tensor,
375        axis: isize,
376        name: Option<&str>,
377    ) -> Option<Tensor> {
378        let name = optional_cstring(name);
379        // SAFETY: all handles remain valid for the duration of the call.
380        let ptr = unsafe {
381            ffi::mpsgraph_graph_reduction_axis(
382                self.as_ptr(),
383                op as u32,
384                tensor.as_ptr(),
385                axis,
386                cstring_ptr(&name),
387            )
388        };
389        wrap_tensor(ptr)
390    }
391
392/// Calls the `MPSGraph` framework counterpart for `reduce_axes`.
393    #[must_use]
394    pub fn reduce_axes(
395        &self,
396        op: ReductionAxesOp,
397        tensor: &Tensor,
398        axes: &[usize],
399        name: Option<&str>,
400    ) -> Option<Tensor> {
401        let name = optional_cstring(name);
402        // SAFETY: all handles remain valid for the duration of the call.
403        let ptr = unsafe {
404            ffi::mpsgraph_graph_reduction_axes(
405                self.as_ptr(),
406                op as u32,
407                tensor.as_ptr(),
408                axes.as_ptr(),
409                axes.len(),
410                cstring_ptr(&name),
411            )
412        };
413        wrap_tensor(ptr)
414    }
415
416/// Calls the `MPSGraph` framework counterpart for `concat_pair`.
417    #[must_use]
418    pub fn concat_pair(
419        &self,
420        first: &Tensor,
421        second: &Tensor,
422        dimension: isize,
423        name: Option<&str>,
424    ) -> Option<Tensor> {
425        let name = optional_cstring(name);
426        // SAFETY: all handles remain valid for the duration of the call.
427        let ptr = unsafe {
428            ffi::mpsgraph_graph_concat_pair(
429                self.as_ptr(),
430                first.as_ptr(),
431                second.as_ptr(),
432                dimension,
433                cstring_ptr(&name),
434            )
435        };
436        wrap_tensor(ptr)
437    }
438
439/// Calls the `MPSGraph` framework counterpart for `concat_tensors`.
440    #[must_use]
441    pub fn concat_tensors(
442        &self,
443        tensors: &[&Tensor],
444        dimension: isize,
445        interleave: bool,
446        name: Option<&str>,
447    ) -> Option<Tensor> {
448        let name = optional_cstring(name);
449        let handles = tensors
450            .iter()
451            .map(|tensor| tensor.as_ptr())
452            .collect::<Vec<_>>();
453        // SAFETY: all handles remain valid for the duration of the call.
454        let ptr = unsafe {
455            ffi::mpsgraph_graph_concat_tensors(
456                self.as_ptr(),
457                handles.as_ptr(),
458                handles.len(),
459                dimension,
460                interleave,
461                cstring_ptr(&name),
462            )
463        };
464        wrap_tensor(ptr)
465    }
466
467/// Calls the `MPSGraph` framework counterpart for `split_sizes`.
468    #[must_use]
469    pub fn split_sizes(
470        &self,
471        tensor: &Tensor,
472        split_sizes: &[usize],
473        axis: isize,
474        name: Option<&str>,
475    ) -> Vec<Tensor> {
476        let name = optional_cstring(name);
477        // SAFETY: all handles remain valid for the duration of the call.
478        let box_handle = unsafe {
479            ffi::mpsgraph_graph_split_sizes(
480                self.as_ptr(),
481                tensor.as_ptr(),
482                split_sizes.as_ptr(),
483                split_sizes.len(),
484                axis,
485                cstring_ptr(&name),
486            )
487        };
488        collect_owned_tensors(box_handle)
489    }
490
491/// Calls the `MPSGraph` framework counterpart for `split_sizes_tensor`.
492    #[must_use]
493    pub fn split_sizes_tensor(
494        &self,
495        tensor: &Tensor,
496        split_sizes_tensor: &Tensor,
497        axis: isize,
498        name: Option<&str>,
499    ) -> Vec<Tensor> {
500        let name = optional_cstring(name);
501        // SAFETY: all handles remain valid for the duration of the call.
502        let box_handle = unsafe {
503            ffi::mpsgraph_graph_split_sizes_tensor(
504                self.as_ptr(),
505                tensor.as_ptr(),
506                split_sizes_tensor.as_ptr(),
507                axis,
508                cstring_ptr(&name),
509            )
510        };
511        collect_owned_tensors(box_handle)
512    }
513
514/// Calls the `MPSGraph` framework counterpart for `split_num`.
515    #[must_use]
516    pub fn split_num(
517        &self,
518        tensor: &Tensor,
519        num_splits: usize,
520        axis: isize,
521        name: Option<&str>,
522    ) -> Vec<Tensor> {
523        let name = optional_cstring(name);
524        // SAFETY: all handles remain valid for the duration of the call.
525        let box_handle = unsafe {
526            ffi::mpsgraph_graph_split_num(
527                self.as_ptr(),
528                tensor.as_ptr(),
529                num_splits,
530                axis,
531                cstring_ptr(&name),
532            )
533        };
534        collect_owned_tensors(box_handle)
535    }
536
537/// Calls the `MPSGraph` framework counterpart for `stack`.
538    #[must_use]
539    pub fn stack(&self, tensors: &[&Tensor], axis: isize, name: Option<&str>) -> Option<Tensor> {
540        let name = optional_cstring(name);
541        let handles = tensors
542            .iter()
543            .map(|tensor| tensor.as_ptr())
544            .collect::<Vec<_>>();
545        // SAFETY: all handles remain valid for the duration of the call.
546        let ptr = unsafe {
547            ffi::mpsgraph_graph_stack(
548                self.as_ptr(),
549                handles.as_ptr(),
550                handles.len(),
551                axis,
552                cstring_ptr(&name),
553            )
554        };
555        wrap_tensor(ptr)
556    }
557
558/// Calls the `MPSGraph` framework counterpart for `pad`.
559    #[must_use]
560    pub fn pad(
561        &self,
562        tensor: &Tensor,
563        padding_mode: isize,
564        left_padding: &[isize],
565        right_padding: &[isize],
566        constant_value: f64,
567        name: Option<&str>,
568    ) -> Option<Tensor> {
569        let name = optional_cstring(name);
570        // SAFETY: all handles remain valid for the duration of the call.
571        let ptr = unsafe {
572            ffi::mpsgraph_graph_pad(
573                self.as_ptr(),
574                tensor.as_ptr(),
575                padding_mode,
576                left_padding.as_ptr(),
577                left_padding.len(),
578                right_padding.as_ptr(),
579                right_padding.len(),
580                constant_value,
581                cstring_ptr(&name),
582            )
583        };
584        wrap_tensor(ptr)
585    }
586
587/// Calls the `MPSGraph` framework counterpart for `top_k`.
588    #[must_use]
589    pub fn top_k(&self, source: &Tensor, k: usize, name: Option<&str>) -> Option<(Tensor, Tensor)> {
590        let name = optional_cstring(name);
591        // SAFETY: all handles remain valid for the duration of the call.
592        let box_handle = unsafe {
593            ffi::mpsgraph_graph_top_k(self.as_ptr(), source.as_ptr(), k, cstring_ptr(&name))
594        };
595        wrap_tensor_pair(box_handle)
596    }
597
598/// Calls the `MPSGraph` framework counterpart for `top_k_tensor`.
599    #[must_use]
600    pub fn top_k_tensor(
601        &self,
602        source: &Tensor,
603        k_tensor: &Tensor,
604        name: Option<&str>,
605    ) -> Option<(Tensor, Tensor)> {
606        let name = optional_cstring(name);
607        // SAFETY: all handles remain valid for the duration of the call.
608        let box_handle = unsafe {
609            ffi::mpsgraph_graph_top_k_tensor(
610                self.as_ptr(),
611                source.as_ptr(),
612                k_tensor.as_ptr(),
613                cstring_ptr(&name),
614            )
615        };
616        wrap_tensor_pair(box_handle)
617    }
618}