pub mod hot_reloadable;
pub mod utils;
use std::collections::HashMap;
use std::fmt::Debug;
use std::fs;
use std::io::Read;
use std::marker::PhantomData;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::sync::Mutex;
use burn_tensor::backend::Backend;
use ort::tensor::IntoTensorElementType;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use crate::data::action::RelayRLData;
use crate::data::tensor::{
AnyBurnTensor, BackendMatcher, ConversionBurnTensor, DType, DeviceType, SupportedTensorBackend,
TensorData,
};
use half::f16;
#[cfg(feature = "tch-backend")]
use half::bf16;
#[cfg(feature = "ndarray-backend")]
use crate::data::tensor::NdArrayDType;
#[cfg(feature = "tch-backend")]
use crate::data::tensor::TchDType;
#[cfg(feature = "tch-model")]
use tch::{CModule, Tensor as TchTensor, no_grad};
#[cfg(feature = "onnx-model")]
use ort::{
session::{Session, SessionInputValue},
value::Value as OrtValue,
};
pub use burn_tensor::Shape;
pub use hot_reloadable::HotReloadableModel;
#[derive(Debug, Clone, Error)]
pub enum ModelError {
#[error("Serialization error: {0}")]
SerializationError(String),
#[error("Deserialization error: {0}")]
DeserializationError(String),
#[error("Backend error: {0}")]
BackendError(String),
#[error("DType error: {0}")]
DTypeError(String),
#[error("Invalid input dimension: {0}")]
InvalidInputDimension(String),
#[error("Invalid output dimension: {0}")]
InvalidOutputDimension(String),
#[error("Unsupported rank: {0}")]
UnsupportedRank(String),
#[error("Unsupported backend: {0}")]
UnsupportedBackend(String),
#[error("IO error: {0}")]
IoError(String),
#[error("JSON error: {0}")]
JsonError(String),
#[error("Unsupported model type: {0}")]
UnsupportedModelType(String),
#[error("Invalid metadata: {0}")]
InvalidMetadata(String),
}
impl From<std::io::Error> for ModelError {
fn from(e: std::io::Error) -> Self {
ModelError::IoError(e.to_string())
}
}
impl From<serde_json::Error> for ModelError {
fn from(e: serde_json::Error) -> Self {
ModelError::JsonError(e.to_string())
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum ModelFileType {
Pt,
Onnx,
}
impl ModelFileType {
pub fn from_path(path: &Path) -> Result<Self, ModelError> {
match path
.extension()
.and_then(|e| e.to_str())
.unwrap_or_default()
{
"pt" => Ok(ModelFileType::Pt),
"onnx" => Ok(ModelFileType::Onnx),
other => Err(ModelError::UnsupportedModelType(format!(
"Unsupported extension: {}",
other
))),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelMetadata {
pub model_file: String,
pub model_type: ModelFileType,
pub input_dtype: DType,
pub output_dtype: DType,
pub input_shape: Vec<usize>,
pub output_shape: Vec<usize>,
pub default_device: Option<DeviceType>,
}
impl ModelMetadata {
pub fn load_from_dir(dir: impl Into<PathBuf>) -> Result<Self, ModelError> {
let dir: PathBuf = dir.into();
let meta_path: PathBuf = dir.join("metadata.json");
let mut s = String::new();
fs::File::open(&meta_path)?.read_to_string(&mut s)?;
let meta: ModelMetadata = serde_json::from_str(&s)?;
if meta.model_file.trim().is_empty() {
return Err(ModelError::InvalidMetadata(
"metadata.model_file is empty".to_string(),
));
}
if meta.input_shape.is_empty() || meta.output_shape.is_empty() {
return Err(ModelError::InvalidMetadata(
"metadata input_shape/output_shape cannot be empty".to_string(),
));
}
Ok(meta)
}
pub fn save_to_dir(&self, dir: impl Into<PathBuf>) -> Result<(), ModelError> {
let dir: PathBuf = dir.into();
fs::create_dir_all(&dir)?;
let meta_path: PathBuf = dir.join("metadata.json");
let s = serde_json::to_string_pretty(self)?;
fs::write(meta_path, s)?;
Ok(())
}
pub fn resolve_model_path(&self, dir: &Path) -> PathBuf {
dir.join(&self.model_file)
}
}
#[derive(Debug, Clone)]
pub enum InferenceModel {
#[cfg(feature = "tch-model")]
Pt(Arc<CModule>),
#[cfg(feature = "onnx-model")]
Onnx(Arc<Mutex<Session>>),
Unsupported,
}
#[derive(Debug, Clone)]
pub struct Model<B: Backend + BackendMatcher<Backend = B>> {
pub file_type: ModelFileType,
raw_bytes: Arc<[u8]>,
inference: InferenceModel,
_phantom: PhantomData<B>,
}
impl<B: Backend + BackendMatcher<Backend = B>> Model<B> {
fn load_from_file(file_type: ModelFileType, path: &Path) -> Result<Self, ModelError> {
let raw_bytes: Arc<[u8]> = fs::read(path)?.into();
let inference: InferenceModel = Self::build_inference(file_type.clone(), path)?;
Ok(Self {
file_type,
raw_bytes,
inference,
_phantom: PhantomData,
})
}
fn build_inference(
file_type: ModelFileType,
path: &Path,
) -> Result<InferenceModel, ModelError> {
match file_type {
ModelFileType::Pt => {
#[cfg(feature = "tch-model")]
{
let module = CModule::load(path)
.map_err(|err| ModelError::BackendError(err.to_string()))?;
Ok(InferenceModel::Pt(Arc::new(module)))
}
#[cfg(not(feature = "tch-model"))]
{
Ok(InferenceModel::Unsupported)
}
}
ModelFileType::Onnx => {
#[cfg(feature = "onnx-model")]
{
let session = Arc::new(std::sync::Mutex::new(
Session::builder()
.map_err(|err| ModelError::BackendError(err.to_string()))?
.commit_from_file(path)
.map_err(|err| ModelError::BackendError(err.to_string()))?,
));
Ok(InferenceModel::Onnx(session))
}
#[cfg(not(feature = "onnx-model"))]
{
Ok(InferenceModel::Unsupported)
}
}
}
}
fn save_to_path(&self, path: &Path) -> Result<(), ModelError> {
if let Some(parent) = path.parent()
&& !parent.as_os_str().is_empty()
{
fs::create_dir_all(parent)?;
}
fs::write(path, self.raw_bytes.as_ref())?;
Ok(())
}
fn inference(&self) -> &InferenceModel {
&self.inference
}
pub fn from_onnx_bytes(bytes: Vec<u8>) -> Result<Self, ModelError> {
#[cfg(feature = "onnx-model")]
{
let session = Arc::new(std::sync::Mutex::new(
Session::builder()
.map_err(|e| ModelError::BackendError(e.to_string()))?
.commit_from_memory(&bytes)
.map_err(|e| ModelError::BackendError(e.to_string()))?,
));
let raw_bytes: Arc<[u8]> = bytes.into();
Ok(Self {
file_type: ModelFileType::Onnx,
raw_bytes,
inference: InferenceModel::Onnx(session),
_phantom: PhantomData,
})
}
#[cfg(not(feature = "onnx-model"))]
{
let raw_bytes: Arc<[u8]> = bytes.into();
Ok(Self {
file_type: ModelFileType::Onnx,
raw_bytes,
inference: InferenceModel::Unsupported,
_phantom: PhantomData,
})
}
}
}
#[derive(Clone)]
#[cfg(all(
any(feature = "tch-model", feature = "onnx-model"),
any(feature = "ndarray-backend", feature = "tch-backend")
))]
pub struct ModelModule<B: Backend + BackendMatcher<Backend = B>> {
pub model: Model<B>,
pub metadata: ModelMetadata,
}
impl<B: Backend + BackendMatcher<Backend = B>> ModelModule<B> {
pub fn load_from_path(path: impl Into<PathBuf>) -> Result<Self, ModelError> {
let path: PathBuf = path.into();
let dir = if path.is_dir() {
path
} else if path
.file_name()
.and_then(|n| n.to_str())
.map(|n| n.eq_ignore_ascii_case("metadata.json"))
.unwrap_or(false)
{
path.parent().unwrap_or(Path::new(".")).to_path_buf()
} else {
let dir = path.parent().unwrap_or(Path::new(".")).to_path_buf();
let meta_path = dir.join("metadata.json");
if !meta_path.exists() {
return Err(ModelError::InvalidMetadata(format!(
"metadata.json not found at {}",
meta_path.display()
)));
}
dir
};
let metadata = ModelMetadata::load_from_dir(&dir)?;
let model_path = metadata.resolve_model_path(&dir);
let file_type = ModelFileType::from_path(&model_path)?;
let model = Model::<B>::load_from_file(file_type, &model_path)?;
Ok(Self { model, metadata })
}
pub fn save(&self, dir: impl Into<PathBuf>) -> Result<(), ModelError> {
let dir: PathBuf = dir.into();
self.metadata.save_to_dir(&dir)?;
let model_path = self.metadata.resolve_model_path(&dir);
self.model.save_to_path(&model_path)?;
Ok(())
}
pub fn from_onnx_bytes(bytes: Vec<u8>, metadata: ModelMetadata) -> Result<Self, ModelError> {
let model = Model::<B>::from_onnx_bytes(bytes)?;
Ok(Self { model, metadata })
}
#[cfg(feature = "tch-model")]
pub fn from_pt_bytes(bytes: Vec<u8>, metadata: ModelMetadata) -> Result<Self, ModelError> {
use std::io::Write;
let mut temp_file = tempfile::NamedTempFile::new()
.map_err(|e| ModelError::BackendError(format!("Failed to create temp file: {}", e)))?;
temp_file
.write_all(&bytes)
.map_err(|e| ModelError::BackendError(format!("Failed to write temp file: {}", e)))?;
let temp_path = temp_file.path();
let module = CModule::load(temp_path)
.map_err(|e| ModelError::BackendError(format!("Failed to load CModule: {}", e)))?;
let raw_bytes: Arc<[u8]> = bytes.into();
let model = Self {
model: Model {
file_type: ModelFileType::Pt,
raw_bytes,
inference: InferenceModel::Pt(Arc::new(module)),
_phantom: PhantomData,
},
metadata,
};
Ok(model)
}
#[cfg(not(feature = "tch-model"))]
pub fn from_pt_bytes(bytes: Vec<u8>, metadata: ModelMetadata) -> Result<Self, ModelError> {
let raw_bytes: Arc<[u8]> = bytes.into();
Ok(Self {
model: Model {
file_type: ModelFileType::Pt,
raw_bytes,
inference: InferenceModel::Unsupported,
_phantom: PhantomData,
},
metadata,
})
}
#[cfg(all(
any(feature = "tch-model", feature = "onnx-model"),
any(feature = "ndarray-backend", feature = "tch-backend")
))]
pub fn step<const D_IN: usize, const D_OUT: usize>(
&self,
observation: Arc<AnyBurnTensor<B, D_IN>>,
mask: Option<Arc<AnyBurnTensor<B, D_OUT>>>,
) -> (TensorData, Option<TensorData>, HashMap<String, RelayRLData>) {
let base_action = self
.run_inference::<D_IN, D_OUT>(observation)
.unwrap_or_else(|_| {
self.zeros_action::<D_OUT>()
.expect("Failed to create zeros action")
});
let mask_td: Option<TensorData> = match mask {
Some(mask_tensor) => match mask_tensor.as_ref() {
AnyBurnTensor::Float(wrapper) => Some(
TensorData::try_from(ConversionBurnTensor {
inner: wrapper.tensor.clone(),
conversion_dtype: self.metadata.output_dtype.clone(),
})
.expect("Failed to convert mask tensor to TensorData"),
),
AnyBurnTensor::Int(wrapper) => Some(
TensorData::try_from(ConversionBurnTensor {
inner: wrapper.tensor.clone(),
conversion_dtype: self.metadata.output_dtype.clone(),
})
.expect("Failed to convert mask tensor to TensorData"),
),
AnyBurnTensor::Bool(wrapper) => Some(
TensorData::try_from(ConversionBurnTensor {
inner: wrapper.tensor.clone(),
conversion_dtype: self.metadata.output_dtype.clone(),
})
.expect("Failed to convert mask tensor to TensorData"),
),
},
None => None,
};
let act_td: TensorData = match mask_td {
Some(ref mask) => {
let action_data: Vec<u8> = base_action
.data
.iter()
.zip(mask.data.iter())
.map(|(a, m)| a * m)
.collect();
TensorData {
shape: base_action.shape.clone(),
dtype: base_action.dtype.clone(),
data: action_data,
supported_backend: base_action.supported_backend.clone(),
}
}
None => base_action,
};
(act_td, mask_td, HashMap::new())
}
#[cfg(all(
any(feature = "tch-model", feature = "onnx-model"),
any(feature = "ndarray-backend", feature = "tch-backend")
))]
pub fn step_batch<const D_IN: usize, const D_OUT: usize>(
&self,
observations: &[Arc<AnyBurnTensor<B, D_IN>>],
masks: &[Option<Arc<AnyBurnTensor<B, D_OUT>>>],
) -> Result<Vec<(TensorData, Option<TensorData>, HashMap<String, RelayRLData>)>, ModelError>
{
if observations.is_empty() {
return Ok(Vec::new());
}
if observations.len() != masks.len() {
return Err(ModelError::InvalidInputDimension(format!(
"batch observation/mask length mismatch: {} observations vs {} masks",
observations.len(),
masks.len()
)));
}
let observation_data: Vec<TensorData> = observations
.iter()
.map(|observation| self.observation_to_tensor_data(observation.clone()))
.collect::<Result<_, _>>()?;
let mask_data: Vec<Option<TensorData>> = masks
.iter()
.map(|mask| {
mask.as_ref()
.map(|mask_tensor| self.mask_to_tensor_data(mask_tensor.clone()))
.transpose()
})
.collect::<Result<_, _>>()?;
let batched_input = Self::stack_tensor_data(&observation_data)?;
let batched_output = self
.run_inference_tensor_data(batched_input)
.unwrap_or_else(|_| self.zeros_batch_action(observations.len()));
let split_actions = Self::split_tensor_data_rows(batched_output, observations.len())?;
Ok(split_actions
.into_iter()
.zip(mask_data)
.map(|(base_action, mask_td)| {
let act_td = match mask_td.as_ref() {
Some(mask) => Self::apply_mask_to_action(base_action, mask),
None => base_action,
};
(act_td, mask_td, HashMap::new())
})
.collect())
}
fn observation_to_tensor_data<const D_IN: usize>(
&self,
observation: Arc<AnyBurnTensor<B, D_IN>>,
) -> Result<TensorData, ModelError> {
match observation.as_ref() {
AnyBurnTensor::Float(wrapper) => TensorData::try_from(ConversionBurnTensor {
inner: wrapper.tensor.clone(),
conversion_dtype: self.metadata.input_dtype.clone(),
}),
AnyBurnTensor::Int(wrapper) => TensorData::try_from(ConversionBurnTensor {
inner: wrapper.tensor.clone(),
conversion_dtype: self.metadata.input_dtype.clone(),
}),
AnyBurnTensor::Bool(wrapper) => TensorData::try_from(ConversionBurnTensor {
inner: wrapper.tensor.clone(),
conversion_dtype: self.metadata.input_dtype.clone(),
}),
}
.map_err(|e| ModelError::BackendError(format!("Tensor conversion failed: {e}")))
}
fn mask_to_tensor_data<const D_OUT: usize>(
&self,
mask: Arc<AnyBurnTensor<B, D_OUT>>,
) -> Result<TensorData, ModelError> {
match mask.as_ref() {
AnyBurnTensor::Float(wrapper) => TensorData::try_from(ConversionBurnTensor {
inner: wrapper.tensor.clone(),
conversion_dtype: self.metadata.output_dtype.clone(),
}),
AnyBurnTensor::Int(wrapper) => TensorData::try_from(ConversionBurnTensor {
inner: wrapper.tensor.clone(),
conversion_dtype: self.metadata.output_dtype.clone(),
}),
AnyBurnTensor::Bool(wrapper) => TensorData::try_from(ConversionBurnTensor {
inner: wrapper.tensor.clone(),
conversion_dtype: self.metadata.output_dtype.clone(),
}),
}
.map_err(|e| ModelError::BackendError(format!("Mask conversion failed: {e}")))
}
fn apply_mask_to_action(base_action: TensorData, mask: &TensorData) -> TensorData {
let action_data: Vec<u8> = base_action
.data
.iter()
.zip(mask.data.iter())
.map(|(a, m)| a * m)
.collect();
TensorData {
shape: base_action.shape,
dtype: base_action.dtype,
data: action_data,
supported_backend: base_action.supported_backend,
}
}
fn stack_tensor_data(rows: &[TensorData]) -> Result<TensorData, ModelError> {
let first = rows.first().ok_or_else(|| {
ModelError::InvalidInputDimension("cannot stack an empty tensor batch".to_string())
})?;
for row in rows.iter().skip(1) {
if row.dtype != first.dtype || row.supported_backend != first.supported_backend {
return Err(ModelError::DTypeError(
"all batched observations must share dtype/backend".to_string(),
));
}
if row.shape != first.shape {
return Err(ModelError::InvalidInputDimension(format!(
"all batched observations must share shape: expected {:?}, got {:?}",
first.shape, row.shape
)));
}
}
let mut shape = Vec::with_capacity(first.shape.len() + 1);
shape.push(rows.len());
shape.extend(first.shape.iter().copied());
let mut data = Vec::with_capacity(first.data.len() * rows.len());
for row in rows {
data.extend_from_slice(&row.data);
}
Ok(TensorData::new(
shape,
first.dtype.clone(),
data,
first.supported_backend.clone(),
))
}
fn split_tensor_data_rows(
batch: TensorData,
rows: usize,
) -> Result<Vec<TensorData>, ModelError> {
if rows == 0 {
return Ok(Vec::new());
}
let Some((&batch_rows, row_shape)) = batch.shape.split_first() else {
return Err(ModelError::InvalidOutputDimension(
"batched action tensor must have at least one dimension".to_string(),
));
};
if batch_rows != rows {
return Err(ModelError::InvalidOutputDimension(format!(
"batched action count mismatch: tensor has {} rows, expected {}",
batch_rows, rows
)));
}
let row_bytes = batch.data.len() / rows;
if row_bytes * rows != batch.data.len() {
return Err(ModelError::InvalidOutputDimension(
"batched action data length is not divisible by row count".to_string(),
));
}
let row_shape = row_shape.to_vec();
let mut result = Vec::with_capacity(rows);
for index in 0..rows {
let start = index * row_bytes;
let end = start + row_bytes;
result.push(TensorData::new(
row_shape.clone(),
batch.dtype.clone(),
batch.data[start..end].to_vec(),
batch.supported_backend.clone(),
));
}
Ok(result)
}
fn resolve_device(&self) -> <B as Backend>::Device {
let preferred = self.metadata.default_device.clone().unwrap_or_default();
<B as BackendMatcher>::get_device(&preferred)
.or_else(|_| <B as BackendMatcher>::get_device(&DeviceType::default()))
.expect("Failed to resolve backend device")
}
fn zeros_batch_action(&self, rows: usize) -> TensorData {
let mut shape = Vec::with_capacity(self.metadata.output_shape.len() + 1);
shape.push(rows);
shape.extend(self.metadata.output_shape.iter().copied());
let row_zero = self
.zeros_action::<1>()
.expect("Failed to create zeros action for batch fallback");
let mut data = Vec::with_capacity(row_zero.data.len() * rows);
for _ in 0..rows {
data.extend_from_slice(&row_zero.data);
}
TensorData::new(shape, row_zero.dtype, data, row_zero.supported_backend)
}
fn zeros_action<const D_OUT: usize>(&self) -> Result<TensorData, ModelError> {
let shape = Shape::from(self.metadata.output_shape.clone());
match &self.metadata.output_dtype {
#[cfg(feature = "ndarray-backend")]
DType::NdArray(dtype) => match dtype {
NdArrayDType::F16 => {
let data_vec = vec![f16::ZERO; shape.dims.iter().product()];
let data: &[f16] = data_vec.as_slice();
let u8_data = bytemuck::cast_slice::<f16, u8>(data);
Ok(TensorData::new(
shape.dims.to_vec(),
DType::NdArray(dtype.clone()),
u8_data.to_vec(),
SupportedTensorBackend::NdArray,
))
}
NdArrayDType::F32 => {
let data_vec = vec![0_f32; shape.dims.iter().product()];
let data: &[f32] = data_vec.as_slice();
let u8_data = bytemuck::cast_slice::<f32, u8>(data);
Ok(TensorData::new(
shape.dims.to_vec(),
DType::NdArray(dtype.clone()),
u8_data.to_vec(),
SupportedTensorBackend::NdArray,
))
}
NdArrayDType::F64 => {
let data_vec = vec![0_f64; shape.dims.iter().product()];
let data: &[f64] = data_vec.as_slice();
let u8_data = bytemuck::cast_slice::<f64, u8>(data);
Ok(TensorData::new(
shape.dims.to_vec(),
DType::NdArray(dtype.clone()),
u8_data.to_vec(),
SupportedTensorBackend::NdArray,
))
}
NdArrayDType::I8 => {
let data_vec = vec![0_i8; shape.dims.iter().product()];
let data: &[i8] = data_vec.as_slice();
let u8_data = bytemuck::cast_slice::<i8, u8>(data);
Ok(TensorData::new(
shape.dims.to_vec(),
DType::NdArray(dtype.clone()),
u8_data.to_vec(),
SupportedTensorBackend::NdArray,
))
}
NdArrayDType::I16 => {
let data_vec = vec![0_i16; shape.dims.iter().product()];
let data: &[i16] = data_vec.as_slice();
let u8_data = bytemuck::cast_slice::<i16, u8>(data);
Ok(TensorData::new(
shape.dims.to_vec(),
DType::NdArray(dtype.clone()),
u8_data.to_vec(),
SupportedTensorBackend::NdArray,
))
}
NdArrayDType::I32 => {
let data_vec = vec![0_i32; shape.dims.iter().product()];
let data: &[i32] = data_vec.as_slice();
let u8_data = bytemuck::cast_slice::<i32, u8>(data);
Ok(TensorData::new(
shape.dims.to_vec(),
DType::NdArray(dtype.clone()),
u8_data.to_vec(),
SupportedTensorBackend::NdArray,
))
}
NdArrayDType::I64 => {
let data_vec = vec![0_i64; shape.dims.iter().product()];
let data: &[i64] = data_vec.as_slice();
let u8_data = bytemuck::cast_slice::<i64, u8>(data);
Ok(TensorData::new(
shape.dims.to_vec(),
DType::NdArray(dtype.clone()),
u8_data.to_vec(),
SupportedTensorBackend::NdArray,
))
}
NdArrayDType::Bool => {
let data_vec = vec![false; shape.dims.iter().product()];
let data: &[bool] = data_vec.as_slice();
let u8_data = bytemuck::cast_slice::<bool, u8>(data);
Ok(TensorData::new(
shape.dims.to_vec(),
DType::NdArray(dtype.clone()),
u8_data.to_vec(),
SupportedTensorBackend::NdArray,
))
}
},
#[cfg(feature = "tch-backend")]
DType::Tch(dtype) => match dtype {
TchDType::F16 => {
let data_vec = vec![f16::ZERO; shape.dims.iter().product()];
let data: &[f16] = data_vec.as_slice();
let u8_data = bytemuck::cast_slice::<f16, u8>(data);
Ok(TensorData::new(
shape.dims.to_vec(),
DType::Tch(dtype.clone()),
u8_data.to_vec(),
SupportedTensorBackend::Tch,
))
}
TchDType::Bf16 => {
let data_vec = vec![bf16::ZERO; shape.dims.iter().product()];
let data: &[bf16] = data_vec.as_slice();
let u8_data = bytemuck::cast_slice::<bf16, u8>(data);
Ok(TensorData::new(
shape.dims.to_vec(),
DType::Tch(dtype.clone()),
u8_data.to_vec(),
SupportedTensorBackend::Tch,
))
}
TchDType::F32 => {
let data_vec = vec![0_f32; shape.dims.iter().product()];
let data: &[f32] = data_vec.as_slice();
let u8_data = bytemuck::cast_slice::<f32, u8>(data);
Ok(TensorData::new(
shape.dims.to_vec(),
DType::Tch(dtype.clone()),
u8_data.to_vec(),
SupportedTensorBackend::Tch,
))
}
TchDType::F64 => {
let data_vec = vec![0_f64; shape.dims.iter().product()];
let data: &[f64] = data_vec.as_slice();
let u8_data = bytemuck::cast_slice::<f64, u8>(data);
Ok(TensorData::new(
shape.dims.to_vec(),
DType::Tch(dtype.clone()),
u8_data.to_vec(),
SupportedTensorBackend::Tch,
))
}
TchDType::I8 => {
let data_vec = vec![0_i8; shape.dims.iter().product()];
let data: &[i8] = data_vec.as_slice();
let u8_data = bytemuck::cast_slice::<i8, u8>(data);
Ok(TensorData::new(
shape.dims.to_vec(),
DType::Tch(dtype.clone()),
u8_data.to_vec(),
SupportedTensorBackend::Tch,
))
}
TchDType::I16 => {
let data_vec = vec![0_i16; shape.dims.iter().product()];
let data: &[i16] = data_vec.as_slice();
let u8_data = bytemuck::cast_slice::<i16, u8>(data);
Ok(TensorData::new(
shape.dims.to_vec(),
DType::Tch(dtype.clone()),
u8_data.to_vec(),
SupportedTensorBackend::Tch,
))
}
TchDType::I32 => {
let data_vec = vec![0_i32; shape.dims.iter().product()];
let data: &[i32] = data_vec.as_slice();
let u8_data = bytemuck::cast_slice::<i32, u8>(data);
Ok(TensorData::new(
shape.dims.to_vec(),
DType::Tch(dtype.clone()),
u8_data.to_vec(),
SupportedTensorBackend::Tch,
))
}
TchDType::I64 => {
let data_vec = vec![0_i64; shape.dims.iter().product()];
let data: &[i64] = data_vec.as_slice();
let u8_data = bytemuck::cast_slice::<i64, u8>(data);
Ok(TensorData::new(
shape.dims.to_vec(),
DType::Tch(dtype.clone()),
u8_data.to_vec(),
SupportedTensorBackend::Tch,
))
}
TchDType::U8 => {
let data_vec = vec![0_u8; shape.dims.iter().product()];
let data: &[u8] = data_vec.as_slice();
let u8_data = bytemuck::cast_slice::<u8, u8>(data);
Ok(TensorData::new(
shape.dims.to_vec(),
DType::Tch(dtype.clone()),
u8_data.to_vec(),
SupportedTensorBackend::Tch,
))
}
TchDType::Bool => {
let data_vec = vec![false; shape.dims.iter().product()];
let data: &[bool] = data_vec.as_slice();
let u8_data = bytemuck::cast_slice::<bool, u8>(data);
Ok(TensorData::new(
shape.dims.to_vec(),
DType::Tch(dtype.clone()),
u8_data.to_vec(),
SupportedTensorBackend::Tch,
))
}
},
}
}
fn run_inference<const D_IN: usize, const D_OUT: usize>(
&self,
observation: Arc<AnyBurnTensor<B, D_IN>>,
) -> Result<TensorData, ModelError> {
match self.model.inference() {
#[cfg(feature = "tch-model")]
InferenceModel::Pt(module) => {
self.run_libtorch_step::<D_IN, D_OUT>(module, observation)
}
#[cfg(feature = "onnx-model")]
InferenceModel::Onnx(session) => {
self.run_onnx_step::<D_IN, D_OUT>(session, observation)
}
_ => Err(ModelError::UnsupportedModelType(
"Unsupported model type".to_string(),
)),
}
}
pub fn flat_batch_inference(&self, input_data: TensorData) -> Result<TensorData, ModelError> {
self.run_inference_tensor_data(input_data)
}
pub fn flat_batch_zeros(&self, rows: usize) -> TensorData {
self.zeros_batch_action(rows)
}
fn run_inference_tensor_data(&self, input_data: TensorData) -> Result<TensorData, ModelError> {
match self.model.inference() {
#[cfg(feature = "tch-model")]
InferenceModel::Pt(module) => self.run_libtorch_step_data(module, input_data),
#[cfg(feature = "onnx-model")]
InferenceModel::Onnx(session) => self.run_onnx_step_data(session, input_data),
_ => Err(ModelError::UnsupportedModelType(
"Unsupported model type".to_string(),
)),
}
}
#[cfg(all(
feature = "tch-model",
any(feature = "ndarray-backend", feature = "tch-backend")
))]
fn run_libtorch_step_data(
&self,
module: &Arc<CModule>,
input_data: TensorData,
) -> Result<TensorData, ModelError> {
let obs_shape_i64: Vec<i64> = input_data.shape.iter().map(|&d| d as i64).collect();
let obs_tensor: TchTensor = match &input_data.dtype {
#[cfg(feature = "ndarray-backend")]
DType::NdArray(nd) => match nd {
NdArrayDType::F16 => {
TchTensor::from_slice::<f16>(bytemuck::cast_slice(&input_data.data))
.reshape(obs_shape_i64.as_slice())
}
NdArrayDType::F32 => {
TchTensor::from_slice::<f32>(bytemuck::cast_slice(&input_data.data))
.reshape(obs_shape_i64.as_slice())
}
NdArrayDType::F64 => {
TchTensor::from_slice::<f64>(bytemuck::cast_slice(&input_data.data))
.reshape(obs_shape_i64.as_slice())
}
NdArrayDType::I8 => {
TchTensor::from_slice::<i8>(bytemuck::cast_slice(&input_data.data))
.reshape(obs_shape_i64.as_slice())
}
NdArrayDType::I16 => {
TchTensor::from_slice::<i16>(bytemuck::cast_slice(&input_data.data))
.reshape(obs_shape_i64.as_slice())
}
NdArrayDType::I32 => {
TchTensor::from_slice::<i32>(bytemuck::cast_slice(&input_data.data))
.reshape(obs_shape_i64.as_slice())
}
NdArrayDType::I64 => {
TchTensor::from_slice::<i64>(bytemuck::cast_slice(&input_data.data))
.reshape(obs_shape_i64.as_slice())
}
NdArrayDType::Bool => {
TchTensor::from_slice::<u8>(bytemuck::cast_slice(&input_data.data))
.reshape(obs_shape_i64.as_slice())
}
},
#[cfg(feature = "tch-backend")]
DType::Tch(tch) => match tch {
TchDType::F16 => {
TchTensor::from_slice::<f16>(bytemuck::cast_slice(&input_data.data))
.reshape(obs_shape_i64.as_slice())
}
TchDType::Bf16 => {
TchTensor::from_slice::<bf16>(bytemuck::cast_slice(&input_data.data))
.reshape(obs_shape_i64.as_slice())
}
TchDType::F32 => {
TchTensor::from_slice::<f32>(bytemuck::cast_slice(&input_data.data))
.reshape(obs_shape_i64.as_slice())
}
TchDType::F64 => {
TchTensor::from_slice::<f64>(bytemuck::cast_slice(&input_data.data))
.reshape(obs_shape_i64.as_slice())
}
TchDType::I8 => TchTensor::from_slice::<i8>(bytemuck::cast_slice(&input_data.data))
.reshape(obs_shape_i64.as_slice()),
TchDType::I16 => {
TchTensor::from_slice::<i16>(bytemuck::cast_slice(&input_data.data))
.reshape(obs_shape_i64.as_slice())
}
TchDType::I32 => {
TchTensor::from_slice::<i32>(bytemuck::cast_slice(&input_data.data))
.reshape(obs_shape_i64.as_slice())
}
TchDType::I64 => {
TchTensor::from_slice::<i64>(bytemuck::cast_slice(&input_data.data))
.reshape(obs_shape_i64.as_slice())
}
TchDType::U8 => TchTensor::from_slice::<u8>(bytemuck::cast_slice(&input_data.data))
.reshape(obs_shape_i64.as_slice()),
TchDType::Bool => {
TchTensor::from_slice::<u8>(bytemuck::cast_slice(&input_data.data))
.reshape(obs_shape_i64.as_slice())
}
},
};
let act_tensor: TchTensor =
no_grad(|| module.forward_ts(&[&obs_tensor])).expect("Failed to run forward pass");
let output_shape: Vec<usize> = act_tensor
.size()
.into_iter()
.map(|dim| dim as usize)
.collect();
let flattened_act: TchTensor = act_tensor.flatten(0, -1);
let act_bytes: Vec<u8> = match &self.metadata.output_dtype {
#[cfg(feature = "ndarray-backend")]
DType::NdArray(dtype) => match dtype {
NdArrayDType::F16 => bytemuck::cast_slice(
&Vec::<f16>::try_from(flattened_act)
.expect("Failed to convert flattened_act to f16"),
)
.to_vec(),
NdArrayDType::F32 => bytemuck::cast_slice(
&Vec::<f32>::try_from(flattened_act)
.expect("Failed to convert flattened_act to f32"),
)
.to_vec(),
NdArrayDType::F64 => bytemuck::cast_slice(
&Vec::<f64>::try_from(flattened_act)
.expect("Failed to convert flattened_act to f64"),
)
.to_vec(),
NdArrayDType::I8 => bytemuck::cast_slice(
&Vec::<i8>::try_from(flattened_act)
.expect("Failed to convert flattened_act to i8"),
)
.to_vec(),
NdArrayDType::I16 => bytemuck::cast_slice(
&Vec::<i16>::try_from(flattened_act)
.expect("Failed to convert flattened_act to i16"),
)
.to_vec(),
NdArrayDType::I32 => bytemuck::cast_slice(
&Vec::<i32>::try_from(flattened_act)
.expect("Failed to convert flattened_act to i32"),
)
.to_vec(),
NdArrayDType::I64 => bytemuck::cast_slice(
&Vec::<i64>::try_from(flattened_act)
.expect("Failed to convert flattened_act to i64"),
)
.to_vec(),
NdArrayDType::Bool => Vec::<bool>::try_from(flattened_act)
.expect("Failed to convert flattened_act to bool")
.into_iter()
.map(|b| if b { 1u8 } else { 0u8 })
.collect(),
},
#[cfg(feature = "tch-backend")]
DType::Tch(dtype) => match dtype {
TchDType::F16 => bytemuck::cast_slice(
&Vec::<f16>::try_from(flattened_act)
.expect("Failed to convert flattened_act to f16"),
)
.to_vec(),
TchDType::Bf16 => bytemuck::cast_slice(
&Vec::<bf16>::try_from(flattened_act)
.expect("Failed to convert flattened_act to bf16"),
)
.to_vec(),
TchDType::F32 => bytemuck::cast_slice(
&Vec::<f32>::try_from(flattened_act)
.expect("Failed to convert flattened_act to f32"),
)
.to_vec(),
TchDType::F64 => bytemuck::cast_slice(
&Vec::<f64>::try_from(flattened_act)
.expect("Failed to convert flattened_act to f64"),
)
.to_vec(),
TchDType::I8 => bytemuck::cast_slice(
&Vec::<i8>::try_from(flattened_act)
.expect("Failed to convert flattened_act to i8"),
)
.to_vec(),
TchDType::I16 => bytemuck::cast_slice(
&Vec::<i16>::try_from(flattened_act)
.expect("Failed to convert flattened_act to i16"),
)
.to_vec(),
TchDType::I32 => bytemuck::cast_slice(
&Vec::<i32>::try_from(flattened_act)
.expect("Failed to convert flattened_act to i32"),
)
.to_vec(),
TchDType::I64 => bytemuck::cast_slice(
&Vec::<i64>::try_from(flattened_act)
.expect("Failed to convert flattened_act to i64"),
)
.to_vec(),
TchDType::U8 => bytemuck::cast_slice(
&Vec::<u8>::try_from(flattened_act)
.expect("Failed to convert flattened_act to u8"),
)
.to_vec(),
TchDType::Bool => Vec::<bool>::try_from(flattened_act)
.expect("Failed to convert flattened_act to bool")
.into_iter()
.map(|b| if b { 1u8 } else { 0u8 })
.collect(),
},
};
Ok(TensorData::new(
output_shape,
self.metadata.output_dtype.clone(),
act_bytes,
TensorData::get_backend_from_dtype(&self.metadata.output_dtype),
))
}
#[cfg(all(
feature = "onnx-model",
any(feature = "ndarray-backend", feature = "tch-backend")
))]
fn run_onnx_step_data(
&self,
session: &Arc<std::sync::Mutex<Session>>,
input_data: TensorData,
) -> Result<TensorData, ModelError> {
fn convert_obs_to_act<IN, OUT>(
tensor_data: TensorData,
session_: &Arc<std::sync::Mutex<Session>>,
) -> Result<Vec<u8>, ModelError>
where
IN: IntoTensorElementType
+ ort::tensor::PrimitiveTensorElementType
+ Debug
+ Clone
+ bytemuck::Pod,
OUT: IntoTensorElementType
+ ort::tensor::PrimitiveTensorElementType
+ Debug
+ Clone
+ bytemuck::Pod,
{
let typed_data: &[IN] = bytemuck::cast_slice(&tensor_data.data);
let data_vec: Vec<IN> = typed_data.to_vec();
let shape = ort::tensor::Shape::from(tensor_data.shape.as_slice());
let ort_value = OrtValue::from_array((shape, data_vec)).map_err(|e| {
ModelError::BackendError(format!("Failed to create OrtValue: {}", e))
})?;
let input = SessionInputValue::from(ort_value);
let mut inputs_map = HashMap::new();
inputs_map.insert("input".to_string(), input);
let mut session_guard = session_
.lock()
.map_err(|e| ModelError::BackendError(format!("Failed to lock session: {}", e)))?;
let output_value = session_guard
.run(inputs_map)
.map_err(|e| ModelError::BackendError(format!("Failed to run session: {}", e)))?;
let first = output_value.into_iter().next().ok_or_else(|| {
ModelError::BackendError("No output from ONNX session".to_string())
})?;
let (_, value) = first;
let (_, owned_slice) = value.try_extract_tensor::<OUT>().map_err(|e| {
ModelError::BackendError(format!("Failed to extract tensor from output: {:?}", e))
})?;
Ok(bytemuck::cast_slice(&owned_slice.to_vec()).to_vec())
}
fn match_obs_to_act<IN>(
input_data: TensorData,
output_dtype: DType,
session_: &Arc<std::sync::Mutex<Session>>,
) -> Result<Vec<u8>, ModelError>
where
IN: IntoTensorElementType
+ ort::tensor::PrimitiveTensorElementType
+ Debug
+ Clone
+ bytemuck::Pod,
{
match &output_dtype {
#[cfg(feature = "ndarray-backend")]
DType::NdArray(nd) => match nd {
NdArrayDType::F16 => convert_obs_to_act::<IN, f32>(input_data, session_),
NdArrayDType::F32 => convert_obs_to_act::<IN, f32>(input_data, session_),
NdArrayDType::F64 => convert_obs_to_act::<IN, f64>(input_data, session_),
NdArrayDType::I8 => convert_obs_to_act::<IN, i8>(input_data, session_),
NdArrayDType::I16 => convert_obs_to_act::<IN, i16>(input_data, session_),
NdArrayDType::I32 => convert_obs_to_act::<IN, i32>(input_data, session_),
NdArrayDType::I64 => convert_obs_to_act::<IN, i64>(input_data, session_),
NdArrayDType::Bool => convert_obs_to_act::<IN, u8>(input_data, session_),
},
#[cfg(feature = "tch-backend")]
DType::Tch(tch) => match tch {
TchDType::F16 => convert_obs_to_act::<IN, f32>(input_data, session_),
TchDType::Bf16 => convert_obs_to_act::<IN, f32>(input_data, session_),
TchDType::F32 => convert_obs_to_act::<IN, f32>(input_data, session_),
TchDType::F64 => convert_obs_to_act::<IN, f64>(input_data, session_),
TchDType::I8 => convert_obs_to_act::<IN, i8>(input_data, session_),
TchDType::I16 => convert_obs_to_act::<IN, i16>(input_data, session_),
TchDType::I32 => convert_obs_to_act::<IN, i32>(input_data, session_),
TchDType::I64 => convert_obs_to_act::<IN, i64>(input_data, session_),
TchDType::U8 => convert_obs_to_act::<IN, u8>(input_data, session_),
TchDType::Bool => convert_obs_to_act::<IN, u8>(input_data, session_),
},
}
}
let rows = input_data.shape.first().copied().unwrap_or(1);
let act_bytes = match &input_data.dtype {
#[cfg(feature = "ndarray-backend")]
DType::NdArray(nd) => match nd {
NdArrayDType::F16 => match_obs_to_act::<f32>(
input_data,
self.metadata.output_dtype.clone(),
session,
)?,
NdArrayDType::F32 => match_obs_to_act::<f32>(
input_data,
self.metadata.output_dtype.clone(),
session,
)?,
NdArrayDType::F64 => match_obs_to_act::<f64>(
input_data,
self.metadata.output_dtype.clone(),
session,
)?,
NdArrayDType::I8 => {
match_obs_to_act::<i8>(input_data, self.metadata.output_dtype.clone(), session)?
}
NdArrayDType::I16 => match_obs_to_act::<i16>(
input_data,
self.metadata.output_dtype.clone(),
session,
)?,
NdArrayDType::I32 => match_obs_to_act::<i32>(
input_data,
self.metadata.output_dtype.clone(),
session,
)?,
NdArrayDType::I64 => match_obs_to_act::<i64>(
input_data,
self.metadata.output_dtype.clone(),
session,
)?,
NdArrayDType::Bool => {
match_obs_to_act::<u8>(input_data, self.metadata.output_dtype.clone(), session)?
}
},
#[cfg(feature = "tch-backend")]
DType::Tch(tch) => match tch {
TchDType::F16 => match_obs_to_act::<f32>(
input_data,
self.metadata.output_dtype.clone(),
session,
)?,
TchDType::Bf16 => match_obs_to_act::<f32>(
input_data,
self.metadata.output_dtype.clone(),
session,
)?,
TchDType::F32 => match_obs_to_act::<f32>(
input_data,
self.metadata.output_dtype.clone(),
session,
)?,
TchDType::F64 => match_obs_to_act::<f64>(
input_data,
self.metadata.output_dtype.clone(),
session,
)?,
TchDType::I8 => {
match_obs_to_act::<i8>(input_data, self.metadata.output_dtype.clone(), session)?
}
TchDType::I16 => match_obs_to_act::<i16>(
input_data,
self.metadata.output_dtype.clone(),
session,
)?,
TchDType::I32 => match_obs_to_act::<i32>(
input_data,
self.metadata.output_dtype.clone(),
session,
)?,
TchDType::I64 => match_obs_to_act::<i64>(
input_data,
self.metadata.output_dtype.clone(),
session,
)?,
TchDType::U8 => {
match_obs_to_act::<u8>(input_data, self.metadata.output_dtype.clone(), session)?
}
TchDType::Bool => {
match_obs_to_act::<u8>(input_data, self.metadata.output_dtype.clone(), session)?
}
},
};
let mut output_shape = Vec::with_capacity(self.metadata.output_shape.len() + 1);
output_shape.push(rows);
output_shape.extend(self.metadata.output_shape.iter().copied());
Ok(TensorData::new(
output_shape,
self.metadata.output_dtype.clone(),
act_bytes,
TensorData::get_backend_from_dtype(&self.metadata.output_dtype),
))
}
#[cfg(all(
feature = "tch-model",
any(feature = "ndarray-backend", feature = "tch-backend")
))]
fn run_libtorch_step<const D_IN: usize, const D_OUT: usize>(
&self,
module: &Arc<CModule>,
observation: Arc<AnyBurnTensor<B, D_IN>>,
) -> Result<TensorData, ModelError> {
let obs_tensor: TchTensor = match &self.metadata.input_dtype {
#[cfg(feature = "ndarray-backend")]
DType::NdArray(nd) => match nd {
NdArrayDType::F16 => {
let obs_tensor_data = observation.clone().into_f16_data().map_err(|e| {
ModelError::BackendError(format!(
"Failed to convert observation to f16: {}",
e
))
})?;
let obs_shape_i64: Vec<i64> =
obs_tensor_data.shape.iter().map(|&d| d as i64).collect();
TchTensor::from_slice::<f16>(bytemuck::cast_slice(&obs_tensor_data.data))
.reshape(obs_shape_i64.as_slice())
}
NdArrayDType::F32 => {
let obs_tensor_data = observation.clone().into_f32_data().map_err(|e| {
ModelError::BackendError(format!(
"Failed to convert observation to f32: {}",
e
))
})?;
let obs_shape_i64: Vec<i64> =
obs_tensor_data.shape.iter().map(|&d| d as i64).collect();
TchTensor::from_slice::<f32>(bytemuck::cast_slice(&obs_tensor_data.data))
.reshape(obs_shape_i64.as_slice())
}
NdArrayDType::F64 => {
let obs_tensor_data = observation.clone().into_f64_data().map_err(|e| {
ModelError::BackendError(format!(
"Failed to convert observation to f64: {}",
e
))
})?;
let obs_shape_i64: Vec<i64> =
obs_tensor_data.shape.iter().map(|&d| d as i64).collect();
TchTensor::from_slice::<f64>(bytemuck::cast_slice(&obs_tensor_data.data))
.reshape(obs_shape_i64.as_slice())
}
NdArrayDType::I8 => {
let obs_tensor_data = observation.clone().into_i8_data().map_err(|e| {
ModelError::BackendError(format!(
"Failed to convert observation to i8: {}",
e
))
})?;
let obs_shape_i64: Vec<i64> =
obs_tensor_data.shape.iter().map(|&d| d as i64).collect();
TchTensor::from_slice::<i8>(bytemuck::cast_slice(&obs_tensor_data.data))
.reshape(obs_shape_i64.as_slice())
}
NdArrayDType::I16 => {
let obs_tensor_data = observation.clone().into_i16_data().map_err(|e| {
ModelError::BackendError(format!(
"Failed to convert observation to i16: {}",
e
))
})?;
let obs_shape_i64: Vec<i64> =
obs_tensor_data.shape.iter().map(|&d| d as i64).collect();
TchTensor::from_slice::<i16>(bytemuck::cast_slice(&obs_tensor_data.data))
.reshape(obs_shape_i64.as_slice())
}
NdArrayDType::I32 => {
let obs_tensor_data = observation.clone().into_i32_data().map_err(|e| {
ModelError::BackendError(format!(
"Failed to convert observation to i32: {}",
e
))
})?;
let obs_shape_i64: Vec<i64> =
obs_tensor_data.shape.iter().map(|&d| d as i64).collect();
TchTensor::from_slice::<i32>(bytemuck::cast_slice(&obs_tensor_data.data))
.reshape(obs_shape_i64.as_slice())
}
NdArrayDType::I64 => {
let obs_tensor_data = observation.clone().into_i64_data().map_err(|e| {
ModelError::BackendError(format!(
"Failed to convert observation to i64: {}",
e
))
})?;
let obs_shape_i64: Vec<i64> =
obs_tensor_data.shape.iter().map(|&d| d as i64).collect();
TchTensor::from_slice::<i64>(bytemuck::cast_slice(&obs_tensor_data.data))
.reshape(obs_shape_i64.as_slice())
}
NdArrayDType::Bool => {
let obs_tensor_data = observation.clone().into_bool_data().map_err(|e| {
ModelError::BackendError(format!(
"Failed to convert observation to bool: {}",
e
))
})?;
let obs_shape_i64: Vec<i64> =
obs_tensor_data.shape.iter().map(|&d| d as i64).collect();
TchTensor::from_slice::<u8>(bytemuck::cast_slice(&obs_tensor_data.data))
.reshape(obs_shape_i64.as_slice())
}
},
#[cfg(feature = "tch-backend")]
DType::Tch(tch) => match tch {
TchDType::F16 => {
let obs_tensor_data = observation.clone().into_f16_data().map_err(|e| {
ModelError::BackendError(format!(
"Failed to convert observation to f16: {}",
e
))
})?;
let obs_shape_i64: Vec<i64> =
obs_tensor_data.shape.iter().map(|&d| d as i64).collect();
TchTensor::from_slice::<f16>(bytemuck::cast_slice(&obs_tensor_data.data))
.reshape(obs_shape_i64.as_slice())
}
TchDType::Bf16 => {
let obs_tensor_data = observation.clone().into_bf16_data().map_err(|e| {
ModelError::BackendError(format!(
"Failed to convert observation to bf16: {}",
e
))
})?;
let obs_shape_i64: Vec<i64> =
obs_tensor_data.shape.iter().map(|&d| d as i64).collect();
TchTensor::from_slice::<bf16>(bytemuck::cast_slice(&obs_tensor_data.data))
.reshape(obs_shape_i64.as_slice())
}
TchDType::F32 => {
let obs_tensor_data = observation.clone().into_f32_data().map_err(|e| {
ModelError::BackendError(format!(
"Failed to convert observation to f32: {}",
e
))
})?;
let obs_shape_i64: Vec<i64> =
obs_tensor_data.shape.iter().map(|&d| d as i64).collect();
TchTensor::from_slice::<f32>(bytemuck::cast_slice(&obs_tensor_data.data))
.reshape(obs_shape_i64.as_slice())
}
TchDType::F64 => {
let obs_tensor_data = observation.clone().into_f64_data().map_err(|e| {
ModelError::BackendError(format!(
"Failed to convert observation to f64: {}",
e
))
})?;
let obs_shape_i64: Vec<i64> =
obs_tensor_data.shape.iter().map(|&d| d as i64).collect();
TchTensor::from_slice::<f64>(bytemuck::cast_slice(&obs_tensor_data.data))
.reshape(obs_shape_i64.as_slice())
}
TchDType::I8 => {
let obs_tensor_data = observation.clone().into_i8_data().map_err(|e| {
ModelError::BackendError(format!(
"Failed to convert observation to i8: {}",
e
))
})?;
let obs_shape_i64: Vec<i64> =
obs_tensor_data.shape.iter().map(|&d| d as i64).collect();
TchTensor::from_slice::<i8>(bytemuck::cast_slice(&obs_tensor_data.data))
.reshape(obs_shape_i64.as_slice())
}
TchDType::I16 => {
let obs_tensor_data = observation.clone().into_i16_data().map_err(|e| {
ModelError::BackendError(format!(
"Failed to convert observation to i16: {}",
e
))
})?;
let obs_shape_i64: Vec<i64> =
obs_tensor_data.shape.iter().map(|&d| d as i64).collect();
TchTensor::from_slice::<i16>(bytemuck::cast_slice(&obs_tensor_data.data))
.reshape(obs_shape_i64.as_slice())
}
TchDType::I32 => {
let obs_tensor_data = observation.clone().into_i32_data().map_err(|e| {
ModelError::BackendError(format!(
"Failed to convert observation to i32: {}",
e
))
})?;
let obs_shape_i64: Vec<i64> =
obs_tensor_data.shape.iter().map(|&d| d as i64).collect();
TchTensor::from_slice::<i32>(bytemuck::cast_slice(&obs_tensor_data.data))
.reshape(obs_shape_i64.as_slice())
}
TchDType::I64 => {
let obs_tensor_data = observation.clone().into_i64_data().map_err(|e| {
ModelError::BackendError(format!(
"Failed to convert observation to i64: {}",
e
))
})?;
let obs_shape_i64: Vec<i64> =
obs_tensor_data.shape.iter().map(|&d| d as i64).collect();
TchTensor::from_slice::<i64>(bytemuck::cast_slice(&obs_tensor_data.data))
.reshape(obs_shape_i64.as_slice())
}
TchDType::U8 => {
let obs_tensor_data = observation.clone().into_u8_data().map_err(|e| {
ModelError::BackendError(format!(
"Failed to convert observation to u8: {}",
e
))
})?;
let obs_shape_i64: Vec<i64> =
obs_tensor_data.shape.iter().map(|&d| d as i64).collect();
TchTensor::from_slice::<u8>(bytemuck::cast_slice(&obs_tensor_data.data))
.reshape(obs_shape_i64.as_slice())
}
TchDType::Bool => {
let obs_tensor_data = observation.clone().into_bool_data().map_err(|e| {
ModelError::BackendError(format!(
"Failed to convert observation to bool: {}",
e
))
})?;
let obs_shape_i64: Vec<i64> =
obs_tensor_data.shape.iter().map(|&d| d as i64).collect();
TchTensor::from_slice::<u8>(bytemuck::cast_slice(&obs_tensor_data.data))
.reshape(obs_shape_i64.as_slice())
}
},
};
let act_tensor: TchTensor =
no_grad(|| module.forward_ts(&[&obs_tensor])).expect("Failed to run forward pass");
let flattened_act: TchTensor = act_tensor.flatten(0, -1);
let act_bytes: Vec<u8> = match &self.metadata.output_dtype {
#[cfg(feature = "ndarray-backend")]
DType::NdArray(dtype) => match dtype {
NdArrayDType::F16 => {
let vec: Vec<f16> = Vec::<f16>::try_from(flattened_act)
.expect("Failed to convert flattened_act to f16");
bytemuck::cast_slice(&vec).to_vec()
}
NdArrayDType::F32 => {
let vec: Vec<f32> = Vec::<f32>::try_from(flattened_act)
.expect("Failed to convert flattened_act to f32");
bytemuck::cast_slice(&vec).to_vec()
}
NdArrayDType::F64 => {
let vec: Vec<f64> = Vec::<f64>::try_from(flattened_act)
.expect("Failed to convert flattened_act to f64");
bytemuck::cast_slice(&vec).to_vec()
}
NdArrayDType::I8 => {
let vec: Vec<i8> = Vec::<i8>::try_from(flattened_act)
.expect("Failed to convert flattened_act to i8");
bytemuck::cast_slice(&vec).to_vec()
}
NdArrayDType::I16 => {
let vec: Vec<i16> = Vec::<i16>::try_from(flattened_act)
.expect("Failed to convert flattened_act to i16");
bytemuck::cast_slice(&vec).to_vec()
}
NdArrayDType::I32 => {
let vec: Vec<i32> = Vec::<i32>::try_from(flattened_act)
.expect("Failed to convert flattened_act to i32");
bytemuck::cast_slice(&vec).to_vec()
}
NdArrayDType::I64 => {
let vec: Vec<i64> = Vec::<i64>::try_from(flattened_act)
.expect("Failed to convert flattened_act to i64");
bytemuck::cast_slice(&vec).to_vec()
}
NdArrayDType::Bool => {
let vec: Vec<bool> = Vec::<bool>::try_from(flattened_act)
.expect("Failed to convert flattened_act to bool");
vec.into_iter().map(|b| if b { 1u8 } else { 0u8 }).collect()
}
},
#[cfg(feature = "tch-backend")]
DType::Tch(dtype) => match dtype {
TchDType::F16 => {
let vec: Vec<f16> = Vec::<f16>::try_from(flattened_act)
.expect("Failed to convert flattened_act to f16");
bytemuck::cast_slice(&vec).to_vec()
}
TchDType::Bf16 => {
let vec: Vec<bf16> = Vec::<bf16>::try_from(flattened_act)
.expect("Failed to convert flattened_act to bf16");
bytemuck::cast_slice(&vec).to_vec()
}
TchDType::F32 => {
let vec: Vec<f32> = Vec::<f32>::try_from(flattened_act)
.expect("Failed to convert flattened_act to f32");
bytemuck::cast_slice(&vec).to_vec()
}
TchDType::F64 => {
let vec: Vec<f64> = Vec::<f64>::try_from(flattened_act)
.expect("Failed to convert flattened_act to f64");
bytemuck::cast_slice(&vec).to_vec()
}
TchDType::I8 => {
let vec: Vec<i8> = Vec::<i8>::try_from(flattened_act)
.expect("Failed to convert flattened_act to i8");
bytemuck::cast_slice(&vec).to_vec()
}
TchDType::I16 => {
let vec: Vec<i16> = Vec::<i16>::try_from(flattened_act)
.expect("Failed to convert flattened_act to i16");
bytemuck::cast_slice(&vec).to_vec()
}
TchDType::I32 => {
let vec: Vec<i32> = Vec::<i32>::try_from(flattened_act)
.expect("Failed to convert flattened_act to i32");
bytemuck::cast_slice(&vec).to_vec()
}
TchDType::I64 => {
let vec: Vec<i64> = Vec::<i64>::try_from(flattened_act)
.expect("Failed to convert flattened_act to i64");
bytemuck::cast_slice(&vec).to_vec()
}
TchDType::U8 => {
let vec: Vec<u8> = Vec::<u8>::try_from(flattened_act)
.expect("Failed to convert flattened_act to u8");
bytemuck::cast_slice(&vec).to_vec()
}
TchDType::Bool => {
let vec: Vec<bool> = Vec::<bool>::try_from(flattened_act)
.expect("Failed to convert flattened_act to bool");
vec.into_iter().map(|b| if b { 1u8 } else { 0u8 }).collect()
}
},
};
Ok(TensorData::new(
self.metadata.output_shape.clone(),
self.metadata.output_dtype.clone(),
act_bytes,
TensorData::get_backend_from_dtype(&self.metadata.output_dtype),
))
}
#[cfg(all(
feature = "onnx-model",
any(feature = "ndarray-backend", feature = "tch-backend")
))]
fn run_onnx_step<const D_IN: usize, const D_OUT: usize>(
&self,
session: &Arc<std::sync::Mutex<Session>>,
observation: Arc<AnyBurnTensor<B, D_IN>>,
) -> Result<TensorData, ModelError> {
fn convert_obs_to_act<IN, OUT>(
tensor_data: TensorData,
session_: &Arc<std::sync::Mutex<Session>>,
) -> Result<Vec<u8>, ModelError>
where
IN: IntoTensorElementType
+ ort::tensor::PrimitiveTensorElementType
+ Debug
+ Clone
+ bytemuck::Pod,
OUT: IntoTensorElementType
+ ort::tensor::PrimitiveTensorElementType
+ Debug
+ Clone
+ bytemuck::Pod,
{
let typed_data: &[IN] = bytemuck::cast_slice(&tensor_data.data);
let data_vec: Vec<IN> = typed_data.to_vec();
let shape = ort::tensor::Shape::from(tensor_data.shape.as_slice());
let ort_value = OrtValue::from_array((shape, data_vec)).map_err(|e| {
ModelError::BackendError(format!("Failed to create OrtValue: {}", e))
})?;
let input = SessionInputValue::from(ort_value);
let mut inputs_map = HashMap::new();
inputs_map.insert("input".to_string(), input);
let mut session_guard = session_
.lock()
.map_err(|e| ModelError::BackendError(format!("Failed to lock session: {}", e)))?;
let output_value = session_guard
.run(inputs_map)
.map_err(|e| ModelError::BackendError(format!("Failed to run session: {}", e)))?;
let first = output_value.into_iter().next().ok_or_else(|| {
ModelError::BackendError("No output from ONNX session".to_string())
})?;
let (_, value) = first;
let (_, owned_slice) = value.try_extract_tensor::<OUT>().map_err(|e| {
ModelError::BackendError(format!("Failed to extract tensor from output: {:?}", e))
})?;
let act_vec: Vec<OUT> = owned_slice.to_vec();
let act_bytes: Vec<u8> = bytemuck::cast_slice(&act_vec).to_vec();
Ok(act_bytes)
}
fn match_obs_to_act<IN>(
input_data: TensorData,
output_dtype: DType,
session_: &Arc<std::sync::Mutex<Session>>,
) -> Result<Vec<u8>, ModelError>
where
IN: IntoTensorElementType
+ ort::tensor::PrimitiveTensorElementType
+ Debug
+ Clone
+ bytemuck::Pod,
{
match &output_dtype {
#[cfg(feature = "ndarray-backend")]
DType::NdArray(nd) => match nd {
NdArrayDType::F16 => convert_obs_to_act::<IN, f32>(input_data, session_),
NdArrayDType::F32 => convert_obs_to_act::<IN, f32>(input_data, session_),
NdArrayDType::F64 => convert_obs_to_act::<IN, f64>(input_data, session_),
NdArrayDType::I8 => convert_obs_to_act::<IN, i8>(input_data, session_),
NdArrayDType::I16 => convert_obs_to_act::<IN, i16>(input_data, session_),
NdArrayDType::I32 => convert_obs_to_act::<IN, i32>(input_data, session_),
NdArrayDType::I64 => convert_obs_to_act::<IN, i64>(input_data, session_),
NdArrayDType::Bool => convert_obs_to_act::<IN, u8>(input_data, session_),
},
#[cfg(feature = "tch-backend")]
DType::Tch(tch) => match tch {
TchDType::F16 => convert_obs_to_act::<IN, f32>(input_data, session_),
TchDType::Bf16 => convert_obs_to_act::<IN, f32>(input_data, session_),
TchDType::F32 => convert_obs_to_act::<IN, f32>(input_data, session_),
TchDType::F64 => convert_obs_to_act::<IN, f64>(input_data, session_),
TchDType::I8 => convert_obs_to_act::<IN, i8>(input_data, session_),
TchDType::I16 => convert_obs_to_act::<IN, i16>(input_data, session_),
TchDType::I32 => convert_obs_to_act::<IN, i32>(input_data, session_),
TchDType::I64 => convert_obs_to_act::<IN, i64>(input_data, session_),
TchDType::U8 => convert_obs_to_act::<IN, u8>(input_data, session_),
TchDType::Bool => convert_obs_to_act::<IN, u8>(input_data, session_),
},
}
}
let act_bytes = match &self.metadata.input_dtype {
#[cfg(feature = "ndarray-backend")]
DType::NdArray(nd) => match nd {
NdArrayDType::F16 => {
let obs_tensor_data: TensorData =
observation.clone().into_f32_data().map_err(|e| {
ModelError::BackendError(format!(
"Failed to convert observation to f32 (from f16): {}",
e
))
})?;
match_obs_to_act::<f32>(
obs_tensor_data,
self.metadata.output_dtype.clone(),
session,
)?
}
NdArrayDType::F32 => {
let obs_tensor_data = observation.clone().into_f32_data().map_err(|e| {
ModelError::BackendError(format!(
"Failed to convert observation to f32: {}",
e
))
})?;
match_obs_to_act::<f32>(
obs_tensor_data,
self.metadata.output_dtype.clone(),
session,
)?
}
NdArrayDType::F64 => {
let obs_tensor_data = observation.clone().into_f64_data().map_err(|e| {
ModelError::BackendError(format!(
"Failed to convert observation to f64: {}",
e
))
})?;
match_obs_to_act::<f64>(
obs_tensor_data,
self.metadata.output_dtype.clone(),
session,
)?
}
NdArrayDType::I8 => {
let obs_tensor_data = observation.clone().into_i8_data().map_err(|e| {
ModelError::BackendError(format!(
"Failed to convert observation to i8: {}",
e
))
})?;
match_obs_to_act::<i8>(
obs_tensor_data,
self.metadata.output_dtype.clone(),
session,
)?
}
NdArrayDType::I16 => {
let obs_tensor_data = observation.clone().into_i16_data().map_err(|e| {
ModelError::BackendError(format!(
"Failed to convert observation to i16: {}",
e
))
})?;
match_obs_to_act::<i16>(
obs_tensor_data,
self.metadata.output_dtype.clone(),
session,
)?
}
NdArrayDType::I32 => {
let obs_tensor_data = observation.clone().into_i32_data().map_err(|e| {
ModelError::BackendError(format!(
"Failed to convert observation to i32: {}",
e
))
})?;
match_obs_to_act::<i32>(
obs_tensor_data,
self.metadata.output_dtype.clone(),
session,
)?
}
NdArrayDType::I64 => {
let obs_tensor_data = observation.clone().into_i64_data().map_err(|e| {
ModelError::BackendError(format!(
"Failed to convert observation to i64: {}",
e
))
})?;
match_obs_to_act::<i64>(
obs_tensor_data,
self.metadata.output_dtype.clone(),
session,
)?
}
NdArrayDType::Bool => {
let obs_tensor_data = observation.clone().into_bool_data().map_err(|e| {
ModelError::BackendError(format!(
"Failed to convert observation to bool: {}",
e
))
})?;
match_obs_to_act::<u8>(
obs_tensor_data,
self.metadata.output_dtype.clone(),
session,
)?
}
},
#[cfg(feature = "tch-backend")]
DType::Tch(tch) => match tch {
TchDType::F16 => {
let obs_tensor_data = observation.clone().into_f32_data().map_err(|e| {
ModelError::BackendError(format!(
"Failed to convert observation to f32 (from f16): {}",
e
))
})?;
match_obs_to_act::<f32>(
obs_tensor_data,
self.metadata.output_dtype.clone(),
session,
)?
}
TchDType::Bf16 => {
let obs_tensor_data = observation.clone().into_f32_data().map_err(|e| {
ModelError::BackendError(format!(
"Failed to convert observation to f32 (from bf16): {}",
e
))
})?;
match_obs_to_act::<f32>(
obs_tensor_data,
self.metadata.output_dtype.clone(),
session,
)?
}
TchDType::F32 => {
let obs_tensor_data = observation.clone().into_f32_data().map_err(|e| {
ModelError::BackendError(format!(
"Failed to convert observation to f32: {}",
e
))
})?;
match_obs_to_act::<f32>(
obs_tensor_data,
self.metadata.output_dtype.clone(),
session,
)?
}
TchDType::F64 => {
let obs_tensor_data = observation.clone().into_f64_data().map_err(|e| {
ModelError::BackendError(format!(
"Failed to convert observation to f64: {}",
e
))
})?;
match_obs_to_act::<f64>(
obs_tensor_data,
self.metadata.output_dtype.clone(),
session,
)?
}
TchDType::I8 => {
let obs_tensor_data = observation.clone().into_i8_data().map_err(|e| {
ModelError::BackendError(format!(
"Failed to convert observation to i8: {}",
e
))
})?;
match_obs_to_act::<i8>(
obs_tensor_data,
self.metadata.output_dtype.clone(),
session,
)?
}
TchDType::I16 => {
let obs_tensor_data = observation.clone().into_i16_data().map_err(|e| {
ModelError::BackendError(format!(
"Failed to convert observation to i16: {}",
e
))
})?;
match_obs_to_act::<i16>(
obs_tensor_data,
self.metadata.output_dtype.clone(),
session,
)?
}
TchDType::I32 => {
let obs_tensor_data = observation.clone().into_i32_data().map_err(|e| {
ModelError::BackendError(format!(
"Failed to convert observation to i32: {}",
e
))
})?;
match_obs_to_act::<i32>(
obs_tensor_data,
self.metadata.output_dtype.clone(),
session,
)?
}
TchDType::I64 => {
let obs_tensor_data = observation.clone().into_i64_data().map_err(|e| {
ModelError::BackendError(format!(
"Failed to convert observation to i64: {}",
e
))
})?;
match_obs_to_act::<i64>(
obs_tensor_data,
self.metadata.output_dtype.clone(),
session,
)?
}
TchDType::U8 => {
let obs_tensor_data = observation.clone().into_u8_data().map_err(|e| {
ModelError::BackendError(format!(
"Failed to convert observation to u8: {}",
e
))
})?;
match_obs_to_act::<u8>(
obs_tensor_data,
self.metadata.output_dtype.clone(),
session,
)?
}
TchDType::Bool => {
let obs_tensor_data = observation.clone().into_bool_data().map_err(|e| {
ModelError::BackendError(format!(
"Failed to convert observation to bool: {}",
e
))
})?;
match_obs_to_act::<u8>(
obs_tensor_data,
self.metadata.output_dtype.clone(),
session,
)?
}
},
};
Ok(TensorData {
shape: self.metadata.output_shape.clone(),
dtype: self.metadata.output_dtype.clone(),
data: act_bytes,
supported_backend: TensorData::get_backend_from_dtype(&self.metadata.output_dtype),
})
}
}
#[cfg(test)]
mod unit_tests {
use super::*;
use std::marker::PhantomData;
use crate::data::tensor::FloatBurnTensor;
use burn_tensor::TensorData as BurnTensorData;
use uuid::Uuid;
#[cfg(all(
feature = "ndarray-backend",
any(feature = "tch-model", feature = "onnx-model")
))]
use burn_ndarray::NdArray;
#[cfg(all(
feature = "ndarray-backend",
any(feature = "tch-model", feature = "onnx-model")
))]
use burn_tensor::{Float, Tensor};
fn temp_dir_path(label: &str) -> PathBuf {
std::env::temp_dir().join(format!("relayrl-model-{label}-{}", Uuid::new_v4()))
}
#[test]
fn model_file_type_parses_supported_extensions() {
assert_eq!(
ModelFileType::from_path(Path::new("policy.pt")).unwrap(),
ModelFileType::Pt
);
assert_eq!(
ModelFileType::from_path(Path::new("policy.onnx")).unwrap(),
ModelFileType::Onnx
);
assert!(matches!(
ModelFileType::from_path(Path::new("policy.bin")),
Err(ModelError::UnsupportedModelType(message)) if message.contains("Unsupported extension")
));
}
#[test]
#[cfg(feature = "ndarray-backend")]
fn model_metadata_save_load_round_trip_preserves_paths() {
let dir = temp_dir_path("metadata-roundtrip");
let metadata = ModelMetadata {
model_file: "policy.onnx".to_string(),
model_type: ModelFileType::Onnx,
input_dtype: DType::NdArray(NdArrayDType::F32),
output_dtype: DType::NdArray(NdArrayDType::F32),
input_shape: vec![2],
output_shape: vec![2],
default_device: Some(DeviceType::Cpu),
};
metadata.save_to_dir(&dir).unwrap();
let loaded = ModelMetadata::load_from_dir(&dir).unwrap();
assert_eq!(loaded.model_file, "policy.onnx");
assert_eq!(loaded.resolve_model_path(&dir), dir.join("policy.onnx"));
let _ = fs::remove_dir_all(dir);
}
#[test]
#[cfg(feature = "ndarray-backend")]
fn model_metadata_load_rejects_invalid_fields() {
let dir = temp_dir_path("metadata-invalid");
let metadata = ModelMetadata {
model_file: String::new(),
model_type: ModelFileType::Onnx,
input_dtype: DType::NdArray(NdArrayDType::F32),
output_dtype: DType::NdArray(NdArrayDType::F32),
input_shape: vec![2],
output_shape: vec![2],
default_device: Some(DeviceType::Cpu),
};
metadata.save_to_dir(&dir).unwrap();
let err = ModelMetadata::load_from_dir(&dir)
.expect_err("metadata with an empty model file should be rejected");
assert!(matches!(
err,
ModelError::InvalidMetadata(message) if message.contains("model_file is empty")
));
let _ = fs::remove_dir_all(dir);
}
#[cfg(all(
feature = "ndarray-backend",
any(feature = "tch-model", feature = "onnx-model")
))]
fn stub_module(output_shape: Vec<usize>) -> ModelModule<NdArray> {
ModelModule {
model: Model {
file_type: ModelFileType::Onnx,
raw_bytes: Arc::<[u8]>::from(vec![1u8, 2, 3]),
inference: InferenceModel::Unsupported,
_phantom: PhantomData,
},
metadata: ModelMetadata {
model_file: "test.onnx".to_string(),
model_type: ModelFileType::Onnx,
input_dtype: DType::NdArray(NdArrayDType::F32),
output_dtype: DType::NdArray(NdArrayDType::F32),
input_shape: vec![2],
output_shape,
default_device: Some(DeviceType::Cpu),
},
}
}
#[cfg(all(
feature = "ndarray-backend",
any(feature = "tch-model", feature = "onnx-model")
))]
fn float_any_tensor(values: &[f32]) -> Arc<AnyBurnTensor<NdArray, 1>> {
let device = NdArray::get_device(&DeviceType::Cpu).unwrap();
let tensor = Tensor::<NdArray, 1, Float>::from_data(
BurnTensorData::new(values.to_vec(), [values.len()]),
&device,
);
Arc::new(AnyBurnTensor::Float(FloatBurnTensor {
tensor: Arc::new(tensor),
dtype: DType::NdArray(NdArrayDType::F32),
}))
}
#[test]
#[cfg(all(
feature = "ndarray-backend",
any(feature = "tch-model", feature = "onnx-model")
))]
fn model_module_save_writes_metadata_and_model_bytes() {
let dir = temp_dir_path("module-save");
let module = stub_module(vec![2]);
module.save(&dir).unwrap();
assert!(dir.join("metadata.json").exists());
assert_eq!(fs::read(dir.join("test.onnx")).unwrap(), vec![1, 2, 3]);
let _ = fs::remove_dir_all(dir);
}
#[test]
#[cfg(all(
feature = "ndarray-backend",
any(feature = "tch-model", feature = "onnx-model")
))]
fn resolve_device_returns_cpu_for_ndarray_models() {
let module = stub_module(vec![2]);
assert!(matches!(
module.resolve_device(),
burn_tensor::Device::<NdArray>::Cpu
));
}
#[test]
#[cfg(all(
feature = "ndarray-backend",
any(feature = "tch-model", feature = "onnx-model")
))]
fn zeros_action_matches_output_shape_dtype_and_backend() {
let module = stub_module(vec![2]);
let zero_action = module.zeros_action::<1>().unwrap();
assert_eq!(zero_action.shape, vec![2]);
assert_eq!(zero_action.dtype, DType::NdArray(NdArrayDType::F32));
assert_eq!(
zero_action.supported_backend,
SupportedTensorBackend::NdArray
);
assert_eq!(zero_action.data, vec![0; 8]);
}
#[test]
#[cfg(all(
feature = "ndarray-backend",
any(feature = "tch-model", feature = "onnx-model")
))]
fn step_falls_back_to_zero_actions_when_inference_is_unavailable() {
let module = stub_module(vec![2]);
let observation = float_any_tensor(&[1.0, 2.0]);
let mask = float_any_tensor(&[1.0, 0.0]);
let (action, mask_data, aux) = module.step::<1, 1>(observation, Some(mask));
assert!(aux.is_empty());
assert_eq!(action.shape, vec![2]);
assert_eq!(action.data, vec![0; 8]);
assert_eq!(
mask_data.expect("mask data should be preserved").data,
[1.0f32, 0.0]
.into_iter()
.flat_map(|value| value.to_le_bytes())
.collect::<Vec<_>>()
);
}
}