pub struct GRUDescriptor { /* private fields */ }Expand description
Mirrors the MPSGraph framework counterpart for this type.
Implementations§
Source§impl GRUDescriptor
impl GRUDescriptor
Sourcepub fn new() -> Option<Self>
pub fn new() -> Option<Self>
Calls the MPSGraph framework counterpart for new.
Examples found in repository?
15fn main() {
16 let graph = Graph::new().expect("graph");
17 let updates = graph
18 .constant_f32_slice(&[10.0, 20.0, 30.0, 40.0, 50.0, 60.0], &[2, 3])
19 .expect("updates");
20 let gather_indices = graph
21 .constant_bytes(&i32_bytes(&[2, 0]), &[2], data_type::INT32)
22 .expect("gather indices");
23 let gather_nd_indices = graph
24 .constant_bytes(&i32_bytes(&[0, 1, 1, 0]), &[2, 2], data_type::INT32)
25 .expect("gather nd indices");
26 let along_indices = graph
27 .constant_bytes(&i32_bytes(&[2, 1, 0, 0, 1, 2]), &[2, 3], data_type::INT32)
28 .expect("gather along indices");
29 let axis_tensor = graph
30 .constant_scalar(1.0, data_type::INT32)
31 .expect("axis tensor");
32
33 let gather = graph
34 .gather(&updates, &gather_indices, 1, 0, Some("gather"))
35 .expect("gather");
36 let gather_nd = graph
37 .gather_nd(&updates, &gather_nd_indices, 0, Some("gather_nd"))
38 .expect("gather nd");
39 let gather_axis = graph
40 .gather_along_axis(1, &updates, &along_indices, Some("gather_axis"))
41 .expect("gather along axis");
42 let gather_axis_tensor = graph
43 .gather_along_axis_tensor(
44 &axis_tensor,
45 &updates,
46 &along_indices,
47 Some("gather_axis_tensor"),
48 )
49 .expect("gather along axis tensor");
50
51 let descriptor = RandomOpDescriptor::new(random_distribution::UNIFORM, data_type::FLOAT32)
52 .expect("random descriptor");
53 descriptor.set_min(0.0).expect("random min");
54 descriptor.set_max(1.0).expect("random max");
55 let random = graph
56 .random_tensor_seed(&[4], &descriptor, 7, Some("random"))
57 .expect("random tensor");
58 let dropout = graph
59 .dropout(&updates, 1.0, Some("dropout"))
60 .expect("dropout");
61
62 let single_gate_descriptor = SingleGateRNNDescriptor::new().expect("single gate descriptor");
63 single_gate_descriptor
64 .set_activation(rnn_activation::RELU)
65 .expect("single gate activation");
66 let single_gate_source = graph
67 .constant_f32_slice(&[0.5], &[1, 1, 1])
68 .expect("single gate source");
69 let single_gate_recurrent = graph
70 .constant_f32_slice(&[0.0], &[1, 1])
71 .expect("single gate recurrent");
72 let single_gate = graph
73 .single_gate_rnn(
74 &single_gate_source,
75 &single_gate_recurrent,
76 None,
77 None,
78 None,
79 None,
80 &single_gate_descriptor,
81 Some("single_gate"),
82 )
83 .expect("single gate rnn");
84
85 let lstm_descriptor = LSTMDescriptor::new().expect("lstm descriptor");
86 lstm_descriptor
87 .set_produce_cell(true)
88 .expect("set produce cell");
89 let lstm_source = graph
90 .constant_f32_slice(&[0.0; 4], &[1, 1, 4])
91 .expect("lstm source");
92 let lstm_recurrent = graph
93 .constant_f32_slice(&[0.0; 4], &[4, 1])
94 .expect("lstm recurrent");
95 let lstm = graph
96 .lstm(
97 &lstm_source,
98 &lstm_recurrent,
99 None,
100 None,
101 None,
102 None,
103 None,
104 None,
105 &lstm_descriptor,
106 Some("lstm"),
107 )
108 .expect("lstm");
109
110 let gru_descriptor = GRUDescriptor::new().expect("gru descriptor");
111 gru_descriptor.set_training(true).expect("set gru training");
112 gru_descriptor
113 .set_reset_after(true)
114 .expect("set gru reset_after");
115 let gru_source = graph
116 .constant_f32_slice(&[0.0; 3], &[1, 1, 3])
117 .expect("gru source");
118 let gru_recurrent = graph
119 .constant_f32_slice(&[0.0; 3], &[3, 1])
120 .expect("gru recurrent");
121 let gru_secondary_bias = graph
122 .constant_f32_slice(&[0.0], &[1])
123 .expect("gru secondary bias");
124 let gru = graph
125 .gru(
126 &gru_source,
127 &gru_recurrent,
128 None,
129 None,
130 None,
131 None,
132 Some(&gru_secondary_bias),
133 &gru_descriptor,
134 Some("gru"),
135 )
136 .expect("gru");
137
138 let results = graph
139 .run(
140 &[],
141 &[
142 &gather,
143 &gather_nd,
144 &gather_axis,
145 &gather_axis_tensor,
146 &random,
147 &dropout,
148 &single_gate[0],
149 &lstm[0],
150 &lstm[1],
151 &gru[0],
152 &gru[1],
153 ],
154 )
155 .expect("run graph");
156
157 println!("gather: {:?}", results[0].read_f32().expect("gather"));
158 println!("gather_nd: {:?}", results[1].read_f32().expect("gather_nd"));
159 println!(
160 "gather_axis: {:?}",
161 results[2].read_f32().expect("gather_axis")
162 );
163 println!(
164 "gather_axis_tensor: {:?}",
165 results[3].read_f32().expect("gather_axis_tensor")
166 );
167 println!("random: {:?}", results[4].read_f32().expect("random"));
168 println!("dropout: {:?}", results[5].read_f32().expect("dropout"));
169 println!(
170 "single_gate: {:?}",
171 results[6].read_f32().expect("single_gate")
172 );
173 println!(
174 "lstm state: {:?}",
175 results[7].read_f32().expect("lstm state")
176 );
177 println!("lstm cell: {:?}", results[8].read_f32().expect("lstm cell"));
178 println!("gru state: {:?}", results[9].read_f32().expect("gru state"));
179 println!(
180 "gru training: {:?}",
181 results[10].read_f32().expect("gru training")
182 );
183}Sourcepub fn set_reverse(&self, value: bool) -> Result<()>
pub fn set_reverse(&self, value: bool) -> Result<()>
Calls the MPSGraph framework counterpart for this method.
Sourcepub fn bidirectional(&self) -> bool
pub fn bidirectional(&self) -> bool
Calls the MPSGraph framework counterpart for this method.
Sourcepub fn set_bidirectional(&self, value: bool) -> Result<()>
pub fn set_bidirectional(&self, value: bool) -> Result<()>
Calls the MPSGraph framework counterpart for this method.
Sourcepub fn set_training(&self, value: bool) -> Result<()>
pub fn set_training(&self, value: bool) -> Result<()>
Calls the MPSGraph framework counterpart for this method.
Examples found in repository?
15fn main() {
16 let graph = Graph::new().expect("graph");
17 let updates = graph
18 .constant_f32_slice(&[10.0, 20.0, 30.0, 40.0, 50.0, 60.0], &[2, 3])
19 .expect("updates");
20 let gather_indices = graph
21 .constant_bytes(&i32_bytes(&[2, 0]), &[2], data_type::INT32)
22 .expect("gather indices");
23 let gather_nd_indices = graph
24 .constant_bytes(&i32_bytes(&[0, 1, 1, 0]), &[2, 2], data_type::INT32)
25 .expect("gather nd indices");
26 let along_indices = graph
27 .constant_bytes(&i32_bytes(&[2, 1, 0, 0, 1, 2]), &[2, 3], data_type::INT32)
28 .expect("gather along indices");
29 let axis_tensor = graph
30 .constant_scalar(1.0, data_type::INT32)
31 .expect("axis tensor");
32
33 let gather = graph
34 .gather(&updates, &gather_indices, 1, 0, Some("gather"))
35 .expect("gather");
36 let gather_nd = graph
37 .gather_nd(&updates, &gather_nd_indices, 0, Some("gather_nd"))
38 .expect("gather nd");
39 let gather_axis = graph
40 .gather_along_axis(1, &updates, &along_indices, Some("gather_axis"))
41 .expect("gather along axis");
42 let gather_axis_tensor = graph
43 .gather_along_axis_tensor(
44 &axis_tensor,
45 &updates,
46 &along_indices,
47 Some("gather_axis_tensor"),
48 )
49 .expect("gather along axis tensor");
50
51 let descriptor = RandomOpDescriptor::new(random_distribution::UNIFORM, data_type::FLOAT32)
52 .expect("random descriptor");
53 descriptor.set_min(0.0).expect("random min");
54 descriptor.set_max(1.0).expect("random max");
55 let random = graph
56 .random_tensor_seed(&[4], &descriptor, 7, Some("random"))
57 .expect("random tensor");
58 let dropout = graph
59 .dropout(&updates, 1.0, Some("dropout"))
60 .expect("dropout");
61
62 let single_gate_descriptor = SingleGateRNNDescriptor::new().expect("single gate descriptor");
63 single_gate_descriptor
64 .set_activation(rnn_activation::RELU)
65 .expect("single gate activation");
66 let single_gate_source = graph
67 .constant_f32_slice(&[0.5], &[1, 1, 1])
68 .expect("single gate source");
69 let single_gate_recurrent = graph
70 .constant_f32_slice(&[0.0], &[1, 1])
71 .expect("single gate recurrent");
72 let single_gate = graph
73 .single_gate_rnn(
74 &single_gate_source,
75 &single_gate_recurrent,
76 None,
77 None,
78 None,
79 None,
80 &single_gate_descriptor,
81 Some("single_gate"),
82 )
83 .expect("single gate rnn");
84
85 let lstm_descriptor = LSTMDescriptor::new().expect("lstm descriptor");
86 lstm_descriptor
87 .set_produce_cell(true)
88 .expect("set produce cell");
89 let lstm_source = graph
90 .constant_f32_slice(&[0.0; 4], &[1, 1, 4])
91 .expect("lstm source");
92 let lstm_recurrent = graph
93 .constant_f32_slice(&[0.0; 4], &[4, 1])
94 .expect("lstm recurrent");
95 let lstm = graph
96 .lstm(
97 &lstm_source,
98 &lstm_recurrent,
99 None,
100 None,
101 None,
102 None,
103 None,
104 None,
105 &lstm_descriptor,
106 Some("lstm"),
107 )
108 .expect("lstm");
109
110 let gru_descriptor = GRUDescriptor::new().expect("gru descriptor");
111 gru_descriptor.set_training(true).expect("set gru training");
112 gru_descriptor
113 .set_reset_after(true)
114 .expect("set gru reset_after");
115 let gru_source = graph
116 .constant_f32_slice(&[0.0; 3], &[1, 1, 3])
117 .expect("gru source");
118 let gru_recurrent = graph
119 .constant_f32_slice(&[0.0; 3], &[3, 1])
120 .expect("gru recurrent");
121 let gru_secondary_bias = graph
122 .constant_f32_slice(&[0.0], &[1])
123 .expect("gru secondary bias");
124 let gru = graph
125 .gru(
126 &gru_source,
127 &gru_recurrent,
128 None,
129 None,
130 None,
131 None,
132 Some(&gru_secondary_bias),
133 &gru_descriptor,
134 Some("gru"),
135 )
136 .expect("gru");
137
138 let results = graph
139 .run(
140 &[],
141 &[
142 &gather,
143 &gather_nd,
144 &gather_axis,
145 &gather_axis_tensor,
146 &random,
147 &dropout,
148 &single_gate[0],
149 &lstm[0],
150 &lstm[1],
151 &gru[0],
152 &gru[1],
153 ],
154 )
155 .expect("run graph");
156
157 println!("gather: {:?}", results[0].read_f32().expect("gather"));
158 println!("gather_nd: {:?}", results[1].read_f32().expect("gather_nd"));
159 println!(
160 "gather_axis: {:?}",
161 results[2].read_f32().expect("gather_axis")
162 );
163 println!(
164 "gather_axis_tensor: {:?}",
165 results[3].read_f32().expect("gather_axis_tensor")
166 );
167 println!("random: {:?}", results[4].read_f32().expect("random"));
168 println!("dropout: {:?}", results[5].read_f32().expect("dropout"));
169 println!(
170 "single_gate: {:?}",
171 results[6].read_f32().expect("single_gate")
172 );
173 println!(
174 "lstm state: {:?}",
175 results[7].read_f32().expect("lstm state")
176 );
177 println!("lstm cell: {:?}", results[8].read_f32().expect("lstm cell"));
178 println!("gru state: {:?}", results[9].read_f32().expect("gru state"));
179 println!(
180 "gru training: {:?}",
181 results[10].read_f32().expect("gru training")
182 );
183}Sourcepub fn reset_gate_first(&self) -> bool
pub fn reset_gate_first(&self) -> bool
Calls the MPSGraph framework counterpart for this method.
Sourcepub fn set_reset_gate_first(&self, value: bool) -> Result<()>
pub fn set_reset_gate_first(&self, value: bool) -> Result<()>
Calls the MPSGraph framework counterpart for this method.
Sourcepub fn reset_after(&self) -> bool
pub fn reset_after(&self) -> bool
Calls the MPSGraph framework counterpart for this method.
Sourcepub fn set_reset_after(&self, value: bool) -> Result<()>
pub fn set_reset_after(&self, value: bool) -> Result<()>
Calls the MPSGraph framework counterpart for this method.
Examples found in repository?
15fn main() {
16 let graph = Graph::new().expect("graph");
17 let updates = graph
18 .constant_f32_slice(&[10.0, 20.0, 30.0, 40.0, 50.0, 60.0], &[2, 3])
19 .expect("updates");
20 let gather_indices = graph
21 .constant_bytes(&i32_bytes(&[2, 0]), &[2], data_type::INT32)
22 .expect("gather indices");
23 let gather_nd_indices = graph
24 .constant_bytes(&i32_bytes(&[0, 1, 1, 0]), &[2, 2], data_type::INT32)
25 .expect("gather nd indices");
26 let along_indices = graph
27 .constant_bytes(&i32_bytes(&[2, 1, 0, 0, 1, 2]), &[2, 3], data_type::INT32)
28 .expect("gather along indices");
29 let axis_tensor = graph
30 .constant_scalar(1.0, data_type::INT32)
31 .expect("axis tensor");
32
33 let gather = graph
34 .gather(&updates, &gather_indices, 1, 0, Some("gather"))
35 .expect("gather");
36 let gather_nd = graph
37 .gather_nd(&updates, &gather_nd_indices, 0, Some("gather_nd"))
38 .expect("gather nd");
39 let gather_axis = graph
40 .gather_along_axis(1, &updates, &along_indices, Some("gather_axis"))
41 .expect("gather along axis");
42 let gather_axis_tensor = graph
43 .gather_along_axis_tensor(
44 &axis_tensor,
45 &updates,
46 &along_indices,
47 Some("gather_axis_tensor"),
48 )
49 .expect("gather along axis tensor");
50
51 let descriptor = RandomOpDescriptor::new(random_distribution::UNIFORM, data_type::FLOAT32)
52 .expect("random descriptor");
53 descriptor.set_min(0.0).expect("random min");
54 descriptor.set_max(1.0).expect("random max");
55 let random = graph
56 .random_tensor_seed(&[4], &descriptor, 7, Some("random"))
57 .expect("random tensor");
58 let dropout = graph
59 .dropout(&updates, 1.0, Some("dropout"))
60 .expect("dropout");
61
62 let single_gate_descriptor = SingleGateRNNDescriptor::new().expect("single gate descriptor");
63 single_gate_descriptor
64 .set_activation(rnn_activation::RELU)
65 .expect("single gate activation");
66 let single_gate_source = graph
67 .constant_f32_slice(&[0.5], &[1, 1, 1])
68 .expect("single gate source");
69 let single_gate_recurrent = graph
70 .constant_f32_slice(&[0.0], &[1, 1])
71 .expect("single gate recurrent");
72 let single_gate = graph
73 .single_gate_rnn(
74 &single_gate_source,
75 &single_gate_recurrent,
76 None,
77 None,
78 None,
79 None,
80 &single_gate_descriptor,
81 Some("single_gate"),
82 )
83 .expect("single gate rnn");
84
85 let lstm_descriptor = LSTMDescriptor::new().expect("lstm descriptor");
86 lstm_descriptor
87 .set_produce_cell(true)
88 .expect("set produce cell");
89 let lstm_source = graph
90 .constant_f32_slice(&[0.0; 4], &[1, 1, 4])
91 .expect("lstm source");
92 let lstm_recurrent = graph
93 .constant_f32_slice(&[0.0; 4], &[4, 1])
94 .expect("lstm recurrent");
95 let lstm = graph
96 .lstm(
97 &lstm_source,
98 &lstm_recurrent,
99 None,
100 None,
101 None,
102 None,
103 None,
104 None,
105 &lstm_descriptor,
106 Some("lstm"),
107 )
108 .expect("lstm");
109
110 let gru_descriptor = GRUDescriptor::new().expect("gru descriptor");
111 gru_descriptor.set_training(true).expect("set gru training");
112 gru_descriptor
113 .set_reset_after(true)
114 .expect("set gru reset_after");
115 let gru_source = graph
116 .constant_f32_slice(&[0.0; 3], &[1, 1, 3])
117 .expect("gru source");
118 let gru_recurrent = graph
119 .constant_f32_slice(&[0.0; 3], &[3, 1])
120 .expect("gru recurrent");
121 let gru_secondary_bias = graph
122 .constant_f32_slice(&[0.0], &[1])
123 .expect("gru secondary bias");
124 let gru = graph
125 .gru(
126 &gru_source,
127 &gru_recurrent,
128 None,
129 None,
130 None,
131 None,
132 Some(&gru_secondary_bias),
133 &gru_descriptor,
134 Some("gru"),
135 )
136 .expect("gru");
137
138 let results = graph
139 .run(
140 &[],
141 &[
142 &gather,
143 &gather_nd,
144 &gather_axis,
145 &gather_axis_tensor,
146 &random,
147 &dropout,
148 &single_gate[0],
149 &lstm[0],
150 &lstm[1],
151 &gru[0],
152 &gru[1],
153 ],
154 )
155 .expect("run graph");
156
157 println!("gather: {:?}", results[0].read_f32().expect("gather"));
158 println!("gather_nd: {:?}", results[1].read_f32().expect("gather_nd"));
159 println!(
160 "gather_axis: {:?}",
161 results[2].read_f32().expect("gather_axis")
162 );
163 println!(
164 "gather_axis_tensor: {:?}",
165 results[3].read_f32().expect("gather_axis_tensor")
166 );
167 println!("random: {:?}", results[4].read_f32().expect("random"));
168 println!("dropout: {:?}", results[5].read_f32().expect("dropout"));
169 println!(
170 "single_gate: {:?}",
171 results[6].read_f32().expect("single_gate")
172 );
173 println!(
174 "lstm state: {:?}",
175 results[7].read_f32().expect("lstm state")
176 );
177 println!("lstm cell: {:?}", results[8].read_f32().expect("lstm cell"));
178 println!("gru state: {:?}", results[9].read_f32().expect("gru state"));
179 println!(
180 "gru training: {:?}",
181 results[10].read_f32().expect("gru training")
182 );
183}Sourcepub fn set_flip_z(&self, value: bool) -> Result<()>
pub fn set_flip_z(&self, value: bool) -> Result<()>
Calls the MPSGraph framework counterpart for this method.
Sourcepub fn update_gate_activation(&self) -> usize
pub fn update_gate_activation(&self) -> usize
Calls the MPSGraph framework counterpart for this method.
Sourcepub fn set_update_gate_activation(&self, value: usize) -> Result<()>
pub fn set_update_gate_activation(&self, value: usize) -> Result<()>
Calls the MPSGraph framework counterpart for this method.
Sourcepub fn reset_gate_activation(&self) -> usize
pub fn reset_gate_activation(&self) -> usize
Calls the MPSGraph framework counterpart for this method.
Sourcepub fn set_reset_gate_activation(&self, value: usize) -> Result<()>
pub fn set_reset_gate_activation(&self, value: usize) -> Result<()>
Calls the MPSGraph framework counterpart for this method.
Sourcepub fn output_gate_activation(&self) -> usize
pub fn output_gate_activation(&self) -> usize
Calls the MPSGraph framework counterpart for this method.
Sourcepub fn set_output_gate_activation(&self, value: usize) -> Result<()>
pub fn set_output_gate_activation(&self, value: usize) -> Result<()>
Calls the MPSGraph framework counterpart for this method.