use iree_embedded::{Arena, Context, Device, Instance, Tensor, include_vmfb};
#[cfg(target_arch = "aarch64")]
static VMFB: &[u8] = include_vmfb!("fixtures/micro_speech-aarch64.vmfb");
#[cfg(target_arch = "x86_64")]
static VMFB: &[u8] = include_vmfb!("fixtures/micro_speech-x86_64.vmfb");
static YES_FEATURES: &[u8] = include_bytes!("fixtures/yes_features.bin");
const LABELS: [&str; 4] = ["silence", "unknown", "yes", "no"];
#[test]
fn micro_speech_predicts_yes() {
static mut BUF: [u8; 4 * 1024 * 1024] = [0; 4 * 1024 * 1024];
let arena = unsafe { Arena::new(&mut *core::ptr::addr_of_mut!(BUF)) };
let instance = Instance::new(&arena).expect("instance");
let device = Device::local_sync(&arena).expect("device");
let ctx = Context::new(&instance, &device, VMFB, &arena).expect("context");
let infer = ctx.resolve("module.tf2onnx").expect("resolve");
assert_eq!(YES_FEATURES.len(), 49 * 40);
let input = Tensor::from_u8(&device, &[1, 49, 40, 1], YES_FEATURES).expect("input");
let outputs = ctx.invoke(infer, &[&input], &arena).expect("invoke");
let mut logits = [0.0f32; 4];
outputs[0]
.read_into_f32(&device, &mut logits)
.expect("read");
let best = logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.total_cmp(b))
.map(|(i, _)| i)
.unwrap();
assert_eq!(LABELS[best], "yes", "logits = {logits:?}");
}