use crate::params::FlameParams;
#[derive(Debug, Clone, Default)]
pub struct FlameParamsBuilder {
shape: Vec<f32>,
expression: Vec<f32>,
pose: Vec<f32>,
translation: [f32; 3],
}
impl FlameParamsBuilder {
#[must_use]
pub fn shape(mut self, shape: Vec<f32>) -> Self {
self.shape = shape;
self
}
#[must_use]
pub fn expression(mut self, expression: Vec<f32>) -> Self {
self.expression = expression;
self
}
#[must_use]
pub fn pose(mut self, pose: Vec<f32>) -> Self {
self.pose = pose;
self
}
#[must_use]
pub fn root_rotation(mut self, rotation: [f32; 3]) -> Self {
self.ensure_pose_size();
self.pose[0..3].copy_from_slice(&rotation);
self
}
#[must_use]
pub fn neck_rotation(mut self, rotation: [f32; 3]) -> Self {
self.ensure_pose_size();
self.pose[3..6].copy_from_slice(&rotation);
self
}
#[must_use]
pub fn jaw_rotation(mut self, angle: f32) -> Self {
self.ensure_pose_size();
self.pose[6] = angle;
self
}
#[must_use]
pub fn jaw_rotation_full(mut self, rotation: [f32; 3]) -> Self {
self.ensure_pose_size();
self.pose[6..9].copy_from_slice(&rotation);
self
}
#[must_use]
pub fn left_eye_rotation(mut self, rotation: [f32; 3]) -> Self {
self.ensure_pose_size();
self.pose[9..12].copy_from_slice(&rotation);
self
}
#[must_use]
pub fn right_eye_rotation(mut self, rotation: [f32; 3]) -> Self {
self.ensure_pose_size();
self.pose[12..15].copy_from_slice(&rotation);
self
}
#[must_use]
pub fn translation(mut self, translation: [f32; 3]) -> Self {
self.translation = translation;
self
}
#[must_use]
pub fn build(self) -> FlameParams {
let pose = if self.pose.is_empty() {
vec![0.0; FlameParams::NUM_JOINTS * 3]
} else {
self.pose
};
FlameParams {
shape: self.shape,
expression: self.expression,
pose,
translation: self.translation,
}
}
fn ensure_pose_size(&mut self) {
if self.pose.len() < FlameParams::NUM_JOINTS * 3 {
self.pose.resize(FlameParams::NUM_JOINTS * 3, 0.0);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_builder_empty() {
let params = FlameParams::builder().build();
assert_eq!(params.shape.len(), 0);
assert_eq!(params.expression.len(), 0);
assert_eq!(params.pose.len(), 15);
assert_relative_eq!(params.translation[0], 0.0);
assert_relative_eq!(params.translation[1], 0.0);
assert_relative_eq!(params.translation[2], 0.0);
}
#[test]
fn test_builder_with_shape() {
let params = FlameParams::builder().shape(vec![0.5, -0.3]).build();
assert_eq!(params.shape, vec![0.5, -0.3]);
}
#[test]
fn test_builder_jaw_rotation() {
let params = FlameParams::builder().jaw_rotation(0.15).build();
assert!((params.pose[6] - 0.15).abs() < 1e-6);
assert!((params.pose[7]).abs() < 1e-6);
assert!((params.pose[8]).abs() < 1e-6);
}
#[test]
fn test_builder_full() {
let params = FlameParams::builder()
.shape(vec![0.1, 0.2])
.expression(vec![0.5])
.root_rotation([0.1, 0.0, 0.0])
.jaw_rotation(0.2)
.translation([0.0, 0.1, 0.0])
.build();
assert_eq!(params.shape, vec![0.1, 0.2]);
assert_eq!(params.expression, vec![0.5]);
assert!((params.pose[0] - 0.1).abs() < 1e-6);
assert!((params.pose[6] - 0.2).abs() < 1e-6);
assert_relative_eq!(params.translation[0], 0.0);
assert_relative_eq!(params.translation[1], 0.1);
assert_relative_eq!(params.translation[2], 0.0);
}
#[test]
fn test_validate_success() {
let params = FlameParams::builder()
.shape(vec![1.0, -1.5, 0.5])
.expression(vec![0.8, -0.5])
.jaw_rotation(0.15)
.build();
assert!(params.validate());
}
#[test]
fn test_validate_shape_out_of_range() {
let params = FlameParams::builder()
.shape(vec![5.0]) .build();
assert!(!params.validate());
}
}