1use crate::error::{Error, Result};
2use crate::ffi;
3use crate::graph::Tensor;
4use crate::types::collect_owned_tensors;
5use core::ffi::{c_char, c_void};
6use core::ptr;
7use std::ffi::CString;
8
9fn optional_cstring(name: Option<&str>) -> Option<CString> {
10 name.and_then(|value| CString::new(value).ok())
11}
12
13#[allow(clippy::ref_option)]
14fn cstring_ptr(value: &Option<CString>) -> *const c_char {
15 value.as_ref().map_or(ptr::null(), |value| value.as_ptr())
16}
17
18fn wrap_tensor(ptr: *mut c_void) -> Option<Tensor> {
19 if ptr.is_null() {
20 None
21 } else {
22 Some(Tensor::from_raw(ptr))
23 }
24}
25
26fn wrap_tensor_pair(box_handle: *mut c_void) -> Option<(Tensor, Tensor)> {
27 let mut values = collect_owned_tensors(box_handle);
28 if values.len() != 2 {
29 return None;
30 }
31 let second = values.pop()?;
32 let first = values.pop()?;
33 Some((first, second))
34}
35
36pub mod random_distribution {
38 pub const UNIFORM: u64 = 0;
39 pub const NORMAL: u64 = 1;
40 pub const TRUNCATED_NORMAL: u64 = 2;
41}
42
43pub mod random_normal_sampling_method {
45 pub const INV_CDF: u64 = 0;
46 pub const BOX_MULLER: u64 = 1;
47}
48
49pub struct RandomOpDescriptor {
51 ptr: *mut c_void,
52}
53
54unsafe impl Send for RandomOpDescriptor {}
55unsafe impl Sync for RandomOpDescriptor {}
56
57impl Drop for RandomOpDescriptor {
58 fn drop(&mut self) {
59 if !self.ptr.is_null() {
60 unsafe { ffi::mpsgraph_object_release(self.ptr) };
62 self.ptr = ptr::null_mut();
63 }
64 }
65}
66
67impl RandomOpDescriptor {
68 #[must_use]
69 pub fn new(distribution: u64, data_type: u32) -> Option<Self> {
70 let ptr = unsafe { ffi::mpsgraph_random_op_descriptor_new(distribution, data_type) };
72 if ptr.is_null() {
73 None
74 } else {
75 Some(Self { ptr })
76 }
77 }
78
79 #[must_use]
80 pub(crate) const fn as_ptr(&self) -> *mut c_void {
81 self.ptr
82 }
83
84 #[must_use]
85 pub fn distribution(&self) -> u64 {
86 unsafe { ffi::mpsgraph_random_op_descriptor_distribution(self.ptr) }
88 }
89
90 pub fn set_distribution(&self, value: u64) -> Result<()> {
91 let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_distribution(self.ptr, value) };
93 if ok {
94 Ok(())
95 } else {
96 Err(Error::OperationFailed("failed to set random distribution"))
97 }
98 }
99
100 #[must_use]
101 pub fn data_type(&self) -> u32 {
102 unsafe { ffi::mpsgraph_random_op_descriptor_data_type(self.ptr) }
104 }
105
106 pub fn set_data_type(&self, value: u32) -> Result<()> {
107 let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_data_type(self.ptr, value) };
109 if ok {
110 Ok(())
111 } else {
112 Err(Error::OperationFailed("failed to set random data type"))
113 }
114 }
115
116 #[must_use]
117 pub fn min(&self) -> f32 {
118 unsafe { ffi::mpsgraph_random_op_descriptor_min(self.ptr) }
120 }
121
122 pub fn set_min(&self, value: f32) -> Result<()> {
123 let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_min(self.ptr, value) };
125 if ok {
126 Ok(())
127 } else {
128 Err(Error::OperationFailed("failed to set random min"))
129 }
130 }
131
132 #[must_use]
133 pub fn max(&self) -> f32 {
134 unsafe { ffi::mpsgraph_random_op_descriptor_max(self.ptr) }
136 }
137
138 pub fn set_max(&self, value: f32) -> Result<()> {
139 let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_max(self.ptr, value) };
141 if ok {
142 Ok(())
143 } else {
144 Err(Error::OperationFailed("failed to set random max"))
145 }
146 }
147
148 #[must_use]
149 pub fn min_integer(&self) -> isize {
150 unsafe { ffi::mpsgraph_random_op_descriptor_min_integer(self.ptr) }
152 }
153
154 pub fn set_min_integer(&self, value: isize) -> Result<()> {
155 let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_min_integer(self.ptr, value) };
157 if ok {
158 Ok(())
159 } else {
160 Err(Error::OperationFailed("failed to set random minInteger"))
161 }
162 }
163
164 #[must_use]
165 pub fn max_integer(&self) -> isize {
166 unsafe { ffi::mpsgraph_random_op_descriptor_max_integer(self.ptr) }
168 }
169
170 pub fn set_max_integer(&self, value: isize) -> Result<()> {
171 let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_max_integer(self.ptr, value) };
173 if ok {
174 Ok(())
175 } else {
176 Err(Error::OperationFailed("failed to set random maxInteger"))
177 }
178 }
179
180 #[must_use]
181 pub fn mean(&self) -> f32 {
182 unsafe { ffi::mpsgraph_random_op_descriptor_mean(self.ptr) }
184 }
185
186 pub fn set_mean(&self, value: f32) -> Result<()> {
187 let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_mean(self.ptr, value) };
189 if ok {
190 Ok(())
191 } else {
192 Err(Error::OperationFailed("failed to set random mean"))
193 }
194 }
195
196 #[must_use]
197 pub fn standard_deviation(&self) -> f32 {
198 unsafe { ffi::mpsgraph_random_op_descriptor_standard_deviation(self.ptr) }
200 }
201
202 pub fn set_standard_deviation(&self, value: f32) -> Result<()> {
203 let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_standard_deviation(self.ptr, value) };
205 if ok {
206 Ok(())
207 } else {
208 Err(Error::OperationFailed("failed to set random standardDeviation"))
209 }
210 }
211
212 #[must_use]
213 pub fn sampling_method(&self) -> u64 {
214 unsafe { ffi::mpsgraph_random_op_descriptor_sampling_method(self.ptr) }
216 }
217
218 pub fn set_sampling_method(&self, value: u64) -> Result<()> {
219 let ok = unsafe { ffi::mpsgraph_random_op_descriptor_set_sampling_method(self.ptr, value) };
221 if ok {
222 Ok(())
223 } else {
224 Err(Error::OperationFailed("failed to set random sampling method"))
225 }
226 }
227}
228
229impl crate::graph::Graph {
230 #[must_use]
231 pub fn random_philox_state_seed(&self, seed: usize, name: Option<&str>) -> Option<Tensor> {
232 let name = optional_cstring(name);
233 let ptr = unsafe { ffi::mpsgraph_graph_random_philox_state_seed(self.as_ptr(), seed, cstring_ptr(&name)) };
235 wrap_tensor(ptr)
236 }
237
238 #[must_use]
239 pub fn random_philox_state_counter(
240 &self,
241 counter_low: usize,
242 counter_high: usize,
243 key: usize,
244 name: Option<&str>,
245 ) -> Option<Tensor> {
246 let name = optional_cstring(name);
247 let ptr = unsafe {
249 ffi::mpsgraph_graph_random_philox_state_counter(
250 self.as_ptr(),
251 counter_low,
252 counter_high,
253 key,
254 cstring_ptr(&name),
255 )
256 };
257 wrap_tensor(ptr)
258 }
259
260 #[must_use]
261 pub fn random_tensor(
262 &self,
263 shape: &[usize],
264 descriptor: &RandomOpDescriptor,
265 name: Option<&str>,
266 ) -> Option<Tensor> {
267 let name = optional_cstring(name);
268 let shape_ptr = if shape.is_empty() { ptr::null() } else { shape.as_ptr() };
269 let ptr = unsafe {
271 ffi::mpsgraph_graph_random_tensor(
272 self.as_ptr(),
273 shape_ptr,
274 shape.len(),
275 descriptor.as_ptr(),
276 cstring_ptr(&name),
277 )
278 };
279 wrap_tensor(ptr)
280 }
281
282 #[must_use]
283 pub fn random_tensor_shape_tensor(
284 &self,
285 shape_tensor: &Tensor,
286 descriptor: &RandomOpDescriptor,
287 name: Option<&str>,
288 ) -> Option<Tensor> {
289 let name = optional_cstring(name);
290 let ptr = unsafe {
292 ffi::mpsgraph_graph_random_tensor_shape_tensor(
293 self.as_ptr(),
294 shape_tensor.as_ptr(),
295 descriptor.as_ptr(),
296 cstring_ptr(&name),
297 )
298 };
299 wrap_tensor(ptr)
300 }
301
302 #[must_use]
303 pub fn random_tensor_seed(
304 &self,
305 shape: &[usize],
306 descriptor: &RandomOpDescriptor,
307 seed: usize,
308 name: Option<&str>,
309 ) -> Option<Tensor> {
310 let name = optional_cstring(name);
311 let shape_ptr = if shape.is_empty() { ptr::null() } else { shape.as_ptr() };
312 let ptr = unsafe {
314 ffi::mpsgraph_graph_random_tensor_seed(
315 self.as_ptr(),
316 shape_ptr,
317 shape.len(),
318 descriptor.as_ptr(),
319 seed,
320 cstring_ptr(&name),
321 )
322 };
323 wrap_tensor(ptr)
324 }
325
326 #[must_use]
327 pub fn random_tensor_shape_tensor_seed(
328 &self,
329 shape_tensor: &Tensor,
330 descriptor: &RandomOpDescriptor,
331 seed: usize,
332 name: Option<&str>,
333 ) -> Option<Tensor> {
334 let name = optional_cstring(name);
335 let ptr = unsafe {
337 ffi::mpsgraph_graph_random_tensor_shape_tensor_seed(
338 self.as_ptr(),
339 shape_tensor.as_ptr(),
340 descriptor.as_ptr(),
341 seed,
342 cstring_ptr(&name),
343 )
344 };
345 wrap_tensor(ptr)
346 }
347
348 #[must_use]
349 pub fn random_tensor_state(
350 &self,
351 shape: &[usize],
352 descriptor: &RandomOpDescriptor,
353 state: &Tensor,
354 name: Option<&str>,
355 ) -> Option<(Tensor, Tensor)> {
356 let name = optional_cstring(name);
357 let shape_ptr = if shape.is_empty() { ptr::null() } else { shape.as_ptr() };
358 let box_handle = unsafe {
360 ffi::mpsgraph_graph_random_tensor_state(
361 self.as_ptr(),
362 shape_ptr,
363 shape.len(),
364 descriptor.as_ptr(),
365 state.as_ptr(),
366 cstring_ptr(&name),
367 )
368 };
369 wrap_tensor_pair(box_handle)
370 }
371
372 #[must_use]
373 pub fn random_tensor_shape_tensor_state(
374 &self,
375 shape_tensor: &Tensor,
376 descriptor: &RandomOpDescriptor,
377 state: &Tensor,
378 name: Option<&str>,
379 ) -> Option<(Tensor, Tensor)> {
380 let name = optional_cstring(name);
381 let box_handle = unsafe {
383 ffi::mpsgraph_graph_random_tensor_shape_tensor_state(
384 self.as_ptr(),
385 shape_tensor.as_ptr(),
386 descriptor.as_ptr(),
387 state.as_ptr(),
388 cstring_ptr(&name),
389 )
390 };
391 wrap_tensor_pair(box_handle)
392 }
393
394 #[must_use]
395 pub fn dropout(&self, tensor: &Tensor, rate: f64, name: Option<&str>) -> Option<Tensor> {
396 let name = optional_cstring(name);
397 let ptr = unsafe {
399 ffi::mpsgraph_graph_dropout(self.as_ptr(), tensor.as_ptr(), rate, cstring_ptr(&name))
400 };
401 wrap_tensor(ptr)
402 }
403
404 #[must_use]
405 pub fn dropout_tensor(
406 &self,
407 tensor: &Tensor,
408 rate_tensor: &Tensor,
409 name: Option<&str>,
410 ) -> Option<Tensor> {
411 let name = optional_cstring(name);
412 let ptr = unsafe {
414 ffi::mpsgraph_graph_dropout_tensor(
415 self.as_ptr(),
416 tensor.as_ptr(),
417 rate_tensor.as_ptr(),
418 cstring_ptr(&name),
419 )
420 };
421 wrap_tensor(ptr)
422 }
423}