use std::collections::HashMap;
use std::convert::TryInto;
use std::path::PathBuf;
use std::sync::Arc;
#[cfg(feature = "ndarray-backend")]
use crate::data::tensor::NdArrayDType;
#[cfg(feature = "tch-backend")]
use crate::data::tensor::TchDType;
use crate::data::action::RelayRLData;
use crate::data::tensor::{
AnyBurnTensor, BackendMatcher, BoolBurnTensor, DType, DeviceType, FloatBurnTensor,
IntBurnTensor,
};
use burn_tensor::{Shape, backend::Backend};
use crate::model::{ModelError, ModelModule};
pub fn convert_generic_dict(
dict: &HashMap<String, RelayRLData>,
) -> Option<HashMap<String, RelayRLData>> {
Some(dict.clone())
}
pub fn validate_module<B: Backend + BackendMatcher<Backend = B> + 'static>(
module: &ModelModule<B>,
) -> Result<(), ModelError> {
let device = module.resolve_device();
let input_shape = &module.metadata.input_shape;
let output_shape = &module.metadata.output_shape;
if !(1..=9).contains(&input_shape.len()) || !(1..=9).contains(&output_shape.len()) {
return Err(ModelError::UnsupportedRank(format!(
"Unsupported ranks: input {} output {}",
input_shape.len(),
output_shape.len()
)));
}
match input_shape.len() {
1 => validate_with_input::<B, 1>(module, &device, input_shape, output_shape),
2 => validate_with_input::<B, 2>(module, &device, input_shape, output_shape),
3 => validate_with_input::<B, 3>(module, &device, input_shape, output_shape),
4 => validate_with_input::<B, 4>(module, &device, input_shape, output_shape),
5 => validate_with_input::<B, 5>(module, &device, input_shape, output_shape),
6 => validate_with_input::<B, 6>(module, &device, input_shape, output_shape),
7 => validate_with_input::<B, 7>(module, &device, input_shape, output_shape),
8 => validate_with_input::<B, 8>(module, &device, input_shape, output_shape),
9 => validate_with_input::<B, 9>(module, &device, input_shape, output_shape),
_ => unreachable!(),
}
}
fn validate_with_input<B: Backend + BackendMatcher<Backend = B> + 'static, const D_IN: usize>(
module: &ModelModule<B>,
device: &<B as Backend>::Device,
input_shape: &[usize],
output_shape: &[usize],
) -> Result<(), ModelError> {
match output_shape.len() {
1 => call_validate::<B, D_IN, 1>(module, device, input_shape, output_shape),
2 => call_validate::<B, D_IN, 2>(module, device, input_shape, output_shape),
3 => call_validate::<B, D_IN, 3>(module, device, input_shape, output_shape),
4 => call_validate::<B, D_IN, 4>(module, device, input_shape, output_shape),
5 => call_validate::<B, D_IN, 5>(module, device, input_shape, output_shape),
6 => call_validate::<B, D_IN, 6>(module, device, input_shape, output_shape),
7 => call_validate::<B, D_IN, 7>(module, device, input_shape, output_shape),
8 => call_validate::<B, D_IN, 8>(module, device, input_shape, output_shape),
9 => call_validate::<B, D_IN, 9>(module, device, input_shape, output_shape),
_ => Err(ModelError::UnsupportedRank(format!(
"Unsupported ranks: input {} output {}",
input_shape.len(),
output_shape.len()
))),
}
}
fn call_validate<
B: Backend + BackendMatcher<Backend = B> + 'static,
const D_IN: usize,
const D_OUT: usize,
>(
module: &ModelModule<B>,
device: &<B as Backend>::Device,
input_shape: &[usize],
output_shape: &[usize],
) -> Result<(), ModelError> {
let input_array: [usize; D_IN] = slice_to_array::<D_IN>(input_shape)?;
let output_array: [usize; D_OUT] = slice_to_array::<D_OUT>(output_shape)?;
let input_shape = Shape::from(input_array);
let output_shape = Shape::from(output_array);
validate_model_shapes::<B, D_IN, D_OUT>(module, device, &input_shape, &output_shape)
}
fn slice_to_array<const N: usize>(shape: &[usize]) -> Result<[usize; N], ModelError> {
shape.try_into().map_err(|_| {
ModelError::InvalidMetadata(format!(
"Expected dimension of length {N}, but got {}",
shape.len()
))
})
}
fn validate_model_shapes<
B: Backend + BackendMatcher<Backend = B> + 'static,
const D_IN: usize,
const D_OUT: usize,
>(
module: &ModelModule<B>,
device: &<B as Backend>::Device,
input_shape: &Shape,
output_shape: &Shape,
) -> Result<(), ModelError> {
let obs: Arc<AnyBurnTensor<B, D_IN>> = match &module.metadata.input_dtype {
#[cfg(feature = "ndarray-backend")]
DType::NdArray(nd) => match nd {
NdArrayDType::F16 | NdArrayDType::F32 | NdArrayDType::F64 => {
Arc::new(AnyBurnTensor::Float(FloatBurnTensor::empty(
input_shape,
&DType::NdArray(nd.clone()),
device,
)))
}
NdArrayDType::I8 | NdArrayDType::I16 | NdArrayDType::I32 | NdArrayDType::I64 => {
Arc::new(AnyBurnTensor::Int(IntBurnTensor::empty(
input_shape,
&DType::NdArray(nd.clone()),
device,
)))
}
NdArrayDType::Bool => Arc::new(AnyBurnTensor::Bool(BoolBurnTensor::empty(
input_shape,
&DType::NdArray(nd.clone()),
device,
))),
},
#[cfg(feature = "tch-backend")]
DType::Tch(tch) => match tch {
TchDType::F16 | TchDType::Bf16 | TchDType::F32 | TchDType::F64 => {
Arc::new(AnyBurnTensor::Float(FloatBurnTensor::empty(
input_shape,
&DType::Tch(tch.clone()),
device,
)))
}
TchDType::I8 | TchDType::I16 | TchDType::I32 | TchDType::I64 | TchDType::U8 => {
Arc::new(AnyBurnTensor::Int(IntBurnTensor::empty(
input_shape,
&DType::Tch(tch.clone()),
device,
)))
}
TchDType::Bool => Arc::new(AnyBurnTensor::Bool(BoolBurnTensor::empty(
input_shape,
&DType::Tch(tch.clone()),
device,
))),
},
};
let mask: Arc<AnyBurnTensor<B, D_OUT>> = match &module.metadata.output_dtype {
#[cfg(feature = "ndarray-backend")]
DType::NdArray(nd) => match nd {
NdArrayDType::F16 | NdArrayDType::F32 | NdArrayDType::F64 => {
Arc::new(AnyBurnTensor::Float(FloatBurnTensor::empty(
output_shape,
&DType::NdArray(nd.clone()),
device,
)))
}
NdArrayDType::I8 | NdArrayDType::I16 | NdArrayDType::I32 | NdArrayDType::I64 => {
Arc::new(AnyBurnTensor::Int(IntBurnTensor::empty(
output_shape,
&DType::NdArray(nd.clone()),
device,
)))
}
NdArrayDType::Bool => Arc::new(AnyBurnTensor::Bool(BoolBurnTensor::empty(
output_shape,
&DType::NdArray(nd.clone()),
device,
))),
},
#[cfg(feature = "tch-backend")]
DType::Tch(tch) => match tch {
TchDType::F16 | TchDType::Bf16 | TchDType::F32 | TchDType::F64 => {
Arc::new(AnyBurnTensor::Float(FloatBurnTensor::empty(
output_shape,
&DType::Tch(tch.clone()),
device,
)))
}
TchDType::I8 | TchDType::I16 | TchDType::I32 | TchDType::I64 | TchDType::U8 => {
Arc::new(AnyBurnTensor::Int(IntBurnTensor::empty(
output_shape,
&DType::Tch(tch.clone()),
device,
)))
}
TchDType::Bool => Arc::new(AnyBurnTensor::Bool(BoolBurnTensor::empty(
output_shape,
&DType::Tch(tch.clone()),
device,
))),
},
};
let (action_tensor, _, _) = module.step::<D_IN, D_OUT>(obs, Some(mask));
let action_dims: &Vec<usize> = &action_tensor.shape;
let output_dims: &Vec<usize> = &output_shape.dims;
for (a, o) in action_dims.iter().zip(output_dims.iter()) {
if *a != *o {
return Err(ModelError::InvalidOutputDimension(format!(
"Model output shape mismatch: expected {:?}, got {:?}",
output_dims, action_dims
)));
}
}
Ok(())
}
pub fn serialize_model_module<B: Backend + BackendMatcher<Backend = B>>(
model: &ModelModule<B>,
dir: PathBuf,
) -> Vec<u8> {
let temp_dir = tempfile::Builder::new()
.prefix("_model_save_")
.tempdir_in(dir)
.expect("Failed to create temp dir");
let temp_path = temp_dir.path().to_path_buf();
ModelModule::<B>::save(model, &temp_path).expect("Failed to save model");
let meta_bytes =
std::fs::read(temp_path.join("metadata.json")).expect("Failed to read metadata.json");
let model_path = temp_path.join(&model.metadata.model_file);
let model_bytes = std::fs::read(&model_path).expect("Failed to read model file");
let meta_len = meta_bytes.len() as u32;
let mut bundle = Vec::with_capacity(4 + meta_bytes.len() + model_bytes.len());
bundle.extend_from_slice(&meta_len.to_le_bytes());
bundle.extend_from_slice(&meta_bytes);
bundle.extend_from_slice(&model_bytes);
bundle
}
pub fn deserialize_model_module<B: Backend + BackendMatcher<Backend = B>>(
bundle: Vec<u8>,
_device: DeviceType,
) -> Result<ModelModule<B>, ModelError> {
if bundle.len() < 4 {
return Err(ModelError::InvalidMetadata("bundle too short".into()));
}
let meta_len = u32::from_le_bytes(bundle[..4].try_into().unwrap()) as usize;
if bundle.len() < 4 + meta_len {
return Err(ModelError::InvalidMetadata("bundle truncated".into()));
}
let meta_bytes = &bundle[4..4 + meta_len];
let model_bytes = &bundle[4 + meta_len..];
let temp_dir = tempfile::TempDir::new().expect("Failed to create temp dir");
let temp_path = temp_dir.path();
let meta_json: serde_json::Value = serde_json::from_slice(meta_bytes)
.map_err(|e| ModelError::InvalidMetadata(e.to_string()))?;
let model_file = meta_json
.get("model_file")
.and_then(|v| v.as_str())
.unwrap_or("model.onnx");
std::fs::write(temp_path.join("metadata.json"), meta_bytes)
.expect("Failed to write metadata.json");
std::fs::write(temp_path.join(model_file), model_bytes).expect("Failed to write model file");
let result = ModelModule::<B>::load_from_path(temp_path)
.map_err(|e| ModelError::InvalidMetadata(format!("{e}")))?;
drop(temp_dir);
Ok(result)
}
#[cfg(all(
test,
feature = "ndarray-backend",
any(feature = "tch-model", feature = "onnx-model")
))]
mod unit_tests {
use std::collections::HashMap;
use std::marker::PhantomData;
use std::sync::Arc;
use burn_ndarray::NdArray;
use super::{convert_generic_dict, validate_module};
use crate::data::action::RelayRLData;
use crate::data::tensor::{DType, DeviceType, NdArrayDType};
use crate::model::{
InferenceModel, Model, ModelError, ModelFileType, ModelMetadata, ModelModule,
};
fn stub_module(rank: usize) -> ModelModule<NdArray> {
let dims = vec![2; rank];
ModelModule {
model: Model {
file_type: ModelFileType::Onnx,
raw_bytes: Arc::<[u8]>::from(Vec::<u8>::new()),
inference: InferenceModel::Unsupported,
_phantom: PhantomData,
},
metadata: ModelMetadata {
model_file: "test.onnx".to_string(),
model_type: ModelFileType::Onnx,
input_dtype: DType::NdArray(NdArrayDType::F32),
output_dtype: DType::NdArray(NdArrayDType::F32),
input_shape: dims.clone(),
output_shape: dims,
default_device: Some(DeviceType::Cpu),
},
}
}
#[test]
fn convert_generic_dict_clones_auxiliary_data() {
let mut dict = HashMap::new();
dict.insert("reward".to_string(), RelayRLData::F32(1.25));
let cloned = convert_generic_dict(&dict).expect("the helper should always return data");
dict.insert("done".to_string(), RelayRLData::Bool(true));
assert_eq!(cloned.len(), 1);
assert!(matches!(
cloned.get("reward"),
Some(RelayRLData::F32(value)) if (*value - 1.25).abs() < f32::EPSILON
));
assert!(!cloned.contains_key("done"));
}
#[test]
fn validate_module_accepts_rank_one() {
let module = stub_module(1);
assert!(validate_module::<NdArray>(&module).is_ok());
}
#[test]
fn validate_module_accepts_rank_nine() {
let module = stub_module(9);
assert!(validate_module::<NdArray>(&module).is_ok());
}
#[test]
fn validate_module_rejects_rank_zero() {
let module = stub_module(0);
let err = validate_module::<NdArray>(&module)
.expect_err("rank 0 should remain outside the supported range");
assert!(matches!(
err,
ModelError::UnsupportedRank(message) if message.contains("input 0 output 0")
));
}
#[test]
fn validate_module_rejects_rank_ten() {
let module = stub_module(10);
let err = validate_module::<NdArray>(&module)
.expect_err("rank 10 should remain outside the supported range");
assert!(matches!(
err,
ModelError::UnsupportedRank(message) if message.contains("input 10 output 10")
));
}
}