use torsh_core::device::DeviceType;
use torsh_core::error::Result;
#[cfg(feature = "std")]
use std::collections::HashMap;
#[cfg(not(feature = "std"))]
use hashbrown::HashMap;
#[cfg(feature = "serialize")]
use serde_json;
pub trait ModuleConstruct {
type Output;
fn try_new() -> Result<Self::Output>;
fn new() -> Self::Output
where
Self::Output: Sized,
{
Self::try_new().expect("Module construction failed")
}
fn default() -> Result<Self::Output> {
Self::try_new()
}
fn with_config(_config: &ModuleConfig) -> Result<Self::Output> {
Self::try_new()
}
}
#[derive(Debug, Clone)]
pub struct ModuleConfig {
pub training: bool,
pub device: DeviceType,
pub bias: bool,
pub dropout: f32,
#[cfg(feature = "serialize")]
pub custom: HashMap<String, serde_json::Value>,
#[cfg(not(feature = "serialize"))]
pub custom: HashMap<String, String>,
}
impl Default for ModuleConfig {
fn default() -> Self {
Self {
training: true,
device: DeviceType::Cpu,
bias: true,
dropout: 0.0,
custom: HashMap::new(),
}
}
}
impl ModuleConfig {
pub fn new() -> Self {
Self::default()
}
pub fn training(mut self, training: bool) -> Self {
self.training = training;
self
}
pub fn device(mut self, device: DeviceType) -> Self {
self.device = device;
self
}
pub fn bias(mut self, bias: bool) -> Self {
self.bias = bias;
self
}
pub fn dropout(mut self, dropout: f32) -> Self {
self.dropout = dropout;
self
}
#[cfg(feature = "serialize")]
pub fn custom_param<T: serde::Serialize>(mut self, name: &str, value: T) -> Self {
if let Ok(json_value) = serde_json::to_value(value) {
self.custom.insert(name.to_string(), json_value);
}
self
}
#[cfg(not(feature = "serialize"))]
pub fn custom_param<T: std::fmt::Display>(mut self, name: &str, value: T) -> Self {
self.custom.insert(name.to_string(), value.to_string());
self
}
#[cfg(feature = "serialize")]
pub fn get_custom<T: serde::de::DeserializeOwned>(&self, name: &str) -> Option<T> {
self.custom
.get(name)
.and_then(|v| serde_json::from_value(v.clone()).ok())
}
#[cfg(not(feature = "serialize"))]
pub fn get_custom(&self, name: &str) -> Option<String> {
self.custom.get(name).cloned()
}
}
#[macro_export]
macro_rules! impl_module_constructor {
($module_type:ty, $constructor:expr) => {
impl ModuleConstruct for $module_type {
type Output = $module_type;
fn try_new() -> Result<Self::Output> {
$constructor
}
}
};
}