pub struct TensorData { /* private fields */ }Expand description
Safe owner for an Objective-C MPSGraphTensorData.
Implementations§
Source§impl TensorData
impl TensorData
Sourcepub fn from_bytes(
device: &MetalDevice,
bytes: &[u8],
shape: &[usize],
data_type: u32,
) -> Option<Self>
pub fn from_bytes( device: &MetalDevice, bytes: &[u8], shape: &[usize], data_type: u32, ) -> Option<Self>
Build tensor data by copying CPU bytes onto the given Metal device.
Sourcepub fn from_f32_slice(
device: &MetalDevice,
values: &[f32],
shape: &[usize],
) -> Option<Self>
pub fn from_f32_slice( device: &MetalDevice, values: &[f32], shape: &[usize], ) -> Option<Self>
Build tensor data from a contiguous f32 slice.
Examples found in repository?
examples/05_concat_split.rs (line 16)
4fn main() {
5 let device = MetalDevice::system_default().expect("no Metal device available");
6 let graph = Graph::new().expect("graph");
7 let input = graph
8 .placeholder(Some(&[2, 2]), data_type::FLOAT32, Some("input"))
9 .expect("placeholder");
10 let concat = graph
11 .concat_pair(&input, &input, 1, Some("concat"))
12 .expect("concat");
13 let split = graph.split_num(&concat, 2, 1, Some("split"));
14 let stacked = graph.stack(&[&split[0], &split[1]], 0, Some("stack")).expect("stack");
15
16 let input_data = TensorData::from_f32_slice(&device, &[1.0, 2.0, 3.0, 4.0], &[2, 2])
17 .expect("tensor data");
18 let results = graph
19 .run(&[Feed::new(&input, &input_data)], &[&stacked])
20 .expect("run");
21
22 println!("stacked tensor bytes: {}", results[0].byte_len().expect("byte len"));
23}More examples
examples/03_arithmetic_topk.rs (line 20)
6fn main() {
7 let device = MetalDevice::system_default().expect("no Metal device available");
8 let graph = Graph::new().expect("graph");
9 let input = graph
10 .placeholder(Some(&[2, 3]), data_type::FLOAT32, Some("input"))
11 .expect("placeholder");
12 let squared = graph
13 .unary_arithmetic(UnaryArithmeticOp::Square, &input, Some("square"))
14 .expect("square");
15 let row_sum = graph
16 .reduce_axes(ReductionAxesOp::Sum, &squared, &[1], Some("row_sum"))
17 .expect("reduce");
18 let topk = graph.top_k(&input, 2, Some("topk")).expect("topk");
19
20 let input_data = TensorData::from_f32_slice(&device, &[1.0, 3.0, 2.0, 4.0, 6.0, 5.0], &[2, 3])
21 .expect("tensor data");
22 let results = graph
23 .run(&[Feed::new(&input, &input_data)], &[&row_sum, &topk.0])
24 .expect("run");
25
26 println!("row sums: {:?}", results[0].read_f32().expect("row sums"));
27 println!("top-k values: {:?}", results[1].read_f32().expect("topk values"));
28}examples/01_add_relu.rs (line 21)
4fn main() {
5 let device = MetalDevice::system_default().expect("no Metal device available");
6 let graph = Graph::new().expect("failed to create MPSGraph");
7
8 let input = graph
9 .placeholder(Some(&[2, 2]), data_type::FLOAT32, Some("input"))
10 .expect("failed to create placeholder");
11 let bias = graph
12 .constant_scalar(1.0, data_type::FLOAT32)
13 .expect("failed to create scalar constant");
14 let added = graph
15 .addition(&input, &bias, Some("add"))
16 .expect("failed to create addition op");
17 let output = graph
18 .relu(&added, Some("relu"))
19 .expect("failed to create relu op");
20
21 let input_data = TensorData::from_f32_slice(&device, &[1.0, -2.0, 3.0, -4.0], &[2, 2])
22 .expect("failed to create tensor data");
23 let results = graph
24 .run(&[Feed::new(&input, &input_data)], &[&output])
25 .expect("failed to execute graph");
26 let values = results[0].read_f32().expect("failed to read tensor output");
27
28 assert_eq!(values, vec![2.0, 0.0, 4.0, 0.0]);
29 println!("add+relu smoke passed: {values:?}");
30}examples/02_compile_matmul.rs (line 32)
4fn main() {
5 let device = MetalDevice::system_default().expect("no Metal device available");
6 let queue = device
7 .new_command_queue()
8 .expect("failed to create command queue");
9 let graph = Graph::new().expect("failed to create MPSGraph");
10
11 let left = graph
12 .placeholder(Some(&[2, 3]), data_type::FLOAT32, Some("left"))
13 .expect("failed to create left placeholder");
14 let right = graph
15 .placeholder(Some(&[3, 2]), data_type::FLOAT32, Some("right"))
16 .expect("failed to create right placeholder");
17 let output = graph
18 .matrix_multiplication(&left, &right, Some("matmul"))
19 .expect("failed to create matrix multiplication op");
20
21 let executable = graph
22 .compile(
23 &device,
24 &[
25 FeedDescription::new(&left, &[2, 3], data_type::FLOAT32),
26 FeedDescription::new(&right, &[3, 2], data_type::FLOAT32),
27 ],
28 &[&output],
29 )
30 .expect("failed to compile executable");
31
32 let left_data = TensorData::from_f32_slice(&device, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3])
33 .expect("failed to create left tensor data");
34 let right_data =
35 TensorData::from_f32_slice(&device, &[7.0, 8.0, 9.0, 10.0, 11.0, 12.0], &[3, 2])
36 .expect("failed to create right tensor data");
37
38 let results = executable
39 .run(&queue, &[&left_data, &right_data])
40 .expect("failed to run executable");
41 let values = results[0].read_f32().expect("failed to read tensor output");
42 let expected = [58.0_f32, 64.0, 139.0, 154.0];
43 for (actual, expected_value) in values.iter().zip(expected) {
44 assert!(
45 (actual - expected_value).abs() < 1.0e-4,
46 "unexpected matrix multiply result: {values:?}"
47 );
48 }
49
50 println!("compile+matmul smoke passed: {values:?}");
51}Sourcepub fn from_buffer(
buffer: &MetalBuffer,
shape: &[usize],
data_type: u32,
) -> Option<Self>
pub fn from_buffer( buffer: &MetalBuffer, shape: &[usize], data_type: u32, ) -> Option<Self>
Alias an existing MTLBuffer as tensor data.
pub const fn as_ptr(&self) -> *mut c_void
pub fn data_type(&self) -> u32
pub fn shape(&self) -> Vec<usize>
pub fn element_count(&self) -> usize
Sourcepub fn byte_len(&self) -> Result<usize>
pub fn byte_len(&self) -> Result<usize>
Examples found in repository?
examples/05_concat_split.rs (line 22)
4fn main() {
5 let device = MetalDevice::system_default().expect("no Metal device available");
6 let graph = Graph::new().expect("graph");
7 let input = graph
8 .placeholder(Some(&[2, 2]), data_type::FLOAT32, Some("input"))
9 .expect("placeholder");
10 let concat = graph
11 .concat_pair(&input, &input, 1, Some("concat"))
12 .expect("concat");
13 let split = graph.split_num(&concat, 2, 1, Some("split"));
14 let stacked = graph.stack(&[&split[0], &split[1]], 0, Some("stack")).expect("stack");
15
16 let input_data = TensorData::from_f32_slice(&device, &[1.0, 2.0, 3.0, 4.0], &[2, 2])
17 .expect("tensor data");
18 let results = graph
19 .run(&[Feed::new(&input, &input_data)], &[&stacked])
20 .expect("run");
21
22 println!("stacked tensor bytes: {}", results[0].byte_len().expect("byte len"));
23}pub fn read_bytes(&self) -> Result<Vec<u8>>
Sourcepub fn read_f32(&self) -> Result<Vec<f32>>
pub fn read_f32(&self) -> Result<Vec<f32>>
Examples found in repository?
examples/03_arithmetic_topk.rs (line 26)
6fn main() {
7 let device = MetalDevice::system_default().expect("no Metal device available");
8 let graph = Graph::new().expect("graph");
9 let input = graph
10 .placeholder(Some(&[2, 3]), data_type::FLOAT32, Some("input"))
11 .expect("placeholder");
12 let squared = graph
13 .unary_arithmetic(UnaryArithmeticOp::Square, &input, Some("square"))
14 .expect("square");
15 let row_sum = graph
16 .reduce_axes(ReductionAxesOp::Sum, &squared, &[1], Some("row_sum"))
17 .expect("reduce");
18 let topk = graph.top_k(&input, 2, Some("topk")).expect("topk");
19
20 let input_data = TensorData::from_f32_slice(&device, &[1.0, 3.0, 2.0, 4.0, 6.0, 5.0], &[2, 3])
21 .expect("tensor data");
22 let results = graph
23 .run(&[Feed::new(&input, &input_data)], &[&row_sum, &topk.0])
24 .expect("run");
25
26 println!("row sums: {:?}", results[0].read_f32().expect("row sums"));
27 println!("top-k values: {:?}", results[1].read_f32().expect("topk values"));
28}More examples
examples/01_add_relu.rs (line 26)
4fn main() {
5 let device = MetalDevice::system_default().expect("no Metal device available");
6 let graph = Graph::new().expect("failed to create MPSGraph");
7
8 let input = graph
9 .placeholder(Some(&[2, 2]), data_type::FLOAT32, Some("input"))
10 .expect("failed to create placeholder");
11 let bias = graph
12 .constant_scalar(1.0, data_type::FLOAT32)
13 .expect("failed to create scalar constant");
14 let added = graph
15 .addition(&input, &bias, Some("add"))
16 .expect("failed to create addition op");
17 let output = graph
18 .relu(&added, Some("relu"))
19 .expect("failed to create relu op");
20
21 let input_data = TensorData::from_f32_slice(&device, &[1.0, -2.0, 3.0, -4.0], &[2, 2])
22 .expect("failed to create tensor data");
23 let results = graph
24 .run(&[Feed::new(&input, &input_data)], &[&output])
25 .expect("failed to execute graph");
26 let values = results[0].read_f32().expect("failed to read tensor output");
27
28 assert_eq!(values, vec![2.0, 0.0, 4.0, 0.0]);
29 println!("add+relu smoke passed: {values:?}");
30}examples/02_compile_matmul.rs (line 41)
4fn main() {
5 let device = MetalDevice::system_default().expect("no Metal device available");
6 let queue = device
7 .new_command_queue()
8 .expect("failed to create command queue");
9 let graph = Graph::new().expect("failed to create MPSGraph");
10
11 let left = graph
12 .placeholder(Some(&[2, 3]), data_type::FLOAT32, Some("left"))
13 .expect("failed to create left placeholder");
14 let right = graph
15 .placeholder(Some(&[3, 2]), data_type::FLOAT32, Some("right"))
16 .expect("failed to create right placeholder");
17 let output = graph
18 .matrix_multiplication(&left, &right, Some("matmul"))
19 .expect("failed to create matrix multiplication op");
20
21 let executable = graph
22 .compile(
23 &device,
24 &[
25 FeedDescription::new(&left, &[2, 3], data_type::FLOAT32),
26 FeedDescription::new(&right, &[3, 2], data_type::FLOAT32),
27 ],
28 &[&output],
29 )
30 .expect("failed to compile executable");
31
32 let left_data = TensorData::from_f32_slice(&device, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3])
33 .expect("failed to create left tensor data");
34 let right_data =
35 TensorData::from_f32_slice(&device, &[7.0, 8.0, 9.0, 10.0, 11.0, 12.0], &[3, 2])
36 .expect("failed to create right tensor data");
37
38 let results = executable
39 .run(&queue, &[&left_data, &right_data])
40 .expect("failed to run executable");
41 let values = results[0].read_f32().expect("failed to read tensor output");
42 let expected = [58.0_f32, 64.0, 139.0, 154.0];
43 for (actual, expected_value) in values.iter().zip(expected) {
44 assert!(
45 (actual - expected_value).abs() < 1.0e-4,
46 "unexpected matrix multiply result: {values:?}"
47 );
48 }
49
50 println!("compile+matmul smoke passed: {values:?}");
51}Source§impl TensorData
impl TensorData
Sourcepub fn graph_device_type(&self) -> Option<u32>
pub fn graph_device_type(&self) -> Option<u32>
Return the graph-device type that backs this tensor data.
Trait Implementations§
Auto Trait Implementations§
impl Freeze for TensorData
impl RefUnwindSafe for TensorData
impl Unpin for TensorData
impl UnsafeUnpin for TensorData
impl UnwindSafe for TensorData
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more