apple-mpsgraph 0.2.7

Safe Rust bindings for Apple's MetalPerformanceShadersGraph framework on macOS, backed by a Swift bridge
Documentation
#![allow(clippy::too_many_lines)]

use apple_metal::MetalDevice;
use apple_mpsgraph::{
    data_type, random_distribution, rnn_activation, BinaryArithmeticOp, CompilationDescriptor,
    FeedDescription, GRUDescriptor, Graph, LSTMDescriptor, RandomOpDescriptor, ShapedType,
    SingleGateRNNDescriptor, TensorData, UnaryArithmeticOp, WhileBeforeResult,
};

fn i32_bytes(values: &[i32]) -> Vec<u8> {
    values
        .iter()
        .flat_map(|value| value.to_ne_bytes())
        .collect::<Vec<_>>()
}

fn read_i32(data: &TensorData) -> Vec<i32> {
    let bytes = data.read_bytes().expect("read bytes");
    bytes
        .chunks_exact(core::mem::size_of::<i32>())
        .map(|chunk| i32::from_ne_bytes(chunk.try_into().expect("i32 chunk")))
        .collect()
}

#[test]
fn call_and_control_flow_execute() {
    let device = MetalDevice::system_default().expect("no Metal device available");
    let queue = device
        .new_command_queue()
        .expect("failed to create command queue");

    let callee_graph = Graph::new().expect("callee graph");
    let callee_input = callee_graph
        .placeholder(Some(&[2]), data_type::FLOAT32, Some("callee_input"))
        .expect("callee placeholder");
    let callee_output = callee_graph
        .addition(&callee_input, &callee_input, Some("callee_double"))
        .expect("callee output");
    let callee_executable = callee_graph
        .compile(
            &device,
            &[FeedDescription::new(
                &callee_input,
                &[2],
                data_type::FLOAT32,
            )],
            &[&callee_output],
        )
        .expect("callee executable");

    let graph = Graph::new().expect("graph");
    let input = graph
        .placeholder(Some(&[2]), data_type::FLOAT32, Some("input"))
        .expect("input placeholder");
    let predicate = graph
        .placeholder(Some(&[]), data_type::BOOL, Some("predicate"))
        .expect("predicate placeholder");
    let bias = graph
        .constant_f32_slice(&[1.0, 1.0], &[2])
        .expect("bias constant");

    let call_output_type =
        ShapedType::new(Some(&[2]), data_type::FLOAT32).expect("call output type");
    let call_results = graph
        .call("double", &[&input], &[&call_output_type], Some("call"))
        .expect("call op");
    assert_eq!(call_results.len(), 1);

    let if_results = graph
        .if_then_else(
            &predicate,
            || vec![graph.addition(&input, &bias, None).expect("then add")],
            || vec![graph.subtraction(&input, &bias, None).expect("else sub")],
            Some("branch"),
        )
        .expect("if/then/else");
    assert_eq!(if_results.len(), 1);

    let call_operation = call_results[0].operation().expect("call operation");
    let dependent_results = graph
        .control_dependency(
            &[&call_operation],
            || {
                vec![graph
                    .unary_arithmetic(UnaryArithmeticOp::Identity, &call_results[0], None)
                    .expect("dependent identity")]
            },
            Some("dependency"),
        )
        .expect("control dependency");
    assert_eq!(dependent_results.len(), 1);

    let number_of_iterations = graph
        .constant_scalar(4.0, data_type::INT32)
        .expect("iteration count");
    let zero = graph
        .constant_scalar(0.0, data_type::INT32)
        .expect("zero constant");
    let one = graph
        .constant_scalar(1.0, data_type::INT32)
        .expect("one constant");
    let limit = graph
        .constant_scalar(3.0, data_type::INT32)
        .expect("limit constant");

    let for_results = graph
        .for_loop_iterations(
            &number_of_iterations,
            &[&zero],
            |_index, args| {
                vec![graph
                    .addition(&args[0], &one, None)
                    .expect("for-loop accumulation")]
            },
            Some("for_loop"),
        )
        .expect("for loop");
    assert_eq!(for_results.len(), 1);

    let while_results = graph
        .while_loop(
            &[&zero],
            |inputs| {
                let condition = graph
                    .binary_arithmetic(BinaryArithmeticOp::LessThan, &inputs[0], &limit, None)
                    .expect("while predicate");
                let passthrough = graph
                    .unary_arithmetic(UnaryArithmeticOp::Identity, &inputs[0], None)
                    .expect("while passthrough");
                WhileBeforeResult {
                    predicate: condition,
                    results: vec![passthrough],
                }
            },
            |inputs| {
                vec![graph
                    .addition(&inputs[0], &one, None)
                    .expect("while increment")]
            },
            Some("while_loop"),
        )
        .expect("while loop");
    assert_eq!(while_results.len(), 1);

    let compile_descriptor = CompilationDescriptor::new().expect("compile descriptor");
    compile_descriptor
        .set_callable("double", Some(&callee_executable))
        .expect("set callable");
    let executable = graph
        .compile_with_descriptor(
            Some(&device),
            &[
                FeedDescription::new(&input, &[2], data_type::FLOAT32),
                FeedDescription::new(&predicate, &[], data_type::BOOL),
            ],
            &[
                &call_results[0],
                &if_results[0],
                &dependent_results[0],
                &for_results[0],
                &while_results[0],
            ],
            Some(&compile_descriptor),
        )
        .expect("compile caller");

    let input_data = TensorData::from_f32_slice(&device, &[3.0, 4.0], &[2]).expect("input data");
    let predicate_data =
        TensorData::from_bytes(&device, &[1_u8], &[], data_type::BOOL).expect("predicate data");
    let results = executable
        .run(&queue, &[&input_data, &predicate_data])
        .expect("run executable");

    assert_eq!(results[0].read_f32().expect("call output"), vec![6.0, 8.0]);
    assert_eq!(results[1].read_f32().expect("if output"), vec![4.0, 5.0]);
    assert_eq!(
        results[2].read_f32().expect("dependency output"),
        vec![6.0, 8.0]
    );
    assert_eq!(read_i32(&results[3]), vec![4]);
    assert_eq!(read_i32(&results[4]), vec![3]);
}

