use std::path::Path;
use tch::{nn, nn::Module, nn::ModuleT, Device, Kind, Tensor};
use ruvector_attn_mincut::attn_mincut;
use ruvector_attention::attention::ScaledDotProductAttention;
use ruvector_attention::traits::Attention;
use crate::config::TrainingConfig;
use crate::error::TrainError;
#[derive(Debug)]
pub struct ModelOutput {
pub keypoints: Tensor,
pub part_logits: Tensor,
pub uv_coords: Tensor,
pub features: Tensor,
}
pub struct WiFiDensePoseModel {
vs: nn::VarStore,
translator: ModalityTranslator,
backbone: Backbone,
kp_head: KeypointHead,
dp_head: DensePoseHead,
pub config: TrainingConfig,
}
impl WiFiDensePoseModel {
pub fn new(config: &TrainingConfig, device: Device) -> Self {
let vs = nn::VarStore::new(device);
let root = vs.root();
let n_ant = (config.window_frames
* config.num_antennas_tx
* config.num_antennas_rx) as i64;
let n_sc = config.num_subcarriers as i64;
let flat_csi = n_ant * n_sc;
let num_parts = config.num_body_parts as i64;
let translator =
ModalityTranslator::new(&root / "translator", flat_csi, n_ant, n_sc);
let backbone = Backbone::new(&root / "backbone", config.backbone_channels as i64);
let kp_head = KeypointHead::new(
&root / "kp_head",
config.backbone_channels as i64,
config.num_keypoints as i64,
);
let dp_head = DensePoseHead::new(
&root / "dp_head",
config.backbone_channels as i64,
num_parts,
);
WiFiDensePoseModel {
vs,
translator,
backbone,
kp_head,
dp_head,
config: config.clone(),
}
}
pub fn forward_t(&self, amplitude: &Tensor, phase: &Tensor) -> ModelOutput {
self.forward_impl(amplitude, phase, true)
}
pub fn forward_inference(&self, amplitude: &Tensor, phase: &Tensor) -> ModelOutput {
tch::no_grad(|| self.forward_impl(amplitude, phase, false))
}
pub fn save(&self, path: &Path) -> Result<(), TrainError> {
self.vs
.save(path)
.map_err(|e| TrainError::training_step(format!("save failed: {e}")))
}
pub fn load(&mut self, path: &Path) -> Result<(), TrainError> {
self.vs
.load(path)
.map_err(|e| TrainError::training_step(format!("load failed: {e}")))
}
pub fn varstore(&self) -> &nn::VarStore {
&self.vs
}
pub fn varstore_mut(&mut self) -> &mut nn::VarStore {
&mut self.vs
}
pub fn var_store(&self) -> &nn::VarStore {
&self.vs
}
pub fn var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.vs
}
pub fn forward_train(&self, amplitude: &Tensor, phase: &Tensor) -> ModelOutput {
self.forward_t(amplitude, phase)
}
pub fn num_parameters(&self) -> i64 {
self.vs
.trainable_variables()
.iter()
.map(|t| t.numel())
.sum()
}
fn forward_impl(&self, amplitude: &Tensor, phase: &Tensor, train: bool) -> ModelOutput {
let cfg = &self.config;
let clean_phase = phase_sanitize(phase);
let batch = amplitude.size()[0];
let flat_amp = amplitude.reshape([batch, -1]);
let flat_phase = clean_phase.reshape([batch, -1]);
let spatial = self.translator.forward_t(&flat_amp, &flat_phase, train);
let features = self.backbone.forward_t(&spatial, train);
let hs = cfg.heatmap_size as i64;
let keypoints = self.kp_head.forward_t(&features, hs, train);
let (part_logits, uv_coords) = self.dp_head.forward_t(&features, hs, train);
ModelOutput {
keypoints,
part_logits,
uv_coords,
features,
}
}
}
fn phase_sanitize(phase: &Tensor) -> Tensor {
let n_sub = phase.size()[2];
if n_sub <= 1 {
return phase.zeros_like();
}
let later = phase.slice(2, 1, n_sub, 1);
let earlier = phase.slice(2, 0, n_sub - 1, 1);
let diff = later - earlier;
let zeros = Tensor::zeros(
[phase.size()[0], phase.size()[1], 1],
(Kind::Float, phase.device()),
);
Tensor::cat(&[zeros, diff], 2)
}
fn apply_antenna_attention(x: &Tensor, lambda: f32) -> Tensor {
let sizes = x.size();
let n_ant = sizes[1];
let n_sc = sizes[2];
if n_ant <= 1 || n_sc <= 1 {
return x.shallow_clone();
}
let b = sizes[0] as usize;
let n_ant_usize = n_ant as usize;
let n_sc_usize = n_sc as usize;
let device = x.device();
let kind = x.kind();
let mut results: Vec<Tensor> = Vec::with_capacity(b);
for bi in 0..b {
let xi = x.select(0, bi as i64);
let flat: Vec<f32> =
Vec::from(xi.to_kind(Kind::Float).to_device(Device::Cpu).contiguous());
let out = attn_mincut(
&flat, &flat, &flat, n_sc_usize, n_ant_usize, lambda, 1, 1e-6, );
let attended = Tensor::from_slice(&out.output)
.reshape([n_ant, n_sc])
.to_device(device)
.to_kind(kind);
results.push(attended);
}
Tensor::stack(&results, 0) }
#[allow(dead_code)]
fn apply_spatial_attention(x: &Tensor) -> Tensor {
let sizes = x.size();
let (b, c, h, w) = (sizes[0], sizes[1], sizes[2], sizes[3]);
let n_spatial = (h * w) as usize;
let d = c as usize;
let device = x.device();
let kind = x.kind();
let attn = ScaledDotProductAttention::new(d);
let mut results: Vec<Tensor> = Vec::with_capacity(b as usize);
for bi in 0..b {
let xi = x.select(0, bi).reshape([c, h * w]).transpose(0, 1); let flat: Vec<f32> =
Vec::from(xi.to_kind(Kind::Float).to_device(Device::Cpu).contiguous());
let tokens: Vec<&[f32]> = (0..n_spatial)
.map(|i| &flat[i * d..(i + 1) * d])
.collect();
let mut out_flat = vec![0.0f32; n_spatial * d];
for i in 0..n_spatial {
let query = &flat[i * d..(i + 1) * d];
match attn.compute(query, &tokens, &tokens) {
Ok(attended) => {
out_flat[i * d..(i + 1) * d].copy_from_slice(&attended);
}
Err(_) => {
out_flat[i * d..(i + 1) * d].copy_from_slice(query);
}
}
}
let out_tensor = Tensor::from_slice(&out_flat)
.reshape([h * w, c])
.transpose(0, 1) .reshape([c, h, w]) .to_device(device)
.to_kind(kind);
results.push(out_tensor);
}
Tensor::stack(&results, 0) }
struct ModalityTranslator {
amp_fc1: nn::Linear,
amp_fc2: nn::Linear,
ph_fc1: nn::Linear,
ph_fc2: nn::Linear,
fuse_fc: nn::Linear,
sp_conv1: nn::Conv2D,
sp_bn1: nn::BatchNorm,
sp_conv2: nn::Conv2D,
n_ant: i64,
n_sc: i64,
}
impl ModalityTranslator {
fn new(vs: nn::Path, flat_csi: i64, n_ant: i64, n_sc: i64) -> Self {
let amp_fc1 = nn::linear(&vs / "amp_fc1", flat_csi, 512, Default::default());
let amp_fc2 = nn::linear(&vs / "amp_fc2", 512, 256, Default::default());
let ph_fc1 = nn::linear(&vs / "ph_fc1", flat_csi, 512, Default::default());
let ph_fc2 = nn::linear(&vs / "ph_fc2", 512, 256, Default::default());
let fuse_fc = nn::linear(&vs / "fuse_fc", 512, 3 * 48 * 48, Default::default());
let sp_conv1 = nn::conv2d(
&vs / "sp_conv1",
3,
32,
3,
nn::ConvConfig {
padding: 1,
bias: false,
..Default::default()
},
);
let sp_bn1 = nn::batch_norm2d(&vs / "sp_bn1", 32, Default::default());
let sp_conv2 = nn::conv2d(
&vs / "sp_conv2",
32,
3,
3,
nn::ConvConfig {
padding: 1,
..Default::default()
},
);
ModalityTranslator {
amp_fc1,
amp_fc2,
ph_fc1,
ph_fc2,
fuse_fc,
sp_conv1,
sp_bn1,
sp_conv2,
n_ant,
n_sc,
}
}
fn forward_t(&self, amp: &Tensor, ph: &Tensor, train: bool) -> Tensor {
let b = amp.size()[0];
let amp_3d = amp.reshape([b, self.n_ant, self.n_sc]);
let ph_3d = ph.reshape([b, self.n_ant, self.n_sc]);
let amp_attended = apply_antenna_attention(&_3d, 0.3);
let ph_attended = apply_antenna_attention(&ph_3d, 0.3);
let amp_flat = amp_attended.reshape([b, -1]); let ph_flat = ph_attended.reshape([b, -1]);
let a = amp_flat
.apply(&self.amp_fc1)
.relu()
.dropout(0.2, train)
.apply(&self.amp_fc2)
.relu();
let p = ph_flat
.apply(&self.ph_fc1)
.relu()
.dropout(0.2, train)
.apply(&self.ph_fc2)
.relu();
let fused = Tensor::cat(&[a, p], 1) .apply(&self.fuse_fc) .view([b, 3, 48, 48])
.relu();
let out = fused
.apply(&self.sp_conv1)
.apply_t(&self.sp_bn1, train)
.relu()
.apply(&self.sp_conv2)
.tanh();
out
}
}
struct Backbone {
stem_conv: nn::Conv2D,
stem_bn: nn::BatchNorm,
l1b1: BasicBlock,
l1b2: BasicBlock,
l2b1: BasicBlock,
l2b2: BasicBlock,
l3b1: BasicBlock,
l3b2: BasicBlock,
}
impl Backbone {
fn new(vs: nn::Path, out_channels: i64) -> Self {
let stem_conv = nn::conv2d(
&vs / "stem_conv",
3,
64,
3,
nn::ConvConfig {
padding: 1,
bias: false,
..Default::default()
},
);
let stem_bn = nn::batch_norm2d(&vs / "stem_bn", 64, Default::default());
Backbone {
stem_conv,
stem_bn,
l1b1: BasicBlock::new(&vs / "l1b1", 64, 64, 1),
l1b2: BasicBlock::new(&vs / "l1b2", 64, 64, 1),
l2b1: BasicBlock::new(&vs / "l2b1", 64, 128, 2),
l2b2: BasicBlock::new(&vs / "l2b2", 128, 128, 1),
l3b1: BasicBlock::new(&vs / "l3b1", 128, out_channels, 2),
l3b2: BasicBlock::new(&vs / "l3b2", out_channels, out_channels, 1),
}
}
fn forward_t(&self, x: &Tensor, train: bool) -> Tensor {
let x = self
.stem_conv
.forward(x)
.apply_t(&self.stem_bn, train)
.relu();
let x = self.l1b1.forward_t(&x, train);
let x = self.l1b2.forward_t(&x, train);
let x = self.l2b1.forward_t(&x, train);
let x = self.l2b2.forward_t(&x, train);
let x = self.l3b1.forward_t(&x, train);
self.l3b2.forward_t(&x, train)
}
}
struct BasicBlock {
conv1: nn::Conv2D,
bn1: nn::BatchNorm,
conv2: nn::Conv2D,
bn2: nn::BatchNorm,
downsample: Option<(nn::Conv2D, nn::BatchNorm)>,
}
impl BasicBlock {
fn new(vs: nn::Path, in_ch: i64, out_ch: i64, stride: i64) -> Self {
let conv1 = nn::conv2d(
&vs / "conv1",
in_ch,
out_ch,
3,
nn::ConvConfig {
stride,
padding: 1,
bias: false,
..Default::default()
},
);
let bn1 = nn::batch_norm2d(&vs / "bn1", out_ch, Default::default());
let conv2 = nn::conv2d(
&vs / "conv2",
out_ch,
out_ch,
3,
nn::ConvConfig {
padding: 1,
bias: false,
..Default::default()
},
);
let bn2 = nn::batch_norm2d(&vs / "bn2", out_ch, Default::default());
let downsample = if in_ch != out_ch || stride != 1 {
let ds_conv = nn::conv2d(
&vs / "ds_conv",
in_ch,
out_ch,
1,
nn::ConvConfig {
stride,
bias: false,
..Default::default()
},
);
let ds_bn = nn::batch_norm2d(&vs / "ds_bn", out_ch, Default::default());
Some((ds_conv, ds_bn))
} else {
None
};
BasicBlock {
conv1,
bn1,
conv2,
bn2,
downsample,
}
}
fn forward_t(&self, x: &Tensor, train: bool) -> Tensor {
let residual = match &self.downsample {
Some((ds_conv, ds_bn)) => ds_conv.forward(x).apply_t(ds_bn, train),
None => x.shallow_clone(),
};
let out = self
.conv1
.forward(x)
.apply_t(&self.bn1, train)
.relu();
let out = self.conv2.forward(&out).apply_t(&self.bn2, train);
(out + residual).relu()
}
}
struct KeypointHead {
conv1: nn::Conv2D,
bn1: nn::BatchNorm,
conv2: nn::Conv2D,
bn2: nn::BatchNorm,
out_conv: nn::Conv2D,
}
impl KeypointHead {
fn new(vs: nn::Path, in_ch: i64, num_kp: i64) -> Self {
let conv1 = nn::conv2d(
&vs / "conv1",
in_ch,
256,
3,
nn::ConvConfig {
padding: 1,
bias: false,
..Default::default()
},
);
let bn1 = nn::batch_norm2d(&vs / "bn1", 256, Default::default());
let conv2 = nn::conv2d(
&vs / "conv2",
256,
128,
3,
nn::ConvConfig {
padding: 1,
bias: false,
..Default::default()
},
);
let bn2 = nn::batch_norm2d(&vs / "bn2", 128, Default::default());
let out_conv = nn::conv2d(&vs / "out_conv", 128, num_kp, 1, Default::default());
KeypointHead {
conv1,
bn1,
conv2,
bn2,
out_conv,
}
}
fn forward_t(&self, x: &Tensor, heatmap_size: i64, train: bool) -> Tensor {
let h = x
.apply(&self.conv1)
.apply_t(&self.bn1, train)
.relu()
.apply(&self.conv2)
.apply_t(&self.bn2, train)
.relu()
.apply(&self.out_conv);
h.upsample_bilinear2d(&[heatmap_size, heatmap_size], false, None, None)
}
}
struct DensePoseHead {
shared_conv1: nn::Conv2D,
shared_bn1: nn::BatchNorm,
shared_conv2: nn::Conv2D,
shared_bn2: nn::BatchNorm,
part_out: nn::Conv2D,
uv_out: nn::Conv2D,
}
impl DensePoseHead {
fn new(vs: nn::Path, in_ch: i64, num_parts: i64) -> Self {
let shared_conv1 = nn::conv2d(
&vs / "shared_conv1",
in_ch,
256,
3,
nn::ConvConfig {
padding: 1,
bias: false,
..Default::default()
},
);
let shared_bn1 = nn::batch_norm2d(&vs / "shared_bn1", 256, Default::default());
let shared_conv2 = nn::conv2d(
&vs / "shared_conv2",
256,
256,
3,
nn::ConvConfig {
padding: 1,
bias: false,
..Default::default()
},
);
let shared_bn2 = nn::batch_norm2d(&vs / "shared_bn2", 256, Default::default());
let part_out = nn::conv2d(
&vs / "part_out",
256,
num_parts + 1,
1,
Default::default(),
);
let uv_out = nn::conv2d(
&vs / "uv_out",
256,
num_parts * 2,
1,
Default::default(),
);
DensePoseHead {
shared_conv1,
shared_bn1,
shared_conv2,
shared_bn2,
part_out,
uv_out,
}
}
fn forward_t(&self, x: &Tensor, out_size: i64, train: bool) -> (Tensor, Tensor) {
let f = x
.apply(&self.shared_conv1)
.apply_t(&self.shared_bn1, train)
.relu()
.apply(&self.shared_conv2)
.apply_t(&self.shared_bn2, train)
.relu();
let f = f.upsample_bilinear2d(&[out_size, out_size], false, None, None);
let parts = f.apply(&self.part_out);
let uv = f.apply(&self.uv_out).sigmoid();
(parts, uv)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::TrainingConfig;
use tch::Device;
fn tiny_config() -> TrainingConfig {
let mut cfg = TrainingConfig::default();
cfg.num_subcarriers = 8;
cfg.window_frames = 4;
cfg.num_antennas_tx = 1;
cfg.num_antennas_rx = 1;
cfg.heatmap_size = 12;
cfg.backbone_channels = 64;
cfg.num_epochs = 2;
cfg.warmup_epochs = 1;
cfg
}
#[test]
fn model_forward_output_shapes() {
tch::manual_seed(0);
let cfg = tiny_config();
let device = Device::Cpu;
let model = WiFiDensePoseModel::new(&cfg, device);
let batch = 2_i64;
let antennas =
(cfg.num_antennas_tx * cfg.num_antennas_rx * cfg.window_frames) as i64;
let n_sub = cfg.num_subcarriers as i64;
let amp = Tensor::ones([batch, antennas, n_sub], (Kind::Float, device));
let ph = Tensor::zeros([batch, antennas, n_sub], (Kind::Float, device));
let out = model.forward_t(&, &ph);
assert_eq!(out.keypoints.size()[0], batch);
assert_eq!(out.keypoints.size()[1], cfg.num_keypoints as i64);
assert_eq!(out.keypoints.size()[2], cfg.heatmap_size as i64);
assert_eq!(out.keypoints.size()[3], cfg.heatmap_size as i64);
assert_eq!(out.part_logits.size()[0], batch);
assert_eq!(out.part_logits.size()[1], (cfg.num_body_parts + 1) as i64);
assert_eq!(out.uv_coords.size()[0], batch);
assert_eq!(out.uv_coords.size()[1], (cfg.num_body_parts * 2) as i64);
}
#[test]
fn model_has_nonzero_parameters() {
tch::manual_seed(0);
let cfg = tiny_config();
let model = WiFiDensePoseModel::new(&cfg, Device::Cpu);
let n = model.num_parameters();
assert!(n > 0, "model must have trainable parameters");
}
#[test]
fn inference_mode_gives_same_shapes() {
tch::manual_seed(0);
let cfg = tiny_config();
let model = WiFiDensePoseModel::new(&cfg, Device::Cpu);
let batch = 1_i64;
let antennas =
(cfg.num_antennas_tx * cfg.num_antennas_rx * cfg.window_frames) as i64;
let n_sub = cfg.num_subcarriers as i64;
let amp = Tensor::rand([batch, antennas, n_sub], (Kind::Float, Device::Cpu));
let ph = Tensor::rand([batch, antennas, n_sub], (Kind::Float, Device::Cpu));
let out = model.forward_inference(&, &ph);
assert_eq!(out.keypoints.size()[0], batch);
assert_eq!(out.part_logits.size()[0], batch);
assert_eq!(out.uv_coords.size()[0], batch);
}
#[test]
fn uv_coords_bounded_zero_one() {
tch::manual_seed(0);
let cfg = tiny_config();
let model = WiFiDensePoseModel::new(&cfg, Device::Cpu);
let batch = 2_i64;
let antennas =
(cfg.num_antennas_tx * cfg.num_antennas_rx * cfg.window_frames) as i64;
let n_sub = cfg.num_subcarriers as i64;
let amp = Tensor::rand([batch, antennas, n_sub], (Kind::Float, Device::Cpu));
let ph = Tensor::rand([batch, antennas, n_sub], (Kind::Float, Device::Cpu));
let out = model.forward_inference(&, &ph);
let uv_min: f64 = out.uv_coords.min().double_value(&[]);
let uv_max: f64 = out.uv_coords.max().double_value(&[]);
assert!(
uv_min >= 0.0 - 1e-5,
"UV min should be >= 0, got {uv_min}"
);
assert!(
uv_max <= 1.0 + 1e-5,
"UV max should be <= 1, got {uv_max}"
);
}
#[test]
fn phase_sanitize_zeros_first_column() {
let ph = Tensor::ones([2, 3, 8], (Kind::Float, Device::Cpu));
let out = phase_sanitize(&ph);
let first_col = out.slice(2, 0, 1, 1);
let max_abs: f64 = first_col.abs().max().double_value(&[]);
assert!(max_abs < 1e-6, "first diff column should be 0");
}
#[test]
fn phase_sanitize_captures_ramp() {
let ph = Tensor::arange(8, (Kind::Float, Device::Cpu))
.reshape([1, 1, 8])
.expand([2, 3, 8], true);
let out = phase_sanitize(&ph);
let tail = out.slice(2, 1, 8, 1);
let min_val: f64 = tail.min().double_value(&[]);
let max_val: f64 = tail.max().double_value(&[]);
assert!(
(min_val - 1.0).abs() < 1e-5,
"expected 1.0 diff, got {min_val}"
);
assert!(
(max_val - 1.0).abs() < 1e-5,
"expected 1.0 diff, got {max_val}"
);
}
#[test]
fn save_and_load_roundtrip() {
use tempfile::tempdir;
tch::manual_seed(42);
let cfg = tiny_config();
let mut model = WiFiDensePoseModel::new(&cfg, Device::Cpu);
let tmp = tempdir().expect("tempdir");
let path = tmp.path().join("weights.pt");
model.save(&path).expect("save should succeed");
model.load(&path).expect("load should succeed");
let batch = 1_i64;
let antennas =
(cfg.num_antennas_tx * cfg.num_antennas_rx * cfg.window_frames) as i64;
let n_sub = cfg.num_subcarriers as i64;
let amp = Tensor::rand([batch, antennas, n_sub], (Kind::Float, Device::Cpu));
let ph = Tensor::rand([batch, antennas, n_sub], (Kind::Float, Device::Cpu));
let out = model.forward_inference(&, &ph);
assert_eq!(out.keypoints.size()[0], batch);
}
#[test]
fn varstore_accessible() {
let cfg = tiny_config();
let mut model = WiFiDensePoseModel::new(&cfg, Device::Cpu);
let _vs = model.varstore();
let _vs_mut = model.varstore_mut();
}
}