openinfer_simulator/runtime/executor/
fetch.rs1use anyhow::{anyhow, Result};
2
3use crate::runtime::state::RuntimeState;
4use crate::runtime::value_eval::{tensor_to_bool, tensor_to_i64};
5use crate::tensor::{Tensor, TensorElement, TensorValue};
6
7pub trait Fetchable: Sized {
9 fn fetch(state: &mut RuntimeState, name: &str) -> Result<Self>;
10}
11
12impl Fetchable for TensorValue {
13 fn fetch(state: &mut RuntimeState, name: &str) -> Result<Self> {
14 state.get_tensor(name)
15 }
16}
17
18impl<T: TensorElement> Fetchable for Tensor<T> {
19 fn fetch(state: &mut RuntimeState, name: &str) -> Result<Self> {
20 state.fetch_typed(name)
21 }
22}
23
24macro_rules! impl_fetch_int {
25 ($t:ty) => {
26 impl Fetchable for $t {
27 fn fetch(state: &mut RuntimeState, name: &str) -> Result<Self> {
28 let tensor: TensorValue = state.get_tensor(name)?;
29 Ok(tensor_to_i64(&tensor)? as $t)
30 }
31 }
32 };
33}
34
35macro_rules! impl_fetch_float {
36 ($t:ty) => {
37 impl Fetchable for $t {
38 fn fetch(state: &mut RuntimeState, name: &str) -> Result<Self> {
39 let tensor: TensorValue = state.get_tensor(name)?;
40 Ok(tensor_to_f64(&tensor)? as $t)
41 }
42 }
43 };
44}
45
46impl_fetch_float!(f32);
47impl_fetch_float!(f64);
48impl_fetch_int!(i8);
49impl_fetch_int!(i16);
50impl_fetch_int!(i32);
51impl_fetch_int!(i64);
52impl_fetch_int!(u8);
53impl_fetch_int!(u16);
54impl_fetch_int!(u32);
55impl_fetch_int!(u64);
56
57impl Fetchable for bool {
58 fn fetch(state: &mut RuntimeState, name: &str) -> Result<Self> {
59 let tensor: TensorValue = state.get_tensor(name)?;
60 tensor_to_bool(&tensor)
61 }
62}
63
64fn tensor_to_f64(value: &TensorValue) -> Result<f64> {
65 if value.len() != 1 {
66 return Err(anyhow!("expected scalar value"));
67 }
68 match value {
69 TensorValue::I8(t) => Ok(t.data[0] as f64),
70 TensorValue::I16(t) => Ok(t.data[0] as f64),
71 TensorValue::I32(t) => Ok(t.data[0] as f64),
72 TensorValue::I64(t) => Ok(t.data[0] as f64),
73 TensorValue::U8(t) => Ok(t.data[0] as f64),
74 TensorValue::U16(t) => Ok(t.data[0] as f64),
75 TensorValue::U32(t) => Ok(t.data[0] as f64),
76 TensorValue::U64(t) => Ok(t.data[0] as f64),
77 TensorValue::Bool(t) => Ok(if t.data[0] { 1.0 } else { 0.0 }),
78 TensorValue::F16(t) => Ok(t.data[0].to_f32() as f64),
79 TensorValue::BF16(t) => Ok(t.data[0].to_f32() as f64),
80 TensorValue::F8(t) => Ok(t.data[0].to_f32() as f64),
81 TensorValue::F32(t) => Ok(t.data[0] as f64),
82 TensorValue::F64(t) => Ok(t.data[0]),
83 TensorValue::Bitset(t) => Ok(t.data[0].bits as f64),
84 TensorValue::I4(_)
85 | TensorValue::I2(_)
86 | TensorValue::I1(_)
87 | TensorValue::U4(_)
88 | TensorValue::U2(_)
89 | TensorValue::U1(_)
90 | TensorValue::T2(_)
91 | TensorValue::T1(_) => Err(anyhow!("packed scalars are not supported")),
92 }
93}