Skip to main content

apple_mpsgraph/
random.rs

1use crate::error::{Error, Result};
2use crate::ffi;
3use crate::graph::Tensor;
4use crate::types::collect_owned_tensors;
5use core::ffi::{c_char, c_void};
6use core::ptr;
7use std::ffi::CString;
8
9fn optional_cstring(name: Option<&str>) -> Option<CString> {
10    name.and_then(|value| CString::new(value).ok())
11}
12
13#[allow(clippy::ref_option)]
14fn cstring_ptr(value: &Option<CString>) -> *const c_char {
15    value.as_ref().map_or(ptr::null(), |value| value.as_ptr())
16}
17
18fn wrap_tensor(ptr: *mut c_void) -> Option<Tensor> {
19    if ptr.is_null() {
20        None
21    } else {
22        Some(Tensor::from_raw(ptr))
23    }
24}
25
26fn wrap_tensor_pair(box_handle: *mut c_void) -> Option<(Tensor, Tensor)> {
27    let mut values = collect_owned_tensors(box_handle);
28    if values.len() != 2 {
29        return None;
30    }
31    let second = values.pop()?;
32    let first = values.pop()?;
33    Some((first, second))
34}
35
36/// `MPSGraphRandomDistribution` constants.
37pub mod random_distribution {
38/// Mirrors the `MPSGraph` framework constant `UNIFORM`.
39    pub const UNIFORM: u64 = 0;
40/// Mirrors the `MPSGraph` framework constant `NORMAL`.
41    pub const NORMAL: u64 = 1;
42/// Mirrors the `MPSGraph` framework constant `TRUNCATED_NORMAL`.
43    pub const TRUNCATED_NORMAL: u64 = 2;
44}
45
46/// `MPSGraphRandomNormalSamplingMethod` constants.
47pub mod random_normal_sampling_method {
48/// Mirrors the `MPSGraph` framework constant `INV_CDF`.
49    pub const INV_CDF: u64 = 0;
50/// Mirrors the `MPSGraph` framework constant `BOX_MULLER`.
51    pub const BOX_MULLER: u64 = 1;
52}
53
54/// Safe owner for `MPSGraphRandomOpDescriptor`.
55pub struct RandomOpDescriptor {
56    ptr: *mut c_void,
57}
58
59unsafe impl Send for RandomOpDescriptor {}
60unsafe impl Sync for RandomOpDescriptor {}
61
62impl Drop for RandomOpDescriptor {
63    fn drop(&mut self) {
64        if !self.ptr.is_null() {
65            // SAFETY: `ptr` is a +1 retained Swift/ObjC object pointer owned by this wrapper.
66            unsafe { ffi::mpsgraph_object_release(self.ptr) };
67            self.ptr = ptr::null_mut();
68        }
69    }
70}
71
72impl RandomOpDescriptor {
73/// Calls the `MPSGraph` framework counterpart for `new`.
74    #[must_use]
75    pub fn new(distribution: u64, data_type: u32) -> Option<Self> {
76        // SAFETY: pure constructor with POD arguments.
77        let ptr = unsafe { ffi::mpsgraph_random_op_descriptor_new(distribution, data_type) };
78        if ptr.is_null() {
79            None
80        } else {
81            Some(Self { ptr })
82        }
83    }
84
85    #[must_use]
86    pub(crate) const fn as_ptr(&self) -> *mut c_void {
87        self.ptr
88    }
89
90/// Calls the `MPSGraph` framework counterpart for `distribution`.
91    #[must_use]
92    pub fn distribution(&self) -> u64 {
93        // SAFETY: `self.ptr` is a live descriptor handle.
94        unsafe { ffi::mpsgraph_random_op_descriptor_distribution(self.ptr) }
95    }
96
97/// Calls the `MPSGraph` framework counterpart for `set_distribution`.
98    pub fn set_distribution(&self, value: u64) -> Result<()> {
99        // SAFETY: `self.ptr` is a live descriptor handle.
100        let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_distribution(self.ptr, value) };
101        if ok {
102            Ok(())
103        } else {
104            Err(Error::OperationFailed("failed to set random distribution"))
105        }
106    }
107
108/// Calls the `MPSGraph` framework counterpart for `data_type`.
109    #[must_use]
110    pub fn data_type(&self) -> u32 {
111        // SAFETY: `self.ptr` is a live descriptor handle.
112        unsafe { ffi::mpsgraph_random_op_descriptor_data_type(self.ptr) }
113    }
114
115/// Calls the `MPSGraph` framework counterpart for `set_data_type`.
116    pub fn set_data_type(&self, value: u32) -> Result<()> {
117        // SAFETY: `self.ptr` is a live descriptor handle.
118        let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_data_type(self.ptr, value) };
119        if ok {
120            Ok(())
121        } else {
122            Err(Error::OperationFailed("failed to set random data type"))
123        }
124    }
125
126/// Calls the `MPSGraph` framework counterpart for `min`.
127    #[must_use]
128    pub fn min(&self) -> f32 {
129        // SAFETY: `self.ptr` is a live descriptor handle.
130        unsafe { ffi::mpsgraph_random_op_descriptor_min(self.ptr) }
131    }
132
133/// Calls the `MPSGraph` framework counterpart for `set_min`.
134    pub fn set_min(&self, value: f32) -> Result<()> {
135        // SAFETY: `self.ptr` is a live descriptor handle.
136        let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_min(self.ptr, value) };
137        if ok {
138            Ok(())
139        } else {
140            Err(Error::OperationFailed("failed to set random min"))
141        }
142    }
143
144/// Calls the `MPSGraph` framework counterpart for `max`.
145    #[must_use]
146    pub fn max(&self) -> f32 {
147        // SAFETY: `self.ptr` is a live descriptor handle.
148        unsafe { ffi::mpsgraph_random_op_descriptor_max(self.ptr) }
149    }
150
151/// Calls the `MPSGraph` framework counterpart for `set_max`.
152    pub fn set_max(&self, value: f32) -> Result<()> {
153        // SAFETY: `self.ptr` is a live descriptor handle.
154        let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_max(self.ptr, value) };
155        if ok {
156            Ok(())
157        } else {
158            Err(Error::OperationFailed("failed to set random max"))
159        }
160    }
161
162/// Calls the `MPSGraph` framework counterpart for `min_integer`.
163    #[must_use]
164    pub fn min_integer(&self) -> isize {
165        // SAFETY: `self.ptr` is a live descriptor handle.
166        unsafe { ffi::mpsgraph_random_op_descriptor_min_integer(self.ptr) }
167    }
168
169/// Calls the `MPSGraph` framework counterpart for `set_min_integer`.
170    pub fn set_min_integer(&self, value: isize) -> Result<()> {
171        // SAFETY: `self.ptr` is a live descriptor handle.
172        let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_min_integer(self.ptr, value) };
173        if ok {
174            Ok(())
175        } else {
176            Err(Error::OperationFailed("failed to set random minInteger"))
177        }
178    }
179
180/// Calls the `MPSGraph` framework counterpart for `max_integer`.
181    #[must_use]
182    pub fn max_integer(&self) -> isize {
183        // SAFETY: `self.ptr` is a live descriptor handle.
184        unsafe { ffi::mpsgraph_random_op_descriptor_max_integer(self.ptr) }
185    }
186
187/// Calls the `MPSGraph` framework counterpart for `set_max_integer`.
188    pub fn set_max_integer(&self, value: isize) -> Result<()> {
189        // SAFETY: `self.ptr` is a live descriptor handle.
190        let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_max_integer(self.ptr, value) };
191        if ok {
192            Ok(())
193        } else {
194            Err(Error::OperationFailed("failed to set random maxInteger"))
195        }
196    }
197
198/// Calls the `MPSGraph` framework counterpart for `mean`.
199    #[must_use]
200    pub fn mean(&self) -> f32 {
201        // SAFETY: `self.ptr` is a live descriptor handle.
202        unsafe { ffi::mpsgraph_random_op_descriptor_mean(self.ptr) }
203    }
204
205/// Calls the `MPSGraph` framework counterpart for `set_mean`.
206    pub fn set_mean(&self, value: f32) -> Result<()> {
207        // SAFETY: `self.ptr` is a live descriptor handle.
208        let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_mean(self.ptr, value) };
209        if ok {
210            Ok(())
211        } else {
212            Err(Error::OperationFailed("failed to set random mean"))
213        }
214    }
215
216/// Calls the `MPSGraph` framework counterpart for `standard_deviation`.
217    #[must_use]
218    pub fn standard_deviation(&self) -> f32 {
219        // SAFETY: `self.ptr` is a live descriptor handle.
220        unsafe { ffi::mpsgraph_random_op_descriptor_standard_deviation(self.ptr) }
221    }
222
223/// Calls the `MPSGraph` framework counterpart for `set_standard_deviation`.
224    pub fn set_standard_deviation(&self, value: f32) -> Result<()> {
225        // SAFETY: `self.ptr` is a live descriptor handle.
226        let ok =
227            unsafe { ffi::mpsgraph_random_op_descriptor_set_standard_deviation(self.ptr, value) };
228        if ok {
229            Ok(())
230        } else {
231            Err(Error::OperationFailed(
232                "failed to set random standardDeviation",
233            ))
234        }
235    }
236
237/// Calls the `MPSGraph` framework counterpart for `sampling_method`.
238    #[must_use]
239    pub fn sampling_method(&self) -> u64 {
240        // SAFETY: `self.ptr` is a live descriptor handle.
241        unsafe { ffi::mpsgraph_random_op_descriptor_sampling_method(self.ptr) }
242    }
243
244/// Calls the `MPSGraph` framework counterpart for `set_sampling_method`.
245    pub fn set_sampling_method(&self, value: u64) -> Result<()> {
246        // SAFETY: `self.ptr` is a live descriptor handle.
247        let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_sampling_method(self.ptr, value) };
248        if ok {
249            Ok(())
250        } else {
251            Err(Error::OperationFailed(
252                "failed to set random sampling method",
253            ))
254        }
255    }
256}
257
258impl crate::graph::Graph {
259/// Calls the `MPSGraph` framework counterpart for `random_philox_state_seed`.
260    #[must_use]
261    pub fn random_philox_state_seed(&self, seed: usize, name: Option<&str>) -> Option<Tensor> {
262        let name = optional_cstring(name);
263        // SAFETY: all handles remain valid for the duration of the call.
264        let ptr = unsafe {
265            ffi::mpsgraph_graph_random_philox_state_seed(self.as_ptr(), seed, cstring_ptr(&name))
266        };
267        wrap_tensor(ptr)
268    }
269
270/// Calls the `MPSGraph` framework counterpart for `random_philox_state_counter`.
271    #[must_use]
272    pub fn random_philox_state_counter(
273        &self,
274        counter_low: usize,
275        counter_high: usize,
276        key: usize,
277        name: Option<&str>,
278    ) -> Option<Tensor> {
279        let name = optional_cstring(name);
280        // SAFETY: all handles remain valid for the duration of the call.
281        let ptr = unsafe {
282            ffi::mpsgraph_graph_random_philox_state_counter(
283                self.as_ptr(),
284                counter_low,
285                counter_high,
286                key,
287                cstring_ptr(&name),
288            )
289        };
290        wrap_tensor(ptr)
291    }
292
293/// Calls the `MPSGraph` framework counterpart for `random_tensor`.
294    #[must_use]
295    pub fn random_tensor(
296        &self,
297        shape: &[usize],
298        descriptor: &RandomOpDescriptor,
299        name: Option<&str>,
300    ) -> Option<Tensor> {
301        let name = optional_cstring(name);
302        let shape_ptr = if shape.is_empty() {
303            ptr::null()
304        } else {
305            shape.as_ptr()
306        };
307        // SAFETY: all handles remain valid for the duration of the call.
308        let ptr = unsafe {
309            ffi::mpsgraph_graph_random_tensor(
310                self.as_ptr(),
311                shape_ptr,
312                shape.len(),
313                descriptor.as_ptr(),
314                cstring_ptr(&name),
315            )
316        };
317        wrap_tensor(ptr)
318    }
319
320/// Calls the `MPSGraph` framework counterpart for `random_tensor_shape_tensor`.
321    #[must_use]
322    pub fn random_tensor_shape_tensor(
323        &self,
324        shape_tensor: &Tensor,
325        descriptor: &RandomOpDescriptor,
326        name: Option<&str>,
327    ) -> Option<Tensor> {
328        let name = optional_cstring(name);
329        // SAFETY: all handles remain valid for the duration of the call.
330        let ptr = unsafe {
331            ffi::mpsgraph_graph_random_tensor_shape_tensor(
332                self.as_ptr(),
333                shape_tensor.as_ptr(),
334                descriptor.as_ptr(),
335                cstring_ptr(&name),
336            )
337        };
338        wrap_tensor(ptr)
339    }
340
341/// Calls the `MPSGraph` framework counterpart for `random_tensor_seed`.
342    #[must_use]
343    pub fn random_tensor_seed(
344        &self,
345        shape: &[usize],
346        descriptor: &RandomOpDescriptor,
347        seed: usize,
348        name: Option<&str>,
349    ) -> Option<Tensor> {
350        let name = optional_cstring(name);
351        let shape_ptr = if shape.is_empty() {
352            ptr::null()
353        } else {
354            shape.as_ptr()
355        };
356        // SAFETY: all handles remain valid for the duration of the call.
357        let ptr = unsafe {
358            ffi::mpsgraph_graph_random_tensor_seed(
359                self.as_ptr(),
360                shape_ptr,
361                shape.len(),
362                descriptor.as_ptr(),
363                seed,
364                cstring_ptr(&name),
365            )
366        };
367        wrap_tensor(ptr)
368    }
369
370/// Calls the `MPSGraph` framework counterpart for `random_tensor_shape_tensor_seed`.
371    #[must_use]
372    pub fn random_tensor_shape_tensor_seed(
373        &self,
374        shape_tensor: &Tensor,
375        descriptor: &RandomOpDescriptor,
376        seed: usize,
377        name: Option<&str>,
378    ) -> Option<Tensor> {
379        let name = optional_cstring(name);
380        // SAFETY: all handles remain valid for the duration of the call.
381        let ptr = unsafe {
382            ffi::mpsgraph_graph_random_tensor_shape_tensor_seed(
383                self.as_ptr(),
384                shape_tensor.as_ptr(),
385                descriptor.as_ptr(),
386                seed,
387                cstring_ptr(&name),
388            )
389        };
390        wrap_tensor(ptr)
391    }
392
393/// Calls the `MPSGraph` framework counterpart for `random_tensor_state`.
394    #[must_use]
395    pub fn random_tensor_state(
396        &self,
397        shape: &[usize],
398        descriptor: &RandomOpDescriptor,
399        state: &Tensor,
400        name: Option<&str>,
401    ) -> Option<(Tensor, Tensor)> {
402        let name = optional_cstring(name);
403        let shape_ptr = if shape.is_empty() {
404            ptr::null()
405        } else {
406            shape.as_ptr()
407        };
408        // SAFETY: all handles remain valid for the duration of the call.
409        let box_handle = unsafe {
410            ffi::mpsgraph_graph_random_tensor_state(
411                self.as_ptr(),
412                shape_ptr,
413                shape.len(),
414                descriptor.as_ptr(),
415                state.as_ptr(),
416                cstring_ptr(&name),
417            )
418        };
419        wrap_tensor_pair(box_handle)
420    }
421
422/// Calls the `MPSGraph` framework counterpart for `random_tensor_shape_tensor_state`.
423    #[must_use]
424    pub fn random_tensor_shape_tensor_state(
425        &self,
426        shape_tensor: &Tensor,
427        descriptor: &RandomOpDescriptor,
428        state: &Tensor,
429        name: Option<&str>,
430    ) -> Option<(Tensor, Tensor)> {
431        let name = optional_cstring(name);
432        // SAFETY: all handles remain valid for the duration of the call.
433        let box_handle = unsafe {
434            ffi::mpsgraph_graph_random_tensor_shape_tensor_state(
435                self.as_ptr(),
436                shape_tensor.as_ptr(),
437                descriptor.as_ptr(),
438                state.as_ptr(),
439                cstring_ptr(&name),
440            )
441        };
442        wrap_tensor_pair(box_handle)
443    }
444
445/// Calls the `MPSGraph` framework counterpart for `dropout`.
446    #[must_use]
447    pub fn dropout(&self, tensor: &Tensor, rate: f64, name: Option<&str>) -> Option<Tensor> {
448        let name = optional_cstring(name);
449        // SAFETY: all handles remain valid for the duration of the call.
450        let ptr = unsafe {
451            ffi::mpsgraph_graph_dropout(self.as_ptr(), tensor.as_ptr(), rate, cstring_ptr(&name))
452        };
453        wrap_tensor(ptr)
454    }
455
456/// Calls the `MPSGraph` framework counterpart for `dropout_tensor`.
457    #[must_use]
458    pub fn dropout_tensor(
459        &self,
460        tensor: &Tensor,
461        rate_tensor: &Tensor,
462        name: Option<&str>,
463    ) -> Option<Tensor> {
464        let name = optional_cstring(name);
465        // SAFETY: all handles remain valid for the duration of the call.
466        let ptr = unsafe {
467            ffi::mpsgraph_graph_dropout_tensor(
468                self.as_ptr(),
469                tensor.as_ptr(),
470                rate_tensor.as_ptr(),
471                cstring_ptr(&name),
472            )
473        };
474        wrap_tensor(ptr)
475    }
476}