use std::collections::HashMap;
#[cfg(feature = "client")]
pub mod client;
#[cfg(any(feature = "inference_server", feature = "training_server"))]
pub mod server;
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
#[derive(Clone, Copy, Debug)]
pub enum TransportType {
#[cfg(feature = "zmq_transport")]
ZMQ,
}
#[cfg(any(feature = "async_transport", feature = "sync_transport"))]
impl Default for TransportType {
fn default() -> Self {
#[cfg(feature = "zmq_transport")]
TransportType::ZMQ
}
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub enum HyperparameterArgs {
Map(HashMap<String, String>),
List(Vec<String>),
}
pub fn parse_args(hyperparameter_args: &Option<HyperparameterArgs>) -> HashMap<String, String> {
let mut hyperparams_map: HashMap<String, String> = HashMap::new();
match hyperparameter_args {
Some(HyperparameterArgs::Map(map)) => {
for entry in map.iter() {
hyperparams_map.insert(entry.0.to_string(), entry.1.to_string());
}
}
Some(HyperparameterArgs::List(args)) => {
for arg in args {
let split: Vec<&str> = if arg.contains("=") {
arg.split('=').collect()
} else if arg.contains(' ') {
arg.split(' ').collect()
} else {
panic!(
"[TrainingServer - new] Invalid hyperparameter argument: {}",
arg
);
};
if split.len() != 2 {
panic!(
"[TrainingServer - new] Invalid hyperparameter argument: {}",
arg
);
}
hyperparams_map.insert(split[0].to_string(), split[1].to_string());
}
}
None => {}
}
hyperparams_map
}