basic/
basic.rs

1use pjrt::ProgramFormat::MLIR;
2use pjrt::{self, Client, HostBuffer, LoadedExecutable, Result};
3
4const CODE: &'static [u8] = include_bytes!("program.mlir");
5
6fn main() -> Result<()> {
7    let api = pjrt::plugin("pjrt_c_api_cpu_plugin.so").load()?;
8    println!("api_version = {:?}", api.version());
9
10    let client = Client::builder(&api).build()?;
11    println!("platform_name = {}", client.platform_name());
12
13    let program = pjrt::Program::new(MLIR, CODE);
14
15    let loaded_executable = LoadedExecutable::builder(&client, &program).build()?;
16
17    let a = HostBuffer::scalar(1.0f32);
18    println!("input = {:?}", a);
19
20    let inputs = a.copy_to_sync(&client)?;
21
22    let result = loaded_executable.execution(inputs).run_sync()?;
23
24    let ouput = &result[0][0];
25    let output = ouput.copy_to_host_sync()?;
26    println!("output= {:?}", output);
27
28    Ok(())
29}