#[test]
fn gather_and_random_execute() {
    let graph = Graph::new().expect("graph");
    let updates = graph
        .constant_f32_slice(&[10.0, 20.0, 30.0, 40.0, 50.0, 60.0], &[2, 3])
        .expect("updates");
    let gather_indices = graph
        .constant_bytes(&i32_bytes(&[2, 0]), &[2], data_type::INT32)
        .expect("gather indices");
    let gather_nd_indices = graph
        .constant_bytes(&i32_bytes(&[0, 1, 1, 0]), &[2, 2], data_type::INT32)
        .expect("gather nd indices");
    let along_indices = graph
        .constant_bytes(&i32_bytes(&[2, 1, 0, 0, 1, 2]), &[2, 3], data_type::INT32)
        .expect("gather along indices");
    let axis_tensor = graph
        .constant_scalar(1.0, data_type::INT32)
        .expect("axis tensor");

    let gather = graph
        .gather(&updates, &gather_indices, 1, 0, Some("gather"))
        .expect("gather");
    let gather_nd = graph
        .gather_nd(&updates, &gather_nd_indices, 0, Some("gather_nd"))
        .expect("gather nd");
    let gather_along_axis = graph
        .gather_along_axis(1, &updates, &along_indices, Some("gather_axis"))
        .expect("gather along axis");
    let gather_along_axis_tensor = graph
        .gather_along_axis_tensor(
            &axis_tensor,
            &updates,
            &along_indices,
            Some("gather_axis_tensor"),
        )
        .expect("gather along axis tensor");

    let descriptor = RandomOpDescriptor::new(random_distribution::UNIFORM, data_type::FLOAT32)
        .expect("random descriptor");
    descriptor.set_min(0.0).expect("random min");
    descriptor.set_max(1.0).expect("random max");
    let random_a = graph
        .random_tensor_seed(&[4], &descriptor, 7, Some("random_a"))
        .expect("random a");
    let random_b = graph
        .random_tensor_seed(&[4], &descriptor, 7, Some("random_b"))
        .expect("random b");
    let state = graph
        .random_philox_state_seed(13, Some("random_state"))
        .expect("state tensor");
    let random_state = graph
        .random_tensor_state(&[2], &descriptor, &state, Some("random_state_tensor"))
        .expect("random state tensor");
    let dropout = graph
        .dropout(&updates, 1.0, Some("dropout"))
        .expect("dropout");

    let results = graph
        .run(
            &[],
            &[
                &gather,
                &gather_nd,
                &gather_along_axis,
                &gather_along_axis_tensor,
                &random_a,
                &random_b,
                &random_state.0,
                &dropout,
            ],
        )
        .expect("run graph");

    assert_eq!(
        results[0].read_f32().expect("gather"),
        vec![30.0, 10.0, 60.0, 40.0]
    );
    assert_eq!(results[1].read_f32().expect("gather nd"), vec![20.0, 40.0]);
    assert_eq!(
        results[2].read_f32().expect("gather axis"),
        vec![30.0, 20.0, 10.0, 40.0, 50.0, 60.0]
    );
    assert_eq!(
        results[3].read_f32().expect("gather axis tensor"),
        vec![30.0, 20.0, 10.0, 40.0, 50.0, 60.0]
    );
    assert_eq!(
        results[4].read_f32().expect("random a"),
        results[5].read_f32().expect("random b")
    );
    let random_values = results[4].read_f32().expect("random values");
    assert!(random_values
        .iter()
        .all(|value| *value >= 0.0 && *value < 1.0));
    assert_eq!(results[6].shape(), vec![2]);
    assert_eq!(results[7].read_f32().expect("dropout"), vec![0.0; 6]);
}

