Skip to main content

03_arithmetic_topk/
03_arithmetic_topk.rs

1use apple_metal::MetalDevice;
2use apple_mpsgraph::{data_type, Feed, Graph, ReductionAxesOp, TensorData, UnaryArithmeticOp};
3
4fn main() {
5    let device = MetalDevice::system_default().expect("no Metal device available");
6    let graph = Graph::new().expect("graph");
7    let input = graph
8        .placeholder(Some(&[2, 3]), data_type::FLOAT32, Some("input"))
9        .expect("placeholder");
10    let squared = graph
11        .unary_arithmetic(UnaryArithmeticOp::Square, &input, Some("square"))
12        .expect("square");
13    let row_sum = graph
14        .reduce_axes(ReductionAxesOp::Sum, &squared, &[1], Some("row_sum"))
15        .expect("reduce");
16    let topk = graph.top_k(&input, 2, Some("topk")).expect("topk");
17
18    let input_data = TensorData::from_f32_slice(&device, &[1.0, 3.0, 2.0, 4.0, 6.0, 5.0], &[2, 3])
19        .expect("tensor data");
20    let results = graph
21        .run(&[Feed::new(&input, &input_data)], &[&row_sum, &topk.0])
22        .expect("run");
23
24    println!("row sums: {:?}", results[0].read_f32().expect("row sums"));
25    println!(
26        "top-k values: {:?}",
27        results[1].read_f32().expect("topk values")
28    );
29}