use ghostflow_core::Tensor;
use crate::deep::layers::{Conv2d, Dense, BatchNorm2d};
use crate::deep::activations::ReLU;
pub struct DARTSCell {
operations: Vec<Vec<MixedOp>>,
num_nodes: usize,
}
struct MixedOp {
ops: Vec<Operation>,
weights: Vec<f32>,
}
enum Operation {
SepConv3x3(Conv2d, BatchNorm2d),
SepConv5x5(Conv2d, BatchNorm2d),
DilConv3x3(Conv2d, BatchNorm2d),
MaxPool3x3,
AvgPool3x3,
Identity,
Zero,
}
impl Operation {
fn forward(&mut self, x: &Tensor, training: bool) -> Tensor {
match self {
Operation::SepConv3x3(conv, bn) => {
let mut out = conv.forward(x, training);
out = bn.forward(&out, training);
ReLU::new().forward(&out)
}
Operation::SepConv5x5(conv, bn) => {
let mut out = conv.forward(x, training);
out = bn.forward(&out, training);
ReLU::new().forward(&out)
}
Operation::DilConv3x3(conv, bn) => {
let mut out = conv.forward(x, training);
out = bn.forward(&out, training);
ReLU::new().forward(&out)
}
Operation::MaxPool3x3 => self.max_pool(x),
Operation::AvgPool3x3 => self.avg_pool(x),
Operation::Identity => x.clone(),
Operation::Zero => self.zeros_like(x),
}
}
fn max_pool(&self, x: &Tensor) -> Tensor {
x.clone() }
fn avg_pool(&self, x: &Tensor) -> Tensor {
x.clone() }
fn zeros_like(&self, x: &Tensor) -> Tensor {
Tensor::from_slice(&vec![0.0f32; x.data_f32().len()], x.dims()).unwrap()
}
}
impl MixedOp {
fn new(channels: usize) -> Self {
MixedOp {
ops: vec![
Operation::SepConv3x3(Conv2d::new(channels, channels, (3, 3)).padding((1, 1)), BatchNorm2d::new(channels)),
Operation::SepConv5x5(Conv2d::new(channels, channels, (5, 5)).padding((2, 2)), BatchNorm2d::new(channels)),
Operation::DilConv3x3(Conv2d::new(channels, channels, (3, 3)).padding((2, 2)), BatchNorm2d::new(channels)),
Operation::MaxPool3x3,
Operation::AvgPool3x3,
Operation::Identity,
Operation::Zero,
],
weights: vec![1.0 / 7.0; 7],
}
}
fn forward(&mut self, x: &Tensor, training: bool) -> Tensor {
let mut result = vec![0.0f32; x.data_f32().len()];
for (op, &weight) in self.ops.iter_mut().zip(self.weights.iter()) {
let op_out = op.forward(x, training);
let op_data = op_out.data_f32();
for (i, &val) in op_data.iter().enumerate() {
result[i] += weight * val;
}
}
Tensor::from_slice(&result, x.dims()).unwrap()
}
}
impl DARTSCell {
pub fn new(channels: usize, num_nodes: usize) -> Self {
let mut operations = Vec::new();
for i in 0..num_nodes {
let mut node_ops = Vec::new();
for _ in 0..=i+1 {
node_ops.push(MixedOp::new(channels));
}
operations.push(node_ops);
}
DARTSCell {
operations,
num_nodes,
}
}
pub fn forward(&mut self, x: &Tensor, training: bool) -> Tensor {
let mut states = vec![x.clone()];
for node_ops in &mut self.operations {
let mut node_out = vec![0.0f32; x.data_f32().len()];
for (i, op) in node_ops.iter_mut().enumerate() {
let op_result = op.forward(&states[i], training);
let op_data = op_result.data_f32();
for (j, &val) in op_data.iter().enumerate() {
node_out[j] += val;
}
}
states.push(Tensor::from_slice(&node_out, x.dims()).unwrap());
}
states[1..].iter().fold(states[1].clone(), |acc, s| acc)
}
}
pub struct DARTS {
stem: Conv2d,
cells: Vec<DARTSCell>,
fc: Dense,
}
impl DARTS {
pub fn new(num_classes: usize, num_cells: usize) -> Self {
DARTS {
stem: Conv2d::new(3, 48, (3, 3)).padding((1, 1)),
cells: (0..num_cells).map(|_| DARTSCell::new(48, 4)).collect(),
fc: Dense::new(48, num_classes),
}
}
pub fn forward(&mut self, x: &Tensor, training: bool) -> Tensor {
let mut out = self.stem.forward(x, training);
for cell in &mut self.cells {
out = cell.forward(&out, training);
}
self.fc.forward(&out, training)
}
}
pub struct ENASController {
lstm: Dense, classifier: Dense,
}
impl ENASController {
pub fn new(hidden_size: usize, num_ops: usize) -> Self {
ENASController {
lstm: Dense::new(hidden_size, hidden_size),
classifier: Dense::new(hidden_size, num_ops),
}
}
pub fn sample_architecture(&mut self, num_layers: usize, training: bool) -> Vec<usize> {
let mut architecture = Vec::new();
let mut hidden = Tensor::from_slice(&vec![0.1f32; 128], &[1, 128]).unwrap();
for _ in 0..num_layers {
hidden = self.lstm.forward(&hidden, training);
let logits = self.classifier.forward(&hidden, training);
let op_idx = self.argmax(&logits);
architecture.push(op_idx);
}
architecture
}
fn argmax(&self, x: &Tensor) -> usize {
let data = x.data_f32();
let mut max_idx = 0;
let mut max_val = data[0];
for (i, &val) in data.iter().enumerate() {
if val > max_val {
max_val = val;
max_idx = i;
}
}
max_idx
}
}
pub struct NASNetCell {
operations: Vec<NASNetOp>,
}
struct NASNetOp {
conv: Conv2d,
bn: BatchNorm2d,
}
impl NASNetOp {
fn new(in_channels: usize, out_channels: usize, kernel_size: usize) -> Self {
NASNetOp {
conv: Conv2d::new(in_channels, out_channels, (kernel_size, kernel_size)).padding((kernel_size / 2, kernel_size / 2)),
bn: BatchNorm2d::new(out_channels),
}
}
fn forward(&mut self, x: &Tensor, training: bool) -> Tensor {
let mut out = self.conv.forward(x, training);
out = self.bn.forward(&out, training);
ReLU::new().forward(&out)
}
}
impl NASNetCell {
pub fn new(channels: usize) -> Self {
NASNetCell {
operations: vec![
NASNetOp::new(channels, channels, 3),
NASNetOp::new(channels, channels, 5),
NASNetOp::new(channels, channels, 3),
NASNetOp::new(channels, channels, 5),
],
}
}
pub fn forward(&mut self, x: &Tensor, training: bool) -> Tensor {
let mut outputs = Vec::new();
for op in &mut self.operations {
outputs.push(op.forward(x, training));
}
self.combine_outputs(&outputs)
}
fn combine_outputs(&self, outputs: &[Tensor]) -> Tensor {
if outputs.is_empty() {
return Tensor::from_slice(&[0.0f32], &[1, 1]).unwrap();
}
let mut result = outputs[0].data_f32().to_vec();
for output in &outputs[1..] {
let data = output.data_f32();
for (i, &val) in data.iter().enumerate() {
result[i] += val;
}
}
Tensor::from_slice(&result, outputs[0].dims()).unwrap()
}
}
pub struct NASNet {
stem: Conv2d,
cells: Vec<NASNetCell>,
fc: Dense,
}
impl NASNet {
pub fn new(num_classes: usize, num_cells: usize) -> Self {
NASNet {
stem: Conv2d::new(3, 96, (3, 3)).padding((1, 1)),
cells: (0..num_cells).map(|_| NASNetCell::new(96)).collect(),
fc: Dense::new(96, num_classes),
}
}
pub fn forward(&mut self, x: &Tensor, training: bool) -> Tensor {
let mut out = self.stem.forward(x, training);
for cell in &mut self.cells {
out = cell.forward(&out, training);
}
self.fc.forward(&out, training)
}
}
pub struct AmoebaNet {
stem: Conv2d,
cells: Vec<AmoebaCell>,
fc: Dense,
}
struct AmoebaCell {
ops: Vec<(Conv2d, BatchNorm2d)>,
}
impl AmoebaCell {
fn new(channels: usize) -> Self {
AmoebaCell {
ops: vec![
(Conv2d::new(channels, channels, (3, 3)).padding((1, 1)), BatchNorm2d::new(channels)),
(Conv2d::new(channels, channels, (5, 5)).padding((2, 2)), BatchNorm2d::new(channels)),
(Conv2d::new(channels, channels, (7, 7)).padding((3, 3)), BatchNorm2d::new(channels)),
],
}
}
fn forward(&mut self, x: &Tensor, training: bool) -> Tensor {
let mut outputs = Vec::new();
for (conv, bn) in &mut self.ops {
let mut out = conv.forward(x, training);
out = bn.forward(&out, training);
out = ReLU::new().forward(&out);
outputs.push(out);
}
self.combine_outputs(&outputs)
}
fn combine_outputs(&self, outputs: &[Tensor]) -> Tensor {
if outputs.is_empty() {
return Tensor::from_slice(&[0.0f32], &[1, 1]).unwrap();
}
let mut result = outputs[0].data_f32().to_vec();
for output in &outputs[1..] {
let data = output.data_f32();
for (i, &val) in data.iter().enumerate() {
result[i] += val;
}
}
Tensor::from_slice(&result, outputs[0].dims()).unwrap()
}
}
impl AmoebaNet {
pub fn new(num_classes: usize, num_cells: usize) -> Self {
AmoebaNet {
stem: Conv2d::new(3, 128, (3, 3)).padding((1, 1)),
cells: (0..num_cells).map(|_| AmoebaCell::new(128)).collect(),
fc: Dense::new(128, num_classes),
}
}
pub fn forward(&mut self, x: &Tensor, training: bool) -> Tensor {
let mut out = self.stem.forward(x, training);
for cell in &mut self.cells {
out = cell.forward(&out, training);
}
self.fc.forward(&out, training)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_darts() {
let mut model = DARTS::new(10, 8);
let input = Tensor::from_slice(&vec![0.5f32; 1 * 3 * 32 * 32], &[1, 3, 32, 32]).unwrap();
let output = model.forward(&input, false);
assert_eq!(output.dims()[1], 10);
}
#[test]
fn test_nasnet() {
let mut model = NASNet::new(1000, 6);
let input = Tensor::from_slice(&vec![0.5f32; 1 * 3 * 224 * 224], &[1, 3, 224, 224]).unwrap();
let output = model.forward(&input, false);
assert_eq!(output.dims()[1], 1000);
}
}