Skip to main content

oxihuman_morph/
expression_randomizer.rs

1// Copyright (C) 2026 COOLJAPAN OU (Team KitaSan) / SPDX-License-Identifier: Apache-2.0
2#![allow(dead_code)]
3
4//! Expression randomizer — deterministic pseudo-random expression weight sampling from a seed.
5
6/// Config for expression randomization.
7#[allow(dead_code)]
8#[derive(Debug, Clone, PartialEq)]
9pub struct ExpressionRandomizerConfig {
10    /// Number of expression channels.
11    pub channel_count: usize,
12    /// Amplitude scale applied to each sample (0..=1).
13    pub amplitude: f32,
14    /// Sparsity: fraction of channels set to zero (0 = dense, 1 = all zero).
15    pub sparsity: f32,
16}
17
18impl Default for ExpressionRandomizerConfig {
19    fn default() -> Self {
20        Self {
21            channel_count: 8,
22            amplitude: 0.6,
23            sparsity: 0.3,
24        }
25    }
26}
27
28/// Lightweight PCG-style 32-bit pseudo-random generator (deterministic, no external deps).
29fn pcg32(state: &mut u64) -> u32 {
30    *state = state
31        .wrapping_mul(6_364_136_223_846_793_005)
32        .wrapping_add(1_442_695_040_888_963_407);
33    let xor = ((*state >> 18) ^ *state) >> 27;
34    let rot = (*state >> 59) as u32;
35    ((xor.wrapping_shr(rot)) | (xor.wrapping_shl(u64::BITS.wrapping_sub(rot)))) as u32
36}
37
38fn pcg_f32(state: &mut u64) -> f32 {
39    // Map u32 to [0, 1)
40    (pcg32(state) >> 8) as f32 / 16_777_216.0
41}
42
43/// Sample a randomized expression weight vector.
44#[allow(dead_code)]
45pub fn sample_expression(seed: u64, cfg: &ExpressionRandomizerConfig) -> Vec<f32> {
46    let mut rng = seed ^ 0xDEAD_BEEF_CAFE_1234;
47    (0..cfg.channel_count)
48        .map(|_| {
49            let sparse_val = pcg_f32(&mut rng);
50            if sparse_val < cfg.sparsity {
51                0.0
52            } else {
53                let v = pcg_f32(&mut rng);
54                // Signed: map [0,1) -> [-1,1)
55                let signed = v * 2.0 - 1.0;
56                (signed * cfg.amplitude).clamp(-1.0, 1.0)
57            }
58        })
59        .collect()
60}
61
62/// Sample and return only the non-zero channels as (index, weight) pairs.
63#[allow(dead_code)]
64pub fn sample_sparse_expression(seed: u64, cfg: &ExpressionRandomizerConfig) -> Vec<(usize, f32)> {
65    sample_expression(seed, cfg)
66        .into_iter()
67        .enumerate()
68        .filter(|(_, w)| w.abs() > 1e-6)
69        .collect()
70}
71
72/// Blend between two sampled expressions at interpolation factor t.
73#[allow(dead_code)]
74pub fn blend_sampled(a: &[f32], b: &[f32], t: f32) -> Vec<f32> {
75    let t = t.clamp(0.0, 1.0);
76    let inv = 1.0 - t;
77    a.iter()
78        .zip(b.iter())
79        .map(|(x, y)| x * inv + y * t)
80        .collect()
81}
82
83/// Returns the L1 norm of a sampled expression vector.
84#[allow(dead_code)]
85pub fn expression_energy(weights: &[f32]) -> f32 {
86    weights.iter().map(|w| w.abs()).sum()
87}
88
89/// Returns the index of the channel with the largest absolute weight.
90#[allow(dead_code)]
91pub fn dominant_channel(weights: &[f32]) -> Option<usize> {
92    weights
93        .iter()
94        .enumerate()
95        .max_by(|(_, a), (_, b)| {
96            a.abs()
97                .partial_cmp(&b.abs())
98                .unwrap_or(std::cmp::Ordering::Equal)
99        })
100        .map(|(i, _)| i)
101}
102
103/// Normalize weights to unit L2 norm.  Returns zeros if input is zero.
104#[allow(dead_code)]
105pub fn normalize_expression(weights: &[f32]) -> Vec<f32> {
106    let l2: f32 = weights.iter().map(|w| w * w).sum::<f32>().sqrt();
107    if l2 < 1e-8 {
108        weights.to_vec()
109    } else {
110        weights.iter().map(|w| w / l2).collect()
111    }
112}
113
114/// Serialise weights to a compact JSON array string.
115#[allow(dead_code)]
116pub fn expression_to_json(weights: &[f32]) -> String {
117    let inner: Vec<String> = weights.iter().map(|w| format!("{:.4}", w)).collect();
118    format!("[{}]", inner.join(","))
119}
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124
125    fn cfg() -> ExpressionRandomizerConfig {
126        ExpressionRandomizerConfig::default()
127    }
128
129    #[test]
130    fn sample_length_matches_config() {
131        let w = sample_expression(42, &cfg());
132        assert_eq!(w.len(), cfg().channel_count);
133    }
134
135    #[test]
136    fn deterministic_same_seed() {
137        let a = sample_expression(99, &cfg());
138        let b = sample_expression(99, &cfg());
139        assert_eq!(a, b);
140    }
141
142    #[test]
143    fn different_seeds_differ() {
144        let a = sample_expression(1, &cfg());
145        let b = sample_expression(2, &cfg());
146        assert_ne!(a, b);
147    }
148
149    #[test]
150    fn weights_in_range() {
151        let w = sample_expression(7, &cfg());
152        assert!(w.iter().all(|v| (-1.0..=1.0).contains(v)));
153    }
154
155    #[test]
156    fn sparse_expression_fewer_entries() {
157        let full = sample_expression(55, &cfg());
158        let sparse = sample_sparse_expression(55, &cfg());
159        assert!(sparse.len() <= full.len());
160    }
161
162    #[test]
163    fn blend_midpoint() {
164        let a = vec![0.0f32; 4];
165        let b = vec![1.0f32; 4];
166        let m = blend_sampled(&a, &b, 0.5);
167        assert!(m.iter().all(|v| (v - 0.5).abs() < 1e-6));
168    }
169
170    #[test]
171    fn energy_zero_for_zeros() {
172        let z = vec![0.0f32; 6];
173        assert!(expression_energy(&z) < 1e-8);
174    }
175
176    #[test]
177    fn dominant_channel_found() {
178        let mut w = vec![0.1f32, 0.9, 0.3];
179        w[1] = 0.9;
180        assert_eq!(dominant_channel(&w), Some(1));
181    }
182
183    #[test]
184    fn normalize_unit_length() {
185        let w = vec![3.0f32, 4.0];
186        let n = normalize_expression(&w);
187        let l2: f32 = n.iter().map(|v| v * v).sum::<f32>().sqrt();
188        assert!((l2 - 1.0).abs() < 1e-5);
189    }
190
191    #[test]
192    fn json_array_format() {
193        let w = vec![0.5f32, -0.5];
194        let j = expression_to_json(&w);
195        assert!(j.starts_with('[') && j.ends_with(']'));
196    }
197}