pub struct Feed<'a> {
pub tensor: &'a Tensor,
pub data: &'a TensorData,
}Expand description
Ordered placeholder feed pairing used for graph execution.
Fields§
§tensor: &'a TensorMirrors the MPSGraph framework property for tensor.
data: &'a TensorDataMirrors the MPSGraph framework property for data.
Implementations§
Source§impl<'a> Feed<'a>
impl<'a> Feed<'a>
Sourcepub const fn new(tensor: &'a Tensor, data: &'a TensorData) -> Self
pub const fn new(tensor: &'a Tensor, data: &'a TensorData) -> Self
Mirrors the MPSGraph framework constant fn.
Examples found in repository?
examples/05_concat_split.rs (line 21)
4fn main() {
5 let device = MetalDevice::system_default().expect("no Metal device available");
6 let graph = Graph::new().expect("graph");
7 let input = graph
8 .placeholder(Some(&[2, 2]), data_type::FLOAT32, Some("input"))
9 .expect("placeholder");
10 let concat = graph
11 .concat_pair(&input, &input, 1, Some("concat"))
12 .expect("concat");
13 let split = graph.split_num(&concat, 2, 1, Some("split"));
14 let stacked = graph
15 .stack(&[&split[0], &split[1]], 0, Some("stack"))
16 .expect("stack");
17
18 let input_data =
19 TensorData::from_f32_slice(&device, &[1.0, 2.0, 3.0, 4.0], &[2, 2]).expect("tensor data");
20 let results = graph
21 .run(&[Feed::new(&input, &input_data)], &[&stacked])
22 .expect("run");
23
24 println!(
25 "stacked tensor bytes: {}",
26 results[0].byte_len().expect("byte len")
27 );
28}More examples
examples/03_arithmetic_topk.rs (line 21)
4fn main() {
5 let device = MetalDevice::system_default().expect("no Metal device available");
6 let graph = Graph::new().expect("graph");
7 let input = graph
8 .placeholder(Some(&[2, 3]), data_type::FLOAT32, Some("input"))
9 .expect("placeholder");
10 let squared = graph
11 .unary_arithmetic(UnaryArithmeticOp::Square, &input, Some("square"))
12 .expect("square");
13 let row_sum = graph
14 .reduce_axes(ReductionAxesOp::Sum, &squared, &[1], Some("row_sum"))
15 .expect("reduce");
16 let topk = graph.top_k(&input, 2, Some("topk")).expect("topk");
17
18 let input_data = TensorData::from_f32_slice(&device, &[1.0, 3.0, 2.0, 4.0, 6.0, 5.0], &[2, 3])
19 .expect("tensor data");
20 let results = graph
21 .run(&[Feed::new(&input, &input_data)], &[&row_sum, &topk.0])
22 .expect("run");
23
24 println!("row sums: {:?}", results[0].read_f32().expect("row sums"));
25 println!(
26 "top-k values: {:?}",
27 results[1].read_f32().expect("topk values")
28 );
29}examples/01_add_relu.rs (line 24)
4fn main() {
5 let device = MetalDevice::system_default().expect("no Metal device available");
6 let graph = Graph::new().expect("failed to create MPSGraph");
7
8 let input = graph
9 .placeholder(Some(&[2, 2]), data_type::FLOAT32, Some("input"))
10 .expect("failed to create placeholder");
11 let bias = graph
12 .constant_scalar(1.0, data_type::FLOAT32)
13 .expect("failed to create scalar constant");
14 let added = graph
15 .addition(&input, &bias, Some("add"))
16 .expect("failed to create addition op");
17 let output = graph
18 .relu(&added, Some("relu"))
19 .expect("failed to create relu op");
20
21 let input_data = TensorData::from_f32_slice(&device, &[1.0, -2.0, 3.0, -4.0], &[2, 2])
22 .expect("failed to create tensor data");
23 let results = graph
24 .run(&[Feed::new(&input, &input_data)], &[&output])
25 .expect("failed to execute graph");
26 let values = results[0].read_f32().expect("failed to read tensor output");
27
28 assert_eq!(values, vec![2.0, 0.0, 4.0, 0.0]);
29 println!("add+relu smoke passed: {values:?}");
30}Trait Implementations§
Auto Trait Implementations§
impl<'a> Freeze for Feed<'a>
impl<'a> RefUnwindSafe for Feed<'a>
impl<'a> Send for Feed<'a>
impl<'a> Sync for Feed<'a>
impl<'a> Unpin for Feed<'a>
impl<'a> UnsafeUnpin for Feed<'a>
impl<'a> UnwindSafe for Feed<'a>
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more