use serde::{Deserialize, Serialize};
use crate::error::ConfigError;
pub const TCN_KERNEL: usize = 3;
pub const CONV_BLOCK_DROPOUT: f64 = 0.3;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum TcnGroupsMode {
#[default]
Fixed,
Gcd,
Depthwise,
}
fn gcd(a: usize, b: usize) -> usize {
let (mut a, mut b) = (a, b);
while b != 0 {
(a, b) = (b, a % b);
}
a
}
fn default_input_pw_groups() -> usize {
1
}
fn default_min_feature_width() -> usize {
15
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct WiFlowStdConfig {
pub subcarriers: usize,
pub window: usize,
pub tcn_channels: Vec<usize>,
pub tcn_groups: usize,
#[serde(default)]
pub tcn_groups_mode: TcnGroupsMode,
#[serde(default = "default_input_pw_groups")]
pub input_pw_groups: usize,
pub conv_channels: Vec<usize>,
pub attention_groups: usize,
pub keypoints: usize,
#[serde(default = "default_min_feature_width")]
pub min_feature_width: usize,
pub dropout: f64,
}
impl Default for WiFlowStdConfig {
fn default() -> Self {
WiFlowStdConfig {
subcarriers: 540,
window: 20,
tcn_channels: vec![540, 440, 340, 240],
tcn_groups: 20,
tcn_groups_mode: TcnGroupsMode::Fixed,
input_pw_groups: 1,
conv_channels: vec![8, 16, 32, 64],
attention_groups: 8,
keypoints: 15,
min_feature_width: 15,
dropout: 0.5,
}
}
}
impl WiFlowStdConfig {
pub fn for_keypoints(keypoints: usize) -> Self {
WiFlowStdConfig {
keypoints,
..Self::default()
}
}
pub fn half() -> Self {
WiFlowStdConfig {
tcn_channels: vec![270, 220, 170, 120],
tcn_groups_mode: TcnGroupsMode::Gcd,
conv_channels: vec![4, 8, 16, 32],
attention_groups: 4,
..Self::default()
}
}
pub fn quarter() -> Self {
WiFlowStdConfig {
tcn_channels: vec![135, 110, 85, 60],
tcn_groups_mode: TcnGroupsMode::Gcd,
conv_channels: vec![2, 4, 8, 16],
attention_groups: 2,
..Self::default()
}
}
pub fn tiny() -> Self {
WiFlowStdConfig {
tcn_channels: vec![68, 56, 44, 32],
tcn_groups_mode: TcnGroupsMode::Depthwise,
input_pw_groups: 4,
conv_channels: vec![2, 4, 8, 16],
attention_groups: 2,
..Self::default()
}
}
pub fn validate(&self) -> Result<(), ConfigError> {
if self.subcarriers == 0 {
return Err(ConfigError::invalid_value("subcarriers", "must be >= 1"));
}
if self.window == 0 {
return Err(ConfigError::invalid_value("window", "must be >= 1"));
}
if self.tcn_groups == 0 {
return Err(ConfigError::invalid_value("tcn_groups", "must be >= 1"));
}
let fixed = self.tcn_groups_mode == TcnGroupsMode::Fixed;
if fixed && self.subcarriers % self.tcn_groups != 0 {
return Err(ConfigError::invalid_value(
"subcarriers",
format!(
"{} is not divisible by tcn_groups={} (grouped conv requirement)",
self.subcarriers, self.tcn_groups
),
));
}
if self.tcn_channels.is_empty() {
return Err(ConfigError::invalid_value(
"tcn_channels",
"must contain at least one level",
));
}
for (i, &c) in self.tcn_channels.iter().enumerate() {
if c == 0 || (fixed && c % self.tcn_groups != 0) {
return Err(ConfigError::invalid_value(
"tcn_channels",
format!(
"level {i} has {c} channels; must be > 0 and divisible by tcn_groups={}",
self.tcn_groups
),
));
}
}
if self.input_pw_groups == 0
|| self.subcarriers % self.input_pw_groups != 0
|| self.tcn_channels[0] % self.input_pw_groups != 0
{
return Err(ConfigError::invalid_value(
"input_pw_groups",
format!(
"{} must be >= 1 and divide both subcarriers={} and tcn_channels[0]={}",
self.input_pw_groups, self.subcarriers, self.tcn_channels[0]
),
));
}
if self.conv_channels.is_empty() {
return Err(ConfigError::invalid_value(
"conv_channels",
"must contain at least one block",
));
}
if self.conv_channels.iter().any(|&c| c == 0) {
return Err(ConfigError::invalid_value(
"conv_channels",
"all blocks must have > 0 channels",
));
}
let c_last = *self.conv_channels.last().expect("non-empty checked above");
if self.attention_groups == 0 || c_last % self.attention_groups != 0 {
return Err(ConfigError::invalid_value(
"attention_groups",
format!(
"{} must be >= 1 and divide the last conv channel count {c_last}",
self.attention_groups
),
));
}
if c_last < 2 || c_last % 2 != 0 {
return Err(ConfigError::invalid_value(
"conv_channels",
format!("last block has {c_last} channels; decoder needs an even count >= 2"),
));
}
if self.keypoints == 0 {
return Err(ConfigError::invalid_value("keypoints", "must be >= 1"));
}
if self.min_feature_width == 0 {
return Err(ConfigError::invalid_value(
"min_feature_width",
"must be >= 1",
));
}
if !self.dropout.is_finite() || !(0.0..1.0).contains(&self.dropout) {
return Err(ConfigError::invalid_value(
"dropout",
format!("{} is outside [0, 1)", self.dropout),
));
}
Ok(())
}
pub fn tcn_output_channels(&self) -> usize {
*self.tcn_channels.last().unwrap_or(&0)
}
pub fn tcn_conv_groups(&self, channels: usize) -> usize {
match self.tcn_groups_mode {
TcnGroupsMode::Fixed => self.tcn_groups,
TcnGroupsMode::Gcd => gcd(channels, self.tcn_groups),
TcnGroupsMode::Depthwise => channels,
}
}
pub fn conv_strides(&self) -> Vec<usize> {
let mut w = self.tcn_output_channels();
let mut strides = Vec::with_capacity(self.conv_channels.len());
for _ in &self.conv_channels {
let next = w.div_ceil(2);
if next >= self.min_feature_width {
strides.push(2);
w = next;
} else {
strides.push(1);
}
}
strides
}
pub fn feature_width(&self) -> usize {
let mut w = self.tcn_output_channels();
for s in self.conv_strides() {
if s == 2 {
w = w.div_ceil(2);
}
}
w
}
pub fn decoder_mid(&self) -> usize {
(self.conv_channels.last().unwrap_or(&0) / 2).max(4)
}
pub fn output_shape(&self, batch: usize) -> (usize, usize, usize) {
(batch, self.keypoints, 2)
}
pub fn param_count(&self) -> usize {
if self.validate().is_err() {
return 0;
}
let mut total = 0;
let mut c_in = self.subcarriers;
for (i, &c_out) in self.tcn_channels.iter().enumerate() {
let pw_groups = if i == 0 { self.input_pw_groups } else { 1 };
total += tcn_block_params(
c_in,
c_out,
TCN_KERNEL,
self.tcn_conv_groups(c_in),
self.tcn_conv_groups(c_out),
pw_groups,
);
c_in = c_out;
}
let mut c_in = 1;
total += conv_block_params(c_in, self.conv_channels[0]);
c_in = self.conv_channels[0];
for &c_out in &self.conv_channels {
total += conv_block_params(c_in, c_out);
c_in = c_out;
}
total += 2 * axial_attention_params(c_in, self.attention_groups);
total += decoder_params(c_in, self.decoder_mid());
total
}
}
fn tcn_block_params(
c_in: usize,
c_out: usize,
k: usize,
g_in: usize,
g_out: usize,
pw_groups: usize,
) -> usize {
let grouped1 = c_in * (c_in / g_in) * k; let bn1g = 2 * c_in;
let pw1 = c_out * (c_in / pw_groups); let bn1p = 2 * c_out;
let grouped2 = c_out * (c_out / g_out) * k;
let bn2g = 2 * c_out;
let pw2 = c_out * c_out;
let bn2p = 2 * c_out;
let downsample = if c_in != c_out {
(c_in / pw_groups) * c_out + 2 * c_out
} else {
0
};
grouped1 + bn1g + pw1 + bn1p + grouped2 + bn2g + pw2 + bn2p + downsample
}
fn conv_block_params(c_in: usize, c_out: usize) -> usize {
let conv1 = c_out * c_in * 3 + c_out;
let conv_rest = 2 * (c_out * c_out * 3 + c_out);
let bns = 3 * 2 * c_out;
let downsample = c_in * c_out + 2 * c_out;
conv1 + conv_rest + bns + downsample
}
fn axial_attention_params(c: usize, groups: usize) -> usize {
let qkv = c * 3 * c;
let bn_qkv = 2 * (3 * c);
let bn_similarity = 2 * groups;
let bn_output = 2 * c;
qkv + bn_qkv + bn_similarity + bn_output
}
fn decoder_params(c: usize, mid: usize) -> usize {
let conv1 = mid * c * 9 + mid;
let bn1 = 2 * mid;
let conv2 = 2 * mid + 2;
let bn2 = 2 * 2;
conv1 + bn1 + conv2 + bn2
}
#[cfg(test)]
mod tests {
use super::*;
const REFERENCE_PARAMS: usize = 2_225_042;
#[test]
fn default_config_is_valid() {
WiFlowStdConfig::default()
.validate()
.expect("default config must validate");
}
#[test]
fn default_param_count_matches_verified_reference() {
assert_eq!(WiFlowStdConfig::default().param_count(), REFERENCE_PARAMS);
}
#[test]
fn param_count_is_independent_of_keypoints() {
let kp17 = WiFlowStdConfig::for_keypoints(17);
kp17.validate().expect("17-keypoint config must validate");
assert_eq!(kp17.param_count(), REFERENCE_PARAMS);
}
#[test]
fn per_component_breakdown_matches_hand_calculation() {
assert_eq!(tcn_block_params(540, 540, 3, 20, 20, 1), 675_000);
assert_eq!(tcn_block_params(540, 440, 3, 20, 20, 1), 746_180);
assert_eq!(tcn_block_params(440, 340, 3, 20, 20, 1), 464_780);
assert_eq!(tcn_block_params(340, 240, 3, 20, 20, 1), 249_380);
assert_eq!(conv_block_params(1, 8), 504);
assert_eq!(conv_block_params(8, 8), 728);
assert_eq!(conv_block_params(8, 16), 2_224);
assert_eq!(conv_block_params(16, 32), 8_544);
assert_eq!(conv_block_params(32, 64), 33_472);
assert_eq!(axial_attention_params(64, 8), 12_816);
assert_eq!(decoder_params(64, 32), 18_598);
}
#[test]
fn half_preset_param_count_matches_trained_checkpoint() {
let cfg = WiFlowStdConfig::half();
cfg.validate().expect("half preset must validate");
assert_eq!(cfg.param_count(), 843_834);
}
#[test]
fn quarter_preset_param_count_matches_trained_checkpoint() {
let cfg = WiFlowStdConfig::quarter();
cfg.validate().expect("quarter preset must validate");
assert_eq!(cfg.param_count(), 338_600);
}
#[test]
fn tiny_preset_param_count_matches_trained_checkpoint() {
let cfg = WiFlowStdConfig::tiny();
cfg.validate().expect("tiny preset must validate");
assert_eq!(cfg.param_count(), 56_290);
}
#[test]
fn preset_tcn_groups_match_sweep_per_block_record() {
let half = WiFlowStdConfig::half();
let groups: Vec<(usize, usize)> = {
let mut c_in = half.subcarriers;
half.tcn_channels
.iter()
.map(|&c_out| {
let g = (half.tcn_conv_groups(c_in), half.tcn_conv_groups(c_out));
c_in = c_out;
g
})
.collect()
};
assert_eq!(groups, [(20, 10), (10, 20), (20, 10), (10, 20)]);
let tiny = WiFlowStdConfig::tiny();
assert_eq!(tiny.tcn_conv_groups(540), 540); assert_eq!(tiny.tcn_conv_groups(68), 68);
}
#[test]
fn preset_stride_schedules_match_sweep_record() {
assert_eq!(WiFlowStdConfig::default().conv_strides(), [2, 2, 2, 2]);
assert_eq!(WiFlowStdConfig::half().conv_strides(), [2, 2, 2, 1]);
assert_eq!(WiFlowStdConfig::quarter().conv_strides(), [2, 2, 1, 1]);
assert_eq!(WiFlowStdConfig::tiny().conv_strides(), [2, 1, 1, 1]);
assert_eq!(WiFlowStdConfig::half().feature_width(), 15);
assert_eq!(WiFlowStdConfig::quarter().feature_width(), 15);
assert_eq!(WiFlowStdConfig::tiny().feature_width(), 16);
}
#[test]
fn for_keypoints_17_keeps_trained_trunk_and_pools_15_to_17() {
let cfg = WiFlowStdConfig::for_keypoints(17);
assert_eq!(cfg.min_feature_width, 15);
assert_eq!(cfg.conv_strides(), [2, 2, 2, 2]);
assert_eq!(cfg.feature_width(), 15);
assert_eq!(cfg.output_shape(1), (1, 17, 2));
}
#[test]
fn min_feature_width_override_changes_schedule_as_designed() {
let cfg = WiFlowStdConfig {
min_feature_width: 30,
..Default::default()
};
cfg.validate().expect("floor 30 validates");
assert_eq!(cfg.conv_strides(), [2, 2, 2, 1]);
assert_eq!(cfg.feature_width(), 30);
let cfg = WiFlowStdConfig {
min_feature_width: 8,
..WiFlowStdConfig::tiny()
};
cfg.validate().expect("floor 8 validates");
assert_eq!(cfg.conv_strides(), [2, 2, 1, 1]);
assert_eq!(cfg.feature_width(), 8);
}
#[test]
fn rejects_zero_min_feature_width() {
let cfg = WiFlowStdConfig {
min_feature_width: 0,
..Default::default()
};
assert!(cfg.validate().is_err());
}
#[test]
fn param_count_returns_zero_for_invalid_configs() {
for cfg in [
WiFlowStdConfig {
conv_channels: vec![],
..Default::default()
},
WiFlowStdConfig {
tcn_groups: 0,
..Default::default()
},
WiFlowStdConfig {
input_pw_groups: 0,
..Default::default()
},
WiFlowStdConfig {
tcn_channels: vec![],
..Default::default()
},
] {
assert!(cfg.validate().is_err(), "precondition: {cfg:?} is invalid");
assert_eq!(cfg.param_count(), 0, "no panic, returns 0: {cfg:?}");
}
}
#[test]
fn fixed_mode_with_defaults_is_unchanged_by_new_knobs() {
let mut cfg = WiFlowStdConfig::default();
assert_eq!(cfg.param_count(), REFERENCE_PARAMS);
cfg.tcn_groups_mode = TcnGroupsMode::Gcd;
cfg.validate().expect("gcd mode validates at defaults");
assert_eq!(cfg.param_count(), REFERENCE_PARAMS);
assert_eq!(WiFlowStdConfig::default().decoder_mid(), 32);
}
#[test]
fn rejects_bad_input_pw_groups() {
let cfg = WiFlowStdConfig {
input_pw_groups: 7,
..Default::default()
};
assert!(cfg.validate().is_err());
let cfg = WiFlowStdConfig {
input_pw_groups: 27,
..WiFlowStdConfig::tiny()
};
assert!(cfg.validate().is_err());
let zero = WiFlowStdConfig {
input_pw_groups: 0,
..Default::default()
};
assert!(zero.validate().is_err());
}
#[test]
fn serde_defaults_for_new_fields_are_backward_compatible() {
let legacy = r#"{
"subcarriers": 540, "window": 20,
"tcn_channels": [540, 440, 340, 240], "tcn_groups": 20,
"conv_channels": [8, 16, 32, 64], "attention_groups": 8,
"keypoints": 15, "dropout": 0.5
}"#;
let cfg: WiFlowStdConfig = serde_json::from_str(legacy).expect("deserialize");
assert_eq!(cfg, WiFlowStdConfig::default());
assert_eq!(cfg.param_count(), REFERENCE_PARAMS);
}
#[test]
fn serde_roundtrip_preserves_presets() {
for cfg in [
WiFlowStdConfig::half(),
WiFlowStdConfig::quarter(),
WiFlowStdConfig::tiny(),
] {
let json = serde_json::to_string(&cfg).expect("serialize");
let back: WiFlowStdConfig = serde_json::from_str(&json).expect("deserialize");
assert_eq!(back, cfg);
}
}
#[test]
fn output_shape_default_and_esp32() {
assert_eq!(WiFlowStdConfig::default().output_shape(4), (4, 15, 2));
assert_eq!(
WiFlowStdConfig::for_keypoints(17).output_shape(1),
(1, 17, 2)
);
}
#[test]
fn feature_width_default_is_15() {
assert_eq!(WiFlowStdConfig::default().feature_width(), 15);
}
#[test]
fn tcn_output_channels_default_is_240() {
assert_eq!(WiFlowStdConfig::default().tcn_output_channels(), 240);
}
#[test]
fn rejects_subcarriers_not_divisible_by_groups() {
let cfg = WiFlowStdConfig {
subcarriers: 541,
..Default::default()
};
assert!(cfg.validate().is_err());
}
#[test]
fn rejects_zero_dimensions() {
for cfg in [
WiFlowStdConfig {
subcarriers: 0,
..Default::default()
},
WiFlowStdConfig {
window: 0,
..Default::default()
},
WiFlowStdConfig {
keypoints: 0,
..Default::default()
},
WiFlowStdConfig {
tcn_groups: 0,
..Default::default()
},
] {
assert!(cfg.validate().is_err(), "expected rejection: {cfg:?}");
}
}
#[test]
fn rejects_empty_or_indivisible_tcn_channels() {
let empty = WiFlowStdConfig {
tcn_channels: vec![],
..Default::default()
};
assert!(empty.validate().is_err());
let indivisible = WiFlowStdConfig {
tcn_channels: vec![540, 441],
..Default::default()
};
assert!(indivisible.validate().is_err());
}
#[test]
fn rejects_bad_conv_channels() {
let empty = WiFlowStdConfig {
conv_channels: vec![],
..Default::default()
};
assert!(empty.validate().is_err());
let zero = WiFlowStdConfig {
conv_channels: vec![8, 0, 64],
..Default::default()
};
assert!(zero.validate().is_err());
let odd_last = WiFlowStdConfig {
conv_channels: vec![8, 16, 33],
attention_groups: 1,
..Default::default()
};
assert!(odd_last.validate().is_err());
}
#[test]
fn rejects_attention_group_mismatch() {
let cfg = WiFlowStdConfig {
attention_groups: 7, ..Default::default()
};
assert!(cfg.validate().is_err());
let zero = WiFlowStdConfig {
attention_groups: 0,
..Default::default()
};
assert!(zero.validate().is_err());
}
#[test]
fn rejects_out_of_range_dropout() {
for d in [1.0, 1.5, -0.1, f64::NAN] {
let cfg = WiFlowStdConfig {
dropout: d,
..Default::default()
};
assert!(cfg.validate().is_err(), "dropout {d} must be rejected");
}
}
#[test]
fn serde_roundtrip_preserves_config() {
let cfg = WiFlowStdConfig::for_keypoints(17);
let json = serde_json::to_string(&cfg).expect("serialize");
let back: WiFlowStdConfig = serde_json::from_str(&json).expect("deserialize");
assert_eq!(back, cfg);
}
}