pub struct Feed<'a> {
pub tensor: &'a Tensor,
pub data: &'a TensorData,
}Expand description
Ordered placeholder feed pairing used for graph execution.
Fields§
§tensor: &'a Tensor§data: &'a TensorDataImplementations§
Source§impl<'a> Feed<'a>
impl<'a> Feed<'a>
Sourcepub const fn new(tensor: &'a Tensor, data: &'a TensorData) -> Self
pub const fn new(tensor: &'a Tensor, data: &'a TensorData) -> Self
Examples found in repository?
examples/05_concat_split.rs (line 19)
4fn main() {
5 let device = MetalDevice::system_default().expect("no Metal device available");
6 let graph = Graph::new().expect("graph");
7 let input = graph
8 .placeholder(Some(&[2, 2]), data_type::FLOAT32, Some("input"))
9 .expect("placeholder");
10 let concat = graph
11 .concat_pair(&input, &input, 1, Some("concat"))
12 .expect("concat");
13 let split = graph.split_num(&concat, 2, 1, Some("split"));
14 let stacked = graph.stack(&[&split[0], &split[1]], 0, Some("stack")).expect("stack");
15
16 let input_data = TensorData::from_f32_slice(&device, &[1.0, 2.0, 3.0, 4.0], &[2, 2])
17 .expect("tensor data");
18 let results = graph
19 .run(&[Feed::new(&input, &input_data)], &[&stacked])
20 .expect("run");
21
22 println!("stacked tensor bytes: {}", results[0].byte_len().expect("byte len"));
23}More examples
examples/03_arithmetic_topk.rs (line 23)
6fn main() {
7 let device = MetalDevice::system_default().expect("no Metal device available");
8 let graph = Graph::new().expect("graph");
9 let input = graph
10 .placeholder(Some(&[2, 3]), data_type::FLOAT32, Some("input"))
11 .expect("placeholder");
12 let squared = graph
13 .unary_arithmetic(UnaryArithmeticOp::Square, &input, Some("square"))
14 .expect("square");
15 let row_sum = graph
16 .reduce_axes(ReductionAxesOp::Sum, &squared, &[1], Some("row_sum"))
17 .expect("reduce");
18 let topk = graph.top_k(&input, 2, Some("topk")).expect("topk");
19
20 let input_data = TensorData::from_f32_slice(&device, &[1.0, 3.0, 2.0, 4.0, 6.0, 5.0], &[2, 3])
21 .expect("tensor data");
22 let results = graph
23 .run(&[Feed::new(&input, &input_data)], &[&row_sum, &topk.0])
24 .expect("run");
25
26 println!("row sums: {:?}", results[0].read_f32().expect("row sums"));
27 println!("top-k values: {:?}", results[1].read_f32().expect("topk values"));
28}examples/01_add_relu.rs (line 24)
4fn main() {
5 let device = MetalDevice::system_default().expect("no Metal device available");
6 let graph = Graph::new().expect("failed to create MPSGraph");
7
8 let input = graph
9 .placeholder(Some(&[2, 2]), data_type::FLOAT32, Some("input"))
10 .expect("failed to create placeholder");
11 let bias = graph
12 .constant_scalar(1.0, data_type::FLOAT32)
13 .expect("failed to create scalar constant");
14 let added = graph
15 .addition(&input, &bias, Some("add"))
16 .expect("failed to create addition op");
17 let output = graph
18 .relu(&added, Some("relu"))
19 .expect("failed to create relu op");
20
21 let input_data = TensorData::from_f32_slice(&device, &[1.0, -2.0, 3.0, -4.0], &[2, 2])
22 .expect("failed to create tensor data");
23 let results = graph
24 .run(&[Feed::new(&input, &input_data)], &[&output])
25 .expect("failed to execute graph");
26 let values = results[0].read_f32().expect("failed to read tensor output");
27
28 assert_eq!(values, vec![2.0, 0.0, 4.0, 0.0]);
29 println!("add+relu smoke passed: {values:?}");
30}Trait Implementations§
Auto Trait Implementations§
impl<'a> Freeze for Feed<'a>
impl<'a> RefUnwindSafe for Feed<'a>
impl<'a> Send for Feed<'a>
impl<'a> Sync for Feed<'a>
impl<'a> Unpin for Feed<'a>
impl<'a> UnsafeUnpin for Feed<'a>
impl<'a> UnwindSafe for Feed<'a>
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more