pub struct RandomOpDescriptor { /* private fields */ }Expand description
Safe owner for MPSGraphRandomOpDescriptor.
Implementations§
Source§impl RandomOpDescriptor
impl RandomOpDescriptor
Sourcepub fn new(distribution: u64, data_type: u32) -> Option<Self>
pub fn new(distribution: u64, data_type: u32) -> Option<Self>
Examples found in repository?
examples/07_gather_random_rnn.rs (line 51)
15fn main() {
16 let graph = Graph::new().expect("graph");
17 let updates = graph
18 .constant_f32_slice(&[10.0, 20.0, 30.0, 40.0, 50.0, 60.0], &[2, 3])
19 .expect("updates");
20 let gather_indices = graph
21 .constant_bytes(&i32_bytes(&[2, 0]), &[2], data_type::INT32)
22 .expect("gather indices");
23 let gather_nd_indices = graph
24 .constant_bytes(&i32_bytes(&[0, 1, 1, 0]), &[2, 2], data_type::INT32)
25 .expect("gather nd indices");
26 let along_indices = graph
27 .constant_bytes(&i32_bytes(&[2, 1, 0, 0, 1, 2]), &[2, 3], data_type::INT32)
28 .expect("gather along indices");
29 let axis_tensor = graph
30 .constant_scalar(1.0, data_type::INT32)
31 .expect("axis tensor");
32
33 let gather = graph
34 .gather(&updates, &gather_indices, 1, 0, Some("gather"))
35 .expect("gather");
36 let gather_nd = graph
37 .gather_nd(&updates, &gather_nd_indices, 0, Some("gather_nd"))
38 .expect("gather nd");
39 let gather_axis = graph
40 .gather_along_axis(1, &updates, &along_indices, Some("gather_axis"))
41 .expect("gather along axis");
42 let gather_axis_tensor = graph
43 .gather_along_axis_tensor(
44 &axis_tensor,
45 &updates,
46 &along_indices,
47 Some("gather_axis_tensor"),
48 )
49 .expect("gather along axis tensor");
50
51 let descriptor = RandomOpDescriptor::new(random_distribution::UNIFORM, data_type::FLOAT32)
52 .expect("random descriptor");
53 descriptor.set_min(0.0).expect("random min");
54 descriptor.set_max(1.0).expect("random max");
55 let random = graph
56 .random_tensor_seed(&[4], &descriptor, 7, Some("random"))
57 .expect("random tensor");
58 let dropout = graph
59 .dropout(&updates, 1.0, Some("dropout"))
60 .expect("dropout");
61
62 let single_gate_descriptor = SingleGateRNNDescriptor::new().expect("single gate descriptor");
63 single_gate_descriptor
64 .set_activation(rnn_activation::RELU)
65 .expect("single gate activation");
66 let single_gate_source = graph
67 .constant_f32_slice(&[0.5], &[1, 1, 1])
68 .expect("single gate source");
69 let single_gate_recurrent = graph
70 .constant_f32_slice(&[0.0], &[1, 1])
71 .expect("single gate recurrent");
72 let single_gate = graph
73 .single_gate_rnn(
74 &single_gate_source,
75 &single_gate_recurrent,
76 None,
77 None,
78 None,
79 None,
80 &single_gate_descriptor,
81 Some("single_gate"),
82 )
83 .expect("single gate rnn");
84
85 let lstm_descriptor = LSTMDescriptor::new().expect("lstm descriptor");
86 lstm_descriptor
87 .set_produce_cell(true)
88 .expect("set produce cell");
89 let lstm_source = graph
90 .constant_f32_slice(&[0.0; 4], &[1, 1, 4])
91 .expect("lstm source");
92 let lstm_recurrent = graph
93 .constant_f32_slice(&[0.0; 4], &[4, 1])
94 .expect("lstm recurrent");
95 let lstm = graph
96 .lstm(
97 &lstm_source,
98 &lstm_recurrent,
99 None,
100 None,
101 None,
102 None,
103 None,
104 None,
105 &lstm_descriptor,
106 Some("lstm"),
107 )
108 .expect("lstm");
109
110 let gru_descriptor = GRUDescriptor::new().expect("gru descriptor");
111 gru_descriptor.set_training(true).expect("set gru training");
112 gru_descriptor
113 .set_reset_after(true)
114 .expect("set gru reset_after");
115 let gru_source = graph
116 .constant_f32_slice(&[0.0; 3], &[1, 1, 3])
117 .expect("gru source");
118 let gru_recurrent = graph
119 .constant_f32_slice(&[0.0; 3], &[3, 1])
120 .expect("gru recurrent");
121 let gru_secondary_bias = graph
122 .constant_f32_slice(&[0.0], &[1])
123 .expect("gru secondary bias");
124 let gru = graph
125 .gru(
126 &gru_source,
127 &gru_recurrent,
128 None,
129 None,
130 None,
131 None,
132 Some(&gru_secondary_bias),
133 &gru_descriptor,
134 Some("gru"),
135 )
136 .expect("gru");
137
138 let results = graph
139 .run(
140 &[],
141 &[
142 &gather,
143 &gather_nd,
144 &gather_axis,
145 &gather_axis_tensor,
146 &random,
147 &dropout,
148 &single_gate[0],
149 &lstm[0],
150 &lstm[1],
151 &gru[0],
152 &gru[1],
153 ],
154 )
155 .expect("run graph");
156
157 println!("gather: {:?}", results[0].read_f32().expect("gather"));
158 println!("gather_nd: {:?}", results[1].read_f32().expect("gather_nd"));
159 println!(
160 "gather_axis: {:?}",
161 results[2].read_f32().expect("gather_axis")
162 );
163 println!(
164 "gather_axis_tensor: {:?}",
165 results[3].read_f32().expect("gather_axis_tensor")
166 );
167 println!("random: {:?}", results[4].read_f32().expect("random"));
168 println!("dropout: {:?}", results[5].read_f32().expect("dropout"));
169 println!(
170 "single_gate: {:?}",
171 results[6].read_f32().expect("single_gate")
172 );
173 println!(
174 "lstm state: {:?}",
175 results[7].read_f32().expect("lstm state")
176 );
177 println!("lstm cell: {:?}", results[8].read_f32().expect("lstm cell"));
178 println!("gru state: {:?}", results[9].read_f32().expect("gru state"));
179 println!(
180 "gru training: {:?}",
181 results[10].read_f32().expect("gru training")
182 );
183}pub fn distribution(&self) -> u64
pub fn set_distribution(&self, value: u64) -> Result<()>
pub fn data_type(&self) -> u32
pub fn set_data_type(&self, value: u32) -> Result<()>
pub fn min(&self) -> f32
Sourcepub fn set_min(&self, value: f32) -> Result<()>
pub fn set_min(&self, value: f32) -> Result<()>
Examples found in repository?
examples/07_gather_random_rnn.rs (line 53)
15fn main() {
16 let graph = Graph::new().expect("graph");
17 let updates = graph
18 .constant_f32_slice(&[10.0, 20.0, 30.0, 40.0, 50.0, 60.0], &[2, 3])
19 .expect("updates");
20 let gather_indices = graph
21 .constant_bytes(&i32_bytes(&[2, 0]), &[2], data_type::INT32)
22 .expect("gather indices");
23 let gather_nd_indices = graph
24 .constant_bytes(&i32_bytes(&[0, 1, 1, 0]), &[2, 2], data_type::INT32)
25 .expect("gather nd indices");
26 let along_indices = graph
27 .constant_bytes(&i32_bytes(&[2, 1, 0, 0, 1, 2]), &[2, 3], data_type::INT32)
28 .expect("gather along indices");
29 let axis_tensor = graph
30 .constant_scalar(1.0, data_type::INT32)
31 .expect("axis tensor");
32
33 let gather = graph
34 .gather(&updates, &gather_indices, 1, 0, Some("gather"))
35 .expect("gather");
36 let gather_nd = graph
37 .gather_nd(&updates, &gather_nd_indices, 0, Some("gather_nd"))
38 .expect("gather nd");
39 let gather_axis = graph
40 .gather_along_axis(1, &updates, &along_indices, Some("gather_axis"))
41 .expect("gather along axis");
42 let gather_axis_tensor = graph
43 .gather_along_axis_tensor(
44 &axis_tensor,
45 &updates,
46 &along_indices,
47 Some("gather_axis_tensor"),
48 )
49 .expect("gather along axis tensor");
50
51 let descriptor = RandomOpDescriptor::new(random_distribution::UNIFORM, data_type::FLOAT32)
52 .expect("random descriptor");
53 descriptor.set_min(0.0).expect("random min");
54 descriptor.set_max(1.0).expect("random max");
55 let random = graph
56 .random_tensor_seed(&[4], &descriptor, 7, Some("random"))
57 .expect("random tensor");
58 let dropout = graph
59 .dropout(&updates, 1.0, Some("dropout"))
60 .expect("dropout");
61
62 let single_gate_descriptor = SingleGateRNNDescriptor::new().expect("single gate descriptor");
63 single_gate_descriptor
64 .set_activation(rnn_activation::RELU)
65 .expect("single gate activation");
66 let single_gate_source = graph
67 .constant_f32_slice(&[0.5], &[1, 1, 1])
68 .expect("single gate source");
69 let single_gate_recurrent = graph
70 .constant_f32_slice(&[0.0], &[1, 1])
71 .expect("single gate recurrent");
72 let single_gate = graph
73 .single_gate_rnn(
74 &single_gate_source,
75 &single_gate_recurrent,
76 None,
77 None,
78 None,
79 None,
80 &single_gate_descriptor,
81 Some("single_gate"),
82 )
83 .expect("single gate rnn");
84
85 let lstm_descriptor = LSTMDescriptor::new().expect("lstm descriptor");
86 lstm_descriptor
87 .set_produce_cell(true)
88 .expect("set produce cell");
89 let lstm_source = graph
90 .constant_f32_slice(&[0.0; 4], &[1, 1, 4])
91 .expect("lstm source");
92 let lstm_recurrent = graph
93 .constant_f32_slice(&[0.0; 4], &[4, 1])
94 .expect("lstm recurrent");
95 let lstm = graph
96 .lstm(
97 &lstm_source,
98 &lstm_recurrent,
99 None,
100 None,
101 None,
102 None,
103 None,
104 None,
105 &lstm_descriptor,
106 Some("lstm"),
107 )
108 .expect("lstm");
109
110 let gru_descriptor = GRUDescriptor::new().expect("gru descriptor");
111 gru_descriptor.set_training(true).expect("set gru training");
112 gru_descriptor
113 .set_reset_after(true)
114 .expect("set gru reset_after");
115 let gru_source = graph
116 .constant_f32_slice(&[0.0; 3], &[1, 1, 3])
117 .expect("gru source");
118 let gru_recurrent = graph
119 .constant_f32_slice(&[0.0; 3], &[3, 1])
120 .expect("gru recurrent");
121 let gru_secondary_bias = graph
122 .constant_f32_slice(&[0.0], &[1])
123 .expect("gru secondary bias");
124 let gru = graph
125 .gru(
126 &gru_source,
127 &gru_recurrent,
128 None,
129 None,
130 None,
131 None,
132 Some(&gru_secondary_bias),
133 &gru_descriptor,
134 Some("gru"),
135 )
136 .expect("gru");
137
138 let results = graph
139 .run(
140 &[],
141 &[
142 &gather,
143 &gather_nd,
144 &gather_axis,
145 &gather_axis_tensor,
146 &random,
147 &dropout,
148 &single_gate[0],
149 &lstm[0],
150 &lstm[1],
151 &gru[0],
152 &gru[1],
153 ],
154 )
155 .expect("run graph");
156
157 println!("gather: {:?}", results[0].read_f32().expect("gather"));
158 println!("gather_nd: {:?}", results[1].read_f32().expect("gather_nd"));
159 println!(
160 "gather_axis: {:?}",
161 results[2].read_f32().expect("gather_axis")
162 );
163 println!(
164 "gather_axis_tensor: {:?}",
165 results[3].read_f32().expect("gather_axis_tensor")
166 );
167 println!("random: {:?}", results[4].read_f32().expect("random"));
168 println!("dropout: {:?}", results[5].read_f32().expect("dropout"));
169 println!(
170 "single_gate: {:?}",
171 results[6].read_f32().expect("single_gate")
172 );
173 println!(
174 "lstm state: {:?}",
175 results[7].read_f32().expect("lstm state")
176 );
177 println!("lstm cell: {:?}", results[8].read_f32().expect("lstm cell"));
178 println!("gru state: {:?}", results[9].read_f32().expect("gru state"));
179 println!(
180 "gru training: {:?}",
181 results[10].read_f32().expect("gru training")
182 );
183}pub fn max(&self) -> f32
Sourcepub fn set_max(&self, value: f32) -> Result<()>
pub fn set_max(&self, value: f32) -> Result<()>
Examples found in repository?
examples/07_gather_random_rnn.rs (line 54)
15fn main() {
16 let graph = Graph::new().expect("graph");
17 let updates = graph
18 .constant_f32_slice(&[10.0, 20.0, 30.0, 40.0, 50.0, 60.0], &[2, 3])
19 .expect("updates");
20 let gather_indices = graph
21 .constant_bytes(&i32_bytes(&[2, 0]), &[2], data_type::INT32)
22 .expect("gather indices");
23 let gather_nd_indices = graph
24 .constant_bytes(&i32_bytes(&[0, 1, 1, 0]), &[2, 2], data_type::INT32)
25 .expect("gather nd indices");
26 let along_indices = graph
27 .constant_bytes(&i32_bytes(&[2, 1, 0, 0, 1, 2]), &[2, 3], data_type::INT32)
28 .expect("gather along indices");
29 let axis_tensor = graph
30 .constant_scalar(1.0, data_type::INT32)
31 .expect("axis tensor");
32
33 let gather = graph
34 .gather(&updates, &gather_indices, 1, 0, Some("gather"))
35 .expect("gather");
36 let gather_nd = graph
37 .gather_nd(&updates, &gather_nd_indices, 0, Some("gather_nd"))
38 .expect("gather nd");
39 let gather_axis = graph
40 .gather_along_axis(1, &updates, &along_indices, Some("gather_axis"))
41 .expect("gather along axis");
42 let gather_axis_tensor = graph
43 .gather_along_axis_tensor(
44 &axis_tensor,
45 &updates,
46 &along_indices,
47 Some("gather_axis_tensor"),
48 )
49 .expect("gather along axis tensor");
50
51 let descriptor = RandomOpDescriptor::new(random_distribution::UNIFORM, data_type::FLOAT32)
52 .expect("random descriptor");
53 descriptor.set_min(0.0).expect("random min");
54 descriptor.set_max(1.0).expect("random max");
55 let random = graph
56 .random_tensor_seed(&[4], &descriptor, 7, Some("random"))
57 .expect("random tensor");
58 let dropout = graph
59 .dropout(&updates, 1.0, Some("dropout"))
60 .expect("dropout");
61
62 let single_gate_descriptor = SingleGateRNNDescriptor::new().expect("single gate descriptor");
63 single_gate_descriptor
64 .set_activation(rnn_activation::RELU)
65 .expect("single gate activation");
66 let single_gate_source = graph
67 .constant_f32_slice(&[0.5], &[1, 1, 1])
68 .expect("single gate source");
69 let single_gate_recurrent = graph
70 .constant_f32_slice(&[0.0], &[1, 1])
71 .expect("single gate recurrent");
72 let single_gate = graph
73 .single_gate_rnn(
74 &single_gate_source,
75 &single_gate_recurrent,
76 None,
77 None,
78 None,
79 None,
80 &single_gate_descriptor,
81 Some("single_gate"),
82 )
83 .expect("single gate rnn");
84
85 let lstm_descriptor = LSTMDescriptor::new().expect("lstm descriptor");
86 lstm_descriptor
87 .set_produce_cell(true)
88 .expect("set produce cell");
89 let lstm_source = graph
90 .constant_f32_slice(&[0.0; 4], &[1, 1, 4])
91 .expect("lstm source");
92 let lstm_recurrent = graph
93 .constant_f32_slice(&[0.0; 4], &[4, 1])
94 .expect("lstm recurrent");
95 let lstm = graph
96 .lstm(
97 &lstm_source,
98 &lstm_recurrent,
99 None,
100 None,
101 None,
102 None,
103 None,
104 None,
105 &lstm_descriptor,
106 Some("lstm"),
107 )
108 .expect("lstm");
109
110 let gru_descriptor = GRUDescriptor::new().expect("gru descriptor");
111 gru_descriptor.set_training(true).expect("set gru training");
112 gru_descriptor
113 .set_reset_after(true)
114 .expect("set gru reset_after");
115 let gru_source = graph
116 .constant_f32_slice(&[0.0; 3], &[1, 1, 3])
117 .expect("gru source");
118 let gru_recurrent = graph
119 .constant_f32_slice(&[0.0; 3], &[3, 1])
120 .expect("gru recurrent");
121 let gru_secondary_bias = graph
122 .constant_f32_slice(&[0.0], &[1])
123 .expect("gru secondary bias");
124 let gru = graph
125 .gru(
126 &gru_source,
127 &gru_recurrent,
128 None,
129 None,
130 None,
131 None,
132 Some(&gru_secondary_bias),
133 &gru_descriptor,
134 Some("gru"),
135 )
136 .expect("gru");
137
138 let results = graph
139 .run(
140 &[],
141 &[
142 &gather,
143 &gather_nd,
144 &gather_axis,
145 &gather_axis_tensor,
146 &random,
147 &dropout,
148 &single_gate[0],
149 &lstm[0],
150 &lstm[1],
151 &gru[0],
152 &gru[1],
153 ],
154 )
155 .expect("run graph");
156
157 println!("gather: {:?}", results[0].read_f32().expect("gather"));
158 println!("gather_nd: {:?}", results[1].read_f32().expect("gather_nd"));
159 println!(
160 "gather_axis: {:?}",
161 results[2].read_f32().expect("gather_axis")
162 );
163 println!(
164 "gather_axis_tensor: {:?}",
165 results[3].read_f32().expect("gather_axis_tensor")
166 );
167 println!("random: {:?}", results[4].read_f32().expect("random"));
168 println!("dropout: {:?}", results[5].read_f32().expect("dropout"));
169 println!(
170 "single_gate: {:?}",
171 results[6].read_f32().expect("single_gate")
172 );
173 println!(
174 "lstm state: {:?}",
175 results[7].read_f32().expect("lstm state")
176 );
177 println!("lstm cell: {:?}", results[8].read_f32().expect("lstm cell"));
178 println!("gru state: {:?}", results[9].read_f32().expect("gru state"));
179 println!(
180 "gru training: {:?}",
181 results[10].read_f32().expect("gru training")
182 );
183}pub fn min_integer(&self) -> isize
pub fn set_min_integer(&self, value: isize) -> Result<()>
pub fn max_integer(&self) -> isize
pub fn set_max_integer(&self, value: isize) -> Result<()>
pub fn mean(&self) -> f32
pub fn set_mean(&self, value: f32) -> Result<()>
pub fn standard_deviation(&self) -> f32
pub fn set_standard_deviation(&self, value: f32) -> Result<()>
pub fn sampling_method(&self) -> u64
pub fn set_sampling_method(&self, value: u64) -> Result<()>
Trait Implementations§
Source§impl Drop for RandomOpDescriptor
impl Drop for RandomOpDescriptor
impl Send for RandomOpDescriptor
impl Sync for RandomOpDescriptor
Auto Trait Implementations§
impl Freeze for RandomOpDescriptor
impl RefUnwindSafe for RandomOpDescriptor
impl Unpin for RandomOpDescriptor
impl UnsafeUnpin for RandomOpDescriptor
impl UnwindSafe for RandomOpDescriptor
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more