use crate::{
error::{VisionError, VisionResult},
handle::LcgRng,
vit::vit_block::{gelu_exact, layer_norm, linear, softmax_rows},
};
const MASK_PENALTY: f32 = -100.0;
#[derive(Debug, Clone, PartialEq)]
pub struct SwinConfig {
pub dim: usize,
pub n_heads: usize,
pub window_size: usize,
pub input_h: usize,
pub input_w: usize,
pub shift: usize,
pub mlp_ratio: usize,
}
impl SwinConfig {
#[allow(clippy::too_many_arguments)]
pub fn new(
dim: usize,
n_heads: usize,
window_size: usize,
input_h: usize,
input_w: usize,
shift: usize,
mlp_ratio: usize,
) -> VisionResult<Self> {
if dim == 0 {
return Err(VisionError::InvalidEmbedDim(dim));
}
if n_heads == 0 {
return Err(VisionError::InvalidNumHeads(n_heads));
}
if dim % n_heads != 0 {
return Err(VisionError::HeadDimMismatch {
n_heads,
embed_dim: dim,
});
}
if window_size == 0 {
return Err(VisionError::InvalidPatchSize {
patch_size: window_size,
img_size: input_h,
});
}
if input_h == 0 || input_w == 0 {
return Err(VisionError::InvalidImageSize {
height: input_h,
width: input_w,
channels: dim,
});
}
if input_h % window_size != 0 {
return Err(VisionError::InvalidPatchSize {
patch_size: window_size,
img_size: input_h,
});
}
if input_w % window_size != 0 {
return Err(VisionError::InvalidPatchSize {
patch_size: window_size,
img_size: input_w,
});
}
if shift > 1 {
return Err(VisionError::Internal(format!(
"shift must be 0 (W-MSA) or 1 (SW-MSA), got {shift}"
)));
}
if mlp_ratio == 0 {
return Err(VisionError::Internal("mlp_ratio must be >= 1".to_string()));
}
Ok(Self {
dim,
n_heads,
window_size,
input_h,
input_w,
shift,
mlp_ratio,
})
}
#[must_use]
#[inline]
pub fn head_dim(&self) -> usize {
self.dim / self.n_heads
}
#[must_use]
#[inline]
pub fn hidden_dim(&self) -> usize {
self.mlp_ratio * self.dim
}
#[must_use]
#[inline]
pub fn window_tokens(&self) -> usize {
self.window_size * self.window_size
}
#[must_use]
#[inline]
pub fn n_windows(&self) -> usize {
(self.input_h / self.window_size) * (self.input_w / self.window_size)
}
#[must_use]
#[inline]
pub fn rel_table_len(&self) -> usize {
let span = 2 * self.window_size - 1;
span * span
}
#[must_use]
#[inline]
pub fn is_shifted(&self) -> bool {
self.shift != 0 && self.window_size < self.input_h.min(self.input_w)
}
#[must_use]
#[inline]
pub fn shift_size(&self) -> usize {
if self.is_shifted() {
self.window_size / 2
} else {
0
}
}
}
pub struct SwinWeights {
pub qkv_weight: Vec<f32>,
pub qkv_bias: Vec<f32>,
pub proj_weight: Vec<f32>,
pub proj_bias: Vec<f32>,
pub relative_position_bias_table: Vec<f32>,
pub mlp_w1: Vec<f32>,
pub mlp_b1: Vec<f32>,
pub mlp_w2: Vec<f32>,
pub mlp_b2: Vec<f32>,
pub ln1_gamma: Vec<f32>,
pub ln1_beta: Vec<f32>,
pub ln2_gamma: Vec<f32>,
pub ln2_beta: Vec<f32>,
}
impl SwinWeights {
pub fn default_init(cfg: &SwinConfig, rng: &mut LcgRng) -> Self {
let d = cfg.dim;
let hidden = cfg.hidden_dim();
let scale = 1.0 / (d as f32).sqrt();
let fill_scaled = |rng: &mut LcgRng, n: usize, sc: f32| -> Vec<f32> {
let mut v = vec![0.0f32; n];
rng.fill_normal(&mut v);
for x in &mut v {
*x *= sc;
}
v
};
let qkv_weight = fill_scaled(rng, 3 * d * d, scale);
let qkv_bias = vec![0.0f32; 3 * d];
let proj_weight = fill_scaled(rng, d * d, scale);
let proj_bias = vec![0.0f32; d];
let table_len = cfg.n_heads * cfg.rel_table_len();
let relative_position_bias_table = fill_scaled(rng, table_len, 0.02);
let mlp_w1 = fill_scaled(rng, hidden * d, scale);
let mlp_b1 = vec![0.0f32; hidden];
let mlp_w2 = fill_scaled(rng, d * hidden, scale);
let mlp_b2 = vec![0.0f32; d];
let ln1_gamma = vec![1.0f32; d];
let ln1_beta = vec![0.0f32; d];
let ln2_gamma = vec![1.0f32; d];
let ln2_beta = vec![0.0f32; d];
Self {
qkv_weight,
qkv_bias,
proj_weight,
proj_bias,
relative_position_bias_table,
mlp_w1,
mlp_b1,
mlp_w2,
mlp_b2,
ln1_gamma,
ln1_beta,
ln2_gamma,
ln2_beta,
}
}
}
pub struct SwinBlock {
pub cfg: SwinConfig,
pub weights: SwinWeights,
}
impl SwinBlock {
pub fn new(cfg: SwinConfig, rng: &mut LcgRng) -> VisionResult<Self> {
let cfg = SwinConfig::new(
cfg.dim,
cfg.n_heads,
cfg.window_size,
cfg.input_h,
cfg.input_w,
cfg.shift,
cfg.mlp_ratio,
)?;
let weights = SwinWeights::default_init(&cfg, rng);
Ok(Self { cfg, weights })
}
fn check_input_len(&self, x: &[f32]) -> VisionResult<()> {
let expected = self.cfg.input_h * self.cfg.input_w * self.cfg.dim;
if x.len() != expected {
return Err(VisionError::DimensionMismatch {
expected,
got: x.len(),
});
}
Ok(())
}
pub fn window_partition(&self, x: &[f32]) -> VisionResult<Vec<f32>> {
self.check_input_len(x)?;
let c = self.cfg.dim;
let h = self.cfg.input_h;
let w = self.cfg.input_w;
let m = self.cfg.window_size;
let win_rows = h / m;
let win_cols = w / m;
let mut out = vec![0.0f32; h * w * c];
let mut dst = 0usize;
for wr in 0..win_rows {
for wc in 0..win_cols {
for i in 0..m {
let row = wr * m + i;
for j in 0..m {
let col = wc * m + j;
let src = (row * w + col) * c;
out[dst..dst + c].copy_from_slice(&x[src..src + c]);
dst += c;
}
}
}
}
Ok(out)
}
pub fn window_reverse(&self, windows: &[f32]) -> VisionResult<Vec<f32>> {
let c = self.cfg.dim;
let h = self.cfg.input_h;
let w = self.cfg.input_w;
let expected = h * w * c;
if windows.len() != expected {
return Err(VisionError::DimensionMismatch {
expected,
got: windows.len(),
});
}
let m = self.cfg.window_size;
let win_rows = h / m;
let win_cols = w / m;
let mut out = vec![0.0f32; expected];
let mut src = 0usize;
for wr in 0..win_rows {
for wc in 0..win_cols {
for i in 0..m {
let row = wr * m + i;
for j in 0..m {
let col = wc * m + j;
let dst = (row * w + col) * c;
out[dst..dst + c].copy_from_slice(&windows[src..src + c]);
src += c;
}
}
}
}
Ok(out)
}
pub fn cyclic_shift(&self, x: &[f32], shift: i32) -> VisionResult<Vec<f32>> {
self.check_input_len(x)?;
let c = self.cfg.dim;
let h = self.cfg.input_h as i32;
let w = self.cfg.input_w as i32;
let mut out = vec![0.0f32; x.len()];
let sh = shift.rem_euclid(h);
let sw = shift.rem_euclid(w);
for dr in 0..h {
let sr = (dr - sh).rem_euclid(h);
for dc in 0..w {
let sc = (dc - sw).rem_euclid(w);
let dst = ((dr * w + dc) as usize) * c;
let src = ((sr * w + sc) as usize) * c;
out[dst..dst + c].copy_from_slice(&x[src..src + c]);
}
}
Ok(out)
}
#[must_use]
pub fn relative_position_index(&self) -> Vec<usize> {
let m = self.cfg.window_size;
let span = 2 * m - 1;
let n = m * m;
let mut idx = vec![0usize; n * n];
for a in 0..n {
let ha = a / m;
let wa = a % m;
for b in 0..n {
let hb = b / m;
let wb = b % m;
let rel_h = ha as i64 - hb as i64 + (m as i64 - 1);
let rel_w = wa as i64 - wb as i64 + (m as i64 - 1);
idx[a * n + b] = (rel_h * span as i64 + rel_w) as usize;
}
}
idx
}
pub fn attention_mask(&self) -> VisionResult<Vec<f32>> {
let n_windows = self.cfg.n_windows();
let m = self.cfg.window_size;
let win_tok = m * m;
let mut mask = vec![0.0f32; n_windows * win_tok * win_tok];
if !self.cfg.is_shifted() {
return Ok(mask);
}
let h = self.cfg.input_h;
let w = self.cfg.input_w;
let shift_size = self.cfg.shift_size();
let h_slices = [
(0usize, h - m),
(h - m, h - shift_size),
(h - shift_size, h),
];
let w_slices = [
(0usize, w - m),
(w - m, w - shift_size),
(w - shift_size, w),
];
let mut region = vec![0usize; h * w];
let mut region_id = 0usize;
for &(h0, h1) in &h_slices {
for &(w0, w1) in &w_slices {
for r in h0..h1 {
for col in w0..w1 {
region[r * w + col] = region_id;
}
}
region_id += 1;
}
}
let win_rows = h / m;
let win_cols = w / m;
let mut win = 0usize;
for wr in 0..win_rows {
for wc in 0..win_cols {
let mut win_region = vec![0usize; win_tok];
let mut t = 0usize;
for i in 0..m {
let row = wr * m + i;
for j in 0..m {
let col = wc * m + j;
win_region[t] = region[row * w + col];
t += 1;
}
}
let base = win * win_tok * win_tok;
for a in 0..win_tok {
for b in 0..win_tok {
if win_region[a] != win_region[b] {
mask[base + a * win_tok + b] = MASK_PENALTY;
}
}
}
win += 1;
}
}
Ok(mask)
}
pub fn forward(&self, x: &[f32]) -> VisionResult<Vec<f32>> {
self.check_input_len(x)?;
let d = self.cfg.dim;
let n_tok = self.cfg.input_h * self.cfg.input_w;
let w = &self.weights;
let normed = layer_norm(x, &w.ln1_gamma, &w.ln1_beta, n_tok, d, 1e-5);
let shifted_block = self.cfg.is_shifted();
let shift_size = self.cfg.shift_size();
let shifted = if shifted_block {
self.cyclic_shift(&normed, -(shift_size as i32))?
} else {
normed
};
let windows = self.window_partition(&shifted)?;
let mask = if shifted_block {
Some(self.attention_mask()?)
} else {
None
};
let attn = self.window_attention(&windows, mask.as_deref())?;
let merged = self.window_reverse(&attn)?;
let merged = if shifted_block {
self.cyclic_shift(&merged, shift_size as i32)?
} else {
merged
};
let mut h: Vec<f32> = x.iter().zip(merged.iter()).map(|(a, b)| a + b).collect();
let normed2 = layer_norm(&h, &w.ln2_gamma, &w.ln2_beta, n_tok, d, 1e-5);
let hidden = self.cfg.hidden_dim();
let mid = linear(&normed2, &w.mlp_w1, &w.mlp_b1, d, hidden);
let mid: Vec<f32> = mid.into_iter().map(gelu_exact).collect();
let mlp_out = linear(&mid, &w.mlp_w2, &w.mlp_b2, hidden, d);
for (o, m_v) in h.iter_mut().zip(mlp_out.iter()) {
*o += m_v;
}
if h.iter().any(|v| !v.is_finite()) {
return Err(VisionError::NonFinite("swin block output"));
}
Ok(h)
}
fn window_attention(&self, windows: &[f32], mask: Option<&[f32]>) -> VisionResult<Vec<f32>> {
let d = self.cfg.dim;
let n_heads = self.cfg.n_heads;
let head_dim = self.cfg.head_dim();
let m = self.cfg.window_size;
let win_tok = m * m;
let n_windows = self.cfg.n_windows();
let span = 2 * m - 1;
let table_per_head = span * span;
let scale = 1.0 / (head_dim as f32).sqrt();
let rel_index = self.relative_position_index();
let w = &self.weights;
let mut out = vec![0.0f32; n_windows * win_tok * d];
let mut q = vec![0.0f32; win_tok * d];
let mut k = vec![0.0f32; win_tok * d];
let mut v = vec![0.0f32; win_tok * d];
let mut scores = vec![0.0f32; win_tok * win_tok];
for win in 0..n_windows {
let win_in = &windows[win * win_tok * d..(win + 1) * win_tok * d];
let qkv = linear(win_in, &w.qkv_weight, &w.qkv_bias, d, 3 * d);
for t in 0..win_tok {
let src = &qkv[t * 3 * d..(t + 1) * 3 * d];
q[t * d..(t + 1) * d].copy_from_slice(&src[..d]);
k[t * d..(t + 1) * d].copy_from_slice(&src[d..2 * d]);
v[t * d..(t + 1) * d].copy_from_slice(&src[2 * d..]);
}
let win_concat = &mut out[win * win_tok * d..(win + 1) * win_tok * d];
let win_mask =
mask.map(|mk| &mk[win * win_tok * win_tok..(win + 1) * win_tok * win_tok]);
for head in 0..n_heads {
let hd_off = head * head_dim;
let table_off = head * table_per_head;
for a in 0..win_tok {
for b in 0..win_tok {
let mut dot = 0.0f32;
for dd in 0..head_dim {
dot += q[a * d + hd_off + dd] * k[b * d + hd_off + dd];
}
let bias =
w.relative_position_bias_table[table_off + rel_index[a * win_tok + b]];
let mut logit = dot * scale + bias;
if let Some(mk) = win_mask {
logit += mk[a * win_tok + b];
}
scores[a * win_tok + b] = logit;
}
}
softmax_rows(&mut scores, win_tok, win_tok);
for a in 0..win_tok {
for dd in 0..head_dim {
let mut acc = 0.0f32;
for b in 0..win_tok {
acc += scores[a * win_tok + b] * v[b * d + hd_off + dd];
}
win_concat[a * d + hd_off + dd] = acc;
}
}
}
}
let projected = linear(&out, &w.proj_weight, &w.proj_bias, d, d);
if projected.iter().any(|v| !v.is_finite()) {
return Err(VisionError::NonFinite("swin window attention"));
}
Ok(projected)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn wmsa_cfg() -> SwinConfig {
SwinConfig::new(32, 4, 2, 4, 4, 0, 4).expect("valid config")
}
fn swmsa_cfg() -> SwinConfig {
SwinConfig::new(32, 4, 2, 4, 4, 1, 4).expect("valid config")
}
fn random_input(cfg: &SwinConfig, seed: u64) -> Vec<f32> {
let mut rng = LcgRng::new(seed);
let mut x = vec![0.0f32; cfg.input_h * cfg.input_w * cfg.dim];
rng.fill_normal(&mut x);
x
}
#[test]
fn config_derived_quantities() {
let cfg = wmsa_cfg();
assert_eq!(cfg.head_dim(), 8);
assert_eq!(cfg.hidden_dim(), 128);
assert_eq!(cfg.window_tokens(), 4);
assert_eq!(cfg.n_windows(), 4);
assert_eq!(cfg.rel_table_len(), 9); }
#[test]
fn window_partition_length() {
let cfg = wmsa_cfg();
let mut rng = LcgRng::new(1);
let block = SwinBlock::new(cfg.clone(), &mut rng).expect("block");
let x = random_input(&cfg, 2);
let parts = block.window_partition(&x).expect("partition");
assert_eq!(parts.len(), cfg.input_h * cfg.input_w * cfg.dim);
}
#[test]
fn window_reverse_round_trip_exact() {
let cfg = wmsa_cfg();
let mut rng = LcgRng::new(3);
let block = SwinBlock::new(cfg.clone(), &mut rng).expect("block");
let x = random_input(&cfg, 4);
let parts = block.window_partition(&x).expect("partition");
let back = block.window_reverse(&parts).expect("reverse");
assert_eq!(back.len(), x.len());
for (a, b) in back.iter().zip(x.iter()) {
assert_eq!(a, b, "round-trip not exact");
}
}
#[test]
fn window_reverse_round_trip_nonsquare() {
let cfg = SwinConfig::new(16, 2, 2, 4, 6, 0, 2).expect("cfg");
let mut rng = LcgRng::new(31);
let block = SwinBlock::new(cfg.clone(), &mut rng).expect("block");
let x = random_input(&cfg, 32);
let parts = block.window_partition(&x).expect("partition");
let back = block.window_reverse(&parts).expect("reverse");
for (a, b) in back.iter().zip(x.iter()) {
assert_eq!(a, b, "non-square round-trip not exact");
}
}
#[test]
fn n_windows_matches_formula() {
let cfg = SwinConfig::new(16, 2, 2, 6, 8, 0, 2).expect("cfg");
assert_eq!(cfg.n_windows(), (6 / 2) * (8 / 2));
}
#[test]
fn cyclic_shift_round_trip_identity() {
let cfg = swmsa_cfg();
let mut rng = LcgRng::new(5);
let block = SwinBlock::new(cfg.clone(), &mut rng).expect("block");
let x = random_input(&cfg, 6);
for k in [1i32, 2, 3, -1, -2] {
let s = block.cyclic_shift(&x, k).expect("shift");
let back = block.cyclic_shift(&s, -k).expect("unshift");
for (a, b) in back.iter().zip(x.iter()) {
assert_eq!(a, b, "cyclic shift round-trip failed for k={k}");
}
}
}
#[test]
fn cyclic_shift_zero_is_identity() {
let cfg = wmsa_cfg();
let mut rng = LcgRng::new(7);
let block = SwinBlock::new(cfg.clone(), &mut rng).expect("block");
let x = random_input(&cfg, 8);
let s = block.cyclic_shift(&x, 0).expect("shift");
assert_eq!(s, x);
}
#[test]
fn relative_position_index_length_and_bounds() {
let cfg = wmsa_cfg();
let mut rng = LcgRng::new(9);
let block = SwinBlock::new(cfg.clone(), &mut rng).expect("block");
let idx = block.relative_position_index();
let m = cfg.window_size;
let span = 2 * m - 1;
assert_eq!(idx.len(), m * m * m * m, "length must be M⁴");
assert!(idx.iter().all(|&v| v < span * span), "index out of bounds");
}
#[test]
fn relative_position_index_diagonal_is_center() {
let cfg = wmsa_cfg();
let mut rng = LcgRng::new(10);
let block = SwinBlock::new(cfg.clone(), &mut rng).expect("block");
let idx = block.relative_position_index();
let m = cfg.window_size;
let n = m * m;
let span = 2 * m - 1;
let center = (m - 1) * span + (m - 1);
for a in 0..n {
assert_eq!(idx[a * n + a], center, "diagonal must map to center bias");
}
}
#[test]
fn attention_mask_length() {
let cfg = swmsa_cfg();
let mut rng = LcgRng::new(11);
let block = SwinBlock::new(cfg.clone(), &mut rng).expect("block");
let mask = block.attention_mask().expect("mask");
let win_tok = cfg.window_tokens();
assert_eq!(mask.len(), cfg.n_windows() * win_tok * win_tok);
}
#[test]
fn attention_mask_values_are_zero_or_penalty() {
let cfg = swmsa_cfg();
let mut rng = LcgRng::new(12);
let block = SwinBlock::new(cfg.clone(), &mut rng).expect("block");
let mask = block.attention_mask().expect("mask");
assert!(
mask.iter().all(|&v| v == 0.0 || v == MASK_PENALTY),
"mask must be 0.0 or -100.0 only"
);
}
#[test]
fn single_window_config_mask_all_zero() {
let cfg = SwinConfig::new(16, 2, 4, 4, 4, 1, 2).expect("cfg");
let mut rng = LcgRng::new(13);
let block = SwinBlock::new(cfg.clone(), &mut rng).expect("block");
assert_eq!(cfg.n_windows(), 1);
let mask = block.attention_mask().expect("mask");
assert!(
mask.iter().all(|&v| v == 0.0),
"single window must be unmasked"
);
}
#[test]
fn wmsa_mask_all_zero() {
let cfg = wmsa_cfg();
let mut rng = LcgRng::new(14);
let block = SwinBlock::new(cfg.clone(), &mut rng).expect("block");
let mask = block.attention_mask().expect("mask");
assert!(
mask.iter().all(|&v| v == 0.0),
"W-MSA mask must be all zero"
);
}
#[test]
fn forward_output_shape_wmsa() {
let cfg = wmsa_cfg();
let mut rng = LcgRng::new(15);
let block = SwinBlock::new(cfg.clone(), &mut rng).expect("block");
let x = random_input(&cfg, 16);
let out = block.forward(&x).expect("forward");
assert_eq!(out.len(), x.len());
}
#[test]
fn forward_output_shape_swmsa() {
let cfg = swmsa_cfg();
let mut rng = LcgRng::new(17);
let block = SwinBlock::new(cfg.clone(), &mut rng).expect("block");
let x = random_input(&cfg, 18);
let out = block.forward(&x).expect("forward");
assert_eq!(out.len(), x.len());
}
#[test]
fn forward_finite_wmsa_and_swmsa() {
for cfg in [wmsa_cfg(), swmsa_cfg()] {
let mut rng = LcgRng::new(19);
let block = SwinBlock::new(cfg.clone(), &mut rng).expect("block");
let x = random_input(&cfg, 20);
let out = block.forward(&x).expect("forward");
assert!(out.iter().all(|v| v.is_finite()), "non-finite output");
}
}
#[test]
fn forward_changes_with_bias_table() {
let cfg = wmsa_cfg();
let mut rng = LcgRng::new(21);
let mut block = SwinBlock::new(cfg.clone(), &mut rng).expect("block");
let x = random_input(&cfg, 22);
let before = block.forward(&x).expect("forward");
for v in &mut block.weights.relative_position_bias_table {
*v += 5.0;
}
let after = block.forward(&x).expect("forward");
let diff: f32 = before
.iter()
.zip(after.iter())
.map(|(a, b)| (a - b).abs())
.sum();
assert!(diff > 1e-5, "bias table change did not affect output");
}
#[test]
fn err_head_dim_mismatch() {
let r = SwinConfig::new(32, 5, 2, 4, 4, 0, 4); assert!(matches!(r, Err(VisionError::HeadDimMismatch { .. })));
}
#[test]
fn err_input_not_divisible_by_window() {
let r = SwinConfig::new(32, 4, 3, 4, 4, 0, 4); assert!(matches!(r, Err(VisionError::InvalidPatchSize { .. })));
}
#[test]
fn err_shift_too_large() {
let r = SwinConfig::new(32, 4, 2, 4, 4, 2, 4); assert!(matches!(r, Err(VisionError::Internal(_))));
}
#[test]
fn err_mlp_ratio_zero() {
let r = SwinConfig::new(32, 4, 2, 4, 4, 0, 0);
assert!(matches!(r, Err(VisionError::Internal(_))));
}
#[test]
fn err_window_size_zero() {
let r = SwinConfig::new(32, 4, 0, 4, 4, 0, 4);
assert!(matches!(r, Err(VisionError::InvalidPatchSize { .. })));
}
#[test]
fn err_dim_zero() {
let r = SwinConfig::new(0, 4, 2, 4, 4, 0, 4);
assert!(matches!(r, Err(VisionError::InvalidEmbedDim(0))));
}
#[test]
fn err_forward_wrong_length() {
let cfg = wmsa_cfg();
let mut rng = LcgRng::new(23);
let block = SwinBlock::new(cfg, &mut rng).expect("block");
let x = vec![0.0f32; 7]; let r = block.forward(&x);
assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
}
#[test]
fn err_partition_wrong_length() {
let cfg = wmsa_cfg();
let mut rng = LcgRng::new(24);
let block = SwinBlock::new(cfg, &mut rng).expect("block");
let x = vec![0.0f32; 9]; let r = block.window_partition(&x);
assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
}
#[test]
fn deterministic_given_seed() {
let cfg = swmsa_cfg();
let mut rng_a = LcgRng::new(99);
let mut rng_b = LcgRng::new(99);
let block_a = SwinBlock::new(cfg.clone(), &mut rng_a).expect("block");
let block_b = SwinBlock::new(cfg.clone(), &mut rng_b).expect("block");
let x = random_input(&cfg, 100);
let out_a = block_a.forward(&x).expect("forward");
let out_b = block_b.forward(&x).expect("forward");
assert_eq!(out_a, out_b, "same seed must give identical output");
}
}