#[cfg(feature = "grpc_network")]
use crate::network::client::agent_grpc::RelayRLAgentGrpc;
#[cfg(feature = "zmq_network")]
use crate::network::client::agent_zmq::RelayRLAgentZmq;
use crate::sys_utils::config_loader::{ConfigLoader, DEFAULT_CONFIG_CONTENT, DEFAULT_CONFIG_PATH};
use crate::types::action::{RelayRLData, TensorData};
use crate::{get_or_create_config_json_path, resolve_config_json_path};
use std::collections::HashMap;
use std::fs;
use std::path::PathBuf;
use std::sync::Arc;
use tch::{CModule, Device, IValue, Kind, Tensor, no_grad};
fn agent_type_to_string(agent: Arc<RelayRLAgent>) -> Option<&'static str> {
if agent.agent_zmq.is_some() {
Some("zmq")
} else if agent.agent_grpc.is_some() {
Some("grpc")
} else {
None
}
}
pub struct RelayRLAgent {
#[cfg(any(feature = "networks", feature = "zmq_network"))]
pub agent_zmq: Option<RelayRLAgentZmq>,
#[cfg(any(feature = "networks", feature = "grpc_network"))]
pub agent_grpc: Option<RelayRLAgentGrpc>,
}
pub fn validate_model(model: &CModule) {
let input_dim: IValue = model
.method_is::<IValue>("get_input_dim", &[])
.expect("Failed to get input dimension");
let input_dim_usize: usize = if let IValue::Int(dim) = input_dim {
if dim < 0 {
panic!("Input dimension must be non-negative");
}
usize::try_from(dim).expect("Input dimension too large")
} else {
panic!("Input dimension must be an integer");
};
let output_dim: IValue = model
.method_is::<IValue>("get_output_dim", &[])
.expect("Failed to get output dimension");
let output_dim_usize: usize = if let IValue::Int(dim) = output_dim {
if dim < 0 {
panic!("Output dimension must be non-negative");
}
usize::try_from(dim).expect("Output dimension too large")
} else {
panic!("Output dimension must be an integer");
};
let input_test_vec: Vec<f64> = vec![0.0; input_dim_usize];
let input_test_tensor: Tensor = Tensor::f_from_slice(&input_test_vec)
.expect("Failed to convert slice to tensor")
.reshape([1, input_dim_usize as i64]);
let output_test_vec: Vec<f64> = vec![0.0; output_dim_usize];
let output_test_tensor: Tensor = Tensor::f_from_slice(&output_test_vec)
.expect("Failed to convert slice to tensor")
.reshape([1, output_dim_usize as i64]);
let obs_tensor: Tensor = input_test_tensor.to_device(Device::Cpu).contiguous();
let mask_tensor: Tensor = output_test_tensor.to_device(Device::Cpu).contiguous();
let obs_ivalue = IValue::Tensor(obs_tensor.to_kind(Kind::Float));
let mask_ivalue = IValue::Tensor(mask_tensor.to_kind(Kind::Float));
let test_input: Vec<IValue> = vec![obs_ivalue, mask_ivalue];
let output: IValue = no_grad(|| model.method_is::<IValue>("step", &test_input))
.expect("Failed to run forward 'step' pass");
match output {
IValue::Tuple(ref values) => {
assert_eq!(
values.len(),
2,
"Model forward must return a tuple of length 2"
);
if let IValue::Tensor(ref _tensor) = values[0] {
} else {
panic!("First element of tuple must be a Tensor");
}
if let IValue::GenericDict(ref dict) = values[1] {
assert!(
!dict.is_empty(),
"Second element of tuple must be a non-empty dictionary"
);
} else {
panic!("Second element of tuple must be a Dictionary");
}
}
_ => panic!("Model forward must return a tuple"),
}
}
pub fn convert_generic_dict(dict: &Vec<(IValue, IValue)>) -> Option<HashMap<String, RelayRLData>> {
let mut map: HashMap<String, RelayRLData> = HashMap::new();
for (k, v) in dict {
if let IValue::String(s) = k {
if let IValue::Tensor(tensor) = v {
map.insert(
s.clone(),
RelayRLData::Tensor(
TensorData::try_from(&tensor.to_kind(Kind::Float))
.expect("Failed to convert tensor to TensorData"),
),
);
} else if let IValue::Int(i) = v {
map.insert(
s.clone(),
RelayRLData::Int((*i).try_into().expect("Failed to convert int to i32")),
);
} else if let IValue::Double(f) = v {
map.insert(s.clone(), RelayRLData::Double(*f));
}
}
}
Some(map)
}
impl RelayRLAgent {
pub async fn new(
model: Option<CModule>,
config_path: Option<PathBuf>,
server_type: Option<String>,
training_prefix: Option<String>,
training_port: Option<String>,
training_host: Option<String>,
) -> RelayRLAgent {
let config_path: Option<PathBuf> = resolve_config_json_path!(config_path);
let training_server: String;
let config_path_clone: Option<PathBuf> = config_path.clone();
{
let config: ConfigLoader = ConfigLoader::new(None, config_path_clone);
let prefix: String = training_prefix.unwrap_or(config.train_server.prefix);
let host: String = training_host.unwrap_or(config.train_server.host);
let port: String = training_port.unwrap_or(config.train_server.port);
training_server = format!("{}{}:{}", prefix, host, port);
}
match server_type {
Some(_) => {
let server_type_str: String =
server_type.expect("Server type is None").to_lowercase();
if server_type_str == "grpc" {
new_grpc_agent(model, config_path, Some(training_server)).await
} else if server_type_str == "zmq" {
new_zmq_agent(model, config_path, Some(training_server)).await
} else {
panic!("[RelayRLAgent - new] Server type unavailable: Input 'zmq' or 'grpc'")
}
}
None => new_zmq_agent(model, config_path, Some(training_server)).await,
}
}
pub async fn restart_agent(
self,
training_server_address: Option<String>,
) -> Option<Vec<Result<(), Box<dyn std::error::Error>>>> {
match (self.agent_zmq, self.agent_grpc) {
(Some(mut zmq_agent), _) => {
Some(zmq_agent.restart_agent(training_server_address).await)
}
(_, Some(mut grpc_agent)) => {
Some(grpc_agent.restart_agent(training_server_address).await)
}
_ => {
eprintln!("Agent instance not available");
None
}
}
}
pub async fn enable_agent(
self,
training_server_address: Option<String>,
) -> Option<Result<(), Box<dyn std::error::Error>>> {
match (self.agent_zmq, self.agent_grpc) {
(Some(mut zmq_agent), _) => Some(zmq_agent.enable_agent(training_server_address).await),
(_, Some(mut grpc_agent)) => {
Some(grpc_agent.enable_agent(training_server_address).await)
}
_ => {
eprintln!("Agent instance not available");
None
}
}
}
pub async fn disable_agent(self) -> Option<Result<(), Box<dyn std::error::Error>>> {
match (self.agent_zmq, self.agent_grpc) {
(Some(mut zmq_agent), _) => Some(zmq_agent.disable_agent().await),
(_, Some(mut grpc_agent)) => Some(grpc_agent.disable_agent().await),
_ => {
eprintln!("Agent instance not available");
None
}
}
}
}
#[cfg(feature = "grpc_network")]
async fn new_grpc_agent(
model: Option<CModule>,
config_path: Option<PathBuf>,
training_server: Option<String>,
) -> RelayRLAgent {
RelayRLAgent {
#[cfg(feature = "zmq_network")]
agent_zmq: None,
agent_grpc: Some(RelayRLAgentGrpc::init_agent(model, config_path, training_server).await),
}
}
#[cfg(not(feature = "grpc_network"))]
async fn new_grpc_agent(
_model: Option<CModule>,
_config_path: Option<PathBuf>,
_training_server: Option<String>,
) -> RelayRLAgent {
panic!("[RelayRLAgent - new] gRPC feature not enabled")
}
#[cfg(feature = "zmq_network")]
async fn new_zmq_agent(
model: Option<CModule>,
config_path: Option<PathBuf>,
training_server: Option<String>,
) -> RelayRLAgent {
RelayRLAgent {
agent_zmq: Some(RelayRLAgentZmq::init_agent(
model,
config_path,
training_server,
)),
#[cfg(feature = "grpc_network")]
agent_grpc: None,
}
}
#[cfg(not(feature = "zmq_network"))]
async fn new_zmq_agent(
_model: Option<CModule>,
_config_path: Option<PathBuf>,
_training_server: Option<String>,
) -> RelayRLAgent {
panic!("[RelayRLAgent - new] ZMQ feature not enabled")
}