Skip to main content

TensorData

Struct TensorData 

Source
pub struct TensorData { /* private fields */ }
Expand description

Safe owner for an Objective-C MPSGraphTensorData.

Implementations§

Source§

impl TensorData

Source

pub fn from_bytes( device: &MetalDevice, bytes: &[u8], shape: &[usize], data_type: u32, ) -> Option<Self>

Build tensor data by copying CPU bytes onto the given Metal device.

Examples found in repository?
examples/06_control_flow_call.rs (line 144)
17fn main() {
18    let device = MetalDevice::system_default().expect("no Metal device available");
19    let queue = device
20        .new_command_queue()
21        .expect("failed to create command queue");
22
23    let callee_graph = Graph::new().expect("callee graph");
24    let callee_input = callee_graph
25        .placeholder(Some(&[2]), data_type::FLOAT32, Some("callee_input"))
26        .expect("callee placeholder");
27    let callee_output = callee_graph
28        .addition(&callee_input, &callee_input, Some("callee_double"))
29        .expect("callee output");
30    let callee_executable = callee_graph
31        .compile(
32            &device,
33            &[FeedDescription::new(
34                &callee_input,
35                &[2],
36                data_type::FLOAT32,
37            )],
38            &[&callee_output],
39        )
40        .expect("callee executable");
41
42    let graph = Graph::new().expect("graph");
43    let input = graph
44        .placeholder(Some(&[2]), data_type::FLOAT32, Some("input"))
45        .expect("input placeholder");
46    let predicate = graph
47        .placeholder(Some(&[]), data_type::BOOL, Some("predicate"))
48        .expect("predicate placeholder");
49    let bias = graph
50        .constant_f32_slice(&[1.0, 1.0], &[2])
51        .expect("bias constant");
52
53    let output_type = ShapedType::new(Some(&[2]), data_type::FLOAT32).expect("output type");
54    let call_results = graph
55        .call("double", &[&input], &[&output_type], Some("call"))
56        .expect("call op");
57    let if_results = graph
58        .if_then_else(
59            &predicate,
60            || vec![graph.addition(&input, &bias, None).expect("then add")],
61            || vec![graph.subtraction(&input, &bias, None).expect("else sub")],
62            Some("branch"),
63        )
64        .expect("if/then/else");
65
66    let call_operation = call_results[0].operation().expect("call operation");
67    let dependency = graph
68        .control_dependency(
69            &[&call_operation],
70            || {
71                vec![graph
72                    .unary_arithmetic(UnaryArithmeticOp::Identity, &call_results[0], None)
73                    .expect("identity")]
74            },
75            Some("dependency"),
76        )
77        .expect("control dependency");
78
79    let number_of_iterations = graph
80        .constant_scalar(4.0, data_type::INT32)
81        .expect("iteration count");
82    let zero = graph
83        .constant_scalar(0.0, data_type::INT32)
84        .expect("zero constant");
85    let one = graph
86        .constant_scalar(1.0, data_type::INT32)
87        .expect("one constant");
88    let limit = graph
89        .constant_scalar(3.0, data_type::INT32)
90        .expect("limit constant");
91
92    let for_results = graph
93        .for_loop_iterations(
94            &number_of_iterations,
95            &[&zero],
96            |_index, args| vec![graph.addition(&args[0], &one, None).expect("for-loop add")],
97            Some("for_loop"),
98        )
99        .expect("for loop");
100    let while_results = graph
101        .while_loop(
102            &[&zero],
103            |inputs| {
104                let condition = graph
105                    .binary_arithmetic(BinaryArithmeticOp::LessThan, &inputs[0], &limit, None)
106                    .expect("while predicate");
107                let passthrough = graph
108                    .unary_arithmetic(UnaryArithmeticOp::Identity, &inputs[0], None)
109                    .expect("while passthrough");
110                WhileBeforeResult {
111                    predicate: condition,
112                    results: vec![passthrough],
113                }
114            },
115            |inputs| vec![graph.addition(&inputs[0], &one, None).expect("while add")],
116            Some("while_loop"),
117        )
118        .expect("while loop");
119
120    let compile_descriptor = CompilationDescriptor::new().expect("compile descriptor");
121    compile_descriptor
122        .set_callable("double", Some(&callee_executable))
123        .expect("set callable");
124    let executable = graph
125        .compile_with_descriptor(
126            Some(&device),
127            &[
128                FeedDescription::new(&input, &[2], data_type::FLOAT32),
129                FeedDescription::new(&predicate, &[], data_type::BOOL),
130            ],
131            &[
132                &call_results[0],
133                &if_results[0],
134                &dependency[0],
135                &for_results[0],
136                &while_results[0],
137            ],
138            Some(&compile_descriptor),
139        )
140        .expect("compile executable");
141
142    let input_data = TensorData::from_f32_slice(&device, &[3.0, 4.0], &[2]).expect("input data");
143    let predicate_data =
144        TensorData::from_bytes(&device, &[1_u8], &[], data_type::BOOL).expect("predicate data");
145    let results = executable
146        .run(&queue, &[&input_data, &predicate_data])
147        .expect("run executable");
148
149    println!(
150        "call output: {:?}",
151        results[0].read_f32().expect("call output")
152    );
153    println!("if output: {:?}", results[1].read_f32().expect("if output"));
154    println!(
155        "dependency output: {:?}",
156        results[2].read_f32().expect("dependency output")
157    );
158    println!("for output: {:?}", read_i32(&results[3]));
159    println!("while output: {:?}", read_i32(&results[4]));
160}
Source

