Skip to main content

05_concat_split/
05_concat_split.rs

1use apple_metal::MetalDevice;
2use apple_mpsgraph::{data_type, Feed, Graph, TensorData};
3
4fn main() {
5    let device = MetalDevice::system_default().expect("no Metal device available");
6    let graph = Graph::new().expect("graph");
7    let input = graph
8        .placeholder(Some(&[2, 2]), data_type::FLOAT32, Some("input"))
9        .expect("placeholder");
10    let concat = graph
11        .concat_pair(&input, &input, 1, Some("concat"))
12        .expect("concat");
13    let split = graph.split_num(&concat, 2, 1, Some("split"));
14    let stacked = graph
15        .stack(&[&split[0], &split[1]], 0, Some("stack"))
16        .expect("stack");
17
18    let input_data =
19        TensorData::from_f32_slice(&device, &[1.0, 2.0, 3.0, 4.0], &[2, 2]).expect("tensor data");
20    let results = graph
21        .run(&[Feed::new(&input, &input_data)], &[&stacked])
22        .expect("run");
23
24    println!(
25        "stacked tensor bytes: {}",
26        results[0].byte_len().expect("byte len")
27    );
28}