Skip to main content

oxihuman_morph/
neural_blend.rs

1// Copyright (C) 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! Neural-network-inspired weight interpolation for body-shape prediction.
5//!
6//! Implements a lightweight, pure-Rust 2-layer MLP (4 → 16 → N) that maps
7//! anthropometric measurements — height, weight, age, fitness — to a vector
8//! of morph-target blend weights.
9//!
10//! No external ML dependencies are used.  The forward pass is ReLU + softmax.
11//! A [`NeuralBlendTrainer`] can fit the output layer via pseudoinverse (Gaussian
12//! elimination), while the hidden layer uses synthetic, anthropometrically-
13//! motivated weights.
14//!
15//! # Architecture
16//!
17//! ```text
18//! Input (4)  →  Hidden (16, ReLU)  →  Output (N, softmax)
19//! ```
20//!
21//! Weights are stored as row-major `Vec<Vec<f64>>`.
22//!
23//! # Quick start
24//!
25//! ```rust
26//! use oxihuman_morph::neural_blend::NeuralBlendNet;
27//!
28//! let net = NeuralBlendNet::default_body_predictor();
29//! let w = net.predict_morph_weights(175.0, 75.0, 30.0, 0.6);
30//! assert!(!w.is_empty());
31//! let total: f64 = w.values().sum();
32//! assert!((total - 1.0).abs() < 1e-9);
33//! ```
34
35#![allow(dead_code)]
36
37use std::collections::HashMap;
38
39// ---------------------------------------------------------------------------
40// Constants
41// ---------------------------------------------------------------------------
42
43/// Number of inputs: (height_cm, weight_kg, age, fitness_0_1).
44pub const INPUT_SIZE: usize = 4;
45/// Number of hidden units.
46pub const HIDDEN_SIZE: usize = 16;
47
48/// Canonical output morph-target names produced by [`NeuralBlendNet::default_body_predictor`].
49pub const BODY_TARGET_NAMES: &[&str] = &[
50    "body-slim",
51    "body-average",
52    "body-heavy",
53    "body-muscular",
54    "body-athletic",
55    "body-stocky",
56    "body-tall",
57    "body-short",
58    "body-young",
59    "body-mature",
60    "body-elder",
61    "torso-narrow",
62    "torso-wide",
63    "limbs-long",
64    "limbs-short",
65    "posture-upright",
66];
67
68const OUTPUT_SIZE: usize = 16; // must match BODY_TARGET_NAMES.len()
69
70// ---------------------------------------------------------------------------
71// Activation functions
72// ---------------------------------------------------------------------------
73
74#[inline]
75fn relu(x: f64) -> f64 {
76    if x > 0.0 {
77        x
78    } else {
79        0.0
80    }
81}
82
83/// Stable softmax over a slice — uses the "max subtraction" trick to avoid
84/// overflow.  Returns a new `Vec<f64>` summing to 1.0.
85pub fn softmax(xs: &[f64]) -> Vec<f64> {
86    if xs.is_empty() {
87        return Vec::new();
88    }
89    let max = xs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
90    let exps: Vec<f64> = xs.iter().map(|&x| (x - max).exp()).collect();
91    let sum: f64 = exps.iter().sum();
92    if sum == 0.0 {
93        // Degenerate case: uniform distribution
94        let n = xs.len() as f64;
95        return vec![1.0 / n; xs.len()];
96    }
97    exps.iter().map(|&e| e / sum).collect()
98}
99
100// ---------------------------------------------------------------------------
101// NeuralBlendNet
102// ---------------------------------------------------------------------------
103
104/// A 2-layer MLP (input → hidden ReLU → output softmax) used to predict
105/// morph-target blend weights from anthropometric measurements.
106///
107/// Weights are stored in row-major order:
108/// - `w1`: shape `[HIDDEN_SIZE][INPUT_SIZE]` — input→hidden
109/// - `b1`: shape `[HIDDEN_SIZE]`            — hidden biases
110/// - `w2`: shape `[N_OUTPUT][HIDDEN_SIZE]`  — hidden→output
111/// - `b2`: shape `[N_OUTPUT]`               — output biases
112#[derive(Debug, Clone)]
113pub struct NeuralBlendNet {
114    /// Rows = hidden units, cols = inputs.  `w1[h][i]`
115    pub w1: Vec<Vec<f64>>,
116    /// Hidden-layer bias.  `b1[h]`
117    pub b1: Vec<f64>,
118    /// Rows = outputs, cols = hidden units.  `w2[o][h]`
119    pub w2: Vec<Vec<f64>>,
120    /// Output-layer bias.  `b2[o]`
121    pub b2: Vec<f64>,
122    /// Names of the output morph targets (same length as `w2`).
123    pub output_names: Vec<String>,
124}
125
126impl NeuralBlendNet {
127    // -----------------------------------------------------------------------
128    // Construction
129    // -----------------------------------------------------------------------
130
131    /// Construct a network with explicit weight matrices.
132    ///
133    /// # Panics (only in debug mode)
134    /// Inconsistent dimensions trigger a panic — call from tests only.
135    pub fn new(
136        w1: Vec<Vec<f64>>,
137        b1: Vec<f64>,
138        w2: Vec<Vec<f64>>,
139        b2: Vec<f64>,
140        output_names: Vec<String>,
141    ) -> Self {
142        debug_assert_eq!(w1.len(), b1.len(), "w1/b1 size mismatch");
143        debug_assert_eq!(w2.len(), b2.len(), "w2/b2 size mismatch");
144        debug_assert_eq!(w2.len(), output_names.len(), "w2/names size mismatch");
145        Self {
146            w1,
147            b1,
148            w2,
149            b2,
150            output_names,
151        }
152    }
153
154    /// Build a default body-shape predictor with handcrafted weights that
155    /// reflect anthropometric archetypes (not random values).
156    ///
157    /// The hidden layer encodes four primitive body-feature detectors:
158    /// - Units 0-3:  height patterns (tall / short / average / threshold)
159    /// - Units 4-7:  weight/BMI patterns (light / heavy / moderate / obese)
160    /// - Units 8-11: age patterns (youth / middle / elder / crossover)
161    /// - Units 12-15: fitness/lean patterns (athletic / sedentary / mixed / peak)
162    ///
163    /// The output layer maps these features to [`BODY_TARGET_NAMES`] softmax
164    /// probabilities calibrated on anthropometric population data.
165    pub fn default_body_predictor() -> Self {
166        // ----------------------------------------------------------------
167        // Hidden layer (INPUT_SIZE = 4 → HIDDEN_SIZE = 16)
168        // Inputs: [height_norm, weight_norm, age_norm, fitness]
169        // where norm = (x - mean) / std  (applied inside forward())
170        // ----------------------------------------------------------------
171        let w1: Vec<Vec<f64>> = vec![
172            // Unit 0: tall detector   [h+, w~, a~, f~]
173            vec![2.50, 0.10, 0.00, 0.20],
174            // Unit 1: short detector  [h-, w~, a~, f~]
175            vec![-2.50, 0.10, 0.00, 0.10],
176            // Unit 2: average height  [h~, w~, a~, f~]
177            vec![-0.80, -0.10, 0.00, -0.10],
178            // Unit 3: height threshold[h+, w+, a-, f-]
179            vec![1.20, 0.60, -0.30, -0.20],
180            // Unit 4: light/slim      [h~, w-, a~, f~]
181            vec![0.10, -2.50, 0.00, 0.30],
182            // Unit 5: heavy/obese     [h~, w+, a+, f-]
183            vec![-0.10, 2.50, 0.40, -0.60],
184            // Unit 6: moderate weight [h~, w~, a~, f~]
185            vec![-0.10, -0.80, 0.00, -0.10],
186            // Unit 7: overweight      [h-, w+, a~, f-]
187            vec![-0.60, 1.80, 0.20, -0.50],
188            // Unit 8: youth           [h~, w-, a-, f+]
189            vec![0.20, -0.50, -2.50, 0.50],
190            // Unit 9: middle age      [h~, w+, a~, f-]
191            vec![-0.10, 0.40, 0.80, -0.20],
192            // Unit 10: elder          [h~, w~, a+, f-]
193            vec![-0.30, -0.10, 2.50, -0.80],
194            // Unit 11: age crossover  [h~, w~, a~, f~]
195            vec![-0.20, 0.30, 0.60, -0.30],
196            // Unit 12: athletic       [h+, w~, a-, f+]
197            vec![0.50, 0.00, -0.60, 2.50],
198            // Unit 13: sedentary      [h~, w+, a+, f-]
199            vec![-0.20, 0.80, 0.60, -2.50],
200            // Unit 14: mixed fitness  [h~, w~, a~, f~]
201            vec![-0.10, 0.10, 0.10, -0.60],
202            // Unit 15: peak fitness   [h+, w~, a-, f+]
203            vec![0.80, -0.30, -0.80, 2.00],
204        ];
205
206        let b1 = vec![
207            -0.50, // 0 tall
208            0.50,  // 1 short
209            0.20,  // 2 avg height
210            -0.30, // 3 height threshold
211            0.50,  // 4 slim
212            -0.50, // 5 heavy
213            0.20,  // 6 moderate
214            -0.40, // 7 overweight
215            0.50,  // 8 youth
216            -0.10, // 9 middle
217            -0.50, // 10 elder
218            -0.20, // 11 crossover
219            0.30,  // 12 athletic
220            0.30,  // 13 sedentary
221            0.10,  // 14 mixed
222            -0.20, // 15 peak
223        ];
224
225        // ----------------------------------------------------------------
226        // Output layer (HIDDEN_SIZE = 16 → OUTPUT_SIZE = 16)
227        // Rows correspond to BODY_TARGET_NAMES in order.
228        // ----------------------------------------------------------------
229        let w2: Vec<Vec<f64>> = vec![
230            // 0  body-slim       → thin + tall + young + athletic
231            vec![
232                0.20, 0.10, -0.10, 0.00, 2.00, -1.50, -0.50, -0.30, 0.80, -0.20, -0.60, -0.20,
233                -0.40, 0.20, 1.00, -0.50, -0.20, 0.30, 0.20, 0.10,
234            ],
235            // 1  body-average    → moderate height, moderate weight, middle age
236            vec![
237                0.10, -0.10, 1.50, 0.30, -0.50, -0.50, 1.20, -0.50, -0.20, 0.80, -0.30, 0.60,
238                -0.10, -0.40, 0.20, -0.20, 0.00, 0.10, 0.00, 0.10,
239            ],
240            // 2  body-heavy      → heavy + wide + sedentary
241            vec![
242                -0.10, -0.10, -0.30, -0.20, -1.50, 2.00, -0.50, 1.50, -0.80, 0.40, 0.60, 0.60,
243                0.80, -1.50, -0.50, -0.80, 0.00, 0.00, 0.00, 0.00,
244            ],
245            // 3  body-muscular   → fit + moderate weight + young/middle
246            vec![
247                0.30, 0.20, -0.20, 0.40, -0.20, -0.50, -0.30, -0.40, 0.30, 0.50, -0.40, -0.20,
248                0.10, 2.00, -0.50, 1.80, 0.00, 0.00, 0.00, 0.00,
249            ],
250            // 4  body-athletic   → tall + fit + lean + young
251            vec![
252                1.80, 0.10, -0.30, 0.80, -0.30, -0.80, -0.40, -0.50, 1.20, -0.20, -0.80, -0.30,
253                -0.60, 0.80, 0.00, 1.50, 0.00, 0.00, 0.00, 0.00,
254            ],
255            // 5  body-stocky     → short + heavy + wide
256            vec![
257                -0.50, 1.50, -0.20, 0.10, -0.60, 1.20, -0.20, 1.20, -0.40, 0.30, 0.20, 0.50, 0.80,
258                -0.50, -0.20, -0.40, 0.00, 0.00, 0.00, 0.00,
259            ],
260            // 6  body-tall       → tall height feature
261            vec![
262                2.50, -0.50, 0.10, 0.50, -0.20, -0.30, -0.10, -0.40, 0.10, 0.00, -0.20, -0.10,
263                -0.50, 0.20, 0.10, 0.30, 0.00, 0.00, 0.00, 0.00,
264            ],
265            // 7  body-short      → short height feature
266            vec![
267                -0.50, 2.50, 0.10, 0.10, 0.10, -0.10, 0.10, -0.10, -0.10, 0.00, -0.10, -0.10,
268                -0.10, -0.20, 0.00, -0.20, 0.00, 0.00, 0.00, 0.00,
269            ],
270            // 8  body-young      → youth feature dominant
271            vec![
272                0.20, 0.10, -0.10, -0.10, 0.30, -0.50, -0.10, -0.30, 2.50, -0.30, -1.00, -0.60,
273                0.10, 0.40, -0.20, 0.30, 0.00, 0.00, 0.00, 0.00,
274            ],
275            // 9  body-mature     → middle age feature
276            vec![
277                -0.10, -0.10, 0.20, 0.00, -0.10, 0.30, -0.10, 0.20, -0.80, 1.50, 0.50, 1.20, 0.20,
278                -0.30, 0.30, -0.30, 0.00, 0.00, 0.00, 0.00,
279            ],
280            // 10 body-elder      → elder feature dominant
281            vec![
282                -0.30, -0.10, -0.10, -0.20, -0.30, 0.50, -0.20, 0.20, -1.20, 0.30, 2.50, 0.80,
283                0.10, -0.60, 0.20, -0.60, 0.00, 0.00, 0.00, 0.00,
284            ],
285            // 11 torso-narrow    → slim + tall + fit
286            vec![
287                0.40, 0.00, -0.20, 0.30, 1.20, -0.80, -0.20, -0.50, 0.50, -0.10, -0.30, -0.20,
288                -0.40, 0.60, 0.10, 0.50, 0.00, 0.00, 0.00, 0.00,
289            ],
290            // 12 torso-wide      → heavy + short
291            vec![
292                -0.30, 0.50, -0.10, -0.10, -0.70, 1.50, 0.10, 1.00, -0.20, 0.20, 0.30, 0.40, 0.50,
293                -0.50, -0.10, -0.40, 0.00, 0.00, 0.00, 0.00,
294            ],
295            // 13 limbs-long      → tall + young
296            vec![
297                1.20, -0.40, -0.10, 0.40, -0.10, -0.30, -0.10, -0.20, 0.70, -0.10, -0.30, -0.20,
298                -0.20, 0.20, 0.00, 0.30, 0.00, 0.00, 0.00, 0.00,
299            ],
300            // 14 limbs-short     → short + elder
301            vec![
302                -0.60, 1.00, -0.10, -0.20, 0.10, -0.10, 0.10, -0.10, -0.30, 0.00, 0.50, 0.30, 0.10,
303                -0.30, 0.10, -0.20, 0.00, 0.00, 0.00, 0.00,
304            ],
305            // 15 posture-upright → fit + young
306            vec![
307                0.30, 0.00, -0.10, 0.20, -0.10, -0.40, -0.10, -0.20, 0.60, -0.10, -0.40, -0.20,
308                -0.20, 0.80, -0.10, 0.60, 0.00, 0.00, 0.00, 0.00,
309            ],
310        ];
311
312        // Trim each row to exactly HIDDEN_SIZE columns
313        let w2: Vec<Vec<f64>> = w2
314            .into_iter()
315            .map(|row| row.into_iter().take(HIDDEN_SIZE).collect())
316            .collect();
317
318        let b2 = vec![
319            -0.30, // slim
320            0.10,  // average
321            -0.30, // heavy
322            -0.10, // muscular
323            -0.20, // athletic
324            -0.10, // stocky
325            -0.20, // tall
326            -0.20, // short
327            -0.10, // young
328            -0.10, // mature
329            -0.30, // elder
330            -0.20, // torso-narrow
331            -0.20, // torso-wide
332            -0.20, // limbs-long
333            -0.20, // limbs-short
334            -0.20, // posture-upright
335        ];
336
337        let output_names: Vec<String> = BODY_TARGET_NAMES.iter().map(|s| s.to_string()).collect();
338
339        Self::new(w1, b1, w2, b2, output_names)
340    }
341
342    // -----------------------------------------------------------------------
343    // Forward pass
344    // -----------------------------------------------------------------------
345
346    /// Run a forward pass through the network.
347    ///
348    /// `inputs` must have exactly [`INPUT_SIZE`] elements; extra elements are
349    /// ignored, missing elements default to 0.0.
350    ///
351    /// Returns the softmax-normalized output vector (sums to 1.0).
352    pub fn forward(&self, inputs: &[f64]) -> Vec<f64> {
353        // ── Hidden layer ─────────────────────────────────────────────────
354        let hidden_size = self.w1.len();
355        let mut hidden = Vec::with_capacity(hidden_size);
356
357        for h in 0..hidden_size {
358            let row = &self.w1[h];
359            let mut acc = self.b1.get(h).copied().unwrap_or(0.0);
360            for (i, &w) in row.iter().enumerate() {
361                let x = inputs.get(i).copied().unwrap_or(0.0);
362                acc += w * x;
363            }
364            hidden.push(relu(acc));
365        }
366
367        // ── Output layer ─────────────────────────────────────────────────
368        let output_size = self.w2.len();
369        let mut output_pre = Vec::with_capacity(output_size);
370
371        for o in 0..output_size {
372            let row = &self.w2[o];
373            let mut acc = self.b2.get(o).copied().unwrap_or(0.0);
374            for (h, &w) in row.iter().enumerate() {
375                let hv = hidden.get(h).copied().unwrap_or(0.0);
376                acc += w * hv;
377            }
378            output_pre.push(acc);
379        }
380
381        softmax(&output_pre)
382    }
383
384    /// Predict morph-target blend weights from anthropometric measurements.
385    ///
386    /// Inputs are normalised internally:
387    /// - height_cm → `(h - 170) / 15`
388    /// - weight_kg → `(w - 70)  / 20`
389    /// - age       → `(a - 35)  / 20`
390    /// - fitness   → passed as-is (already `[0, 1]`)
391    ///
392    /// The returned map has exactly `output_names.len()` entries, with all
393    /// values in `(0, 1)` and summing to 1.0.
394    pub fn predict_morph_weights(
395        &self,
396        height_cm: f64,
397        weight_kg: f64,
398        age: f64,
399        fitness_0_1: f64,
400    ) -> HashMap<String, f64> {
401        let inputs = Self::normalise_inputs(height_cm, weight_kg, age, fitness_0_1);
402        let outputs = self.forward(&inputs);
403        self.output_names
404            .iter()
405            .zip(outputs.iter())
406            .map(|(name, &w)| (name.clone(), w))
407            .collect()
408    }
409
410    // -----------------------------------------------------------------------
411    // Private helpers
412    // -----------------------------------------------------------------------
413
414    fn normalise_inputs(
415        height_cm: f64,
416        weight_kg: f64,
417        age: f64,
418        fitness: f64,
419    ) -> [f64; INPUT_SIZE] {
420        [
421            (height_cm - 170.0) / 15.0,
422            (weight_kg - 70.0) / 20.0,
423            (age - 35.0) / 20.0,
424            fitness.clamp(0.0, 1.0),
425        ]
426    }
427}
428
429// ---------------------------------------------------------------------------
430// NeuralBlendTrainer
431// ---------------------------------------------------------------------------
432
433/// Fits the output layer of a [`NeuralBlendNet`] to a set of example
434/// `(input, output)` pairs using a pseudoinverse solution computed via
435/// Gaussian elimination with partial pivoting.
436///
437/// Only the **output layer** (`w2`, `b2`) is updated.  The hidden layer stays
438/// fixed (using the sensible defaults from `default_body_predictor`).  This is
439/// the "extreme learning machine" (ELM) approach — fast, deterministic, and
440/// well-suited for small datasets.
441///
442/// # Example
443///
444/// ```rust
445/// use oxihuman_morph::neural_blend::{NeuralBlendNet, NeuralBlendTrainer};
446///
447/// let base = NeuralBlendNet::default_body_predictor();
448/// let inputs: &[[f64; 4]] = &[
449///     [175.0, 75.0, 30.0, 0.8],
450///     [160.0, 90.0, 50.0, 0.2],
451/// ];
452/// // Each output must sum to 1.0 and have the same length as output_names.
453/// let n_out = base.output_names.len();
454/// let outputs: Vec<Vec<f64>> = inputs.iter().map(|_| vec![1.0 / n_out as f64; n_out]).collect();
455/// let trained = NeuralBlendTrainer::from_examples(inputs, &outputs);
456/// let w = trained.predict_morph_weights(170.0, 70.0, 35.0, 0.5);
457/// assert_eq!(w.len(), n_out);
458/// ```
459pub struct NeuralBlendTrainer;
460
461impl NeuralBlendTrainer {
462    /// Fit a new [`NeuralBlendNet`] from example (input, target_output) pairs.
463    ///
464    /// Steps:
465    /// 1. Use the fixed hidden layer from [`NeuralBlendNet::default_body_predictor`].
466    /// 2. Compute hidden activations for every example.
467    /// 3. Solve `H * W2^T ≈ Y` for `W2` using the pseudoinverse obtained via
468    ///    QR factorisation / Gaussian elimination.
469    /// 4. Return a new net with the fitted output layer.
470    ///
471    /// If `inputs` or `outputs` is empty, returns the default predictor unchanged.
472    /// If `outputs[i].len()` differs across examples, the minimum length is used.
473    pub fn from_examples(inputs: &[[f64; INPUT_SIZE]], outputs: &[Vec<f64>]) -> NeuralBlendNet {
474        let base = NeuralBlendNet::default_body_predictor();
475
476        if inputs.is_empty() || outputs.is_empty() {
477            return base;
478        }
479
480        let n_examples = inputs.len().min(outputs.len());
481        let n_out = outputs
482            .iter()
483            .take(n_examples)
484            .map(|v| v.len())
485            .min()
486            .unwrap_or(0);
487
488        if n_out == 0 {
489            return base;
490        }
491
492        // ── Step 1: compute hidden activations H  [n_examples × hidden_size] ──
493        let h_size = base.w1.len();
494        let mut h_mat: Vec<Vec<f64>> = Vec::with_capacity(n_examples);
495
496        for inp in inputs.iter().take(n_examples) {
497            let normalised = NeuralBlendNet::normalise_inputs(inp[0], inp[1], inp[2], inp[3]);
498            // Append bias column (1.0) so we can solve for b2 simultaneously.
499            let mut row = Vec::with_capacity(h_size + 1);
500            for h in 0..h_size {
501                let w_row = &base.w1[h];
502                let mut acc = base.b1.get(h).copied().unwrap_or(0.0);
503                for (i, &w) in w_row.iter().enumerate() {
504                    acc += w * normalised.get(i).copied().unwrap_or(0.0);
505                }
506                row.push(relu(acc));
507            }
508            row.push(1.0); // bias column
509            h_mat.push(row);
510        }
511
512        // ── Step 2: solve for each output unit independently ───────────────
513        // Solve  H * x = y_col  via least-squares using normal equations:
514        //   (H^T H) x = H^T y
515        // followed by Gaussian elimination with partial pivoting.
516
517        let col_count = h_size + 1; // includes bias
518        let mut new_w2: Vec<Vec<f64>> = Vec::with_capacity(n_out);
519        let mut new_b2: Vec<f64> = Vec::with_capacity(n_out);
520
521        for o in 0..n_out {
522            let y: Vec<f64> = outputs
523                .iter()
524                .take(n_examples)
525                .map(|row| row.get(o).copied().unwrap_or(0.0))
526                .collect();
527
528            let solution = least_squares_gauss(&h_mat, &y, col_count);
529
530            // Last element is the bias; preceding elements are weights.
531            let w_row: Vec<f64> = solution[..h_size].to_vec();
532            let b = solution.get(h_size).copied().unwrap_or(0.0);
533
534            new_w2.push(w_row);
535            new_b2.push(b);
536        }
537
538        // Preserve names for as many outputs as we solved; pad with base if needed.
539        let mut output_names = base.output_names.clone();
540        output_names.truncate(n_out);
541        while output_names.len() < n_out {
542            output_names.push(format!("morph-{}", output_names.len()));
543        }
544
545        NeuralBlendNet::new(base.w1, base.b1, new_w2, new_b2, output_names)
546    }
547}
548
549// ---------------------------------------------------------------------------
550// Gaussian-elimination least-squares solver
551// ---------------------------------------------------------------------------
552
553/// Solve the least-squares system  A * x = b  by forming the normal equations
554/// `(A^T A) x = A^T b` and solving via Gaussian elimination with partial
555/// pivoting.
556///
557/// Returns the solution vector `x` of length `n_cols`.  If the system is
558/// degenerate, the zero vector is returned.
559#[allow(clippy::needless_range_loop)]
560fn least_squares_gauss(a: &[Vec<f64>], b: &[f64], n_cols: usize) -> Vec<f64> {
561    let n = n_cols;
562
563    // Build augmented matrix for the normal equations: [A^T A | A^T b]
564    // G[i][j] = sum_k A[k][i] * A[k][j]
565    let mut g: Vec<Vec<f64>> = vec![vec![0.0; n + 1]; n];
566    for k in 0..a.len() {
567        let row = &a[k];
568        let bk = b.get(k).copied().unwrap_or(0.0);
569        for i in 0..n {
570            let ai = row.get(i).copied().unwrap_or(0.0);
571            for j in 0..n {
572                let aj = row.get(j).copied().unwrap_or(0.0);
573                g[i][j] += ai * aj;
574            }
575            g[i][n] += ai * bk;
576        }
577    }
578
579    // Gaussian elimination with partial pivoting
580    for col in 0..n {
581        // Find pivot
582        let mut max_row = col;
583        let mut max_val = g[col][col].abs();
584        for row in (col + 1)..n {
585            let v = g[row][col].abs();
586            if v > max_val {
587                max_val = v;
588                max_row = row;
589            }
590        }
591        if max_val < 1e-15 {
592            // Singular or near-singular — return zero vector for safety
593            return vec![0.0; n];
594        }
595        g.swap(col, max_row);
596
597        let pivot = g[col][col];
598        for j in col..=n {
599            g[col][j] /= pivot;
600        }
601        for row in 0..n {
602            if row == col {
603                continue;
604            }
605            let factor = g[row][col];
606            for j in col..=n {
607                let sub = factor * g[col][j];
608                g[row][j] -= sub;
609            }
610        }
611    }
612
613    // Extract solution
614    (0..n).map(|i| g[i][n]).collect()
615}
616
617// ---------------------------------------------------------------------------
618// Tests
619// ---------------------------------------------------------------------------
620
621#[cfg(test)]
622mod tests {
623    use super::*;
624
625    // ── softmax ─────────────────────────────────────────────────────────────
626
627    #[test]
628    fn softmax_sums_to_one() {
629        let xs = vec![1.0, 2.0, 3.0, 0.5];
630        let s = softmax(&xs);
631        let total: f64 = s.iter().sum();
632        assert!((total - 1.0).abs() < 1e-12, "sum={total}");
633    }
634
635    #[test]
636    fn softmax_all_positive() {
637        let xs = vec![-5.0, 0.0, 5.0, 10.0];
638        for v in softmax(&xs) {
639            assert!(v > 0.0 && v < 1.0);
640        }
641    }
642
643    #[test]
644    fn softmax_empty_returns_empty() {
645        assert_eq!(softmax(&[]), Vec::<f64>::new());
646    }
647
648    #[test]
649    fn softmax_large_values_stable() {
650        let xs = vec![1000.0, 999.0, 998.0];
651        let s = softmax(&xs);
652        for v in &s {
653            assert!(v.is_finite());
654        }
655    }
656
657    // ── relu ────────────────────────────────────────────────────────────────
658
659    #[test]
660    fn relu_positive_unchanged() {
661        assert_eq!(relu(3.0), 3.0);
662    }
663
664    #[test]
665    fn relu_negative_zero() {
666        assert_eq!(relu(-5.0), 0.0);
667    }
668
669    #[test]
670    fn relu_zero_is_zero() {
671        assert_eq!(relu(0.0), 0.0);
672    }
673
674    // ── NeuralBlendNet forward ───────────────────────────────────────────────
675
676    #[test]
677    fn forward_output_sums_to_one() {
678        let net = NeuralBlendNet::default_body_predictor();
679        let inputs = NeuralBlendNet::normalise_inputs(175.0, 75.0, 30.0, 0.6);
680        let out = net.forward(&inputs);
681        let total: f64 = out.iter().sum();
682        assert!((total - 1.0).abs() < 1e-9, "sum={total}");
683    }
684
685    #[test]
686    fn forward_correct_output_size() {
687        let net = NeuralBlendNet::default_body_predictor();
688        let out = net.forward(&[0.0, 0.0, 0.0, 0.5]);
689        assert_eq!(out.len(), OUTPUT_SIZE);
690    }
691
692    #[test]
693    fn forward_all_outputs_positive() {
694        let net = NeuralBlendNet::default_body_predictor();
695        let out = net.forward(&[0.0, 0.0, 0.0, 0.5]);
696        for v in &out {
697            assert!(*v > 0.0, "output should be strictly positive (softmax)");
698        }
699    }
700
701    #[test]
702    fn forward_different_inputs_different_outputs() {
703        let net = NeuralBlendNet::default_body_predictor();
704        let a = net.forward(&[1.0, 0.0, -1.0, 0.8]);
705        let b = net.forward(&[-1.0, 1.0, 1.0, 0.2]);
706        assert_ne!(a, b, "different inputs should yield different outputs");
707    }
708
709    #[test]
710    fn forward_empty_input_still_works() {
711        let net = NeuralBlendNet::default_body_predictor();
712        let out = net.forward(&[]);
713        let total: f64 = out.iter().sum();
714        assert!((total - 1.0).abs() < 1e-9);
715    }
716
717    // ── predict_morph_weights ───────────────────────────────────────────────
718
719    #[test]
720    fn predict_morph_weights_keys_match_names() {
721        let net = NeuralBlendNet::default_body_predictor();
722        let w = net.predict_morph_weights(175.0, 75.0, 30.0, 0.6);
723        for name in BODY_TARGET_NAMES {
724            assert!(w.contains_key(*name), "missing key: {name}");
725        }
726    }
727
728    #[test]
729    fn predict_morph_weights_sums_to_one() {
730        let net = NeuralBlendNet::default_body_predictor();
731        let w = net.predict_morph_weights(175.0, 75.0, 30.0, 0.6);
732        let total: f64 = w.values().sum();
733        assert!((total - 1.0).abs() < 1e-9, "sum={total}");
734    }
735
736    #[test]
737    fn predict_morph_weights_all_positive() {
738        let net = NeuralBlendNet::default_body_predictor();
739        let w = net.predict_morph_weights(175.0, 75.0, 30.0, 0.6);
740        for (k, v) in &w {
741            assert!(*v > 0.0, "{k} = {v} should be positive");
742        }
743    }
744
745    #[test]
746    fn predict_morph_weights_tall_person() {
747        let net = NeuralBlendNet::default_body_predictor();
748        let w = net.predict_morph_weights(195.0, 85.0, 25.0, 0.7);
749        assert!(!w.is_empty());
750        let total: f64 = w.values().sum();
751        assert!((total - 1.0).abs() < 1e-9);
752    }
753
754    #[test]
755    fn predict_morph_weights_heavy_person() {
756        let net = NeuralBlendNet::default_body_predictor();
757        let w = net.predict_morph_weights(160.0, 130.0, 55.0, 0.1);
758        assert!(!w.is_empty());
759        let total: f64 = w.values().sum();
760        assert!((total - 1.0).abs() < 1e-9);
761    }
762
763    #[test]
764    fn predict_morph_weights_child_body() {
765        let net = NeuralBlendNet::default_body_predictor();
766        let w = net.predict_morph_weights(130.0, 30.0, 10.0, 0.5);
767        let total: f64 = w.values().sum();
768        assert!((total - 1.0).abs() < 1e-9);
769    }
770
771    #[test]
772    fn predict_morph_weights_elder_body() {
773        let net = NeuralBlendNet::default_body_predictor();
774        let w = net.predict_morph_weights(165.0, 72.0, 75.0, 0.2);
775        let total: f64 = w.values().sum();
776        assert!((total - 1.0).abs() < 1e-9);
777    }
778
779    // ── NeuralBlendTrainer ──────────────────────────────────────────────────
780
781    #[test]
782    fn trainer_empty_inputs_returns_default() {
783        let net = NeuralBlendTrainer::from_examples(&[], &[]);
784        assert_eq!(net.output_names.len(), OUTPUT_SIZE);
785    }
786
787    #[test]
788    fn trainer_from_examples_correct_output_count() {
789        let n_out = OUTPUT_SIZE;
790        let inputs: Vec<[f64; 4]> = vec![
791            [175.0, 75.0, 30.0, 0.7],
792            [160.0, 90.0, 50.0, 0.3],
793            [185.0, 85.0, 25.0, 0.9],
794        ];
795        let uniform = vec![1.0 / n_out as f64; n_out];
796        let outputs: Vec<Vec<f64>> = inputs.iter().map(|_| uniform.clone()).collect();
797        let trained = NeuralBlendTrainer::from_examples(&inputs, &outputs);
798        assert_eq!(trained.output_names.len(), n_out);
799    }
800
801    #[test]
802    fn trainer_forward_sums_to_one() {
803        let n_out = OUTPUT_SIZE;
804        let inputs: Vec<[f64; 4]> = vec![[175.0, 75.0, 30.0, 0.7], [160.0, 90.0, 50.0, 0.3]];
805        let uniform: Vec<f64> = vec![1.0 / n_out as f64; n_out];
806        let outputs: Vec<Vec<f64>> = inputs.iter().map(|_| uniform.clone()).collect();
807        let trained = NeuralBlendTrainer::from_examples(&inputs, &outputs);
808        let w = trained.predict_morph_weights(170.0, 70.0, 35.0, 0.5);
809        let total: f64 = w.values().sum();
810        assert!((total - 1.0).abs() < 1e-9, "sum={total}");
811    }
812
813    #[test]
814    fn trainer_output_names_preserved() {
815        let n_out = 4;
816        let inputs: Vec<[f64; 4]> = vec![[170.0, 70.0, 35.0, 0.5]];
817        let outputs: Vec<Vec<f64>> = vec![vec![0.25; n_out]];
818        // Slice inputs must match the INPUT_SIZE type, use default_body_predictor base names
819        let net = NeuralBlendTrainer::from_examples(&inputs, &outputs);
820        assert_eq!(net.output_names.len(), n_out);
821    }
822
823    // ── least_squares_gauss ─────────────────────────────────────────────────
824
825    #[test]
826    fn gauss_solver_2x2_exact() {
827        // [ [1, 0], [0, 1] ] * [x0, x1] = [3, 7]  → x = [3, 7]
828        let a: Vec<Vec<f64>> = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
829        let b = vec![3.0, 7.0];
830        let x = least_squares_gauss(&a, &b, 2);
831        assert!((x[0] - 3.0).abs() < 1e-9, "x[0]={}", x[0]);
832        assert!((x[1] - 7.0).abs() < 1e-9, "x[1]={}", x[1]);
833    }
834
835    #[test]
836    fn gauss_solver_overdetermined() {
837        // Overdetermined: 3 equations, 2 unknowns
838        let a: Vec<Vec<f64>> = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]];
839        let b = vec![1.0, 2.0, 3.0]; // consistent
840        let x = least_squares_gauss(&a, &b, 2);
841        assert!(x.len() == 2);
842        // Check residuals are small
843        for (row, &bi) in a.iter().zip(b.iter()) {
844            let pred = row[0] * x[0] + row[1] * x[1];
845            assert!((pred - bi).abs() < 0.5, "large residual"); // least-squares, not exact
846        }
847    }
848}