use axonml_autograd::Variable;
use axonml_nn::{BatchNorm2d, Conv2d, Module, Parameter, ReLU};
use axonml_tensor::Tensor;
use crate::ops::{FaceDetection, nms};
struct BlazeBlock {
dw_conv: Conv2d,
dw_bn: BatchNorm2d,
pw_conv: Conv2d,
pw_bn: BatchNorm2d,
project: Option<(Conv2d, BatchNorm2d)>,
relu: ReLU,
}
impl BlazeBlock {
fn new(in_channels: usize, out_channels: usize, stride: usize) -> Self {
let project = if in_channels != out_channels || stride != 1 {
Some((
Conv2d::with_options(
in_channels,
out_channels,
(1, 1),
(stride, stride),
(0, 0),
false,
),
BatchNorm2d::new(out_channels),
))
} else {
None
};
Self {
dw_conv: Conv2d::with_groups(
in_channels,
in_channels,
(3, 3),
(stride, stride),
(1, 1),
true,
in_channels,
),
dw_bn: BatchNorm2d::new(in_channels),
pw_conv: Conv2d::with_options(in_channels, out_channels, (1, 1), (1, 1), (0, 0), true),
pw_bn: BatchNorm2d::new(out_channels),
project,
relu: ReLU,
}
}
fn forward(&self, x: &Variable) -> Variable {
let identity = match &self.project {
Some((conv, bn)) => bn.forward(&conv.forward(x)),
None => x.clone(),
};
let out = self.dw_conv.forward(x);
let out = self.dw_bn.forward(&out);
let out = self.relu.forward(&out);
let out = self.pw_conv.forward(&out);
let out = self.pw_bn.forward(&out);
let out = out.add_var(&identity);
self.relu.forward(&out)
}
fn parameters(&self) -> Vec<Parameter> {
let mut p = Vec::new();
p.extend(self.dw_conv.parameters());
p.extend(self.dw_bn.parameters());
p.extend(self.pw_conv.parameters());
p.extend(self.pw_bn.parameters());
if let Some((conv, bn)) = &self.project {
p.extend(conv.parameters());
p.extend(bn.parameters());
}
p
}
fn train_mode(&mut self) {
self.dw_bn.train();
self.pw_bn.train();
if let Some((_, bn)) = &mut self.project {
bn.train();
}
}
fn eval_mode(&mut self) {
self.dw_bn.eval();
self.pw_bn.eval();
if let Some((_, bn)) = &mut self.project {
bn.eval();
}
}
}
struct DoubleBlazeBlock {
dw_conv1: Conv2d,
dw_bn1: BatchNorm2d,
pw_conv1: Conv2d,
pw_bn1: BatchNorm2d,
dw_conv2: Conv2d,
dw_bn2: BatchNorm2d,
pw_conv2: Conv2d,
pw_bn2: BatchNorm2d,
project: Option<(Conv2d, BatchNorm2d)>,
relu: ReLU,
}
impl DoubleBlazeBlock {
fn new(in_channels: usize, mid_channels: usize, out_channels: usize, stride: usize) -> Self {
let project = if in_channels != out_channels || stride != 1 {
Some((
Conv2d::with_options(
in_channels,
out_channels,
(1, 1),
(stride, stride),
(0, 0),
false,
),
BatchNorm2d::new(out_channels),
))
} else {
None
};
Self {
dw_conv1: Conv2d::with_groups(
in_channels,
in_channels,
(3, 3),
(stride, stride),
(1, 1),
true,
in_channels,
),
dw_bn1: BatchNorm2d::new(in_channels),
pw_conv1: Conv2d::with_options(in_channels, mid_channels, (1, 1), (1, 1), (0, 0), true),
pw_bn1: BatchNorm2d::new(mid_channels),
dw_conv2: Conv2d::with_groups(
mid_channels,
mid_channels,
(3, 3),
(1, 1),
(1, 1),
true,
mid_channels,
),
dw_bn2: BatchNorm2d::new(mid_channels),
pw_conv2: Conv2d::with_options(
mid_channels,
out_channels,
(1, 1),
(1, 1),
(0, 0),
true,
),
pw_bn2: BatchNorm2d::new(out_channels),
project,
relu: ReLU,
}
}
fn forward(&self, x: &Variable) -> Variable {
let identity = match &self.project {
Some((conv, bn)) => bn.forward(&conv.forward(x)),
None => x.clone(),
};
let out = self.dw_conv1.forward(x);
let out = self.dw_bn1.forward(&out);
let out = self.relu.forward(&out);
let out = self.pw_conv1.forward(&out);
let out = self.pw_bn1.forward(&out);
let out = self.relu.forward(&out);
let out = self.dw_conv2.forward(&out);
let out = self.dw_bn2.forward(&out);
let out = self.relu.forward(&out);
let out = self.pw_conv2.forward(&out);
let out = self.pw_bn2.forward(&out);
let out = out.add_var(&identity);
self.relu.forward(&out)
}
fn parameters(&self) -> Vec<Parameter> {
let mut p = Vec::new();
p.extend(self.dw_conv1.parameters());
p.extend(self.dw_bn1.parameters());
p.extend(self.pw_conv1.parameters());
p.extend(self.pw_bn1.parameters());
p.extend(self.dw_conv2.parameters());
p.extend(self.dw_bn2.parameters());
p.extend(self.pw_conv2.parameters());
p.extend(self.pw_bn2.parameters());
if let Some((conv, bn)) = &self.project {
p.extend(conv.parameters());
p.extend(bn.parameters());
}
p
}
fn train_mode(&mut self) {
self.dw_bn1.train();
self.pw_bn1.train();
self.dw_bn2.train();
self.pw_bn2.train();
if let Some((_, bn)) = &mut self.project {
bn.train();
}
}
fn eval_mode(&mut self) {
self.dw_bn1.eval();
self.pw_bn1.eval();
self.dw_bn2.eval();
self.pw_bn2.eval();
if let Some((_, bn)) = &mut self.project {
bn.eval();
}
}
}
pub struct BlazeFace {
stem: Conv2d,
stem_bn: BatchNorm2d,
relu: ReLU,
front_blocks: Vec<BlazeBlock>,
back_blocks: Vec<DoubleBlazeBlock>,
cls_pre1: Conv2d,
cls_pre1_bn: BatchNorm2d,
cls_head1: Conv2d,
bbox_head1: Conv2d,
cls_pre2: Conv2d,
cls_pre2_bn: BatchNorm2d,
cls_head2: Conv2d,
bbox_head2: Conv2d,
}
#[allow(dead_code)]
struct AnchorConfig {
feature_size: usize,
num_anchors: usize,
stride: f32,
anchor_sizes: Vec<f32>,
}
impl Default for BlazeFace {
fn default() -> Self {
Self::new()
}
}
impl BlazeFace {
pub fn new() -> Self {
Self {
stem: Conv2d::with_options(3, 24, (5, 5), (2, 2), (2, 2), true),
stem_bn: BatchNorm2d::new(24),
relu: ReLU,
front_blocks: vec![
BlazeBlock::new(24, 24, 1),
BlazeBlock::new(24, 28, 1),
BlazeBlock::new(28, 32, 2), BlazeBlock::new(32, 36, 1),
BlazeBlock::new(36, 42, 1),
BlazeBlock::new(42, 48, 2), BlazeBlock::new(48, 56, 1),
BlazeBlock::new(56, 64, 1),
],
back_blocks: vec![
DoubleBlazeBlock::new(64, 64, 96, 2), DoubleBlazeBlock::new(96, 96, 96, 1),
DoubleBlazeBlock::new(96, 96, 96, 1),
],
cls_pre1: Conv2d::with_options(64, 64, (3, 3), (1, 1), (1, 1), true),
cls_pre1_bn: BatchNorm2d::new(64),
cls_head1: Conv2d::with_options(64, 2, (3, 3), (1, 1), (1, 1), true),
bbox_head1: Conv2d::with_options(64, 2 * 4, (3, 3), (1, 1), (1, 1), true),
cls_pre2: Conv2d::with_options(96, 96, (3, 3), (1, 1), (1, 1), true),
cls_pre2_bn: BatchNorm2d::new(96),
cls_head2: Conv2d::with_options(96, 6, (3, 3), (1, 1), (1, 1), true),
bbox_head2: Conv2d::with_options(96, 6 * 4, (3, 3), (1, 1), (1, 1), true),
}
}
pub fn forward_train(&self, x: &Variable) -> (Variable, Variable) {
let (feat1, feat2) = self.forward_features(x);
let cls1_feat = self
.relu
.forward(&self.cls_pre1_bn.forward(&self.cls_pre1.forward(&feat1)));
let cls1 = self.cls_head1.forward(&cls1_feat); let bbox1 = self.bbox_head1.forward(&feat1);
let cls2_feat = self
.relu
.forward(&self.cls_pre2_bn.forward(&self.cls_pre2.forward(&feat2)));
let cls2 = self.cls_head2.forward(&cls2_feat); let bbox2 = self.bbox_head2.forward(&feat2);
let batch = x.shape()[0];
let h1 = cls1.shape()[2];
let w1 = cls1.shape()[3];
let h2 = cls2.shape()[2];
let w2 = cls2.shape()[3];
let cls1_perm = cls1.transpose(1, 2); let cls1_perm = cls1_perm.transpose(2, 3); let cls1_flat = cls1_perm.reshape(&[batch, h1 * w1 * 2]);
let cls2_perm = cls2.transpose(1, 2); let cls2_perm = cls2_perm.transpose(2, 3); let cls2_flat = cls2_perm.reshape(&[batch, h2 * w2 * 6]);
let cls_all = Variable::cat(&[&cls1_flat, &cls2_flat], 1);
let n1 = 2 * h1 * w1;
let n2 = 6 * h2 * w2;
let bbox1_perm = bbox1.transpose(1, 2); let bbox1_perm = bbox1_perm.transpose(2, 3); let bbox1_flat = bbox1_perm.reshape(&[batch, n1, 4]);
let bbox2_perm = bbox2.transpose(1, 2); let bbox2_perm = bbox2_perm.transpose(2, 3); let bbox2_flat = bbox2_perm.reshape(&[batch, n2, 4]);
let bbox_all = Variable::cat(&[&bbox1_flat, &bbox2_flat], 1);
(cls_all, bbox_all)
}
pub fn generate_anchors(input_size: usize) -> Vec<[f32; 4]> {
let scale1_anchors: Vec<(f32, f32, f32)> = vec![
(0.75, 1.0, 1.0), (1.5, 1.0, 1.0), ];
let scale2_anchors: Vec<(f32, f32, f32)> = vec![
(1.0, 1.0, 1.0), (1.5, 1.0, 1.0), (2.5, 1.0, 1.0), (4.0, 1.0, 1.0), (1.5, 1.0, 1.3), (6.0, 1.0, 1.0), ];
let mut anchors = Vec::new();
let feat1_size = input_size / 8;
let feat2_size = input_size / 16;
let stride1 = input_size as f32 / feat1_size as f32;
for y in 0..feat1_size {
for x in 0..feat1_size {
let cx = (x as f32 + 0.5) * stride1;
let cy = (y as f32 + 0.5) * stride1;
for &(scale, wr, hr) in &scale1_anchors {
let base = stride1 * scale;
anchors.push([cx, cy, base * wr, base * hr]);
}
}
}
let stride2 = input_size as f32 / feat2_size as f32;
for y in 0..feat2_size {
for x in 0..feat2_size {
let cx = (x as f32 + 0.5) * stride2;
let cy = (y as f32 + 0.5) * stride2;
for &(scale, wr, hr) in &scale2_anchors {
let base = stride2 * scale;
anchors.push([cx, cy, base * wr, base * hr]);
}
}
}
anchors
}
pub fn detect(
&self,
image: &Variable,
score_threshold: f32,
nms_threshold: f32,
) -> Vec<FaceDetection> {
let input_size = image.shape()[2]; let (cls_logits, bbox_preds) = self.forward_train(image);
let cls_data = cls_logits.data().to_vec();
let bbox_data = bbox_preds.data().to_vec();
let anchors = Self::generate_anchors(input_size);
let num_anchors = anchors.len();
let mut all_boxes = Vec::new();
let mut all_scores = Vec::new();
for i in 0..num_anchors {
let score = 1.0 / (1.0 + (-cls_data[i]).exp()); if score < score_threshold {
continue;
}
let anchor = &anchors[i];
let (acx, acy, aw, ah) = (anchor[0], anchor[1], anchor[2], anchor[3]);
let dx = bbox_data[i * 4];
let dy = bbox_data[i * 4 + 1];
let dw = bbox_data[i * 4 + 2];
let dh = bbox_data[i * 4 + 3];
let pred_cx = acx + dx * aw;
let pred_cy = acy + dy * ah;
let pred_w = aw * dw.exp();
let pred_h = ah * dh.exp();
all_boxes.push([
pred_cx - pred_w / 2.0,
pred_cy - pred_h / 2.0,
pred_cx + pred_w / 2.0,
pred_cy + pred_h / 2.0,
]);
all_scores.push(score);
}
if all_scores.is_empty() {
return vec![];
}
let n = all_scores.len();
let boxes_flat: Vec<f32> = all_boxes.iter().flat_map(|b| b.iter().copied()).collect();
let boxes_tensor = Tensor::from_vec(boxes_flat, &[n, 4]).unwrap();
let scores_tensor = Tensor::from_vec(all_scores.clone(), &[n]).unwrap();
let keep = nms(&boxes_tensor, &scores_tensor, nms_threshold);
keep.iter()
.map(|&i| FaceDetection {
bbox: all_boxes[i],
confidence: all_scores[i],
landmarks: None,
})
.collect()
}
fn forward_features(&self, x: &Variable) -> (Variable, Variable) {
let mut out = self
.relu
.forward(&self.stem_bn.forward(&self.stem.forward(x)));
for block in &self.front_blocks {
out = block.forward(&out);
}
let feat1 = out.clone();
let mut out = feat1.clone();
for block in &self.back_blocks {
out = block.forward(&out);
}
let feat2 = out;
(feat1, feat2)
}
pub(crate) fn forward_cls(&self, x: &Variable) -> Variable {
let (cls, _bbox) = self.forward_train(x);
cls
}
#[allow(dead_code)]
pub(crate) fn forward_bbox(&self, x: &Variable) -> Variable {
let (_cls, bbox) = self.forward_train(x);
bbox.reshape(&[bbox.shape()[0], bbox.shape()[1] * 4])
}
}
impl Module for BlazeFace {
fn forward(&self, x: &Variable) -> Variable {
self.forward_cls(x)
}
fn parameters(&self) -> Vec<Parameter> {
let mut p = Vec::new();
p.extend(self.stem.parameters());
p.extend(self.stem_bn.parameters());
for block in &self.front_blocks {
p.extend(block.parameters());
}
for block in &self.back_blocks {
p.extend(block.parameters());
}
p.extend(self.cls_pre1.parameters());
p.extend(self.cls_pre1_bn.parameters());
p.extend(self.cls_head1.parameters());
p.extend(self.bbox_head1.parameters());
p.extend(self.cls_pre2.parameters());
p.extend(self.cls_pre2_bn.parameters());
p.extend(self.cls_head2.parameters());
p.extend(self.bbox_head2.parameters());
p
}
fn train(&mut self) {
self.stem_bn.train();
self.cls_pre1_bn.train();
self.cls_pre2_bn.train();
for block in &mut self.front_blocks {
block.train_mode();
}
for block in &mut self.back_blocks {
block.train_mode();
}
}
fn eval(&mut self) {
self.stem_bn.eval();
self.cls_pre1_bn.eval();
self.cls_pre2_bn.eval();
for block in &mut self.front_blocks {
block.eval_mode();
}
for block in &mut self.back_blocks {
block.eval_mode();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_blazeblock() {
let block = BlazeBlock::new(24, 24, 1);
let input = Variable::new(
Tensor::from_vec(vec![0.1; 1 * 24 * 16 * 16], &[1, 24, 16, 16]).unwrap(),
false,
);
let output = block.forward(&input);
assert_eq!(output.shape(), vec![1, 24, 16, 16]);
}
#[test]
fn test_blazeblock_downsample() {
let block = BlazeBlock::new(24, 48, 2);
let input = Variable::new(
Tensor::from_vec(vec![0.1; 1 * 24 * 16 * 16], &[1, 24, 16, 16]).unwrap(),
false,
);
let output = block.forward(&input);
assert_eq!(output.shape(), vec![1, 48, 8, 8]);
}
#[test]
fn test_double_blazeblock() {
let block = DoubleBlazeBlock::new(64, 64, 96, 2);
let input = Variable::new(
Tensor::from_vec(vec![0.1; 1 * 64 * 16 * 16], &[1, 64, 16, 16]).unwrap(),
false,
);
let output = block.forward(&input);
assert_eq!(output.shape(), vec![1, 96, 8, 8]);
}
#[test]
fn test_blazeface_creation() {
let model = BlazeFace::new();
let params = model.parameters();
assert!(!params.is_empty());
let total: usize = params
.iter()
.map(|p| p.variable().data().to_vec().len())
.sum();
println!("BlazeFace total params: {}", total);
assert!(total < 300_000);
}
#[test]
fn test_blazeface_forward_train() {
let model = BlazeFace::new();
let input = Variable::new(
Tensor::from_vec(vec![0.1; 1 * 3 * 128 * 128], &[1, 3, 128, 128]).unwrap(),
false,
);
let (cls, bbox) = model.forward_train(&input);
assert_eq!(cls.shape(), vec![1, 896]);
assert_eq!(bbox.shape(), vec![1, 896, 4]);
}
#[test]
fn test_blazeface_anchors() {
let anchors = BlazeFace::generate_anchors(128);
assert_eq!(anchors.len(), 896);
let a = &anchors[0];
assert!(a[0] > 0.0 && a[0] < 128.0); assert!(a[1] > 0.0 && a[1] < 128.0); }
#[test]
fn test_blazeface_detect() {
let model = BlazeFace::new();
let input = Variable::new(
Tensor::from_vec(vec![0.5; 1 * 3 * 128 * 128], &[1, 3, 128, 128]).unwrap(),
false,
);
let _dets = model.detect(&input, 0.5, 0.3);
}
#[test]
fn test_blazeface_backward() {
let model = BlazeFace::new();
let input = Variable::new(
Tensor::from_vec(vec![0.1; 1 * 3 * 128 * 128], &[1, 3, 128, 128]).unwrap(),
true,
);
let (cls, _bbox) = model.forward_train(&input);
let loss = cls.mean();
loss.backward();
let grads: usize = model
.parameters()
.iter()
.filter(|p| p.variable().grad().is_some())
.count();
assert!(grads > 0, "At least some parameters should have gradients");
}
}