Skip to main content

RandomOpDescriptor

Struct RandomOpDescriptor 

Source
pub struct RandomOpDescriptor { /* private fields */ }
Expand description

Safe owner for MPSGraphRandomOpDescriptor.

Implementations§

Source§

impl RandomOpDescriptor

Source

pub fn new(distribution: u64, data_type: u32) -> Option<Self>

Calls the MPSGraph framework counterpart for new.

Examples found in repository?
examples/07_gather_random_rnn.rs (line 51)
15fn main() {
16    let graph = Graph::new().expect("graph");
17    let updates = graph
18        .constant_f32_slice(&[10.0, 20.0, 30.0, 40.0, 50.0, 60.0], &[2, 3])
19        .expect("updates");
20    let gather_indices = graph
21        .constant_bytes(&i32_bytes(&[2, 0]), &[2], data_type::INT32)
22        .expect("gather indices");
23    let gather_nd_indices = graph
24        .constant_bytes(&i32_bytes(&[0, 1, 1, 0]), &[2, 2], data_type::INT32)
25        .expect("gather nd indices");
26    let along_indices = graph
27        .constant_bytes(&i32_bytes(&[2, 1, 0, 0, 1, 2]), &[2, 3], data_type::INT32)
28        .expect("gather along indices");
29    let axis_tensor = graph
30        .constant_scalar(1.0, data_type::INT32)
31        .expect("axis tensor");
32
33    let gather = graph
34        .gather(&updates, &gather_indices, 1, 0, Some("gather"))
35        .expect("gather");
36    let gather_nd = graph
37        .gather_nd(&updates, &gather_nd_indices, 0, Some("gather_nd"))
38        .expect("gather nd");
39    let gather_axis = graph
40        .gather_along_axis(1, &updates, &along_indices, Some("gather_axis"))
41        .expect("gather along axis");
42    let gather_axis_tensor = graph
43        .gather_along_axis_tensor(
44            &axis_tensor,
45            &updates,
46            &along_indices,
47            Some("gather_axis_tensor"),
48        )
49        .expect("gather along axis tensor");
50
51    let descriptor = RandomOpDescriptor::new(random_distribution::UNIFORM, data_type::FLOAT32)
52        .expect("random descriptor");
53    descriptor.set_min(0.0).expect("random min");
54    descriptor.set_max(1.0).expect("random max");
55    let random = graph
56        .random_tensor_seed(&[4], &descriptor, 7, Some("random"))
57        .expect("random tensor");
58    let dropout = graph
59        .dropout(&updates, 1.0, Some("dropout"))
60        .expect("dropout");
61
62    let single_gate_descriptor = SingleGateRNNDescriptor::new().expect("single gate descriptor");
63    single_gate_descriptor
64        .set_activation(rnn_activation::RELU)
65        .expect("single gate activation");
66    let single_gate_source = graph
67        .constant_f32_slice(&[0.5], &[1, 1, 1])
68        .expect("single gate source");
69    let single_gate_recurrent = graph
70        .constant_f32_slice(&[0.0], &[1, 1])
71        .expect("single gate recurrent");
72    let single_gate = graph
73        .single_gate_rnn(
74            &single_gate_source,
75            &single_gate_recurrent,
76            None,
77            None,
78            None,
79            None,
80            &single_gate_descriptor,
81            Some("single_gate"),
82        )
83        .expect("single gate rnn");
84
85    let lstm_descriptor = LSTMDescriptor::new().expect("lstm descriptor");
86    lstm_descriptor
87        .set_produce_cell(true)
88        .expect("set produce cell");
89    let lstm_source = graph
90        .constant_f32_slice(&[0.0; 4], &[1, 1, 4])
91        .expect("lstm source");
92    let lstm_recurrent = graph
93        .constant_f32_slice(&[0.0; 4], &[4, 1])
94        .expect("lstm recurrent");
95    let lstm = graph
96        .lstm(
97            &lstm_source,
98            &lstm_recurrent,
99            None,
100            None,
101            None,
102            None,
103            None,
104            None,
105            &lstm_descriptor,
106            Some("lstm"),
107        )
108        .expect("lstm");
109
110    let gru_descriptor = GRUDescriptor::new().expect("gru descriptor");
111    gru_descriptor.set_training(true).expect("set gru training");
112    gru_descriptor
113        .set_reset_after(true)
114        .expect("set gru reset_after");
115    let gru_source = graph
116        .constant_f32_slice(&[0.0; 3], &[1, 1, 3])
117        .expect("gru source");
118    let gru_recurrent = graph
119        .constant_f32_slice(&[0.0; 3], &[3, 1])
120        .expect("gru recurrent");
121    let gru_secondary_bias = graph
122        .constant_f32_slice(&[0.0], &[1])
123        .expect("gru secondary bias");
124    let gru = graph
125        .gru(
126            &gru_source,
127            &gru_recurrent,
128            None,
129            None,
130            None,
131            None,
132            Some(&gru_secondary_bias),
133            &gru_descriptor,
134            Some("gru"),
135        )
136        .expect("gru");
137
138    let results = graph
139        .run(
140            &[],
141            &[
142                &gather,
143                &gather_nd,
144                &gather_axis,
145                &gather_axis_tensor,
146                &random,
147                &dropout,
148                &single_gate[0],
149                &lstm[0],
150                &lstm[1],
151                &gru[0],
152                &gru[1],
153            ],
154        )
155        .expect("run graph");
156
157    println!("gather: {:?}", results[0].read_f32().expect("gather"));
158    println!("gather_nd: {:?}", results[1].read_f32().expect("gather_nd"));
159    println!(
160        "gather_axis: {:?}",
161        results[2].read_f32().expect("gather_axis")
162    );
163    println!(
164        "gather_axis_tensor: {:?}",
165        results[3].read_f32().expect("gather_axis_tensor")
166    );
167    println!("random: {:?}", results[4].read_f32().expect("random"));
168    println!("dropout: {:?}", results[5].read_f32().expect("dropout"));
169    println!(
170        "single_gate: {:?}",
171        results[6].read_f32().expect("single_gate")
172    );
173    println!(
174        "lstm state: {:?}",
175        results[7].read_f32().expect("lstm state")
176    );
177    println!("lstm cell: {:?}", results[8].read_f32().expect("lstm cell"));
178    println!("gru state: {:?}", results[9].read_f32().expect("gru state"));
179    println!(
180        "gru training: {:?}",
181        results[10].read_f32().expect("gru training")
182    );
183}
Source