pub fn from_f32_slice( device: &MetalDevice, values: &[f32], shape: &[usize], ) -> Option<Self>

Build tensor data from a contiguous f32 slice.

Examples found in repository?
examples/05_concat_split.rs (line 19)
4fn main() {
5    let device = MetalDevice::system_default().expect("no Metal device available");
6    let graph = Graph::new().expect("graph");
7    let input = graph
8        .placeholder(Some(&[2, 2]), data_type::FLOAT32, Some("input"))
9        .expect("placeholder");
10    let concat = graph
11        .concat_pair(&input, &input, 1, Some("concat"))
12        .expect("concat");
13    let split = graph.split_num(&concat, 2, 1, Some("split"));
14    let stacked = graph
15        .stack(&[&split[0], &split[1]], 0, Some("stack"))
16        .expect("stack");
17
18    let input_data =
19        TensorData::from_f32_slice(&device, &[1.0, 2.0, 3.0, 4.0], &[2, 2]).expect("tensor data");
20    let results = graph
21        .run(&[Feed::new(&input, &input_data)], &[&stacked])
22        .expect("run");
23
24    println!(
25        "stacked tensor bytes: {}",
26        results[0].byte_len().expect("byte len")
27    );
28}
More examples
Hide additional examples
examples/03_arithmetic_topk.rs (line 18)
4fn main() {
5    let device = MetalDevice::system_default().expect("no Metal device available");
6    let graph = Graph::new().expect("graph");
7    let input = graph
8        .placeholder(Some(&[2, 3]), data_type::FLOAT32, Some("input"))
9        .expect("placeholder");
10    let squared = graph
11        .unary_arithmetic(UnaryArithmeticOp::Square, &input, Some("square"))
12        .expect("square");
13    let row_sum = graph
14        .reduce_axes(ReductionAxesOp::Sum, &squared, &[1], Some("row_sum"))
15        .expect("reduce");
16    let topk = graph.top_k(&input, 2, Some("topk")).expect("topk");
17
18    let input_data = TensorData::from_f32_slice(&device, &[1.0, 3.0, 2.0, 4.0, 6.0, 5.0], &[2, 3])
19        .expect("tensor data");
20    let results = graph
21        .run(&[Feed::new(&input, &input_data)], &[&row_sum, &topk.0])
22        .expect("run");
23
24    println!("row sums: {:?}", results[0].read_f32().expect("row sums"));
25    println!(
26        "top-k values: {:?}",
27        results[1].read_f32().expect("topk values")
28    );
29}
examples/01_add_relu.rs (line 21)
4fn main() {
5    let device = MetalDevice::system_default().expect("no Metal device available");
6    let graph = Graph::new().expect("failed to create MPSGraph");
7
8    let input = graph
9        .placeholder(Some(&[2, 2]), data_type::FLOAT32, Some("input"))
10        .expect("failed to create placeholder");
11    let bias = graph
12        .constant_scalar(1.0, data_type::FLOAT32)
13        .expect("failed to create scalar constant");
14    let added = graph
15        .addition(&input, &bias, Some("add"))
16        .expect("failed to create addition op");
17    let output = graph
18        .relu(&added, Some("relu"))
19        .expect("failed to create relu op");
20
21    let input_data = TensorData::from_f32_slice(&device, &[1.0, -2.0, 3.0, -4.0], &[2, 2])
22        .expect("failed to create tensor data");
23    let results = graph
24        .run(&[Feed::new(&input, &input_data)], &[&output])
25        .expect("failed to execute graph");
26    let values = results[0].read_f32().expect("failed to read tensor output");
27
28    assert_eq!(values, vec![2.0, 0.0, 4.0, 0.0]);
29    println!("add+relu smoke passed: {values:?}");
30}
examples/02_compile_matmul.rs (line 32)
4fn main() {
5    let device = MetalDevice::system_default().expect("no Metal device available");
6    let queue = device
7        .new_command_queue()
8        .expect("failed to create command queue");
9    let graph = Graph::new().expect("failed to create MPSGraph");
10
11    let left = graph
12        .placeholder(Some(&[2, 3]), data_type::FLOAT32, Some("left"))
13        .expect("failed to create left placeholder");
14    let right = graph
15        .placeholder(Some(&[3, 2]), data_type::FLOAT32, Some("right"))
16        .expect("failed to create right placeholder");
17    let output = graph
18        .matrix_multiplication(&left, &right, Some("matmul"))
19        .expect("failed to create matrix multiplication op");
20
21    let executable = graph
22        .compile(
23            &device,
24            &[
25                FeedDescription::new(&left, &[2, 3], data_type::FLOAT32),
26                FeedDescription::new(&right, &[3, 2], data_type::FLOAT32),
27            ],
28            &[&output],
29        )
30        .expect("failed to compile executable");
31
32    let left_data = TensorData::from_f32_slice(&device, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3])
33        .expect("failed to create left tensor data");
34    let right_data =
35        TensorData::from_f32_slice(&device, &[7.0, 8.0, 9.0, 10.0, 11.0, 12.0], &[3, 2])
36            .expect("failed to create right tensor data");
37
38    let results = executable
39        .run(&queue, &[&left_data, &right_data])
40        .expect("failed to run executable");
41    let values = results[0].read_f32().expect("failed to read tensor output");
42    let expected = [58.0_f32, 64.0, 139.0, 154.0];
43    for (actual, expected_value) in values.iter().zip(expected) {
44        assert!(
45            (actual - expected_value).abs() < 1.0e-4,
46            "unexpected matrix multiply result: {values:?}"
47        );
48    }
49
50    println!("compile+matmul smoke passed: {values:?}");
51}
examples/06_control_flow_call.rs (line 142)
17fn main() {
18    let device = MetalDevice::system_default().expect("no Metal device available");
19    let queue = device
20        .new_command_queue()
21        .expect("failed to create command queue");
22
23    let callee_graph = Graph::new().expect("callee graph");
24    let callee_input = callee_graph
25        .placeholder(Some(&[2]), data_type::FLOAT32, Some("callee_input"))
26        .expect("callee placeholder");
27    let callee_output = callee_graph
28        .addition(&callee_input, &callee_input, Some("callee_double"))
29        .expect("callee output");
30    let callee_executable = callee_graph
31        .compile(
32            &device,
33            &[FeedDescription::new(
34                &callee_input,
35                &[2],
36                data_type::FLOAT32,
37            )],
38            &[&callee_output],
39        )
40        .expect("callee executable");
41
42    let graph = Graph::new().expect("graph");
43    let input = graph
44        .placeholder(Some(&[2]), data_type::FLOAT32, Some("input"))
45        .expect("input placeholder");
46    let predicate = graph
47        .placeholder(Some(&[]), data_type::BOOL, Some("predicate"))
48        .expect("predicate placeholder");
49    let bias = graph
50        .constant_f32_slice(&[1.0, 1.0], &[2])
51        .expect("bias constant");
52
53    let output_type = ShapedType::new(Some(&[2]), data_type::FLOAT32).expect("output type");
54    let call_results = graph
55        .call("double", &[&input], &[&output_type], Some("call"))
56        .expect("call op");
57    let if_results = graph
58        .if_then_else(
59            &predicate,
60            || vec![graph.addition(&input, &bias, None).expect("then add")],
61            || vec![graph.subtraction(&input, &bias, None).expect("else sub")],
62            Some("branch"),
63        )
64        .expect("if/then/else");
65
66    let call_operation = call_results[0].operation().expect("call operation");
67    let dependency = graph
68        .control_dependency(
69            &[&call_operation],
70            || {
71                vec![graph
72                    .unary_arithmetic(UnaryArithmeticOp::Identity, &call_results[0], None)
73                    .expect("identity")]
74            },
75            Some("dependency"),
76        )
77        .expect("control dependency");
78
79    let number_of_iterations = graph
80        .constant_scalar(4.0, data_type::INT32)
81        .expect("iteration count");
82    let zero = graph
83        .constant_scalar(0.0, data_type::INT32)
84        .expect("zero constant");
85    let one = graph
86        .constant_scalar(1.0, data_type::INT32)
87        .expect("one constant");
88    let limit = graph
89        .constant_scalar(3.0, data_type::INT32)
90        .expect("limit constant");
91
92    let for_results = graph
93        .for_loop_iterations(
94            &number_of_iterations,
95            &[&zero],
96            |_index, args| vec![graph.addition(&args[0], &one, None).expect("for-loop add")],
97            Some("for_loop"),
98        )
99        .expect("for loop");
100    let while_results = graph
101        .while_loop(
102            &[&zero],
103            |inputs| {
104                let condition = graph
105                    .binary_arithmetic(BinaryArithmeticOp::LessThan, &inputs[0], &limit, None)
106                    .expect("while predicate");
107                let passthrough = graph
108                    .unary_arithmetic(UnaryArithmeticOp::Identity, &inputs[0], None)
109                    .expect("while passthrough");
110                WhileBeforeResult {
111                    predicate: condition,
112                    results: vec![passthrough],
113                }
114            },
115            |inputs| vec![graph.addition(&inputs[0], &one, None).expect("while add")],
116            Some("while_loop"),
117        )
118        .expect("while loop");
119
120    let compile_descriptor = CompilationDescriptor::new().expect("compile descriptor");
121    compile_descriptor
122        .set_callable("double", Some(&callee_executable))
123        .expect("set callable");
124    let executable = graph
125        .compile_with_descriptor(
126            Some(&device),
127            &[
128                FeedDescription::new(&input, &[2], data_type::FLOAT32),
129                FeedDescription::new(&predicate, &[], data_type::BOOL),
130            ],
131            &[
132                &call_results[0],
133                &if_results[0],
134                &dependency[0],
135                &for_results[0],
136                &while_results[0],
137            ],
138            Some(&compile_descriptor),
139        )
140        .expect("compile executable");
141
142    let input_data = TensorData::from_f32_slice(&device, &[3.0, 4.0], &[2]).expect("input data");
143    let predicate_data =
144        TensorData::from_bytes(&device, &[1_u8], &[], data_type::BOOL).expect("predicate data");
145    let results = executable
146        .run(&queue, &[&input_data, &predicate_data])
147        .expect("run executable");
148
149    println!(
150        "call output: {:?}",
151        results[0].read_f32().expect("call output")
152    );
153    println!("if output: {:?}", results[1].read_f32().expect("if output"));
154    println!(
155        "dependency output: {:?}",
156        results[2].read_f32().expect("dependency output")
157    );
158    println!("for output: {:?}", read_i32(&results[3]));
159    println!("while output: {:?}", read_i32(&results[4]));
160}
Source

