Skip to main content

apple_mpsgraph/
random.rs

1use crate::error::{Error, Result};
2use crate::ffi;
3use crate::graph::Tensor;
4use crate::types::collect_owned_tensors;
5use core::ffi::{c_char, c_void};
6use core::ptr;
7use std::ffi::CString;
8
9fn optional_cstring(name: Option<&str>) -> Option<CString> {
10    name.and_then(|value| CString::new(value).ok())
11}
12
13#[allow(clippy::ref_option)]
14fn cstring_ptr(value: &Option<CString>) -> *const c_char {
15    value.as_ref().map_or(ptr::null(), |value| value.as_ptr())
16}
17
18fn wrap_tensor(ptr: *mut c_void) -> Option<Tensor> {
19    if ptr.is_null() {
20        None
21    } else {
22        Some(Tensor::from_raw(ptr))
23    }
24}
25
26fn wrap_tensor_pair(box_handle: *mut c_void) -> Option<(Tensor, Tensor)> {
27    let mut values = collect_owned_tensors(box_handle);
28    if values.len() != 2 {
29        return None;
30    }
31    let second = values.pop()?;
32    let first = values.pop()?;
33    Some((first, second))
34}
35
36/// `MPSGraphRandomDistribution` constants.
37pub mod random_distribution {
38    pub const UNIFORM: u64 = 0;
39    pub const NORMAL: u64 = 1;
40    pub const TRUNCATED_NORMAL: u64 = 2;
41}
42
43/// `MPSGraphRandomNormalSamplingMethod` constants.
44pub mod random_normal_sampling_method {
45    pub const INV_CDF: u64 = 0;
46    pub const BOX_MULLER: u64 = 1;
47}
48
49/// Safe owner for `MPSGraphRandomOpDescriptor`.
50pub struct RandomOpDescriptor {
51    ptr: *mut c_void,
52}
53
54unsafe impl Send for RandomOpDescriptor {}
55unsafe impl Sync for RandomOpDescriptor {}
56
57impl Drop for RandomOpDescriptor {
58    fn drop(&mut self) {
59        if !self.ptr.is_null() {
60            // SAFETY: `ptr` is a +1 retained Swift/ObjC object pointer owned by this wrapper.
61            unsafe { ffi::mpsgraph_object_release(self.ptr) };
62            self.ptr = ptr::null_mut();
63        }
64    }
65}
66
67impl RandomOpDescriptor {
68    #[must_use]
69    pub fn new(distribution: u64, data_type: u32) -> Option<Self> {
70        // SAFETY: pure constructor with POD arguments.
71        let ptr = unsafe { ffi::mpsgraph_random_op_descriptor_new(distribution, data_type) };
72        if ptr.is_null() {
73            None
74        } else {
75            Some(Self { ptr })
76        }
77    }
78
79    #[must_use]
80    pub(crate) const fn as_ptr(&self) -> *mut c_void {
81        self.ptr
82    }
83
84    #[must_use]
85    pub fn distribution(&self) -> u64 {
86        // SAFETY: `self.ptr` is a live descriptor handle.
87        unsafe { ffi::mpsgraph_random_op_descriptor_distribution(self.ptr) }
88    }
89
90    pub fn set_distribution(&self, value: u64) -> Result<()> {
91        // SAFETY: `self.ptr` is a live descriptor handle.
92        let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_distribution(self.ptr, value) };
93        if ok {
94            Ok(())
95        } else {
96            Err(Error::OperationFailed("failed to set random distribution"))
97        }
98    }
99
100    #[must_use]
101    pub fn data_type(&self) -> u32 {
102        // SAFETY: `self.ptr` is a live descriptor handle.
103        unsafe { ffi::mpsgraph_random_op_descriptor_data_type(self.ptr) }
104    }
105
106    pub fn set_data_type(&self, value: u32) -> Result<()> {
107        // SAFETY: `self.ptr` is a live descriptor handle.
108        let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_data_type(self.ptr, value) };
109        if ok {
110            Ok(())
111        } else {
112            Err(Error::OperationFailed("failed to set random data type"))
113        }
114    }
115
116    #[must_use]
117    pub fn min(&self) -> f32 {
118        // SAFETY: `self.ptr` is a live descriptor handle.
119        unsafe { ffi::mpsgraph_random_op_descriptor_min(self.ptr) }
120    }
121
122    pub fn set_min(&self, value: f32) -> Result<()> {
123        // SAFETY: `self.ptr` is a live descriptor handle.
124        let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_min(self.ptr, value) };
125        if ok {
126            Ok(())
127        } else {
128            Err(Error::OperationFailed("failed to set random min"))
129        }
130    }
131
132    #[must_use]
133    pub fn max(&self) -> f32 {
134        // SAFETY: `self.ptr` is a live descriptor handle.
135        unsafe { ffi::mpsgraph_random_op_descriptor_max(self.ptr) }
136    }
137
138    pub fn set_max(&self, value: f32) -> Result<()> {
139        // SAFETY: `self.ptr` is a live descriptor handle.
140        let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_max(self.ptr, value) };
141        if ok {
142            Ok(())
143        } else {
144            Err(Error::OperationFailed("failed to set random max"))
145        }
146    }
147
148    #[must_use]
149    pub fn min_integer(&self) -> isize {
150        // SAFETY: `self.ptr` is a live descriptor handle.
151        unsafe { ffi::mpsgraph_random_op_descriptor_min_integer(self.ptr) }
152    }
153
154    pub fn set_min_integer(&self, value: isize) -> Result<()> {
155        // SAFETY: `self.ptr` is a live descriptor handle.
156        let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_min_integer(self.ptr, value) };
157        if ok {
158            Ok(())
159        } else {
160            Err(Error::OperationFailed("failed to set random minInteger"))
161        }
162    }
163
164    #[must_use]
165    pub fn max_integer(&self) -> isize {
166        // SAFETY: `self.ptr` is a live descriptor handle.
167        unsafe { ffi::mpsgraph_random_op_descriptor_max_integer(self.ptr) }
168    }
169
170    pub fn set_max_integer(&self, value: isize) -> Result<()> {
171        // SAFETY: `self.ptr` is a live descriptor handle.
172        let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_max_integer(self.ptr, value) };
173        if ok {
174            Ok(())
175        } else {
176            Err(Error::OperationFailed("failed to set random maxInteger"))
177        }
178    }
179
180    #[must_use]
181    pub fn mean(&self) -> f32 {
182        // SAFETY: `self.ptr` is a live descriptor handle.
183        unsafe { ffi::mpsgraph_random_op_descriptor_mean(self.ptr) }
184    }
185
186    pub fn set_mean(&self, value: f32) -> Result<()> {
187        // SAFETY: `self.ptr` is a live descriptor handle.
188        let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_mean(self.ptr, value) };
189        if ok {
190            Ok(())
191        } else {
192            Err(Error::OperationFailed("failed to set random mean"))
193        }
194    }
195
196    #[must_use]
197    pub fn standard_deviation(&self) -> f32 {
198        // SAFETY: `self.ptr` is a live descriptor handle.
199        unsafe { ffi::mpsgraph_random_op_descriptor_standard_deviation(self.ptr) }
200    }
201
202    pub fn set_standard_deviation(&self, value: f32) -> Result<()> {
203        // SAFETY: `self.ptr` is a live descriptor handle.
204        let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_standard_deviation(self.ptr, value) };
205        if ok {
206            Ok(())
207        } else {
208            Err(Error::OperationFailed("failed to set random standardDeviation"))
209        }
210    }
211
212    #[must_use]
213    pub fn sampling_method(&self) -> u64 {
214        // SAFETY: `self.ptr` is a live descriptor handle.
215        unsafe { ffi::mpsgraph_random_op_descriptor_sampling_method(self.ptr) }
216    }
217
218    pub fn set_sampling_method(&self, value: u64) -> Result<()> {
219        // SAFETY: `self.ptr` is a live descriptor handle.
220        let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_sampling_method(self.ptr, value) };
221        if ok {
222            Ok(())
223        } else {
224            Err(Error::OperationFailed("failed to set random sampling method"))
225        }
226    }
227}
228
229impl crate::graph::Graph {
230    #[must_use]
231    pub fn random_philox_state_seed(&self, seed: usize, name: Option<&str>) -> Option<Tensor> {
232        let name = optional_cstring(name);
233        // SAFETY: all handles remain valid for the duration of the call.
234        let ptr = unsafe { ffi::mpsgraph_graph_random_philox_state_seed(self.as_ptr(), seed, cstring_ptr(&name)) };
235        wrap_tensor(ptr)
236    }
237
238    #[must_use]
239    pub fn random_philox_state_counter(
240        &self,
241        counter_low: usize,
242        counter_high: usize,
243        key: usize,
244        name: Option<&str>,
245    ) -> Option<Tensor> {
246        let name = optional_cstring(name);
247        // SAFETY: all handles remain valid for the duration of the call.
248        let ptr = unsafe {
249            ffi::mpsgraph_graph_random_philox_state_counter(
250                self.as_ptr(),
251                counter_low,
252                counter_high,
253                key,
254                cstring_ptr(&name),
255            )
256        };
257        wrap_tensor(ptr)
258    }
259
260    #[must_use]
261    pub fn random_tensor(
262        &self,
263        shape: &[usize],
264        descriptor: &RandomOpDescriptor,
265        name: Option<&str>,
266    ) -> Option<Tensor> {
267        let name = optional_cstring(name);
268        let shape_ptr = if shape.is_empty() { ptr::null() } else { shape.as_ptr() };
269        // SAFETY: all handles remain valid for the duration of the call.
270        let ptr = unsafe {
271            ffi::mpsgraph_graph_random_tensor(
272                self.as_ptr(),
273                shape_ptr,
274                shape.len(),
275                descriptor.as_ptr(),
276                cstring_ptr(&name),
277            )
278        };
279        wrap_tensor(ptr)
280    }
281
282    #[must_use]
283    pub fn random_tensor_shape_tensor(
284        &self,
285        shape_tensor: &Tensor,
286        descriptor: &RandomOpDescriptor,
287        name: Option<&str>,
288    ) -> Option<Tensor> {
289        let name = optional_cstring(name);
290        // SAFETY: all handles remain valid for the duration of the call.
291        let ptr = unsafe {
292            ffi::mpsgraph_graph_random_tensor_shape_tensor(
293                self.as_ptr(),
294                shape_tensor.as_ptr(),
295                descriptor.as_ptr(),
296                cstring_ptr(&name),
297            )
298        };
299        wrap_tensor(ptr)
300    }
301
302    #[must_use]
303    pub fn random_tensor_seed(
304        &self,
305        shape: &[usize],
306        descriptor: &RandomOpDescriptor,
307        seed: usize,
308        name: Option<&str>,
309    ) -> Option<Tensor> {
310        let name = optional_cstring(name);
311        let shape_ptr = if shape.is_empty() { ptr::null() } else { shape.as_ptr() };
312        // SAFETY: all handles remain valid for the duration of the call.
313        let ptr = unsafe {
314            ffi::mpsgraph_graph_random_tensor_seed(
315                self.as_ptr(),
316                shape_ptr,
317                shape.len(),
318                descriptor.as_ptr(),
319                seed,
320                cstring_ptr(&name),
321            )
322        };
323        wrap_tensor(ptr)
324    }
325
326    #[must_use]
327    pub fn random_tensor_shape_tensor_seed(
328        &self,
329        shape_tensor: &Tensor,
330        descriptor: &RandomOpDescriptor,
331        seed: usize,
332        name: Option<&str>,
333    ) -> Option<Tensor> {
334        let name = optional_cstring(name);
335        // SAFETY: all handles remain valid for the duration of the call.
336        let ptr = unsafe {
337            ffi::mpsgraph_graph_random_tensor_shape_tensor_seed(
338                self.as_ptr(),
339                shape_tensor.as_ptr(),
340                descriptor.as_ptr(),
341                seed,
342                cstring_ptr(&name),
343            )
344        };
345        wrap_tensor(ptr)
346    }
347
348    #[must_use]
349    pub fn random_tensor_state(
350        &self,
351        shape: &[usize],
352        descriptor: &RandomOpDescriptor,
353        state: &Tensor,
354        name: Option<&str>,
355    ) -> Option<(Tensor, Tensor)> {
356        let name = optional_cstring(name);
357        let shape_ptr = if shape.is_empty() { ptr::null() } else { shape.as_ptr() };
358        // SAFETY: all handles remain valid for the duration of the call.
359        let box_handle = unsafe {
360            ffi::mpsgraph_graph_random_tensor_state(
361                self.as_ptr(),
362                shape_ptr,
363                shape.len(),
364                descriptor.as_ptr(),
365                state.as_ptr(),
366                cstring_ptr(&name),
367            )
368        };
369        wrap_tensor_pair(box_handle)
370    }
371
372    #[must_use]
373    pub fn random_tensor_shape_tensor_state(
374        &self,
375        shape_tensor: &Tensor,
376        descriptor: &RandomOpDescriptor,
377        state: &Tensor,
378        name: Option<&str>,
379    ) -> Option<(Tensor, Tensor)> {
380        let name = optional_cstring(name);
381        // SAFETY: all handles remain valid for the duration of the call.
382        let box_handle = unsafe {
383            ffi::mpsgraph_graph_random_tensor_shape_tensor_state(
384                self.as_ptr(),
385                shape_tensor.as_ptr(),
386                descriptor.as_ptr(),
387                state.as_ptr(),
388                cstring_ptr(&name),
389            )
390        };
391        wrap_tensor_pair(box_handle)
392    }
393
394    #[must_use]
395    pub fn dropout(&self, tensor: &Tensor, rate: f64, name: Option<&str>) -> Option<Tensor> {
396        let name = optional_cstring(name);
397        // SAFETY: all handles remain valid for the duration of the call.
398        let ptr = unsafe {
399            ffi::mpsgraph_graph_dropout(self.as_ptr(), tensor.as_ptr(), rate, cstring_ptr(&name))
400        };
401        wrap_tensor(ptr)
402    }
403
404    #[must_use]
405    pub fn dropout_tensor(
406        &self,
407        tensor: &Tensor,
408        rate_tensor: &Tensor,
409        name: Option<&str>,
410    ) -> Option<Tensor> {
411        let name = optional_cstring(name);
412        // SAFETY: all handles remain valid for the duration of the call.
413        let ptr = unsafe {
414            ffi::mpsgraph_graph_dropout_tensor(
415                self.as_ptr(),
416                tensor.as_ptr(),
417                rate_tensor.as_ptr(),
418                cstring_ptr(&name),
419            )
420        };
421        wrap_tensor(ptr)
422    }
423}