use tch::{nn, Device, Tensor};
use super::config::WiFlowStdConfig;
use super::layers::{ConvBlock, DualAxialAttention, GroupedTemporalBlock};
use crate::error::TrainError;
pub struct WiFlowStdModel {
vs: nn::VarStore,
tcn: Vec<GroupedTemporalBlock>,
conv_in: ConvBlock,
conv_blocks: Vec<ConvBlock>,
attention: DualAxialAttention,
dec_conv1: nn::Conv2D,
dec_bn1: nn::BatchNorm,
dec_conv2: nn::Conv2D,
dec_bn2: nn::BatchNorm,
pub config: WiFlowStdConfig,
}
impl WiFlowStdModel {
pub fn new(config: &WiFlowStdConfig, device: Device) -> Result<Self, TrainError> {
config.validate()?;
let vs = nn::VarStore::new(device);
let root = vs.root();
let mut tcn = Vec::with_capacity(config.tcn_channels.len());
let mut c_in = config.subcarriers;
for (i, &c_out) in config.tcn_channels.iter().enumerate() {
let dilation = 1_i64 << i;
let pw_groups = if i == 0 { config.input_pw_groups } else { 1 };
tcn.push(GroupedTemporalBlock::new(
&root / format!("tcn{i}"),
c_in as i64,
c_out as i64,
dilation,
config.tcn_conv_groups(c_in) as i64,
config.tcn_conv_groups(c_out) as i64,
pw_groups as i64,
config.dropout,
));
c_in = c_out;
}
let c0 = config.conv_channels[0] as i64;
let conv_in = ConvBlock::new(&root / "conv_in", 1, c0, 1);
let mut conv_blocks = Vec::with_capacity(config.conv_channels.len());
let strides = config.conv_strides();
let mut c_in = c0;
for (i, &c_out) in config.conv_channels.iter().enumerate() {
conv_blocks.push(ConvBlock::new(
&root / format!("conv{i}"),
c_in,
c_out as i64,
strides[i] as i64,
));
c_in = c_out as i64;
}
let attention =
DualAxialAttention::new(&root / "attention", c_in, config.attention_groups as i64);
let mid = config.decoder_mid() as i64;
let dec_conv1 = nn::conv2d(
&root / "dec_conv1",
c_in,
mid,
3,
nn::ConvConfig {
padding: 1,
..Default::default()
},
);
let dec_bn1 = nn::batch_norm2d(&root / "dec_bn1", mid, super::layers::bn_cfg());
let dec_conv2 = nn::conv2d(&root / "dec_conv2", mid, 2, 1, Default::default());
let dec_bn2 = nn::batch_norm2d(&root / "dec_bn2", 2, super::layers::bn_cfg());
Ok(WiFlowStdModel {
vs,
tcn,
conv_in,
conv_blocks,
attention,
dec_conv1,
dec_bn1,
dec_conv2,
dec_bn2,
config: config.clone(),
})
}
pub fn forward_t(&self, csi: &Tensor) -> Tensor {
self.forward_impl(csi, true)
}
pub fn forward_inference(&self, csi: &Tensor) -> Tensor {
tch::no_grad(|| self.forward_impl(csi, false))
}
pub fn save(&self, path: &std::path::Path) -> Result<(), TrainError> {
self.vs
.save(path)
.map_err(|e| TrainError::training_step(format!("save failed: {e}")))
}
pub fn load(&mut self, path: &std::path::Path) -> Result<(), TrainError> {
self.vs
.load(path)
.map_err(|e| TrainError::training_step(format!("load failed: {e}")))
}
pub fn var_store(&self) -> &nn::VarStore {
&self.vs
}
pub fn var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.vs
}
pub fn num_parameters(&self) -> i64 {
self.vs
.trainable_variables()
.iter()
.map(|t| t.numel() as i64)
.sum()
}
fn forward_impl(&self, csi: &Tensor, train: bool) -> Tensor {
let mut h = csi.shallow_clone();
for block in &self.tcn {
h = block.forward_t(&h, train);
}
let h = h.transpose(1, 2).unsqueeze(1);
let mut h = self.conv_in.forward_t(&h, train);
for block in &self.conv_blocks {
h = block.forward_t(&h, train);
}
let h = h.permute([0, 1, 3, 2]);
let h = self.attention.forward_t(&h, train);
let h = h
.apply(&self.dec_conv1)
.apply_t(&self.dec_bn1, train)
.silu()
.apply(&self.dec_conv2)
.apply_t(&self.dec_bn2, train)
.silu();
let k = self.config.keypoints as i64;
h.adaptive_avg_pool2d([k, 1])
.squeeze_dim(-1)
.transpose(1, 2)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tch::Kind;
fn random_csi(cfg: &WiFlowStdConfig, batch: i64) -> Tensor {
Tensor::rand(
[batch, cfg.subcarriers as i64, cfg.window as i64],
(Kind::Float, Device::Cpu),
)
}
#[test]
fn param_count_matches_pure_rust_formula() {
tch::manual_seed(0);
let cfg = WiFlowStdConfig::default();
let model = WiFlowStdModel::new(&cfg, Device::Cpu).expect("default config builds");
assert_eq!(model.num_parameters(), cfg.param_count() as i64);
assert_eq!(model.num_parameters(), 2_225_042);
}
#[test]
fn compact_preset_param_counts_and_shapes() {
for (cfg, expected) in [
(WiFlowStdConfig::half(), 843_834_i64),
(WiFlowStdConfig::quarter(), 338_600),
(WiFlowStdConfig::tiny(), 56_290),
] {
tch::manual_seed(0);
let model = WiFlowStdModel::new(&cfg, Device::Cpu).expect("preset builds");
assert_eq!(model.num_parameters(), expected);
assert_eq!(model.num_parameters(), cfg.param_count() as i64);
let out = model.forward_inference(&random_csi(&cfg, 2));
assert_eq!(out.size(), &[2, 15, 2]);
}
}
#[test]
fn forward_output_shape_15_keypoints() {
tch::manual_seed(0);
let cfg = WiFlowStdConfig::default();
let model = WiFlowStdModel::new(&cfg, Device::Cpu).expect("build");
let out = model.forward_t(&random_csi(&cfg, 2));
assert_eq!(out.size(), &[2, 15, 2]);
}
#[test]
fn forward_output_shape_17_keypoints_esp32() {
tch::manual_seed(0);
let cfg = WiFlowStdConfig::for_keypoints(17);
let model = WiFlowStdModel::new(&cfg, Device::Cpu).expect("build");
let out = model.forward_inference(&random_csi(&cfg, 1));
assert_eq!(out.size(), &[1, 17, 2]);
}
#[test]
fn inference_outputs_are_finite_and_deterministic() {
tch::manual_seed(7);
let cfg = WiFlowStdConfig::default();
let model = WiFlowStdModel::new(&cfg, Device::Cpu).expect("build");
let csi = random_csi(&cfg, 1);
let a = model.forward_inference(&csi);
let b = model.forward_inference(&csi);
assert!(
bool::try_from(a.isfinite().all()).unwrap(),
"non-finite output"
);
assert!(
bool::try_from(a.eq_tensor(&b).all()).unwrap(),
"inference must be deterministic (dropout disabled)"
);
}
#[test]
fn dump_variable_names() {
let cfg = WiFlowStdConfig::default();
let model = WiFlowStdModel::new(&cfg, Device::Cpu).expect("build");
let vars = model.var_store().variables();
let mut names: Vec<(String, Vec<i64>)> =
vars.iter().map(|(n, t)| (n.clone(), t.size())).collect();
names.sort();
for (name, shape) in &names {
println!("{name} {shape:?}");
}
println!("total: {} variables", names.len());
assert!(!names.is_empty());
}
#[test]
fn invalid_config_is_rejected() {
let cfg = WiFlowStdConfig {
subcarriers: 541, ..Default::default()
};
assert!(WiFlowStdModel::new(&cfg, Device::Cpu).is_err());
}
#[test]
fn save_and_load_roundtrip() {
use tempfile::tempdir;
tch::manual_seed(42);
let cfg = WiFlowStdConfig::default();
let mut model = WiFlowStdModel::new(&cfg, Device::Cpu).expect("build");
let tmp = tempdir().expect("tempdir");
let path = tmp.path().join("wiflow_std.safetensors");
model.save(&path).expect("save");
model.load(&path).expect("load");
let out = model.forward_inference(&random_csi(&cfg, 1));
assert_eq!(out.size(), &[1, 15, 2]);
}
}