Skip to main content

CompilationDescriptor

Struct CompilationDescriptor 

Source
pub struct CompilationDescriptor { /* private fields */ }
Expand description

Safe owner for MPSGraphCompilationDescriptor.

Implementations§

Source§

impl CompilationDescriptor

Source

pub fn new() -> Option<Self>

Examples found in repository?
examples/04_descriptor_compile.rs (line 17)
7fn main() {
8    let device = MetalDevice::system_default().expect("no Metal device available");
9    let graph = Graph::new().expect("graph");
10    let input = graph
11        .placeholder(Some(&[4]), data_type::FLOAT32, Some("input"))
12        .expect("placeholder");
13    let output = graph
14        .unary_arithmetic(UnaryArithmeticOp::Absolute, &input, Some("abs"))
15        .expect("absolute");
16
17    let descriptor = CompilationDescriptor::new().expect("compilation descriptor");
18    descriptor
19        .set_optimization_level(optimization::LEVEL1)
20        .expect("set optimization level");
21    descriptor
22        .set_wait_for_compilation_completion(true)
23        .expect("set wait");
24
25    let executable = graph
26        .compile_with_descriptor(
27            Some(&device),
28            &[FeedDescription::new(&input, &[4], data_type::FLOAT32)],
29            &[&output],
30            Some(&descriptor),
31        )
32        .expect("compile");
33    let input_type = ShapedType::new(Some(&[4]), data_type::FLOAT32).expect("shaped type");
34    let output_types = executable
35        .output_types(Some(&device), &[&input_type], Some(&descriptor))
36        .expect("output types");
37
38    println!("feed tensors: {}", executable.feed_tensors().len());
39    println!("target tensors: {}", executable.target_tensors().len());
40    println!("output type: {:?}", output_types[0].shape());
41}
More examples
Hide additional examples
examples/06_control_flow_call.rs (line 120)
17fn main() {
18    let device = MetalDevice::system_default().expect("no Metal device available");
19    let queue = device
20        .new_command_queue()
21        .expect("failed to create command queue");
22
23    let callee_graph = Graph::new().expect("callee graph");
24    let callee_input = callee_graph
25        .placeholder(Some(&[2]), data_type::FLOAT32, Some("callee_input"))
26        .expect("callee placeholder");
27    let callee_output = callee_graph
28        .addition(&callee_input, &callee_input, Some("callee_double"))
29        .expect("callee output");
30    let callee_executable = callee_graph
31        .compile(
32            &device,
33            &[FeedDescription::new(
34                &callee_input,
35                &[2],
36                data_type::FLOAT32,
37            )],
38            &[&callee_output],
39        )
40        .expect("callee executable");
41
42    let graph = Graph::new().expect("graph");
43    let input = graph
44        .placeholder(Some(&[2]), data_type::FLOAT32, Some("input"))
45        .expect("input placeholder");
46    let predicate = graph
47        .placeholder(Some(&[]), data_type::BOOL, Some("predicate"))
48        .expect("predicate placeholder");
49    let bias = graph
50        .constant_f32_slice(&[1.0, 1.0], &[2])
51        .expect("bias constant");
52
53    let output_type = ShapedType::new(Some(&[2]), data_type::FLOAT32).expect("output type");
54    let call_results = graph
55        .call("double", &[&input], &[&output_type], Some("call"))
56        .expect("call op");
57    let if_results = graph
58        .if_then_else(
59            &predicate,
60            || vec![graph.addition(&input, &bias, None).expect("then add")],
61            || vec![graph.subtraction(&input, &bias, None).expect("else sub")],
62            Some("branch"),
63        )
64        .expect("if/then/else");
65
66    let call_operation = call_results[0].operation().expect("call operation");
67    let dependency = graph
68        .control_dependency(
69            &[&call_operation],
70            || {
71                vec![graph
72                    .unary_arithmetic(UnaryArithmeticOp::Identity, &call_results[0], None)
73                    .expect("identity")]
74            },
75            Some("dependency"),
76        )
77        .expect("control dependency");
78
79    let number_of_iterations = graph
80        .constant_scalar(4.0, data_type::INT32)
81        .expect("iteration count");
82    let zero = graph
83        .constant_scalar(0.0, data_type::INT32)
84        .expect("zero constant");
85    let one = graph
86        .constant_scalar(1.0, data_type::INT32)
87        .expect("one constant");
88    let limit = graph
89        .constant_scalar(3.0, data_type::INT32)
90        .expect("limit constant");
91
92    let for_results = graph
93        .for_loop_iterations(
94            &number_of_iterations,
95            &[&zero],
96            |_index, args| vec![graph.addition(&args[0], &one, None).expect("for-loop add")],
97            Some("for_loop"),
98        )
99        .expect("for loop");
100    let while_results = graph
101        .while_loop(
102            &[&zero],
103            |inputs| {
104                let condition = graph
105                    .binary_arithmetic(BinaryArithmeticOp::LessThan, &inputs[0], &limit, None)
106                    .expect("while predicate");
107                let passthrough = graph
108                    .unary_arithmetic(UnaryArithmeticOp::Identity, &inputs[0], None)
109                    .expect("while passthrough");
110                WhileBeforeResult {
111                    predicate: condition,
112                    results: vec![passthrough],
113                }
114            },
115            |inputs| vec![graph.addition(&inputs[0], &one, None).expect("while add")],
116            Some("while_loop"),
117        )
118        .expect("while loop");
119
120    let compile_descriptor = CompilationDescriptor::new().expect("compile descriptor");
121    compile_descriptor
122        .set_callable("double", Some(&callee_executable))
123        .expect("set callable");
124    let executable = graph
125        .compile_with_descriptor(
126            Some(&device),
127            &[
128                FeedDescription::new(&input, &[2], data_type::FLOAT32),
129                FeedDescription::new(&predicate, &[], data_type::BOOL),
130            ],
131            &[
132                &call_results[0],
133                &if_results[0],
134                &dependency[0],
135                &for_results[0],
136                &while_results[0],
137            ],
138            Some(&compile_descriptor),
139        )
140        .expect("compile executable");
141
142    let input_data = TensorData::from_f32_slice(&device, &[3.0, 4.0], &[2]).expect("input data");
143    let predicate_data =
144        TensorData::from_bytes(&device, &[1_u8], &[], data_type::BOOL).expect("predicate data");
145    let results = executable
146        .run(&queue, &[&input_data, &predicate_data])
147        .expect("run executable");
148
149    println!(
150        "call output: {:?}",
151        results[0].read_f32().expect("call output")
152    );
153    println!("if output: {:?}", results[1].read_f32().expect("if output"));
154    println!(
155        "dependency output: {:?}",
156        results[2].read_f32().expect("dependency output")
157    );
158    println!("for output: {:?}", read_i32(&results[3]));
159    println!("while output: {:?}", read_i32(&results[4]));
160}
Source

