Skip to main content

apple_mpsgraph/
ops.rs

1use crate::ffi;
2use crate::graph::Tensor;
3use crate::types::collect_owned_tensors;
4use core::ffi::{c_char, c_void};
5use std::ffi::CString;
6
7fn optional_cstring(name: Option<&str>) -> Option<CString> {
8    name.and_then(|value| CString::new(value).ok())
9}
10
11#[allow(clippy::ref_option)]
12fn cstring_ptr(value: &Option<CString>) -> *const c_char {
13    value
14        .as_ref()
15        .map_or(core::ptr::null(), |value| value.as_ptr())
16}
17
18fn wrap_tensor(ptr: *mut c_void) -> Option<Tensor> {
19    if ptr.is_null() {
20        None
21    } else {
22        Some(Tensor::from_raw(ptr))
23    }
24}
25
26fn wrap_tensor_pair(box_handle: *mut c_void) -> Option<(Tensor, Tensor)> {
27    let mut values = collect_owned_tensors(box_handle);
28    if values.len() != 2 {
29        return None;
30    }
31    let second = values.pop()?;
32    let first = values.pop()?;
33    Some((first, second))
34}
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37#[repr(u32)]
38pub enum UnaryArithmeticOp {
39    Identity = 0,
40    Exponent = 1,
41    ExponentBase2 = 2,
42    ExponentBase10 = 3,
43    Logarithm = 4,
44    LogarithmBase2 = 5,
45    LogarithmBase10 = 6,
46    Square = 7,
47    SquareRoot = 8,
48    Reciprocal = 9,
49    Absolute = 10,
50    Negative = 11,
51    Sign = 12,
52    SignBit = 13,
53    Ceil = 14,
54    Floor = 15,
55    Round = 16,
56    Rint = 17,
57    Sin = 18,
58    Cos = 19,
59    Tan = 20,
60    Sinh = 21,
61    Cosh = 22,
62    Tanh = 23,
63    Asin = 24,
64    Acos = 25,
65    Atan = 26,
66    Asinh = 27,
67    Acosh = 28,
68    Atanh = 29,
69    IsNaN = 30,
70    IsInfinite = 31,
71}
72
73#[derive(Debug, Clone, Copy, PartialEq, Eq)]
74#[repr(u32)]
75pub enum BinaryArithmeticOp {
76    Addition = 0,
77    Subtraction = 1,
78    Multiplication = 2,
79    Division = 3,
80    DivisionNoNaN = 4,
81    Power = 5,
82    Minimum = 6,
83    Maximum = 7,
84    Equal = 8,
85    NotEqual = 9,
86    GreaterThan = 10,
87    GreaterThanOrEqualTo = 11,
88    LessThan = 12,
89    LessThanOrEqualTo = 13,
90    LogicalAnd = 14,
91    LogicalOr = 15,
92    Atan2 = 16,
93    FloorModulo = 17,
94}
95
96#[derive(Debug, Clone, Copy, PartialEq, Eq)]
97#[repr(u32)]
98pub enum ReductionAxisOp {
99    Sum = 0,
100    Maximum = 1,
101    Minimum = 2,
102    Product = 3,
103}
104
105#[derive(Debug, Clone, Copy, PartialEq, Eq)]
106#[repr(u32)]
107pub enum ReductionAxesOp {
108    Sum = 0,
109    Maximum = 1,
110    Minimum = 2,
111    Product = 3,
112}
113
114impl crate::graph::Graph {
115    #[must_use]
116    pub fn unary_arithmetic(
117        &self,
118        op: UnaryArithmeticOp,
119        tensor: &Tensor,
120        name: Option<&str>,
121    ) -> Option<Tensor> {
122        let name = optional_cstring(name);
123        // SAFETY: all handles remain valid for the duration of the call.
124        let ptr = unsafe {
125            ffi::mpsgraph_graph_arithmetic_unary(
126                self.as_ptr(),
127                op as u32,
128                tensor.as_ptr(),
129                cstring_ptr(&name),
130            )
131        };
132        wrap_tensor(ptr)
133    }
134
135    #[must_use]
136    pub fn binary_arithmetic(
137        &self,
138        op: BinaryArithmeticOp,
139        primary: &Tensor,
140        secondary: &Tensor,
141        name: Option<&str>,
142    ) -> Option<Tensor> {
143        let name = optional_cstring(name);
144        // SAFETY: all handles remain valid for the duration of the call.
145        let ptr = unsafe {
146            ffi::mpsgraph_graph_arithmetic_binary(
147                self.as_ptr(),
148                op as u32,
149                primary.as_ptr(),
150                secondary.as_ptr(),
151                cstring_ptr(&name),
152            )
153        };
154        wrap_tensor(ptr)
155    }
156
157    #[must_use]
158    pub fn select(
159        &self,
160        predicate: &Tensor,
161        true_tensor: &Tensor,
162        false_tensor: &Tensor,
163        name: Option<&str>,
164    ) -> Option<Tensor> {
165        let name = optional_cstring(name);
166        // SAFETY: all handles remain valid for the duration of the call.
167        let ptr = unsafe {
168            ffi::mpsgraph_graph_select(
169                self.as_ptr(),
170                predicate.as_ptr(),
171                true_tensor.as_ptr(),
172                false_tensor.as_ptr(),
173                cstring_ptr(&name),
174            )
175        };
176        wrap_tensor(ptr)
177    }
178
179    #[must_use]
180    pub fn relu_gradient(
181        &self,
182        gradient: &Tensor,
183        source: &Tensor,
184        name: Option<&str>,
185    ) -> Option<Tensor> {
186        let name = optional_cstring(name);
187        // SAFETY: all handles remain valid for the duration of the call.
188        let ptr = unsafe {
189            ffi::mpsgraph_graph_relu_gradient(
190                self.as_ptr(),
191                gradient.as_ptr(),
192                source.as_ptr(),
193                cstring_ptr(&name),
194            )
195        };
196        wrap_tensor(ptr)
197    }
198
199    #[must_use]
200    pub fn sigmoid_gradient(
201        &self,
202        gradient: &Tensor,
203        source: &Tensor,
204        name: Option<&str>,
205    ) -> Option<Tensor> {
206        let name = optional_cstring(name);
207        // SAFETY: all handles remain valid for the duration of the call.
208        let ptr = unsafe {
209            ffi::mpsgraph_graph_sigmoid_gradient(
210                self.as_ptr(),
211                gradient.as_ptr(),
212                source.as_ptr(),
213                cstring_ptr(&name),
214            )
215        };
216        wrap_tensor(ptr)
217    }
218
219    #[must_use]
220    pub fn softmax_gradient(
221        &self,
222        gradient: &Tensor,
223        source: &Tensor,
224        axis: isize,
225        name: Option<&str>,
226    ) -> Option<Tensor> {
227        let name = optional_cstring(name);
228        // SAFETY: all handles remain valid for the duration of the call.
229        let ptr = unsafe {
230            ffi::mpsgraph_graph_softmax_gradient(
231                self.as_ptr(),
232                gradient.as_ptr(),
233                source.as_ptr(),
234                axis,
235                cstring_ptr(&name),
236            )
237        };
238        wrap_tensor(ptr)
239    }
240
241    #[must_use]
242    pub fn leaky_relu(&self, tensor: &Tensor, alpha: f64, name: Option<&str>) -> Option<Tensor> {
243        let name = optional_cstring(name);
244        // SAFETY: all handles remain valid for the duration of the call.
245        let ptr = unsafe {
246            ffi::mpsgraph_graph_leaky_relu_scalar(
247                self.as_ptr(),
248                tensor.as_ptr(),
249                alpha,
250                cstring_ptr(&name),
251            )
252        };
253        wrap_tensor(ptr)
254    }
255
256    #[must_use]
257    pub fn leaky_relu_tensor(
258        &self,
259        tensor: &Tensor,
260        alpha_tensor: &Tensor,
261        name: Option<&str>,
262    ) -> Option<Tensor> {
263        let name = optional_cstring(name);
264        // SAFETY: all handles remain valid for the duration of the call.
265        let ptr = unsafe {
266            ffi::mpsgraph_graph_leaky_relu_tensor(
267                self.as_ptr(),
268                tensor.as_ptr(),
269                alpha_tensor.as_ptr(),
270                cstring_ptr(&name),
271            )
272        };
273        wrap_tensor(ptr)
274    }
275
276    #[must_use]
277    pub fn leaky_relu_gradient(
278        &self,
279        gradient: &Tensor,
280        source: &Tensor,
281        alpha_tensor: &Tensor,
282        name: Option<&str>,
283    ) -> Option<Tensor> {
284        let name = optional_cstring(name);
285        // SAFETY: all handles remain valid for the duration of the call.
286        let ptr = unsafe {
287            ffi::mpsgraph_graph_leaky_relu_gradient(
288                self.as_ptr(),
289                gradient.as_ptr(),
290                source.as_ptr(),
291                alpha_tensor.as_ptr(),
292                cstring_ptr(&name),
293            )
294        };
295        wrap_tensor(ptr)
296    }
297
298    #[must_use]
299    pub fn reduce_axis(
300        &self,
301        op: ReductionAxisOp,
302        tensor: &Tensor,
303        axis: isize,
304        name: Option<&str>,
305    ) -> Option<Tensor> {
306        let name = optional_cstring(name);
307        // SAFETY: all handles remain valid for the duration of the call.
308        let ptr = unsafe {
309            ffi::mpsgraph_graph_reduction_axis(
310                self.as_ptr(),
311                op as u32,
312                tensor.as_ptr(),
313                axis,
314                cstring_ptr(&name),
315            )
316        };
317        wrap_tensor(ptr)
318    }
319
320    #[must_use]
321    pub fn reduce_axes(
322        &self,
323        op: ReductionAxesOp,
324        tensor: &Tensor,
325        axes: &[usize],
326        name: Option<&str>,
327    ) -> Option<Tensor> {
328        let name = optional_cstring(name);
329        // SAFETY: all handles remain valid for the duration of the call.
330        let ptr = unsafe {
331            ffi::mpsgraph_graph_reduction_axes(
332                self.as_ptr(),
333                op as u32,
334                tensor.as_ptr(),
335                axes.as_ptr(),
336                axes.len(),
337                cstring_ptr(&name),
338            )
339        };
340        wrap_tensor(ptr)
341    }
342
343    #[must_use]
344    pub fn concat_pair(
345        &self,
346        first: &Tensor,
347        second: &Tensor,
348        dimension: isize,
349        name: Option<&str>,
350    ) -> Option<Tensor> {
351        let name = optional_cstring(name);
352        // SAFETY: all handles remain valid for the duration of the call.
353        let ptr = unsafe {
354            ffi::mpsgraph_graph_concat_pair(
355                self.as_ptr(),
356                first.as_ptr(),
357                second.as_ptr(),
358                dimension,
359                cstring_ptr(&name),
360            )
361        };
362        wrap_tensor(ptr)
363    }
364
365    #[must_use]
366    pub fn concat_tensors(
367        &self,
368        tensors: &[&Tensor],
369        dimension: isize,
370        interleave: bool,
371        name: Option<&str>,
372    ) -> Option<Tensor> {
373        let name = optional_cstring(name);
374        let handles = tensors
375            .iter()
376            .map(|tensor| tensor.as_ptr())
377            .collect::<Vec<_>>();
378        // SAFETY: all handles remain valid for the duration of the call.
379        let ptr = unsafe {
380            ffi::mpsgraph_graph_concat_tensors(
381                self.as_ptr(),
382                handles.as_ptr(),
383                handles.len(),
384                dimension,
385                interleave,
386                cstring_ptr(&name),
387            )
388        };
389        wrap_tensor(ptr)
390    }
391
392    #[must_use]
393    pub fn split_sizes(
394        &self,
395        tensor: &Tensor,
396        split_sizes: &[usize],
397        axis: isize,
398        name: Option<&str>,
399    ) -> Vec<Tensor> {
400        let name = optional_cstring(name);
401        // SAFETY: all handles remain valid for the duration of the call.
402        let box_handle = unsafe {
403            ffi::mpsgraph_graph_split_sizes(
404                self.as_ptr(),
405                tensor.as_ptr(),
406                split_sizes.as_ptr(),
407                split_sizes.len(),
408                axis,
409                cstring_ptr(&name),
410            )
411        };
412        collect_owned_tensors(box_handle)
413    }
414
415    #[must_use]
416    pub fn split_sizes_tensor(
417        &self,
418        tensor: &Tensor,
419        split_sizes_tensor: &Tensor,
420        axis: isize,
421        name: Option<&str>,
422    ) -> Vec<Tensor> {
423        let name = optional_cstring(name);
424        // SAFETY: all handles remain valid for the duration of the call.
425        let box_handle = unsafe {
426            ffi::mpsgraph_graph_split_sizes_tensor(
427                self.as_ptr(),
428                tensor.as_ptr(),
429                split_sizes_tensor.as_ptr(),
430                axis,
431                cstring_ptr(&name),
432            )
433        };
434        collect_owned_tensors(box_handle)
435    }
436
437    #[must_use]
438    pub fn split_num(
439        &self,
440        tensor: &Tensor,
441        num_splits: usize,
442        axis: isize,
443        name: Option<&str>,
444    ) -> Vec<Tensor> {
445        let name = optional_cstring(name);
446        // SAFETY: all handles remain valid for the duration of the call.
447        let box_handle = unsafe {
448            ffi::mpsgraph_graph_split_num(
449                self.as_ptr(),
450                tensor.as_ptr(),
451                num_splits,
452                axis,
453                cstring_ptr(&name),
454            )
455        };
456        collect_owned_tensors(box_handle)
457    }
458
459    #[must_use]
460    pub fn stack(&self, tensors: &[&Tensor], axis: isize, name: Option<&str>) -> Option<Tensor> {
461        let name = optional_cstring(name);
462        let handles = tensors
463            .iter()
464            .map(|tensor| tensor.as_ptr())
465            .collect::<Vec<_>>();
466        // SAFETY: all handles remain valid for the duration of the call.
467        let ptr = unsafe {
468            ffi::mpsgraph_graph_stack(
469                self.as_ptr(),
470                handles.as_ptr(),
471                handles.len(),
472                axis,
473                cstring_ptr(&name),
474            )
475        };
476        wrap_tensor(ptr)
477    }
478
479    #[must_use]
480    pub fn pad(
481        &self,
482        tensor: &Tensor,
483        padding_mode: isize,
484        left_padding: &[isize],
485        right_padding: &[isize],
486        constant_value: f64,
487        name: Option<&str>,
488    ) -> Option<Tensor> {
489        let name = optional_cstring(name);
490        // SAFETY: all handles remain valid for the duration of the call.
491        let ptr = unsafe {
492            ffi::mpsgraph_graph_pad(
493                self.as_ptr(),
494                tensor.as_ptr(),
495                padding_mode,
496                left_padding.as_ptr(),
497                left_padding.len(),
498                right_padding.as_ptr(),
499                right_padding.len(),
500                constant_value,
501                cstring_ptr(&name),
502            )
503        };
504        wrap_tensor(ptr)
505    }
506
507    #[must_use]
508    pub fn top_k(&self, source: &Tensor, k: usize, name: Option<&str>) -> Option<(Tensor, Tensor)> {
509        let name = optional_cstring(name);
510        // SAFETY: all handles remain valid for the duration of the call.
511        let box_handle = unsafe {
512            ffi::mpsgraph_graph_top_k(self.as_ptr(), source.as_ptr(), k, cstring_ptr(&name))
513        };
514        wrap_tensor_pair(box_handle)
515    }
516
517    #[must_use]
518    pub fn top_k_tensor(
519        &self,
520        source: &Tensor,
521        k_tensor: &Tensor,
522        name: Option<&str>,
523    ) -> Option<(Tensor, Tensor)> {
524        let name = optional_cstring(name);
525        // SAFETY: all handles remain valid for the duration of the call.
526        let box_handle = unsafe {
527            ffi::mpsgraph_graph_top_k_tensor(
528                self.as_ptr(),
529                source.as_ptr(),
530                k_tensor.as_ptr(),
531                cstring_ptr(&name),
532            )
533        };
534        wrap_tensor_pair(box_handle)
535    }
536}