use crate::graph::{Graph, NodeId};
pub struct SDUNetConfig {
pub batch_size: u32,
pub in_channels: u32,
pub base_channels: u32,
pub num_levels: usize,
pub resolution: u32,
pub num_groups: u32,
pub gn_eps: f32,
}
impl SDUNetConfig {
pub fn tiny() -> Self {
Self {
batch_size: 4,
in_channels: 4,
base_channels: 32,
num_levels: 3,
resolution: 32,
num_groups: 8,
gn_eps: 1e-5,
}
}
pub fn small() -> Self {
Self {
batch_size: 2,
in_channels: 4,
base_channels: 64,
num_levels: 3,
resolution: 32,
num_groups: 16,
gn_eps: 1e-5,
}
}
fn channel_mult(&self) -> Vec<u32> {
(0..self.num_levels).map(|i| 1u32 << i).collect()
}
}
struct SpatialState {
h: u32,
w: u32,
c: u32,
}
fn resblock(
g: &mut Graph,
x: NodeId,
prefix: &str,
cfg: &SDUNetConfig,
s: &SpatialState,
out_c: u32,
) -> NodeId {
let batch = cfg.batch_size;
let spatial = s.h * s.w;
let in_c = s.c;
let gn1_w = g.parameter(&format!("{prefix}.norm1.weight"), &[in_c as usize]);
let gn1_b = g.parameter(&format!("{prefix}.norm1.bias"), &[in_c as usize]);
let h = g.group_norm(
x,
gn1_w,
gn1_b,
batch,
in_c,
spatial,
cfg.num_groups,
cfg.gn_eps,
);
let h = g.silu(h);
let conv1_w = g.parameter(
&format!("{prefix}.conv1.weight"),
&[(out_c * in_c * 9) as usize],
);
let h = g.conv2d(h, conv1_w, batch, in_c, s.h, s.w, out_c, 3, 3, 1, 1);
let gn2_w = g.parameter(&format!("{prefix}.norm2.weight"), &[out_c as usize]);
let gn2_b = g.parameter(&format!("{prefix}.norm2.bias"), &[out_c as usize]);
let h = g.group_norm(
h,
gn2_w,
gn2_b,
batch,
out_c,
spatial,
cfg.num_groups,
cfg.gn_eps,
);
let h = g.silu(h);
let conv2_w = g.parameter(
&format!("{prefix}.conv2.weight"),
&[(out_c * out_c * 9) as usize],
);
let h = g.conv2d(h, conv2_w, batch, out_c, s.h, s.w, out_c, 3, 3, 1, 1);
if in_c == out_c {
g.add(x, h)
} else {
let res_w = g.parameter(
&format!("{prefix}.res_conv.weight"),
&[(out_c * in_c) as usize],
);
let x_proj = g.conv2d(x, res_w, batch, in_c, s.h, s.w, out_c, 1, 1, 1, 0);
g.add(x_proj, h)
}
}
pub fn build_training_graph(g: &mut Graph, cfg: &SDUNetConfig) -> NodeId {
let batch = cfg.batch_size;
let res = cfg.resolution;
let in_c = cfg.in_channels;
let in_size = (batch * in_c * res * res) as usize;
let ch_mults = cfg.channel_mult();
let noisy = g.input("noisy_latent", &[in_size]);
let target = g.input("noise_target", &[in_size]);
let base_c = cfg.base_channels;
let conv_in_w = g.parameter("conv_in.weight", &[(base_c * in_c * 3 * 3) as usize]);
let mut x = g.conv2d(noisy, conv_in_w, batch, in_c, res, res, base_c, 3, 3, 1, 1);
let mut s = SpatialState {
h: res,
w: res,
c: base_c,
};
let mut skip_connections: Vec<(NodeId, SpatialState)> = Vec::new();
for (level, &mult) in ch_mults.iter().enumerate() {
let out_c = base_c * mult;
x = resblock(g, x, &format!("encoder.{level}.resblock"), cfg, &s, out_c);
s.c = out_c;
skip_connections.push((
x,
SpatialState {
h: s.h,
w: s.w,
c: s.c,
},
));
if level < cfg.num_levels - 1 {
let down_w = g.parameter(
&format!("encoder.{level}.downsample.weight"),
&[(out_c * out_c * 3 * 3) as usize],
);
x = g.conv2d(x, down_w, batch, out_c, s.h, s.w, out_c, 3, 3, 2, 1);
s.h = (s.h + 2 - 3) / 2 + 1; s.w = (s.w + 2 - 3) / 2 + 1;
}
}
x = resblock(g, x, "middle.resblock", cfg, &s, s.c);
for level in (0..cfg.num_levels).rev() {
let out_c = base_c * ch_mults[level];
if level < cfg.num_levels - 1 {
x = g.upsample_2x(x, batch, s.c, s.h, s.w);
s.h *= 2;
s.w *= 2;
}
let &(skip, ref skip_s) = &skip_connections[level];
assert_eq!(s.h, skip_s.h, "spatial mismatch at level {level}");
assert_eq!(s.w, skip_s.w, "spatial mismatch at level {level}");
let spatial = s.h * s.w;
x = g.concat(x, skip, batch, s.c, skip_s.c, spatial);
let concat_c = s.c + skip_s.c;
let dec_s = SpatialState {
h: s.h,
w: s.w,
c: concat_c,
};
x = resblock(
g,
x,
&format!("decoder.{level}.resblock"),
cfg,
&dec_s,
out_c,
);
s.c = out_c;
}
let gn_out_w = g.parameter("conv_out.norm.weight", &[base_c as usize]);
let gn_out_b = g.parameter("conv_out.norm.bias", &[base_c as usize]);
x = g.group_norm(
x,
gn_out_w,
gn_out_b,
batch,
base_c,
res * res,
cfg.num_groups,
cfg.gn_eps,
);
x = g.silu(x);
let conv_out_w = g.parameter("conv_out.weight", &[(in_c * base_c * 3 * 3) as usize]);
let pred = g.conv2d(x, conv_out_w, batch, base_c, res, res, in_c, 3, 3, 1, 1);
let neg_target = g.neg(target);
let diff = g.add(pred, neg_target);
let sq = g.mul(diff, diff);
g.mean_all(sq)
}
pub fn count_params(cfg: &SDUNetConfig) -> usize {
let mut g = Graph::new();
let _loss = build_training_graph(&mut g, cfg);
g.nodes()
.iter()
.filter(|n| matches!(n.op, crate::graph::Op::Parameter { .. }))
.map(|n| n.ty.num_elements())
.sum()
}