pub fn disable_type_inference(&self) -> Result<()>

Source

pub fn optimization_level(&self) -> u64

Source

pub fn set_optimization_level(&self, value: u64) -> Result<()>

Examples found in repository?
examples/04_descriptor_compile.rs (line 19)
7fn main() {
8    let device = MetalDevice::system_default().expect("no Metal device available");
9    let graph = Graph::new().expect("graph");
10    let input = graph
11        .placeholder(Some(&[4]), data_type::FLOAT32, Some("input"))
12        .expect("placeholder");
13    let output = graph
14        .unary_arithmetic(UnaryArithmeticOp::Absolute, &input, Some("abs"))
15        .expect("absolute");
16
17    let descriptor = CompilationDescriptor::new().expect("compilation descriptor");
18    descriptor
19        .set_optimization_level(optimization::LEVEL1)
20        .expect("set optimization level");
21    descriptor
22        .set_wait_for_compilation_completion(true)
23        .expect("set wait");
24
25    let executable = graph
26        .compile_with_descriptor(
27            Some(&device),
28            &[FeedDescription::new(&input, &[4], data_type::FLOAT32)],
29            &[&output],
30            Some(&descriptor),
31        )
32        .expect("compile");
33    let input_type = ShapedType::new(Some(&[4]), data_type::FLOAT32).expect("shaped type");
34    let output_types = executable
35        .output_types(Some(&device), &[&input_type], Some(&descriptor))
36        .expect("output types");
37
38    println!("feed tensors: {}", executable.feed_tensors().len());
39    println!("target tensors: {}", executable.target_tensors().len());
40    println!("output type: {:?}", output_types[0].shape());
41}
Source

