apple-mpsgraph 0.2.7

Safe Rust bindings for Apple's MetalPerformanceShadersGraph framework on macOS, backed by a Swift bridge
Documentation
#![allow(clippy::too_many_lines)]

use apple_metal::MetalDevice;
use apple_mpsgraph::{
    data_type, BinaryArithmeticOp, CompilationDescriptor, FeedDescription, Graph, ShapedType,
    TensorData, UnaryArithmeticOp, WhileBeforeResult,
};

fn read_i32(data: &TensorData) -> Vec<i32> {
    let bytes = data.read_bytes().expect("read bytes");
    bytes
        .chunks_exact(core::mem::size_of::<i32>())
        .map(|chunk| i32::from_ne_bytes(chunk.try_into().expect("i32 chunk")))
        .collect()
}

fn main() {
    let device = MetalDevice::system_default().expect("no Metal device available");
    let queue = device
        .new_command_queue()
        .expect("failed to create command queue");

    let callee_graph = Graph::new().expect("callee graph");
    let callee_input = callee_graph
        .placeholder(Some(&[2]), data_type::FLOAT32, Some("callee_input"))
        .expect("callee placeholder");
    let callee_output = callee_graph
        .addition(&callee_input, &callee_input, Some("callee_double"))
        .expect("callee output");
    let callee_executable = callee_graph
        .compile(
            &device,
            &[FeedDescription::new(
                &callee_input,
                &[2],
                data_type::FLOAT32,
            )],
            &[&callee_output],
        )
        .expect("callee executable");

    let graph = Graph::new().expect("graph");
    let input = graph
        .placeholder(Some(&[2]), data_type::FLOAT32, Some("input"))
        .expect("input placeholder");
    let predicate = graph
        .placeholder(Some(&[]), data_type::BOOL, Some("predicate"))
        .expect("predicate placeholder");
    let bias = graph
        .constant_f32_slice(&[1.0, 1.0], &[2])
        .expect("bias constant");

    let output_type = ShapedType::new(Some(&[2]), data_type::FLOAT32).expect("output type");
    let call_results = graph
        .call("double", &[&input], &[&output_type], Some("call"))
        .expect("call op");
    let if_results = graph
        .if_then_else(
            &predicate,
            || vec![graph.addition(&input, &bias, None).expect("then add")],
            || vec![graph.subtraction(&input, &bias, None).expect("else sub")],
            Some("branch"),
        )
        .expect("if/then/else");

    let call_operation = call_results[0].operation().expect("call operation");
    let dependency = graph
        .control_dependency(
            &[&call_operation],
            || {
                vec![graph
                    .unary_arithmetic(UnaryArithmeticOp::Identity, &call_results[0], None)
                    .expect("identity")]
            },
            Some("dependency"),
        )
        .expect("control dependency");

    let number_of_iterations = graph
        .constant_scalar(4.0, data_type::INT32)
        .expect("iteration count");
    let zero = graph
        .constant_scalar(0.0, data_type::INT32)
        .expect("zero constant");
    let one = graph
        .constant_scalar(1.0, data_type::INT32)
        .expect("one constant");
    let limit = graph
        .constant_scalar(3.0, data_type::INT32)
        .expect("limit constant");

    let for_results = graph
        .for_loop_iterations(
            &number_of_iterations,
            &[&zero],
            |_index, args| vec![graph.addition(&args[0], &one, None).expect("for-loop add")],
            Some("for_loop"),
        )
        .expect("for loop");
    let while_results = graph
        .while_loop(
            &[&zero],
            |inputs| {
                let condition = graph
                    .binary_arithmetic(BinaryArithmeticOp::LessThan, &inputs[0], &limit, None)
                    .expect("while predicate");
                let passthrough = graph
                    .unary_arithmetic(UnaryArithmeticOp::Identity, &inputs[0], None)
                    .expect("while passthrough");
                WhileBeforeResult {
                    predicate: condition,
                    results: vec![passthrough],
                }
            },
            |inputs| vec![graph.addition(&inputs[0], &one, None).expect("while add")],
            Some("while_loop"),
        )
        .expect("while loop");

    let compile_descriptor = CompilationDescriptor::new().expect("compile descriptor");
    compile_descriptor
        .set_callable("double", Some(&callee_executable))
        .expect("set callable");
    let executable = graph
        .compile_with_descriptor(
            Some(&device),
            &[
                FeedDescription::new(&input, &[2], data_type::FLOAT32),
                FeedDescription::new(&predicate, &[], data_type::BOOL),
            ],
            &[
                &call_results[0],
                &if_results[0],
                &dependency[0],
                &for_results[0],
                &while_results[0],
            ],
            Some(&compile_descriptor),
        )
        .expect("compile executable");

    let input_data = TensorData::from_f32_slice(&device, &[3.0, 4.0], &[2]).expect("input data");
    let predicate_data =
        TensorData::from_bytes(&device, &[1_u8], &[], data_type::BOOL).expect("predicate data");
    let results = executable
        .run(&queue, &[&input_data, &predicate_data])
        .expect("run executable");

    println!(
        "call output: {:?}",
        results[0].read_f32().expect("call output")
    );
    println!("if output: {:?}", results[1].read_f32().expect("if output"));
    println!(
        "dependency output: {:?}",
        results[2].read_f32().expect("dependency output")
    );
    println!("for output: {:?}", read_i32(&results[3]));
    println!("while output: {:?}", read_i32(&results[4]));
}