pub fn distribution(&self) -> u64

Calls the MPSGraph framework counterpart for distribution.

Source

pub fn set_distribution(&self, value: u64) -> Result<()>

Calls the MPSGraph framework counterpart for set_distribution.

Source

pub fn data_type(&self) -> u32

Calls the MPSGraph framework counterpart for data_type.

Source

pub fn set_data_type(&self, value: u32) -> Result<()>

Calls the MPSGraph framework counterpart for set_data_type.

Source

pub fn min(&self) -> f32

Calls the MPSGraph framework counterpart for min.

Source

pub fn set_min(&self, value: f32) -> Result<()>

Calls the MPSGraph framework counterpart for set_min.

Examples found in repository?
examples/07_gather_random_rnn.rs (line 53)
15fn main() {
16    let graph = Graph::new().expect("graph");
17    let updates = graph
18        .constant_f32_slice(&[10.0, 20.0, 30.0, 40.0, 50.0, 60.0], &[2, 3])
19        .expect("updates");
20    let gather_indices = graph
21        .constant_bytes(&i32_bytes(&[2, 0]), &[2], data_type::INT32)
22        .expect("gather indices");
23    let gather_nd_indices = graph
24        .constant_bytes(&i32_bytes(&[0, 1, 1, 0]), &[2, 2], data_type::INT32)
25        .expect("gather nd indices");
26    let along_indices = graph
27        .constant_bytes(&i32_bytes(&[2, 1, 0, 0, 1, 2]), &[2, 3], data_type::INT32)
28        .expect("gather along indices");
29    let axis_tensor = graph
30        .constant_scalar(1.0, data_type::INT32)
31        .expect("axis tensor");
32
33    let gather = graph
34        .gather(&updates, &gather_indices, 1, 0, Some("gather"))
35        .expect("gather");
36    let gather_nd = graph
37        .gather_nd(&updates, &gather_nd_indices, 0, Some("gather_nd"))
38        .expect("gather nd");
39    let gather_axis = graph
40        .gather_along_axis(1, &updates, &along_indices, Some("gather_axis"))
41        .expect("gather along axis");
42    let gather_axis_tensor = graph
43        .gather_along_axis_tensor(
44            &axis_tensor,
45            &updates,
46            &along_indices,
47            Some("gather_axis_tensor"),
48        )
49        .expect("gather along axis tensor");
50
51    let descriptor = RandomOpDescriptor::new(random_distribution::UNIFORM, data_type::FLOAT32)
52        .expect("random descriptor");
53    descriptor.set_min(0.0).expect("random min");
54    descriptor.set_max(1.0).expect("random max");
55    let random = graph
56        .random_tensor_seed(&[4], &descriptor, 7, Some("random"))
57        .expect("random tensor");
58    let dropout = graph
59        .dropout(&updates, 1.0, Some("dropout"))
60        .expect("dropout");
61
62    let single_gate_descriptor = SingleGateRNNDescriptor::new().expect("single gate descriptor");
63    single_gate_descriptor
64        .set_activation(rnn_activation::RELU)
65        .expect("single gate activation");
66    let single_gate_source = graph
67        .constant_f32_slice(&[0.5], &[1, 1, 1])
68        .expect("single gate source");
69    let single_gate_recurrent = graph
70        .constant_f32_slice(&[0.0], &[1, 1])
71        .expect("single gate recurrent");
72    let single_gate = graph
73        .single_gate_rnn(
74            &single_gate_source,
75            &single_gate_recurrent,
76            None,
77            None,
78            None,
79            None,
80            &single_gate_descriptor,
81            Some("single_gate"),
82        )
83        .expect("single gate rnn");
84
85    let lstm_descriptor = LSTMDescriptor::new().expect("lstm descriptor");
86    lstm_descriptor
87        .set_produce_cell(true)
88        .expect("set produce cell");
89    let lstm_source = graph
90        .constant_f32_slice(&[0.0; 4], &[1, 1, 4])
91        .expect("lstm source");
92    let lstm_recurrent = graph
93        .constant_f32_slice(&[0.0; 4], &[4, 1])
94        .expect("lstm recurrent");
95    let lstm = graph
96        .lstm(
97            &lstm_source,
98            &lstm_recurrent,
99            None,
100            None,
101            None,
102            None,
103            None,
104            None,
105            &lstm_descriptor,
106            Some("lstm"),
107        )
108        .expect("lstm");
109
110    let gru_descriptor = GRUDescriptor::new().expect("gru descriptor");
111    gru_descriptor.set_training(true).expect("set gru training");
112    gru_descriptor
113        .set_reset_after(true)
114        .expect("set gru reset_after");
115    let gru_source = graph
116        .constant_f32_slice(&[0.0; 3], &[1, 1, 3])
117        .expect("gru source");
118    let gru_recurrent = graph
119        .constant_f32_slice(&[0.0; 3], &[3, 1])
120        .expect("gru recurrent");
121    let gru_secondary_bias = graph
122        .constant_f32_slice(&[0.0], &[1])
123        .expect("gru secondary bias");
124    let gru = graph
125        .gru(
126            &gru_source,
127            &gru_recurrent,
128            None,
129            None,
130            None,
131            None,
132            Some(&gru_secondary_bias),
133            &gru_descriptor,
134            Some("gru"),
135        )
136        .expect("gru");
137
138    let results = graph
139        .run(
140            &[],
141            &[
142                &gather,
143                &gather_nd,
144                &gather_axis,
145                &gather_axis_tensor,
146                &random,
147                &dropout,
148                &single_gate[0],
149                &lstm[0],
150                &lstm[1],
151                &gru[0],
152                &gru[1],
153            ],
154        )
155        .expect("run graph");
156
157    println!("gather: {:?}", results[0].read_f32().expect("gather"));
158    println!("gather_nd: {:?}", results[1].read_f32().expect("gather_nd"));
159    println!(
160        "gather_axis: {:?}",
161        results[2].read_f32().expect("gather_axis")
162    );
163    println!(
164        "gather_axis_tensor: {:?}",
165        results[3].read_f32().expect("gather_axis_tensor")
166    );
167    println!("random: {:?}", results[4].read_f32().expect("random"));
168    println!("dropout: {:?}", results[5].read_f32().expect("dropout"));
169    println!(
170        "single_gate: {:?}",
171        results[6].read_f32().expect("single_gate")
172    );
173    println!(
174        "lstm state: {:?}",
175        results[7].read_f32().expect("lstm state")
176    );
177    println!("lstm cell: {:?}", results[8].read_f32().expect("lstm cell"));
178    println!("gru state: {:?}", results[9].read_f32().expect("gru state"));
179    println!(
180        "gru training: {:?}",
181        results[10].read_f32().expect("gru training")
182    );
183}
Source

