use std::collections::HashMap;
use std::convert::TryInto;
use std::io::Write;
use std::path::PathBuf;
use std::sync::Arc;
use tempfile::NamedTempFile;
#[cfg(feature = "tch-backend")]
use crate::data::tensor::TchDType;
#[cfg(feature = "ndarray-backend")]
use crate::data::tensor::NdArrayDType;
use crate::data::action::RelayRLData;
use crate::data::tensor::{
AnyBurnTensor, BackendMatcher, BoolBurnTensor, DType, DeviceType, FloatBurnTensor,
IntBurnTensor
};
use burn_tensor::{Shape, backend::Backend};
use crate::model::{ModelError, ModelModule};
pub fn convert_generic_dict(
dict: &HashMap<String, RelayRLData>,
) -> Option<HashMap<String, RelayRLData>> {
Some(dict.clone())
}
pub fn validate_module<B: Backend + BackendMatcher<Backend = B> + 'static>(
module: &ModelModule<B>,
) -> Result<(), ModelError> {
let device = module.resolve_device();
let input_shape = &module.metadata.input_shape;
let output_shape = &module.metadata.output_shape;
if !(1..=9).contains(&input_shape.len()) || !(1..=9).contains(&output_shape.len()) {
return Err(ModelError::UnsupportedRank(format!(
"Unsupported ranks: input {} output {}",
input_shape.len(),
output_shape.len()
)));
}
match input_shape.len() {
1 => validate_with_input::<B, 1>(module, &device, input_shape, output_shape),
2 => validate_with_input::<B, 2>(module, &device, input_shape, output_shape),
3 => validate_with_input::<B, 3>(module, &device, input_shape, output_shape),
4 => validate_with_input::<B, 4>(module, &device, input_shape, output_shape),
5 => validate_with_input::<B, 5>(module, &device, input_shape, output_shape),
6 => validate_with_input::<B, 6>(module, &device, input_shape, output_shape),
7 => validate_with_input::<B, 7>(module, &device, input_shape, output_shape),
8 => validate_with_input::<B, 8>(module, &device, input_shape, output_shape),
9 => validate_with_input::<B, 9>(module, &device, input_shape, output_shape),
_ => unreachable!(),
}
}
fn validate_with_input<B: Backend + BackendMatcher<Backend = B> + 'static, const D_IN: usize>(
module: &ModelModule<B>,
device: &<B as Backend>::Device,
input_shape: &[usize],
output_shape: &[usize],
) -> Result<(), ModelError> {
match output_shape.len() {
1 => call_validate::<B, D_IN, 1>(module, device, input_shape, output_shape),
2 => call_validate::<B, D_IN, 2>(module, device, input_shape, output_shape),
3 => call_validate::<B, D_IN, 3>(module, device, input_shape, output_shape),
4 => call_validate::<B, D_IN, 4>(module, device, input_shape, output_shape),
5 => call_validate::<B, D_IN, 5>(module, device, input_shape, output_shape),
6 => call_validate::<B, D_IN, 6>(module, device, input_shape, output_shape),
7 => call_validate::<B, D_IN, 7>(module, device, input_shape, output_shape),
8 => call_validate::<B, D_IN, 8>(module, device, input_shape, output_shape),
9 => call_validate::<B, D_IN, 9>(module, device, input_shape, output_shape),
_ => Err(ModelError::UnsupportedRank(format!(
"Unsupported ranks: input {} output {}",
input_shape.len(),
output_shape.len()
))),
}
}
fn call_validate<
B: Backend + BackendMatcher<Backend = B> + 'static,
const D_IN: usize,
const D_OUT: usize,
>(
module: &ModelModule<B>,
device: &<B as Backend>::Device,
input_shape: &[usize],
output_shape: &[usize],
) -> Result<(), ModelError> {
let input_array: [usize; D_IN] = slice_to_array::<D_IN>(input_shape)?;
let output_array: [usize; D_OUT] = slice_to_array::<D_OUT>(output_shape)?;
let input_shape = Shape::from(input_array);
let output_shape = Shape::from(output_array);
validate_model_shapes::<B, D_IN, D_OUT>(module, device, &input_shape, &output_shape)
}
fn slice_to_array<const N: usize>(shape: &[usize]) -> Result<[usize; N], ModelError> {
shape.try_into().map_err(|_| {
ModelError::InvalidMetadata(format!(
"Expected dimension of length {N}, but got {}",
shape.len()
))
})
}
fn validate_model_shapes<
B: Backend + BackendMatcher<Backend = B> + 'static,
const D_IN: usize,
const D_OUT: usize,
>(
module: &ModelModule<B>,
device: &<B as Backend>::Device,
input_shape: &Shape,
output_shape: &Shape,
) -> Result<(), ModelError> {
let obs: Arc<AnyBurnTensor<B, D_IN>> = match &module.metadata.input_dtype {
#[cfg(feature = "ndarray-backend")]
DType::NdArray(nd) => match nd {
NdArrayDType::F16 | NdArrayDType::F32 | NdArrayDType::F64 => {
Arc::new(AnyBurnTensor::Float(FloatBurnTensor::empty(
input_shape,
&DType::NdArray(nd.clone()),
device,
)))
}
NdArrayDType::I8 | NdArrayDType::I16 | NdArrayDType::I32 | NdArrayDType::I64 => {
Arc::new(AnyBurnTensor::Int(IntBurnTensor::empty(
input_shape,
&DType::NdArray(nd.clone()),
device,
)))
}
NdArrayDType::Bool => Arc::new(AnyBurnTensor::Bool(BoolBurnTensor::empty(
input_shape,
&DType::NdArray(nd.clone()),
device,
))),
},
#[cfg(feature = "tch-backend")]
DType::Tch(tch) => match tch {
TchDType::F16 | TchDType::Bf16 | TchDType::F32 | TchDType::F64 => {
Arc::new(AnyBurnTensor::Float(FloatBurnTensor::empty(
input_shape,
&DType::Tch(tch.clone()),
device,
)))
}
TchDType::I8 | TchDType::I16 | TchDType::I32 | TchDType::I64 | TchDType::U8 => {
Arc::new(AnyBurnTensor::Int(IntBurnTensor::empty(
input_shape,
&DType::Tch(tch.clone()),
device,
)))
}
TchDType::Bool => Arc::new(AnyBurnTensor::Bool(BoolBurnTensor::empty(
input_shape,
&DType::Tch(tch.clone()),
device,
))),
},
};
let mask: Arc<AnyBurnTensor<B, D_OUT>> = match &module.metadata.output_dtype {
#[cfg(feature = "ndarray-backend")]
DType::NdArray(nd) => match nd {
NdArrayDType::F16 | NdArrayDType::F32 | NdArrayDType::F64 => {
Arc::new(AnyBurnTensor::Float(FloatBurnTensor::empty(
output_shape,
&DType::NdArray(nd.clone()),
device,
)))
}
NdArrayDType::I8 | NdArrayDType::I16 | NdArrayDType::I32 | NdArrayDType::I64 => {
Arc::new(AnyBurnTensor::Int(IntBurnTensor::empty(
output_shape,
&DType::NdArray(nd.clone()),
device,
)))
}
NdArrayDType::Bool => Arc::new(AnyBurnTensor::Bool(BoolBurnTensor::empty(
output_shape,
&DType::NdArray(nd.clone()),
device,
))),
},
#[cfg(feature = "tch-backend")]
DType::Tch(tch) => match tch {
TchDType::F16 | TchDType::Bf16 | TchDType::F32 | TchDType::F64 => {
Arc::new(AnyBurnTensor::Float(FloatBurnTensor::empty(
output_shape,
&DType::Tch(tch.clone()),
device,
)))
}
TchDType::I8 | TchDType::I16 | TchDType::I32 | TchDType::I64 | TchDType::U8 => {
Arc::new(AnyBurnTensor::Int(IntBurnTensor::empty(
output_shape,
&DType::Tch(tch.clone()),
device,
)))
}
TchDType::Bool => Arc::new(AnyBurnTensor::Bool(BoolBurnTensor::empty(
output_shape,
&DType::Tch(tch.clone()),
device,
))),
},
};
let (action_tensor, _, _) = module.step::<D_IN, D_OUT>(obs, Some(mask));
let action_dims: &Vec<usize> = &action_tensor.shape;
let output_dims: &Vec<usize> = &output_shape.dims;
for (a, o) in action_dims.iter().zip(output_dims.iter()) {
if *a != *o {
return Err(ModelError::InvalidOutputDimension(format!(
"Model output shape mismatch: expected {:?}, got {:?}",
output_dims, action_dims
)));
}
}
Ok(())
}
pub fn serialize_model_module<B: Backend + BackendMatcher<Backend = B>>(
model: &ModelModule<B>,
dir: PathBuf,
) -> Vec<u8> {
let temp_file = tempfile::Builder::new()
.prefix("_model")
.suffix(".pt")
.tempfile_in(dir)
.expect("Failed to create temp file");
let temp_path = temp_file.path();
ModelModule::<B>::save(model, temp_path).expect("Failed to save model");
std::fs::read(temp_path).expect("Failed to read model bytes")
}
pub fn deserialize_model_module<B: Backend + BackendMatcher<Backend = B>>(
model_bytes: Vec<u8>,
_device: DeviceType,
) -> Result<ModelModule<B>, ModelError> {
let mut temp_file = NamedTempFile::new().expect("Failed to create temp file");
temp_file
.write_all(&model_bytes)
.expect("Failed to write model bytes");
temp_file.flush().expect("Failed to flush temp file");
Ok(
ModelModule::<B>::load_from_path(temp_file.path())
.expect("Failed to load model from bytes"),
)
}
#[cfg(all(
test,
feature = "ndarray-backend",
any(feature = "tch-model", feature = "onnx-model")
))]
mod unit_tests {
use std::collections::HashMap;
use std::marker::PhantomData;
use std::sync::Arc;
use burn_ndarray::NdArray;
use super::{convert_generic_dict, validate_module};
use crate::data::action::RelayRLData;
use crate::data::tensor::{DType, DeviceType, NdArrayDType};
use crate::model::{
InferenceModel, Model, ModelError, ModelFileType, ModelMetadata, ModelModule,
};
fn stub_module(rank: usize) -> ModelModule<NdArray> {
let dims = vec![2; rank];
ModelModule {
model: Model {
file_type: ModelFileType::Onnx,
raw_bytes: Arc::<[u8]>::from(Vec::<u8>::new()),
inference: InferenceModel::Unsupported,
_phantom: PhantomData,
},
metadata: ModelMetadata {
model_file: "test.onnx".to_string(),
model_type: ModelFileType::Onnx,
input_dtype: DType::NdArray(NdArrayDType::F32),
output_dtype: DType::NdArray(NdArrayDType::F32),
input_shape: dims.clone(),
output_shape: dims,
default_device: Some(DeviceType::Cpu),
},
}
}
#[test]
fn convert_generic_dict_clones_auxiliary_data() {
let mut dict = HashMap::new();
dict.insert("reward".to_string(), RelayRLData::F32(1.25));
let cloned = convert_generic_dict(&dict).expect("the helper should always return data");
dict.insert("done".to_string(), RelayRLData::Bool(true));
assert_eq!(cloned.len(), 1);
assert!(matches!(
cloned.get("reward"),
Some(RelayRLData::F32(value)) if (*value - 1.25).abs() < f32::EPSILON
));
assert!(!cloned.contains_key("done"));
}
#[test]
fn validate_module_accepts_rank_one() {
let module = stub_module(1);
assert!(validate_module::<NdArray>(&module).is_ok());
}
#[test]
fn validate_module_accepts_rank_nine() {
let module = stub_module(9);
assert!(validate_module::<NdArray>(&module).is_ok());
}
#[test]
fn validate_module_rejects_rank_zero() {
let module = stub_module(0);
let err = validate_module::<NdArray>(&module)
.expect_err("rank 0 should remain outside the supported range");
assert!(matches!(
err,
ModelError::UnsupportedRank(message) if message.contains("input 0 output 0")
));
}
#[test]
fn validate_module_rejects_rank_ten() {
let module = stub_module(10);
let err = validate_module::<NdArray>(&module)
.expect_err("rank 10 should remain outside the supported range");
assert!(matches!(
err,
ModelError::UnsupportedRank(message) if message.contains("input 10 output 10")
));
}
}