pub struct ShapedType { /* private fields */ }Expand description
Owned wrapper for MPSGraphShapedType.
Implementations§
Source§impl ShapedType
impl ShapedType
Sourcepub fn new(shape: Option<&[isize]>, data_type: u32) -> Option<Self>
pub fn new(shape: Option<&[isize]>, data_type: u32) -> Option<Self>
Create a shaped type from an optional shape and MPSDataType raw value.
Examples found in repository?
examples/04_descriptor_compile.rs (line 33)
7fn main() {
8 let device = MetalDevice::system_default().expect("no Metal device available");
9 let graph = Graph::new().expect("graph");
10 let input = graph
11 .placeholder(Some(&[4]), data_type::FLOAT32, Some("input"))
12 .expect("placeholder");
13 let output = graph
14 .unary_arithmetic(UnaryArithmeticOp::Absolute, &input, Some("abs"))
15 .expect("absolute");
16
17 let descriptor = CompilationDescriptor::new().expect("compilation descriptor");
18 descriptor
19 .set_optimization_level(optimization::LEVEL1)
20 .expect("set optimization level");
21 descriptor
22 .set_wait_for_compilation_completion(true)
23 .expect("set wait");
24
25 let executable = graph
26 .compile_with_descriptor(
27 Some(&device),
28 &[FeedDescription::new(&input, &[4], data_type::FLOAT32)],
29 &[&output],
30 Some(&descriptor),
31 )
32 .expect("compile");
33 let input_type = ShapedType::new(Some(&[4]), data_type::FLOAT32).expect("shaped type");
34 let output_types = executable
35 .output_types(Some(&device), &[&input_type], Some(&descriptor))
36 .expect("output types");
37
38 println!("feed tensors: {}", executable.feed_tensors().len());
39 println!("target tensors: {}", executable.target_tensors().len());
40 println!("output type: {:?}", output_types[0].shape());
41}More examples
examples/06_control_flow_call.rs (line 46)
16fn main() {
17 let device = MetalDevice::system_default().expect("no Metal device available");
18 let queue = device
19 .new_command_queue()
20 .expect("failed to create command queue");
21
22 let callee_graph = Graph::new().expect("callee graph");
23 let callee_input = callee_graph
24 .placeholder(Some(&[2]), data_type::FLOAT32, Some("callee_input"))
25 .expect("callee placeholder");
26 let callee_output = callee_graph
27 .addition(&callee_input, &callee_input, Some("callee_double"))
28 .expect("callee output");
29 let callee_executable = callee_graph
30 .compile(
31 &device,
32 &[FeedDescription::new(&callee_input, &[2], data_type::FLOAT32)],
33 &[&callee_output],
34 )
35 .expect("callee executable");
36
37 let graph = Graph::new().expect("graph");
38 let input = graph
39 .placeholder(Some(&[2]), data_type::FLOAT32, Some("input"))
40 .expect("input placeholder");
41 let predicate = graph
42 .placeholder(Some(&[]), data_type::BOOL, Some("predicate"))
43 .expect("predicate placeholder");
44 let bias = graph.constant_f32_slice(&[1.0, 1.0], &[2]).expect("bias constant");
45
46 let output_type = ShapedType::new(Some(&[2]), data_type::FLOAT32).expect("output type");
47 let call_results = graph
48 .call("double", &[&input], &[&output_type], Some("call"))
49 .expect("call op");
50 let if_results = graph
51 .if_then_else(
52 &predicate,
53 || vec![graph.addition(&input, &bias, None).expect("then add")],
54 || vec![graph.subtraction(&input, &bias, None).expect("else sub")],
55 Some("branch"),
56 )
57 .expect("if/then/else");
58
59 let call_operation = call_results[0].operation().expect("call operation");
60 let dependency = graph
61 .control_dependency(&[&call_operation], || {
62 vec![graph
63 .unary_arithmetic(UnaryArithmeticOp::Identity, &call_results[0], None)
64 .expect("identity")]
65 }, Some("dependency"))
66 .expect("control dependency");
67
68 let number_of_iterations = graph
69 .constant_scalar(4.0, data_type::INT32)
70 .expect("iteration count");
71 let zero = graph
72 .constant_scalar(0.0, data_type::INT32)
73 .expect("zero constant");
74 let one = graph.constant_scalar(1.0, data_type::INT32).expect("one constant");
75 let limit = graph
76 .constant_scalar(3.0, data_type::INT32)
77 .expect("limit constant");
78
79 let for_results = graph
80 .for_loop_iterations(&number_of_iterations, &[&zero], |_index, args| {
81 vec![graph.addition(&args[0], &one, None).expect("for-loop add")]
82 }, Some("for_loop"))
83 .expect("for loop");
84 let while_results = graph
85 .while_loop(
86 &[&zero],
87 |inputs| {
88 let condition = graph
89 .binary_arithmetic(BinaryArithmeticOp::LessThan, &inputs[0], &limit, None)
90 .expect("while predicate");
91 let passthrough = graph
92 .unary_arithmetic(UnaryArithmeticOp::Identity, &inputs[0], None)
93 .expect("while passthrough");
94 WhileBeforeResult {
95 predicate: condition,
96 results: vec![passthrough],
97 }
98 },
99 |inputs| vec![graph.addition(&inputs[0], &one, None).expect("while add")],
100 Some("while_loop"),
101 )
102 .expect("while loop");
103
104 let compile_descriptor = CompilationDescriptor::new().expect("compile descriptor");
105 compile_descriptor
106 .set_callable("double", Some(&callee_executable))
107 .expect("set callable");
108 let executable = graph
109 .compile_with_descriptor(
110 Some(&device),
111 &[
112 FeedDescription::new(&input, &[2], data_type::FLOAT32),
113 FeedDescription::new(&predicate, &[], data_type::BOOL),
114 ],
115 &[
116 &call_results[0],
117 &if_results[0],
118 &dependency[0],
119 &for_results[0],
120 &while_results[0],
121 ],
122 Some(&compile_descriptor),
123 )
124 .expect("compile executable");
125
126 let input_data = TensorData::from_f32_slice(&device, &[3.0, 4.0], &[2]).expect("input data");
127 let predicate_data = TensorData::from_bytes(&device, &[1_u8], &[], data_type::BOOL)
128 .expect("predicate data");
129 let results = executable
130 .run(&queue, &[&input_data, &predicate_data])
131 .expect("run executable");
132
133 println!("call output: {:?}", results[0].read_f32().expect("call output"));
134 println!("if output: {:?}", results[1].read_f32().expect("if output"));
135 println!("dependency output: {:?}", results[2].read_f32().expect("dependency output"));
136 println!("for output: {:?}", read_i32(&results[3]));
137 println!("while output: {:?}", read_i32(&results[4]));
138}Sourcepub fn shape(&self) -> Option<Vec<isize>>
pub fn shape(&self) -> Option<Vec<isize>>
Return the optional tensor shape. None corresponds to an unranked shape.
Examples found in repository?
examples/04_descriptor_compile.rs (line 40)
7fn main() {
8 let device = MetalDevice::system_default().expect("no Metal device available");
9 let graph = Graph::new().expect("graph");
10 let input = graph
11 .placeholder(Some(&[4]), data_type::FLOAT32, Some("input"))
12 .expect("placeholder");
13 let output = graph
14 .unary_arithmetic(UnaryArithmeticOp::Absolute, &input, Some("abs"))
15 .expect("absolute");
16
17 let descriptor = CompilationDescriptor::new().expect("compilation descriptor");
18 descriptor
19 .set_optimization_level(optimization::LEVEL1)
20 .expect("set optimization level");
21 descriptor
22 .set_wait_for_compilation_completion(true)
23 .expect("set wait");
24
25 let executable = graph
26 .compile_with_descriptor(
27 Some(&device),
28 &[FeedDescription::new(&input, &[4], data_type::FLOAT32)],
29 &[&output],
30 Some(&descriptor),
31 )
32 .expect("compile");
33 let input_type = ShapedType::new(Some(&[4]), data_type::FLOAT32).expect("shaped type");
34 let output_types = executable
35 .output_types(Some(&device), &[&input_type], Some(&descriptor))
36 .expect("output types");
37
38 println!("feed tensors: {}", executable.feed_tensors().len());
39 println!("target tensors: {}", executable.target_tensors().len());
40 println!("output type: {:?}", output_types[0].shape());
41}Sourcepub fn set_shape(&self, shape: Option<&[isize]>) -> Result<()>
pub fn set_shape(&self, shape: Option<&[isize]>) -> Result<()>
Replace the shape metadata for this shaped type.
Sourcepub fn set_data_type(&self, data_type: u32) -> Result<()>
pub fn set_data_type(&self, data_type: u32) -> Result<()>
Replace the data-type metadata for this shaped type.
Trait Implementations§
Auto Trait Implementations§
impl Freeze for ShapedType
impl RefUnwindSafe for ShapedType
impl Unpin for ShapedType
impl UnsafeUnpin for ShapedType
impl UnwindSafe for ShapedType
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more