pub struct FeedDescription<'a> {
pub tensor: &'a Tensor,
pub shape: &'a [usize],
pub data_type: u32,
}Expand description
Feed metadata used to compile a graph into an executable.
Fields§
§tensor: &'a Tensor§shape: &'a [usize]§data_type: u32Implementations§
Source§impl<'a> FeedDescription<'a>
impl<'a> FeedDescription<'a>
Sourcepub const fn new(tensor: &'a Tensor, shape: &'a [usize], data_type: u32) -> Self
pub const fn new(tensor: &'a Tensor, shape: &'a [usize], data_type: u32) -> Self
Examples found in repository?
examples/04_descriptor_compile.rs (line 28)
7fn main() {
8 let device = MetalDevice::system_default().expect("no Metal device available");
9 let graph = Graph::new().expect("graph");
10 let input = graph
11 .placeholder(Some(&[4]), data_type::FLOAT32, Some("input"))
12 .expect("placeholder");
13 let output = graph
14 .unary_arithmetic(UnaryArithmeticOp::Absolute, &input, Some("abs"))
15 .expect("absolute");
16
17 let descriptor = CompilationDescriptor::new().expect("compilation descriptor");
18 descriptor
19 .set_optimization_level(optimization::LEVEL1)
20 .expect("set optimization level");
21 descriptor
22 .set_wait_for_compilation_completion(true)
23 .expect("set wait");
24
25 let executable = graph
26 .compile_with_descriptor(
27 Some(&device),
28 &[FeedDescription::new(&input, &[4], data_type::FLOAT32)],
29 &[&output],
30 Some(&descriptor),
31 )
32 .expect("compile");
33 let input_type = ShapedType::new(Some(&[4]), data_type::FLOAT32).expect("shaped type");
34 let output_types = executable
35 .output_types(Some(&device), &[&input_type], Some(&descriptor))
36 .expect("output types");
37
38 println!("feed tensors: {}", executable.feed_tensors().len());
39 println!("target tensors: {}", executable.target_tensors().len());
40 println!("output type: {:?}", output_types[0].shape());
41}More examples
examples/02_compile_matmul.rs (line 25)
4fn main() {
5 let device = MetalDevice::system_default().expect("no Metal device available");
6 let queue = device
7 .new_command_queue()
8 .expect("failed to create command queue");
9 let graph = Graph::new().expect("failed to create MPSGraph");
10
11 let left = graph
12 .placeholder(Some(&[2, 3]), data_type::FLOAT32, Some("left"))
13 .expect("failed to create left placeholder");
14 let right = graph
15 .placeholder(Some(&[3, 2]), data_type::FLOAT32, Some("right"))
16 .expect("failed to create right placeholder");
17 let output = graph
18 .matrix_multiplication(&left, &right, Some("matmul"))
19 .expect("failed to create matrix multiplication op");
20
21 let executable = graph
22 .compile(
23 &device,
24 &[
25 FeedDescription::new(&left, &[2, 3], data_type::FLOAT32),
26 FeedDescription::new(&right, &[3, 2], data_type::FLOAT32),
27 ],
28 &[&output],
29 )
30 .expect("failed to compile executable");
31
32 let left_data = TensorData::from_f32_slice(&device, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3])
33 .expect("failed to create left tensor data");
34 let right_data =
35 TensorData::from_f32_slice(&device, &[7.0, 8.0, 9.0, 10.0, 11.0, 12.0], &[3, 2])
36 .expect("failed to create right tensor data");
37
38 let results = executable
39 .run(&queue, &[&left_data, &right_data])
40 .expect("failed to run executable");
41 let values = results[0].read_f32().expect("failed to read tensor output");
42 let expected = [58.0_f32, 64.0, 139.0, 154.0];
43 for (actual, expected_value) in values.iter().zip(expected) {
44 assert!(
45 (actual - expected_value).abs() < 1.0e-4,
46 "unexpected matrix multiply result: {values:?}"
47 );
48 }
49
50 println!("compile+matmul smoke passed: {values:?}");
51}examples/06_control_flow_call.rs (line 32)
16fn main() {
17 let device = MetalDevice::system_default().expect("no Metal device available");
18 let queue = device
19 .new_command_queue()
20 .expect("failed to create command queue");
21
22 let callee_graph = Graph::new().expect("callee graph");
23 let callee_input = callee_graph
24 .placeholder(Some(&[2]), data_type::FLOAT32, Some("callee_input"))
25 .expect("callee placeholder");
26 let callee_output = callee_graph
27 .addition(&callee_input, &callee_input, Some("callee_double"))
28 .expect("callee output");
29 let callee_executable = callee_graph
30 .compile(
31 &device,
32 &[FeedDescription::new(&callee_input, &[2], data_type::FLOAT32)],
33 &[&callee_output],
34 )
35 .expect("callee executable");
36
37 let graph = Graph::new().expect("graph");
38 let input = graph
39 .placeholder(Some(&[2]), data_type::FLOAT32, Some("input"))
40 .expect("input placeholder");
41 let predicate = graph
42 .placeholder(Some(&[]), data_type::BOOL, Some("predicate"))
43 .expect("predicate placeholder");
44 let bias = graph.constant_f32_slice(&[1.0, 1.0], &[2]).expect("bias constant");
45
46 let output_type = ShapedType::new(Some(&[2]), data_type::FLOAT32).expect("output type");
47 let call_results = graph
48 .call("double", &[&input], &[&output_type], Some("call"))
49 .expect("call op");
50 let if_results = graph
51 .if_then_else(
52 &predicate,
53 || vec![graph.addition(&input, &bias, None).expect("then add")],
54 || vec![graph.subtraction(&input, &bias, None).expect("else sub")],
55 Some("branch"),
56 )
57 .expect("if/then/else");
58
59 let call_operation = call_results[0].operation().expect("call operation");
60 let dependency = graph
61 .control_dependency(&[&call_operation], || {
62 vec![graph
63 .unary_arithmetic(UnaryArithmeticOp::Identity, &call_results[0], None)
64 .expect("identity")]
65 }, Some("dependency"))
66 .expect("control dependency");
67
68 let number_of_iterations = graph
69 .constant_scalar(4.0, data_type::INT32)
70 .expect("iteration count");
71 let zero = graph
72 .constant_scalar(0.0, data_type::INT32)
73 .expect("zero constant");
74 let one = graph.constant_scalar(1.0, data_type::INT32).expect("one constant");
75 let limit = graph
76 .constant_scalar(3.0, data_type::INT32)
77 .expect("limit constant");
78
79 let for_results = graph
80 .for_loop_iterations(&number_of_iterations, &[&zero], |_index, args| {
81 vec![graph.addition(&args[0], &one, None).expect("for-loop add")]
82 }, Some("for_loop"))
83 .expect("for loop");
84 let while_results = graph
85 .while_loop(
86 &[&zero],
87 |inputs| {
88 let condition = graph
89 .binary_arithmetic(BinaryArithmeticOp::LessThan, &inputs[0], &limit, None)
90 .expect("while predicate");
91 let passthrough = graph
92 .unary_arithmetic(UnaryArithmeticOp::Identity, &inputs[0], None)
93 .expect("while passthrough");
94 WhileBeforeResult {
95 predicate: condition,
96 results: vec![passthrough],
97 }
98 },
99 |inputs| vec![graph.addition(&inputs[0], &one, None).expect("while add")],
100 Some("while_loop"),
101 )
102 .expect("while loop");
103
104 let compile_descriptor = CompilationDescriptor::new().expect("compile descriptor");
105 compile_descriptor
106 .set_callable("double", Some(&callee_executable))
107 .expect("set callable");
108 let executable = graph
109 .compile_with_descriptor(
110 Some(&device),
111 &[
112 FeedDescription::new(&input, &[2], data_type::FLOAT32),
113 FeedDescription::new(&predicate, &[], data_type::BOOL),
114 ],
115 &[
116 &call_results[0],
117 &if_results[0],
118 &dependency[0],
119 &for_results[0],
120 &while_results[0],
121 ],
122 Some(&compile_descriptor),
123 )
124 .expect("compile executable");
125
126 let input_data = TensorData::from_f32_slice(&device, &[3.0, 4.0], &[2]).expect("input data");
127 let predicate_data = TensorData::from_bytes(&device, &[1_u8], &[], data_type::BOOL)
128 .expect("predicate data");
129 let results = executable
130 .run(&queue, &[&input_data, &predicate_data])
131 .expect("run executable");
132
133 println!("call output: {:?}", results[0].read_f32().expect("call output"));
134 println!("if output: {:?}", results[1].read_f32().expect("if output"));
135 println!("dependency output: {:?}", results[2].read_f32().expect("dependency output"));
136 println!("for output: {:?}", read_i32(&results[3]));
137 println!("while output: {:?}", read_i32(&results[4]));
138}Trait Implementations§
Source§impl<'a> Clone for FeedDescription<'a>
impl<'a> Clone for FeedDescription<'a>
Source§fn clone(&self) -> FeedDescription<'a>
fn clone(&self) -> FeedDescription<'a>
Returns a duplicate of the value. Read more
1.0.0 (const: unstable) · Source§fn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
Performs copy-assignment from
source. Read moreimpl<'a> Copy for FeedDescription<'a>
Auto Trait Implementations§
impl<'a> Freeze for FeedDescription<'a>
impl<'a> RefUnwindSafe for FeedDescription<'a>
impl<'a> Send for FeedDescription<'a>
impl<'a> Sync for FeedDescription<'a>
impl<'a> Unpin for FeedDescription<'a>
impl<'a> UnsafeUnpin for FeedDescription<'a>
impl<'a> UnwindSafe for FeedDescription<'a>
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more