Skip to main content

apple_mpsgraph/
random.rs

1use crate::error::{Error, Result};
2use crate::ffi;
3use crate::graph::Tensor;
4use crate::types::collect_owned_tensors;
5use core::ffi::{c_char, c_void};
6use core::ptr;
7use std::ffi::CString;
8
9fn optional_cstring(name: Option<&str>) -> Option<CString> {
10    name.and_then(|value| CString::new(value).ok())
11}
12
13#[allow(clippy::ref_option)]
14fn cstring_ptr(value: &Option<CString>) -> *const c_char {
15    value.as_ref().map_or(ptr::null(), |value| value.as_ptr())
16}
17
18fn wrap_tensor(ptr: *mut c_void) -> Option<Tensor> {
19    if ptr.is_null() {
20        None
21    } else {
22        Some(Tensor::from_raw(ptr))
23    }
24}
25
26fn wrap_tensor_pair(box_handle: *mut c_void) -> Option<(Tensor, Tensor)> {
27    let mut values = collect_owned_tensors(box_handle);
28    if values.len() != 2 {
29        return None;
30    }
31    let second = values.pop()?;
32    let first = values.pop()?;
33    Some((first, second))
34}
35
36/// `MPSGraphRandomDistribution` constants.
37pub mod random_distribution {
38    pub const UNIFORM: u64 = 0;
39    pub const NORMAL: u64 = 1;
40    pub const TRUNCATED_NORMAL: u64 = 2;
41}
42
43/// `MPSGraphRandomNormalSamplingMethod` constants.
44pub mod random_normal_sampling_method {
45    pub const INV_CDF: u64 = 0;
46    pub const BOX_MULLER: u64 = 1;
47}
48
49/// Safe owner for `MPSGraphRandomOpDescriptor`.
50pub struct RandomOpDescriptor {
51    ptr: *mut c_void,
52}
53
54unsafe impl Send for RandomOpDescriptor {}
55unsafe impl Sync for RandomOpDescriptor {}
56
57impl Drop for RandomOpDescriptor {
58    fn drop(&mut self) {
59        if !self.ptr.is_null() {
60            // SAFETY: `ptr` is a +1 retained Swift/ObjC object pointer owned by this wrapper.
61            unsafe { ffi::mpsgraph_object_release(self.ptr) };
62            self.ptr = ptr::null_mut();
63        }
64    }
65}
66
67impl RandomOpDescriptor {
68    #[must_use]
69    pub fn new(distribution: u64, data_type: u32) -> Option<Self> {
70        // SAFETY: pure constructor with POD arguments.
71        let ptr = unsafe { ffi::mpsgraph_random_op_descriptor_new(distribution, data_type) };
72        if ptr.is_null() {
73            None
74        } else {
75            Some(Self { ptr })
76        }
77    }
78
79    #[must_use]
80    pub(crate) const fn as_ptr(&self) -> *mut c_void {
81        self.ptr
82    }
83
84    #[must_use]
85    pub fn distribution(&self) -> u64 {
86        // SAFETY: `self.ptr` is a live descriptor handle.
87        unsafe { ffi::mpsgraph_random_op_descriptor_distribution(self.ptr) }
88    }
89
90    pub fn set_distribution(&self, value: u64) -> Result<()> {
91        // SAFETY: `self.ptr` is a live descriptor handle.
92        let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_distribution(self.ptr, value) };
93        if ok {
94            Ok(())
95        } else {
96            Err(Error::OperationFailed("failed to set random distribution"))
97        }
98    }
99
100    #[must_use]
101    pub fn data_type(&self) -> u32 {
102        // SAFETY: `self.ptr` is a live descriptor handle.
103        unsafe { ffi::mpsgraph_random_op_descriptor_data_type(self.ptr) }
104    }
105
106    pub fn set_data_type(&self, value: u32) -> Result<()> {
107        // SAFETY: `self.ptr` is a live descriptor handle.
108        let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_data_type(self.ptr, value) };
109        if ok {
110            Ok(())
111        } else {
112            Err(Error::OperationFailed("failed to set random data type"))
113        }
114    }
115
116    #[must_use]
117    pub fn min(&self) -> f32 {
118        // SAFETY: `self.ptr` is a live descriptor handle.
119        unsafe { ffi::mpsgraph_random_op_descriptor_min(self.ptr) }
120    }
121
122    pub fn set_min(&self, value: f32) -> Result<()> {
123        // SAFETY: `self.ptr` is a live descriptor handle.
124        let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_min(self.ptr, value) };
125        if ok {
126            Ok(())
127        } else {
128            Err(Error::OperationFailed("failed to set random min"))
129        }
130    }
131
132    #[must_use]
133    pub fn max(&self) -> f32 {
134        // SAFETY: `self.ptr` is a live descriptor handle.
135        unsafe { ffi::mpsgraph_random_op_descriptor_max(self.ptr) }
136    }
137
138    pub fn set_max(&self, value: f32) -> Result<()> {
139        // SAFETY: `self.ptr` is a live descriptor handle.
140        let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_max(self.ptr, value) };
141        if ok {
142            Ok(())
143        } else {
144            Err(Error::OperationFailed("failed to set random max"))
145        }
146    }
147
148    #[must_use]
149    pub fn min_integer(&self) -> isize {
150        // SAFETY: `self.ptr` is a live descriptor handle.
151        unsafe { ffi::mpsgraph_random_op_descriptor_min_integer(self.ptr) }
152    }
153
154    pub fn set_min_integer(&self, value: isize) -> Result<()> {
155        // SAFETY: `self.ptr` is a live descriptor handle.
156        let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_min_integer(self.ptr, value) };
157        if ok {
158            Ok(())
159        } else {
160            Err(Error::OperationFailed("failed to set random minInteger"))
161        }
162    }
163
164    #[must_use]
165    pub fn max_integer(&self) -> isize {
166        // SAFETY: `self.ptr` is a live descriptor handle.
167        unsafe { ffi::mpsgraph_random_op_descriptor_max_integer(self.ptr) }
168    }
169
170    pub fn set_max_integer(&self, value: isize) -> Result<()> {
171        // SAFETY: `self.ptr` is a live descriptor handle.
172        let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_max_integer(self.ptr, value) };
173        if ok {
174            Ok(())
175        } else {
176            Err(Error::OperationFailed("failed to set random maxInteger"))
177        }
178    }
179
180    #[must_use]
181    pub fn mean(&self) -> f32 {
182        // SAFETY: `self.ptr` is a live descriptor handle.
183        unsafe { ffi::mpsgraph_random_op_descriptor_mean(self.ptr) }
184    }
185
186    pub fn set_mean(&self, value: f32) -> Result<()> {
187        // SAFETY: `self.ptr` is a live descriptor handle.
188        let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_mean(self.ptr, value) };
189        if ok {
190            Ok(())
191        } else {
192            Err(Error::OperationFailed("failed to set random mean"))
193        }
194    }
195
196    #[must_use]
197    pub fn standard_deviation(&self) -> f32 {
198        // SAFETY: `self.ptr` is a live descriptor handle.
199        unsafe { ffi::mpsgraph_random_op_descriptor_standard_deviation(self.ptr) }
200    }
201
202    pub fn set_standard_deviation(&self, value: f32) -> Result<()> {
203        // SAFETY: `self.ptr` is a live descriptor handle.
204        let ok =
205            unsafe { ffi::mpsgraph_random_op_descriptor_set_standard_deviation(self.ptr, value) };
206        if ok {
207            Ok(())
208        } else {
209            Err(Error::OperationFailed(
210                "failed to set random standardDeviation",
211            ))
212        }
213    }
214
215    #[must_use]
216    pub fn sampling_method(&self) -> u64 {
217        // SAFETY: `self.ptr` is a live descriptor handle.
218        unsafe { ffi::mpsgraph_random_op_descriptor_sampling_method(self.ptr) }
219    }
220
221    pub fn set_sampling_method(&self, value: u64) -> Result<()> {
222        // SAFETY: `self.ptr` is a live descriptor handle.
223        let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_sampling_method(self.ptr, value) };
224        if ok {
225            Ok(())
226        } else {
227            Err(Error::OperationFailed(
228                "failed to set random sampling method",
229            ))
230        }
231    }
232}
233
234impl crate::graph::Graph {
235    #[must_use]
236    pub fn random_philox_state_seed(&self, seed: usize, name: Option<&str>) -> Option<Tensor> {
237        let name = optional_cstring(name);
238        // SAFETY: all handles remain valid for the duration of the call.
239        let ptr = unsafe {
240            ffi::mpsgraph_graph_random_philox_state_seed(self.as_ptr(), seed, cstring_ptr(&name))
241        };
242        wrap_tensor(ptr)
243    }
244
245    #[must_use]
246    pub fn random_philox_state_counter(
247        &self,
248        counter_low: usize,
249        counter_high: usize,
250        key: usize,
251        name: Option<&str>,
252    ) -> Option<Tensor> {
253        let name = optional_cstring(name);
254        // SAFETY: all handles remain valid for the duration of the call.
255        let ptr = unsafe {
256            ffi::mpsgraph_graph_random_philox_state_counter(
257                self.as_ptr(),
258                counter_low,
259                counter_high,
260                key,
261                cstring_ptr(&name),
262            )
263        };
264        wrap_tensor(ptr)
265    }
266
267    #[must_use]
268    pub fn random_tensor(
269        &self,
270        shape: &[usize],
271        descriptor: &RandomOpDescriptor,
272        name: Option<&str>,
273    ) -> Option<Tensor> {
274        let name = optional_cstring(name);
275        let shape_ptr = if shape.is_empty() {
276            ptr::null()
277        } else {
278            shape.as_ptr()
279        };
280        // SAFETY: all handles remain valid for the duration of the call.
281        let ptr = unsafe {
282            ffi::mpsgraph_graph_random_tensor(
283                self.as_ptr(),
284                shape_ptr,
285                shape.len(),
286                descriptor.as_ptr(),
287                cstring_ptr(&name),
288            )
289        };
290        wrap_tensor(ptr)
291    }
292
293    #[must_use]
294    pub fn random_tensor_shape_tensor(
295        &self,
296        shape_tensor: &Tensor,
297        descriptor: &RandomOpDescriptor,
298        name: Option<&str>,
299    ) -> Option<Tensor> {
300        let name = optional_cstring(name);
301        // SAFETY: all handles remain valid for the duration of the call.
302        let ptr = unsafe {
303            ffi::mpsgraph_graph_random_tensor_shape_tensor(
304                self.as_ptr(),
305                shape_tensor.as_ptr(),
306                descriptor.as_ptr(),
307                cstring_ptr(&name),
308            )
309        };
310        wrap_tensor(ptr)
311    }
312
313    #[must_use]
314    pub fn random_tensor_seed(
315        &self,
316        shape: &[usize],
317        descriptor: &RandomOpDescriptor,
318        seed: usize,
319        name: Option<&str>,
320    ) -> Option<Tensor> {
321        let name = optional_cstring(name);
322        let shape_ptr = if shape.is_empty() {
323            ptr::null()
324        } else {
325            shape.as_ptr()
326        };
327        // SAFETY: all handles remain valid for the duration of the call.
328        let ptr = unsafe {
329            ffi::mpsgraph_graph_random_tensor_seed(
330                self.as_ptr(),
331                shape_ptr,
332                shape.len(),
333                descriptor.as_ptr(),
334                seed,
335                cstring_ptr(&name),
336            )
337        };
338        wrap_tensor(ptr)
339    }
340
341    #[must_use]
342    pub fn random_tensor_shape_tensor_seed(
343        &self,
344        shape_tensor: &Tensor,
345        descriptor: &RandomOpDescriptor,
346        seed: usize,
347        name: Option<&str>,
348    ) -> Option<Tensor> {
349        let name = optional_cstring(name);
350        // SAFETY: all handles remain valid for the duration of the call.
351        let ptr = unsafe {
352            ffi::mpsgraph_graph_random_tensor_shape_tensor_seed(
353                self.as_ptr(),
354                shape_tensor.as_ptr(),
355                descriptor.as_ptr(),
356                seed,
357                cstring_ptr(&name),
358            )
359        };
360        wrap_tensor(ptr)
361    }
362
363    #[must_use]
364    pub fn random_tensor_state(
365        &self,
366        shape: &[usize],
367        descriptor: &RandomOpDescriptor,
368        state: &Tensor,
369        name: Option<&str>,
370    ) -> Option<(Tensor, Tensor)> {
371        let name = optional_cstring(name);
372        let shape_ptr = if shape.is_empty() {
373            ptr::null()
374        } else {
375            shape.as_ptr()
376        };
377        // SAFETY: all handles remain valid for the duration of the call.
378        let box_handle = unsafe {
379            ffi::mpsgraph_graph_random_tensor_state(
380                self.as_ptr(),
381                shape_ptr,
382                shape.len(),
383                descriptor.as_ptr(),
384                state.as_ptr(),
385                cstring_ptr(&name),
386            )
387        };
388        wrap_tensor_pair(box_handle)
389    }
390
391    #[must_use]
392    pub fn random_tensor_shape_tensor_state(
393        &self,
394        shape_tensor: &Tensor,
395        descriptor: &RandomOpDescriptor,
396        state: &Tensor,
397        name: Option<&str>,
398    ) -> Option<(Tensor, Tensor)> {
399        let name = optional_cstring(name);
400        // SAFETY: all handles remain valid for the duration of the call.
401        let box_handle = unsafe {
402            ffi::mpsgraph_graph_random_tensor_shape_tensor_state(
403                self.as_ptr(),
404                shape_tensor.as_ptr(),
405                descriptor.as_ptr(),
406                state.as_ptr(),
407                cstring_ptr(&name),
408            )
409        };
410        wrap_tensor_pair(box_handle)
411    }
412
413    #[must_use]
414    pub fn dropout(&self, tensor: &Tensor, rate: f64, name: Option<&str>) -> Option<Tensor> {
415        let name = optional_cstring(name);
416        // SAFETY: all handles remain valid for the duration of the call.
417        let ptr = unsafe {
418            ffi::mpsgraph_graph_dropout(self.as_ptr(), tensor.as_ptr(), rate, cstring_ptr(&name))
419        };
420        wrap_tensor(ptr)
421    }
422
423    #[must_use]
424    pub fn dropout_tensor(
425        &self,
426        tensor: &Tensor,
427        rate_tensor: &Tensor,
428        name: Option<&str>,
429    ) -> Option<Tensor> {
430        let name = optional_cstring(name);
431        // SAFETY: all handles remain valid for the duration of the call.
432        let ptr = unsafe {
433            ffi::mpsgraph_graph_dropout_tensor(
434                self.as_ptr(),
435                tensor.as_ptr(),
436                rate_tensor.as_ptr(),
437                cstring_ptr(&name),
438            )
439        };
440        wrap_tensor(ptr)
441    }
442}