Skip to main content

07_gather_random_rnn/
07_gather_random_rnn.rs

1#![allow(clippy::too_many_lines)]
2
3use apple_mpsgraph::{
4    data_type, random_distribution, rnn_activation, GRUDescriptor, Graph, LSTMDescriptor,
5    RandomOpDescriptor, SingleGateRNNDescriptor,
6};
7
8fn i32_bytes(values: &[i32]) -> Vec<u8> {
9    values
10        .iter()
11        .flat_map(|value| value.to_ne_bytes())
12        .collect::<Vec<_>>()
13}
14
15fn main() {
16    let graph = Graph::new().expect("graph");
17    let updates = graph
18        .constant_f32_slice(&[10.0, 20.0, 30.0, 40.0, 50.0, 60.0], &[2, 3])
19        .expect("updates");
20    let gather_indices = graph
21        .constant_bytes(&i32_bytes(&[2, 0]), &[2], data_type::INT32)
22        .expect("gather indices");
23    let gather_nd_indices = graph
24        .constant_bytes(&i32_bytes(&[0, 1, 1, 0]), &[2, 2], data_type::INT32)
25        .expect("gather nd indices");
26    let along_indices = graph
27        .constant_bytes(&i32_bytes(&[2, 1, 0, 0, 1, 2]), &[2, 3], data_type::INT32)
28        .expect("gather along indices");
29    let axis_tensor = graph
30        .constant_scalar(1.0, data_type::INT32)
31        .expect("axis tensor");
32
33    let gather = graph
34        .gather(&updates, &gather_indices, 1, 0, Some("gather"))
35        .expect("gather");
36    let gather_nd = graph
37        .gather_nd(&updates, &gather_nd_indices, 0, Some("gather_nd"))
38        .expect("gather nd");
39    let gather_axis = graph
40        .gather_along_axis(1, &updates, &along_indices, Some("gather_axis"))
41        .expect("gather along axis");
42    let gather_axis_tensor = graph
43        .gather_along_axis_tensor(
44            &axis_tensor,
45            &updates,
46            &along_indices,
47            Some("gather_axis_tensor"),
48        )
49        .expect("gather along axis tensor");
50
51    let descriptor = RandomOpDescriptor::new(random_distribution::UNIFORM, data_type::FLOAT32)
52        .expect("random descriptor");
53    descriptor.set_min(0.0).expect("random min");
54    descriptor.set_max(1.0).expect("random max");
55    let random = graph
56        .random_tensor_seed(&[4], &descriptor, 7, Some("random"))
57        .expect("random tensor");
58    let dropout = graph
59        .dropout(&updates, 1.0, Some("dropout"))
60        .expect("dropout");
61
62    let single_gate_descriptor = SingleGateRNNDescriptor::new().expect("single gate descriptor");
63    single_gate_descriptor
64        .set_activation(rnn_activation::RELU)
65        .expect("single gate activation");
66    let single_gate_source = graph
67        .constant_f32_slice(&[0.5], &[1, 1, 1])
68        .expect("single gate source");
69    let single_gate_recurrent = graph
70        .constant_f32_slice(&[0.0], &[1, 1])
71        .expect("single gate recurrent");
72    let single_gate = graph
73        .single_gate_rnn(
74            &single_gate_source,
75            &single_gate_recurrent,
76            None,
77            None,
78            None,
79            None,
80            &single_gate_descriptor,
81            Some("single_gate"),
82        )
83        .expect("single gate rnn");
84
85    let lstm_descriptor = LSTMDescriptor::new().expect("lstm descriptor");
86    lstm_descriptor
87        .set_produce_cell(true)
88        .expect("set produce cell");
89    let lstm_source = graph
90        .constant_f32_slice(&[0.0; 4], &[1, 1, 4])
91        .expect("lstm source");
92    let lstm_recurrent = graph
93        .constant_f32_slice(&[0.0; 4], &[4, 1])
94        .expect("lstm recurrent");
95    let lstm = graph
96        .lstm(
97            &lstm_source,
98            &lstm_recurrent,
99            None,
100            None,
101            None,
102            None,
103            None,
104            None,
105            &lstm_descriptor,
106            Some("lstm"),
107        )
108        .expect("lstm");
109
110    let gru_descriptor = GRUDescriptor::new().expect("gru descriptor");
111    gru_descriptor.set_training(true).expect("set gru training");
112    gru_descriptor
113        .set_reset_after(true)
114        .expect("set gru reset_after");
115    let gru_source = graph
116        .constant_f32_slice(&[0.0; 3], &[1, 1, 3])
117        .expect("gru source");
118    let gru_recurrent = graph
119        .constant_f32_slice(&[0.0; 3], &[3, 1])
120        .expect("gru recurrent");
121    let gru_secondary_bias = graph
122        .constant_f32_slice(&[0.0], &[1])
123        .expect("gru secondary bias");
124    let gru = graph
125        .gru(
126            &gru_source,
127            &gru_recurrent,
128            None,
129            None,
130            None,
131            None,
132            Some(&gru_secondary_bias),
133            &gru_descriptor,
134            Some("gru"),
135        )
136        .expect("gru");
137
138    let results = graph
139        .run(
140            &[],
141            &[
142                &gather,
143                &gather_nd,
144                &gather_axis,
145                &gather_axis_tensor,
146                &random,
147                &dropout,
148                &single_gate[0],
149                &lstm[0],
150                &lstm[1],
151                &gru[0],
152                &gru[1],
153            ],
154        )
155        .expect("run graph");
156
157    println!("gather: {:?}", results[0].read_f32().expect("gather"));
158    println!("gather_nd: {:?}", results[1].read_f32().expect("gather_nd"));
159    println!(
160        "gather_axis: {:?}",
161        results[2].read_f32().expect("gather_axis")
162    );
163    println!(
164        "gather_axis_tensor: {:?}",
165        results[3].read_f32().expect("gather_axis_tensor")
166    );
167    println!("random: {:?}", results[4].read_f32().expect("random"));
168    println!("dropout: {:?}", results[5].read_f32().expect("dropout"));
169    println!(
170        "single_gate: {:?}",
171        results[6].read_f32().expect("single_gate")
172    );
173    println!(
174        "lstm state: {:?}",
175        results[7].read_f32().expect("lstm state")
176    );
177    println!("lstm cell: {:?}", results[8].read_f32().expect("lstm cell"));
178    println!("gru state: {:?}", results[9].read_f32().expect("gru state"));
179    println!(
180        "gru training: {:?}",
181        results[10].read_f32().expect("gru training")
182    );
183}