Skip to main content

openinfer_simulator/runtime/executor/
fetch.rs

1use anyhow::{anyhow, Result};
2
3use crate::runtime::state::RuntimeState;
4use crate::runtime::value_eval::{tensor_to_bool, tensor_to_i64};
5use crate::tensor::{Tensor, TensorElement, TensorValue};
6
7/// Adapter for fetching runtime values by name.
8pub trait Fetchable: Sized {
9    fn fetch(state: &mut RuntimeState, name: &str) -> Result<Self>;
10}
11
12impl Fetchable for TensorValue {
13    fn fetch(state: &mut RuntimeState, name: &str) -> Result<Self> {
14        state.get_tensor(name)
15    }
16}
17
18impl<T: TensorElement> Fetchable for Tensor<T> {
19    fn fetch(state: &mut RuntimeState, name: &str) -> Result<Self> {
20        state.fetch_typed(name)
21    }
22}
23
24macro_rules! impl_fetch_int {
25    ($t:ty) => {
26        impl Fetchable for $t {
27            fn fetch(state: &mut RuntimeState, name: &str) -> Result<Self> {
28                let tensor: TensorValue = state.get_tensor(name)?;
29                Ok(tensor_to_i64(&tensor)? as $t)
30            }
31        }
32    };
33}
34
35macro_rules! impl_fetch_float {
36    ($t:ty) => {
37        impl Fetchable for $t {
38            fn fetch(state: &mut RuntimeState, name: &str) -> Result<Self> {
39                let tensor: TensorValue = state.get_tensor(name)?;
40                Ok(tensor_to_f64(&tensor)? as $t)
41            }
42        }
43    };
44}
45
46impl_fetch_float!(f32);
47impl_fetch_float!(f64);
48impl_fetch_int!(i8);
49impl_fetch_int!(i16);
50impl_fetch_int!(i32);
51impl_fetch_int!(i64);
52impl_fetch_int!(u8);
53impl_fetch_int!(u16);
54impl_fetch_int!(u32);
55impl_fetch_int!(u64);
56
57impl Fetchable for bool {
58    fn fetch(state: &mut RuntimeState, name: &str) -> Result<Self> {
59        let tensor: TensorValue = state.get_tensor(name)?;
60        tensor_to_bool(&tensor)
61    }
62}
63
64fn tensor_to_f64(value: &TensorValue) -> Result<f64> {
65    if value.len() != 1 {
66        return Err(anyhow!("expected scalar value"));
67    }
68    match value {
69        TensorValue::I8(t) => Ok(t.data[0] as f64),
70        TensorValue::I16(t) => Ok(t.data[0] as f64),
71        TensorValue::I32(t) => Ok(t.data[0] as f64),
72        TensorValue::I64(t) => Ok(t.data[0] as f64),
73        TensorValue::U8(t) => Ok(t.data[0] as f64),
74        TensorValue::U16(t) => Ok(t.data[0] as f64),
75        TensorValue::U32(t) => Ok(t.data[0] as f64),
76        TensorValue::U64(t) => Ok(t.data[0] as f64),
77        TensorValue::Bool(t) => Ok(if t.data[0] { 1.0 } else { 0.0 }),
78        TensorValue::F16(t) => Ok(t.data[0].to_f32() as f64),
79        TensorValue::BF16(t) => Ok(t.data[0].to_f32() as f64),
80        TensorValue::F8(t) => Ok(t.data[0].to_f32() as f64),
81        TensorValue::F32(t) => Ok(t.data[0] as f64),
82        TensorValue::F64(t) => Ok(t.data[0]),
83        TensorValue::Bitset(t) => Ok(t.data[0].bits as f64),
84        TensorValue::I4(_)
85        | TensorValue::I2(_)
86        | TensorValue::I1(_)
87        | TensorValue::U4(_)
88        | TensorValue::U2(_)
89        | TensorValue::U1(_)
90        | TensorValue::T2(_)
91        | TensorValue::T1(_) => Err(anyhow!("packed scalars are not supported")),
92    }
93}