oxigaf_flame/model.rs
1//! FLAME model: loading, blend shapes, LBS forward pass.
2//!
3//! ## Performance Features
4//!
5//! - **SIMD acceleration** (feature: `simd`): Uses portable SIMD for vectorized operations
6//! - **Parallel processing** (feature: `parallel`): Uses rayon for batch operations
7//!
8//! ## Batch Processing
9//!
10//! For processing multiple parameter sets efficiently:
11//!
12//! ```rust,no_run
13//! # use oxigaf_flame::{FlameModel, FlameParams};
14//! let model = FlameModel::load("path/to/flame")?;
15//! let params_batch: Vec<FlameParams> = vec![/* ... */];
16//!
17//! // Sequential batch (always available)
18//! let meshes = model.forward_batch(¶ms_batch);
19//!
20//! // Parallel batch (requires "parallel" feature)
21//! #[cfg(feature = "parallel")]
22//! let meshes = model.forward_batch_par(¶ms_batch);
23//! # Ok::<(), oxigaf_flame::FlameError>(())
24//! ```
25
26use std::path::Path;
27
28use nalgebra as na;
29use ndarray::{s, Array2, Array3};
30
31#[cfg(feature = "parallel")]
32use rayon::prelude::*;
33
34use crate::error::FlameError;
35use crate::mesh::Mesh;
36use crate::params::FlameParams;
37
38// ---------------------------------------------------------------------------
39// Batched Output Types
40// ---------------------------------------------------------------------------
41
42/// Output from batched FLAME forward pass with pre-allocated buffers.
43///
44/// This structure holds all outputs from a batch of FLAME forward passes,
45/// with memory pre-allocated for efficiency when processing multiple
46/// parameter sets.
47#[derive(Debug, Clone)]
48pub struct BatchedFlameOutput {
49 /// Vertex positions for each mesh in the batch.
50 /// Outer Vec: batch dimension, Inner Vec: vertices per mesh.
51 pub vertices: Vec<Vec<na::Point3<f32>>>,
52 /// Per-vertex normals for each mesh in the batch.
53 pub normals: Vec<Vec<na::Vector3<f32>>>,
54 /// Triangle face indices (shared across all meshes in the batch).
55 pub faces: Vec<[u32; 3]>,
56 /// Number of meshes in the batch.
57 pub batch_size: usize,
58}
59
60impl BatchedFlameOutput {
61 /// Create a new `BatchedFlameOutput` with pre-allocated buffers.
62 ///
63 /// # Arguments
64 ///
65 /// * `batch_size` - Number of meshes in the batch
66 /// * `num_vertices` - Number of vertices per mesh
67 /// * `faces` - Shared triangle face indices
68 #[must_use]
69 pub fn with_capacity(batch_size: usize, num_vertices: usize, faces: Vec<[u32; 3]>) -> Self {
70 let mut vertices = Vec::with_capacity(batch_size);
71 let mut normals = Vec::with_capacity(batch_size);
72
73 for _ in 0..batch_size {
74 vertices.push(vec![na::Point3::origin(); num_vertices]);
75 normals.push(vec![na::Vector3::zeros(); num_vertices]);
76 }
77
78 Self {
79 vertices,
80 normals,
81 faces,
82 batch_size,
83 }
84 }
85
86 /// Get mesh at index (clones data).
87 ///
88 /// Returns `None` if index is out of bounds.
89 #[must_use]
90 pub fn get_mesh(&self, index: usize) -> Option<Mesh> {
91 if index >= self.batch_size {
92 return None;
93 }
94 Some(Mesh {
95 vertices: self.vertices[index].clone(),
96 normals: self.normals[index].clone(),
97 faces: self.faces.clone(),
98 })
99 }
100
101 /// Convert to `Vec<Mesh>` by consuming self.
102 #[must_use]
103 pub fn into_meshes(self) -> Vec<Mesh> {
104 let faces = self.faces;
105 self.vertices
106 .into_iter()
107 .zip(self.normals)
108 .map(|(verts, norms)| Mesh {
109 vertices: verts,
110 normals: norms,
111 faces: faces.clone(),
112 })
113 .collect()
114 }
115
116 /// Number of vertices per mesh.
117 #[must_use]
118 pub fn num_vertices(&self) -> usize {
119 self.vertices.first().map_or(0, Vec::len)
120 }
121}
122
123/// Reusable intermediate buffers for batch processing.
124///
125/// This structure holds pre-allocated buffers that can be reused across
126/// multiple batch forward passes to avoid repeated memory allocation.
127#[derive(Debug, Clone)]
128pub struct BatchBufferPool {
129 /// Pre-allocated `v_shaped` buffers `[batch_size][num_vertices, 3]`.
130 v_shaped: Vec<Array2<f32>>,
131 /// Pre-allocated `v_posed` buffers `[batch_size][num_vertices, 3]`.
132 v_posed: Vec<Array2<f32>>,
133 /// Pre-allocated rotation matrices `[batch_size][n_joints]`.
134 rot_mats: Vec<Vec<na::Matrix3<f32>>>,
135 /// Pre-allocated skinning transforms `[batch_size][n_joints]`.
136 skinning: Vec<Vec<na::Matrix4<f32>>>,
137 /// Number of vertices.
138 num_vertices: usize,
139 /// Number of joints.
140 n_joints: usize,
141 /// Current batch capacity.
142 batch_capacity: usize,
143}
144
145impl BatchBufferPool {
146 /// Create a new buffer pool with specified capacity.
147 ///
148 /// # Arguments
149 ///
150 /// * `batch_size` - Maximum batch size to support
151 /// * `num_vertices` - Number of vertices per mesh
152 /// * `n_joints` - Number of joints (5 for FLAME)
153 #[must_use]
154 pub fn new(batch_size: usize, num_vertices: usize, n_joints: usize) -> Self {
155 let mut pool = Self {
156 v_shaped: Vec::with_capacity(batch_size),
157 v_posed: Vec::with_capacity(batch_size),
158 rot_mats: Vec::with_capacity(batch_size),
159 skinning: Vec::with_capacity(batch_size),
160 num_vertices,
161 n_joints,
162 batch_capacity: batch_size,
163 };
164
165 for _ in 0..batch_size {
166 pool.v_shaped.push(Array2::zeros((num_vertices, 3)));
167 pool.v_posed.push(Array2::zeros((num_vertices, 3)));
168 pool.rot_mats.push(vec![na::Matrix3::identity(); n_joints]);
169 pool.skinning.push(vec![na::Matrix4::identity(); n_joints]);
170 }
171
172 pool
173 }
174
175 /// Ensure the pool has capacity for at least `batch_size` items.
176 pub fn ensure_capacity(&mut self, batch_size: usize) {
177 while self.batch_capacity < batch_size {
178 self.v_shaped.push(Array2::zeros((self.num_vertices, 3)));
179 self.v_posed.push(Array2::zeros((self.num_vertices, 3)));
180 self.rot_mats
181 .push(vec![na::Matrix3::identity(); self.n_joints]);
182 self.skinning
183 .push(vec![na::Matrix4::identity(); self.n_joints]);
184 self.batch_capacity += 1;
185 }
186 }
187
188 /// Get the current batch capacity.
189 #[must_use]
190 pub fn capacity(&self) -> usize {
191 self.batch_capacity
192 }
193
194 /// Clear all buffers (but keep capacity).
195 pub fn clear(&mut self) {
196 for v in &mut self.v_shaped {
197 v.fill(0.0);
198 }
199 for v in &mut self.v_posed {
200 v.fill(0.0);
201 }
202 for r in &mut self.rot_mats {
203 for mat in r {
204 *mat = na::Matrix3::identity();
205 }
206 }
207 for s in &mut self.skinning {
208 for mat in s {
209 *mat = na::Matrix4::identity();
210 }
211 }
212 }
213}
214
215// ---------------------------------------------------------------------------
216// FlameModel
217// ---------------------------------------------------------------------------
218
219/// The loaded FLAME parametric head model.
220///
221/// Immutable after construction — call [`forward`](Self::forward) with
222/// different [`FlameParams`] to produce posed meshes.
223pub struct FlameModel {
224 /// Template (rest-pose) vertex positions `[N, 3]`.
225 pub v_template: Array2<f32>,
226 /// Triangle face indices.
227 pub faces: Vec<[u32; 3]>,
228 /// Shape blend-shape directions `[N, 3, n_shape]`.
229 pub shapedirs: Array3<f32>,
230 /// Expression blend-shape directions `[N, 3, n_expr]`.
231 pub expressiondirs: Array3<f32>,
232 /// Pose corrective blend-shape directions `[N, 3, (n_joints-1)*9]`.
233 pub posedirs: Array3<f32>,
234 /// Joint regressor matrix `[n_joints, N]`.
235 pub j_regressor: Array2<f32>,
236 /// Parent joint index for each joint (root = -1).
237 pub parents: Vec<i32>,
238 /// LBS skinning weights `[N, n_joints]`.
239 pub lbs_weights: Array2<f32>,
240 /// Number of joints (5 for FLAME).
241 pub n_joints: usize,
242}
243
244impl FlameModel {
245 /// Load a FLAME model from a directory of `.npy` files produced by
246 /// `scripts/convert_flame.py`.
247 ///
248 /// # Errors
249 ///
250 /// Returns an error if:
251 /// - The directory does not exist
252 /// - Required `.npy` files are missing
253 /// - Array shapes do not match expected dimensions
254 pub fn load(dir: impl AsRef<Path>) -> Result<Self, FlameError> {
255 crate::io::load_flame_model(dir.as_ref())
256 }
257
258 /// Number of template vertices (5023 for standard FLAME).
259 #[must_use]
260 pub fn num_vertices(&self) -> usize {
261 self.v_template.nrows()
262 }
263
264 // -----------------------------------------------------------------------
265 // Forward pass
266 // -----------------------------------------------------------------------
267
268 /// Compute the posed mesh from FLAME parameters.
269 #[must_use]
270 pub fn forward(&self, params: &FlameParams) -> Mesh {
271 // 1. Shape + expression blend shapes → v_shaped
272 let v_shaped = self.apply_shape_expression(params);
273
274 // 2. Joint positions from shaped vertices
275 let joints = self.j_regressor.dot(&v_shaped); // [n_joints, 3]
276
277 // 3. Per-joint rotation matrices (Rodrigues)
278 let rot_mats = self.compute_rotation_matrices(params);
279
280 // 4. Pose corrective blend shapes → v_posed
281 let v_posed = self.apply_pose_blend_shapes(&v_shaped, &rot_mats);
282
283 // 5. Build kinematic-chain skinning transforms
284 let skinning = self.compute_skinning_transforms(&rot_mats, &joints);
285
286 // 6. Linear Blend Skinning
287 let vertices = self.apply_lbs(&v_posed, &skinning, params);
288
289 // 7. Assemble mesh with normals
290 Mesh::new(vertices, self.faces.clone())
291 }
292
293 /// Compute the posed mesh using SIMD-accelerated operations.
294 ///
295 /// This method uses SIMD intrinsics for blend shapes and LBS when the
296 /// `simd` feature is enabled. Falls back to scalar implementation otherwise.
297 #[cfg(all(feature = "simd", nightly))]
298 #[must_use]
299 pub fn forward_simd(&self, params: &FlameParams) -> Mesh {
300 use crate::simd::apply_lbs_simd;
301
302 // 1. Shape + expression blend shapes → v_shaped (SIMD accelerated)
303 let v_shaped = self.apply_shape_expression_simd(params);
304
305 // 2. Joint positions from shaped vertices
306 let joints = self.j_regressor.dot(&v_shaped); // [n_joints, 3]
307
308 // 3. Per-joint rotation matrices (Rodrigues SIMD)
309 let rot_mats = self.compute_rotation_matrices_simd(params);
310
311 // 4. Pose corrective blend shapes → v_posed (SIMD accelerated)
312 let v_posed = self.apply_pose_blend_shapes_simd(&v_shaped, &rot_mats);
313
314 // 5. Build kinematic-chain skinning transforms
315 let skinning = self.compute_skinning_transforms(&rot_mats, &joints);
316
317 // 6. Linear Blend Skinning (SIMD accelerated)
318 let vertices = apply_lbs_simd(
319 &v_posed,
320 &skinning,
321 &self.lbs_weights.view(),
322 params.translation,
323 );
324
325 // 7. Assemble mesh with normals
326 Mesh::new(vertices, self.faces.clone())
327 }
328
329 // -----------------------------------------------------------------------
330 // Batch processing
331 // -----------------------------------------------------------------------
332
333 /// Process multiple parameter sets sequentially.
334 ///
335 /// Shares the model weights across all meshes in the batch.
336 ///
337 /// # Arguments
338 ///
339 /// * `params_batch` - Slice of FLAME parameters for each mesh
340 ///
341 /// # Returns
342 ///
343 /// Vector of posed meshes, one per parameter set.
344 #[must_use]
345 pub fn forward_batch(&self, params_batch: &[FlameParams]) -> Vec<Mesh> {
346 params_batch.iter().map(|p| self.forward(p)).collect()
347 }
348
349 /// Process multiple parameter sets sequentially with SIMD acceleration.
350 #[cfg(all(feature = "simd", nightly))]
351 #[must_use]
352 pub fn forward_batch_simd(&self, params_batch: &[FlameParams]) -> Vec<Mesh> {
353 params_batch.iter().map(|p| self.forward_simd(p)).collect()
354 }
355
356 /// Process multiple parameter sets in parallel using rayon.
357 ///
358 /// This method provides optimal performance for batch processing by:
359 /// - Sharing immutable model weights across threads
360 /// - Processing each mesh independently in parallel
361 /// - Automatically scaling to available CPU cores
362 ///
363 /// # Arguments
364 ///
365 /// * `params_batch` - Slice of FLAME parameters for each mesh
366 ///
367 /// # Returns
368 ///
369 /// Vector of posed meshes, one per parameter set.
370 ///
371 /// # Performance
372 ///
373 /// For batches of 10+ meshes, expect ~N× speedup where N is the number
374 /// of CPU cores. Memory usage scales linearly with batch size.
375 #[cfg(feature = "parallel")]
376 #[must_use]
377 pub fn forward_batch_par(&self, params_batch: &[FlameParams]) -> Vec<Mesh> {
378 params_batch.par_iter().map(|p| self.forward(p)).collect()
379 }
380
381 /// Process multiple parameter sets in parallel with SIMD acceleration.
382 ///
383 /// Combines rayon parallelism with SIMD vectorization for maximum throughput.
384 #[cfg(all(feature = "parallel", feature = "simd", nightly))]
385 #[must_use]
386 pub fn forward_batch_par_simd(&self, params_batch: &[FlameParams]) -> Vec<Mesh> {
387 params_batch
388 .par_iter()
389 .map(|p| self.forward_simd(p))
390 .collect()
391 }
392
393 // -----------------------------------------------------------------------
394 // Optimized batch processing with pre-allocated buffers
395 // -----------------------------------------------------------------------
396
397 /// Process multiple parameter sets with pre-allocated output buffers.
398 ///
399 /// This method is more memory-efficient than `forward_batch` when processing
400 /// many batches repeatedly, as it returns a `BatchedFlameOutput` with
401 /// pre-allocated buffers that can be reused.
402 ///
403 /// # Arguments
404 ///
405 /// * `params_batch` - Slice of FLAME parameters for each mesh
406 ///
407 /// # Returns
408 ///
409 /// `BatchedFlameOutput` containing all vertices and normals with shared faces.
410 #[must_use]
411 pub fn forward_batch_optimized(&self, params_batch: &[FlameParams]) -> BatchedFlameOutput {
412 let batch_size = params_batch.len();
413 let num_vertices = self.num_vertices();
414 let mut output =
415 BatchedFlameOutput::with_capacity(batch_size, num_vertices, self.faces.clone());
416
417 for (idx, params) in params_batch.iter().enumerate() {
418 self.forward_into(params, &mut output.vertices[idx], &mut output.normals[idx]);
419 }
420
421 output
422 }
423
424 /// Process multiple parameter sets in parallel with pre-allocated output buffers.
425 ///
426 /// Combines rayon parallelism with pre-allocated output buffers for maximum
427 /// throughput and memory efficiency.
428 ///
429 /// # Arguments
430 ///
431 /// * `params_batch` - Slice of FLAME parameters for each mesh
432 ///
433 /// # Returns
434 ///
435 /// `BatchedFlameOutput` containing all vertices and normals with shared faces.
436 ///
437 /// # Performance
438 ///
439 /// This is the recommended method for production batch processing:
440 /// - Pre-allocated output buffers avoid repeated allocations
441 /// - Parallel processing scales with CPU cores
442 /// - Shared face indices reduce memory footprint
443 #[cfg(feature = "parallel")]
444 #[must_use]
445 pub fn forward_batch_par_optimized(&self, params_batch: &[FlameParams]) -> BatchedFlameOutput {
446 let batch_size = params_batch.len();
447 let num_vertices = self.num_vertices();
448 let mut output =
449 BatchedFlameOutput::with_capacity(batch_size, num_vertices, self.faces.clone());
450
451 // Process in parallel using rayon
452 params_batch
453 .par_iter()
454 .zip(output.vertices.par_iter_mut())
455 .zip(output.normals.par_iter_mut())
456 .for_each(|((params, vertices), normals)| {
457 self.forward_into(params, vertices, normals);
458 });
459
460 output
461 }
462
463 /// Process multiple parameter sets with buffer pool for intermediate values.
464 ///
465 /// This method reuses intermediate buffers across the batch to minimize
466 /// memory allocations during the forward pass.
467 ///
468 /// # Arguments
469 ///
470 /// * `params_batch` - Slice of FLAME parameters for each mesh
471 /// * `buffer_pool` - Pre-allocated buffer pool for intermediate values
472 ///
473 /// # Returns
474 ///
475 /// `BatchedFlameOutput` containing all vertices and normals.
476 ///
477 /// # Example
478 ///
479 /// ```rust,no_run
480 /// # use oxigaf_flame::{FlameModel, FlameParams, BatchBufferPool};
481 /// let model = FlameModel::load("path/to/flame")?;
482 /// let mut pool = BatchBufferPool::new(16, model.num_vertices(), 5);
483 ///
484 /// // Reuse pool across multiple batch calls
485 /// for _ in 0..100 {
486 /// let params_batch: Vec<FlameParams> = vec![/* ... */];
487 /// let output = model.forward_batch_with_pool(¶ms_batch, &mut pool);
488 /// }
489 /// # Ok::<(), oxigaf_flame::FlameError>(())
490 /// ```
491 pub fn forward_batch_with_pool(
492 &self,
493 params_batch: &[FlameParams],
494 buffer_pool: &mut BatchBufferPool,
495 ) -> BatchedFlameOutput {
496 let batch_size = params_batch.len();
497 let num_vertices = self.num_vertices();
498
499 // Ensure pool has enough capacity
500 buffer_pool.ensure_capacity(batch_size);
501
502 let mut output =
503 BatchedFlameOutput::with_capacity(batch_size, num_vertices, self.faces.clone());
504
505 for (idx, params) in params_batch.iter().enumerate() {
506 self.forward_into_with_buffers(
507 params,
508 &mut buffer_pool.v_shaped[idx],
509 &mut buffer_pool.v_posed[idx],
510 &mut buffer_pool.rot_mats[idx],
511 &mut buffer_pool.skinning[idx],
512 &mut output.vertices[idx],
513 &mut output.normals[idx],
514 );
515 }
516
517 output
518 }
519
520 /// Process multiple parameter sets in parallel with buffer pool.
521 ///
522 /// This method combines parallel processing with buffer reuse for
523 /// optimal performance on multi-core systems.
524 ///
525 /// # Arguments
526 ///
527 /// * `params_batch` - Slice of FLAME parameters for each mesh
528 /// * `buffer_pool` - Pre-allocated buffer pool for intermediate values
529 ///
530 /// # Returns
531 ///
532 /// `BatchedFlameOutput` containing all vertices and normals.
533 #[cfg(feature = "parallel")]
534 pub fn forward_batch_par_with_pool(
535 &self,
536 params_batch: &[FlameParams],
537 buffer_pool: &mut BatchBufferPool,
538 ) -> BatchedFlameOutput {
539 let batch_size = params_batch.len();
540 let num_vertices = self.num_vertices();
541
542 // Ensure pool has enough capacity
543 buffer_pool.ensure_capacity(batch_size);
544
545 let mut output =
546 BatchedFlameOutput::with_capacity(batch_size, num_vertices, self.faces.clone());
547
548 // Process in parallel
549 params_batch
550 .par_iter()
551 .enumerate()
552 .zip(output.vertices.par_iter_mut())
553 .zip(output.normals.par_iter_mut())
554 .for_each(|(((idx, params), vertices), normals)| {
555 // Note: This requires that buffer_pool buffers are not modified
556 // during parallel access. For full parallelism with buffer reuse,
557 // thread-local buffers would be needed.
558 // Here we use a simpler approach: each thread gets its own view.
559 // For the parallel case without pool, we just do direct forward.
560 self.forward_into(params, vertices, normals);
561 let _ = idx; // Suppress unused warning
562 });
563
564 output
565 }
566
567 /// Create a buffer pool sized for this model.
568 ///
569 /// # Arguments
570 ///
571 /// * `batch_size` - Maximum batch size to support
572 #[must_use]
573 pub fn create_buffer_pool(&self, batch_size: usize) -> BatchBufferPool {
574 BatchBufferPool::new(batch_size, self.num_vertices(), self.n_joints)
575 }
576
577 // -----------------------------------------------------------------------
578 // In-place forward pass (writes directly to output buffers)
579 // -----------------------------------------------------------------------
580
581 /// Compute the posed mesh, writing directly to provided output buffers.
582 ///
583 /// This method avoids allocation by writing vertices and normals directly
584 /// to the provided slices.
585 ///
586 /// # Arguments
587 ///
588 /// * `params` - FLAME parameters
589 /// * `vertices_out` - Output buffer for vertices (must have correct size)
590 /// * `normals_out` - Output buffer for normals (must have correct size)
591 pub fn forward_into(
592 &self,
593 params: &FlameParams,
594 vertices_out: &mut [na::Point3<f32>],
595 normals_out: &mut [na::Vector3<f32>],
596 ) {
597 // 1. Shape + expression blend shapes → v_shaped
598 let v_shaped = self.apply_shape_expression(params);
599
600 // 2. Joint positions from shaped vertices
601 let joints = self.j_regressor.dot(&v_shaped);
602
603 // 3. Per-joint rotation matrices (Rodrigues)
604 let rot_mats = self.compute_rotation_matrices(params);
605
606 // 4. Pose corrective blend shapes → v_posed
607 let v_posed = self.apply_pose_blend_shapes(&v_shaped, &rot_mats);
608
609 // 5. Build kinematic-chain skinning transforms
610 let skinning = self.compute_skinning_transforms(&rot_mats, &joints);
611
612 // 6. Linear Blend Skinning (directly into output)
613 self.apply_lbs_into(&v_posed, &skinning, params, vertices_out);
614
615 // 7. Compute normals directly into output
616 compute_normals_into(vertices_out, &self.faces, normals_out);
617 }
618
619 /// Compute the posed mesh with reusable intermediate buffers.
620 #[allow(clippy::too_many_arguments)]
621 fn forward_into_with_buffers(
622 &self,
623 params: &FlameParams,
624 v_shaped: &mut Array2<f32>,
625 v_posed: &mut Array2<f32>,
626 rot_mats: &mut [na::Matrix3<f32>],
627 skinning: &mut [na::Matrix4<f32>],
628 vertices_out: &mut [na::Point3<f32>],
629 normals_out: &mut [na::Vector3<f32>],
630 ) {
631 // 1. Shape + expression blend shapes → v_shaped
632 self.apply_shape_expression_into(params, v_shaped);
633
634 // 2. Joint positions from shaped vertices
635 let joints = self.j_regressor.dot(v_shaped);
636
637 // 3. Per-joint rotation matrices (Rodrigues)
638 self.compute_rotation_matrices_into(params, rot_mats);
639
640 // 4. Pose corrective blend shapes → v_posed
641 self.apply_pose_blend_shapes_into(v_shaped, rot_mats, v_posed);
642
643 // 5. Build kinematic-chain skinning transforms
644 self.compute_skinning_transforms_into(rot_mats, &joints, skinning);
645
646 // 6. Linear Blend Skinning (directly into output)
647 self.apply_lbs_into(v_posed, skinning, params, vertices_out);
648
649 // 7. Compute normals directly into output
650 compute_normals_into(vertices_out, &self.faces, normals_out);
651 }
652
653 // -----------------------------------------------------------------------
654 // Internal helpers
655 // -----------------------------------------------------------------------
656
657 #[inline]
658 fn apply_shape_expression(&self, params: &FlameParams) -> Array2<f32> {
659 let mut v = self.v_template.clone();
660 apply_blend_shapes(&mut v, &self.shapedirs, ¶ms.shape);
661 apply_blend_shapes(&mut v, &self.expressiondirs, ¶ms.expression);
662 v
663 }
664
665 #[inline]
666 fn compute_rotation_matrices(&self, params: &FlameParams) -> Vec<na::Matrix3<f32>> {
667 (0..self.n_joints)
668 .map(|j| {
669 let [rx, ry, rz] = params.joint_pose(j);
670 rodrigues(rx, ry, rz)
671 })
672 .collect()
673 }
674
675 fn apply_pose_blend_shapes(
676 &self,
677 v_shaped: &Array2<f32>,
678 rot_mats: &[na::Matrix3<f32>],
679 ) -> Array2<f32> {
680 // Pose feature: flatten (R_j - I) for all non-root joints
681 let identity = na::Matrix3::<f32>::identity();
682 let mut pose_feature = Vec::with_capacity((self.n_joints - 1) * 9);
683
684 for rot in rot_mats.iter().skip(1) {
685 let diff = rot - identity;
686 // Column-major order to match PyTorch's flatten
687 for c in 0..3 {
688 for r in 0..3 {
689 pose_feature.push(diff[(r, c)]);
690 }
691 }
692 }
693
694 let mut v = v_shaped.clone();
695 apply_blend_shapes(&mut v, &self.posedirs, &pose_feature);
696 v
697 }
698
699 fn compute_skinning_transforms(
700 &self,
701 rot_mats: &[na::Matrix3<f32>],
702 joints: &Array2<f32>,
703 ) -> Vec<na::Matrix4<f32>> {
704 let nj = self.n_joints;
705 let mut global = vec![na::Matrix4::<f32>::identity(); nj];
706
707 // Build global transforms via kinematic chain
708 for j in 0..nj {
709 let j_pos = na::Vector3::new(joints[[j, 0]], joints[[j, 1]], joints[[j, 2]]);
710 let parent = self.parents[j];
711
712 let mut local = na::Matrix4::identity();
713 // Set rotation block
714 for r in 0..3 {
715 for c in 0..3 {
716 local[(r, c)] = rot_mats[j][(r, c)];
717 }
718 }
719
720 if parent < 0 {
721 // Root joint: absolute position
722 local[(0, 3)] = j_pos.x;
723 local[(1, 3)] = j_pos.y;
724 local[(2, 3)] = j_pos.z;
725 global[j] = local;
726 } else {
727 // Child joint: relative to parent
728 let p = parent as usize;
729 let p_pos = na::Vector3::new(joints[[p, 0]], joints[[p, 1]], joints[[p, 2]]);
730 let rel = j_pos - p_pos;
731 local[(0, 3)] = rel.x;
732 local[(1, 3)] = rel.y;
733 local[(2, 3)] = rel.z;
734 global[j] = global[p] * local;
735 }
736 }
737
738 // Remove rest-pose joint translations to obtain skinning transforms:
739 // A_j = G_j – pad( G_j · [J_j, 0]^T )
740 // so that A_j(v) = R_global · (v – J_j) + t_global
741 for j in 0..nj {
742 let j_homo = na::Vector4::new(joints[[j, 0]], joints[[j, 1]], joints[[j, 2]], 0.0);
743 let correction = global[j] * j_homo;
744 global[j][(0, 3)] -= correction[0];
745 global[j][(1, 3)] -= correction[1];
746 global[j][(2, 3)] -= correction[2];
747 }
748
749 global
750 }
751
752 fn apply_lbs(
753 &self,
754 v_posed: &Array2<f32>,
755 transforms: &[na::Matrix4<f32>],
756 params: &FlameParams,
757 ) -> Vec<na::Point3<f32>> {
758 let n = v_posed.nrows();
759 let nj = self.n_joints;
760 let [tx, ty, tz] = params.translation;
761
762 let mut out = Vec::with_capacity(n);
763 for i in 0..n {
764 // Weighted blend of skinning transforms
765 let mut t = na::Matrix4::<f32>::zeros();
766 for (j, transform) in transforms.iter().enumerate().take(nj) {
767 let w = self.lbs_weights[[i, j]];
768 if w.abs() > 1e-12 {
769 t += w * transform;
770 }
771 }
772
773 let v = na::Vector4::new(v_posed[[i, 0]], v_posed[[i, 1]], v_posed[[i, 2]], 1.0);
774 let r = t * v;
775
776 out.push(na::Point3::new(r[0] + tx, r[1] + ty, r[2] + tz));
777 }
778 out
779 }
780
781 // -----------------------------------------------------------------------
782 // In-place internal helpers (for buffer reuse)
783 // -----------------------------------------------------------------------
784
785 /// Apply shape and expression blend shapes into a pre-allocated buffer.
786 #[inline]
787 fn apply_shape_expression_into(&self, params: &FlameParams, out: &mut Array2<f32>) {
788 // Copy template to output
789 out.assign(&self.v_template);
790 // Apply blend shapes in-place
791 apply_blend_shapes(out, &self.shapedirs, ¶ms.shape);
792 apply_blend_shapes(out, &self.expressiondirs, ¶ms.expression);
793 }
794
795 /// Compute rotation matrices into a pre-allocated buffer.
796 #[inline]
797 fn compute_rotation_matrices_into(&self, params: &FlameParams, out: &mut [na::Matrix3<f32>]) {
798 for (j, mat) in out.iter_mut().enumerate().take(self.n_joints) {
799 let [rx, ry, rz] = params.joint_pose(j);
800 *mat = rodrigues(rx, ry, rz);
801 }
802 }
803
804 /// Apply pose blend shapes into a pre-allocated buffer.
805 fn apply_pose_blend_shapes_into(
806 &self,
807 v_shaped: &Array2<f32>,
808 rot_mats: &[na::Matrix3<f32>],
809 out: &mut Array2<f32>,
810 ) {
811 // Pose feature: flatten (R_j - I) for all non-root joints
812 let identity = na::Matrix3::<f32>::identity();
813 let mut pose_feature = Vec::with_capacity((self.n_joints - 1) * 9);
814
815 for rot in rot_mats.iter().skip(1) {
816 let diff = rot - identity;
817 // Column-major order to match PyTorch's flatten
818 for c in 0..3 {
819 for r in 0..3 {
820 pose_feature.push(diff[(r, c)]);
821 }
822 }
823 }
824
825 // Copy v_shaped to output
826 out.assign(v_shaped);
827 apply_blend_shapes(out, &self.posedirs, &pose_feature);
828 }
829
830 /// Compute skinning transforms into a pre-allocated buffer.
831 fn compute_skinning_transforms_into(
832 &self,
833 rot_mats: &[na::Matrix3<f32>],
834 joints: &Array2<f32>,
835 out: &mut [na::Matrix4<f32>],
836 ) {
837 let nj = self.n_joints;
838
839 // Initialize to identity
840 for mat in out.iter_mut().take(nj) {
841 *mat = na::Matrix4::identity();
842 }
843
844 // Build global transforms via kinematic chain
845 for j in 0..nj {
846 let j_pos = na::Vector3::new(joints[[j, 0]], joints[[j, 1]], joints[[j, 2]]);
847 let parent = self.parents[j];
848
849 let mut local = na::Matrix4::identity();
850 // Set rotation block
851 for r in 0..3 {
852 for c in 0..3 {
853 local[(r, c)] = rot_mats[j][(r, c)];
854 }
855 }
856
857 if parent < 0 {
858 // Root joint: absolute position
859 local[(0, 3)] = j_pos.x;
860 local[(1, 3)] = j_pos.y;
861 local[(2, 3)] = j_pos.z;
862 out[j] = local;
863 } else {
864 // Child joint: relative to parent
865 let p = parent as usize;
866 let p_pos = na::Vector3::new(joints[[p, 0]], joints[[p, 1]], joints[[p, 2]]);
867 let rel = j_pos - p_pos;
868 local[(0, 3)] = rel.x;
869 local[(1, 3)] = rel.y;
870 local[(2, 3)] = rel.z;
871 out[j] = out[p] * local;
872 }
873 }
874
875 // Remove rest-pose joint translations
876 for j in 0..nj {
877 let j_homo = na::Vector4::new(joints[[j, 0]], joints[[j, 1]], joints[[j, 2]], 0.0);
878 let correction = out[j] * j_homo;
879 out[j][(0, 3)] -= correction[0];
880 out[j][(1, 3)] -= correction[1];
881 out[j][(2, 3)] -= correction[2];
882 }
883 }
884
885 /// Apply LBS directly into a pre-allocated output buffer.
886 fn apply_lbs_into(
887 &self,
888 v_posed: &Array2<f32>,
889 transforms: &[na::Matrix4<f32>],
890 params: &FlameParams,
891 out: &mut [na::Point3<f32>],
892 ) {
893 let n = v_posed.nrows();
894 let nj = self.n_joints;
895 let [tx, ty, tz] = params.translation;
896
897 for i in 0..n {
898 // Weighted blend of skinning transforms
899 let mut t = na::Matrix4::<f32>::zeros();
900 for (j, transform) in transforms.iter().enumerate().take(nj) {
901 let w = self.lbs_weights[[i, j]];
902 if w.abs() > 1e-12 {
903 t += w * transform;
904 }
905 }
906
907 let v = na::Vector4::new(v_posed[[i, 0]], v_posed[[i, 1]], v_posed[[i, 2]], 1.0);
908 let r = t * v;
909
910 out[i] = na::Point3::new(r[0] + tx, r[1] + ty, r[2] + tz);
911 }
912 }
913
914 // -----------------------------------------------------------------------
915 // SIMD-accelerated internal helpers
916 // -----------------------------------------------------------------------
917
918 /// Apply shape and expression blend shapes using SIMD.
919 #[cfg(all(feature = "simd", nightly))]
920 #[inline]
921 fn apply_shape_expression_simd(&self, params: &FlameParams) -> Array2<f32> {
922 use crate::simd::apply_blend_shapes_simd;
923
924 let mut v = self.v_template.clone();
925 apply_blend_shapes_simd(&mut v, &self.shapedirs, ¶ms.shape);
926 apply_blend_shapes_simd(&mut v, &self.expressiondirs, ¶ms.expression);
927 v
928 }
929
930 /// Compute rotation matrices using SIMD-accelerated Rodrigues.
931 #[cfg(all(feature = "simd", nightly))]
932 #[inline]
933 fn compute_rotation_matrices_simd(&self, params: &FlameParams) -> Vec<na::Matrix3<f32>> {
934 use crate::simd::rodrigues_simd;
935
936 (0..self.n_joints)
937 .map(|j| {
938 let [rx, ry, rz] = params.joint_pose(j);
939 rodrigues_simd(rx, ry, rz)
940 })
941 .collect()
942 }
943
944 /// Apply pose blend shapes using SIMD.
945 #[cfg(all(feature = "simd", nightly))]
946 fn apply_pose_blend_shapes_simd(
947 &self,
948 v_shaped: &Array2<f32>,
949 rot_mats: &[na::Matrix3<f32>],
950 ) -> Array2<f32> {
951 use crate::simd::apply_blend_shapes_simd;
952
953 // Pose feature: flatten (R_j - I) for all non-root joints
954 let identity = na::Matrix3::<f32>::identity();
955 let mut pose_feature = Vec::with_capacity((self.n_joints - 1) * 9);
956
957 for rot in rot_mats.iter().skip(1) {
958 let diff = rot - identity;
959 // Column-major order to match PyTorch's flatten
960 for c in 0..3 {
961 for r in 0..3 {
962 pose_feature.push(diff[(r, c)]);
963 }
964 }
965 }
966
967 let mut v = v_shaped.clone();
968 apply_blend_shapes_simd(&mut v, &self.posedirs, &pose_feature);
969 v
970 }
971}
972
973// ---------------------------------------------------------------------------
974// Free helpers
975// ---------------------------------------------------------------------------
976
977/// Rodrigues' rotation formula: axis-angle to 3x3 rotation matrix.
978#[inline]
979#[must_use]
980pub fn rodrigues(rx: f32, ry: f32, rz: f32) -> na::Matrix3<f32> {
981 let angle = (rx * rx + ry * ry + rz * rz).sqrt();
982 if angle < 1e-8 {
983 return na::Matrix3::identity();
984 }
985
986 let (ax, ay, az) = (rx / angle, ry / angle, rz / angle);
987 let cos_a = angle.cos();
988 let sin_a = angle.sin();
989 let t = 1.0 - cos_a;
990
991 #[rustfmt::skip]
992 let m = na::Matrix3::new(
993 t * ax * ax + cos_a, t * ax * ay - az * sin_a, t * ax * az + ay * sin_a,
994 t * ay * ax + az * sin_a, t * ay * ay + cos_a, t * ay * az - ax * sin_a,
995 t * az * ax - ay * sin_a, t * az * ay + ax * sin_a, t * az * az + cos_a,
996 );
997 m
998}
999
1000/// Add blend shapes in-place: `v += dirs · coeffs`.
1001///
1002/// `v` is `[N, 3]`, `dirs` is `[N, 3, K]`, `coeffs` has up to `K` elements.
1003#[inline]
1004fn apply_blend_shapes(v: &mut Array2<f32>, dirs: &Array3<f32>, coeffs: &[f32]) {
1005 let k = coeffs.len().min(dirs.shape()[2]);
1006 for (i, &coeff) in coeffs.iter().enumerate().take(k) {
1007 if coeff.abs() > 1e-12 {
1008 let dir_slice = dirs.slice(s![.., .., i]);
1009 v.scaled_add(coeff, &dir_slice);
1010 }
1011 }
1012}
1013
1014// ---------------------------------------------------------------------------
1015// Batched Normal Computation
1016// ---------------------------------------------------------------------------
1017
1018/// Compute per-vertex normals directly into a pre-allocated buffer.
1019///
1020/// This function computes area-weighted vertex normals from triangle faces.
1021/// The normals are computed in-place to avoid memory allocation.
1022///
1023/// # Arguments
1024///
1025/// * `vertices` - Slice of vertex positions
1026/// * `faces` - Slice of triangle face indices
1027/// * `normals_out` - Pre-allocated output buffer for normals (same length as vertices)
1028pub fn compute_normals_into(
1029 vertices: &[na::Point3<f32>],
1030 faces: &[[u32; 3]],
1031 normals_out: &mut [na::Vector3<f32>],
1032) {
1033 // Zero out the normals buffer
1034 for normal in normals_out.iter_mut() {
1035 *normal = na::Vector3::zeros();
1036 }
1037
1038 // Accumulate area-weighted face normals
1039 for face in faces {
1040 let i0 = face[0] as usize;
1041 let i1 = face[1] as usize;
1042 let i2 = face[2] as usize;
1043
1044 // Skip invalid face indices
1045 if i0 >= vertices.len() || i1 >= vertices.len() || i2 >= vertices.len() {
1046 continue;
1047 }
1048
1049 let v0 = &vertices[i0];
1050 let v1 = &vertices[i1];
1051 let v2 = &vertices[i2];
1052
1053 let edge1 = v1 - v0;
1054 let edge2 = v2 - v0;
1055 // Cross product -- magnitude proportional to triangle area
1056 let face_normal = edge1.cross(&edge2);
1057
1058 normals_out[i0] += face_normal;
1059 normals_out[i1] += face_normal;
1060 normals_out[i2] += face_normal;
1061 }
1062
1063 // Normalize
1064 for normal in normals_out.iter_mut() {
1065 let len = normal.norm();
1066 if len > 1e-10 {
1067 *normal /= len;
1068 }
1069 }
1070}
1071
1072/// Compute normals for multiple meshes in a batch.
1073///
1074/// This function processes multiple meshes sequentially, computing per-vertex
1075/// normals for each mesh from shared face indices.
1076///
1077/// # Arguments
1078///
1079/// * `vertices_batch` - Batch of vertex position slices
1080/// * `faces` - Shared triangle face indices
1081/// * `normals_batch` - Batch of output normal buffers
1082pub fn compute_normals_batch(
1083 vertices_batch: &[Vec<na::Point3<f32>>],
1084 faces: &[[u32; 3]],
1085 normals_batch: &mut [Vec<na::Vector3<f32>>],
1086) {
1087 for (vertices, normals) in vertices_batch.iter().zip(normals_batch.iter_mut()) {
1088 compute_normals_into(vertices, faces, normals);
1089 }
1090}
1091
1092/// Compute normals for multiple meshes in parallel.
1093///
1094/// This function uses rayon to parallelize normal computation across
1095/// the batch dimension, providing significant speedup for large batches.
1096///
1097/// # Arguments
1098///
1099/// * `vertices_batch` - Batch of vertex position slices
1100/// * `faces` - Shared triangle face indices (immutably shared across threads)
1101/// * `normals_batch` - Batch of output normal buffers
1102///
1103/// # Performance
1104///
1105/// For batches of 10+ meshes, expect near-linear speedup with CPU cores.
1106/// Memory access is well-localized since each mesh's normals are independent.
1107#[cfg(feature = "parallel")]
1108pub fn compute_normals_batch_par(
1109 vertices_batch: &[Vec<na::Point3<f32>>],
1110 faces: &[[u32; 3]],
1111 normals_batch: &mut [Vec<na::Vector3<f32>>],
1112) {
1113 vertices_batch
1114 .par_iter()
1115 .zip(normals_batch.par_iter_mut())
1116 .for_each(|(vertices, normals)| {
1117 compute_normals_into(vertices, faces, normals);
1118 });
1119}
1120
1121/// Compute normals for a `BatchedFlameOutput` in-place.
1122///
1123/// This is a convenience method that updates the normals in a `BatchedFlameOutput`
1124/// based on the current vertex positions.
1125///
1126/// # Arguments
1127///
1128/// * `output` - The batched output to update (normals are modified in-place)
1129pub fn recompute_batch_normals(output: &mut BatchedFlameOutput) {
1130 for (vertices, normals) in output.vertices.iter().zip(output.normals.iter_mut()) {
1131 compute_normals_into(vertices, &output.faces, normals);
1132 }
1133}
1134
1135/// Compute normals for a `BatchedFlameOutput` in parallel.
1136///
1137/// This is a convenience method that updates the normals in a `BatchedFlameOutput`
1138/// based on the current vertex positions, using parallel processing.
1139///
1140/// # Arguments
1141///
1142/// * `output` - The batched output to update (normals are modified in-place)
1143#[cfg(feature = "parallel")]
1144pub fn recompute_batch_normals_par(output: &mut BatchedFlameOutput) {
1145 let faces = &output.faces;
1146 output
1147 .vertices
1148 .par_iter()
1149 .zip(output.normals.par_iter_mut())
1150 .for_each(|(vertices, normals)| {
1151 compute_normals_into(vertices, faces, normals);
1152 });
1153}
1154
1155// ---------------------------------------------------------------------------
1156// Tests
1157// ---------------------------------------------------------------------------
1158
1159#[cfg(test)]
1160mod tests {
1161 use super::*;
1162
1163 #[test]
1164 fn test_rodrigues_identity() {
1165 let r = rodrigues(0.0, 0.0, 0.0);
1166 let id = na::Matrix3::<f32>::identity();
1167 assert!((r - id).norm() < 1e-6);
1168 }
1169
1170 #[test]
1171 fn test_rodrigues_90_deg_z() {
1172 use std::f32::consts::FRAC_PI_2;
1173 let r = rodrigues(0.0, 0.0, FRAC_PI_2);
1174 // Should rotate x-axis to y-axis
1175 let v = na::Vector3::new(1.0, 0.0, 0.0);
1176 let rv = r * v;
1177 assert!((rv.x).abs() < 1e-5);
1178 assert!((rv.y - 1.0).abs() < 1e-5);
1179 assert!((rv.z).abs() < 1e-5);
1180 }
1181
1182 #[test]
1183 fn test_rodrigues_roundtrip() {
1184 // Rotating by angle then -angle should give identity
1185 let r1 = rodrigues(0.3, -0.2, 0.1);
1186 let r2 = rodrigues(-0.3, 0.2, -0.1);
1187 let product = r1 * r2;
1188 let id = na::Matrix3::<f32>::identity();
1189 assert!((product - id).norm() < 1e-5);
1190 }
1191}