1#![allow(clippy::too_many_lines)]
2
3use apple_mpsgraph::{
4 data_type, random_distribution, rnn_activation, Graph, GRUDescriptor, LSTMDescriptor,
5 RandomOpDescriptor, SingleGateRNNDescriptor,
6};
7
8fn i32_bytes(values: &[i32]) -> Vec<u8> {
9 values
10 .iter()
11 .flat_map(|value| value.to_ne_bytes())
12 .collect::<Vec<_>>()
13}
14
15fn main() {
16 let graph = Graph::new().expect("graph");
17 let updates = graph
18 .constant_f32_slice(&[10.0, 20.0, 30.0, 40.0, 50.0, 60.0], &[2, 3])
19 .expect("updates");
20 let gather_indices = graph
21 .constant_bytes(&i32_bytes(&[2, 0]), &[2], data_type::INT32)
22 .expect("gather indices");
23 let gather_nd_indices = graph
24 .constant_bytes(&i32_bytes(&[0, 1, 1, 0]), &[2, 2], data_type::INT32)
25 .expect("gather nd indices");
26 let along_indices = graph
27 .constant_bytes(&i32_bytes(&[2, 1, 0, 0, 1, 2]), &[2, 3], data_type::INT32)
28 .expect("gather along indices");
29 let axis_tensor = graph
30 .constant_scalar(1.0, data_type::INT32)
31 .expect("axis tensor");
32
33 let gather = graph
34 .gather(&updates, &gather_indices, 1, 0, Some("gather"))
35 .expect("gather");
36 let gather_nd = graph
37 .gather_nd(&updates, &gather_nd_indices, 0, Some("gather_nd"))
38 .expect("gather nd");
39 let gather_axis = graph
40 .gather_along_axis(1, &updates, &along_indices, Some("gather_axis"))
41 .expect("gather along axis");
42 let gather_axis_tensor = graph
43 .gather_along_axis_tensor(&axis_tensor, &updates, &along_indices, Some("gather_axis_tensor"))
44 .expect("gather along axis tensor");
45
46 let descriptor = RandomOpDescriptor::new(random_distribution::UNIFORM, data_type::FLOAT32)
47 .expect("random descriptor");
48 descriptor.set_min(0.0).expect("random min");
49 descriptor.set_max(1.0).expect("random max");
50 let random = graph
51 .random_tensor_seed(&[4], &descriptor, 7, Some("random"))
52 .expect("random tensor");
53 let dropout = graph.dropout(&updates, 1.0, Some("dropout")).expect("dropout");
54
55 let single_gate_descriptor = SingleGateRNNDescriptor::new().expect("single gate descriptor");
56 single_gate_descriptor
57 .set_activation(rnn_activation::RELU)
58 .expect("single gate activation");
59 let single_gate_source = graph
60 .constant_f32_slice(&[0.5], &[1, 1, 1])
61 .expect("single gate source");
62 let single_gate_recurrent = graph
63 .constant_f32_slice(&[0.0], &[1, 1])
64 .expect("single gate recurrent");
65 let single_gate = graph
66 .single_gate_rnn(
67 &single_gate_source,
68 &single_gate_recurrent,
69 None,
70 None,
71 None,
72 None,
73 &single_gate_descriptor,
74 Some("single_gate"),
75 )
76 .expect("single gate rnn");
77
78 let lstm_descriptor = LSTMDescriptor::new().expect("lstm descriptor");
79 lstm_descriptor
80 .set_produce_cell(true)
81 .expect("set produce cell");
82 let lstm_source = graph
83 .constant_f32_slice(&[0.0; 4], &[1, 1, 4])
84 .expect("lstm source");
85 let lstm_recurrent = graph
86 .constant_f32_slice(&[0.0; 4], &[4, 1])
87 .expect("lstm recurrent");
88 let lstm = graph
89 .lstm(
90 &lstm_source,
91 &lstm_recurrent,
92 None,
93 None,
94 None,
95 None,
96 None,
97 None,
98 &lstm_descriptor,
99 Some("lstm"),
100 )
101 .expect("lstm");
102
103 let gru_descriptor = GRUDescriptor::new().expect("gru descriptor");
104 gru_descriptor.set_training(true).expect("set gru training");
105 gru_descriptor
106 .set_reset_after(true)
107 .expect("set gru reset_after");
108 let gru_source = graph
109 .constant_f32_slice(&[0.0; 3], &[1, 1, 3])
110 .expect("gru source");
111 let gru_recurrent = graph
112 .constant_f32_slice(&[0.0; 3], &[3, 1])
113 .expect("gru recurrent");
114 let gru_secondary_bias = graph
115 .constant_f32_slice(&[0.0], &[1])
116 .expect("gru secondary bias");
117 let gru = graph
118 .gru(
119 &gru_source,
120 &gru_recurrent,
121 None,
122 None,
123 None,
124 None,
125 Some(&gru_secondary_bias),
126 &gru_descriptor,
127 Some("gru"),
128 )
129 .expect("gru");
130
131 let results = graph
132 .run(
133 &[],
134 &[
135 &gather,
136 &gather_nd,
137 &gather_axis,
138 &gather_axis_tensor,
139 &random,
140 &dropout,
141 &single_gate[0],
142 &lstm[0],
143 &lstm[1],
144 &gru[0],
145 &gru[1],
146 ],
147 )
148 .expect("run graph");
149
150 println!("gather: {:?}", results[0].read_f32().expect("gather"));
151 println!("gather_nd: {:?}", results[1].read_f32().expect("gather_nd"));
152 println!("gather_axis: {:?}", results[2].read_f32().expect("gather_axis"));
153 println!(
154 "gather_axis_tensor: {:?}",
155 results[3].read_f32().expect("gather_axis_tensor")
156 );
157 println!("random: {:?}", results[4].read_f32().expect("random"));
158 println!("dropout: {:?}", results[5].read_f32().expect("dropout"));
159 println!("single_gate: {:?}", results[6].read_f32().expect("single_gate"));
160 println!("lstm state: {:?}", results[7].read_f32().expect("lstm state"));
161 println!("lstm cell: {:?}", results[8].read_f32().expect("lstm cell"));
162 println!("gru state: {:?}", results[9].read_f32().expect("gru state"));
163 println!("gru training: {:?}", results[10].read_f32().expect("gru training"));
164}