pub fn wait_for_compilation_completion(&self) -> bool

Source

pub fn set_wait_for_compilation_completion(&self, value: bool) -> Result<()>

Examples found in repository?
examples/04_descriptor_compile.rs (line 22)
7fn main() {
8    let device = MetalDevice::system_default().expect("no Metal device available");
9    let graph = Graph::new().expect("graph");
10    let input = graph
11        .placeholder(Some(&[4]), data_type::FLOAT32, Some("input"))
12        .expect("placeholder");
13    let output = graph
14        .unary_arithmetic(UnaryArithmeticOp::Absolute, &input, Some("abs"))
15        .expect("absolute");
16
17    let descriptor = CompilationDescriptor::new().expect("compilation descriptor");
18    descriptor
19        .set_optimization_level(optimization::LEVEL1)
20        .expect("set optimization level");
21    descriptor
22        .set_wait_for_compilation_completion(true)
23        .expect("set wait");
24
25    let executable = graph
26        .compile_with_descriptor(
27            Some(&device),
28            &[FeedDescription::new(&input, &[4], data_type::FLOAT32)],
29            &[&output],
30            Some(&descriptor),
31        )
32        .expect("compile");
33    let input_type = ShapedType::new(Some(&[4]), data_type::FLOAT32).expect("shaped type");
34    let output_types = executable
35        .output_types(Some(&device), &[&input_type], Some(&descriptor))
36        .expect("output types");
37
38    println!("feed tensors: {}", executable.feed_tensors().len());
39    println!("target tensors: {}", executable.target_tensors().len());
40    println!("output type: {:?}", output_types[0].shape());
41}
Source

pub fn optimization_profile(&self) -> u64

Source

pub fn set_optimization_profile(&self, value: u64) -> Result<()>

Source

pub fn reduced_precision_fast_math(&self) -> usize

Source

pub fn set_reduced_precision_fast_math(&self, value: usize) -> Result<()>

Source

pub fn set_callable( &self, symbol_name: &str, executable: Option<&Executable>, ) -> Result<()>

