use relayrl_types::HyperparameterArgs;
use std::collections::HashMap;
#[cfg(feature = "client")]
pub(super) const CLIENT_NAMESPACE_PREFIX: &str = "client";
#[cfg(feature = "client")]
pub(super) const ACTOR_CONTEXT: &str = "actor";
#[cfg(feature = "client")]
pub(super) const ENVIRONMENT_CONTEXT_PREFIX: &str = "env";
#[cfg(feature = "client")]
pub(super) const SCALE_MANAGER_CONTEXT: &str = "scaler";
#[cfg(all(feature = "client", feature = "zmq-transport"))]
pub(super) const ZMQ_CLIENT_CONTEXT: &str = "zmq-client";
#[cfg(all(feature = "client", feature = "nats-transport"))]
pub(super) const NATS_CLIENT_CONTEXT: &str = "nats-client";
#[cfg(feature = "client")]
pub(super) const ROUTER_NAMESPACE_PREFIX: &str = "router";
#[cfg(all(feature = "client", feature = "nats-transport"))]
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
pub(super) const RECEIVER_CONTEXT: &str = "receiver";
#[cfg(feature = "client")]
pub(super) const BUFFER_CONTEXT: &str = "buffer";
#[cfg(all(
feature = "training-server",
any(feature = "nats-transport", feature = "zmq-transport")
))]
pub(super) const TRAINING_SERVER_NAMESPACE_PREFIX: &str = "training-server";
#[cfg(all(
feature = "inference-server",
any(feature = "nats-transport", feature = "zmq-transport")
))]
pub(super) const INFERENCE_SERVER_NAMESPACE_PREFIX: &str = "inference-server";
#[cfg(all(
any(feature = "training-server", feature = "inference-server"),
any(feature = "nats-transport", feature = "zmq-transport")
))]
pub(super) const WORKER_CONTEXT: &str = "worker";
#[cfg(all(
any(feature = "training-server", feature = "inference-server"),
feature = "zmq-transport"
))]
pub(super) const ZMQ_SERVER_CONTEXT: &str = "zmq-server";
#[cfg(all(
any(feature = "training-server", feature = "inference-server"),
feature = "nats-transport"
))]
pub(super) const NATS_SERVER_CONTEXT: &str = "nats-server";
#[cfg(feature = "client")]
pub mod client;
#[cfg(any(feature = "inference-server", feature = "training-server"))]
pub mod server;
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
#[derive(Clone, Copy, Debug)]
pub enum TransportType {
#[cfg(feature = "nats-transport")]
NATS,
#[cfg(feature = "zmq-transport")]
ZMQ,
}
#[cfg(any(feature = "nats-transport", feature = "zmq-transport"))]
impl Default for TransportType {
fn default() -> Self {
#[cfg(all(feature = "zmq-transport", not(feature = "nats-transport")))]
return TransportType::ZMQ;
#[cfg(all(not(feature = "zmq-transport"), feature = "nats-transport"))]
return TransportType::NATS;
#[cfg(all(feature = "zmq-transport", feature = "nats-transport"))]
return TransportType::NATS;
}
}
pub fn parse_args(hyperparameter_args: &Option<HyperparameterArgs>) -> HashMap<String, String> {
let mut hyperparams_map: HashMap<String, String> = HashMap::new();
match hyperparameter_args {
Some(HyperparameterArgs::Map(map)) => {
for entry in map.iter() {
hyperparams_map.insert(entry.0.to_string(), entry.1.to_string());
}
}
Some(HyperparameterArgs::List(args)) => {
for arg in args {
let split: Vec<&str> = if arg.contains("=") {
arg.split('=').collect()
} else if arg.contains(' ') {
arg.split(' ').collect()
} else {
panic!(
"[TrainingServer - new] Invalid hyperparameter argument: {}",
arg
);
};
if split.len() != 2 {
panic!(
"[TrainingServer - new] Invalid hyperparameter argument: {}",
arg
);
}
hyperparams_map.insert(split[0].to_string(), split[1].to_string());
}
}
None => {}
}
hyperparams_map
}