use anyhow::Result;
extern crate xla;
const USE_TEXT_FORMAT: bool = false;
fn main() -> Result<()> {
xla::set_tf_min_log_level(xla::TfLogLevel::Warning);
let client = xla::PjRtClient::cpu()?;
println!("{} {} {}", client.platform_name(), client.platform_version(), client.device_count());
let proto = if USE_TEXT_FORMAT {
xla::HloModuleProto::from_text_file("examples/fn_hlo.txt")?
} else {
xla::HloModuleProto::from_proto_file("examples/fn_hlo.pb", true)?
};
let comp = xla::XlaComputation::from_proto(&proto);
let result = client.compile(&comp)?;
let x = xla::Literal::vec1(&[1f32, 2f32, 3f32, 4f32]).reshape(&[2, 2])?;
let y = xla::Literal::vec1(&[1f32, 1f32, 1f32, 1f32]).reshape(&[2, 2])?;
let result = result.execute::<xla::Literal>(&[x, y])?[0][0].to_literal_sync()?;
let result = &result.to_tuple1()?;
let shape = result.shape()?;
println!("Result: {:?} {:?}", shape, result.to_vec::<f32>(),);
Ok(())
}