1use crate::error::{Error, Result};
2use crate::ffi;
3use crate::graph::Tensor;
4use crate::types::collect_owned_tensors;
5use core::ffi::{c_char, c_void};
6use core::ptr;
7use std::ffi::CString;
8
9fn optional_cstring(name: Option<&str>) -> Option<CString> {
10 name.and_then(|value| CString::new(value).ok())
11}
12
13#[allow(clippy::ref_option)]
14fn cstring_ptr(value: &Option<CString>) -> *const c_char {
15 value.as_ref().map_or(ptr::null(), |value| value.as_ptr())
16}
17
18fn wrap_tensor(ptr: *mut c_void) -> Option<Tensor> {
19 if ptr.is_null() {
20 None
21 } else {
22 Some(Tensor::from_raw(ptr))
23 }
24}
25
26fn wrap_tensor_pair(box_handle: *mut c_void) -> Option<(Tensor, Tensor)> {
27 let mut values = collect_owned_tensors(box_handle);
28 if values.len() != 2 {
29 return None;
30 }
31 let second = values.pop()?;
32 let first = values.pop()?;
33 Some((first, second))
34}
35
36pub mod random_distribution {
38 pub const UNIFORM: u64 = 0;
39 pub const NORMAL: u64 = 1;
40 pub const TRUNCATED_NORMAL: u64 = 2;
41}
42
43pub mod random_normal_sampling_method {
45 pub const INV_CDF: u64 = 0;
46 pub const BOX_MULLER: u64 = 1;
47}
48
49pub struct RandomOpDescriptor {
51 ptr: *mut c_void,
52}
53
54unsafe impl Send for RandomOpDescriptor {}
55unsafe impl Sync for RandomOpDescriptor {}
56
57impl Drop for RandomOpDescriptor {
58 fn drop(&mut self) {
59 if !self.ptr.is_null() {
60 unsafe { ffi::mpsgraph_object_release(self.ptr) };
62 self.ptr = ptr::null_mut();
63 }
64 }
65}
66
67impl RandomOpDescriptor {
68 #[must_use]
69 pub fn new(distribution: u64, data_type: u32) -> Option<Self> {
70 let ptr = unsafe { ffi::mpsgraph_random_op_descriptor_new(distribution, data_type) };
72 if ptr.is_null() {
73 None
74 } else {
75 Some(Self { ptr })
76 }
77 }
78
79 #[must_use]
80 pub(crate) const fn as_ptr(&self) -> *mut c_void {
81 self.ptr
82 }
83
84 #[must_use]
85 pub fn distribution(&self) -> u64 {
86 unsafe { ffi::mpsgraph_random_op_descriptor_distribution(self.ptr) }
88 }
89
90 pub fn set_distribution(&self, value: u64) -> Result<()> {
91 let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_distribution(self.ptr, value) };
93 if ok {
94 Ok(())
95 } else {
96 Err(Error::OperationFailed("failed to set random distribution"))
97 }
98 }
99
100 #[must_use]
101 pub fn data_type(&self) -> u32 {
102 unsafe { ffi::mpsgraph_random_op_descriptor_data_type(self.ptr) }
104 }
105
106 pub fn set_data_type(&self, value: u32) -> Result<()> {
107 let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_data_type(self.ptr, value) };
109 if ok {
110 Ok(())
111 } else {
112 Err(Error::OperationFailed("failed to set random data type"))
113 }
114 }
115
116 #[must_use]
117 pub fn min(&self) -> f32 {
118 unsafe { ffi::mpsgraph_random_op_descriptor_min(self.ptr) }
120 }
121
122 pub fn set_min(&self, value: f32) -> Result<()> {
123 let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_min(self.ptr, value) };
125 if ok {
126 Ok(())
127 } else {
128 Err(Error::OperationFailed("failed to set random min"))
129 }
130 }
131
132 #[must_use]
133 pub fn max(&self) -> f32 {
134 unsafe { ffi::mpsgraph_random_op_descriptor_max(self.ptr) }
136 }
137
138 pub fn set_max(&self, value: f32) -> Result<()> {
139 let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_max(self.ptr, value) };
141 if ok {
142 Ok(())
143 } else {
144 Err(Error::OperationFailed("failed to set random max"))
145 }
146 }
147
148 #[must_use]
149 pub fn min_integer(&self) -> isize {
150 unsafe { ffi::mpsgraph_random_op_descriptor_min_integer(self.ptr) }
152 }
153
154 pub fn set_min_integer(&self, value: isize) -> Result<()> {
155 let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_min_integer(self.ptr, value) };
157 if ok {
158 Ok(())
159 } else {
160 Err(Error::OperationFailed("failed to set random minInteger"))
161 }
162 }
163
164 #[must_use]
165 pub fn max_integer(&self) -> isize {
166 unsafe { ffi::mpsgraph_random_op_descriptor_max_integer(self.ptr) }
168 }
169
170 pub fn set_max_integer(&self, value: isize) -> Result<()> {
171 let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_max_integer(self.ptr, value) };
173 if ok {
174 Ok(())
175 } else {
176 Err(Error::OperationFailed("failed to set random maxInteger"))
177 }
178 }
179
180 #[must_use]
181 pub fn mean(&self) -> f32 {
182 unsafe { ffi::mpsgraph_random_op_descriptor_mean(self.ptr) }
184 }
185
186 pub fn set_mean(&self, value: f32) -> Result<()> {
187 let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_mean(self.ptr, value) };
189 if ok {
190 Ok(())
191 } else {
192 Err(Error::OperationFailed("failed to set random mean"))
193 }
194 }
195
196 #[must_use]
197 pub fn standard_deviation(&self) -> f32 {
198 unsafe { ffi::mpsgraph_random_op_descriptor_standard_deviation(self.ptr) }
200 }
201
202 pub fn set_standard_deviation(&self, value: f32) -> Result<()> {
203 let ok =
205 unsafe { ffi::mpsgraph_random_op_descriptor_set_standard_deviation(self.ptr, value) };
206 if ok {
207 Ok(())
208 } else {
209 Err(Error::OperationFailed(
210 "failed to set random standardDeviation",
211 ))
212 }
213 }
214
215 #[must_use]
216 pub fn sampling_method(&self) -> u64 {
217 unsafe { ffi::mpsgraph_random_op_descriptor_sampling_method(self.ptr) }
219 }
220
221 pub fn set_sampling_method(&self, value: u64) -> Result<()> {
222 let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_sampling_method(self.ptr, value) };
224 if ok {
225 Ok(())
226 } else {
227 Err(Error::OperationFailed(
228 "failed to set random sampling method",
229 ))
230 }
231 }
232}
233
234impl crate::graph::Graph {
235 #[must_use]
236 pub fn random_philox_state_seed(&self, seed: usize, name: Option<&str>) -> Option<Tensor> {
237 let name = optional_cstring(name);
238 let ptr = unsafe {
240 ffi::mpsgraph_graph_random_philox_state_seed(self.as_ptr(), seed, cstring_ptr(&name))
241 };
242 wrap_tensor(ptr)
243 }
244
245 #[must_use]
246 pub fn random_philox_state_counter(
247 &self,
248 counter_low: usize,
249 counter_high: usize,
250 key: usize,
251 name: Option<&str>,
252 ) -> Option<Tensor> {
253 let name = optional_cstring(name);
254 let ptr = unsafe {
256 ffi::mpsgraph_graph_random_philox_state_counter(
257 self.as_ptr(),
258 counter_low,
259 counter_high,
260 key,
261 cstring_ptr(&name),
262 )
263 };
264 wrap_tensor(ptr)
265 }
266
267 #[must_use]
268 pub fn random_tensor(
269 &self,
270 shape: &[usize],
271 descriptor: &RandomOpDescriptor,
272 name: Option<&str>,
273 ) -> Option<Tensor> {
274 let name = optional_cstring(name);
275 let shape_ptr = if shape.is_empty() {
276 ptr::null()
277 } else {
278 shape.as_ptr()
279 };
280 let ptr = unsafe {
282 ffi::mpsgraph_graph_random_tensor(
283 self.as_ptr(),
284 shape_ptr,
285 shape.len(),
286 descriptor.as_ptr(),
287 cstring_ptr(&name),
288 )
289 };
290 wrap_tensor(ptr)
291 }
292
293 #[must_use]
294 pub fn random_tensor_shape_tensor(
295 &self,
296 shape_tensor: &Tensor,
297 descriptor: &RandomOpDescriptor,
298 name: Option<&str>,
299 ) -> Option<Tensor> {
300 let name = optional_cstring(name);
301 let ptr = unsafe {
303 ffi::mpsgraph_graph_random_tensor_shape_tensor(
304 self.as_ptr(),
305 shape_tensor.as_ptr(),
306 descriptor.as_ptr(),
307 cstring_ptr(&name),
308 )
309 };
310 wrap_tensor(ptr)
311 }
312
313 #[must_use]
314 pub fn random_tensor_seed(
315 &self,
316 shape: &[usize],
317 descriptor: &RandomOpDescriptor,
318 seed: usize,
319 name: Option<&str>,
320 ) -> Option<Tensor> {
321 let name = optional_cstring(name);
322 let shape_ptr = if shape.is_empty() {
323 ptr::null()
324 } else {
325 shape.as_ptr()
326 };
327 let ptr = unsafe {
329 ffi::mpsgraph_graph_random_tensor_seed(
330 self.as_ptr(),
331 shape_ptr,
332 shape.len(),
333 descriptor.as_ptr(),
334 seed,
335 cstring_ptr(&name),
336 )
337 };
338 wrap_tensor(ptr)
339 }
340
341 #[must_use]
342 pub fn random_tensor_shape_tensor_seed(
343 &self,
344 shape_tensor: &Tensor,
345 descriptor: &RandomOpDescriptor,
346 seed: usize,
347 name: Option<&str>,
348 ) -> Option<Tensor> {
349 let name = optional_cstring(name);
350 let ptr = unsafe {
352 ffi::mpsgraph_graph_random_tensor_shape_tensor_seed(
353 self.as_ptr(),
354 shape_tensor.as_ptr(),
355 descriptor.as_ptr(),
356 seed,
357 cstring_ptr(&name),
358 )
359 };
360 wrap_tensor(ptr)
361 }
362
363 #[must_use]
364 pub fn random_tensor_state(
365 &self,
366 shape: &[usize],
367 descriptor: &RandomOpDescriptor,
368 state: &Tensor,
369 name: Option<&str>,
370 ) -> Option<(Tensor, Tensor)> {
371 let name = optional_cstring(name);
372 let shape_ptr = if shape.is_empty() {
373 ptr::null()
374 } else {
375 shape.as_ptr()
376 };
377 let box_handle = unsafe {
379 ffi::mpsgraph_graph_random_tensor_state(
380 self.as_ptr(),
381 shape_ptr,
382 shape.len(),
383 descriptor.as_ptr(),
384 state.as_ptr(),
385 cstring_ptr(&name),
386 )
387 };
388 wrap_tensor_pair(box_handle)
389 }
390
391 #[must_use]
392 pub fn random_tensor_shape_tensor_state(
393 &self,
394 shape_tensor: &Tensor,
395 descriptor: &RandomOpDescriptor,
396 state: &Tensor,
397 name: Option<&str>,
398 ) -> Option<(Tensor, Tensor)> {
399 let name = optional_cstring(name);
400 let box_handle = unsafe {
402 ffi::mpsgraph_graph_random_tensor_shape_tensor_state(
403 self.as_ptr(),
404 shape_tensor.as_ptr(),
405 descriptor.as_ptr(),
406 state.as_ptr(),
407 cstring_ptr(&name),
408 )
409 };
410 wrap_tensor_pair(box_handle)
411 }
412
413 #[must_use]
414 pub fn dropout(&self, tensor: &Tensor, rate: f64, name: Option<&str>) -> Option<Tensor> {
415 let name = optional_cstring(name);
416 let ptr = unsafe {
418 ffi::mpsgraph_graph_dropout(self.as_ptr(), tensor.as_ptr(), rate, cstring_ptr(&name))
419 };
420 wrap_tensor(ptr)
421 }
422
423 #[must_use]
424 pub fn dropout_tensor(
425 &self,
426 tensor: &Tensor,
427 rate_tensor: &Tensor,
428 name: Option<&str>,
429 ) -> Option<Tensor> {
430 let name = optional_cstring(name);
431 let ptr = unsafe {
433 ffi::mpsgraph_graph_dropout_tensor(
434 self.as_ptr(),
435 tensor.as_ptr(),
436 rate_tensor.as_ptr(),
437 cstring_ptr(&name),
438 )
439 };
440 wrap_tensor(ptr)
441 }
442}