pub fn max(&self) -> f32

Calls the MPSGraph framework counterpart for max.

Source

pub fn set_max(&self, value: f32) -> Result<()>

Calls the MPSGraph framework counterpart for set_max.

Examples found in repository?
examples/07_gather_random_rnn.rs (line 54)
15fn main() {
16    let graph = Graph::new().expect("graph");
17    let updates = graph
18        .constant_f32_slice(&[10.0, 20.0, 30.0, 40.0, 50.0, 60.0], &[2, 3])
19        .expect("updates");
20    let gather_indices = graph
21        .constant_bytes(&i32_bytes(&[2, 0]), &[2], data_type::INT32)
22        .expect("gather indices");
23    let gather_nd_indices = graph
24        .constant_bytes(&i32_bytes(&[0, 1, 1, 0]), &[2, 2], data_type::INT32)
25        .expect("gather nd indices");
26    let along_indices = graph
27        .constant_bytes(&i32_bytes(&[2, 1, 0, 0, 1, 2]), &[2, 3], data_type::INT32)
28        .expect("gather along indices");
29    let axis_tensor = graph
30        .constant_scalar(1.0, data_type::INT32)
31        .expect("axis tensor");
32
33    let gather = graph
34        .gather(&updates, &gather_indices, 1, 0, Some("gather"))
35        .expect("gather");
36    let gather_nd = graph
37        .gather_nd(&updates, &gather_nd_indices, 0, Some("gather_nd"))
38        .expect("gather nd");
39    let gather_axis = graph
40        .gather_along_axis(1, &updates, &along_indices, Some("gather_axis"))
41        .expect("gather along axis");
42    let gather_axis_tensor = graph
43        .gather_along_axis_tensor(
44            &axis_tensor,
45            &updates,
46            &along_indices,
47            Some("gather_axis_tensor"),
48        )
49        .expect("gather along axis tensor");
50
51    let descriptor = RandomOpDescriptor::new(random_distribution::UNIFORM, data_type::FLOAT32)
52        .expect("random descriptor");
53    descriptor.set_min(0.0).expect("random min");
54    descriptor.set_max(1.0).expect("random max");
55    let random = graph
56        .random_tensor_seed(&[4], &descriptor, 7, Some("random"))
57        .expect("random tensor");
58    let dropout = graph
59        .dropout(&updates, 1.0, Some("dropout"))
60        .expect("dropout");
61
62    let single_gate_descriptor = SingleGateRNNDescriptor::new().expect("single gate descriptor");
63    single_gate_descriptor
64        .set_activation(rnn_activation::RELU)
65        .expect("single gate activation");
66    let single_gate_source = graph
67        .constant_f32_slice(&[0.5], &[1, 1, 1])
68        .expect("single gate source");
69    let single_gate_recurrent = graph
70        .constant_f32_slice(&[0.0], &[1, 1])
71        .expect("single gate recurrent");
72    let single_gate = graph
73        .single_gate_rnn(
74            &single_gate_source,
75            &single_gate_recurrent,
76            None,
77            None,
78            None,
79            None,
80            &single_gate_descriptor,
81            Some("single_gate"),
82        )
83        .expect("single gate rnn");
84
85    let lstm_descriptor = LSTMDescriptor::new().expect("lstm descriptor");
86    lstm_descriptor
87        .set_produce_cell(true)
88        .expect("set produce cell");
89    let lstm_source = graph
90        .constant_f32_slice(&[0.0; 4], &[1, 1, 4])
91        .expect("lstm source");
92    let lstm_recurrent = graph
93        .constant_f32_slice(&[0.0; 4], &[4, 1])
94        .expect("lstm recurrent");
95    let lstm = graph
96        .lstm(
97            &lstm_source,
98            &lstm_recurrent,
99            None,
100            None,
101            None,
102            None,
103            None,
104            None,
105            &lstm_descriptor,
106            Some("lstm"),
107        )
108        .expect("lstm");
109
110    let gru_descriptor = GRUDescriptor::new().expect("gru descriptor");
111    gru_descriptor.set_training(true).expect("set gru training");
112    gru_descriptor
113        .set_reset_after(true)
114        .expect("set gru reset_after");
115    let gru_source = graph
116        .constant_f32_slice(&[0.0; 3], &[1, 1, 3])
117        .expect("gru source");
118    let gru_recurrent = graph
119        .constant_f32_slice(&[0.0; 3], &[3, 1])
120        .expect("gru recurrent");
121    let gru_secondary_bias = graph
122        .constant_f32_slice(&[0.0], &[1])
123        .expect("gru secondary bias");
124    let gru = graph
125        .gru(
126            &gru_source,
127            &gru_recurrent,
128            None,
129            None,
130            None,
131            None,
132            Some(&gru_secondary_bias),
133            &gru_descriptor,
134            Some("gru"),
135        )
136        .expect("gru");
137
138    let results = graph
139        .run(
140            &[],
141            &[
142                &gather,
143                &gather_nd,
144                &gather_axis,
145                &gather_axis_tensor,
146                &random,
147                &dropout,
148                &single_gate[0],
149                &lstm[0],
150                &lstm[1],
151                &gru[0],
152                &gru[1],
153            ],
154        )
155        .expect("run graph");
156
157    println!("gather: {:?}", results[0].read_f32().expect("gather"));
158    println!("gather_nd: {:?}", results[1].read_f32().expect("gather_nd"));
159    println!(
160        "gather_axis: {:?}",
161        results[2].read_f32().expect("gather_axis")
162    );
163    println!(
164        "gather_axis_tensor: {:?}",
165        results[3].read_f32().expect("gather_axis_tensor")
166    );
167    println!("random: {:?}", results[4].read_f32().expect("random"));
168    println!("dropout: {:?}", results[5].read_f32().expect("dropout"));
169    println!(
170        "single_gate: {:?}",
171        results[6].read_f32().expect("single_gate")
172    );
173    println!(
174        "lstm state: {:?}",
175        results[7].read_f32().expect("lstm state")
176    );
177    println!("lstm cell: {:?}", results[8].read_f32().expect("lstm cell"));
178    println!("gru state: {:?}", results[9].read_f32().expect("gru state"));
179    println!(
180        "gru training: {:?}",
181        results[10].read_f32().expect("gru training")
182    );
183}
Source

