Skip to main content

oxihuman_morph/
learned_corrective.rs

1// Copyright (C) 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3#![allow(dead_code)]
4
5//! ML-learned corrective shape stub.
6
7/// A single learned corrective entry mapping a driver value to a delta.
8#[derive(Debug, Clone)]
9pub struct CorrectiveEntry {
10    pub driver_index: usize,
11    pub driver_value: f32,
12    pub delta: Vec<[f32; 3]>,
13    pub weight: f32,
14}
15
16/// Learned corrective shape system.
17#[derive(Debug, Clone)]
18pub struct LearnedCorrective {
19    pub entries: Vec<CorrectiveEntry>,
20    pub vertex_count: usize,
21    pub enabled: bool,
22}
23
24impl LearnedCorrective {
25    pub fn new(vertex_count: usize) -> Self {
26        LearnedCorrective {
27            entries: Vec::new(),
28            vertex_count,
29            enabled: true,
30        }
31    }
32}
33
34/// Create a new learned corrective system.
35pub fn new_learned_corrective(vertex_count: usize) -> LearnedCorrective {
36    LearnedCorrective::new(vertex_count)
37}
38
39/// Add a corrective entry.
40pub fn lc_add_entry(lc: &mut LearnedCorrective, entry: CorrectiveEntry) {
41    lc.entries.push(entry);
42}
43
44/// Evaluate all corrective entries and accumulate deltas (stub: zeroed output).
45pub fn lc_evaluate(lc: &LearnedCorrective, _drivers: &[f32]) -> Vec<[f32; 3]> {
46    /* Stub: returns zeroed delta array */
47    vec![[0.0; 3]; lc.vertex_count]
48}
49
50/// Return entry count.
51pub fn lc_entry_count(lc: &LearnedCorrective) -> usize {
52    lc.entries.len()
53}
54
55/// Enable or disable the corrective system.
56pub fn lc_set_enabled(lc: &mut LearnedCorrective, enabled: bool) {
57    lc.enabled = enabled;
58}
59
60/// Serialize to JSON-like string.
61pub fn lc_to_json(lc: &LearnedCorrective) -> String {
62    format!(
63        r#"{{"vertex_count":{},"entry_count":{},"enabled":{}}}"#,
64        lc.vertex_count,
65        lc.entries.len(),
66        lc.enabled
67    )
68}
69
70#[cfg(test)]
71mod tests {
72    use super::*;
73
74    #[test]
75    fn test_new_vertex_count() {
76        let lc = new_learned_corrective(100);
77        assert_eq!(lc.vertex_count, 100 /* vertex count must match */,);
78    }
79
80    #[test]
81    fn test_default_no_entries() {
82        let lc = new_learned_corrective(10);
83        assert_eq!(lc_entry_count(&lc), 0 /* initially no entries */,);
84    }
85
86    #[test]
87    fn test_add_entry() {
88        let mut lc = new_learned_corrective(4);
89        let e = CorrectiveEntry {
90            driver_index: 0,
91            driver_value: 1.0,
92            delta: vec![[0.1, 0.0, 0.0]; 4],
93            weight: 1.0,
94        };
95        lc_add_entry(&mut lc, e);
96        assert_eq!(
97            lc_entry_count(&lc),
98            1, /* entry count must be 1 after add */
99        );
100    }
101
102    #[test]
103    fn test_evaluate_length() {
104        let lc = new_learned_corrective(8);
105        let out = lc_evaluate(&lc, &[]);
106        assert_eq!(
107            out.len(),
108            8, /* evaluate output length must match vertex count */
109        );
110    }
111
112    #[test]
113    fn test_evaluate_zeroed() {
114        let lc = new_learned_corrective(3);
115        let out = lc_evaluate(&lc, &[1.0]);
116        assert!((out[0][0]).abs() < 1e-6 /* stub output must be zero */,);
117    }
118
119    #[test]
120    fn test_set_enabled_false() {
121        let mut lc = new_learned_corrective(4);
122        lc_set_enabled(&mut lc, false);
123        assert!(!lc.enabled /* enabled flag must be false */,);
124    }
125
126    #[test]
127    fn test_to_json_contains_vertex_count() {
128        let lc = new_learned_corrective(20);
129        let j = lc_to_json(&lc);
130        assert!(j.contains("\"vertex_count\""), /* json must contain vertex_count */);
131    }
132
133    #[test]
134    fn test_to_json_contains_entry_count() {
135        let lc = new_learned_corrective(5);
136        let j = lc_to_json(&lc);
137        assert!(j.contains("\"entry_count\""), /* json must contain entry_count */);
138    }
139
140    #[test]
141    fn test_multiple_entries() {
142        let mut lc = new_learned_corrective(2);
143        for i in 0..5 {
144            lc_add_entry(
145                &mut lc,
146                CorrectiveEntry {
147                    driver_index: i,
148                    driver_value: 0.5,
149                    delta: vec![[0.0; 3]; 2],
150                    weight: 1.0,
151                },
152            );
153        }
154        assert_eq!(
155            lc_entry_count(&lc),
156            5, /* five entries must be stored */
157        );
158    }
159
160    #[test]
161    fn test_enabled_by_default() {
162        let lc = new_learned_corrective(1);
163        assert!(lc.enabled /* must be enabled by default */,);
164    }
165}