use crate::blocks::BlockKind;
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum ResNetDepth {
R18,
R34,
R50,
R101,
R152,
}
impl ResNetDepth {
pub fn layers(self) -> [usize; 4] {
match self {
ResNetDepth::R18 => [2, 2, 2, 2],
ResNetDepth::R34 => [3, 4, 6, 3],
ResNetDepth::R50 => [3, 4, 6, 3],
ResNetDepth::R101 => [3, 4, 23, 3],
ResNetDepth::R152 => [3, 8, 36, 3],
}
}
pub fn block(self) -> BlockKind {
match self {
ResNetDepth::R18 | ResNetDepth::R34 => BlockKind::Basic,
ResNetDepth::R50 | ResNetDepth::R101 | ResNetDepth::R152 => BlockKind::Bottleneck,
}
}
pub fn expansion(self) -> usize {
self.block().expansion()
}
}
#[derive(Copy, Clone, Debug)]
pub enum OutputMode {
Classification { num_classes: usize },
Features,
}
#[derive(Clone, Debug)]
pub struct ResNetConfig {
pub depth: ResNetDepth,
pub output: OutputMode,
pub max_batch_size: usize,
}
impl ResNetConfig {
pub fn new(depth: ResNetDepth, output: OutputMode) -> Self {
Self { depth, output, max_batch_size: 1 }
}
pub fn with_max_batch_size(mut self, max_batch_size: usize) -> Self {
self.max_batch_size = max_batch_size;
self
}
}