#[test]
fn rnn_descriptors_and_execute() {
    let graph = Graph::new().expect("graph");

    let single_gate_descriptor = SingleGateRNNDescriptor::new().expect("single gate descriptor");
    single_gate_descriptor
        .set_activation(rnn_activation::RELU)
        .expect("set single gate activation");
    assert_eq!(single_gate_descriptor.activation(), rnn_activation::RELU);

    let single_gate_source = graph
        .constant_f32_slice(&[0.5], &[1, 1, 1])
        .expect("single gate source");
    let single_gate_recurrent = graph
        .constant_f32_slice(&[0.0], &[1, 1])
        .expect("single gate recurrent");
    let single_gate_results = graph
        .single_gate_rnn(
            &single_gate_source,
            &single_gate_recurrent,
            None,
            None,
            None,
            None,
            &single_gate_descriptor,
            Some("single_gate"),
        )
        .expect("single gate rnn");
    assert_eq!(single_gate_results.len(), 1);

    let lstm_descriptor = LSTMDescriptor::new().expect("lstm descriptor");
    lstm_descriptor
        .set_produce_cell(true)
        .expect("set produce cell");
    assert!(lstm_descriptor.produce_cell());

    let lstm_source = graph
        .constant_f32_slice(&[0.0; 4], &[1, 1, 4])
        .expect("lstm source");
    let lstm_recurrent = graph
        .constant_f32_slice(&[0.0; 4], &[4, 1])
        .expect("lstm recurrent");
    let lstm_results = graph
        .lstm(
            &lstm_source,
            &lstm_recurrent,
            None,
            None,
            None,
            None,
            None,
            None,
            &lstm_descriptor,
            Some("lstm"),
        )
        .expect("lstm");
    assert_eq!(lstm_results.len(), 2);

    let gru_descriptor = GRUDescriptor::new().expect("gru descriptor");
    gru_descriptor
        .set_reset_after(true)
        .expect("set reset after");
    gru_descriptor.set_training(true).expect("set training");
    assert!(gru_descriptor.reset_after());
    assert!(gru_descriptor.training());

    let gru_source = graph
        .constant_f32_slice(&[0.0; 3], &[1, 1, 3])
        .expect("gru source");
    let gru_recurrent = graph
        .constant_f32_slice(&[0.0; 3], &[3, 1])
        .expect("gru recurrent");
    let gru_secondary_bias = graph
        .constant_f32_slice(&[0.0], &[1])
        .expect("gru secondary bias");
    let gru_results = graph
        .gru(
            &gru_source,
            &gru_recurrent,
            None,
            None,
            None,
            None,
            Some(&gru_secondary_bias),
            &gru_descriptor,
            Some("gru"),
        )
        .expect("gru");
    assert_eq!(gru_results.len(), 2);

    let results = graph
        .run(
            &[],
            &[
                &single_gate_results[0],
                &lstm_results[0],
                &lstm_results[1],
                &gru_results[0],
                &gru_results[1],
            ],
        )
        .expect("run rnn graph");

    assert_eq!(
        results[0].read_f32().expect("single gate output"),
        vec![0.5]
    );
    assert_eq!(results[1].read_f32().expect("lstm output"), vec![0.0]);
    assert_eq!(results[2].read_f32().expect("lstm cell"), vec![0.0]);
    assert_eq!(results[3].read_f32().expect("gru output"), vec![0.0]);
    assert_eq!(
        results[4].read_f32().expect("gru training state"),
        vec![0.0, 0.0, 0.0, 0.0]
    );
}