Skip to main content

apple_mpsgraph/
ops.rs

1use crate::ffi;
2use crate::graph::Tensor;
3use crate::types::collect_owned_tensors;
4use core::ffi::{c_char, c_void};
5use std::ffi::CString;
6
7fn optional_cstring(name: Option<&str>) -> Option<CString> {
8    name.and_then(|value| CString::new(value).ok())
9}
10
11#[allow(clippy::ref_option)]
12fn cstring_ptr(value: &Option<CString>) -> *const c_char {
13    value.as_ref().map_or(core::ptr::null(), |value| value.as_ptr())
14}
15
16fn wrap_tensor(ptr: *mut c_void) -> Option<Tensor> {
17    if ptr.is_null() {
18        None
19    } else {
20        Some(Tensor::from_raw(ptr))
21    }
22}
23
24fn wrap_tensor_pair(box_handle: *mut c_void) -> Option<(Tensor, Tensor)> {
25    let mut values = collect_owned_tensors(box_handle);
26    if values.len() != 2 {
27        return None;
28    }
29    let second = values.pop()?;
30    let first = values.pop()?;
31    Some((first, second))
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35#[repr(u32)]
36pub enum UnaryArithmeticOp {
37    Identity = 0,
38    Exponent = 1,
39    ExponentBase2 = 2,
40    ExponentBase10 = 3,
41    Logarithm = 4,
42    LogarithmBase2 = 5,
43    LogarithmBase10 = 6,
44    Square = 7,
45    SquareRoot = 8,
46    Reciprocal = 9,
47    Absolute = 10,
48    Negative = 11,
49    Sign = 12,
50    SignBit = 13,
51    Ceil = 14,
52    Floor = 15,
53    Round = 16,
54    Rint = 17,
55    Sin = 18,
56    Cos = 19,
57    Tan = 20,
58    Sinh = 21,
59    Cosh = 22,
60    Tanh = 23,
61    Asin = 24,
62    Acos = 25,
63    Atan = 26,
64    Asinh = 27,
65    Acosh = 28,
66    Atanh = 29,
67    IsNaN = 30,
68    IsInfinite = 31,
69}
70
71#[derive(Debug, Clone, Copy, PartialEq, Eq)]
72#[repr(u32)]
73pub enum BinaryArithmeticOp {
74    Addition = 0,
75    Subtraction = 1,
76    Multiplication = 2,
77    Division = 3,
78    DivisionNoNaN = 4,
79    Power = 5,
80    Minimum = 6,
81    Maximum = 7,
82    Equal = 8,
83    NotEqual = 9,
84    GreaterThan = 10,
85    GreaterThanOrEqualTo = 11,
86    LessThan = 12,
87    LessThanOrEqualTo = 13,
88    LogicalAnd = 14,
89    LogicalOr = 15,
90    Atan2 = 16,
91    FloorModulo = 17,
92}
93
94#[derive(Debug, Clone, Copy, PartialEq, Eq)]
95#[repr(u32)]
96pub enum ReductionAxisOp {
97    Sum = 0,
98    Maximum = 1,
99    Minimum = 2,
100    Product = 3,
101}
102
103#[derive(Debug, Clone, Copy, PartialEq, Eq)]
104#[repr(u32)]
105pub enum ReductionAxesOp {
106    Sum = 0,
107    Maximum = 1,
108    Minimum = 2,
109    Product = 3,
110}
111
112impl crate::graph::Graph {
113    #[must_use]
114    pub fn unary_arithmetic(
115        &self,
116        op: UnaryArithmeticOp,
117        tensor: &Tensor,
118        name: Option<&str>,
119    ) -> Option<Tensor> {
120        let name = optional_cstring(name);
121        // SAFETY: all handles remain valid for the duration of the call.
122        let ptr = unsafe {
123            ffi::mpsgraph_graph_arithmetic_unary(
124                self.as_ptr(),
125                op as u32,
126                tensor.as_ptr(),
127                cstring_ptr(&name),
128            )
129        };
130        wrap_tensor(ptr)
131    }
132
133    #[must_use]
134    pub fn binary_arithmetic(
135        &self,
136        op: BinaryArithmeticOp,
137        primary: &Tensor,
138        secondary: &Tensor,
139        name: Option<&str>,
140    ) -> Option<Tensor> {
141        let name = optional_cstring(name);
142        // SAFETY: all handles remain valid for the duration of the call.
143        let ptr = unsafe {
144            ffi::mpsgraph_graph_arithmetic_binary(
145                self.as_ptr(),
146                op as u32,
147                primary.as_ptr(),
148                secondary.as_ptr(),
149                cstring_ptr(&name),
150            )
151        };
152        wrap_tensor(ptr)
153    }
154
155    #[must_use]
156    pub fn select(
157        &self,
158        predicate: &Tensor,
159        true_tensor: &Tensor,
160        false_tensor: &Tensor,
161        name: Option<&str>,
162    ) -> Option<Tensor> {
163        let name = optional_cstring(name);
164        // SAFETY: all handles remain valid for the duration of the call.
165        let ptr = unsafe {
166            ffi::mpsgraph_graph_select(
167                self.as_ptr(),
168                predicate.as_ptr(),
169                true_tensor.as_ptr(),
170                false_tensor.as_ptr(),
171                cstring_ptr(&name),
172            )
173        };
174        wrap_tensor(ptr)
175    }
176
177    #[must_use]
178    pub fn relu_gradient(
179        &self,
180        gradient: &Tensor,
181        source: &Tensor,
182        name: Option<&str>,
183    ) -> Option<Tensor> {
184        let name = optional_cstring(name);
185        // SAFETY: all handles remain valid for the duration of the call.
186        let ptr = unsafe {
187            ffi::mpsgraph_graph_relu_gradient(
188                self.as_ptr(),
189                gradient.as_ptr(),
190                source.as_ptr(),
191                cstring_ptr(&name),
192            )
193        };
194        wrap_tensor(ptr)
195    }
196
197    #[must_use]
198    pub fn sigmoid_gradient(
199        &self,
200        gradient: &Tensor,
201        source: &Tensor,
202        name: Option<&str>,
203    ) -> Option<Tensor> {
204        let name = optional_cstring(name);
205        // SAFETY: all handles remain valid for the duration of the call.
206        let ptr = unsafe {
207            ffi::mpsgraph_graph_sigmoid_gradient(
208                self.as_ptr(),
209                gradient.as_ptr(),
210                source.as_ptr(),
211                cstring_ptr(&name),
212            )
213        };
214        wrap_tensor(ptr)
215    }
216
217    #[must_use]
218    pub fn softmax_gradient(
219        &self,
220        gradient: &Tensor,
221        source: &Tensor,
222        axis: isize,
223        name: Option<&str>,
224    ) -> Option<Tensor> {
225        let name = optional_cstring(name);
226        // SAFETY: all handles remain valid for the duration of the call.
227        let ptr = unsafe {
228            ffi::mpsgraph_graph_softmax_gradient(
229                self.as_ptr(),
230                gradient.as_ptr(),
231                source.as_ptr(),
232                axis,
233                cstring_ptr(&name),
234            )
235        };
236        wrap_tensor(ptr)
237    }
238
239    #[must_use]
240    pub fn leaky_relu(
241        &self,
242        tensor: &Tensor,
243        alpha: f64,
244        name: Option<&str>,
245    ) -> Option<Tensor> {
246        let name = optional_cstring(name);
247        // SAFETY: all handles remain valid for the duration of the call.
248        let ptr = unsafe {
249            ffi::mpsgraph_graph_leaky_relu_scalar(
250                self.as_ptr(),
251                tensor.as_ptr(),
252                alpha,
253                cstring_ptr(&name),
254            )
255        };
256        wrap_tensor(ptr)
257    }
258
259    #[must_use]
260    pub fn leaky_relu_tensor(
261        &self,
262        tensor: &Tensor,
263        alpha_tensor: &Tensor,
264        name: Option<&str>,
265    ) -> Option<Tensor> {
266        let name = optional_cstring(name);
267        // SAFETY: all handles remain valid for the duration of the call.
268        let ptr = unsafe {
269            ffi::mpsgraph_graph_leaky_relu_tensor(
270                self.as_ptr(),
271                tensor.as_ptr(),
272                alpha_tensor.as_ptr(),
273                cstring_ptr(&name),
274            )
275        };
276        wrap_tensor(ptr)
277    }
278
279    #[must_use]
280    pub fn leaky_relu_gradient(
281        &self,
282        gradient: &Tensor,
283        source: &Tensor,
284        alpha_tensor: &Tensor,
285        name: Option<&str>,
286    ) -> Option<Tensor> {
287        let name = optional_cstring(name);
288        // SAFETY: all handles remain valid for the duration of the call.
289        let ptr = unsafe {
290            ffi::mpsgraph_graph_leaky_relu_gradient(
291                self.as_ptr(),
292                gradient.as_ptr(),
293                source.as_ptr(),
294                alpha_tensor.as_ptr(),
295                cstring_ptr(&name),
296            )
297        };
298        wrap_tensor(ptr)
299    }
300
301    #[must_use]
302    pub fn reduce_axis(
303        &self,
304        op: ReductionAxisOp,
305        tensor: &Tensor,
306        axis: isize,
307        name: Option<&str>,
308    ) -> Option<Tensor> {
309        let name = optional_cstring(name);
310        // SAFETY: all handles remain valid for the duration of the call.
311        let ptr = unsafe {
312            ffi::mpsgraph_graph_reduction_axis(
313                self.as_ptr(),
314                op as u32,
315                tensor.as_ptr(),
316                axis,
317                cstring_ptr(&name),
318            )
319        };
320        wrap_tensor(ptr)
321    }
322
323    #[must_use]
324    pub fn reduce_axes(
325        &self,
326        op: ReductionAxesOp,
327        tensor: &Tensor,
328        axes: &[usize],
329        name: Option<&str>,
330    ) -> Option<Tensor> {
331        let name = optional_cstring(name);
332        // SAFETY: all handles remain valid for the duration of the call.
333        let ptr = unsafe {
334            ffi::mpsgraph_graph_reduction_axes(
335                self.as_ptr(),
336                op as u32,
337                tensor.as_ptr(),
338                axes.as_ptr(),
339                axes.len(),
340                cstring_ptr(&name),
341            )
342        };
343        wrap_tensor(ptr)
344    }
345
346    #[must_use]
347    pub fn concat_pair(
348        &self,
349        first: &Tensor,
350        second: &Tensor,
351        dimension: isize,
352        name: Option<&str>,
353    ) -> Option<Tensor> {
354        let name = optional_cstring(name);
355        // SAFETY: all handles remain valid for the duration of the call.
356        let ptr = unsafe {
357            ffi::mpsgraph_graph_concat_pair(
358                self.as_ptr(),
359                first.as_ptr(),
360                second.as_ptr(),
361                dimension,
362                cstring_ptr(&name),
363            )
364        };
365        wrap_tensor(ptr)
366    }
367
368    #[must_use]
369    pub fn concat_tensors(
370        &self,
371        tensors: &[&Tensor],
372        dimension: isize,
373        interleave: bool,
374        name: Option<&str>,
375    ) -> Option<Tensor> {
376        let name = optional_cstring(name);
377        let handles = tensors.iter().map(|tensor| tensor.as_ptr()).collect::<Vec<_>>();
378        // SAFETY: all handles remain valid for the duration of the call.
379        let ptr = unsafe {
380            ffi::mpsgraph_graph_concat_tensors(
381                self.as_ptr(),
382                handles.as_ptr(),
383                handles.len(),
384                dimension,
385                interleave,
386                cstring_ptr(&name),
387            )
388        };
389        wrap_tensor(ptr)
390    }
391
392    #[must_use]
393    pub fn split_sizes(
394        &self,
395        tensor: &Tensor,
396        split_sizes: &[usize],
397        axis: isize,
398        name: Option<&str>,
399    ) -> Vec<Tensor> {
400        let name = optional_cstring(name);
401        // SAFETY: all handles remain valid for the duration of the call.
402        let box_handle = unsafe {
403            ffi::mpsgraph_graph_split_sizes(
404                self.as_ptr(),
405                tensor.as_ptr(),
406                split_sizes.as_ptr(),
407                split_sizes.len(),
408                axis,
409                cstring_ptr(&name),
410            )
411        };
412        collect_owned_tensors(box_handle)
413    }
414
415    #[must_use]
416    pub fn split_sizes_tensor(
417        &self,
418        tensor: &Tensor,
419        split_sizes_tensor: &Tensor,
420        axis: isize,
421        name: Option<&str>,
422    ) -> Vec<Tensor> {
423        let name = optional_cstring(name);
424        // SAFETY: all handles remain valid for the duration of the call.
425        let box_handle = unsafe {
426            ffi::mpsgraph_graph_split_sizes_tensor(
427                self.as_ptr(),
428                tensor.as_ptr(),
429                split_sizes_tensor.as_ptr(),
430                axis,
431                cstring_ptr(&name),
432            )
433        };
434        collect_owned_tensors(box_handle)
435    }
436
437    #[must_use]
438    pub fn split_num(
439        &self,
440        tensor: &Tensor,
441        num_splits: usize,
442        axis: isize,
443        name: Option<&str>,
444    ) -> Vec<Tensor> {
445        let name = optional_cstring(name);
446        // SAFETY: all handles remain valid for the duration of the call.
447        let box_handle = unsafe {
448            ffi::mpsgraph_graph_split_num(
449                self.as_ptr(),
450                tensor.as_ptr(),
451                num_splits,
452                axis,
453                cstring_ptr(&name),
454            )
455        };
456        collect_owned_tensors(box_handle)
457    }
458
459    #[must_use]
460    pub fn stack(&self, tensors: &[&Tensor], axis: isize, name: Option<&str>) -> Option<Tensor> {
461        let name = optional_cstring(name);
462        let handles = tensors.iter().map(|tensor| tensor.as_ptr()).collect::<Vec<_>>();
463        // SAFETY: all handles remain valid for the duration of the call.
464        let ptr = unsafe {
465            ffi::mpsgraph_graph_stack(
466                self.as_ptr(),
467                handles.as_ptr(),
468                handles.len(),
469                axis,
470                cstring_ptr(&name),
471            )
472        };
473        wrap_tensor(ptr)
474    }
475
476    #[must_use]
477    pub fn pad(
478        &self,
479        tensor: &Tensor,
480        padding_mode: isize,
481        left_padding: &[isize],
482        right_padding: &[isize],
483        constant_value: f64,
484        name: Option<&str>,
485    ) -> Option<Tensor> {
486        let name = optional_cstring(name);
487        // SAFETY: all handles remain valid for the duration of the call.
488        let ptr = unsafe {
489            ffi::mpsgraph_graph_pad(
490                self.as_ptr(),
491                tensor.as_ptr(),
492                padding_mode,
493                left_padding.as_ptr(),
494                left_padding.len(),
495                right_padding.as_ptr(),
496                right_padding.len(),
497                constant_value,
498                cstring_ptr(&name),
499            )
500        };
501        wrap_tensor(ptr)
502    }
503
504    #[must_use]
505    pub fn top_k(&self, source: &Tensor, k: usize, name: Option<&str>) -> Option<(Tensor, Tensor)> {
506        let name = optional_cstring(name);
507        // SAFETY: all handles remain valid for the duration of the call.
508        let box_handle = unsafe {
509            ffi::mpsgraph_graph_top_k(self.as_ptr(), source.as_ptr(), k, cstring_ptr(&name))
510        };
511        wrap_tensor_pair(box_handle)
512    }
513
514    #[must_use]
515    pub fn top_k_tensor(
516        &self,
517        source: &Tensor,
518        k_tensor: &Tensor,
519        name: Option<&str>,
520    ) -> Option<(Tensor, Tensor)> {
521        let name = optional_cstring(name);
522        // SAFETY: all handles remain valid for the duration of the call.
523        let box_handle = unsafe {
524            ffi::mpsgraph_graph_top_k_tensor(
525                self.as_ptr(),
526                source.as_ptr(),
527                k_tensor.as_ptr(),
528                cstring_ptr(&name),
529            )
530        };
531        wrap_tensor_pair(box_handle)
532    }
533}