use std::path::Path;
use nalgebra as na;
use ndarray::{s, Array2, Array3};
#[cfg(feature = "parallel")]
use rayon::prelude::*;
use crate::error::FlameError;
use crate::mesh::Mesh;
use crate::params::FlameParams;
#[derive(Debug, Clone)]
pub struct BatchedFlameOutput {
pub vertices: Vec<Vec<na::Point3<f32>>>,
pub normals: Vec<Vec<na::Vector3<f32>>>,
pub faces: Vec<[u32; 3]>,
pub batch_size: usize,
}
impl BatchedFlameOutput {
#[must_use]
pub fn with_capacity(batch_size: usize, num_vertices: usize, faces: Vec<[u32; 3]>) -> Self {
let mut vertices = Vec::with_capacity(batch_size);
let mut normals = Vec::with_capacity(batch_size);
for _ in 0..batch_size {
vertices.push(vec![na::Point3::origin(); num_vertices]);
normals.push(vec![na::Vector3::zeros(); num_vertices]);
}
Self {
vertices,
normals,
faces,
batch_size,
}
}
#[must_use]
pub fn get_mesh(&self, index: usize) -> Option<Mesh> {
if index >= self.batch_size {
return None;
}
Some(Mesh {
vertices: self.vertices[index].clone(),
normals: self.normals[index].clone(),
faces: self.faces.clone(),
})
}
#[must_use]
pub fn into_meshes(self) -> Vec<Mesh> {
let faces = self.faces;
self.vertices
.into_iter()
.zip(self.normals)
.map(|(verts, norms)| Mesh {
vertices: verts,
normals: norms,
faces: faces.clone(),
})
.collect()
}
#[must_use]
pub fn num_vertices(&self) -> usize {
self.vertices.first().map_or(0, Vec::len)
}
}
#[derive(Debug, Clone)]
pub struct BatchBufferPool {
v_shaped: Vec<Array2<f32>>,
v_posed: Vec<Array2<f32>>,
rot_mats: Vec<Vec<na::Matrix3<f32>>>,
skinning: Vec<Vec<na::Matrix4<f32>>>,
num_vertices: usize,
n_joints: usize,
batch_capacity: usize,
}
impl BatchBufferPool {
#[must_use]
pub fn new(batch_size: usize, num_vertices: usize, n_joints: usize) -> Self {
let mut pool = Self {
v_shaped: Vec::with_capacity(batch_size),
v_posed: Vec::with_capacity(batch_size),
rot_mats: Vec::with_capacity(batch_size),
skinning: Vec::with_capacity(batch_size),
num_vertices,
n_joints,
batch_capacity: batch_size,
};
for _ in 0..batch_size {
pool.v_shaped.push(Array2::zeros((num_vertices, 3)));
pool.v_posed.push(Array2::zeros((num_vertices, 3)));
pool.rot_mats.push(vec![na::Matrix3::identity(); n_joints]);
pool.skinning.push(vec![na::Matrix4::identity(); n_joints]);
}
pool
}
pub fn ensure_capacity(&mut self, batch_size: usize) {
while self.batch_capacity < batch_size {
self.v_shaped.push(Array2::zeros((self.num_vertices, 3)));
self.v_posed.push(Array2::zeros((self.num_vertices, 3)));
self.rot_mats
.push(vec![na::Matrix3::identity(); self.n_joints]);
self.skinning
.push(vec![na::Matrix4::identity(); self.n_joints]);
self.batch_capacity += 1;
}
}
#[must_use]
pub fn capacity(&self) -> usize {
self.batch_capacity
}
pub fn clear(&mut self) {
for v in &mut self.v_shaped {
v.fill(0.0);
}
for v in &mut self.v_posed {
v.fill(0.0);
}
for r in &mut self.rot_mats {
for mat in r {
*mat = na::Matrix3::identity();
}
}
for s in &mut self.skinning {
for mat in s {
*mat = na::Matrix4::identity();
}
}
}
}
pub struct FlameModel {
pub v_template: Array2<f32>,
pub faces: Vec<[u32; 3]>,
pub shapedirs: Array3<f32>,
pub expressiondirs: Array3<f32>,
pub posedirs: Array3<f32>,
pub j_regressor: Array2<f32>,
pub parents: Vec<i32>,
pub lbs_weights: Array2<f32>,
pub n_joints: usize,
}
impl FlameModel {
pub fn load(dir: impl AsRef<Path>) -> Result<Self, FlameError> {
crate::io::load_flame_model(dir.as_ref())
}
#[must_use]
pub fn num_vertices(&self) -> usize {
self.v_template.nrows()
}
#[must_use]
pub fn forward(&self, params: &FlameParams) -> Mesh {
let v_shaped = self.apply_shape_expression(params);
let joints = self.j_regressor.dot(&v_shaped);
let rot_mats = self.compute_rotation_matrices(params);
let v_posed = self.apply_pose_blend_shapes(&v_shaped, &rot_mats);
let skinning = self.compute_skinning_transforms(&rot_mats, &joints);
let vertices = self.apply_lbs(&v_posed, &skinning, params);
Mesh::new(vertices, self.faces.clone())
}
#[cfg(all(feature = "simd", nightly))]
#[must_use]
pub fn forward_simd(&self, params: &FlameParams) -> Mesh {
use crate::simd::apply_lbs_simd;
let v_shaped = self.apply_shape_expression_simd(params);
let joints = self.j_regressor.dot(&v_shaped);
let rot_mats = self.compute_rotation_matrices_simd(params);
let v_posed = self.apply_pose_blend_shapes_simd(&v_shaped, &rot_mats);
let skinning = self.compute_skinning_transforms(&rot_mats, &joints);
let vertices = apply_lbs_simd(
&v_posed,
&skinning,
&self.lbs_weights.view(),
params.translation,
);
Mesh::new(vertices, self.faces.clone())
}
#[must_use]
pub fn forward_batch(&self, params_batch: &[FlameParams]) -> Vec<Mesh> {
params_batch.iter().map(|p| self.forward(p)).collect()
}
#[cfg(all(feature = "simd", nightly))]
#[must_use]
pub fn forward_batch_simd(&self, params_batch: &[FlameParams]) -> Vec<Mesh> {
params_batch.iter().map(|p| self.forward_simd(p)).collect()
}
#[cfg(feature = "parallel")]
#[must_use]
pub fn forward_batch_par(&self, params_batch: &[FlameParams]) -> Vec<Mesh> {
params_batch.par_iter().map(|p| self.forward(p)).collect()
}
#[cfg(all(feature = "parallel", feature = "simd", nightly))]
#[must_use]
pub fn forward_batch_par_simd(&self, params_batch: &[FlameParams]) -> Vec<Mesh> {
params_batch
.par_iter()
.map(|p| self.forward_simd(p))
.collect()
}
#[must_use]
pub fn forward_batch_optimized(&self, params_batch: &[FlameParams]) -> BatchedFlameOutput {
let batch_size = params_batch.len();
let num_vertices = self.num_vertices();
let mut output =
BatchedFlameOutput::with_capacity(batch_size, num_vertices, self.faces.clone());
for (idx, params) in params_batch.iter().enumerate() {
self.forward_into(params, &mut output.vertices[idx], &mut output.normals[idx]);
}
output
}
#[cfg(feature = "parallel")]
#[must_use]
pub fn forward_batch_par_optimized(&self, params_batch: &[FlameParams]) -> BatchedFlameOutput {
let batch_size = params_batch.len();
let num_vertices = self.num_vertices();
let mut output =
BatchedFlameOutput::with_capacity(batch_size, num_vertices, self.faces.clone());
params_batch
.par_iter()
.zip(output.vertices.par_iter_mut())
.zip(output.normals.par_iter_mut())
.for_each(|((params, vertices), normals)| {
self.forward_into(params, vertices, normals);
});
output
}
pub fn forward_batch_with_pool(
&self,
params_batch: &[FlameParams],
buffer_pool: &mut BatchBufferPool,
) -> BatchedFlameOutput {
let batch_size = params_batch.len();
let num_vertices = self.num_vertices();
buffer_pool.ensure_capacity(batch_size);
let mut output =
BatchedFlameOutput::with_capacity(batch_size, num_vertices, self.faces.clone());
for (idx, params) in params_batch.iter().enumerate() {
self.forward_into_with_buffers(
params,
&mut buffer_pool.v_shaped[idx],
&mut buffer_pool.v_posed[idx],
&mut buffer_pool.rot_mats[idx],
&mut buffer_pool.skinning[idx],
&mut output.vertices[idx],
&mut output.normals[idx],
);
}
output
}
#[cfg(feature = "parallel")]
pub fn forward_batch_par_with_pool(
&self,
params_batch: &[FlameParams],
buffer_pool: &mut BatchBufferPool,
) -> BatchedFlameOutput {
let batch_size = params_batch.len();
let num_vertices = self.num_vertices();
buffer_pool.ensure_capacity(batch_size);
let mut output =
BatchedFlameOutput::with_capacity(batch_size, num_vertices, self.faces.clone());
params_batch
.par_iter()
.enumerate()
.zip(output.vertices.par_iter_mut())
.zip(output.normals.par_iter_mut())
.for_each(|(((idx, params), vertices), normals)| {
self.forward_into(params, vertices, normals);
let _ = idx; });
output
}
#[must_use]
pub fn create_buffer_pool(&self, batch_size: usize) -> BatchBufferPool {
BatchBufferPool::new(batch_size, self.num_vertices(), self.n_joints)
}
pub fn forward_into(
&self,
params: &FlameParams,
vertices_out: &mut [na::Point3<f32>],
normals_out: &mut [na::Vector3<f32>],
) {
let v_shaped = self.apply_shape_expression(params);
let joints = self.j_regressor.dot(&v_shaped);
let rot_mats = self.compute_rotation_matrices(params);
let v_posed = self.apply_pose_blend_shapes(&v_shaped, &rot_mats);
let skinning = self.compute_skinning_transforms(&rot_mats, &joints);
self.apply_lbs_into(&v_posed, &skinning, params, vertices_out);
compute_normals_into(vertices_out, &self.faces, normals_out);
}
#[allow(clippy::too_many_arguments)]
fn forward_into_with_buffers(
&self,
params: &FlameParams,
v_shaped: &mut Array2<f32>,
v_posed: &mut Array2<f32>,
rot_mats: &mut [na::Matrix3<f32>],
skinning: &mut [na::Matrix4<f32>],
vertices_out: &mut [na::Point3<f32>],
normals_out: &mut [na::Vector3<f32>],
) {
self.apply_shape_expression_into(params, v_shaped);
let joints = self.j_regressor.dot(v_shaped);
self.compute_rotation_matrices_into(params, rot_mats);
self.apply_pose_blend_shapes_into(v_shaped, rot_mats, v_posed);
self.compute_skinning_transforms_into(rot_mats, &joints, skinning);
self.apply_lbs_into(v_posed, skinning, params, vertices_out);
compute_normals_into(vertices_out, &self.faces, normals_out);
}
#[inline]
fn apply_shape_expression(&self, params: &FlameParams) -> Array2<f32> {
let mut v = self.v_template.clone();
apply_blend_shapes(&mut v, &self.shapedirs, ¶ms.shape);
apply_blend_shapes(&mut v, &self.expressiondirs, ¶ms.expression);
v
}
#[inline]
fn compute_rotation_matrices(&self, params: &FlameParams) -> Vec<na::Matrix3<f32>> {
(0..self.n_joints)
.map(|j| {
let [rx, ry, rz] = params.joint_pose(j);
rodrigues(rx, ry, rz)
})
.collect()
}
fn apply_pose_blend_shapes(
&self,
v_shaped: &Array2<f32>,
rot_mats: &[na::Matrix3<f32>],
) -> Array2<f32> {
let identity = na::Matrix3::<f32>::identity();
let mut pose_feature = Vec::with_capacity((self.n_joints - 1) * 9);
for rot in rot_mats.iter().skip(1) {
let diff = rot - identity;
for c in 0..3 {
for r in 0..3 {
pose_feature.push(diff[(r, c)]);
}
}
}
let mut v = v_shaped.clone();
apply_blend_shapes(&mut v, &self.posedirs, &pose_feature);
v
}
fn compute_skinning_transforms(
&self,
rot_mats: &[na::Matrix3<f32>],
joints: &Array2<f32>,
) -> Vec<na::Matrix4<f32>> {
let nj = self.n_joints;
let mut global = vec![na::Matrix4::<f32>::identity(); nj];
for j in 0..nj {
let j_pos = na::Vector3::new(joints[[j, 0]], joints[[j, 1]], joints[[j, 2]]);
let parent = self.parents[j];
let mut local = na::Matrix4::identity();
for r in 0..3 {
for c in 0..3 {
local[(r, c)] = rot_mats[j][(r, c)];
}
}
if parent < 0 {
local[(0, 3)] = j_pos.x;
local[(1, 3)] = j_pos.y;
local[(2, 3)] = j_pos.z;
global[j] = local;
} else {
let p = parent as usize;
let p_pos = na::Vector3::new(joints[[p, 0]], joints[[p, 1]], joints[[p, 2]]);
let rel = j_pos - p_pos;
local[(0, 3)] = rel.x;
local[(1, 3)] = rel.y;
local[(2, 3)] = rel.z;
global[j] = global[p] * local;
}
}
for j in 0..nj {
let j_homo = na::Vector4::new(joints[[j, 0]], joints[[j, 1]], joints[[j, 2]], 0.0);
let correction = global[j] * j_homo;
global[j][(0, 3)] -= correction[0];
global[j][(1, 3)] -= correction[1];
global[j][(2, 3)] -= correction[2];
}
global
}
fn apply_lbs(
&self,
v_posed: &Array2<f32>,
transforms: &[na::Matrix4<f32>],
params: &FlameParams,
) -> Vec<na::Point3<f32>> {
let n = v_posed.nrows();
let nj = self.n_joints;
let [tx, ty, tz] = params.translation;
let mut out = Vec::with_capacity(n);
for i in 0..n {
let mut t = na::Matrix4::<f32>::zeros();
for (j, transform) in transforms.iter().enumerate().take(nj) {
let w = self.lbs_weights[[i, j]];
if w.abs() > 1e-12 {
t += w * transform;
}
}
let v = na::Vector4::new(v_posed[[i, 0]], v_posed[[i, 1]], v_posed[[i, 2]], 1.0);
let r = t * v;
out.push(na::Point3::new(r[0] + tx, r[1] + ty, r[2] + tz));
}
out
}
#[inline]
fn apply_shape_expression_into(&self, params: &FlameParams, out: &mut Array2<f32>) {
out.assign(&self.v_template);
apply_blend_shapes(out, &self.shapedirs, ¶ms.shape);
apply_blend_shapes(out, &self.expressiondirs, ¶ms.expression);
}
#[inline]
fn compute_rotation_matrices_into(&self, params: &FlameParams, out: &mut [na::Matrix3<f32>]) {
for (j, mat) in out.iter_mut().enumerate().take(self.n_joints) {
let [rx, ry, rz] = params.joint_pose(j);
*mat = rodrigues(rx, ry, rz);
}
}
fn apply_pose_blend_shapes_into(
&self,
v_shaped: &Array2<f32>,
rot_mats: &[na::Matrix3<f32>],
out: &mut Array2<f32>,
) {
let identity = na::Matrix3::<f32>::identity();
let mut pose_feature = Vec::with_capacity((self.n_joints - 1) * 9);
for rot in rot_mats.iter().skip(1) {
let diff = rot - identity;
for c in 0..3 {
for r in 0..3 {
pose_feature.push(diff[(r, c)]);
}
}
}
out.assign(v_shaped);
apply_blend_shapes(out, &self.posedirs, &pose_feature);
}
fn compute_skinning_transforms_into(
&self,
rot_mats: &[na::Matrix3<f32>],
joints: &Array2<f32>,
out: &mut [na::Matrix4<f32>],
) {
let nj = self.n_joints;
for mat in out.iter_mut().take(nj) {
*mat = na::Matrix4::identity();
}
for j in 0..nj {
let j_pos = na::Vector3::new(joints[[j, 0]], joints[[j, 1]], joints[[j, 2]]);
let parent = self.parents[j];
let mut local = na::Matrix4::identity();
for r in 0..3 {
for c in 0..3 {
local[(r, c)] = rot_mats[j][(r, c)];
}
}
if parent < 0 {
local[(0, 3)] = j_pos.x;
local[(1, 3)] = j_pos.y;
local[(2, 3)] = j_pos.z;
out[j] = local;
} else {
let p = parent as usize;
let p_pos = na::Vector3::new(joints[[p, 0]], joints[[p, 1]], joints[[p, 2]]);
let rel = j_pos - p_pos;
local[(0, 3)] = rel.x;
local[(1, 3)] = rel.y;
local[(2, 3)] = rel.z;
out[j] = out[p] * local;
}
}
for j in 0..nj {
let j_homo = na::Vector4::new(joints[[j, 0]], joints[[j, 1]], joints[[j, 2]], 0.0);
let correction = out[j] * j_homo;
out[j][(0, 3)] -= correction[0];
out[j][(1, 3)] -= correction[1];
out[j][(2, 3)] -= correction[2];
}
}
fn apply_lbs_into(
&self,
v_posed: &Array2<f32>,
transforms: &[na::Matrix4<f32>],
params: &FlameParams,
out: &mut [na::Point3<f32>],
) {
let n = v_posed.nrows();
let nj = self.n_joints;
let [tx, ty, tz] = params.translation;
for i in 0..n {
let mut t = na::Matrix4::<f32>::zeros();
for (j, transform) in transforms.iter().enumerate().take(nj) {
let w = self.lbs_weights[[i, j]];
if w.abs() > 1e-12 {
t += w * transform;
}
}
let v = na::Vector4::new(v_posed[[i, 0]], v_posed[[i, 1]], v_posed[[i, 2]], 1.0);
let r = t * v;
out[i] = na::Point3::new(r[0] + tx, r[1] + ty, r[2] + tz);
}
}
#[cfg(all(feature = "simd", nightly))]
#[inline]
fn apply_shape_expression_simd(&self, params: &FlameParams) -> Array2<f32> {
use crate::simd::apply_blend_shapes_simd;
let mut v = self.v_template.clone();
apply_blend_shapes_simd(&mut v, &self.shapedirs, ¶ms.shape);
apply_blend_shapes_simd(&mut v, &self.expressiondirs, ¶ms.expression);
v
}
#[cfg(all(feature = "simd", nightly))]
#[inline]
fn compute_rotation_matrices_simd(&self, params: &FlameParams) -> Vec<na::Matrix3<f32>> {
use crate::simd::rodrigues_simd;
(0..self.n_joints)
.map(|j| {
let [rx, ry, rz] = params.joint_pose(j);
rodrigues_simd(rx, ry, rz)
})
.collect()
}
#[cfg(all(feature = "simd", nightly))]
fn apply_pose_blend_shapes_simd(
&self,
v_shaped: &Array2<f32>,
rot_mats: &[na::Matrix3<f32>],
) -> Array2<f32> {
use crate::simd::apply_blend_shapes_simd;
let identity = na::Matrix3::<f32>::identity();
let mut pose_feature = Vec::with_capacity((self.n_joints - 1) * 9);
for rot in rot_mats.iter().skip(1) {
let diff = rot - identity;
for c in 0..3 {
for r in 0..3 {
pose_feature.push(diff[(r, c)]);
}
}
}
let mut v = v_shaped.clone();
apply_blend_shapes_simd(&mut v, &self.posedirs, &pose_feature);
v
}
}
#[inline]
#[must_use]
pub fn rodrigues(rx: f32, ry: f32, rz: f32) -> na::Matrix3<f32> {
let angle = (rx * rx + ry * ry + rz * rz).sqrt();
if angle < 1e-8 {
return na::Matrix3::identity();
}
let (ax, ay, az) = (rx / angle, ry / angle, rz / angle);
let cos_a = angle.cos();
let sin_a = angle.sin();
let t = 1.0 - cos_a;
#[rustfmt::skip]
let m = na::Matrix3::new(
t * ax * ax + cos_a, t * ax * ay - az * sin_a, t * ax * az + ay * sin_a,
t * ay * ax + az * sin_a, t * ay * ay + cos_a, t * ay * az - ax * sin_a,
t * az * ax - ay * sin_a, t * az * ay + ax * sin_a, t * az * az + cos_a,
);
m
}
#[inline]
fn apply_blend_shapes(v: &mut Array2<f32>, dirs: &Array3<f32>, coeffs: &[f32]) {
let k = coeffs.len().min(dirs.shape()[2]);
for (i, &coeff) in coeffs.iter().enumerate().take(k) {
if coeff.abs() > 1e-12 {
let dir_slice = dirs.slice(s![.., .., i]);
v.scaled_add(coeff, &dir_slice);
}
}
}
pub fn compute_normals_into(
vertices: &[na::Point3<f32>],
faces: &[[u32; 3]],
normals_out: &mut [na::Vector3<f32>],
) {
for normal in normals_out.iter_mut() {
*normal = na::Vector3::zeros();
}
for face in faces {
let i0 = face[0] as usize;
let i1 = face[1] as usize;
let i2 = face[2] as usize;
if i0 >= vertices.len() || i1 >= vertices.len() || i2 >= vertices.len() {
continue;
}
let v0 = &vertices[i0];
let v1 = &vertices[i1];
let v2 = &vertices[i2];
let edge1 = v1 - v0;
let edge2 = v2 - v0;
let face_normal = edge1.cross(&edge2);
normals_out[i0] += face_normal;
normals_out[i1] += face_normal;
normals_out[i2] += face_normal;
}
for normal in normals_out.iter_mut() {
let len = normal.norm();
if len > 1e-10 {
*normal /= len;
}
}
}
pub fn compute_normals_batch(
vertices_batch: &[Vec<na::Point3<f32>>],
faces: &[[u32; 3]],
normals_batch: &mut [Vec<na::Vector3<f32>>],
) {
for (vertices, normals) in vertices_batch.iter().zip(normals_batch.iter_mut()) {
compute_normals_into(vertices, faces, normals);
}
}
#[cfg(feature = "parallel")]
pub fn compute_normals_batch_par(
vertices_batch: &[Vec<na::Point3<f32>>],
faces: &[[u32; 3]],
normals_batch: &mut [Vec<na::Vector3<f32>>],
) {
vertices_batch
.par_iter()
.zip(normals_batch.par_iter_mut())
.for_each(|(vertices, normals)| {
compute_normals_into(vertices, faces, normals);
});
}
pub fn recompute_batch_normals(output: &mut BatchedFlameOutput) {
for (vertices, normals) in output.vertices.iter().zip(output.normals.iter_mut()) {
compute_normals_into(vertices, &output.faces, normals);
}
}
#[cfg(feature = "parallel")]
pub fn recompute_batch_normals_par(output: &mut BatchedFlameOutput) {
let faces = &output.faces;
output
.vertices
.par_iter()
.zip(output.normals.par_iter_mut())
.for_each(|(vertices, normals)| {
compute_normals_into(vertices, faces, normals);
});
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rodrigues_identity() {
let r = rodrigues(0.0, 0.0, 0.0);
let id = na::Matrix3::<f32>::identity();
assert!((r - id).norm() < 1e-6);
}
#[test]
fn test_rodrigues_90_deg_z() {
use std::f32::consts::FRAC_PI_2;
let r = rodrigues(0.0, 0.0, FRAC_PI_2);
let v = na::Vector3::new(1.0, 0.0, 0.0);
let rv = r * v;
assert!((rv.x).abs() < 1e-5);
assert!((rv.y - 1.0).abs() < 1e-5);
assert!((rv.z).abs() < 1e-5);
}
#[test]
fn test_rodrigues_roundtrip() {
let r1 = rodrigues(0.3, -0.2, 0.1);
let r2 = rodrigues(-0.3, 0.2, -0.1);
let product = r1 * r2;
let id = na::Matrix3::<f32>::identity();
assert!((product - id).norm() < 1e-5);
}
}