Examples found in repository?
examples/06_control_flow_call.rs (line 122)
17fn main() {
18    let device = MetalDevice::system_default().expect("no Metal device available");
19    let queue = device
20        .new_command_queue()
21        .expect("failed to create command queue");
22
23    let callee_graph = Graph::new().expect("callee graph");
24    let callee_input = callee_graph
25        .placeholder(Some(&[2]), data_type::FLOAT32, Some("callee_input"))
26        .expect("callee placeholder");
27    let callee_output = callee_graph
28        .addition(&callee_input, &callee_input, Some("callee_double"))
29        .expect("callee output");
30    let callee_executable = callee_graph
31        .compile(
32            &device,
33            &[FeedDescription::new(
34                &callee_input,
35                &[2],
36                data_type::FLOAT32,
37            )],
38            &[&callee_output],
39        )
40        .expect("callee executable");
41
42    let graph = Graph::new().expect("graph");
43    let input = graph
44        .placeholder(Some(&[2]), data_type::FLOAT32, Some("input"))
45        .expect("input placeholder");
46    let predicate = graph
47        .placeholder(Some(&[]), data_type::BOOL, Some("predicate"))
48        .expect("predicate placeholder");
49    let bias = graph
50        .constant_f32_slice(&[1.0, 1.0], &[2])
51        .expect("bias constant");
52
53    let output_type = ShapedType::new(Some(&[2]), data_type::FLOAT32).expect("output type");
54    let call_results = graph
55        .call("double", &[&input], &[&output_type], Some("call"))
56        .expect("call op");
57    let if_results = graph
58        .if_then_else(
59            &predicate,
60            || vec![graph.addition(&input, &bias, None).expect("then add")],
61            || vec![graph.subtraction(&input, &bias, None).expect("else sub")],
62            Some("branch"),
63        )
64        .expect("if/then/else");
65
66    let call_operation = call_results[0].operation().expect("call operation");
67    let dependency = graph
68        .control_dependency(
69            &[&call_operation],
70            || {
71                vec![graph
72                    .unary_arithmetic(UnaryArithmeticOp::Identity, &call_results[0], None)
73                    .expect("identity")]
74            },
75            Some("dependency"),
76        )
77        .expect("control dependency");
78
79    let number_of_iterations = graph
80        .constant_scalar(4.0, data_type::INT32)
81        .expect("iteration count");
82    let zero = graph
83        .constant_scalar(0.0, data_type::INT32)
84        .expect("zero constant");
85    let one = graph
86        .constant_scalar(1.0, data_type::INT32)
87        .expect("one constant");
88    let limit = graph
89        .constant_scalar(3.0, data_type::INT32)
90        .expect("limit constant");
91
92    let for_results = graph
93        .for_loop_iterations(
94            &number_of_iterations,
95            &[&zero],
96            |_index, args| vec![graph.addition(&args[0], &one, None).expect("for-loop add")],
97            Some("for_loop"),
98        )
99        .expect("for loop");
100    let while_results = graph
101        .while_loop(
102            &[&zero],
103            |inputs| {
104                let condition = graph
105                    .binary_arithmetic(BinaryArithmeticOp::LessThan, &inputs[0], &limit, None)
106                    .expect("while predicate");
107                let passthrough = graph
108                    .unary_arithmetic(UnaryArithmeticOp::Identity, &inputs[0], None)
109                    .expect("while passthrough");
110                WhileBeforeResult {
111                    predicate: condition,
112                    results: vec![passthrough],
113                }
114            },
115            |inputs| vec![graph.addition(&inputs[0], &one, None).expect("while add")],
116            Some("while_loop"),
117        )
118        .expect("while loop");
119
120    let compile_descriptor = CompilationDescriptor::new().expect("compile descriptor");
121    compile_descriptor
122        .set_callable("double", Some(&callee_executable))
123        .expect("set callable");
124    let executable = graph
125        .compile_with_descriptor(
126            Some(&device),
127            &[
128                FeedDescription::new(&input, &[2], data_type::FLOAT32),
129                FeedDescription::new(&predicate, &[], data_type::BOOL),
130            ],
131            &[
132                &call_results[0],
133                &if_results[0],
134                &dependency[0],
135                &for_results[0],
136                &while_results[0],
137            ],
138            Some(&compile_descriptor),
139        )
140        .expect("compile executable");
141
142    let input_data = TensorData::from_f32_slice(&device, &[3.0, 4.0], &[2]).expect("input data");
143    let predicate_data =
144        TensorData::from_bytes(&device, &[1_u8], &[], data_type::BOOL).expect("predicate data");
145    let results = executable
146        .run(&queue, &[&input_data, &predicate_data])
147        .expect("run executable");
148
149    println!(
150        "call output: {:?}",
151        results[0].read_f32().expect("call output")
152    );
153    println!("if output: {:?}", results[1].read_f32().expect("if output"));
154    println!(
155        "dependency output: {:?}",
156        results[2].read_f32().expect("dependency output")
157    );
158    println!("for output: {:?}", read_i32(&results[3]));
159    println!("while output: {:?}", read_i32(&results[4]));
160}

Trait Implementations§

Source§

impl Drop for CompilationDescriptor

Source§

fn drop(&mut self)

Executes the destructor for this type. Read more
Source§

fn pin_drop(self: Pin<&mut Self>)

🔬This is a nightly-only experimental API. (pin_ergonomics)
Execute the destructor for this type, but different to Drop::drop, it requires self to be pinned. Read more
Source§

impl Send for CompilationDescriptor

Source§

impl Sync for CompilationDescriptor

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.