pub struct FeedDescription<'a> {
pub tensor: &'a Tensor,
pub shape: &'a [usize],
pub data_type: u32,
}Expand description
Feed metadata used to compile a graph into an executable.
Fields§
§tensor: &'a TensorMirrors the MPSGraph framework property for tensor.
shape: &'a [usize]Mirrors the MPSGraph framework property for shape.
data_type: u32Mirrors the MPSGraph framework property for data_type.
Implementations§
Source§impl<'a> FeedDescription<'a>
impl<'a> FeedDescription<'a>
Sourcepub const fn new(tensor: &'a Tensor, shape: &'a [usize], data_type: u32) -> Self
pub const fn new(tensor: &'a Tensor, shape: &'a [usize], data_type: u32) -> Self
Mirrors the MPSGraph framework constant fn.
Examples found in repository?
examples/04_descriptor_compile.rs (line 28)
7fn main() {
8 let device = MetalDevice::system_default().expect("no Metal device available");
9 let graph = Graph::new().expect("graph");
10 let input = graph
11 .placeholder(Some(&[4]), data_type::FLOAT32, Some("input"))
12 .expect("placeholder");
13 let output = graph
14 .unary_arithmetic(UnaryArithmeticOp::Absolute, &input, Some("abs"))
15 .expect("absolute");
16
17 let descriptor = CompilationDescriptor::new().expect("compilation descriptor");
18 descriptor
19 .set_optimization_level(optimization::LEVEL1)
20 .expect("set optimization level");
21 descriptor
22 .set_wait_for_compilation_completion(true)
23 .expect("set wait");
24
25 let executable = graph
26 .compile_with_descriptor(
27 Some(&device),
28 &[FeedDescription::new(&input, &[4], data_type::FLOAT32)],
29 &[&output],
30 Some(&descriptor),
31 )
32 .expect("compile");
33 let input_type = ShapedType::new(Some(&[4]), data_type::FLOAT32).expect("shaped type");
34 let output_types = executable
35 .output_types(Some(&device), &[&input_type], Some(&descriptor))
36 .expect("output types");
37
38 println!("feed tensors: {}", executable.feed_tensors().len());
39 println!("target tensors: {}", executable.target_tensors().len());
40 println!("output type: {:?}", output_types[0].shape());
41}More examples
examples/02_compile_matmul.rs (line 25)
4fn main() {
5 let device = MetalDevice::system_default().expect("no Metal device available");
6 let queue = device
7 .new_command_queue()
8 .expect("failed to create command queue");
9 let graph = Graph::new().expect("failed to create MPSGraph");
10
11 let left = graph
12 .placeholder(Some(&[2, 3]), data_type::FLOAT32, Some("left"))
13 .expect("failed to create left placeholder");
14 let right = graph
15 .placeholder(Some(&[3, 2]), data_type::FLOAT32, Some("right"))
16 .expect("failed to create right placeholder");
17 let output = graph
18 .matrix_multiplication(&left, &right, Some("matmul"))
19 .expect("failed to create matrix multiplication op");
20
21 let executable = graph
22 .compile(
23 &device,
24 &[
25 FeedDescription::new(&left, &[2, 3], data_type::FLOAT32),
26 FeedDescription::new(&right, &[3, 2], data_type::FLOAT32),
27 ],
28 &[&output],
29 )
30 .expect("failed to compile executable");
31
32 let left_data = TensorData::from_f32_slice(&device, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3])
33 .expect("failed to create left tensor data");
34 let right_data =
35 TensorData::from_f32_slice(&device, &[7.0, 8.0, 9.0, 10.0, 11.0, 12.0], &[3, 2])
36 .expect("failed to create right tensor data");
37
38 let results = executable
39 .run(&queue, &[&left_data, &right_data])
40 .expect("failed to run executable");
41 let values = results[0].read_f32().expect("failed to read tensor output");
42 let expected = [58.0_f32, 64.0, 139.0, 154.0];
43 for (actual, expected_value) in values.iter().zip(expected) {
44 assert!(
45 (actual - expected_value).abs() < 1.0e-4,
46 "unexpected matrix multiply result: {values:?}"
47 );
48 }
49
50 println!("compile+matmul smoke passed: {values:?}");
51}examples/06_control_flow_call.rs (lines 33-37)
17fn main() {
18 let device = MetalDevice::system_default().expect("no Metal device available");
19 let queue = device
20 .new_command_queue()
21 .expect("failed to create command queue");
22
23 let callee_graph = Graph::new().expect("callee graph");
24 let callee_input = callee_graph
25 .placeholder(Some(&[2]), data_type::FLOAT32, Some("callee_input"))
26 .expect("callee placeholder");
27 let callee_output = callee_graph
28 .addition(&callee_input, &callee_input, Some("callee_double"))
29 .expect("callee output");
30 let callee_executable = callee_graph
31 .compile(
32 &device,
33 &[FeedDescription::new(
34 &callee_input,
35 &[2],
36 data_type::FLOAT32,
37 )],
38 &[&callee_output],
39 )
40 .expect("callee executable");
41
42 let graph = Graph::new().expect("graph");
43 let input = graph
44 .placeholder(Some(&[2]), data_type::FLOAT32, Some("input"))
45 .expect("input placeholder");
46 let predicate = graph
47 .placeholder(Some(&[]), data_type::BOOL, Some("predicate"))
48 .expect("predicate placeholder");
49 let bias = graph
50 .constant_f32_slice(&[1.0, 1.0], &[2])
51 .expect("bias constant");
52
53 let output_type = ShapedType::new(Some(&[2]), data_type::FLOAT32).expect("output type");
54 let call_results = graph
55 .call("double", &[&input], &[&output_type], Some("call"))
56 .expect("call op");
57 let if_results = graph
58 .if_then_else(
59 &predicate,
60 || vec![graph.addition(&input, &bias, None).expect("then add")],
61 || vec![graph.subtraction(&input, &bias, None).expect("else sub")],
62 Some("branch"),
63 )
64 .expect("if/then/else");
65
66 let call_operation = call_results[0].operation().expect("call operation");
67 let dependency = graph
68 .control_dependency(
69 &[&call_operation],
70 || {
71 vec![graph
72 .unary_arithmetic(UnaryArithmeticOp::Identity, &call_results[0], None)
73 .expect("identity")]
74 },
75 Some("dependency"),
76 )
77 .expect("control dependency");
78
79 let number_of_iterations = graph
80 .constant_scalar(4.0, data_type::INT32)
81 .expect("iteration count");
82 let zero = graph
83 .constant_scalar(0.0, data_type::INT32)
84 .expect("zero constant");
85 let one = graph
86 .constant_scalar(1.0, data_type::INT32)
87 .expect("one constant");
88 let limit = graph
89 .constant_scalar(3.0, data_type::INT32)
90 .expect("limit constant");
91
92 let for_results = graph
93 .for_loop_iterations(
94 &number_of_iterations,
95 &[&zero],
96 |_index, args| vec![graph.addition(&args[0], &one, None).expect("for-loop add")],
97 Some("for_loop"),
98 )
99 .expect("for loop");
100 let while_results = graph
101 .while_loop(
102 &[&zero],
103 |inputs| {
104 let condition = graph
105 .binary_arithmetic(BinaryArithmeticOp::LessThan, &inputs[0], &limit, None)
106 .expect("while predicate");
107 let passthrough = graph
108 .unary_arithmetic(UnaryArithmeticOp::Identity, &inputs[0], None)
109 .expect("while passthrough");
110 WhileBeforeResult {
111 predicate: condition,
112 results: vec![passthrough],
113 }
114 },
115 |inputs| vec![graph.addition(&inputs[0], &one, None).expect("while add")],
116 Some("while_loop"),
117 )
118 .expect("while loop");
119
120 let compile_descriptor = CompilationDescriptor::new().expect("compile descriptor");
121 compile_descriptor
122 .set_callable("double", Some(&callee_executable))
123 .expect("set callable");
124 let executable = graph
125 .compile_with_descriptor(
126 Some(&device),
127 &[
128 FeedDescription::new(&input, &[2], data_type::FLOAT32),
129 FeedDescription::new(&predicate, &[], data_type::BOOL),
130 ],
131 &[
132 &call_results[0],
133 &if_results[0],
134 &dependency[0],
135 &for_results[0],
136 &while_results[0],
137 ],
138 Some(&compile_descriptor),
139 )
140 .expect("compile executable");
141
142 let input_data = TensorData::from_f32_slice(&device, &[3.0, 4.0], &[2]).expect("input data");
143 let predicate_data =
144 TensorData::from_bytes(&device, &[1_u8], &[], data_type::BOOL).expect("predicate data");
145 let results = executable
146 .run(&queue, &[&input_data, &predicate_data])
147 .expect("run executable");
148
149 println!(
150 "call output: {:?}",
151 results[0].read_f32().expect("call output")
152 );
153 println!("if output: {:?}", results[1].read_f32().expect("if output"));
154 println!(
155 "dependency output: {:?}",
156 results[2].read_f32().expect("dependency output")
157 );
158 println!("for output: {:?}", read_i32(&results[3]));
159 println!("while output: {:?}", read_i32(&results[4]));
160}Trait Implementations§
Source§impl<'a> Clone for FeedDescription<'a>
impl<'a> Clone for FeedDescription<'a>
Source§fn clone(&self) -> FeedDescription<'a>
fn clone(&self) -> FeedDescription<'a>
Returns a duplicate of the value. Read more
1.0.0 (const: unstable) · Source§fn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
Performs copy-assignment from
source. Read moreimpl<'a> Copy for FeedDescription<'a>
Auto Trait Implementations§
impl<'a> Freeze for FeedDescription<'a>
impl<'a> RefUnwindSafe for FeedDescription<'a>
impl<'a> Send for FeedDescription<'a>
impl<'a> Sync for FeedDescription<'a>
impl<'a> Unpin for FeedDescription<'a>
impl<'a> UnsafeUnpin for FeedDescription<'a>
impl<'a> UnwindSafe for FeedDescription<'a>
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more