use serde::{Deserialize, Serialize};
use crate::error::{CnnError, CnnResult};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ConvConfig {
pub in_channels: usize,
pub out_channels: usize,
pub kernel_size: usize,
pub stride: usize,
pub padding: usize,
pub dilation: usize,
pub groups: usize,
pub bias: bool,
}
impl ConvConfig {
pub fn builder() -> ConvConfigBuilder {
ConvConfigBuilder::default()
}
pub fn validate(&self) -> CnnResult<()> {
if self.in_channels == 0 {
return Err(CnnError::InvalidConfig(
"in_channels must be greater than 0".to_string(),
));
}
if self.out_channels == 0 {
return Err(CnnError::InvalidConfig(
"out_channels must be greater than 0".to_string(),
));
}
if self.kernel_size == 0 {
return Err(CnnError::InvalidConfig(
"kernel_size must be greater than 0".to_string(),
));
}
if self.stride == 0 {
return Err(CnnError::InvalidConfig(
"stride must be greater than 0".to_string(),
));
}
if self.dilation == 0 {
return Err(CnnError::InvalidConfig(
"dilation must be greater than 0".to_string(),
));
}
if self.groups == 0 {
return Err(CnnError::InvalidConfig(
"groups must be greater than 0".to_string(),
));
}
if self.in_channels % self.groups != 0 {
return Err(CnnError::InvalidConfig(format!(
"in_channels ({}) must be divisible by groups ({})",
self.in_channels, self.groups
)));
}
if self.out_channels % self.groups != 0 {
return Err(CnnError::InvalidConfig(format!(
"out_channels ({}) must be divisible by groups ({})",
self.out_channels, self.groups
)));
}
Ok(())
}
#[inline]
pub fn output_size(&self, input_size: usize) -> usize {
let effective_kernel = self.dilation * (self.kernel_size - 1) + 1;
(input_size + 2 * self.padding - effective_kernel) / self.stride + 1
}
}
#[derive(Default)]
pub struct ConvConfigBuilder {
in_channels: Option<usize>,
out_channels: Option<usize>,
kernel_size: Option<usize>,
stride: usize,
padding: usize,
dilation: usize,
groups: usize,
bias: bool,
}
impl ConvConfigBuilder {
pub fn in_channels(mut self, in_channels: usize) -> Self {
self.in_channels = Some(in_channels);
self
}
pub fn out_channels(mut self, out_channels: usize) -> Self {
self.out_channels = Some(out_channels);
self
}
pub fn kernel_size(mut self, kernel_size: usize) -> Self {
self.kernel_size = Some(kernel_size);
self
}
pub fn stride(mut self, stride: usize) -> Self {
self.stride = stride;
self
}
pub fn padding(mut self, padding: usize) -> Self {
self.padding = padding;
self
}
pub fn dilation(mut self, dilation: usize) -> Self {
self.dilation = dilation;
self
}
pub fn groups(mut self, groups: usize) -> Self {
self.groups = groups;
self
}
pub fn bias(mut self, bias: bool) -> Self {
self.bias = bias;
self
}
pub fn build(self) -> CnnResult<ConvConfig> {
let config = ConvConfig {
in_channels: self.in_channels.ok_or_else(|| {
CnnError::InvalidConfig("in_channels must be specified".to_string())
})?,
out_channels: self.out_channels.ok_or_else(|| {
CnnError::InvalidConfig("out_channels must be specified".to_string())
})?,
kernel_size: self.kernel_size.ok_or_else(|| {
CnnError::InvalidConfig("kernel_size must be specified".to_string())
})?,
stride: if self.stride == 0 { 1 } else { self.stride },
padding: self.padding,
dilation: if self.dilation == 0 { 1 } else { self.dilation },
groups: if self.groups == 0 { 1 } else { self.groups },
bias: self.bias,
};
config.validate()?;
Ok(config)
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PoolConfig {
pub kernel_size: usize,
pub stride: usize,
pub padding: usize,
pub ceil_mode: bool,
}
impl PoolConfig {
pub fn builder() -> PoolConfigBuilder {
PoolConfigBuilder::default()
}
pub fn validate(&self) -> CnnResult<()> {
if self.kernel_size == 0 {
return Err(CnnError::InvalidConfig(
"kernel_size must be greater than 0".to_string(),
));
}
if self.stride == 0 {
return Err(CnnError::InvalidConfig(
"stride must be greater than 0".to_string(),
));
}
Ok(())
}
#[inline]
pub fn output_size(&self, input_size: usize) -> usize {
let numerator = input_size + 2 * self.padding - self.kernel_size;
if self.ceil_mode {
(numerator + self.stride - 1) / self.stride + 1
} else {
numerator / self.stride + 1
}
}
}
#[derive(Default)]
pub struct PoolConfigBuilder {
kernel_size: Option<usize>,
stride: Option<usize>,
padding: usize,
ceil_mode: bool,
}
impl PoolConfigBuilder {
pub fn kernel_size(mut self, kernel_size: usize) -> Self {
self.kernel_size = Some(kernel_size);
self
}
pub fn stride(mut self, stride: usize) -> Self {
self.stride = Some(stride);
self
}
pub fn padding(mut self, padding: usize) -> Self {
self.padding = padding;
self
}
pub fn ceil_mode(mut self, ceil_mode: bool) -> Self {
self.ceil_mode = ceil_mode;
self
}
pub fn build(self) -> CnnResult<PoolConfig> {
let kernel_size = self.kernel_size.ok_or_else(|| {
CnnError::InvalidConfig("kernel_size must be specified".to_string())
})?;
let config = PoolConfig {
kernel_size,
stride: self.stride.unwrap_or(kernel_size),
padding: self.padding,
ceil_mode: self.ceil_mode,
};
config.validate()?;
Ok(config)
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct NormConfig {
pub num_features: usize,
pub eps: f32,
pub momentum: f32,
pub affine: bool,
pub track_running_stats: bool,
}
impl NormConfig {
pub fn builder() -> NormConfigBuilder {
NormConfigBuilder::default()
}
pub fn validate(&self) -> CnnResult<()> {
if self.num_features == 0 {
return Err(CnnError::InvalidConfig(
"num_features must be greater than 0".to_string(),
));
}
if self.eps <= 0.0 || !self.eps.is_finite() {
return Err(CnnError::InvalidConfig(
"eps must be positive and finite".to_string(),
));
}
if self.momentum < 0.0 || self.momentum > 1.0 {
return Err(CnnError::InvalidConfig(
"momentum must be in range [0.0, 1.0]".to_string(),
));
}
Ok(())
}
}
#[derive(Default)]
pub struct NormConfigBuilder {
num_features: Option<usize>,
eps: f32,
momentum: f32,
affine: bool,
track_running_stats: bool,
}
impl NormConfigBuilder {
pub fn num_features(mut self, num_features: usize) -> Self {
self.num_features = Some(num_features);
self
}
pub fn eps(mut self, eps: f32) -> Self {
self.eps = eps;
self
}
pub fn momentum(mut self, momentum: f32) -> Self {
self.momentum = momentum;
self
}
pub fn affine(mut self, affine: bool) -> Self {
self.affine = affine;
self
}
pub fn track_running_stats(mut self, track: bool) -> Self {
self.track_running_stats = track;
self
}
pub fn build(self) -> CnnResult<NormConfig> {
let config = NormConfig {
num_features: self.num_features.ok_or_else(|| {
CnnError::InvalidConfig("num_features must be specified".to_string())
})?,
eps: if self.eps == 0.0 { 1e-5 } else { self.eps },
momentum: if self.momentum == 0.0 { 0.1 } else { self.momentum },
affine: self.affine,
track_running_stats: self.track_running_stats,
};
config.validate()?;
Ok(config)
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct BackboneConfig {
pub in_channels: usize,
pub out_features: usize,
pub base_channels: usize,
pub num_stages: usize,
pub global_pool: bool,
pub dropout: f32,
}
impl BackboneConfig {
pub fn builder() -> BackboneConfigBuilder {
BackboneConfigBuilder::default()
}
pub fn validate(&self) -> CnnResult<()> {
if self.in_channels == 0 {
return Err(CnnError::InvalidConfig(
"in_channels must be greater than 0".to_string(),
));
}
if self.out_features == 0 {
return Err(CnnError::InvalidConfig(
"out_features must be greater than 0".to_string(),
));
}
if self.base_channels == 0 {
return Err(CnnError::InvalidConfig(
"base_channels must be greater than 0".to_string(),
));
}
if self.num_stages == 0 {
return Err(CnnError::InvalidConfig(
"num_stages must be greater than 0".to_string(),
));
}
if self.dropout < 0.0 || self.dropout > 1.0 {
return Err(CnnError::InvalidConfig(
"dropout must be in range [0.0, 1.0]".to_string(),
));
}
Ok(())
}
}
#[derive(Default)]
pub struct BackboneConfigBuilder {
in_channels: usize,
out_features: usize,
base_channels: usize,
num_stages: usize,
global_pool: bool,
dropout: f32,
}
impl BackboneConfigBuilder {
pub fn in_channels(mut self, in_channels: usize) -> Self {
self.in_channels = in_channels;
self
}
pub fn out_features(mut self, out_features: usize) -> Self {
self.out_features = out_features;
self
}
pub fn base_channels(mut self, base_channels: usize) -> Self {
self.base_channels = base_channels;
self
}
pub fn num_stages(mut self, num_stages: usize) -> Self {
self.num_stages = num_stages;
self
}
pub fn global_pool(mut self, global_pool: bool) -> Self {
self.global_pool = global_pool;
self
}
pub fn dropout(mut self, dropout: f32) -> Self {
self.dropout = dropout;
self
}
pub fn build(self) -> CnnResult<BackboneConfig> {
let config = BackboneConfig {
in_channels: if self.in_channels == 0 { 3 } else { self.in_channels },
out_features: if self.out_features == 0 { 512 } else { self.out_features },
base_channels: if self.base_channels == 0 { 64 } else { self.base_channels },
num_stages: if self.num_stages == 0 { 4 } else { self.num_stages },
global_pool: self.global_pool,
dropout: self.dropout,
};
config.validate()?;
Ok(config)
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ProjectorConfig {
pub in_features: usize,
pub hidden_dim: usize,
pub out_features: usize,
pub num_layers: usize,
pub use_bn: bool,
}
impl ProjectorConfig {
pub fn builder() -> ProjectorConfigBuilder {
ProjectorConfigBuilder::default()
}
pub fn validate(&self) -> CnnResult<()> {
if self.in_features == 0 {
return Err(CnnError::InvalidConfig(
"in_features must be greater than 0".to_string(),
));
}
if self.hidden_dim == 0 {
return Err(CnnError::InvalidConfig(
"hidden_dim must be greater than 0".to_string(),
));
}
if self.out_features == 0 {
return Err(CnnError::InvalidConfig(
"out_features must be greater than 0".to_string(),
));
}
if self.num_layers == 0 {
return Err(CnnError::InvalidConfig(
"num_layers must be greater than 0".to_string(),
));
}
Ok(())
}
}
#[derive(Default)]
pub struct ProjectorConfigBuilder {
in_features: Option<usize>,
hidden_dim: usize,
out_features: usize,
num_layers: usize,
use_bn: bool,
}
impl ProjectorConfigBuilder {
pub fn in_features(mut self, in_features: usize) -> Self {
self.in_features = Some(in_features);
self
}
pub fn hidden_dim(mut self, hidden_dim: usize) -> Self {
self.hidden_dim = hidden_dim;
self
}
pub fn out_features(mut self, out_features: usize) -> Self {
self.out_features = out_features;
self
}
pub fn num_layers(mut self, num_layers: usize) -> Self {
self.num_layers = num_layers;
self
}
pub fn use_bn(mut self, use_bn: bool) -> Self {
self.use_bn = use_bn;
self
}
pub fn build(self) -> CnnResult<ProjectorConfig> {
let in_features = self.in_features.ok_or_else(|| {
CnnError::InvalidConfig("in_features must be specified".to_string())
})?;
let config = ProjectorConfig {
in_features,
hidden_dim: if self.hidden_dim == 0 { in_features * 2 } else { self.hidden_dim },
out_features: if self.out_features == 0 { 128 } else { self.out_features },
num_layers: if self.num_layers == 0 { 2 } else { self.num_layers },
use_bn: self.use_bn,
};
config.validate()?;
Ok(config)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_conv_config_builder() {
let config = ConvConfig::builder()
.in_channels(3)
.out_channels(64)
.kernel_size(3)
.stride(1)
.padding(1)
.build()
.unwrap();
assert_eq!(config.in_channels, 3);
assert_eq!(config.out_channels, 64);
assert_eq!(config.kernel_size, 3);
assert_eq!(config.stride, 1);
assert_eq!(config.padding, 1);
assert_eq!(config.dilation, 1);
assert_eq!(config.groups, 1);
}
#[test]
fn test_conv_output_size() {
let config = ConvConfig::builder()
.in_channels(3)
.out_channels(64)
.kernel_size(3)
.stride(1)
.padding(1)
.build()
.unwrap();
assert_eq!(config.output_size(224), 224);
}
#[test]
fn test_conv_validation_grouped() {
let result = ConvConfig::builder()
.in_channels(64)
.out_channels(128)
.kernel_size(3)
.groups(32)
.build();
assert!(result.is_ok());
let result = ConvConfig::builder()
.in_channels(64)
.out_channels(128)
.kernel_size(3)
.groups(3)
.build();
assert!(result.is_err());
}
#[test]
fn test_pool_config_builder() {
let config = PoolConfig::builder()
.kernel_size(2)
.stride(2)
.build()
.unwrap();
assert_eq!(config.kernel_size, 2);
assert_eq!(config.stride, 2);
assert_eq!(config.padding, 0);
}
#[test]
fn test_pool_output_size() {
let config = PoolConfig::builder()
.kernel_size(2)
.stride(2)
.build()
.unwrap();
assert_eq!(config.output_size(224), 112);
}
#[test]
fn test_norm_config_builder() {
let config = NormConfig::builder()
.num_features(64)
.eps(1e-5)
.momentum(0.1)
.affine(true)
.track_running_stats(true)
.build()
.unwrap();
assert_eq!(config.num_features, 64);
assert_eq!(config.eps, 1e-5);
assert!(config.affine);
}
#[test]
fn test_backbone_config_builder() {
let config = BackboneConfig::builder()
.in_channels(3)
.out_features(512)
.base_channels(64)
.num_stages(4)
.global_pool(true)
.build()
.unwrap();
assert_eq!(config.in_channels, 3);
assert_eq!(config.out_features, 512);
assert_eq!(config.num_stages, 4);
assert!(config.global_pool);
}
#[test]
fn test_projector_config_builder() {
let config = ProjectorConfig::builder()
.in_features(512)
.hidden_dim(1024)
.out_features(128)
.num_layers(2)
.use_bn(true)
.build()
.unwrap();
assert_eq!(config.in_features, 512);
assert_eq!(config.hidden_dim, 1024);
assert_eq!(config.out_features, 128);
assert!(config.use_bn);
}
}