1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
mod control_flow_dependency_block;
mod for_loop_body_block;
mod if_then_else_block;
mod while_after_block;
mod while_before_block;
pub use control_flow_dependency_block::ControlFlowDependencyBlock;
pub use for_loop_body_block::ForLoopBodyBlock;
pub use if_then_else_block::IfThenElseBlock;
pub use while_after_block::WhileAfterBlock;
pub use while_before_block::WhileBeforeBlock;
use crate::{Graph, Operation, Tensor};
use objc2::{extern_methods, msg_send, rc::Retained};
use objc2_foundation::{NSArray, NSString};
use std::{ops::Deref, ptr::NonNull};
/// MPSGraphControlFlowOps.
impl Graph {
/// Runs the graph for the given feeds and returns the target tensor values, ensuring all target operations also executed.
///
/// This call blocks until execution has completed.
///
/// # Arguments
///
/// - operations: Operations marked as control dependency for all ops created inside the dependent block
/// - dependent_block: closure which is provided by caller to create dependent ops
/// - name: name of scope
///
/// # Returns
///
/// A valid MPSGraphTensor array with results returned from dependent_block forwarded
fn control_dependency(
&self,
operations: &[&Operation],
dependent_block: ControlFlowDependencyBlock,
name: Option<&str>,
) -> Box<[Retained<Tensor>]> {
let operations_array = NSArray::from_slice(operations);
let result: Retained<NSArray<Tensor>> = unsafe {
msg_send![
self,
controlDependencyWithOperations: &*operations_array,
dependentBlock: &*dependent_block,
name: name.map(NSString::from_str).as_deref(),
]
};
result.to_vec().into_boxed_slice()
}
/// Adds an if-then-else operation to the graph.
///
/// # Arguments
///
/// * `predicate_tensor` - [`Tensor`] must have a single scalar value, used to decide between then/else branches
/// * `then_block` - If predicate is true operations in this block are executed
/// * `else_block` - If predicate is false operations in this block are executed
/// * `name` - Name of the operation
///
/// # Returns
///
/// If no error, the tensors returned by the user. If not empty, the user must define both `then_block` and `else_block`;
/// both should have the same number of arguments and each corresponding argument should have the same element types.
pub fn if_then_else(
&self,
predicate_tensor: &Tensor,
then_block: IfThenElseBlock,
else_block: IfThenElseBlock,
name: Option<&str>,
) -> Box<[Retained<Tensor>]> {
let result: Retained<NSArray<Tensor>> = unsafe {
msg_send![
self,
ifWithPredicateTensor: predicate_tensor,
thenBlock: &*then_block,
elseBlock: &*else_block,
name: name.map(NSString::from_str).as_deref(),
]
};
result.to_vec().into_boxed_slice()
}
/// Adds a while loop operation.
///
/// # Arguments
///
/// * `initial_inputs` - Input tensors to the `before_block`. For the first iteration, these are the same as the `initial_inputs` passed to the while loop.
/// * `before_block` - This block is run first and then calls the `after_block` with the results, or returns the results from the loop.
/// * `after_block` - Executed after the condition evaluation.
/// * `name` - Name of the operation.
///
/// # Returns
///
/// A valid [`Tensor`] slice with results returned from the condition block, depending on the predicate tensor.
pub fn while_loop(
&self,
initial_inputs: &[&Tensor],
before_block: WhileBeforeBlock,
after_block: WhileAfterBlock,
name: Option<&str>,
) -> Box<[Retained<Tensor>]> {
let initial_intputs_array = NSArray::from_slice(initial_inputs);
let result: Retained<NSArray<Tensor>> = unsafe {
msg_send![
self,
whileWithInitialInputs: &*initial_intputs_array,
before: &*before_block,
after: &*after_block,
name: name.map(NSString::from_str).as_deref(),
]
};
result.to_vec().into_boxed_slice()
}
/// Adds a for loop operation. The lower and upper bounds specify a half-open range: the range includes the lower bound but does not include the upper bound.
///
/// # Arguments
///
/// * `lower_bound` - Lower bound value of the loop. This is a scalar tensor and is the index the loop will start with.
/// * `upper_bound` - Upper bound value of the loop. This is a scalar tensor.
/// * `step` - Step value of the loop. This is a scalar tensor and must be positive.
/// * `initial_body_arguments` - Initial set of iteration arguments passed to the `body` block of the for loop.
/// * `body` - This block will execute the body of the for loop.
/// * `name` - Name of the operation.
///
/// # Returns
///
/// A valid [`Tensor`] slice with the same count and corresponding element types as `initial_body_arguments` and the return types of the for loop.
pub fn for_loop(
&self,
lower_bound: &Tensor,
upper_bound: &Tensor,
step: &Tensor,
initial_body_arguments: &[&Tensor],
body: ForLoopBodyBlock,
name: Option<&str>,
) -> Box<[Retained<Tensor>]> {
let initial_body_arguments_array = NSArray::from_slice(initial_body_arguments);
let result: Retained<NSArray<Tensor>> = unsafe {
msg_send![
self,
forLoopWithLowerBound: lower_bound,
upperBound: upper_bound,
step: step,
initialBodyArguments: &*initial_body_arguments_array,
body: &*body,
name: name.map(NSString::from_str).as_deref(),
]
};
result.to_vec().into_boxed_slice()
}
/// Adds a for loop operation with a specific number of iterations.
///
/// # Arguments
///
/// * `number_of_iterations` - [`Tensor`] with the number of iterations the loop will execute.
/// * `initial_body_arguments` - Initial set of iteration arguments passed to the `body` block of the for loop.
/// * `body` - The `body` block. This executes the body of the for loop; the index will go from 0 to `number_of_iterations` − 1.
/// * `name` - Name of the operation.
///
/// # Returns
///
/// A valid [`Tensor`] slice with the same count and corresponding element types as `initial_body_arguments` and the return types of the for loop.
pub fn for_loop_with_number_of_iterations(
&self,
number_of_iterations: &Tensor,
initial_body_arguments: &[&Tensor],
body: ForLoopBodyBlock,
name: Option<&str>,
) -> Box<[Retained<Tensor>]> {
let initial_body_arguments_array = NSArray::from_slice(initial_body_arguments);
let result: Retained<NSArray<Tensor>> = unsafe {
msg_send![
self,
forLoopWithNumberOfIterations: number_of_iterations,
initialBodyArguments: &*initial_body_arguments_array,
body: &*body,
name: name.map(NSString::from_str).as_deref(),
]
};
result.to_vec().into_boxed_slice()
}
}