#![allow(dead_code)]
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum NbsActivation {
Relu,
Tanh,
Sigmoid,
}
#[derive(Debug, Clone)]
pub struct NeuralBlendShape {
pub input_dim: usize,
pub output_dim: usize,
pub activation: NbsActivation,
pub weights: Vec<f32>,
pub bias: Vec<f32>,
pub enabled: bool,
}
impl NeuralBlendShape {
pub fn new(input_dim: usize, output_dim: usize) -> Self {
NeuralBlendShape {
input_dim,
output_dim,
activation: NbsActivation::Relu,
weights: vec![0.0; input_dim * output_dim],
bias: vec![0.0; output_dim],
enabled: true,
}
}
}
pub fn new_neural_blend_shape(input_dim: usize, output_dim: usize) -> NeuralBlendShape {
NeuralBlendShape::new(input_dim, output_dim)
}
pub fn nbs_forward(nbs: &NeuralBlendShape, input: &[f32]) -> Vec<f32> {
let _ = input;
vec![0.0; nbs.output_dim]
}
pub fn nbs_set_activation(nbs: &mut NeuralBlendShape, activation: NbsActivation) {
nbs.activation = activation;
}
pub fn nbs_set_enabled(nbs: &mut NeuralBlendShape, enabled: bool) {
nbs.enabled = enabled;
}
pub fn nbs_load_weights(nbs: &mut NeuralBlendShape, weights: &[f32]) {
let n = weights.len().min(nbs.weights.len());
nbs.weights[..n].copy_from_slice(&weights[..n]);
}
pub fn nbs_to_json(nbs: &NeuralBlendShape) -> String {
format!(
r#"{{"input_dim":{},"output_dim":{},"enabled":{}}}"#,
nbs.input_dim, nbs.output_dim, nbs.enabled
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_dims() {
let nbs = new_neural_blend_shape(8, 4);
assert_eq!(nbs.input_dim, 8 ,);
assert_eq!(nbs.output_dim, 4 ,);
}
#[test]
fn test_default_enabled() {
let nbs = new_neural_blend_shape(4, 2);
assert!(nbs.enabled ,);
}
#[test]
fn test_forward_output_length() {
let nbs = new_neural_blend_shape(4, 6);
let out = nbs_forward(&nbs, &[0.0; 4]);
assert_eq!(
out.len(),
6,
);
}
#[test]
fn test_forward_disabled_still_runs() {
let mut nbs = new_neural_blend_shape(4, 3);
nbs_set_enabled(&mut nbs, false);
let out = nbs_forward(&nbs, &[1.0; 4]);
assert_eq!(
out.len(),
3,
);
}
#[test]
fn test_set_activation() {
let mut nbs = new_neural_blend_shape(2, 2);
nbs_set_activation(&mut nbs, NbsActivation::Tanh);
assert_eq!(
nbs.activation,
NbsActivation::Tanh,
);
}
#[test]
fn test_load_weights() {
let mut nbs = new_neural_blend_shape(2, 2);
nbs_load_weights(&mut nbs, &[1.0, 2.0, 3.0, 4.0]);
assert!((nbs.weights[0] - 1.0).abs() < 1e-6, );
}
#[test]
fn test_load_weights_partial() {
let mut nbs = new_neural_blend_shape(4, 4);
nbs_load_weights(&mut nbs, &[5.0, 6.0]);
assert!((nbs.weights[0] - 5.0).abs() < 1e-6, );
}
#[test]
fn test_to_json_contains_dims() {
let nbs = new_neural_blend_shape(3, 5);
let j = nbs_to_json(&nbs);
assert!(j.contains("\"input_dim\""), );
assert!(j.contains("\"output_dim\""), );
}
#[test]
fn test_weight_count() {
let nbs = new_neural_blend_shape(3, 4);
assert_eq!(
nbs.weights.len(),
12,
);
}
#[test]
fn test_bias_count() {
let nbs = new_neural_blend_shape(3, 4);
assert_eq!(
nbs.bias.len(),
4,
);
}
}