pub struct TensorData { /* private fields */ }Expand description
Safe owner for an Objective-C MPSGraphTensorData.
Implementations§
Source§impl TensorData
impl TensorData
Sourcepub fn from_bytes(
device: &MetalDevice,
bytes: &[u8],
shape: &[usize],
data_type: u32,
) -> Option<Self>
pub fn from_bytes( device: &MetalDevice, bytes: &[u8], shape: &[usize], data_type: u32, ) -> Option<Self>
Build tensor data by copying CPU bytes onto the given Metal device.
Examples found in repository?
examples/06_control_flow_call.rs (line 144)
17fn main() {
18 let device = MetalDevice::system_default().expect("no Metal device available");
19 let queue = device
20 .new_command_queue()
21 .expect("failed to create command queue");
22
23 let callee_graph = Graph::new().expect("callee graph");
24 let callee_input = callee_graph
25 .placeholder(Some(&[2]), data_type::FLOAT32, Some("callee_input"))
26 .expect("callee placeholder");
27 let callee_output = callee_graph
28 .addition(&callee_input, &callee_input, Some("callee_double"))
29 .expect("callee output");
30 let callee_executable = callee_graph
31 .compile(
32 &device,
33 &[FeedDescription::new(
34 &callee_input,
35 &[2],
36 data_type::FLOAT32,
37 )],
38 &[&callee_output],
39 )
40 .expect("callee executable");
41
42 let graph = Graph::new().expect("graph");
43 let input = graph
44 .placeholder(Some(&[2]), data_type::FLOAT32, Some("input"))
45 .expect("input placeholder");
46 let predicate = graph
47 .placeholder(Some(&[]), data_type::BOOL, Some("predicate"))
48 .expect("predicate placeholder");
49 let bias = graph
50 .constant_f32_slice(&[1.0, 1.0], &[2])
51 .expect("bias constant");
52
53 let output_type = ShapedType::new(Some(&[2]), data_type::FLOAT32).expect("output type");
54 let call_results = graph
55 .call("double", &[&input], &[&output_type], Some("call"))
56 .expect("call op");
57 let if_results = graph
58 .if_then_else(
59 &predicate,
60 || vec![graph.addition(&input, &bias, None).expect("then add")],
61 || vec![graph.subtraction(&input, &bias, None).expect("else sub")],
62 Some("branch"),
63 )
64 .expect("if/then/else");
65
66 let call_operation = call_results[0].operation().expect("call operation");
67 let dependency = graph
68 .control_dependency(
69 &[&call_operation],
70 || {
71 vec![graph
72 .unary_arithmetic(UnaryArithmeticOp::Identity, &call_results[0], None)
73 .expect("identity")]
74 },
75 Some("dependency"),
76 )
77 .expect("control dependency");
78
79 let number_of_iterations = graph
80 .constant_scalar(4.0, data_type::INT32)
81 .expect("iteration count");
82 let zero = graph
83 .constant_scalar(0.0, data_type::INT32)
84 .expect("zero constant");
85 let one = graph
86 .constant_scalar(1.0, data_type::INT32)
87 .expect("one constant");
88 let limit = graph
89 .constant_scalar(3.0, data_type::INT32)
90 .expect("limit constant");
91
92 let for_results = graph
93 .for_loop_iterations(
94 &number_of_iterations,
95 &[&zero],
96 |_index, args| vec![graph.addition(&args[0], &one, None).expect("for-loop add")],
97 Some("for_loop"),
98 )
99 .expect("for loop");
100 let while_results = graph
101 .while_loop(
102 &[&zero],
103 |inputs| {
104 let condition = graph
105 .binary_arithmetic(BinaryArithmeticOp::LessThan, &inputs[0], &limit, None)
106 .expect("while predicate");
107 let passthrough = graph
108 .unary_arithmetic(UnaryArithmeticOp::Identity, &inputs[0], None)
109 .expect("while passthrough");
110 WhileBeforeResult {
111 predicate: condition,
112 results: vec![passthrough],
113 }
114 },
115 |inputs| vec![graph.addition(&inputs[0], &one, None).expect("while add")],
116 Some("while_loop"),
117 )
118 .expect("while loop");
119
120 let compile_descriptor = CompilationDescriptor::new().expect("compile descriptor");
121 compile_descriptor
122 .set_callable("double", Some(&callee_executable))
123 .expect("set callable");
124 let executable = graph
125 .compile_with_descriptor(
126 Some(&device),
127 &[
128 FeedDescription::new(&input, &[2], data_type::FLOAT32),
129 FeedDescription::new(&predicate, &[], data_type::BOOL),
130 ],
131 &[
132 &call_results[0],
133 &if_results[0],
134 &dependency[0],
135 &for_results[0],
136 &while_results[0],
137 ],
138 Some(&compile_descriptor),
139 )
140 .expect("compile executable");
141
142 let input_data = TensorData::from_f32_slice(&device, &[3.0, 4.0], &[2]).expect("input data");
143 let predicate_data =
144 TensorData::from_bytes(&device, &[1_u8], &[], data_type::BOOL).expect("predicate data");
145 let results = executable
146 .run(&queue, &[&input_data, &predicate_data])
147 .expect("run executable");
148
149 println!(
150 "call output: {:?}",
151 results[0].read_f32().expect("call output")
152 );
153 println!("if output: {:?}", results[1].read_f32().expect("if output"));
154 println!(
155 "dependency output: {:?}",
156 results[2].read_f32().expect("dependency output")
157 );
158 println!("for output: {:?}", read_i32(&results[3]));
159 println!("while output: {:?}", read_i32(&results[4]));
160}Sourcepub fn from_f32_slice(
device: &MetalDevice,
values: &[f32],
shape: &[usize],
) -> Option<Self>
pub fn from_f32_slice( device: &MetalDevice, values: &[f32], shape: &[usize], ) -> Option<Self>
Build tensor data from a contiguous f32 slice.
Examples found in repository?
examples/05_concat_split.rs (line 19)
4fn main() {
5 let device = MetalDevice::system_default().expect("no Metal device available");
6 let graph = Graph::new().expect("graph");
7 let input = graph
8 .placeholder(Some(&[2, 2]), data_type::FLOAT32, Some("input"))
9 .expect("placeholder");
10 let concat = graph
11 .concat_pair(&input, &input, 1, Some("concat"))
12 .expect("concat");
13 let split = graph.split_num(&concat, 2, 1, Some("split"));
14 let stacked = graph
15 .stack(&[&split[0], &split[1]], 0, Some("stack"))
16 .expect("stack");
17
18 let input_data =
19 TensorData::from_f32_slice(&device, &[1.0, 2.0, 3.0, 4.0], &[2, 2]).expect("tensor data");
20 let results = graph
21 .run(&[Feed::new(&input, &input_data)], &[&stacked])
22 .expect("run");
23
24 println!(
25 "stacked tensor bytes: {}",
26 results[0].byte_len().expect("byte len")
27 );
28}More examples
examples/03_arithmetic_topk.rs (line 18)
4fn main() {
5 let device = MetalDevice::system_default().expect("no Metal device available");
6 let graph = Graph::new().expect("graph");
7 let input = graph
8 .placeholder(Some(&[2, 3]), data_type::FLOAT32, Some("input"))
9 .expect("placeholder");
10 let squared = graph
11 .unary_arithmetic(UnaryArithmeticOp::Square, &input, Some("square"))
12 .expect("square");
13 let row_sum = graph
14 .reduce_axes(ReductionAxesOp::Sum, &squared, &[1], Some("row_sum"))
15 .expect("reduce");
16 let topk = graph.top_k(&input, 2, Some("topk")).expect("topk");
17
18 let input_data = TensorData::from_f32_slice(&device, &[1.0, 3.0, 2.0, 4.0, 6.0, 5.0], &[2, 3])
19 .expect("tensor data");
20 let results = graph
21 .run(&[Feed::new(&input, &input_data)], &[&row_sum, &topk.0])
22 .expect("run");
23
24 println!("row sums: {:?}", results[0].read_f32().expect("row sums"));
25 println!(
26 "top-k values: {:?}",
27 results[1].read_f32().expect("topk values")
28 );
29}examples/01_add_relu.rs (line 21)
4fn main() {
5 let device = MetalDevice::system_default().expect("no Metal device available");
6 let graph = Graph::new().expect("failed to create MPSGraph");
7
8 let input = graph
9 .placeholder(Some(&[2, 2]), data_type::FLOAT32, Some("input"))
10 .expect("failed to create placeholder");
11 let bias = graph
12 .constant_scalar(1.0, data_type::FLOAT32)
13 .expect("failed to create scalar constant");
14 let added = graph
15 .addition(&input, &bias, Some("add"))
16 .expect("failed to create addition op");
17 let output = graph
18 .relu(&added, Some("relu"))
19 .expect("failed to create relu op");
20
21 let input_data = TensorData::from_f32_slice(&device, &[1.0, -2.0, 3.0, -4.0], &[2, 2])
22 .expect("failed to create tensor data");
23 let results = graph
24 .run(&[Feed::new(&input, &input_data)], &[&output])
25 .expect("failed to execute graph");
26 let values = results[0].read_f32().expect("failed to read tensor output");
27
28 assert_eq!(values, vec![2.0, 0.0, 4.0, 0.0]);
29 println!("add+relu smoke passed: {values:?}");
30}examples/02_compile_matmul.rs (line 32)
4fn main() {
5 let device = MetalDevice::system_default().expect("no Metal device available");
6 let queue = device
7 .new_command_queue()
8 .expect("failed to create command queue");
9 let graph = Graph::new().expect("failed to create MPSGraph");
10
11 let left = graph
12 .placeholder(Some(&[2, 3]), data_type::FLOAT32, Some("left"))
13 .expect("failed to create left placeholder");
14 let right = graph
15 .placeholder(Some(&[3, 2]), data_type::FLOAT32, Some("right"))
16 .expect("failed to create right placeholder");
17 let output = graph
18 .matrix_multiplication(&left, &right, Some("matmul"))
19 .expect("failed to create matrix multiplication op");
20
21 let executable = graph
22 .compile(
23 &device,
24 &[
25 FeedDescription::new(&left, &[2, 3], data_type::FLOAT32),
26 FeedDescription::new(&right, &[3, 2], data_type::FLOAT32),
27 ],
28 &[&output],
29 )
30 .expect("failed to compile executable");
31
32 let left_data = TensorData::from_f32_slice(&device, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3])
33 .expect("failed to create left tensor data");
34 let right_data =
35 TensorData::from_f32_slice(&device, &[7.0, 8.0, 9.0, 10.0, 11.0, 12.0], &[3, 2])
36 .expect("failed to create right tensor data");
37
38 let results = executable
39 .run(&queue, &[&left_data, &right_data])
40 .expect("failed to run executable");
41 let values = results[0].read_f32().expect("failed to read tensor output");
42 let expected = [58.0_f32, 64.0, 139.0, 154.0];
43 for (actual, expected_value) in values.iter().zip(expected) {
44 assert!(
45 (actual - expected_value).abs() < 1.0e-4,
46 "unexpected matrix multiply result: {values:?}"
47 );
48 }
49
50 println!("compile+matmul smoke passed: {values:?}");
51}examples/06_control_flow_call.rs (line 142)
17fn main() {
18 let device = MetalDevice::system_default().expect("no Metal device available");
19 let queue = device
20 .new_command_queue()
21 .expect("failed to create command queue");
22
23 let callee_graph = Graph::new().expect("callee graph");
24 let callee_input = callee_graph
25 .placeholder(Some(&[2]), data_type::FLOAT32, Some("callee_input"))
26 .expect("callee placeholder");
27 let callee_output = callee_graph
28 .addition(&callee_input, &callee_input, Some("callee_double"))
29 .expect("callee output");
30 let callee_executable = callee_graph
31 .compile(
32 &device,
33 &[FeedDescription::new(
34 &callee_input,
35 &[2],
36 data_type::FLOAT32,
37 )],
38 &[&callee_output],
39 )
40 .expect("callee executable");
41
42 let graph = Graph::new().expect("graph");
43 let input = graph
44 .placeholder(Some(&[2]), data_type::FLOAT32, Some("input"))
45 .expect("input placeholder");
46 let predicate = graph
47 .placeholder(Some(&[]), data_type::BOOL, Some("predicate"))
48 .expect("predicate placeholder");
49 let bias = graph
50 .constant_f32_slice(&[1.0, 1.0], &[2])
51 .expect("bias constant");
52
53 let output_type = ShapedType::new(Some(&[2]), data_type::FLOAT32).expect("output type");
54 let call_results = graph
55 .call("double", &[&input], &[&output_type], Some("call"))
56 .expect("call op");
57 let if_results = graph
58 .if_then_else(
59 &predicate,
60 || vec![graph.addition(&input, &bias, None).expect("then add")],
61 || vec![graph.subtraction(&input, &bias, None).expect("else sub")],
62 Some("branch"),
63 )
64 .expect("if/then/else");
65
66 let call_operation = call_results[0].operation().expect("call operation");
67 let dependency = graph
68 .control_dependency(
69 &[&call_operation],
70 || {
71 vec![graph
72 .unary_arithmetic(UnaryArithmeticOp::Identity, &call_results[0], None)
73 .expect("identity")]
74 },
75 Some("dependency"),
76 )
77 .expect("control dependency");
78
79 let number_of_iterations = graph
80 .constant_scalar(4.0, data_type::INT32)
81 .expect("iteration count");
82 let zero = graph
83 .constant_scalar(0.0, data_type::INT32)
84 .expect("zero constant");
85 let one = graph
86 .constant_scalar(1.0, data_type::INT32)
87 .expect("one constant");
88 let limit = graph
89 .constant_scalar(3.0, data_type::INT32)
90 .expect("limit constant");
91
92 let for_results = graph
93 .for_loop_iterations(
94 &number_of_iterations,
95 &[&zero],
96 |_index, args| vec![graph.addition(&args[0], &one, None).expect("for-loop add")],
97 Some("for_loop"),
98 )
99 .expect("for loop");
100 let while_results = graph
101 .while_loop(
102 &[&zero],
103 |inputs| {
104 let condition = graph
105 .binary_arithmetic(BinaryArithmeticOp::LessThan, &inputs[0], &limit, None)
106 .expect("while predicate");
107 let passthrough = graph
108 .unary_arithmetic(UnaryArithmeticOp::Identity, &inputs[0], None)
109 .expect("while passthrough");
110 WhileBeforeResult {
111 predicate: condition,
112 results: vec![passthrough],
113 }
114 },
115 |inputs| vec![graph.addition(&inputs[0], &one, None).expect("while add")],
116 Some("while_loop"),
117 )
118 .expect("while loop");
119
120 let compile_descriptor = CompilationDescriptor::new().expect("compile descriptor");
121 compile_descriptor
122 .set_callable("double", Some(&callee_executable))
123 .expect("set callable");
124 let executable = graph
125 .compile_with_descriptor(
126 Some(&device),
127 &[
128 FeedDescription::new(&input, &[2], data_type::FLOAT32),
129 FeedDescription::new(&predicate, &[], data_type::BOOL),
130 ],
131 &[
132 &call_results[0],
133 &if_results[0],
134 &dependency[0],
135 &for_results[0],
136 &while_results[0],
137 ],
138 Some(&compile_descriptor),
139 )
140 .expect("compile executable");
141
142 let input_data = TensorData::from_f32_slice(&device, &[3.0, 4.0], &[2]).expect("input data");
143 let predicate_data =
144 TensorData::from_bytes(&device, &[1_u8], &[], data_type::BOOL).expect("predicate data");
145 let results = executable
146 .run(&queue, &[&input_data, &predicate_data])
147 .expect("run executable");
148
149 println!(
150 "call output: {:?}",
151 results[0].read_f32().expect("call output")
152 );
153 println!("if output: {:?}", results[1].read_f32().expect("if output"));
154 println!(
155 "dependency output: {:?}",
156 results[2].read_f32().expect("dependency output")
157 );
158 println!("for output: {:?}", read_i32(&results[3]));
159 println!("while output: {:?}", read_i32(&results[4]));
160}Sourcepub fn from_buffer(
buffer: &MetalBuffer,
shape: &[usize],
data_type: u32,
) -> Option<Self>
pub fn from_buffer( buffer: &MetalBuffer, shape: &[usize], data_type: u32, ) -> Option<Self>
Alias an existing MTLBuffer as tensor data.
pub const fn as_ptr(&self) -> *mut c_void
pub fn data_type(&self) -> u32
pub fn shape(&self) -> Vec<usize>
pub fn element_count(&self) -> usize
Sourcepub fn byte_len(&self) -> Result<usize>
pub fn byte_len(&self) -> Result<usize>
Examples found in repository?
examples/05_concat_split.rs (line 26)
4fn main() {
5 let device = MetalDevice::system_default().expect("no Metal device available");
6 let graph = Graph::new().expect("graph");
7 let input = graph
8 .placeholder(Some(&[2, 2]), data_type::FLOAT32, Some("input"))
9 .expect("placeholder");
10 let concat = graph
11 .concat_pair(&input, &input, 1, Some("concat"))
12 .expect("concat");
13 let split = graph.split_num(&concat, 2, 1, Some("split"));
14 let stacked = graph
15 .stack(&[&split[0], &split[1]], 0, Some("stack"))
16 .expect("stack");
17
18 let input_data =
19 TensorData::from_f32_slice(&device, &[1.0, 2.0, 3.0, 4.0], &[2, 2]).expect("tensor data");
20 let results = graph
21 .run(&[Feed::new(&input, &input_data)], &[&stacked])
22 .expect("run");
23
24 println!(
25 "stacked tensor bytes: {}",
26 results[0].byte_len().expect("byte len")
27 );
28}Sourcepub fn read_bytes(&self) -> Result<Vec<u8>>
pub fn read_bytes(&self) -> Result<Vec<u8>>
Sourcepub fn read_f32(&self) -> Result<Vec<f32>>
pub fn read_f32(&self) -> Result<Vec<f32>>
Examples found in repository?
examples/03_arithmetic_topk.rs (line 24)
4fn main() {
5 let device = MetalDevice::system_default().expect("no Metal device available");
6 let graph = Graph::new().expect("graph");
7 let input = graph
8 .placeholder(Some(&[2, 3]), data_type::FLOAT32, Some("input"))
9 .expect("placeholder");
10 let squared = graph
11 .unary_arithmetic(UnaryArithmeticOp::Square, &input, Some("square"))
12 .expect("square");
13 let row_sum = graph
14 .reduce_axes(ReductionAxesOp::Sum, &squared, &[1], Some("row_sum"))
15 .expect("reduce");
16 let topk = graph.top_k(&input, 2, Some("topk")).expect("topk");
17
18 let input_data = TensorData::from_f32_slice(&device, &[1.0, 3.0, 2.0, 4.0, 6.0, 5.0], &[2, 3])
19 .expect("tensor data");
20 let results = graph
21 .run(&[Feed::new(&input, &input_data)], &[&row_sum, &topk.0])
22 .expect("run");
23
24 println!("row sums: {:?}", results[0].read_f32().expect("row sums"));
25 println!(
26 "top-k values: {:?}",
27 results[1].read_f32().expect("topk values")
28 );
29}More examples
examples/01_add_relu.rs (line 26)
4fn main() {
5 let device = MetalDevice::system_default().expect("no Metal device available");
6 let graph = Graph::new().expect("failed to create MPSGraph");
7
8 let input = graph
9 .placeholder(Some(&[2, 2]), data_type::FLOAT32, Some("input"))
10 .expect("failed to create placeholder");
11 let bias = graph
12 .constant_scalar(1.0, data_type::FLOAT32)
13 .expect("failed to create scalar constant");
14 let added = graph
15 .addition(&input, &bias, Some("add"))
16 .expect("failed to create addition op");
17 let output = graph
18 .relu(&added, Some("relu"))
19 .expect("failed to create relu op");
20
21 let input_data = TensorData::from_f32_slice(&device, &[1.0, -2.0, 3.0, -4.0], &[2, 2])
22 .expect("failed to create tensor data");
23 let results = graph
24 .run(&[Feed::new(&input, &input_data)], &[&output])
25 .expect("failed to execute graph");
26 let values = results[0].read_f32().expect("failed to read tensor output");
27
28 assert_eq!(values, vec![2.0, 0.0, 4.0, 0.0]);
29 println!("add+relu smoke passed: {values:?}");
30}examples/02_compile_matmul.rs (line 41)
4fn main() {
5 let device = MetalDevice::system_default().expect("no Metal device available");
6 let queue = device
7 .new_command_queue()
8 .expect("failed to create command queue");
9 let graph = Graph::new().expect("failed to create MPSGraph");
10
11 let left = graph
12 .placeholder(Some(&[2, 3]), data_type::FLOAT32, Some("left"))
13 .expect("failed to create left placeholder");
14 let right = graph
15 .placeholder(Some(&[3, 2]), data_type::FLOAT32, Some("right"))
16 .expect("failed to create right placeholder");
17 let output = graph
18 .matrix_multiplication(&left, &right, Some("matmul"))
19 .expect("failed to create matrix multiplication op");
20
21 let executable = graph
22 .compile(
23 &device,
24 &[
25 FeedDescription::new(&left, &[2, 3], data_type::FLOAT32),
26 FeedDescription::new(&right, &[3, 2], data_type::FLOAT32),
27 ],
28 &[&output],
29 )
30 .expect("failed to compile executable");
31
32 let left_data = TensorData::from_f32_slice(&device, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3])
33 .expect("failed to create left tensor data");
34 let right_data =
35 TensorData::from_f32_slice(&device, &[7.0, 8.0, 9.0, 10.0, 11.0, 12.0], &[3, 2])
36 .expect("failed to create right tensor data");
37
38 let results = executable
39 .run(&queue, &[&left_data, &right_data])
40 .expect("failed to run executable");
41 let values = results[0].read_f32().expect("failed to read tensor output");
42 let expected = [58.0_f32, 64.0, 139.0, 154.0];
43 for (actual, expected_value) in values.iter().zip(expected) {
44 assert!(
45 (actual - expected_value).abs() < 1.0e-4,
46 "unexpected matrix multiply result: {values:?}"
47 );
48 }
49
50 println!("compile+matmul smoke passed: {values:?}");
51}examples/06_control_flow_call.rs (line 151)
17fn main() {
18 let device = MetalDevice::system_default().expect("no Metal device available");
19 let queue = device
20 .new_command_queue()
21 .expect("failed to create command queue");
22
23 let callee_graph = Graph::new().expect("callee graph");
24 let callee_input = callee_graph
25 .placeholder(Some(&[2]), data_type::FLOAT32, Some("callee_input"))
26 .expect("callee placeholder");
27 let callee_output = callee_graph
28 .addition(&callee_input, &callee_input, Some("callee_double"))
29 .expect("callee output");
30 let callee_executable = callee_graph
31 .compile(
32 &device,
33 &[FeedDescription::new(
34 &callee_input,
35 &[2],
36 data_type::FLOAT32,
37 )],
38 &[&callee_output],
39 )
40 .expect("callee executable");
41
42 let graph = Graph::new().expect("graph");
43 let input = graph
44 .placeholder(Some(&[2]), data_type::FLOAT32, Some("input"))
45 .expect("input placeholder");
46 let predicate = graph
47 .placeholder(Some(&[]), data_type::BOOL, Some("predicate"))
48 .expect("predicate placeholder");
49 let bias = graph
50 .constant_f32_slice(&[1.0, 1.0], &[2])
51 .expect("bias constant");
52
53 let output_type = ShapedType::new(Some(&[2]), data_type::FLOAT32).expect("output type");
54 let call_results = graph
55 .call("double", &[&input], &[&output_type], Some("call"))
56 .expect("call op");
57 let if_results = graph
58 .if_then_else(
59 &predicate,
60 || vec![graph.addition(&input, &bias, None).expect("then add")],
61 || vec![graph.subtraction(&input, &bias, None).expect("else sub")],
62 Some("branch"),
63 )
64 .expect("if/then/else");
65
66 let call_operation = call_results[0].operation().expect("call operation");
67 let dependency = graph
68 .control_dependency(
69 &[&call_operation],
70 || {
71 vec![graph
72 .unary_arithmetic(UnaryArithmeticOp::Identity, &call_results[0], None)
73 .expect("identity")]
74 },
75 Some("dependency"),
76 )
77 .expect("control dependency");
78
79 let number_of_iterations = graph
80 .constant_scalar(4.0, data_type::INT32)
81 .expect("iteration count");
82 let zero = graph
83 .constant_scalar(0.0, data_type::INT32)
84 .expect("zero constant");
85 let one = graph
86 .constant_scalar(1.0, data_type::INT32)
87 .expect("one constant");
88 let limit = graph
89 .constant_scalar(3.0, data_type::INT32)
90 .expect("limit constant");
91
92 let for_results = graph
93 .for_loop_iterations(
94 &number_of_iterations,
95 &[&zero],
96 |_index, args| vec![graph.addition(&args[0], &one, None).expect("for-loop add")],
97 Some("for_loop"),
98 )
99 .expect("for loop");
100 let while_results = graph
101 .while_loop(
102 &[&zero],
103 |inputs| {
104 let condition = graph
105 .binary_arithmetic(BinaryArithmeticOp::LessThan, &inputs[0], &limit, None)
106 .expect("while predicate");
107 let passthrough = graph
108 .unary_arithmetic(UnaryArithmeticOp::Identity, &inputs[0], None)
109 .expect("while passthrough");
110 WhileBeforeResult {
111 predicate: condition,
112 results: vec![passthrough],
113 }
114 },
115 |inputs| vec![graph.addition(&inputs[0], &one, None).expect("while add")],
116 Some("while_loop"),
117 )
118 .expect("while loop");
119
120 let compile_descriptor = CompilationDescriptor::new().expect("compile descriptor");
121 compile_descriptor
122 .set_callable("double", Some(&callee_executable))
123 .expect("set callable");
124 let executable = graph
125 .compile_with_descriptor(
126 Some(&device),
127 &[
128 FeedDescription::new(&input, &[2], data_type::FLOAT32),
129 FeedDescription::new(&predicate, &[], data_type::BOOL),
130 ],
131 &[
132 &call_results[0],
133 &if_results[0],
134 &dependency[0],
135 &for_results[0],
136 &while_results[0],
137 ],
138 Some(&compile_descriptor),
139 )
140 .expect("compile executable");
141
142 let input_data = TensorData::from_f32_slice(&device, &[3.0, 4.0], &[2]).expect("input data");
143 let predicate_data =
144 TensorData::from_bytes(&device, &[1_u8], &[], data_type::BOOL).expect("predicate data");
145 let results = executable
146 .run(&queue, &[&input_data, &predicate_data])
147 .expect("run executable");
148
149 println!(
150 "call output: {:?}",
151 results[0].read_f32().expect("call output")
152 );
153 println!("if output: {:?}", results[1].read_f32().expect("if output"));
154 println!(
155 "dependency output: {:?}",
156 results[2].read_f32().expect("dependency output")
157 );
158 println!("for output: {:?}", read_i32(&results[3]));
159 println!("while output: {:?}", read_i32(&results[4]));
160}examples/07_gather_random_rnn.rs (line 157)
15fn main() {
16 let graph = Graph::new().expect("graph");
17 let updates = graph
18 .constant_f32_slice(&[10.0, 20.0, 30.0, 40.0, 50.0, 60.0], &[2, 3])
19 .expect("updates");
20 let gather_indices = graph
21 .constant_bytes(&i32_bytes(&[2, 0]), &[2], data_type::INT32)
22 .expect("gather indices");
23 let gather_nd_indices = graph
24 .constant_bytes(&i32_bytes(&[0, 1, 1, 0]), &[2, 2], data_type::INT32)
25 .expect("gather nd indices");
26 let along_indices = graph
27 .constant_bytes(&i32_bytes(&[2, 1, 0, 0, 1, 2]), &[2, 3], data_type::INT32)
28 .expect("gather along indices");
29 let axis_tensor = graph
30 .constant_scalar(1.0, data_type::INT32)
31 .expect("axis tensor");
32
33 let gather = graph
34 .gather(&updates, &gather_indices, 1, 0, Some("gather"))
35 .expect("gather");
36 let gather_nd = graph
37 .gather_nd(&updates, &gather_nd_indices, 0, Some("gather_nd"))
38 .expect("gather nd");
39 let gather_axis = graph
40 .gather_along_axis(1, &updates, &along_indices, Some("gather_axis"))
41 .expect("gather along axis");
42 let gather_axis_tensor = graph
43 .gather_along_axis_tensor(
44 &axis_tensor,
45 &updates,
46 &along_indices,
47 Some("gather_axis_tensor"),
48 )
49 .expect("gather along axis tensor");
50
51 let descriptor = RandomOpDescriptor::new(random_distribution::UNIFORM, data_type::FLOAT32)
52 .expect("random descriptor");
53 descriptor.set_min(0.0).expect("random min");
54 descriptor.set_max(1.0).expect("random max");
55 let random = graph
56 .random_tensor_seed(&[4], &descriptor, 7, Some("random"))
57 .expect("random tensor");
58 let dropout = graph
59 .dropout(&updates, 1.0, Some("dropout"))
60 .expect("dropout");
61
62 let single_gate_descriptor = SingleGateRNNDescriptor::new().expect("single gate descriptor");
63 single_gate_descriptor
64 .set_activation(rnn_activation::RELU)
65 .expect("single gate activation");
66 let single_gate_source = graph
67 .constant_f32_slice(&[0.5], &[1, 1, 1])
68 .expect("single gate source");
69 let single_gate_recurrent = graph
70 .constant_f32_slice(&[0.0], &[1, 1])
71 .expect("single gate recurrent");
72 let single_gate = graph
73 .single_gate_rnn(
74 &single_gate_source,
75 &single_gate_recurrent,
76 None,
77 None,
78 None,
79 None,
80 &single_gate_descriptor,
81 Some("single_gate"),
82 )
83 .expect("single gate rnn");
84
85 let lstm_descriptor = LSTMDescriptor::new().expect("lstm descriptor");
86 lstm_descriptor
87 .set_produce_cell(true)
88 .expect("set produce cell");
89 let lstm_source = graph
90 .constant_f32_slice(&[0.0; 4], &[1, 1, 4])
91 .expect("lstm source");
92 let lstm_recurrent = graph
93 .constant_f32_slice(&[0.0; 4], &[4, 1])
94 .expect("lstm recurrent");
95 let lstm = graph
96 .lstm(
97 &lstm_source,
98 &lstm_recurrent,
99 None,
100 None,
101 None,
102 None,
103 None,
104 None,
105 &lstm_descriptor,
106 Some("lstm"),
107 )
108 .expect("lstm");
109
110 let gru_descriptor = GRUDescriptor::new().expect("gru descriptor");
111 gru_descriptor.set_training(true).expect("set gru training");
112 gru_descriptor
113 .set_reset_after(true)
114 .expect("set gru reset_after");
115 let gru_source = graph
116 .constant_f32_slice(&[0.0; 3], &[1, 1, 3])
117 .expect("gru source");
118 let gru_recurrent = graph
119 .constant_f32_slice(&[0.0; 3], &[3, 1])
120 .expect("gru recurrent");
121 let gru_secondary_bias = graph
122 .constant_f32_slice(&[0.0], &[1])
123 .expect("gru secondary bias");
124 let gru = graph
125 .gru(
126 &gru_source,
127 &gru_recurrent,
128 None,
129 None,
130 None,
131 None,
132 Some(&gru_secondary_bias),
133 &gru_descriptor,
134 Some("gru"),
135 )
136 .expect("gru");
137
138 let results = graph
139 .run(
140 &[],
141 &[
142 &gather,
143 &gather_nd,
144 &gather_axis,
145 &gather_axis_tensor,
146 &random,
147 &dropout,
148 &single_gate[0],
149 &lstm[0],
150 &lstm[1],
151 &gru[0],
152 &gru[1],
153 ],
154 )
155 .expect("run graph");
156
157 println!("gather: {:?}", results[0].read_f32().expect("gather"));
158 println!("gather_nd: {:?}", results[1].read_f32().expect("gather_nd"));
159 println!(
160 "gather_axis: {:?}",
161 results[2].read_f32().expect("gather_axis")
162 );
163 println!(
164 "gather_axis_tensor: {:?}",
165 results[3].read_f32().expect("gather_axis_tensor")
166 );
167 println!("random: {:?}", results[4].read_f32().expect("random"));
168 println!("dropout: {:?}", results[5].read_f32().expect("dropout"));
169 println!(
170 "single_gate: {:?}",
171 results[6].read_f32().expect("single_gate")
172 );
173 println!(
174 "lstm state: {:?}",
175 results[7].read_f32().expect("lstm state")
176 );
177 println!("lstm cell: {:?}", results[8].read_f32().expect("lstm cell"));
178 println!("gru state: {:?}", results[9].read_f32().expect("gru state"));
179 println!(
180 "gru training: {:?}",
181 results[10].read_f32().expect("gru training")
182 );
183}Source§impl TensorData
impl TensorData
Sourcepub fn graph_device_type(&self) -> Option<u32>
pub fn graph_device_type(&self) -> Option<u32>
Return the graph-device type that backs this tensor data.
Trait Implementations§
Auto Trait Implementations§
impl Freeze for TensorData
impl RefUnwindSafe for TensorData
impl Unpin for TensorData
impl UnsafeUnpin for TensorData
impl UnwindSafe for TensorData
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more