pub struct Tensor { /* private fields */ }Implementations§
Source§impl Tensor
impl Tensor
Sourcepub fn operation(&self) -> Option<Operation>
pub fn operation(&self) -> Option<Operation>
Return the operation that produced this tensor.
Examples found in repository?
examples/06_control_flow_call.rs (line 59)
16fn main() {
17 let device = MetalDevice::system_default().expect("no Metal device available");
18 let queue = device
19 .new_command_queue()
20 .expect("failed to create command queue");
21
22 let callee_graph = Graph::new().expect("callee graph");
23 let callee_input = callee_graph
24 .placeholder(Some(&[2]), data_type::FLOAT32, Some("callee_input"))
25 .expect("callee placeholder");
26 let callee_output = callee_graph
27 .addition(&callee_input, &callee_input, Some("callee_double"))
28 .expect("callee output");
29 let callee_executable = callee_graph
30 .compile(
31 &device,
32 &[FeedDescription::new(&callee_input, &[2], data_type::FLOAT32)],
33 &[&callee_output],
34 )
35 .expect("callee executable");
36
37 let graph = Graph::new().expect("graph");
38 let input = graph
39 .placeholder(Some(&[2]), data_type::FLOAT32, Some("input"))
40 .expect("input placeholder");
41 let predicate = graph
42 .placeholder(Some(&[]), data_type::BOOL, Some("predicate"))
43 .expect("predicate placeholder");
44 let bias = graph.constant_f32_slice(&[1.0, 1.0], &[2]).expect("bias constant");
45
46 let output_type = ShapedType::new(Some(&[2]), data_type::FLOAT32).expect("output type");
47 let call_results = graph
48 .call("double", &[&input], &[&output_type], Some("call"))
49 .expect("call op");
50 let if_results = graph
51 .if_then_else(
52 &predicate,
53 || vec![graph.addition(&input, &bias, None).expect("then add")],
54 || vec![graph.subtraction(&input, &bias, None).expect("else sub")],
55 Some("branch"),
56 )
57 .expect("if/then/else");
58
59 let call_operation = call_results[0].operation().expect("call operation");
60 let dependency = graph
61 .control_dependency(&[&call_operation], || {
62 vec![graph
63 .unary_arithmetic(UnaryArithmeticOp::Identity, &call_results[0], None)
64 .expect("identity")]
65 }, Some("dependency"))
66 .expect("control dependency");
67
68 let number_of_iterations = graph
69 .constant_scalar(4.0, data_type::INT32)
70 .expect("iteration count");
71 let zero = graph
72 .constant_scalar(0.0, data_type::INT32)
73 .expect("zero constant");
74 let one = graph.constant_scalar(1.0, data_type::INT32).expect("one constant");
75 let limit = graph
76 .constant_scalar(3.0, data_type::INT32)
77 .expect("limit constant");
78
79 let for_results = graph
80 .for_loop_iterations(&number_of_iterations, &[&zero], |_index, args| {
81 vec![graph.addition(&args[0], &one, None).expect("for-loop add")]
82 }, Some("for_loop"))
83 .expect("for loop");
84 let while_results = graph
85 .while_loop(
86 &[&zero],
87 |inputs| {
88 let condition = graph
89 .binary_arithmetic(BinaryArithmeticOp::LessThan, &inputs[0], &limit, None)
90 .expect("while predicate");
91 let passthrough = graph
92 .unary_arithmetic(UnaryArithmeticOp::Identity, &inputs[0], None)
93 .expect("while passthrough");
94 WhileBeforeResult {
95 predicate: condition,
96 results: vec![passthrough],
97 }
98 },
99 |inputs| vec![graph.addition(&inputs[0], &one, None).expect("while add")],
100 Some("while_loop"),
101 )
102 .expect("while loop");
103
104 let compile_descriptor = CompilationDescriptor::new().expect("compile descriptor");
105 compile_descriptor
106 .set_callable("double", Some(&callee_executable))
107 .expect("set callable");
108 let executable = graph
109 .compile_with_descriptor(
110 Some(&device),
111 &[
112 FeedDescription::new(&input, &[2], data_type::FLOAT32),
113 FeedDescription::new(&predicate, &[], data_type::BOOL),
114 ],
115 &[
116 &call_results[0],
117 &if_results[0],
118 &dependency[0],
119 &for_results[0],
120 &while_results[0],
121 ],
122 Some(&compile_descriptor),
123 )
124 .expect("compile executable");
125
126 let input_data = TensorData::from_f32_slice(&device, &[3.0, 4.0], &[2]).expect("input data");
127 let predicate_data = TensorData::from_bytes(&device, &[1_u8], &[], data_type::BOOL)
128 .expect("predicate data");
129 let results = executable
130 .run(&queue, &[&input_data, &predicate_data])
131 .expect("run executable");
132
133 println!("call output: {:?}", results[0].read_f32().expect("call output"));
134 println!("if output: {:?}", results[1].read_f32().expect("if output"));
135 println!("dependency output: {:?}", results[2].read_f32().expect("dependency output"));
136 println!("for output: {:?}", read_i32(&results[3]));
137 println!("while output: {:?}", read_i32(&results[4]));
138}Trait Implementations§
Auto Trait Implementations§
impl Freeze for Tensor
impl RefUnwindSafe for Tensor
impl Unpin for Tensor
impl UnsafeUnpin for Tensor
impl UnwindSafe for Tensor
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more