pub struct ShapedType { /* private fields */ }Expand description
Owned wrapper for MPSGraphShapedType.
Implementations§
Source§impl ShapedType
impl ShapedType
Sourcepub fn as_graph_type(&self) -> GraphType
pub fn as_graph_type(&self) -> GraphType
Calls the MPSGraph framework counterpart for as_graph_type.
Source§impl ShapedType
impl ShapedType
Sourcepub fn new(shape: Option<&[isize]>, data_type: u32) -> Option<Self>
pub fn new(shape: Option<&[isize]>, data_type: u32) -> Option<Self>
Create a shaped type from an optional shape and MPSDataType raw value.
Examples found in repository?
examples/04_descriptor_compile.rs (line 33)
7fn main() {
8 let device = MetalDevice::system_default().expect("no Metal device available");
9 let graph = Graph::new().expect("graph");
10 let input = graph
11 .placeholder(Some(&[4]), data_type::FLOAT32, Some("input"))
12 .expect("placeholder");
13 let output = graph
14 .unary_arithmetic(UnaryArithmeticOp::Absolute, &input, Some("abs"))
15 .expect("absolute");
16
17 let descriptor = CompilationDescriptor::new().expect("compilation descriptor");
18 descriptor
19 .set_optimization_level(optimization::LEVEL1)
20 .expect("set optimization level");
21 descriptor
22 .set_wait_for_compilation_completion(true)
23 .expect("set wait");
24
25 let executable = graph
26 .compile_with_descriptor(
27 Some(&device),
28 &[FeedDescription::new(&input, &[4], data_type::FLOAT32)],
29 &[&output],
30 Some(&descriptor),
31 )
32 .expect("compile");
33 let input_type = ShapedType::new(Some(&[4]), data_type::FLOAT32).expect("shaped type");
34 let output_types = executable
35 .output_types(Some(&device), &[&input_type], Some(&descriptor))
36 .expect("output types");
37
38 println!("feed tensors: {}", executable.feed_tensors().len());
39 println!("target tensors: {}", executable.target_tensors().len());
40 println!("output type: {:?}", output_types[0].shape());
41}More examples
examples/06_control_flow_call.rs (line 53)
17fn main() {
18 let device = MetalDevice::system_default().expect("no Metal device available");
19 let queue = device
20 .new_command_queue()
21 .expect("failed to create command queue");
22
23 let callee_graph = Graph::new().expect("callee graph");
24 let callee_input = callee_graph
25 .placeholder(Some(&[2]), data_type::FLOAT32, Some("callee_input"))
26 .expect("callee placeholder");
27 let callee_output = callee_graph
28 .addition(&callee_input, &callee_input, Some("callee_double"))
29 .expect("callee output");
30 let callee_executable = callee_graph
31 .compile(
32 &device,
33 &[FeedDescription::new(
34 &callee_input,
35 &[2],
36 data_type::FLOAT32,
37 )],
38 &[&callee_output],
39 )
40 .expect("callee executable");
41
42 let graph = Graph::new().expect("graph");
43 let input = graph
44 .placeholder(Some(&[2]), data_type::FLOAT32, Some("input"))
45 .expect("input placeholder");
46 let predicate = graph
47 .placeholder(Some(&[]), data_type::BOOL, Some("predicate"))
48 .expect("predicate placeholder");
49 let bias = graph
50 .constant_f32_slice(&[1.0, 1.0], &[2])
51 .expect("bias constant");
52
53 let output_type = ShapedType::new(Some(&[2]), data_type::FLOAT32).expect("output type");
54 let call_results = graph
55 .call("double", &[&input], &[&output_type], Some("call"))
56 .expect("call op");
57 let if_results = graph
58 .if_then_else(
59 &predicate,
60 || vec![graph.addition(&input, &bias, None).expect("then add")],
61 || vec![graph.subtraction(&input, &bias, None).expect("else sub")],
62 Some("branch"),
63 )
64 .expect("if/then/else");
65
66 let call_operation = call_results[0].operation().expect("call operation");
67 let dependency = graph
68 .control_dependency(
69 &[&call_operation],
70 || {
71 vec![graph
72 .unary_arithmetic(UnaryArithmeticOp::Identity, &call_results[0], None)
73 .expect("identity")]
74 },
75 Some("dependency"),
76 )
77 .expect("control dependency");
78
79 let number_of_iterations = graph
80 .constant_scalar(4.0, data_type::INT32)
81 .expect("iteration count");
82 let zero = graph
83 .constant_scalar(0.0, data_type::INT32)
84 .expect("zero constant");
85 let one = graph
86 .constant_scalar(1.0, data_type::INT32)
87 .expect("one constant");
88 let limit = graph
89 .constant_scalar(3.0, data_type::INT32)
90 .expect("limit constant");
91
92 let for_results = graph
93 .for_loop_iterations(
94 &number_of_iterations,
95 &[&zero],
96 |_index, args| vec![graph.addition(&args[0], &one, None).expect("for-loop add")],
97 Some("for_loop"),
98 )
99 .expect("for loop");
100 let while_results = graph
101 .while_loop(
102 &[&zero],
103 |inputs| {
104 let condition = graph
105 .binary_arithmetic(BinaryArithmeticOp::LessThan, &inputs[0], &limit, None)
106 .expect("while predicate");
107 let passthrough = graph
108 .unary_arithmetic(UnaryArithmeticOp::Identity, &inputs[0], None)
109 .expect("while passthrough");
110 WhileBeforeResult {
111 predicate: condition,
112 results: vec![passthrough],
113 }
114 },
115 |inputs| vec![graph.addition(&inputs[0], &one, None).expect("while add")],
116 Some("while_loop"),
117 )
118 .expect("while loop");
119
120 let compile_descriptor = CompilationDescriptor::new().expect("compile descriptor");
121 compile_descriptor
122 .set_callable("double", Some(&callee_executable))
123 .expect("set callable");
124 let executable = graph
125 .compile_with_descriptor(
126 Some(&device),
127 &[
128 FeedDescription::new(&input, &[2], data_type::FLOAT32),
129 FeedDescription::new(&predicate, &[], data_type::BOOL),
130 ],
131 &[
132 &call_results[0],
133 &if_results[0],
134 &dependency[0],
135 &for_results[0],
136 &while_results[0],
137 ],
138 Some(&compile_descriptor),
139 )
140 .expect("compile executable");
141
142 let input_data = TensorData::from_f32_slice(&device, &[3.0, 4.0], &[2]).expect("input data");
143 let predicate_data =
144 TensorData::from_bytes(&device, &[1_u8], &[], data_type::BOOL).expect("predicate data");
145 let results = executable
146 .run(&queue, &[&input_data, &predicate_data])
147 .expect("run executable");
148
149 println!(
150 "call output: {:?}",
151 results[0].read_f32().expect("call output")
152 );
153 println!("if output: {:?}", results[1].read_f32().expect("if output"));
154 println!(
155 "dependency output: {:?}",
156 results[2].read_f32().expect("dependency output")
157 );
158 println!("for output: {:?}", read_i32(&results[3]));
159 println!("while output: {:?}", read_i32(&results[4]));
160}Sourcepub fn shape(&self) -> Option<Vec<isize>>
pub fn shape(&self) -> Option<Vec<isize>>
Return the optional tensor shape. None corresponds to an unranked shape.
Examples found in repository?
examples/04_descriptor_compile.rs (line 40)
7fn main() {
8 let device = MetalDevice::system_default().expect("no Metal device available");
9 let graph = Graph::new().expect("graph");
10 let input = graph
11 .placeholder(Some(&[4]), data_type::FLOAT32, Some("input"))
12 .expect("placeholder");
13 let output = graph
14 .unary_arithmetic(UnaryArithmeticOp::Absolute, &input, Some("abs"))
15 .expect("absolute");
16
17 let descriptor = CompilationDescriptor::new().expect("compilation descriptor");
18 descriptor
19 .set_optimization_level(optimization::LEVEL1)
20 .expect("set optimization level");
21 descriptor
22 .set_wait_for_compilation_completion(true)
23 .expect("set wait");
24
25 let executable = graph
26 .compile_with_descriptor(
27 Some(&device),
28 &[FeedDescription::new(&input, &[4], data_type::FLOAT32)],
29 &[&output],
30 Some(&descriptor),
31 )
32 .expect("compile");
33 let input_type = ShapedType::new(Some(&[4]), data_type::FLOAT32).expect("shaped type");
34 let output_types = executable
35 .output_types(Some(&device), &[&input_type], Some(&descriptor))
36 .expect("output types");
37
38 println!("feed tensors: {}", executable.feed_tensors().len());
39 println!("target tensors: {}", executable.target_tensors().len());
40 println!("output type: {:?}", output_types[0].shape());
41}Sourcepub fn set_shape(&self, shape: Option<&[isize]>) -> Result<()>
pub fn set_shape(&self, shape: Option<&[isize]>) -> Result<()>
Replace the shape metadata for this shaped type.
Sourcepub fn set_data_type(&self, data_type: u32) -> Result<()>
pub fn set_data_type(&self, data_type: u32) -> Result<()>
Replace the data-type metadata for this shaped type.
Trait Implementations§
Auto Trait Implementations§
impl Freeze for ShapedType
impl RefUnwindSafe for ShapedType
impl Unpin for ShapedType
impl UnsafeUnpin for ShapedType
impl UnwindSafe for ShapedType
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more