Skip to main content

06_control_flow_call/
06_control_flow_call.rs

1#![allow(clippy::too_many_lines)]
2
3use apple_metal::MetalDevice;
4use apple_mpsgraph::{
5    data_type, BinaryArithmeticOp, CompilationDescriptor, FeedDescription, Graph, ShapedType,
6    TensorData, UnaryArithmeticOp, WhileBeforeResult,
7};
8
9fn read_i32(data: &TensorData) -> Vec<i32> {
10    let bytes = data.read_bytes().expect("read bytes");
11    bytes
12        .chunks_exact(core::mem::size_of::<i32>())
13        .map(|chunk| i32::from_ne_bytes(chunk.try_into().expect("i32 chunk")))
14        .collect()
15}
16
17fn main() {
18    let device = MetalDevice::system_default().expect("no Metal device available");
19    let queue = device
20        .new_command_queue()
21        .expect("failed to create command queue");
22
23    let callee_graph = Graph::new().expect("callee graph");
24    let callee_input = callee_graph
25        .placeholder(Some(&[2]), data_type::FLOAT32, Some("callee_input"))
26        .expect("callee placeholder");
27    let callee_output = callee_graph
28        .addition(&callee_input, &callee_input, Some("callee_double"))
29        .expect("callee output");
30    let callee_executable = callee_graph
31        .compile(
32            &device,
33            &[FeedDescription::new(
34                &callee_input,
35                &[2],
36                data_type::FLOAT32,
37            )],
38            &[&callee_output],
39        )
40        .expect("callee executable");
41
42    let graph = Graph::new().expect("graph");
43    let input = graph
44        .placeholder(Some(&[2]), data_type::FLOAT32, Some("input"))
45        .expect("input placeholder");
46    let predicate = graph
47        .placeholder(Some(&[]), data_type::BOOL, Some("predicate"))
48        .expect("predicate placeholder");
49    let bias = graph
50        .constant_f32_slice(&[1.0, 1.0], &[2])
51        .expect("bias constant");
52
53    let output_type = ShapedType::new(Some(&[2]), data_type::FLOAT32).expect("output type");
54    let call_results = graph
55        .call("double", &[&input], &[&output_type], Some("call"))
56        .expect("call op");
57    let if_results = graph
58        .if_then_else(
59            &predicate,
60            || vec![graph.addition(&input, &bias, None).expect("then add")],
61            || vec![graph.subtraction(&input, &bias, None).expect("else sub")],
62            Some("branch"),
63        )
64        .expect("if/then/else");
65
66    let call_operation = call_results[0].operation().expect("call operation");
67    let dependency = graph
68        .control_dependency(
69            &[&call_operation],
70            || {
71                vec![graph
72                    .unary_arithmetic(UnaryArithmeticOp::Identity, &call_results[0], None)
73                    .expect("identity")]
74            },
75            Some("dependency"),
76        )
77        .expect("control dependency");
78
79    let number_of_iterations = graph
80        .constant_scalar(4.0, data_type::INT32)
81        .expect("iteration count");
82    let zero = graph
83        .constant_scalar(0.0, data_type::INT32)
84        .expect("zero constant");
85    let one = graph
86        .constant_scalar(1.0, data_type::INT32)
87        .expect("one constant");
88    let limit = graph
89        .constant_scalar(3.0, data_type::INT32)
90        .expect("limit constant");
91
92    let for_results = graph
93        .for_loop_iterations(
94            &number_of_iterations,
95            &[&zero],
96            |_index, args| vec![graph.addition(&args[0], &one, None).expect("for-loop add")],
97            Some("for_loop"),
98        )
99        .expect("for loop");
100    let while_results = graph
101        .while_loop(
102            &[&zero],
103            |inputs| {
104                let condition = graph
105                    .binary_arithmetic(BinaryArithmeticOp::LessThan, &inputs[0], &limit, None)
106                    .expect("while predicate");
107                let passthrough = graph
108                    .unary_arithmetic(UnaryArithmeticOp::Identity, &inputs[0], None)
109                    .expect("while passthrough");
110                WhileBeforeResult {
111                    predicate: condition,
112                    results: vec![passthrough],
113                }
114            },
115            |inputs| vec![graph.addition(&inputs[0], &one, None).expect("while add")],
116            Some("while_loop"),
117        )
118        .expect("while loop");
119
120    let compile_descriptor = CompilationDescriptor::new().expect("compile descriptor");
121    compile_descriptor
122        .set_callable("double", Some(&callee_executable))
123        .expect("set callable");
124    let executable = graph
125        .compile_with_descriptor(
126            Some(&device),
127            &[
128                FeedDescription::new(&input, &[2], data_type::FLOAT32),
129                FeedDescription::new(&predicate, &[], data_type::BOOL),
130            ],
131            &[
132                &call_results[0],
133                &if_results[0],
134                &dependency[0],
135                &for_results[0],
136                &while_results[0],
137            ],
138            Some(&compile_descriptor),
139        )
140        .expect("compile executable");
141
142    let input_data = TensorData::from_f32_slice(&device, &[3.0, 4.0], &[2]).expect("input data");
143    let predicate_data =
144        TensorData::from_bytes(&device, &[1_u8], &[], data_type::BOOL).expect("predicate data");
145    let results = executable
146        .run(&queue, &[&input_data, &predicate_data])
147        .expect("run executable");
148
149    println!(
150        "call output: {:?}",
151        results[0].read_f32().expect("call output")
152    );
153    println!("if output: {:?}", results[1].read_f32().expect("if output"));
154    println!(
155        "dependency output: {:?}",
156        results[2].read_f32().expect("dependency output")
157    );
158    println!("for output: {:?}", read_i32(&results[3]));
159    println!("while output: {:?}", read_i32(&results[4]));
160}