pub struct SingleGateRNNDescriptor { /* private fields */ }Implementations§
Source§impl SingleGateRNNDescriptor
impl SingleGateRNNDescriptor
Sourcepub fn new() -> Option<Self>
pub fn new() -> Option<Self>
Examples found in repository?
examples/07_gather_random_rnn.rs (line 55)
15fn main() {
16 let graph = Graph::new().expect("graph");
17 let updates = graph
18 .constant_f32_slice(&[10.0, 20.0, 30.0, 40.0, 50.0, 60.0], &[2, 3])
19 .expect("updates");
20 let gather_indices = graph
21 .constant_bytes(&i32_bytes(&[2, 0]), &[2], data_type::INT32)
22 .expect("gather indices");
23 let gather_nd_indices = graph
24 .constant_bytes(&i32_bytes(&[0, 1, 1, 0]), &[2, 2], data_type::INT32)
25 .expect("gather nd indices");
26 let along_indices = graph
27 .constant_bytes(&i32_bytes(&[2, 1, 0, 0, 1, 2]), &[2, 3], data_type::INT32)
28 .expect("gather along indices");
29 let axis_tensor = graph
30 .constant_scalar(1.0, data_type::INT32)
31 .expect("axis tensor");
32
33 let gather = graph
34 .gather(&updates, &gather_indices, 1, 0, Some("gather"))
35 .expect("gather");
36 let gather_nd = graph
37 .gather_nd(&updates, &gather_nd_indices, 0, Some("gather_nd"))
38 .expect("gather nd");
39 let gather_axis = graph
40 .gather_along_axis(1, &updates, &along_indices, Some("gather_axis"))
41 .expect("gather along axis");
42 let gather_axis_tensor = graph
43 .gather_along_axis_tensor(&axis_tensor, &updates, &along_indices, Some("gather_axis_tensor"))
44 .expect("gather along axis tensor");
45
46 let descriptor = RandomOpDescriptor::new(random_distribution::UNIFORM, data_type::FLOAT32)
47 .expect("random descriptor");
48 descriptor.set_min(0.0).expect("random min");
49 descriptor.set_max(1.0).expect("random max");
50 let random = graph
51 .random_tensor_seed(&[4], &descriptor, 7, Some("random"))
52 .expect("random tensor");
53 let dropout = graph.dropout(&updates, 1.0, Some("dropout")).expect("dropout");
54
55 let single_gate_descriptor = SingleGateRNNDescriptor::new().expect("single gate descriptor");
56 single_gate_descriptor
57 .set_activation(rnn_activation::RELU)
58 .expect("single gate activation");
59 let single_gate_source = graph
60 .constant_f32_slice(&[0.5], &[1, 1, 1])
61 .expect("single gate source");
62 let single_gate_recurrent = graph
63 .constant_f32_slice(&[0.0], &[1, 1])
64 .expect("single gate recurrent");
65 let single_gate = graph
66 .single_gate_rnn(
67 &single_gate_source,
68 &single_gate_recurrent,
69 None,
70 None,
71 None,
72 None,
73 &single_gate_descriptor,
74 Some("single_gate"),
75 )
76 .expect("single gate rnn");
77
78 let lstm_descriptor = LSTMDescriptor::new().expect("lstm descriptor");
79 lstm_descriptor
80 .set_produce_cell(true)
81 .expect("set produce cell");
82 let lstm_source = graph
83 .constant_f32_slice(&[0.0; 4], &[1, 1, 4])
84 .expect("lstm source");
85 let lstm_recurrent = graph
86 .constant_f32_slice(&[0.0; 4], &[4, 1])
87 .expect("lstm recurrent");
88 let lstm = graph
89 .lstm(
90 &lstm_source,
91 &lstm_recurrent,
92 None,
93 None,
94 None,
95 None,
96 None,
97 None,
98 &lstm_descriptor,
99 Some("lstm"),
100 )
101 .expect("lstm");
102
103 let gru_descriptor = GRUDescriptor::new().expect("gru descriptor");
104 gru_descriptor.set_training(true).expect("set gru training");
105 gru_descriptor
106 .set_reset_after(true)
107 .expect("set gru reset_after");
108 let gru_source = graph
109 .constant_f32_slice(&[0.0; 3], &[1, 1, 3])
110 .expect("gru source");
111 let gru_recurrent = graph
112 .constant_f32_slice(&[0.0; 3], &[3, 1])
113 .expect("gru recurrent");
114 let gru_secondary_bias = graph
115 .constant_f32_slice(&[0.0], &[1])
116 .expect("gru secondary bias");
117 let gru = graph
118 .gru(
119 &gru_source,
120 &gru_recurrent,
121 None,
122 None,
123 None,
124 None,
125 Some(&gru_secondary_bias),
126 &gru_descriptor,
127 Some("gru"),
128 )
129 .expect("gru");
130
131 let results = graph
132 .run(
133 &[],
134 &[
135 &gather,
136 &gather_nd,
137 &gather_axis,
138 &gather_axis_tensor,
139 &random,
140 &dropout,
141 &single_gate[0],
142 &lstm[0],
143 &lstm[1],
144 &gru[0],
145 &gru[1],
146 ],
147 )
148 .expect("run graph");
149
150 println!("gather: {:?}", results[0].read_f32().expect("gather"));
151 println!("gather_nd: {:?}", results[1].read_f32().expect("gather_nd"));
152 println!("gather_axis: {:?}", results[2].read_f32().expect("gather_axis"));
153 println!(
154 "gather_axis_tensor: {:?}",
155 results[3].read_f32().expect("gather_axis_tensor")
156 );
157 println!("random: {:?}", results[4].read_f32().expect("random"));
158 println!("dropout: {:?}", results[5].read_f32().expect("dropout"));
159 println!("single_gate: {:?}", results[6].read_f32().expect("single_gate"));
160 println!("lstm state: {:?}", results[7].read_f32().expect("lstm state"));
161 println!("lstm cell: {:?}", results[8].read_f32().expect("lstm cell"));
162 println!("gru state: {:?}", results[9].read_f32().expect("gru state"));
163 println!("gru training: {:?}", results[10].read_f32().expect("gru training"));
164}pub fn reverse(&self) -> bool
pub fn set_reverse(&self, value: bool) -> Result<()>
pub fn bidirectional(&self) -> bool
pub fn set_bidirectional(&self, value: bool) -> Result<()>
pub fn training(&self) -> bool
pub fn set_training(&self, value: bool) -> Result<()>
pub fn activation(&self) -> usize
Sourcepub fn set_activation(&self, value: usize) -> Result<()>
pub fn set_activation(&self, value: usize) -> Result<()>
Examples found in repository?
examples/07_gather_random_rnn.rs (line 57)
15fn main() {
16 let graph = Graph::new().expect("graph");
17 let updates = graph
18 .constant_f32_slice(&[10.0, 20.0, 30.0, 40.0, 50.0, 60.0], &[2, 3])
19 .expect("updates");
20 let gather_indices = graph
21 .constant_bytes(&i32_bytes(&[2, 0]), &[2], data_type::INT32)
22 .expect("gather indices");
23 let gather_nd_indices = graph
24 .constant_bytes(&i32_bytes(&[0, 1, 1, 0]), &[2, 2], data_type::INT32)
25 .expect("gather nd indices");
26 let along_indices = graph
27 .constant_bytes(&i32_bytes(&[2, 1, 0, 0, 1, 2]), &[2, 3], data_type::INT32)
28 .expect("gather along indices");
29 let axis_tensor = graph
30 .constant_scalar(1.0, data_type::INT32)
31 .expect("axis tensor");
32
33 let gather = graph
34 .gather(&updates, &gather_indices, 1, 0, Some("gather"))
35 .expect("gather");
36 let gather_nd = graph
37 .gather_nd(&updates, &gather_nd_indices, 0, Some("gather_nd"))
38 .expect("gather nd");
39 let gather_axis = graph
40 .gather_along_axis(1, &updates, &along_indices, Some("gather_axis"))
41 .expect("gather along axis");
42 let gather_axis_tensor = graph
43 .gather_along_axis_tensor(&axis_tensor, &updates, &along_indices, Some("gather_axis_tensor"))
44 .expect("gather along axis tensor");
45
46 let descriptor = RandomOpDescriptor::new(random_distribution::UNIFORM, data_type::FLOAT32)
47 .expect("random descriptor");
48 descriptor.set_min(0.0).expect("random min");
49 descriptor.set_max(1.0).expect("random max");
50 let random = graph
51 .random_tensor_seed(&[4], &descriptor, 7, Some("random"))
52 .expect("random tensor");
53 let dropout = graph.dropout(&updates, 1.0, Some("dropout")).expect("dropout");
54
55 let single_gate_descriptor = SingleGateRNNDescriptor::new().expect("single gate descriptor");
56 single_gate_descriptor
57 .set_activation(rnn_activation::RELU)
58 .expect("single gate activation");
59 let single_gate_source = graph
60 .constant_f32_slice(&[0.5], &[1, 1, 1])
61 .expect("single gate source");
62 let single_gate_recurrent = graph
63 .constant_f32_slice(&[0.0], &[1, 1])
64 .expect("single gate recurrent");
65 let single_gate = graph
66 .single_gate_rnn(
67 &single_gate_source,
68 &single_gate_recurrent,
69 None,
70 None,
71 None,
72 None,
73 &single_gate_descriptor,
74 Some("single_gate"),
75 )
76 .expect("single gate rnn");
77
78 let lstm_descriptor = LSTMDescriptor::new().expect("lstm descriptor");
79 lstm_descriptor
80 .set_produce_cell(true)
81 .expect("set produce cell");
82 let lstm_source = graph
83 .constant_f32_slice(&[0.0; 4], &[1, 1, 4])
84 .expect("lstm source");
85 let lstm_recurrent = graph
86 .constant_f32_slice(&[0.0; 4], &[4, 1])
87 .expect("lstm recurrent");
88 let lstm = graph
89 .lstm(
90 &lstm_source,
91 &lstm_recurrent,
92 None,
93 None,
94 None,
95 None,
96 None,
97 None,
98 &lstm_descriptor,
99 Some("lstm"),
100 )
101 .expect("lstm");
102
103 let gru_descriptor = GRUDescriptor::new().expect("gru descriptor");
104 gru_descriptor.set_training(true).expect("set gru training");
105 gru_descriptor
106 .set_reset_after(true)
107 .expect("set gru reset_after");
108 let gru_source = graph
109 .constant_f32_slice(&[0.0; 3], &[1, 1, 3])
110 .expect("gru source");
111 let gru_recurrent = graph
112 .constant_f32_slice(&[0.0; 3], &[3, 1])
113 .expect("gru recurrent");
114 let gru_secondary_bias = graph
115 .constant_f32_slice(&[0.0], &[1])
116 .expect("gru secondary bias");
117 let gru = graph
118 .gru(
119 &gru_source,
120 &gru_recurrent,
121 None,
122 None,
123 None,
124 None,
125 Some(&gru_secondary_bias),
126 &gru_descriptor,
127 Some("gru"),
128 )
129 .expect("gru");
130
131 let results = graph
132 .run(
133 &[],
134 &[
135 &gather,
136 &gather_nd,
137 &gather_axis,
138 &gather_axis_tensor,
139 &random,
140 &dropout,
141 &single_gate[0],
142 &lstm[0],
143 &lstm[1],
144 &gru[0],
145 &gru[1],
146 ],
147 )
148 .expect("run graph");
149
150 println!("gather: {:?}", results[0].read_f32().expect("gather"));
151 println!("gather_nd: {:?}", results[1].read_f32().expect("gather_nd"));
152 println!("gather_axis: {:?}", results[2].read_f32().expect("gather_axis"));
153 println!(
154 "gather_axis_tensor: {:?}",
155 results[3].read_f32().expect("gather_axis_tensor")
156 );
157 println!("random: {:?}", results[4].read_f32().expect("random"));
158 println!("dropout: {:?}", results[5].read_f32().expect("dropout"));
159 println!("single_gate: {:?}", results[6].read_f32().expect("single_gate"));
160 println!("lstm state: {:?}", results[7].read_f32().expect("lstm state"));
161 println!("lstm cell: {:?}", results[8].read_f32().expect("lstm cell"));
162 println!("gru state: {:?}", results[9].read_f32().expect("gru state"));
163 println!("gru training: {:?}", results[10].read_f32().expect("gru training"));
164}Trait Implementations§
Source§impl Drop for SingleGateRNNDescriptor
impl Drop for SingleGateRNNDescriptor
impl Send for SingleGateRNNDescriptor
impl Sync for SingleGateRNNDescriptor
Auto Trait Implementations§
impl Freeze for SingleGateRNNDescriptor
impl RefUnwindSafe for SingleGateRNNDescriptor
impl Unpin for SingleGateRNNDescriptor
impl UnsafeUnpin for SingleGateRNNDescriptor
impl UnwindSafe for SingleGateRNNDescriptor
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more