Skip to main content

04_descriptor_compile/
04_descriptor_compile.rs

1use apple_metal::MetalDevice;
2use apple_mpsgraph::{
3    data_type, optimization, CompilationDescriptor, FeedDescription, Graph, ShapedType,
4    UnaryArithmeticOp,
5};
6
7fn main() {
8    let device = MetalDevice::system_default().expect("no Metal device available");
9    let graph = Graph::new().expect("graph");
10    let input = graph
11        .placeholder(Some(&[4]), data_type::FLOAT32, Some("input"))
12        .expect("placeholder");
13    let output = graph
14        .unary_arithmetic(UnaryArithmeticOp::Absolute, &input, Some("abs"))
15        .expect("absolute");
16
17    let descriptor = CompilationDescriptor::new().expect("compilation descriptor");
18    descriptor
19        .set_optimization_level(optimization::LEVEL1)
20        .expect("set optimization level");
21    descriptor
22        .set_wait_for_compilation_completion(true)
23        .expect("set wait");
24
25    let executable = graph
26        .compile_with_descriptor(
27            Some(&device),
28            &[FeedDescription::new(&input, &[4], data_type::FLOAT32)],
29            &[&output],
30            Some(&descriptor),
31        )
32        .expect("compile");
33    let input_type = ShapedType::new(Some(&[4]), data_type::FLOAT32).expect("shaped type");
34    let output_types = executable
35        .output_types(Some(&device), &[&input_type], Some(&descriptor))
36        .expect("output types");
37
38    println!("feed tensors: {}", executable.feed_tensors().len());
39    println!("target tensors: {}", executable.target_tensors().len());
40    println!("output type: {:?}", output_types[0].shape());
41}