use std::collections::HashMap;
use ndarray::ArrayD;
use ort::session::Session;
use ort::value::ValueType;
use crate::errors::error::{SurrealError, SurrealErrorStatus};
use crate::execution::session::get_session;
use crate::safe_eject;
use crate::storage::surml_file::SurMlFile;
pub struct ModelComputation<'a> {
pub surml_file: &'a mut SurMlFile,
}
impl ModelComputation<'_> {
pub fn input_tensor_from_key_bindings(
&self,
input_values: HashMap<String, f32>,
) -> Result<ArrayD<f32>, SurrealError> {
let buffer = self.input_vector_from_key_bindings(input_values)?;
Ok(ndarray::arr1::<f32>(&buffer).into_dyn())
}
fn process_input_dims(session_ref: &Session) -> Result<Vec<usize>, SurrealError> {
let inputs = session_ref.inputs();
if inputs.is_empty() {
return Err(SurrealError {
message: "No inputs found in session".into(),
status: SurrealErrorStatus::Unknown,
});
}
let dtype = inputs[0].dtype();
let unwrapped_dims = match dtype {
ValueType::Tensor {
ty: _,
shape,
dimension_symbols: _,
} => shape,
_ => {
return Err(SurrealError {
message: "input dims not found".into(),
status: SurrealErrorStatus::Unknown,
});
}
};
let mut dims_cache = Vec::new();
for dim in unwrapped_dims.iter() {
if dim < &0 {
dims_cache.push((dim * -1) as usize);
} else {
dims_cache.push(*dim as usize);
}
}
Ok(dims_cache)
}
pub fn input_vector_from_key_bindings(
&self,
mut input_values: HashMap<String, f32>,
) -> Result<Vec<f32>, SurrealError> {
let mut buffer = Vec::with_capacity(self.surml_file.header.keys.store.len());
for key in &self.surml_file.header.keys.store {
let value = match input_values.get_mut(key) {
Some(value) => value,
None => {
return Err(SurrealError::new(
format!(
"src/execution/compute.rs 67: Key {} not found in input values",
key
),
SurrealErrorStatus::NotFound,
));
}
};
buffer.push(std::mem::take(value));
}
Ok(buffer)
}
pub fn raw_compute(
&self,
tensor: ArrayD<f32>,
_dims: Option<(i32, i32)>,
) -> Result<Vec<f32>, SurrealError> {
let mut session = get_session(self.surml_file.model.clone())?;
let dims_cache = ModelComputation::process_input_dims(&session)?;
let tensor = if dims_cache.is_empty() {
tensor
} else {
match tensor.into_shape_with_order(dims_cache) {
Ok(tensor) => tensor,
Err(_) => {
return Err(SurrealError::new(
"Failed to reshape tensor to input dimensions".to_string(),
SurrealErrorStatus::Unknown,
));
}
}
};
let tensor = match ort::value::Tensor::from_array(tensor) {
Ok(tensor) => tensor,
Err(_) => {
return Err(SurrealError::new(
"Failed to convert tensor to ort tensor".to_string(),
SurrealErrorStatus::Unknown,
));
}
};
let x = ort::inputs![tensor];
let outputs = safe_eject!(session.run(x), SurrealErrorStatus::Unknown);
let mut buffer: Vec<f32> = Vec::new();
match outputs[0].try_extract_tensor::<f32>() {
Ok((_shape, data)) => {
for i in data.iter() {
buffer.push(*i);
}
}
Err(_) => {
let (_shape, data) = safe_eject!(
outputs[0].try_extract_tensor::<i64>(),
SurrealErrorStatus::Unknown
);
for i in data.iter() {
buffer.push(*i as f32);
}
}
};
Ok(buffer)
}
pub fn buffered_compute(
&self,
input_values: &mut HashMap<String, f32>,
) -> Result<Vec<f32>, SurrealError> {
for (key, value) in &mut *input_values {
let value_ref = *value;
if let Some(normaliser) = self.surml_file.header.get_normaliser(&key.to_string())? {
*value = normaliser.normalise(value_ref);
}
}
let tensor = self.input_tensor_from_key_bindings(input_values.clone())?;
let output = self.raw_compute(tensor, None)?;
if self.surml_file.header.output.normaliser.is_none() {
return Ok(output);
}
let output_normaliser = match self.surml_file.header.output.normaliser.as_ref() {
Some(normaliser) => normaliser,
None => {
return Err(SurrealError::new(
String::from(
"No normaliser present for output which shouldn't happen as passed initial check for",
)
.to_string(),
SurrealErrorStatus::Unknown,
));
}
};
let mut buffer = Vec::with_capacity(output.len());
for value in output {
buffer.push(output_normaliser.inverse_normalise(value));
}
Ok(buffer)
}
}
#[cfg(test)]
mod tests {
#[cfg(any(
feature = "sklearn-tests",
feature = "onnx-tests",
feature = "torch-tests",
feature = "tensorflow-tests"
))]
use super::*;
#[cfg(any(
feature = "sklearn-tests",
feature = "onnx-tests",
feature = "torch-tests",
feature = "tensorflow-tests"
))]
#[cfg(feature = "sklearn-tests")]
#[test]
fn test_raw_compute_linear_sklearn() {
let mut file = SurMlFile::from_file("./model_stash/sklearn/surml/linear.surml").unwrap();
let model_computation = ModelComputation {
surml_file: &mut file,
};
let mut input_values = HashMap::new();
input_values.insert(String::from("squarefoot"), 1000.0);
input_values.insert(String::from("num_floors"), 2.0);
let raw_input = model_computation.input_tensor_from_key_bindings(input_values).unwrap();
let output = model_computation.raw_compute(raw_input, Some((1, 2))).unwrap();
assert_eq!(output.len(), 1);
assert_eq!(output[0], 985.57745);
}
#[cfg(feature = "sklearn-tests")]
#[test]
fn test_buffered_compute_linear_sklearn() {
let mut file = SurMlFile::from_file("./model_stash/sklearn/surml/linear.surml").unwrap();
let model_computation = ModelComputation {
surml_file: &mut file,
};
let mut input_values = HashMap::new();
input_values.insert(String::from("squarefoot"), 1000.0);
input_values.insert(String::from("num_floors"), 2.0);
let output = model_computation.buffered_compute(&mut input_values).unwrap();
assert_eq!(output.len(), 1);
}
#[cfg(feature = "onnx-tests")]
#[test]
fn test_raw_compute_linear_onnx() {
let mut file = SurMlFile::from_file("./model_stash/onnx/surml/linear.surml").unwrap();
let model_computation = ModelComputation {
surml_file: &mut file,
};
let mut input_values = HashMap::new();
input_values.insert(String::from("squarefoot"), 1000.0);
input_values.insert(String::from("num_floors"), 2.0);
let raw_input = model_computation.input_tensor_from_key_bindings(input_values).unwrap();
let output = model_computation.raw_compute(raw_input, Some((1, 2))).unwrap();
assert_eq!(output.len(), 1);
assert_eq!(output[0], 985.57745);
}
#[cfg(feature = "onnx-tests")]
#[test]
fn test_buffered_compute_linear_onnx() {
let mut file = SurMlFile::from_file("./model_stash/onnx/surml/linear.surml").unwrap();
let model_computation = ModelComputation {
surml_file: &mut file,
};
let mut input_values = HashMap::new();
input_values.insert(String::from("squarefoot"), 1000.0);
input_values.insert(String::from("num_floors"), 2.0);
let output = model_computation.buffered_compute(&mut input_values).unwrap();
assert_eq!(output.len(), 1);
}
#[cfg(feature = "torch-tests")]
#[test]
fn test_raw_compute_linear_torch() {
let mut file = SurMlFile::from_file("./model_stash/torch/surml/linear.surml").unwrap();
let model_computation = ModelComputation {
surml_file: &mut file,
};
let mut input_values = HashMap::new();
input_values.insert(String::from("squarefoot"), 1000.0);
input_values.insert(String::from("num_floors"), 2.0);
let raw_input = model_computation.input_tensor_from_key_bindings(input_values).unwrap();
let output = model_computation.raw_compute(raw_input, None).unwrap();
assert_eq!(output.len(), 1);
}
#[cfg(feature = "torch-tests")]
#[test]
fn test_buffered_compute_linear_torch() {
let mut file = SurMlFile::from_file("./model_stash/torch/surml/linear.surml").unwrap();
let model_computation = ModelComputation {
surml_file: &mut file,
};
let mut input_values = HashMap::new();
input_values.insert(String::from("squarefoot"), 1000.0);
input_values.insert(String::from("num_floors"), 2.0);
let output = model_computation.buffered_compute(&mut input_values).unwrap();
assert_eq!(output.len(), 1);
}
#[cfg(feature = "tensorflow-tests")]
#[test]
fn test_raw_compute_linear_tensorflow() {
let mut file = SurMlFile::from_file("./model_stash/tensorflow/surml/linear.surml").unwrap();
let model_computation = ModelComputation {
surml_file: &mut file,
};
let mut input_values = HashMap::new();
input_values.insert(String::from("squarefoot"), 1000.0);
input_values.insert(String::from("num_floors"), 2.0);
let raw_input = model_computation.input_tensor_from_key_bindings(input_values).unwrap();
let output = model_computation.raw_compute(raw_input, None).unwrap();
assert_eq!(output.len(), 1);
}
#[cfg(feature = "tensorflow-tests")]
#[test]
fn test_buffered_compute_linear_tensorflow() {
let mut file = SurMlFile::from_file("./model_stash/tensorflow/surml/linear.surml").unwrap();
let model_computation = ModelComputation {
surml_file: &mut file,
};
let mut input_values = HashMap::new();
input_values.insert(String::from("squarefoot"), 1000.0);
input_values.insert(String::from("num_floors"), 2.0);
let output = model_computation.buffered_compute(&mut input_values).unwrap();
assert_eq!(output.len(), 1);
}
}