pub struct SingleGateRNNDescriptor { /* private fields */ }Expand description
Mirrors the MPSGraph framework counterpart for this type.
Implementations§
Source§impl SingleGateRNNDescriptor
impl SingleGateRNNDescriptor
Sourcepub fn new() -> Option<Self>
pub fn new() -> Option<Self>
Calls the MPSGraph framework counterpart for new.
Examples found in repository?
examples/07_gather_random_rnn.rs (line 62)
15fn main() {
16 let graph = Graph::new().expect("graph");
17 let updates = graph
18 .constant_f32_slice(&[10.0, 20.0, 30.0, 40.0, 50.0, 60.0], &[2, 3])
19 .expect("updates");
20 let gather_indices = graph
21 .constant_bytes(&i32_bytes(&[2, 0]), &[2], data_type::INT32)
22 .expect("gather indices");
23 let gather_nd_indices = graph
24 .constant_bytes(&i32_bytes(&[0, 1, 1, 0]), &[2, 2], data_type::INT32)
25 .expect("gather nd indices");
26 let along_indices = graph
27 .constant_bytes(&i32_bytes(&[2, 1, 0, 0, 1, 2]), &[2, 3], data_type::INT32)
28 .expect("gather along indices");
29 let axis_tensor = graph
30 .constant_scalar(1.0, data_type::INT32)
31 .expect("axis tensor");
32
33 let gather = graph
34 .gather(&updates, &gather_indices, 1, 0, Some("gather"))
35 .expect("gather");
36 let gather_nd = graph
37 .gather_nd(&updates, &gather_nd_indices, 0, Some("gather_nd"))
38 .expect("gather nd");
39 let gather_axis = graph
40 .gather_along_axis(1, &updates, &along_indices, Some("gather_axis"))
41 .expect("gather along axis");
42 let gather_axis_tensor = graph
43 .gather_along_axis_tensor(
44 &axis_tensor,
45 &updates,
46 &along_indices,
47 Some("gather_axis_tensor"),
48 )
49 .expect("gather along axis tensor");
50
51 let descriptor = RandomOpDescriptor::new(random_distribution::UNIFORM, data_type::FLOAT32)
52 .expect("random descriptor");
53 descriptor.set_min(0.0).expect("random min");
54 descriptor.set_max(1.0).expect("random max");
55 let random = graph
56 .random_tensor_seed(&[4], &descriptor, 7, Some("random"))
57 .expect("random tensor");
58 let dropout = graph
59 .dropout(&updates, 1.0, Some("dropout"))
60 .expect("dropout");
61
62 let single_gate_descriptor = SingleGateRNNDescriptor::new().expect("single gate descriptor");
63 single_gate_descriptor
64 .set_activation(rnn_activation::RELU)
65 .expect("single gate activation");
66 let single_gate_source = graph
67 .constant_f32_slice(&[0.5], &[1, 1, 1])
68 .expect("single gate source");
69 let single_gate_recurrent = graph
70 .constant_f32_slice(&[0.0], &[1, 1])
71 .expect("single gate recurrent");
72 let single_gate = graph
73 .single_gate_rnn(
74 &single_gate_source,
75 &single_gate_recurrent,
76 None,
77 None,
78 None,
79 None,
80 &single_gate_descriptor,
81 Some("single_gate"),
82 )
83 .expect("single gate rnn");
84
85 let lstm_descriptor = LSTMDescriptor::new().expect("lstm descriptor");
86 lstm_descriptor
87 .set_produce_cell(true)
88 .expect("set produce cell");
89 let lstm_source = graph
90 .constant_f32_slice(&[0.0; 4], &[1, 1, 4])
91 .expect("lstm source");
92 let lstm_recurrent = graph
93 .constant_f32_slice(&[0.0; 4], &[4, 1])
94 .expect("lstm recurrent");
95 let lstm = graph
96 .lstm(
97 &lstm_source,
98 &lstm_recurrent,
99 None,
100 None,
101 None,
102 None,
103 None,
104 None,
105 &lstm_descriptor,
106 Some("lstm"),
107 )
108 .expect("lstm");
109
110 let gru_descriptor = GRUDescriptor::new().expect("gru descriptor");
111 gru_descriptor.set_training(true).expect("set gru training");
112 gru_descriptor
113 .set_reset_after(true)
114 .expect("set gru reset_after");
115 let gru_source = graph
116 .constant_f32_slice(&[0.0; 3], &[1, 1, 3])
117 .expect("gru source");
118 let gru_recurrent = graph
119 .constant_f32_slice(&[0.0; 3], &[3, 1])
120 .expect("gru recurrent");
121 let gru_secondary_bias = graph
122 .constant_f32_slice(&[0.0], &[1])
123 .expect("gru secondary bias");
124 let gru = graph
125 .gru(
126 &gru_source,
127 &gru_recurrent,
128 None,
129 None,
130 None,
131 None,
132 Some(&gru_secondary_bias),
133 &gru_descriptor,
134 Some("gru"),
135 )
136 .expect("gru");
137
138 let results = graph
139 .run(
140 &[],
141 &[
142 &gather,
143 &gather_nd,
144 &gather_axis,
145 &gather_axis_tensor,
146 &random,
147 &dropout,
148 &single_gate[0],
149 &lstm[0],
150 &lstm[1],
151 &gru[0],
152 &gru[1],
153 ],
154 )
155 .expect("run graph");
156
157 println!("gather: {:?}", results[0].read_f32().expect("gather"));
158 println!("gather_nd: {:?}", results[1].read_f32().expect("gather_nd"));
159 println!(
160 "gather_axis: {:?}",
161 results[2].read_f32().expect("gather_axis")
162 );
163 println!(
164 "gather_axis_tensor: {:?}",
165 results[3].read_f32().expect("gather_axis_tensor")
166 );
167 println!("random: {:?}", results[4].read_f32().expect("random"));
168 println!("dropout: {:?}", results[5].read_f32().expect("dropout"));
169 println!(
170 "single_gate: {:?}",
171 results[6].read_f32().expect("single_gate")
172 );
173 println!(
174 "lstm state: {:?}",
175 results[7].read_f32().expect("lstm state")
176 );
177 println!("lstm cell: {:?}", results[8].read_f32().expect("lstm cell"));
178 println!("gru state: {:?}", results[9].read_f32().expect("gru state"));
179 println!(
180 "gru training: {:?}",
181 results[10].read_f32().expect("gru training")
182 );
183}Sourcepub fn set_reverse(&self, value: bool) -> Result<()>
pub fn set_reverse(&self, value: bool) -> Result<()>
Calls the MPSGraph framework counterpart for this method.
Sourcepub fn bidirectional(&self) -> bool
pub fn bidirectional(&self) -> bool
Calls the MPSGraph framework counterpart for this method.
Sourcepub fn set_bidirectional(&self, value: bool) -> Result<()>
pub fn set_bidirectional(&self, value: bool) -> Result<()>
Calls the MPSGraph framework counterpart for this method.
Sourcepub fn set_training(&self, value: bool) -> Result<()>
pub fn set_training(&self, value: bool) -> Result<()>
Calls the MPSGraph framework counterpart for this method.
Sourcepub fn activation(&self) -> usize
pub fn activation(&self) -> usize
Calls the MPSGraph framework counterpart for this method.
Sourcepub fn set_activation(&self, value: usize) -> Result<()>
pub fn set_activation(&self, value: usize) -> Result<()>
Calls the MPSGraph framework counterpart for this method.
Examples found in repository?
examples/07_gather_random_rnn.rs (line 64)
15fn main() {
16 let graph = Graph::new().expect("graph");
17 let updates = graph
18 .constant_f32_slice(&[10.0, 20.0, 30.0, 40.0, 50.0, 60.0], &[2, 3])
19 .expect("updates");
20 let gather_indices = graph
21 .constant_bytes(&i32_bytes(&[2, 0]), &[2], data_type::INT32)
22 .expect("gather indices");
23 let gather_nd_indices = graph
24 .constant_bytes(&i32_bytes(&[0, 1, 1, 0]), &[2, 2], data_type::INT32)
25 .expect("gather nd indices");
26 let along_indices = graph
27 .constant_bytes(&i32_bytes(&[2, 1, 0, 0, 1, 2]), &[2, 3], data_type::INT32)
28 .expect("gather along indices");
29 let axis_tensor = graph
30 .constant_scalar(1.0, data_type::INT32)
31 .expect("axis tensor");
32
33 let gather = graph
34 .gather(&updates, &gather_indices, 1, 0, Some("gather"))
35 .expect("gather");
36 let gather_nd = graph
37 .gather_nd(&updates, &gather_nd_indices, 0, Some("gather_nd"))
38 .expect("gather nd");
39 let gather_axis = graph
40 .gather_along_axis(1, &updates, &along_indices, Some("gather_axis"))
41 .expect("gather along axis");
42 let gather_axis_tensor = graph
43 .gather_along_axis_tensor(
44 &axis_tensor,
45 &updates,
46 &along_indices,
47 Some("gather_axis_tensor"),
48 )
49 .expect("gather along axis tensor");
50
51 let descriptor = RandomOpDescriptor::new(random_distribution::UNIFORM, data_type::FLOAT32)
52 .expect("random descriptor");
53 descriptor.set_min(0.0).expect("random min");
54 descriptor.set_max(1.0).expect("random max");
55 let random = graph
56 .random_tensor_seed(&[4], &descriptor, 7, Some("random"))
57 .expect("random tensor");
58 let dropout = graph
59 .dropout(&updates, 1.0, Some("dropout"))
60 .expect("dropout");
61
62 let single_gate_descriptor = SingleGateRNNDescriptor::new().expect("single gate descriptor");
63 single_gate_descriptor
64 .set_activation(rnn_activation::RELU)
65 .expect("single gate activation");
66 let single_gate_source = graph
67 .constant_f32_slice(&[0.5], &[1, 1, 1])
68 .expect("single gate source");
69 let single_gate_recurrent = graph
70 .constant_f32_slice(&[0.0], &[1, 1])
71 .expect("single gate recurrent");
72 let single_gate = graph
73 .single_gate_rnn(
74 &single_gate_source,
75 &single_gate_recurrent,
76 None,
77 None,
78 None,
79 None,
80 &single_gate_descriptor,
81 Some("single_gate"),
82 )
83 .expect("single gate rnn");
84
85 let lstm_descriptor = LSTMDescriptor::new().expect("lstm descriptor");
86 lstm_descriptor
87 .set_produce_cell(true)
88 .expect("set produce cell");
89 let lstm_source = graph
90 .constant_f32_slice(&[0.0; 4], &[1, 1, 4])
91 .expect("lstm source");
92 let lstm_recurrent = graph
93 .constant_f32_slice(&[0.0; 4], &[4, 1])
94 .expect("lstm recurrent");
95 let lstm = graph
96 .lstm(
97 &lstm_source,
98 &lstm_recurrent,
99 None,
100 None,
101 None,
102 None,
103 None,
104 None,
105 &lstm_descriptor,
106 Some("lstm"),
107 )
108 .expect("lstm");
109
110 let gru_descriptor = GRUDescriptor::new().expect("gru descriptor");
111 gru_descriptor.set_training(true).expect("set gru training");
112 gru_descriptor
113 .set_reset_after(true)
114 .expect("set gru reset_after");
115 let gru_source = graph
116 .constant_f32_slice(&[0.0; 3], &[1, 1, 3])
117 .expect("gru source");
118 let gru_recurrent = graph
119 .constant_f32_slice(&[0.0; 3], &[3, 1])
120 .expect("gru recurrent");
121 let gru_secondary_bias = graph
122 .constant_f32_slice(&[0.0], &[1])
123 .expect("gru secondary bias");
124 let gru = graph
125 .gru(
126 &gru_source,
127 &gru_recurrent,
128 None,
129 None,
130 None,
131 None,
132 Some(&gru_secondary_bias),
133 &gru_descriptor,
134 Some("gru"),
135 )
136 .expect("gru");
137
138 let results = graph
139 .run(
140 &[],
141 &[
142 &gather,
143 &gather_nd,
144 &gather_axis,
145 &gather_axis_tensor,
146 &random,
147 &dropout,
148 &single_gate[0],
149 &lstm[0],
150 &lstm[1],
151 &gru[0],
152 &gru[1],
153 ],
154 )
155 .expect("run graph");
156
157 println!("gather: {:?}", results[0].read_f32().expect("gather"));
158 println!("gather_nd: {:?}", results[1].read_f32().expect("gather_nd"));
159 println!(
160 "gather_axis: {:?}",
161 results[2].read_f32().expect("gather_axis")
162 );
163 println!(
164 "gather_axis_tensor: {:?}",
165 results[3].read_f32().expect("gather_axis_tensor")
166 );
167 println!("random: {:?}", results[4].read_f32().expect("random"));
168 println!("dropout: {:?}", results[5].read_f32().expect("dropout"));
169 println!(
170 "single_gate: {:?}",
171 results[6].read_f32().expect("single_gate")
172 );
173 println!(
174 "lstm state: {:?}",
175 results[7].read_f32().expect("lstm state")
176 );
177 println!("lstm cell: {:?}", results[8].read_f32().expect("lstm cell"));
178 println!("gru state: {:?}", results[9].read_f32().expect("gru state"));
179 println!(
180 "gru training: {:?}",
181 results[10].read_f32().expect("gru training")
182 );
183}Trait Implementations§
Source§impl Drop for SingleGateRNNDescriptor
impl Drop for SingleGateRNNDescriptor
impl Send for SingleGateRNNDescriptor
impl Sync for SingleGateRNNDescriptor
Auto Trait Implementations§
impl Freeze for SingleGateRNNDescriptor
impl RefUnwindSafe for SingleGateRNNDescriptor
impl Unpin for SingleGateRNNDescriptor
impl UnsafeUnpin for SingleGateRNNDescriptor
impl UnwindSafe for SingleGateRNNDescriptor
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more