use crate::graph::{Graph, NodeId};
struct Spatial {
h: u32,
w: u32,
}
impl Spatial {
fn after_conv(&self, kernel: u32, stride: u32, padding: u32) -> Self {
Self {
h: (self.h + 2 * padding - kernel) / stride + 1,
w: (self.w + 2 * padding - kernel) / stride + 1,
}
}
}
pub fn build_graph(g: &mut Graph, batch: u32) -> NodeId {
let s = Spatial { h: 224, w: 224 };
let image = g.input("image", &[(batch * 3 * 224 * 224) as usize]);
let conv1_w = g.parameter("conv1.weight", &[64 * 3 * 7 * 7]);
let x = g.conv2d(image, conv1_w, batch, 3, s.h, s.w, 64, 7, 7, 2, 3);
let s = s.after_conv(7, 2, 3);
let bn1_bias = g.parameter("bn1.fused_bias", &[(batch * 64 * s.h * s.w) as usize]);
let x = g.add(x, bn1_bias);
let x = g.relu(x);
let x = g.max_pool_2d(x, batch, 64, s.h, s.w, 3, 3, 2, 1);
let s = s.after_conv(3, 2, 1);
let (x, s) = basic_block(g, x, &s, batch, 64, 64, 1, "layer1.0");
let (x, s) = basic_block(g, x, &s, batch, 64, 64, 1, "layer1.1");
let (x, s) = basic_block(g, x, &s, batch, 64, 128, 2, "layer2.0");
let (x, s) = basic_block(g, x, &s, batch, 128, 128, 1, "layer2.1");
let (x, s) = basic_block(g, x, &s, batch, 128, 256, 2, "layer3.0");
let (x, s) = basic_block(g, x, &s, batch, 256, 256, 1, "layer3.1");
let (x, s) = basic_block(g, x, &s, batch, 256, 512, 2, "layer4.0");
let (x, _s) = basic_block(g, x, &s, batch, 512, 512, 1, "layer4.1");
let spatial = 7 * 7; let x = g.global_avg_pool(x, batch, 512, spatial);
let fc_w = g.parameter("fc.weight", &[512, 1000]);
let fc_b = g.parameter("fc.bias", &[1000]);
let logits = g.matmul(x, fc_w);
g.bias_add(logits, fc_b)
}
fn basic_block(
g: &mut Graph,
x: NodeId,
s: &Spatial,
batch: u32,
in_c: u32,
out_c: u32,
stride: u32,
name: &str,
) -> (NodeId, Spatial) {
let s1 = s.after_conv(3, stride, 1);
let w1 = g.parameter(
&format!("{name}.conv1.weight"),
&[(out_c * in_c * 3 * 3) as usize],
);
let h = g.conv2d(x, w1, batch, in_c, s.h, s.w, out_c, 3, 3, stride, 1);
let bn1_b = g.parameter(
&format!("{name}.bn1.fused_bias"),
&[(batch * out_c * s1.h * s1.w) as usize],
);
let h = g.add(h, bn1_b);
let h = g.relu(h);
let w2 = g.parameter(
&format!("{name}.conv2.weight"),
&[(out_c * out_c * 3 * 3) as usize],
);
let h = g.conv2d(h, w2, batch, out_c, s1.h, s1.w, out_c, 3, 3, 1, 1);
let bn2_b = g.parameter(
&format!("{name}.bn2.fused_bias"),
&[(batch * out_c * s1.h * s1.w) as usize],
);
let h = g.add(h, bn2_b);
let shortcut = if stride > 1 || in_c != out_c {
let ds_w = g.parameter(
&format!("{name}.downsample.0.weight"),
&[(out_c * in_c) as usize],
);
let ds = g.conv2d(x, ds_w, batch, in_c, s.h, s.w, out_c, 1, 1, stride, 0);
let ds_bn_b = g.parameter(
&format!("{name}.downsample.1.fused_bias"),
&[(batch * out_c * s1.h * s1.w) as usize],
);
g.add(ds, ds_bn_b)
} else {
x
};
let out = g.add(h, shortcut);
let out = g.relu(out);
(out, s1)
}
pub fn weight_names(batch: u32) -> Vec<String> {
let mut names = Vec::new();
names.push("conv1.weight".into());
names.push("bn1.fused_bias".into());
for (layer_idx, &(in_c, out_c, stride)) in [
(64u32, 64u32, 1u32),
(64, 64, 1),
(64, 128, 2),
(128, 128, 1),
(128, 256, 2),
(256, 256, 1),
(256, 512, 2),
(512, 512, 1),
]
.iter()
.enumerate()
{
let stage = layer_idx / 2 + 1;
let block = layer_idx % 2;
let name = format!("layer{stage}.{block}");
names.push(format!("{name}.conv1.weight"));
names.push(format!("{name}.bn1.fused_bias"));
names.push(format!("{name}.conv2.weight"));
names.push(format!("{name}.bn2.fused_bias"));
if stride > 1 || in_c != out_c {
names.push(format!("{name}.downsample.0.weight"));
names.push(format!("{name}.downsample.1.fused_bias"));
}
}
names.push("fc.weight".into());
names.push("fc.bias".into());
let _ = batch;
names
}
pub fn fuse_bn_into_conv(
conv_weight: &[f32],
scale: &[f32],
bias: &[f32],
mean: &[f32],
var: &[f32],
eps: f32,
out_channels: usize,
kernel_size: usize,
batch: usize,
out_h: usize,
out_w: usize,
) -> (Vec<f32>, Vec<f32>) {
let in_channels = conv_weight.len() / (out_channels * kernel_size);
let mut w_fused = conv_weight.to_vec();
for co in 0..out_channels {
let inv_std = scale[co] / (var[co] + eps).sqrt();
let start = co * in_channels * kernel_size;
let end = start + in_channels * kernel_size;
for v in &mut w_fused[start..end] {
*v *= inv_std;
}
}
let spatial = out_h * out_w;
let full_size = batch * out_channels * spatial;
let mut b_fused = vec![0.0f32; full_size];
for n in 0..batch {
for co in 0..out_channels {
let inv_std = scale[co] / (var[co] + eps).sqrt();
let b = bias[co] - mean[co] * inv_std;
for s in 0..spatial {
b_fused[(n * out_channels + co) * spatial + s] = b;
}
}
}
(w_fused, b_fused)
}
pub fn build_resnet50(g: &mut Graph, batch: u32) -> NodeId {
let s = Spatial { h: 224, w: 224 };
let image = g.input("image", &[(batch * 3 * 224 * 224) as usize]);
let conv1_w = g.parameter("conv1.weight", &[64 * 3 * 7 * 7]);
let x = g.conv2d(image, conv1_w, batch, 3, s.h, s.w, 64, 7, 7, 2, 3);
let s = s.after_conv(7, 2, 3);
let bn1_bias = g.parameter("bn1.fused_bias", &[(batch * 64 * s.h * s.w) as usize]);
let x = g.add(x, bn1_bias);
let x = g.relu(x);
let x = g.max_pool_2d(x, batch, 64, s.h, s.w, 3, 3, 2, 1);
let s = s.after_conv(3, 2, 1);
let (x, s) = bottleneck(g, x, &s, batch, 64, 64, 256, 1, "layer1.0");
let (x, s) = bottleneck(g, x, &s, batch, 256, 64, 256, 1, "layer1.1");
let (x, s) = bottleneck(g, x, &s, batch, 256, 64, 256, 1, "layer1.2");
let (x, s) = bottleneck(g, x, &s, batch, 256, 128, 512, 2, "layer2.0");
let (x, s) = bottleneck(g, x, &s, batch, 512, 128, 512, 1, "layer2.1");
let (x, s) = bottleneck(g, x, &s, batch, 512, 128, 512, 1, "layer2.2");
let (x, s) = bottleneck(g, x, &s, batch, 512, 128, 512, 1, "layer2.3");
let (x, s) = bottleneck(g, x, &s, batch, 512, 256, 1024, 2, "layer3.0");
let (x, s) = bottleneck(g, x, &s, batch, 1024, 256, 1024, 1, "layer3.1");
let (x, s) = bottleneck(g, x, &s, batch, 1024, 256, 1024, 1, "layer3.2");
let (x, s) = bottleneck(g, x, &s, batch, 1024, 256, 1024, 1, "layer3.3");
let (x, s) = bottleneck(g, x, &s, batch, 1024, 256, 1024, 1, "layer3.4");
let (x, s) = bottleneck(g, x, &s, batch, 1024, 256, 1024, 1, "layer3.5");
let (x, s) = bottleneck(g, x, &s, batch, 1024, 512, 2048, 2, "layer4.0");
let (x, s) = bottleneck(g, x, &s, batch, 2048, 512, 2048, 1, "layer4.1");
let (x, _) = bottleneck(g, x, &s, batch, 2048, 512, 2048, 1, "layer4.2");
let x = g.global_avg_pool(x, batch, 2048, 7 * 7);
let fc_w = g.parameter("fc.weight", &[2048, 1000]);
let fc_b = g.parameter("fc.bias", &[1000]);
let logits = g.matmul(x, fc_w);
g.bias_add(logits, fc_b)
}
pub fn build_resnet50_training(batch: u32) -> Graph {
let mut g = Graph::new();
let logits = build_resnet50(&mut g, batch);
let labels = g.input("labels", &[batch as usize, 1000]);
let loss = g.cross_entropy_loss(logits, labels);
g.set_outputs(vec![loss]);
g
}
fn bottleneck(
g: &mut Graph,
x: NodeId,
s: &Spatial,
batch: u32,
in_c: u32,
mid_c: u32,
out_c: u32,
stride: u32,
name: &str,
) -> (NodeId, Spatial) {
let w1 = g.parameter(&format!("{name}.conv1.weight"), &[(mid_c * in_c) as usize]);
let h = g.conv2d(x, w1, batch, in_c, s.h, s.w, mid_c, 1, 1, 1, 0);
let bn1_b = g.parameter(
&format!("{name}.bn1.fused_bias"),
&[(batch * mid_c * s.h * s.w) as usize],
);
let h = g.add(h, bn1_b);
let h = g.relu(h);
let s1 = s.after_conv(3, stride, 1);
let w2 = g.parameter(
&format!("{name}.conv2.weight"),
&[(mid_c * mid_c * 9) as usize],
);
let h = g.conv2d(h, w2, batch, mid_c, s.h, s.w, mid_c, 3, 3, stride, 1);
let bn2_b = g.parameter(
&format!("{name}.bn2.fused_bias"),
&[(batch * mid_c * s1.h * s1.w) as usize],
);
let h = g.add(h, bn2_b);
let h = g.relu(h);
let w3 = g.parameter(&format!("{name}.conv3.weight"), &[(out_c * mid_c) as usize]);
let h = g.conv2d(h, w3, batch, mid_c, s1.h, s1.w, out_c, 1, 1, 1, 0);
let bn3_b = g.parameter(
&format!("{name}.bn3.fused_bias"),
&[(batch * out_c * s1.h * s1.w) as usize],
);
let h = g.add(h, bn3_b);
let shortcut = if stride > 1 || in_c != out_c {
let ds_w = g.parameter(
&format!("{name}.downsample.0.weight"),
&[(out_c * in_c) as usize],
);
let ds = g.conv2d(x, ds_w, batch, in_c, s.h, s.w, out_c, 1, 1, stride, 0);
let ds_bn_b = g.parameter(
&format!("{name}.downsample.1.fused_bias"),
&[(batch * out_c * s1.h * s1.w) as usize],
);
g.add(ds, ds_bn_b)
} else {
x
};
let out = g.add(h, shortcut);
let out = g.relu(out);
(out, s1)
}