pub fn from_buffer( buffer: &MetalBuffer, shape: &[usize], data_type: u32, ) -> Option<Self>

Alias an existing MTLBuffer as tensor data.

Source

pub const fn as_ptr(&self) -> *mut c_void

Source

pub fn data_type(&self) -> u32

Source

pub fn shape(&self) -> Vec<usize>

Source

pub fn element_count(&self) -> usize

Source

pub fn byte_len(&self) -> Result<usize>

Examples found in repository?
examples/05_concat_split.rs (line 26)
4fn main() {
5    let device = MetalDevice::system_default().expect("no Metal device available");
6    let graph = Graph::new().expect("graph");
7    let input = graph
8        .placeholder(Some(&[2, 2]), data_type::FLOAT32, Some("input"))
9        .expect("placeholder");
10    let concat = graph
11        .concat_pair(&input, &input, 1, Some("concat"))
12        .expect("concat");
13    let split = graph.split_num(&concat, 2, 1, Some("split"));
14    let stacked = graph
15        .stack(&[&split[0], &split[1]], 0, Some("stack"))
16        .expect("stack");
17
18    let input_data =
19        TensorData::from_f32_slice(&device, &[1.0, 2.0, 3.0, 4.0], &[2, 2]).expect("tensor data");
20    let results = graph
21        .run(&[Feed::new(&input, &input_data)], &[&stacked])
22        .expect("run");
23
24    println!(
25        "stacked tensor bytes: {}",
26        results[0].byte_len().expect("byte len")
27    );
28}
Source

