1use pjrt::{self, Client, HostBuffer, Result};
2
3#[tokio::main]
4async fn main() -> Result<()> {
5 let api = pjrt::plugin("pjrt_c_api_cpu_plugin.so").load()?;
6 println!("{:?}", api.plugin_attributes());
7
8 let client = Client::builder(&api).build()?;
9
10 let host_buf = HostBuffer::builder()
11 .data([1.0f32, 2.0, 3.0, 4.0])
12 .dims([2, 2])
13 .build();
14 println!("{:?}", host_buf);
15
16 let dev1 = client.lookup_addressable_device(0)?;
17 let dev2 = client.lookup_addressable_device(1)?;
18
19 println!("-- ASYNC --");
20 let dev_buf = host_buf.copy_to(&dev1).await?;
21 println!("to {:?}, {:?}", dev_buf.dims(), dev_buf.layout());
22
23 let b = dev_buf.copy_to_host().await?;
24 println!("to_host_buffer {:?}", b);
25
26 let b = dev_buf.copy_to_device(&dev2).await?;
27 println!("copy_to_device {:?}, {:?}", b.dims(), b.layout());
28
29 println!("-- SYNC --");
30 let dev_buf = host_buf.copy_to_sync(&dev1)?;
31 println!("to {:?}, {:?}", dev_buf.dims(), dev_buf.layout());
32
33 let b = dev_buf.copy_to_host_sync()?;
34 println!("to_host_buffer {:?}", b);
35
36 let b = dev_buf.copy_to_device_sync(&dev2)?;
37 println!("copy_to_device {:?}, {:?}", b.dims(), b.layout());
38
39 Ok(())
40}