use crate::error::IoError;
use crate::neural_network::layer::activation_layer::{
relu::ReLU, sigmoid::Sigmoid, softmax::Softmax, tanh::Tanh,
};
use crate::neural_network::layer::convolution_layer::{
conv_1d::Conv1D, conv_2d::Conv2D, conv_3d::Conv3D, depthwise_conv_2d::DepthwiseConv2D,
separable_conv_2d::SeparableConv2D,
};
use crate::neural_network::layer::dense::Dense;
use crate::neural_network::layer::layer_weight::LayerWeight;
use crate::neural_network::layer::recurrent_layer::{gru::GRU, lstm::LSTM, simple_rnn::SimpleRNN};
use crate::neural_network::layer::regularization_layer::normalization_layer::{
batch_normalization::BatchNormalization, group_normalization::GroupNormalization,
instance_normalization::InstanceNormalization, layer_normalization::LayerNormalization,
};
use crate::neural_network::neural_network_trait::ApplyWeights;
use crate::neural_network::neural_network_trait::Layer;
use crate::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum SerializableLayerWeight {
Dense(SerializableDenseWeight),
SimpleRNN(SerializableSimpleRNNWeight),
LSTM(SerializableLSTMWeight),
GRU(SerializableGRUWeight),
Conv1D(SerializableConv1DWeight),
Conv2D(SerializableConv2DWeight),
Conv3D(SerializableConv3DWeight),
SeparableConv2D(SerializableSeparableConv2DWeight),
DepthwiseConv2D(SerializableDepthwiseConv2DWeight),
BatchNormalization(SerializableBatchNormalizationWeight),
LayerNormalization(SerializableLayerNormalizationWeight),
InstanceNormalization(SerializableInstanceNormalizationWeight),
GroupNormalization(SerializableGroupNormalizationWeight),
Empty,
}
impl SerializableLayerWeight {
pub fn from_layer_weight(weight: &LayerWeight) -> Self {
match weight {
LayerWeight::Empty => SerializableLayerWeight::Empty,
LayerWeight::Dense(w) => SerializableLayerWeight::Dense(SerializableDenseWeight {
weight: w.weight.outer_iter().map(|row| row.to_vec()).collect(),
bias: w.bias.outer_iter().map(|row| row.to_vec()).collect(),
}),
LayerWeight::SimpleRNN(w) => {
SerializableLayerWeight::SimpleRNN(SerializableSimpleRNNWeight {
kernel: w.kernel.outer_iter().map(|row| row.to_vec()).collect(),
recurrent_kernel: w
.recurrent_kernel
.outer_iter()
.map(|row| row.to_vec())
.collect(),
bias: w.bias.outer_iter().map(|row| row.to_vec()).collect(),
})
}
LayerWeight::LSTM(w) => SerializableLayerWeight::LSTM(SerializableLSTMWeight {
input: SerializableGateWeight {
kernel: w
.input
.kernel
.outer_iter()
.map(|row| row.to_vec())
.collect(),
recurrent_kernel: w
.input
.recurrent_kernel
.outer_iter()
.map(|row| row.to_vec())
.collect(),
bias: w.input.bias.outer_iter().map(|row| row.to_vec()).collect(),
},
forget: SerializableGateWeight {
kernel: w
.forget
.kernel
.outer_iter()
.map(|row| row.to_vec())
.collect(),
recurrent_kernel: w
.forget
.recurrent_kernel
.outer_iter()
.map(|row| row.to_vec())
.collect(),
bias: w.forget.bias.outer_iter().map(|row| row.to_vec()).collect(),
},
cell: SerializableGateWeight {
kernel: w.cell.kernel.outer_iter().map(|row| row.to_vec()).collect(),
recurrent_kernel: w
.cell
.recurrent_kernel
.outer_iter()
.map(|row| row.to_vec())
.collect(),
bias: w.cell.bias.outer_iter().map(|row| row.to_vec()).collect(),
},
output: SerializableGateWeight {
kernel: w
.output
.kernel
.outer_iter()
.map(|row| row.to_vec())
.collect(),
recurrent_kernel: w
.output
.recurrent_kernel
.outer_iter()
.map(|row| row.to_vec())
.collect(),
bias: w.output.bias.outer_iter().map(|row| row.to_vec()).collect(),
},
}),
LayerWeight::GRU(w) => SerializableLayerWeight::GRU(SerializableGRUWeight {
reset: SerializableGateWeight {
kernel: w
.reset
.kernel
.outer_iter()
.map(|row| row.to_vec())
.collect(),
recurrent_kernel: w
.reset
.recurrent_kernel
.outer_iter()
.map(|row| row.to_vec())
.collect(),
bias: w.reset.bias.outer_iter().map(|row| row.to_vec()).collect(),
},
update: SerializableGateWeight {
kernel: w
.update
.kernel
.outer_iter()
.map(|row| row.to_vec())
.collect(),
recurrent_kernel: w
.update
.recurrent_kernel
.outer_iter()
.map(|row| row.to_vec())
.collect(),
bias: w.update.bias.outer_iter().map(|row| row.to_vec()).collect(),
},
candidate: SerializableGateWeight {
kernel: w
.candidate
.kernel
.outer_iter()
.map(|row| row.to_vec())
.collect(),
recurrent_kernel: w
.candidate
.recurrent_kernel
.outer_iter()
.map(|row| row.to_vec())
.collect(),
bias: w
.candidate
.bias
.outer_iter()
.map(|row| row.to_vec())
.collect(),
},
}),
LayerWeight::Conv1D(w) => SerializableLayerWeight::Conv1D(SerializableConv1DWeight {
weight: w
.weight
.outer_iter()
.map(|d1| d1.outer_iter().map(|d2| d2.to_vec()).collect())
.collect(),
bias: w.bias.outer_iter().map(|row| row.to_vec()).collect(),
}),
LayerWeight::Conv2D(w) => SerializableLayerWeight::Conv2D(SerializableConv2DWeight {
weight: w
.weight
.outer_iter()
.map(|d1| {
d1.outer_iter()
.map(|d2| d2.outer_iter().map(|d3| d3.to_vec()).collect())
.collect()
})
.collect(),
bias: w.bias.outer_iter().map(|row| row.to_vec()).collect(),
}),
LayerWeight::Conv3D(w) => SerializableLayerWeight::Conv3D(SerializableConv3DWeight {
weight: w
.weight
.outer_iter()
.map(|d1| {
d1.outer_iter()
.map(|d2| {
d2.outer_iter()
.map(|d3| d3.outer_iter().map(|d4| d4.to_vec()).collect())
.collect()
})
.collect()
})
.collect(),
bias: w.bias.outer_iter().map(|row| row.to_vec()).collect(),
}),
LayerWeight::SeparableConv2DLayer(w) => {
SerializableLayerWeight::SeparableConv2D(SerializableSeparableConv2DWeight {
depthwise_weight: w
.depthwise_weight
.outer_iter()
.map(|d1| {
d1.outer_iter()
.map(|d2| d2.outer_iter().map(|d3| d3.to_vec()).collect())
.collect()
})
.collect(),
pointwise_weight: w
.pointwise_weight
.outer_iter()
.map(|d1| {
d1.outer_iter()
.map(|d2| d2.outer_iter().map(|d3| d3.to_vec()).collect())
.collect()
})
.collect(),
bias: w.bias.outer_iter().map(|row| row.to_vec()).collect(),
})
}
LayerWeight::DepthwiseConv2DLayer(w) => {
SerializableLayerWeight::DepthwiseConv2D(SerializableDepthwiseConv2DWeight {
weight: w
.weight
.outer_iter()
.map(|d1| {
d1.outer_iter()
.map(|d2| d2.outer_iter().map(|d3| d3.to_vec()).collect())
.collect()
})
.collect(),
bias: w.bias.to_vec(),
})
}
LayerWeight::BatchNormalization(w) => {
SerializableLayerWeight::BatchNormalization(SerializableBatchNormalizationWeight {
gamma: w.gamma.iter().cloned().collect(),
beta: w.beta.iter().cloned().collect(),
running_mean: w.running_mean.iter().cloned().collect(),
running_var: w.running_var.iter().cloned().collect(),
shape: w.gamma.shape().to_vec(),
})
}
LayerWeight::LayerNormalizationLayer(w) => {
SerializableLayerWeight::LayerNormalization(SerializableLayerNormalizationWeight {
gamma: w.gamma.iter().cloned().collect(),
beta: w.beta.iter().cloned().collect(),
shape: w.gamma.shape().to_vec(),
})
}
LayerWeight::InstanceNormalizationLayer(w) => {
SerializableLayerWeight::InstanceNormalization(
SerializableInstanceNormalizationWeight {
gamma: w.gamma.iter().cloned().collect(),
beta: w.beta.iter().cloned().collect(),
shape: w.gamma.shape().to_vec(),
},
)
}
LayerWeight::GroupNormalizationLayer(w) => {
SerializableLayerWeight::GroupNormalization(SerializableGroupNormalizationWeight {
gamma: w.gamma.iter().cloned().collect(),
beta: w.beta.iter().cloned().collect(),
shape: w.gamma.shape().to_vec(),
})
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LayerInfo {
pub layer_type: String,
pub output_shape: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SerializableLayer {
pub info: LayerInfo,
pub weights: SerializableLayerWeight,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SerializableSequential {
pub layers: Vec<SerializableLayer>,
}
macro_rules! apply_weights_with_activations {
($layer_any:expr, $weight:expr, $layer_type:ident, $layer_name:expr, $expected_type:expr) => {{
let applied = try_apply_with_activations!($layer_any, $weight, $layer_type, $layer_name);
if !applied {
return Err(IoError::StdIoError(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Expected {} layer but got {}", $layer_name, $expected_type),
)));
}
}};
}
macro_rules! apply_weights_simple {
($layer_any:expr, $weight:expr, $layer_type:ident, $layer_name:expr, $expected_type:expr) => {{
if let Some(layer) = $layer_any.downcast_mut::<$layer_type>() {
$weight.apply_to_layer(layer)?;
} else {
return Err(IoError::StdIoError(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Expected {} layer but got {}", $layer_name, $expected_type),
)));
}
}};
}
macro_rules! try_apply_with_activations {
($layer_any:expr, $weight:expr, $layer_type:ident, $layer_name:expr) => {{
if let Some(layer) = $layer_any.downcast_mut::<$layer_type<ReLU>>() {
$weight.apply_to_layer(layer)?;
true
} else if let Some(layer) = $layer_any.downcast_mut::<$layer_type<Sigmoid>>() {
$weight.apply_to_layer(layer)?;
true
} else if let Some(layer) = $layer_any.downcast_mut::<$layer_type<Softmax>>() {
$weight.apply_to_layer(layer)?;
true
} else if let Some(layer) = $layer_any.downcast_mut::<$layer_type<Tanh>>() {
$weight.apply_to_layer(layer)?;
true
} else {
false
}
}};
}
pub fn apply_weights_to_layer(
layer: &mut dyn Layer,
weights: &SerializableLayerWeight,
expected_type: &str,
) -> Result<(), IoError> {
use std::any::Any;
let layer_any: &mut dyn Any = layer;
match weights {
SerializableLayerWeight::Empty => {}
SerializableLayerWeight::Dense(w) => {
apply_weights_with_activations!(layer_any, w, Dense, "Dense", expected_type);
}
SerializableLayerWeight::SimpleRNN(w) => {
apply_weights_with_activations!(layer_any, w, SimpleRNN, "SimpleRNN", expected_type);
}
SerializableLayerWeight::LSTM(w) => {
apply_weights_with_activations!(layer_any, w, LSTM, "LSTM", expected_type);
}
SerializableLayerWeight::GRU(w) => {
apply_weights_with_activations!(layer_any, w, GRU, "GRU", expected_type);
}
SerializableLayerWeight::Conv1D(w) => {
apply_weights_with_activations!(layer_any, w, Conv1D, "Conv1D", expected_type);
}
SerializableLayerWeight::Conv2D(w) => {
apply_weights_with_activations!(layer_any, w, Conv2D, "Conv2D", expected_type);
}
SerializableLayerWeight::Conv3D(w) => {
apply_weights_with_activations!(layer_any, w, Conv3D, "Conv3D", expected_type);
}
SerializableLayerWeight::SeparableConv2D(w) => {
apply_weights_with_activations!(
layer_any,
w,
SeparableConv2D,
"SeparableConv2D",
expected_type
);
}
SerializableLayerWeight::DepthwiseConv2D(w) => {
apply_weights_with_activations!(
layer_any,
w,
DepthwiseConv2D,
"DepthwiseConv2D",
expected_type
);
}
SerializableLayerWeight::BatchNormalization(w) => {
apply_weights_simple!(
layer_any,
w,
BatchNormalization,
"BatchNormalization",
expected_type
);
}
SerializableLayerWeight::LayerNormalization(w) => {
apply_weights_simple!(
layer_any,
w,
LayerNormalization,
"LayerNormalization",
expected_type
);
}
SerializableLayerWeight::InstanceNormalization(w) => {
apply_weights_simple!(
layer_any,
w,
InstanceNormalization,
"InstanceNormalization",
expected_type
);
}
SerializableLayerWeight::GroupNormalization(w) => {
apply_weights_simple!(
layer_any,
w,
GroupNormalization,
"GroupNormalization",
expected_type
);
}
}
Ok(())
}
mod helper_function;
pub mod serializable_batch_normalization_weight;
pub mod serializable_conv_1d_weight;
pub mod serializable_conv_2d_weight;
pub mod serializable_conv_3d_weight;
pub mod serializable_dense_weight;
pub mod serializable_depthwise_conv_2d_weight;
pub mod serializable_gate_weight;
pub mod serializable_group_normalization;
pub mod serializable_gru_weight;
pub mod serializable_instance_normalization;
pub mod serializable_layer_normalization_weight;
pub mod serializable_lstm_weight;
pub mod serializable_separable_conv_2d_weight;
pub mod serializable_simple_rnn_weight;
pub use serializable_batch_normalization_weight::*;
pub use serializable_conv_1d_weight::*;
pub use serializable_conv_2d_weight::*;
pub use serializable_conv_3d_weight::*;
pub use serializable_dense_weight::*;
pub use serializable_depthwise_conv_2d_weight::*;
pub use serializable_gate_weight::*;
pub use serializable_group_normalization::*;
pub use serializable_gru_weight::*;
pub use serializable_instance_normalization::*;
pub use serializable_layer_normalization_weight::*;
pub use serializable_lstm_weight::*;
pub use serializable_separable_conv_2d_weight::*;
pub use serializable_simple_rnn_weight::*;