pub fn read_bytes(&self) -> Result<Vec<u8>>

Examples found in repository?
examples/06_control_flow_call.rs (line 10)
9fn read_i32(data: &TensorData) -> Vec<i32> {
10    let bytes = data.read_bytes().expect("read bytes");
11    bytes
12        .chunks_exact(core::mem::size_of::<i32>())
13        .map(|chunk| i32::from_ne_bytes(chunk.try_into().expect("i32 chunk")))
14        .collect()
15}
Source

pub fn read_f32(&self) -> Result<Vec<f32>>

Examples found in repository?
examples/03_arithmetic_topk.rs (line 24)
4fn main() {
5    let device = MetalDevice::system_default().expect("no Metal device available");
6    let graph = Graph::new().expect("graph");
7    let input = graph
8        .placeholder(Some(&[2, 3]), data_type::FLOAT32, Some("input"))
9        .expect("placeholder");
10    let squared = graph
11        .unary_arithmetic(UnaryArithmeticOp::Square, &input, Some("square"))
12        .expect("square");
13    let row_sum = graph
14        .reduce_axes(ReductionAxesOp::Sum, &squared, &[1], Some("row_sum"))
15        .expect("reduce");
16    let topk = graph.top_k(&input, 2, Some("topk")).expect("topk");
17
18    let input_data = TensorData::from_f32_slice(&device, &[1.0, 3.0, 2.0, 4.0, 6.0, 5.0], &[2, 3])
19        .expect("tensor data");
20    let results = graph
21        .run(&[Feed::new(&input, &input_data)], &[&row_sum, &topk.0])
22        .expect("run");
23
24    println!("row sums: {:?}", results[0].read_f32().expect("row sums"));
25    println!(
26        "top-k values: {:?}",
27        results[1].read_f32().expect("topk values")
28    );
29}
More examples
Hide additional examples
examples/01_add_relu.rs (line 26)
4fn main() {
5    let device = MetalDevice::system_default().expect("no Metal device available");
6    let graph = Graph::new().expect("failed to create MPSGraph");
7
8    let input = graph
9        .placeholder(Some(&[2, 2]), data_type::FLOAT32, Some("input"))
10        .expect("failed to create placeholder");
11    let bias = graph
12        .constant_scalar(1.0, data_type::FLOAT32)
13        .expect("failed to create scalar constant");
14    let added = graph
15        .addition(&input, &bias, Some("add"))
16        .expect("failed to create addition op");
17    let output = graph
18        .relu(&added, Some("relu"))
19        .expect("failed to create relu op");
20
21    let input_data = TensorData::from_f32_slice(&device, &[1.0, -2.0, 3.0, -4.0], &[2, 2])
22        .expect("failed to create tensor data");
23    let results = graph
24        .run(&[Feed::new(&input, &input_data)], &[&output])
25        .expect("failed to execute graph");
26    let values = results[0].read_f32().expect("failed to read tensor output");
27
28    assert_eq!(values, vec![2.0, 0.0, 4.0, 0.0]);
29    println!("add+relu smoke passed: {values:?}");
30}
examples/02_compile_matmul.rs (line 41)
4fn main() {
5    let device = MetalDevice::system_default().expect("no Metal device available");
6    let queue = device
7        .new_command_queue()
8        .expect("failed to create command queue");
9    let graph = Graph::new().expect("failed to create MPSGraph");
10
11    let left = graph
12        .placeholder(Some(&[2, 3]), data_type::FLOAT32, Some("left"))
13        .expect("failed to create left placeholder");
14    let right = graph
15        .placeholder(Some(&[3, 2]), data_type::FLOAT32, Some("right"))
16        .expect("failed to create right placeholder");
17    let output = graph
18        .matrix_multiplication(&left, &right, Some("matmul"))
19        .expect("failed to create matrix multiplication op");
20
21    let executable = graph
22        .compile(
23            &device,
24            &[
25                FeedDescription::new(&left, &[2, 3], data_type::FLOAT32),
26                FeedDescription::new(&right, &[3, 2], data_type::FLOAT32),
27            ],
28            &[&output],
29        )
30        .expect("failed to compile executable");
31
32    let left_data = TensorData::from_f32_slice(&device, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3])
33        .expect("failed to create left tensor data");
34    let right_data =
35        TensorData::from_f32_slice(&device, &[7.0, 8.0, 9.0, 10.0, 11.0, 12.0], &[3, 2])
36            .expect("failed to create right tensor data");
37
38    let results = executable
39        .run(&queue, &[&left_data, &right_data])
40        .expect("failed to run executable");
41    let values = results[0].read_f32().expect("failed to read tensor output");
42    let expected = [58.0_f32, 64.0, 139.0, 154.0];
43    for (actual, expected_value) in values.iter().zip(expected) {
44        assert!(
45            (actual - expected_value).abs() < 1.0e-4,
46            "unexpected matrix multiply result: {values:?}"
47        );
48    }
49
50    println!("compile+matmul smoke passed: {values:?}");
51}
examples/06_control_flow_call.rs (line 151)
17fn main() {
18    let device = MetalDevice::system_default().expect("no Metal device available");
19    let queue = device
20        .new_command_queue()
21        .expect("failed to create command queue");
22
23    let callee_graph = Graph::new().expect("callee graph");
24    let callee_input = callee_graph
25        .placeholder(Some(&[2]), data_type::FLOAT32, Some("callee_input"))
26        .expect("callee placeholder");
27    let callee_output = callee_graph
28        .addition(&callee_input, &callee_input, Some("callee_double"))
29        .expect("callee output");
30    let callee_executable = callee_graph
31        .compile(
32            &device,
33            &[FeedDescription::new(
34                &callee_input,
35                &[2],
36                data_type::FLOAT32,
37            )],
38            &[&callee_output],
39        )
40        .expect("callee executable");
41
42    let graph = Graph::new().expect("graph");
43    let input = graph
44        .placeholder(Some(&[2]), data_type::FLOAT32, Some("input"))
45        .expect("input placeholder");
46    let predicate = graph
47        .placeholder(Some(&[]), data_type::BOOL, Some("predicate"))
48        .expect("predicate placeholder");
49    let bias = graph
50        .constant_f32_slice(&[1.0, 1.0], &[2])
51        .expect("bias constant");
52
53    let output_type = ShapedType::new(Some(&[2]), data_type::FLOAT32).expect("output type");
54    let call_results = graph
55        .call("double", &[&input], &[&output_type], Some("call"))
56        .expect("call op");
57    let if_results = graph
58        .if_then_else(
59            &predicate,
60            || vec![graph.addition(&input, &bias, None).expect("then add")],
61            || vec![graph.subtraction(&input, &bias, None).expect("else sub")],
62            Some("branch"),
63        )
64        .expect("if/then/else");
65
66    let call_operation = call_results[0].operation().expect("call operation");
67    let dependency = graph
68        .control_dependency(
69            &[&call_operation],
70            || {
71                vec![graph
72                    .unary_arithmetic(UnaryArithmeticOp::Identity, &call_results[0], None)
73                    .expect("identity")]
74            },
75            Some("dependency"),
76        )
77        .expect("control dependency");
78
79    let number_of_iterations = graph
80        .constant_scalar(4.0, data_type::INT32)
81        .expect("iteration count");
82    let zero = graph
83        .constant_scalar(0.0, data_type::INT32)
84        .expect("zero constant");
85    let one = graph
86        .constant_scalar(1.0, data_type::INT32)
87        .expect("one constant");
88    let limit = graph
89        .constant_scalar(3.0, data_type::INT32)
90        .expect("limit constant");
91
92    let for_results = graph
93        .for_loop_iterations(
94            &number_of_iterations,
95            &[&zero],
96            |_index, args| vec![graph.addition(&args[0], &one, None).expect("for-loop add")],
97            Some("for_loop"),
98        )
99        .expect("for loop");
100    let while_results = graph
101        .while_loop(
102            &[&zero],
103            |inputs| {
104                let condition = graph
105                    .binary_arithmetic(BinaryArithmeticOp::LessThan, &inputs[0], &limit, None)
106                    .expect("while predicate");
107                let passthrough = graph
108                    .unary_arithmetic(UnaryArithmeticOp::Identity, &inputs[0], None)
109                    .expect("while passthrough");
110                WhileBeforeResult {
111                    predicate: condition,
112                    results: vec![passthrough],
113                }
114            },
115            |inputs| vec![graph.addition(&inputs[0], &one, None).expect("while add")],
116            Some("while_loop"),
117        )
118        .expect("while loop");
119
120    let compile_descriptor = CompilationDescriptor::new().expect("compile descriptor");
121    compile_descriptor
122        .set_callable("double", Some(&callee_executable))
123        .expect("set callable");
124    let executable = graph
125        .compile_with_descriptor(
126            Some(&device),
127            &[
128                FeedDescription::new(&input, &[2], data_type::FLOAT32),
129                FeedDescription::new(&predicate, &[], data_type::BOOL),
130            ],
131            &[
132                &call_results[0],
133                &if_results[0],
134                &dependency[0],
135                &for_results[0],
136                &while_results[0],
137            ],
138            Some(&compile_descriptor),
139        )
140        .expect("compile executable");
141
142    let input_data = TensorData::from_f32_slice(&device, &[3.0, 4.0], &[2]).expect("input data");
143    let predicate_data =
144        TensorData::from_bytes(&device, &[1_u8], &[], data_type::BOOL).expect("predicate data");
145    let results = executable
146        .run(&queue, &[&input_data, &predicate_data])
147        .expect("run executable");
148
149    println!(
150        "call output: {:?}",
151        results[0].read_f32().expect("call output")
152    );
153    println!("if output: {:?}", results[1].read_f32().expect("if output"));
154    println!(
155        "dependency output: {:?}",
156        results[2].read_f32().expect("dependency output")
157    );
158    println!("for output: {:?}", read_i32(&results[3]));
159    println!("while output: {:?}", read_i32(&results[4]));
160}
examples/07_gather_random_rnn.rs (line 157)
15fn main() {
16    let graph = Graph::new().expect("graph");
17    let updates = graph
18        .constant_f32_slice(&[10.0, 20.0, 30.0, 40.0, 50.0, 60.0], &[2, 3])
19        .expect("updates");
20    let gather_indices = graph
21        .constant_bytes(&i32_bytes(&[2, 0]), &[2], data_type::INT32)
22        .expect("gather indices");
23    let gather_nd_indices = graph
24        .constant_bytes(&i32_bytes(&[0, 1, 1, 0]), &[2, 2], data_type::INT32)
25        .expect("gather nd indices");
26    let along_indices = graph
27        .constant_bytes(&i32_bytes(&[2, 1, 0, 0, 1, 2]), &[2, 3], data_type::INT32)
28        .expect("gather along indices");
29    let axis_tensor = graph
30        .constant_scalar(1.0, data_type::INT32)
31        .expect("axis tensor");
32
33    let gather = graph
34        .gather(&updates, &gather_indices, 1, 0, Some("gather"))
35        .expect("gather");
36    let gather_nd = graph
37        .gather_nd(&updates, &gather_nd_indices, 0, Some("gather_nd"))
38        .expect("gather nd");
39    let gather_axis = graph
40        .gather_along_axis(1, &updates, &along_indices, Some("gather_axis"))
41        .expect("gather along axis");
42    let gather_axis_tensor = graph
43        .gather_along_axis_tensor(
44            &axis_tensor,
45            &updates,
46            &along_indices,
47            Some("gather_axis_tensor"),
48        )
49        .expect("gather along axis tensor");
50
51    let descriptor = RandomOpDescriptor::new(random_distribution::UNIFORM, data_type::FLOAT32)
52        .expect("random descriptor");
53    descriptor.set_min(0.0).expect("random min");
54    descriptor.set_max(1.0).expect("random max");
55    let random = graph
56        .random_tensor_seed(&[4], &descriptor, 7, Some("random"))
57        .expect("random tensor");
58    let dropout = graph
59        .dropout(&updates, 1.0, Some("dropout"))
60        .expect("dropout");
61
62    let single_gate_descriptor = SingleGateRNNDescriptor::new().expect("single gate descriptor");
63    single_gate_descriptor
64        .set_activation(rnn_activation::RELU)
65        .expect("single gate activation");
66    let single_gate_source = graph
67        .constant_f32_slice(&[0.5], &[1, 1, 1])
68        .expect("single gate source");
69    let single_gate_recurrent = graph
70        .constant_f32_slice(&[0.0], &[1, 1])
71        .expect("single gate recurrent");
72    let single_gate = graph
73        .single_gate_rnn(
74            &single_gate_source,
75            &single_gate_recurrent,
76            None,
77            None,
78            None,
79            None,
80            &single_gate_descriptor,
81            Some("single_gate"),
82        )
83        .expect("single gate rnn");
84
85    let lstm_descriptor = LSTMDescriptor::new().expect("lstm descriptor");
86    lstm_descriptor
87        .set_produce_cell(true)
88        .expect("set produce cell");
89    let lstm_source = graph
90        .constant_f32_slice(&[0.0; 4], &[1, 1, 4])
91        .expect("lstm source");
92    let lstm_recurrent = graph
93        .constant_f32_slice(&[0.0; 4], &[4, 1])
94        .expect("lstm recurrent");
95    let lstm = graph
96        .lstm(
97            &lstm_source,
98            &lstm_recurrent,
99            None,
100            None,
101            None,
102            None,
103            None,
104            None,
105            &lstm_descriptor,
106            Some("lstm"),
107        )
108        .expect("lstm");
109
110    let gru_descriptor = GRUDescriptor::new().expect("gru descriptor");
111    gru_descriptor.set_training(true).expect("set gru training");
112    gru_descriptor
113        .set_reset_after(true)
114        .expect("set gru reset_after");
115    let gru_source = graph
116        .constant_f32_slice(&[0.0; 3], &[1, 1, 3])
117        .expect("gru source");
118    let gru_recurrent = graph
119        .constant_f32_slice(&[0.0; 3], &[3, 1])
120        .expect("gru recurrent");
121    let gru_secondary_bias = graph
122        .constant_f32_slice(&[0.0], &[1])
123        .expect("gru secondary bias");
124    let gru = graph
125        .gru(
126            &gru_source,
127            &gru_recurrent,
128            None,
129            None,
130            None,
131            None,
132            Some(&gru_secondary_bias),
133            &gru_descriptor,
134            Some("gru"),
135        )
136        .expect("gru");
137
138    let results = graph
139        .run(
140            &[],
141            &[
142                &gather,
143                &gather_nd,
144                &gather_axis,
145                &gather_axis_tensor,
146                &random,
147                &dropout,
148                &single_gate[0],
149                &lstm[0],
150                &lstm[1],
151                &gru[0],
152                &gru[1],
153            ],
154        )
155        .expect("run graph");
156
157    println!("gather: {:?}", results[0].read_f32().expect("gather"));
158    println!("gather_nd: {:?}", results[1].read_f32().expect("gather_nd"));
159    println!(
160        "gather_axis: {:?}",
161        results[2].read_f32().expect("gather_axis")
162    );
163    println!(
164        "gather_axis_tensor: {:?}",
165        results[3].read_f32().expect("gather_axis_tensor")
166    );
167    println!("random: {:?}", results[4].read_f32().expect("random"));
168    println!("dropout: {:?}", results[5].read_f32().expect("dropout"));
169    println!(
170        "single_gate: {:?}",
171        results[6].read_f32().expect("single_gate")
172    );
173    println!(
174        "lstm state: {:?}",
175        results[7].read_f32().expect("lstm state")
176    );
177    println!("lstm cell: {:?}", results[8].read_f32().expect("lstm cell"));
178    println!("gru state: {:?}", results[9].read_f32().expect("gru state"));
179    println!(
180        "gru training: {:?}",
181        results[10].read_f32().expect("gru training")
182    );
183}
Source§

impl TensorData

Source

pub fn graph_device_type(&self) -> Option<u32>

Return the graph-device type that backs this tensor data.

Trait Implementations§

Source§

impl Drop for TensorData

Source§

fn drop(&mut self)

Executes the destructor for this type. Read more
Source§

fn pin_drop(self: Pin<&mut Self>)

🔬This is a nightly-only experimental API. (pin_ergonomics)
Execute the destructor for this type, but different to Drop::drop, it requires self to be pinned. Read more
Source§

impl Send for TensorData

Source§

impl Sync for TensorData

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.