1use crate::error::{Error, Result};
2use crate::ffi;
3use crate::graph::Tensor;
4use crate::types::collect_owned_tensors;
5use core::ffi::{c_char, c_void};
6use core::ptr;
7use std::ffi::CString;
8
9fn optional_cstring(name: Option<&str>) -> Option<CString> {
10 name.and_then(|value| CString::new(value).ok())
11}
12
13#[allow(clippy::ref_option)]
14fn cstring_ptr(value: &Option<CString>) -> *const c_char {
15 value.as_ref().map_or(ptr::null(), |value| value.as_ptr())
16}
17
18fn wrap_tensor(ptr: *mut c_void) -> Option<Tensor> {
19 if ptr.is_null() {
20 None
21 } else {
22 Some(Tensor::from_raw(ptr))
23 }
24}
25
26fn wrap_tensor_pair(box_handle: *mut c_void) -> Option<(Tensor, Tensor)> {
27 let mut values = collect_owned_tensors(box_handle);
28 if values.len() != 2 {
29 return None;
30 }
31 let second = values.pop()?;
32 let first = values.pop()?;
33 Some((first, second))
34}
35
36pub mod random_distribution {
38pub const UNIFORM: u64 = 0;
40pub const NORMAL: u64 = 1;
42pub const TRUNCATED_NORMAL: u64 = 2;
44}
45
46pub mod random_normal_sampling_method {
48pub const INV_CDF: u64 = 0;
50pub const BOX_MULLER: u64 = 1;
52}
53
54pub struct RandomOpDescriptor {
56 ptr: *mut c_void,
57}
58
59unsafe impl Send for RandomOpDescriptor {}
60unsafe impl Sync for RandomOpDescriptor {}
61
62impl Drop for RandomOpDescriptor {
63 fn drop(&mut self) {
64 if !self.ptr.is_null() {
65 unsafe { ffi::mpsgraph_object_release(self.ptr) };
67 self.ptr = ptr::null_mut();
68 }
69 }
70}
71
72impl RandomOpDescriptor {
73#[must_use]
75 pub fn new(distribution: u64, data_type: u32) -> Option<Self> {
76 let ptr = unsafe { ffi::mpsgraph_random_op_descriptor_new(distribution, data_type) };
78 if ptr.is_null() {
79 None
80 } else {
81 Some(Self { ptr })
82 }
83 }
84
85 #[must_use]
86 pub(crate) const fn as_ptr(&self) -> *mut c_void {
87 self.ptr
88 }
89
90#[must_use]
92 pub fn distribution(&self) -> u64 {
93 unsafe { ffi::mpsgraph_random_op_descriptor_distribution(self.ptr) }
95 }
96
97pub fn set_distribution(&self, value: u64) -> Result<()> {
99 let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_distribution(self.ptr, value) };
101 if ok {
102 Ok(())
103 } else {
104 Err(Error::OperationFailed("failed to set random distribution"))
105 }
106 }
107
108#[must_use]
110 pub fn data_type(&self) -> u32 {
111 unsafe { ffi::mpsgraph_random_op_descriptor_data_type(self.ptr) }
113 }
114
115pub fn set_data_type(&self, value: u32) -> Result<()> {
117 let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_data_type(self.ptr, value) };
119 if ok {
120 Ok(())
121 } else {
122 Err(Error::OperationFailed("failed to set random data type"))
123 }
124 }
125
126#[must_use]
128 pub fn min(&self) -> f32 {
129 unsafe { ffi::mpsgraph_random_op_descriptor_min(self.ptr) }
131 }
132
133pub fn set_min(&self, value: f32) -> Result<()> {
135 let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_min(self.ptr, value) };
137 if ok {
138 Ok(())
139 } else {
140 Err(Error::OperationFailed("failed to set random min"))
141 }
142 }
143
144#[must_use]
146 pub fn max(&self) -> f32 {
147 unsafe { ffi::mpsgraph_random_op_descriptor_max(self.ptr) }
149 }
150
151pub fn set_max(&self, value: f32) -> Result<()> {
153 let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_max(self.ptr, value) };
155 if ok {
156 Ok(())
157 } else {
158 Err(Error::OperationFailed("failed to set random max"))
159 }
160 }
161
162#[must_use]
164 pub fn min_integer(&self) -> isize {
165 unsafe { ffi::mpsgraph_random_op_descriptor_min_integer(self.ptr) }
167 }
168
169pub fn set_min_integer(&self, value: isize) -> Result<()> {
171 let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_min_integer(self.ptr, value) };
173 if ok {
174 Ok(())
175 } else {
176 Err(Error::OperationFailed("failed to set random minInteger"))
177 }
178 }
179
180#[must_use]
182 pub fn max_integer(&self) -> isize {
183 unsafe { ffi::mpsgraph_random_op_descriptor_max_integer(self.ptr) }
185 }
186
187pub fn set_max_integer(&self, value: isize) -> Result<()> {
189 let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_max_integer(self.ptr, value) };
191 if ok {
192 Ok(())
193 } else {
194 Err(Error::OperationFailed("failed to set random maxInteger"))
195 }
196 }
197
198#[must_use]
200 pub fn mean(&self) -> f32 {
201 unsafe { ffi::mpsgraph_random_op_descriptor_mean(self.ptr) }
203 }
204
205pub fn set_mean(&self, value: f32) -> Result<()> {
207 let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_mean(self.ptr, value) };
209 if ok {
210 Ok(())
211 } else {
212 Err(Error::OperationFailed("failed to set random mean"))
213 }
214 }
215
216#[must_use]
218 pub fn standard_deviation(&self) -> f32 {
219 unsafe { ffi::mpsgraph_random_op_descriptor_standard_deviation(self.ptr) }
221 }
222
223pub fn set_standard_deviation(&self, value: f32) -> Result<()> {
225 let ok =
227 unsafe { ffi::mpsgraph_random_op_descriptor_set_standard_deviation(self.ptr, value) };
228 if ok {
229 Ok(())
230 } else {
231 Err(Error::OperationFailed(
232 "failed to set random standardDeviation",
233 ))
234 }
235 }
236
237#[must_use]
239 pub fn sampling_method(&self) -> u64 {
240 unsafe { ffi::mpsgraph_random_op_descriptor_sampling_method(self.ptr) }
242 }
243
244pub fn set_sampling_method(&self, value: u64) -> Result<()> {
246 let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_sampling_method(self.ptr, value) };
248 if ok {
249 Ok(())
250 } else {
251 Err(Error::OperationFailed(
252 "failed to set random sampling method",
253 ))
254 }
255 }
256}
257
258impl crate::graph::Graph {
259#[must_use]
261 pub fn random_philox_state_seed(&self, seed: usize, name: Option<&str>) -> Option<Tensor> {
262 let name = optional_cstring(name);
263 let ptr = unsafe {
265 ffi::mpsgraph_graph_random_philox_state_seed(self.as_ptr(), seed, cstring_ptr(&name))
266 };
267 wrap_tensor(ptr)
268 }
269
270#[must_use]
272 pub fn random_philox_state_counter(
273 &self,
274 counter_low: usize,
275 counter_high: usize,
276 key: usize,
277 name: Option<&str>,
278 ) -> Option<Tensor> {
279 let name = optional_cstring(name);
280 let ptr = unsafe {
282 ffi::mpsgraph_graph_random_philox_state_counter(
283 self.as_ptr(),
284 counter_low,
285 counter_high,
286 key,
287 cstring_ptr(&name),
288 )
289 };
290 wrap_tensor(ptr)
291 }
292
293#[must_use]
295 pub fn random_tensor(
296 &self,
297 shape: &[usize],
298 descriptor: &RandomOpDescriptor,
299 name: Option<&str>,
300 ) -> Option<Tensor> {
301 let name = optional_cstring(name);
302 let shape_ptr = if shape.is_empty() {
303 ptr::null()
304 } else {
305 shape.as_ptr()
306 };
307 let ptr = unsafe {
309 ffi::mpsgraph_graph_random_tensor(
310 self.as_ptr(),
311 shape_ptr,
312 shape.len(),
313 descriptor.as_ptr(),
314 cstring_ptr(&name),
315 )
316 };
317 wrap_tensor(ptr)
318 }
319
320#[must_use]
322 pub fn random_tensor_shape_tensor(
323 &self,
324 shape_tensor: &Tensor,
325 descriptor: &RandomOpDescriptor,
326 name: Option<&str>,
327 ) -> Option<Tensor> {
328 let name = optional_cstring(name);
329 let ptr = unsafe {
331 ffi::mpsgraph_graph_random_tensor_shape_tensor(
332 self.as_ptr(),
333 shape_tensor.as_ptr(),
334 descriptor.as_ptr(),
335 cstring_ptr(&name),
336 )
337 };
338 wrap_tensor(ptr)
339 }
340
341#[must_use]
343 pub fn random_tensor_seed(
344 &self,
345 shape: &[usize],
346 descriptor: &RandomOpDescriptor,
347 seed: usize,
348 name: Option<&str>,
349 ) -> Option<Tensor> {
350 let name = optional_cstring(name);
351 let shape_ptr = if shape.is_empty() {
352 ptr::null()
353 } else {
354 shape.as_ptr()
355 };
356 let ptr = unsafe {
358 ffi::mpsgraph_graph_random_tensor_seed(
359 self.as_ptr(),
360 shape_ptr,
361 shape.len(),
362 descriptor.as_ptr(),
363 seed,
364 cstring_ptr(&name),
365 )
366 };
367 wrap_tensor(ptr)
368 }
369
370#[must_use]
372 pub fn random_tensor_shape_tensor_seed(
373 &self,
374 shape_tensor: &Tensor,
375 descriptor: &RandomOpDescriptor,
376 seed: usize,
377 name: Option<&str>,
378 ) -> Option<Tensor> {
379 let name = optional_cstring(name);
380 let ptr = unsafe {
382 ffi::mpsgraph_graph_random_tensor_shape_tensor_seed(
383 self.as_ptr(),
384 shape_tensor.as_ptr(),
385 descriptor.as_ptr(),
386 seed,
387 cstring_ptr(&name),
388 )
389 };
390 wrap_tensor(ptr)
391 }
392
393#[must_use]
395 pub fn random_tensor_state(
396 &self,
397 shape: &[usize],
398 descriptor: &RandomOpDescriptor,
399 state: &Tensor,
400 name: Option<&str>,
401 ) -> Option<(Tensor, Tensor)> {
402 let name = optional_cstring(name);
403 let shape_ptr = if shape.is_empty() {
404 ptr::null()
405 } else {
406 shape.as_ptr()
407 };
408 let box_handle = unsafe {
410 ffi::mpsgraph_graph_random_tensor_state(
411 self.as_ptr(),
412 shape_ptr,
413 shape.len(),
414 descriptor.as_ptr(),
415 state.as_ptr(),
416 cstring_ptr(&name),
417 )
418 };
419 wrap_tensor_pair(box_handle)
420 }
421
422#[must_use]
424 pub fn random_tensor_shape_tensor_state(
425 &self,
426 shape_tensor: &Tensor,
427 descriptor: &RandomOpDescriptor,
428 state: &Tensor,
429 name: Option<&str>,
430 ) -> Option<(Tensor, Tensor)> {
431 let name = optional_cstring(name);
432 let box_handle = unsafe {
434 ffi::mpsgraph_graph_random_tensor_shape_tensor_state(
435 self.as_ptr(),
436 shape_tensor.as_ptr(),
437 descriptor.as_ptr(),
438 state.as_ptr(),
439 cstring_ptr(&name),
440 )
441 };
442 wrap_tensor_pair(box_handle)
443 }
444
445#[must_use]
447 pub fn dropout(&self, tensor: &Tensor, rate: f64, name: Option<&str>) -> Option<Tensor> {
448 let name = optional_cstring(name);
449 let ptr = unsafe {
451 ffi::mpsgraph_graph_dropout(self.as_ptr(), tensor.as_ptr(), rate, cstring_ptr(&name))
452 };
453 wrap_tensor(ptr)
454 }
455
456#[must_use]
458 pub fn dropout_tensor(
459 &self,
460 tensor: &Tensor,
461 rate_tensor: &Tensor,
462 name: Option<&str>,
463 ) -> Option<Tensor> {
464 let name = optional_cstring(name);
465 let ptr = unsafe {
467 ffi::mpsgraph_graph_dropout_tensor(
468 self.as_ptr(),
469 tensor.as_ptr(),
470 rate_tensor.as_ptr(),
471 cstring_ptr(&name),
472 )
473 };
474 wrap_tensor(ptr)
475 }
476}