#![allow(clippy::too_many_lines)]
use apple_metal::MetalDevice;
use apple_mpsgraph::{
data_type, random_distribution, rnn_activation, BinaryArithmeticOp, CompilationDescriptor,
FeedDescription, GRUDescriptor, Graph, LSTMDescriptor, RandomOpDescriptor, ShapedType,
SingleGateRNNDescriptor, TensorData, UnaryArithmeticOp, WhileBeforeResult,
};
fn i32_bytes(values: &[i32]) -> Vec<u8> {
values
.iter()
.flat_map(|value| value.to_ne_bytes())
.collect::<Vec<_>>()
}
fn read_i32(data: &TensorData) -> Vec<i32> {
let bytes = data.read_bytes().expect("read bytes");
bytes
.chunks_exact(core::mem::size_of::<i32>())
.map(|chunk| i32::from_ne_bytes(chunk.try_into().expect("i32 chunk")))
.collect()
}
#[test]
fn call_and_control_flow_execute() {
let device = MetalDevice::system_default().expect("no Metal device available");
let queue = device
.new_command_queue()
.expect("failed to create command queue");
let callee_graph = Graph::new().expect("callee graph");
let callee_input = callee_graph
.placeholder(Some(&[2]), data_type::FLOAT32, Some("callee_input"))
.expect("callee placeholder");
let callee_output = callee_graph
.addition(&callee_input, &callee_input, Some("callee_double"))
.expect("callee output");
let callee_executable = callee_graph
.compile(
&device,
&[FeedDescription::new(
&callee_input,
&[2],
data_type::FLOAT32,
)],
&[&callee_output],
)
.expect("callee executable");
let graph = Graph::new().expect("graph");
let input = graph
.placeholder(Some(&[2]), data_type::FLOAT32, Some("input"))
.expect("input placeholder");
let predicate = graph
.placeholder(Some(&[]), data_type::BOOL, Some("predicate"))
.expect("predicate placeholder");
let bias = graph
.constant_f32_slice(&[1.0, 1.0], &[2])
.expect("bias constant");
let call_output_type =
ShapedType::new(Some(&[2]), data_type::FLOAT32).expect("call output type");
let call_results = graph
.call("double", &[&input], &[&call_output_type], Some("call"))
.expect("call op");
assert_eq!(call_results.len(), 1);
let if_results = graph
.if_then_else(
&predicate,
|| vec![graph.addition(&input, &bias, None).expect("then add")],
|| vec![graph.subtraction(&input, &bias, None).expect("else sub")],
Some("branch"),
)
.expect("if/then/else");
assert_eq!(if_results.len(), 1);
let call_operation = call_results[0].operation().expect("call operation");
let dependent_results = graph
.control_dependency(
&[&call_operation],
|| {
vec![graph
.unary_arithmetic(UnaryArithmeticOp::Identity, &call_results[0], None)
.expect("dependent identity")]
},
Some("dependency"),
)
.expect("control dependency");
assert_eq!(dependent_results.len(), 1);
let number_of_iterations = graph
.constant_scalar(4.0, data_type::INT32)
.expect("iteration count");
let zero = graph
.constant_scalar(0.0, data_type::INT32)
.expect("zero constant");
let one = graph
.constant_scalar(1.0, data_type::INT32)
.expect("one constant");
let limit = graph
.constant_scalar(3.0, data_type::INT32)
.expect("limit constant");
let for_results = graph
.for_loop_iterations(
&number_of_iterations,
&[&zero],
|_index, args| {
vec![graph
.addition(&args[0], &one, None)
.expect("for-loop accumulation")]
},
Some("for_loop"),
)
.expect("for loop");
assert_eq!(for_results.len(), 1);
let while_results = graph
.while_loop(
&[&zero],
|inputs| {
let condition = graph
.binary_arithmetic(BinaryArithmeticOp::LessThan, &inputs[0], &limit, None)
.expect("while predicate");
let passthrough = graph
.unary_arithmetic(UnaryArithmeticOp::Identity, &inputs[0], None)
.expect("while passthrough");
WhileBeforeResult {
predicate: condition,
results: vec![passthrough],
}
},
|inputs| {
vec![graph
.addition(&inputs[0], &one, None)
.expect("while increment")]
},
Some("while_loop"),
)
.expect("while loop");
assert_eq!(while_results.len(), 1);
let compile_descriptor = CompilationDescriptor::new().expect("compile descriptor");
compile_descriptor
.set_callable("double", Some(&callee_executable))
.expect("set callable");
let executable = graph
.compile_with_descriptor(
Some(&device),
&[
FeedDescription::new(&input, &[2], data_type::FLOAT32),
FeedDescription::new(&predicate, &[], data_type::BOOL),
],
&[
&call_results[0],
&if_results[0],
&dependent_results[0],
&for_results[0],
&while_results[0],
],
Some(&compile_descriptor),
)
.expect("compile caller");
let input_data = TensorData::from_f32_slice(&device, &[3.0, 4.0], &[2]).expect("input data");
let predicate_data =
TensorData::from_bytes(&device, &[1_u8], &[], data_type::BOOL).expect("predicate data");
let results = executable
.run(&queue, &[&input_data, &predicate_data])
.expect("run executable");
assert_eq!(results[0].read_f32().expect("call output"), vec![6.0, 8.0]);
assert_eq!(results[1].read_f32().expect("if output"), vec![4.0, 5.0]);
assert_eq!(
results[2].read_f32().expect("dependency output"),
vec![6.0, 8.0]
);
assert_eq!(read_i32(&results[3]), vec![4]);
assert_eq!(read_i32(&results[4]), vec![3]);
}
#[test]
fn gather_and_random_execute() {
let graph = Graph::new().expect("graph");
let updates = graph
.constant_f32_slice(&[10.0, 20.0, 30.0, 40.0, 50.0, 60.0], &[2, 3])
.expect("updates");
let gather_indices = graph
.constant_bytes(&i32_bytes(&[2, 0]), &[2], data_type::INT32)
.expect("gather indices");
let gather_nd_indices = graph
.constant_bytes(&i32_bytes(&[0, 1, 1, 0]), &[2, 2], data_type::INT32)
.expect("gather nd indices");
let along_indices = graph
.constant_bytes(&i32_bytes(&[2, 1, 0, 0, 1, 2]), &[2, 3], data_type::INT32)
.expect("gather along indices");
let axis_tensor = graph
.constant_scalar(1.0, data_type::INT32)
.expect("axis tensor");
let gather = graph
.gather(&updates, &gather_indices, 1, 0, Some("gather"))
.expect("gather");
let gather_nd = graph
.gather_nd(&updates, &gather_nd_indices, 0, Some("gather_nd"))
.expect("gather nd");
let gather_along_axis = graph
.gather_along_axis(1, &updates, &along_indices, Some("gather_axis"))
.expect("gather along axis");
let gather_along_axis_tensor = graph
.gather_along_axis_tensor(
&axis_tensor,
&updates,
&along_indices,
Some("gather_axis_tensor"),
)
.expect("gather along axis tensor");
let descriptor = RandomOpDescriptor::new(random_distribution::UNIFORM, data_type::FLOAT32)
.expect("random descriptor");
descriptor.set_min(0.0).expect("random min");
descriptor.set_max(1.0).expect("random max");
let random_a = graph
.random_tensor_seed(&[4], &descriptor, 7, Some("random_a"))
.expect("random a");
let random_b = graph
.random_tensor_seed(&[4], &descriptor, 7, Some("random_b"))
.expect("random b");
let state = graph
.random_philox_state_seed(13, Some("random_state"))
.expect("state tensor");
let random_state = graph
.random_tensor_state(&[2], &descriptor, &state, Some("random_state_tensor"))
.expect("random state tensor");
let dropout = graph
.dropout(&updates, 1.0, Some("dropout"))
.expect("dropout");
let results = graph
.run(
&[],
&[
&gather,
&gather_nd,
&gather_along_axis,
&gather_along_axis_tensor,
&random_a,
&random_b,
&random_state.0,
&dropout,
],
)
.expect("run graph");
assert_eq!(
results[0].read_f32().expect("gather"),
vec![30.0, 10.0, 60.0, 40.0]
);
assert_eq!(results[1].read_f32().expect("gather nd"), vec![20.0, 40.0]);
assert_eq!(
results[2].read_f32().expect("gather axis"),
vec![30.0, 20.0, 10.0, 40.0, 50.0, 60.0]
);
assert_eq!(
results[3].read_f32().expect("gather axis tensor"),
vec![30.0, 20.0, 10.0, 40.0, 50.0, 60.0]
);
assert_eq!(
results[4].read_f32().expect("random a"),
results[5].read_f32().expect("random b")
);
let random_values = results[4].read_f32().expect("random values");
assert!(random_values
.iter()
.all(|value| *value >= 0.0 && *value < 1.0));
assert_eq!(results[6].shape(), vec![2]);
assert_eq!(results[7].read_f32().expect("dropout"), vec![0.0; 6]);
}
#[test]
fn rnn_descriptors_and_execute() {
let graph = Graph::new().expect("graph");
let single_gate_descriptor = SingleGateRNNDescriptor::new().expect("single gate descriptor");
single_gate_descriptor
.set_activation(rnn_activation::RELU)
.expect("set single gate activation");
assert_eq!(single_gate_descriptor.activation(), rnn_activation::RELU);
let single_gate_source = graph
.constant_f32_slice(&[0.5], &[1, 1, 1])
.expect("single gate source");
let single_gate_recurrent = graph
.constant_f32_slice(&[0.0], &[1, 1])
.expect("single gate recurrent");
let single_gate_results = graph
.single_gate_rnn(
&single_gate_source,
&single_gate_recurrent,
None,
None,
None,
None,
&single_gate_descriptor,
Some("single_gate"),
)
.expect("single gate rnn");
assert_eq!(single_gate_results.len(), 1);
let lstm_descriptor = LSTMDescriptor::new().expect("lstm descriptor");
lstm_descriptor
.set_produce_cell(true)
.expect("set produce cell");
assert!(lstm_descriptor.produce_cell());
let lstm_source = graph
.constant_f32_slice(&[0.0; 4], &[1, 1, 4])
.expect("lstm source");
let lstm_recurrent = graph
.constant_f32_slice(&[0.0; 4], &[4, 1])
.expect("lstm recurrent");
let lstm_results = graph
.lstm(
&lstm_source,
&lstm_recurrent,
None,
None,
None,
None,
None,
None,
&lstm_descriptor,
Some("lstm"),
)
.expect("lstm");
assert_eq!(lstm_results.len(), 2);
let gru_descriptor = GRUDescriptor::new().expect("gru descriptor");
gru_descriptor
.set_reset_after(true)
.expect("set reset after");
gru_descriptor.set_training(true).expect("set training");
assert!(gru_descriptor.reset_after());
assert!(gru_descriptor.training());
let gru_source = graph
.constant_f32_slice(&[0.0; 3], &[1, 1, 3])
.expect("gru source");
let gru_recurrent = graph
.constant_f32_slice(&[0.0; 3], &[3, 1])
.expect("gru recurrent");
let gru_secondary_bias = graph
.constant_f32_slice(&[0.0], &[1])
.expect("gru secondary bias");
let gru_results = graph
.gru(
&gru_source,
&gru_recurrent,
None,
None,
None,
None,
Some(&gru_secondary_bias),
&gru_descriptor,
Some("gru"),
)
.expect("gru");
assert_eq!(gru_results.len(), 2);
let results = graph
.run(
&[],
&[
&single_gate_results[0],
&lstm_results[0],
&lstm_results[1],
&gru_results[0],
&gru_results[1],
],
)
.expect("run rnn graph");
assert_eq!(
results[0].read_f32().expect("single gate output"),
vec![0.5]
);
assert_eq!(results[1].read_f32().expect("lstm output"), vec![0.0]);
assert_eq!(results[2].read_f32().expect("lstm cell"), vec![0.0]);
assert_eq!(results[3].read_f32().expect("gru output"), vec![0.0]);
assert_eq!(
results[4].read_f32().expect("gru training state"),
vec![0.0, 0.0, 0.0, 0.0]
);
}