use std::collections::HashMap;
use crate::params::ParamState;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub enum BodyRegion {
Head,
Neck,
Torso,
LeftArm,
RightArm,
LeftLeg,
RightLeg,
Hands,
Feet,
}
impl BodyRegion {
pub fn all() -> &'static [BodyRegion] {
&[
BodyRegion::Head,
BodyRegion::Neck,
BodyRegion::Torso,
BodyRegion::LeftArm,
BodyRegion::RightArm,
BodyRegion::LeftLeg,
BodyRegion::RightLeg,
BodyRegion::Hands,
BodyRegion::Feet,
]
}
pub fn name(&self) -> &'static str {
match self {
BodyRegion::Head => "Head",
BodyRegion::Neck => "Neck",
BodyRegion::Torso => "Torso",
BodyRegion::LeftArm => "Left Arm",
BodyRegion::RightArm => "Right Arm",
BodyRegion::LeftLeg => "Left Leg",
BodyRegion::RightLeg => "Right Leg",
BodyRegion::Hands => "Hands",
BodyRegion::Feet => "Feet",
}
}
}
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
pub struct RegionParams {
pub overrides: HashMap<BodyRegion, ParamState>,
}
fn lerp_param(a: f32, b: f32, t: f32) -> f32 {
a + (b - a) * t.clamp(0.0, 1.0)
}
fn blend_param_state(override_state: &ParamState, global: &ParamState, t: f32) -> ParamState {
let height = lerp_param(override_state.height, global.height, t);
let weight = lerp_param(override_state.weight, global.weight, t);
let muscle = lerp_param(override_state.muscle, global.muscle, t);
let age = lerp_param(override_state.age, global.age, t);
let mut extra: HashMap<String, f32> = HashMap::new();
for (k, &ov) in &override_state.extra {
let target = global.extra.get(k).copied().unwrap_or(0.0);
extra.insert(k.clone(), lerp_param(ov, target, t));
}
for (k, &gv) in &global.extra {
if !override_state.extra.contains_key(k) {
extra.insert(k.clone(), lerp_param(0.0, gv, t));
}
}
ParamState {
height,
weight,
muscle,
age,
extra,
}
}
impl RegionParams {
pub fn new() -> Self {
Self::default()
}
pub fn set_region(&mut self, region: BodyRegion, params: ParamState) {
self.overrides.insert(region, params);
}
pub fn effective_params(&self, region: BodyRegion, global: &ParamState) -> ParamState {
self.overrides
.get(®ion)
.cloned()
.unwrap_or_else(|| global.clone())
}
pub fn clear_region(&mut self, region: BodyRegion) {
self.overrides.remove(®ion);
}
pub fn has_overrides(&self) -> bool {
!self.overrides.is_empty()
}
pub fn blend_toward_global(&self, global: &ParamState, t: f32) -> RegionParams {
let mut result = RegionParams::new();
for (®ion, override_state) in &self.overrides {
let blended = blend_param_state(override_state, global, t);
result.overrides.insert(region, blended);
}
result
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct RegionTag {
pub target_name: String,
pub regions: Vec<BodyRegion>,
}
impl RegionTag {
pub fn new(target_name: impl Into<String>, regions: Vec<BodyRegion>) -> Self {
RegionTag {
target_name: target_name.into(),
regions,
}
}
pub fn infer_from_name(target_name: &str) -> Self {
let lower = target_name.to_lowercase();
let mut regions = Vec::new();
if lower.contains("head")
|| lower.contains("face")
|| lower.contains("eye")
|| lower.contains("nose")
|| lower.contains("mouth")
|| lower.contains("ear")
{
regions.push(BodyRegion::Head);
}
if lower.contains("neck") {
regions.push(BodyRegion::Neck);
}
if lower.contains("torso")
|| lower.contains("chest")
|| lower.contains("belly")
|| lower.contains("back")
|| lower.contains("waist")
{
regions.push(BodyRegion::Torso);
}
if lower.contains("l-arm")
|| lower.contains("larm")
|| lower.contains("left-arm")
|| lower.contains("leftarm")
{
regions.push(BodyRegion::LeftArm);
}
if lower.contains("r-arm")
|| lower.contains("rarm")
|| lower.contains("right-arm")
|| lower.contains("rightarm")
{
regions.push(BodyRegion::RightArm);
}
if lower.contains("l-leg")
|| lower.contains("lleg")
|| lower.contains("left-leg")
|| lower.contains("leftleg")
{
regions.push(BodyRegion::LeftLeg);
}
if lower.contains("r-leg")
|| lower.contains("rleg")
|| lower.contains("right-leg")
|| lower.contains("rightleg")
{
regions.push(BodyRegion::RightLeg);
}
if lower.contains("hand") {
regions.push(BodyRegion::Hands);
}
if lower.contains("foot") || lower.contains("feet") {
regions.push(BodyRegion::Feet);
}
if regions.is_empty() {
regions.push(BodyRegion::Torso);
}
RegionTag {
target_name: target_name.to_string(),
regions,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn all_regions_has_nine() {
assert_eq!(BodyRegion::all().len(), 9);
}
#[test]
fn effective_params_falls_back_to_global() {
let rp = RegionParams::new();
let global = ParamState::new(0.7, 0.3, 0.6, 0.4);
let effective = rp.effective_params(BodyRegion::Head, &global);
assert_eq!(effective, global);
}
#[test]
fn effective_params_uses_override() {
let mut rp = RegionParams::new();
let global = ParamState::new(0.5, 0.5, 0.5, 0.5);
let head_override = ParamState::new(0.9, 0.1, 0.8, 0.2);
rp.set_region(BodyRegion::Head, head_override.clone());
let effective_head = rp.effective_params(BodyRegion::Head, &global);
assert_eq!(effective_head, head_override);
let effective_torso = rp.effective_params(BodyRegion::Torso, &global);
assert_eq!(effective_torso, global);
}
#[test]
fn clear_region_reverts_to_global() {
let mut rp = RegionParams::new();
let global = ParamState::new(0.5, 0.5, 0.5, 0.5);
let head_override = ParamState::new(0.9, 0.1, 0.8, 0.2);
rp.set_region(BodyRegion::Head, head_override);
rp.clear_region(BodyRegion::Head);
let effective = rp.effective_params(BodyRegion::Head, &global);
assert_eq!(effective, global);
}
#[test]
fn has_overrides_false_when_empty() {
let rp = RegionParams::new();
assert!(!rp.has_overrides());
}
#[test]
fn blend_toward_global_at_t1_equals_global() {
let mut rp = RegionParams::new();
let global = ParamState::new(0.5, 0.5, 0.5, 0.5);
let head_override = ParamState::new(0.9, 0.1, 0.8, 0.2);
rp.set_region(BodyRegion::Head, head_override);
let blended = rp.blend_toward_global(&global, 1.0);
let effective = blended.effective_params(BodyRegion::Head, &global);
assert!((effective.height - global.height).abs() < 1e-6);
assert!((effective.weight - global.weight).abs() < 1e-6);
assert!((effective.muscle - global.muscle).abs() < 1e-6);
assert!((effective.age - global.age).abs() < 1e-6);
}
#[test]
fn infer_from_name_head() {
let tag = RegionTag::infer_from_name("head/head-age-young.target");
assert!(tag.regions.contains(&BodyRegion::Head));
}
#[test]
fn infer_from_name_default_torso() {
let tag = RegionTag::infer_from_name("other/unknown.target");
assert!(tag.regions.contains(&BodyRegion::Torso));
}
#[test]
fn region_params_serialization() {
let mut rp = RegionParams::new();
let global = ParamState::new(0.5, 0.5, 0.5, 0.5);
let head_override = ParamState::new(0.9, 0.1, 0.8, 0.2);
rp.set_region(BodyRegion::Head, head_override);
let json = serde_json::to_string(&rp).expect("serialize");
let deserialized: RegionParams = serde_json::from_str(&json).expect("deserialize");
assert_eq!(rp.has_overrides(), deserialized.has_overrides());
let effective_orig = rp.effective_params(BodyRegion::Head, &global);
let effective_deser = deserialized.effective_params(BodyRegion::Head, &global);
assert!((effective_orig.height - effective_deser.height).abs() < 1e-6);
}
}