use crate::proto::{RelayRlAction as GrpcRelayRLAction, Trajectory};
use crate::types::action::{RelayRLAction, SafeTensorError, TensorData};
use crate::types::trajectory::{RelayRLTrajectory, RelayRLTrajectoryTrait};
use tch::{CModule, Device, TchError};
use tempfile::NamedTempFile;
use crate::types::action::RelayRLData;
use std::collections::HashMap;
use std::io::Write;
use std::path::PathBuf;
pub(crate) fn serialize_action(action: &RelayRLAction) -> GrpcRelayRLAction {
let obs_bytes = action
.obs
.as_ref()
.map_or_else(Vec::new, |td| td.data.clone());
let act_bytes = action
.act
.as_ref()
.map_or_else(Vec::new, |td| td.data.clone());
let mask_bytes = action
.mask
.as_ref()
.map_or_else(Vec::new, |td| td.data.clone());
let data: HashMap<String, Vec<u8>> = action.data.as_ref().map_or_else(HashMap::new, |map| {
map.iter()
.map(|(k, v)| {
let serialized =
serde_json::to_vec(v).expect("Serialization of RelayRLData failed");
(k.clone(), serialized)
})
.collect()
});
GrpcRelayRLAction {
obs: obs_bytes,
action: act_bytes,
mask: mask_bytes,
reward: action.rew,
data,
done: action.done,
reward_update_flag: false,
}
}
pub(crate) fn deserialize_action(
grpc_action: GrpcRelayRLAction,
) -> Result<RelayRLAction, SafeTensorError> {
let obs: Option<TensorData> = if grpc_action.obs.is_empty() {
None
} else {
Some(RelayRLAction::from_bytes(grpc_action.obs)?)
};
let act: Option<TensorData> = if grpc_action.action.is_empty() {
None
} else {
Some(RelayRLAction::from_bytes(grpc_action.action)?)
};
let mask: Option<TensorData> = if grpc_action.mask.is_empty() {
None
} else {
Some(RelayRLAction::from_bytes(grpc_action.mask)?)
};
let data: Option<HashMap<String, RelayRLData>> = if grpc_action.data.is_empty() {
None
} else {
let mut map = HashMap::new();
for (k, v) in grpc_action.data.into_iter() {
let deserialized: RelayRLData = serde_json::from_slice(&v)
.map_err(|e| SafeTensorError::SerializationError(e.to_string()))?;
map.insert(k, deserialized);
}
Some(map)
};
Ok(RelayRLAction {
obs,
act,
mask,
rew: grpc_action.reward,
data,
done: grpc_action.done,
reward_updated: false,
})
}
pub(crate) fn grpc_trajectory_to_relayrl_trajectory(
trajectory: Trajectory,
max_traj_length: u32,
) -> RelayRLTrajectory {
let mut relayrl_trajectory: RelayRLTrajectory =
RelayRLTrajectory::new(Some(max_traj_length), None);
for action in trajectory.actions {
let action: RelayRLAction =
deserialize_action(action).expect("failed to deserialize action");
relayrl_trajectory.add_action(&action, false);
}
relayrl_trajectory
}
pub(crate) fn serialize_model(model: &CModule, dir: PathBuf) -> Vec<u8> {
let temp_file = tempfile::Builder::new()
.prefix("_model")
.suffix(".pt")
.tempfile_in(dir)
.expect("Failed to create temp file");
let temp_path = temp_file.path();
model.save(temp_path).expect("Failed to save model");
std::fs::read(temp_path).expect("Failed to read model bytes")
}
pub(crate) fn deserialize_model(model_bytes: Vec<u8>) -> Result<CModule, TchError> {
let mut temp_file = NamedTempFile::new().expect("Failed to create temp file");
temp_file
.write_all(&model_bytes)
.expect("Failed to write model bytes");
temp_file.flush().expect("Failed to flush temp file");
Ok(CModule::load_on_device(temp_file.path(), Device::Cpu)
.expect("Failed to load model from bytes"))
}