pub fn min_integer(&self) -> isize

Calls the MPSGraph framework counterpart for min_integer.

Source

pub fn set_min_integer(&self, value: isize) -> Result<()>

Calls the MPSGraph framework counterpart for set_min_integer.

Source

pub fn max_integer(&self) -> isize

Calls the MPSGraph framework counterpart for max_integer.

Source

pub fn set_max_integer(&self, value: isize) -> Result<()>

Calls the MPSGraph framework counterpart for set_max_integer.

Source

pub fn mean(&self) -> f32

Calls the MPSGraph framework counterpart for mean.

Source

pub fn set_mean(&self, value: f32) -> Result<()>

Calls the MPSGraph framework counterpart for set_mean.

Source

pub fn standard_deviation(&self) -> f32

Calls the MPSGraph framework counterpart for standard_deviation.

Source

pub fn set_standard_deviation(&self, value: f32) -> Result<()>

Calls the MPSGraph framework counterpart for set_standard_deviation.

Source

pub fn sampling_method(&self) -> u64

Calls the MPSGraph framework counterpart for sampling_method.

Source

pub fn set_sampling_method(&self, value: u64) -> Result<()>

Calls the MPSGraph framework counterpart for set_sampling_method.

Trait Implementations§

Source§

impl Drop for RandomOpDescriptor

Source§

fn drop(&mut self)

Executes the destructor for this type. Read more
Source§

fn pin_drop(self: Pin<&mut Self>)

🔬This is a nightly-only experimental API. (pin_ergonomics)
Execute the destructor for this type, but different to Drop::drop, it requires self to be pinned. Read more
Source§

impl Send for RandomOpDescriptor

Source§

impl Sync for RandomOpDescriptor

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.