use std::collections::VecDeque;
#[derive(Debug, Clone, PartialEq)]
pub enum OpType {
Identity,
Conv3x3,
Conv5x5,
DilatedConv3x3,
DepthwiseSep3x3,
MaxPool3x3,
AvgPool3x3,
Skip,
Zero,
Linear(usize),
GRU,
LSTM,
}
impl OpType {
pub fn num_params(&self, in_channels: usize) -> usize {
match self {
Self::Identity | Self::Skip | Self::Zero | Self::MaxPool3x3 | Self::AvgPool3x3 => 0,
Self::Conv3x3 => 9 * in_channels * in_channels,
Self::Conv5x5 => 25 * in_channels * in_channels,
Self::DilatedConv3x3 => 9 * in_channels * in_channels,
Self::DepthwiseSep3x3 => 9 * in_channels + in_channels * in_channels,
Self::Linear(out) => in_channels * out,
Self::GRU | Self::LSTM => 4 * in_channels * in_channels,
}
}
pub fn flops(&self, in_channels: usize, spatial: usize) -> usize {
let spatial_sq = spatial * spatial;
match self {
Self::Conv3x3 => 9 * 2 * in_channels * in_channels * spatial_sq,
Self::Conv5x5 => 25 * 2 * in_channels * in_channels * spatial_sq,
Self::DilatedConv3x3 => 9 * 2 * in_channels * in_channels * spatial_sq,
Self::DepthwiseSep3x3 => (9 * in_channels + in_channels * in_channels) * spatial_sq,
Self::Linear(out) => in_channels * out * 2,
Self::GRU | Self::LSTM => 4 * in_channels * in_channels * 2,
_ => in_channels * spatial_sq,
}
}
}
#[derive(Debug, Clone)]
pub struct ArchNode {
pub id: usize,
pub name: String,
pub output_channels: usize,
}
#[derive(Debug, Clone)]
pub struct ArchEdge {
pub from: usize,
pub to: usize,
pub op: OpType,
}
#[derive(Debug, Clone)]
pub struct Architecture {
pub nodes: Vec<ArchNode>,
pub edges: Vec<ArchEdge>,
pub n_cells: usize,
pub channels: usize,
pub n_classes: usize,
}
impl Architecture {
pub fn new(n_cells: usize, channels: usize, n_classes: usize) -> Self {
Self {
nodes: Vec::new(),
edges: Vec::new(),
n_cells,
channels,
n_classes,
}
}
pub fn total_params(&self) -> usize {
self.edges
.iter()
.map(|e| e.op.num_params(self.channels))
.sum()
}
pub fn total_flops(&self, spatial: usize) -> usize {
self.edges
.iter()
.map(|e| e.op.flops(self.channels, spatial))
.sum()
}
pub fn topological_sort(&self) -> Vec<usize> {
let n = self.nodes.len();
let mut in_degree = vec![0usize; n];
let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
for e in &self.edges {
if e.from < n && e.to < n {
adj[e.from].push(e.to);
in_degree[e.to] = in_degree[e.to].saturating_add(1);
}
}
let mut queue: VecDeque<usize> = (0..n).filter(|&i| in_degree[i] == 0).collect();
let mut order = Vec::with_capacity(n);
while let Some(v) = queue.pop_front() {
order.push(v);
for &u in &adj[v] {
in_degree[u] -= 1;
if in_degree[u] == 0 {
queue.push_back(u);
}
}
}
order
}
pub fn is_valid(&self) -> bool {
self.topological_sort().len() == self.nodes.len()
}
}
pub struct SearchSpace {
pub operations: Vec<OpType>,
pub n_nodes_per_cell: usize,
pub n_input_nodes: usize,
pub channels: Vec<usize>,
pub n_cells_range: (usize, usize),
}
impl SearchSpace {
pub fn darts_like(n_nodes: usize) -> Self {
Self {
operations: vec![
OpType::Skip,
OpType::Zero,
OpType::MaxPool3x3,
OpType::AvgPool3x3,
OpType::Conv3x3,
OpType::Conv5x5,
OpType::DilatedConv3x3,
OpType::DepthwiseSep3x3,
],
n_nodes_per_cell: n_nodes,
n_input_nodes: 2,
channels: vec![16, 32, 64, 128],
n_cells_range: (2, 20),
}
}
pub fn n_architectures(&self) -> u64 {
let n_ops = self.operations.len() as u64;
let n = self.n_nodes_per_cell;
let n_edges = (self.n_input_nodes * n) as u64;
n_ops.saturating_pow(n_edges as u32)
}
pub fn sample_random(
&self,
rng: &mut (impl scirs2_core::random::Rng + ?Sized),
) -> Architecture {
use scirs2_core::random::{Rng, RngExt};
let cells_lo = self.n_cells_range.0;
let cells_hi = self.n_cells_range.1;
let n_cells = if cells_lo >= cells_hi {
cells_lo
} else {
rng.random_range(cells_lo..=cells_hi)
};
let ch_idx = rng.random_range(0..self.channels.len());
let channels = self.channels[ch_idx];
let n_classes = 10;
let mut arch = Architecture::new(n_cells, channels, n_classes);
for c in 0..n_cells {
for j in 0..self.n_nodes_per_cell {
arch.nodes.push(ArchNode {
id: c * self.n_nodes_per_cell + j,
name: format!("cell{}_node{}", c, j),
output_channels: channels,
});
}
}
for c in 0..n_cells {
for j in 0..self.n_nodes_per_cell {
let n_inputs = j.min(self.n_input_nodes);
for k in 0..n_inputs.max(1) {
let from_offset = if n_inputs == 0 {
0
} else {
c * self.n_nodes_per_cell + (j.saturating_sub(k + 1))
};
let to = c * self.n_nodes_per_cell + j;
if from_offset != to {
let op_idx = rng.random_range(0..self.operations.len());
arch.edges.push(ArchEdge {
from: from_offset,
to,
op: self.operations[op_idx].clone(),
});
}
}
}
}
arch
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::random::{rngs::StdRng, SeedableRng};
#[test]
fn test_search_space_sample_produces_arch() {
let space = SearchSpace::darts_like(4);
let mut rng = StdRng::seed_from_u64(42);
let arch = space.sample_random(&mut rng);
assert!(arch.n_cells > 0);
assert!(arch.channels > 0);
assert!(!arch.nodes.is_empty());
}
#[test]
fn test_architecture_params_nonzero_for_conv() {
let mut arch = Architecture::new(2, 32, 10);
arch.nodes.push(ArchNode {
id: 0,
name: "node0".into(),
output_channels: 32,
});
arch.nodes.push(ArchNode {
id: 1,
name: "node1".into(),
output_channels: 32,
});
arch.edges.push(ArchEdge {
from: 0,
to: 1,
op: OpType::Conv3x3,
});
assert_eq!(arch.total_params(), 9 * 32 * 32);
assert!(arch.total_flops(8) > 0);
}
#[test]
fn test_topological_sort_linear_dag() {
let mut arch = Architecture::new(1, 32, 10);
for i in 0..3_usize {
arch.nodes.push(ArchNode {
id: i,
name: format!("n{}", i),
output_channels: 32,
});
}
arch.edges.push(ArchEdge {
from: 0,
to: 1,
op: OpType::Skip,
});
arch.edges.push(ArchEdge {
from: 1,
to: 2,
op: OpType::Conv3x3,
});
assert!(arch.is_valid());
let order = arch.topological_sort();
assert_eq!(order.len(), 3);
assert_eq!(order[0], 0);
}
#[test]
fn test_n_architectures_positive() {
let space = SearchSpace::darts_like(4);
assert!(space.n_architectures() > 0);
}
#[test]
fn test_op_type_zero_params_for_pooling() {
assert_eq!(OpType::MaxPool3x3.num_params(64), 0);
assert_eq!(OpType::AvgPool3x3.num_params(64), 0);
assert_eq!(OpType::Skip.num_params(64), 0);
}
}