Skip to main content

oxigaf_flame/
model.rs

1//! FLAME model: loading, blend shapes, LBS forward pass.
2//!
3//! ## Performance Features
4//!
5//! - **SIMD acceleration** (feature: `simd`): Uses portable SIMD for vectorized operations
6//! - **Parallel processing** (feature: `parallel`): Uses rayon for batch operations
7//!
8//! ## Batch Processing
9//!
10//! For processing multiple parameter sets efficiently:
11//!
12//! ```rust,no_run
13//! # use oxigaf_flame::{FlameModel, FlameParams};
14//! let model = FlameModel::load("path/to/flame")?;
15//! let params_batch: Vec<FlameParams> = vec![/* ... */];
16//!
17//! // Sequential batch (always available)
18//! let meshes = model.forward_batch(&params_batch);
19//!
20//! // Parallel batch (requires "parallel" feature)
21//! #[cfg(feature = "parallel")]
22//! let meshes = model.forward_batch_par(&params_batch);
23//! # Ok::<(), oxigaf_flame::FlameError>(())
24//! ```
25
26use std::path::Path;
27
28use nalgebra as na;
29use ndarray::{s, Array2, Array3};
30
31#[cfg(feature = "parallel")]
32use rayon::prelude::*;
33
34use crate::error::FlameError;
35use crate::mesh::Mesh;
36use crate::params::FlameParams;
37
38// ---------------------------------------------------------------------------
39// Batched Output Types
40// ---------------------------------------------------------------------------
41
42/// Output from batched FLAME forward pass with pre-allocated buffers.
43///
44/// This structure holds all outputs from a batch of FLAME forward passes,
45/// with memory pre-allocated for efficiency when processing multiple
46/// parameter sets.
47#[derive(Debug, Clone)]
48pub struct BatchedFlameOutput {
49    /// Vertex positions for each mesh in the batch.
50    /// Outer Vec: batch dimension, Inner Vec: vertices per mesh.
51    pub vertices: Vec<Vec<na::Point3<f32>>>,
52    /// Per-vertex normals for each mesh in the batch.
53    pub normals: Vec<Vec<na::Vector3<f32>>>,
54    /// Triangle face indices (shared across all meshes in the batch).
55    pub faces: Vec<[u32; 3]>,
56    /// Number of meshes in the batch.
57    pub batch_size: usize,
58}
59
60impl BatchedFlameOutput {
61    /// Create a new `BatchedFlameOutput` with pre-allocated buffers.
62    ///
63    /// # Arguments
64    ///
65    /// * `batch_size` - Number of meshes in the batch
66    /// * `num_vertices` - Number of vertices per mesh
67    /// * `faces` - Shared triangle face indices
68    #[must_use]
69    pub fn with_capacity(batch_size: usize, num_vertices: usize, faces: Vec<[u32; 3]>) -> Self {
70        let mut vertices = Vec::with_capacity(batch_size);
71        let mut normals = Vec::with_capacity(batch_size);
72
73        for _ in 0..batch_size {
74            vertices.push(vec![na::Point3::origin(); num_vertices]);
75            normals.push(vec![na::Vector3::zeros(); num_vertices]);
76        }
77
78        Self {
79            vertices,
80            normals,
81            faces,
82            batch_size,
83        }
84    }
85
86    /// Get mesh at index (clones data).
87    ///
88    /// Returns `None` if index is out of bounds.
89    #[must_use]
90    pub fn get_mesh(&self, index: usize) -> Option<Mesh> {
91        if index >= self.batch_size {
92            return None;
93        }
94        Some(Mesh {
95            vertices: self.vertices[index].clone(),
96            normals: self.normals[index].clone(),
97            faces: self.faces.clone(),
98        })
99    }
100
101    /// Convert to `Vec<Mesh>` by consuming self.
102    #[must_use]
103    pub fn into_meshes(self) -> Vec<Mesh> {
104        let faces = self.faces;
105        self.vertices
106            .into_iter()
107            .zip(self.normals)
108            .map(|(verts, norms)| Mesh {
109                vertices: verts,
110                normals: norms,
111                faces: faces.clone(),
112            })
113            .collect()
114    }
115
116    /// Number of vertices per mesh.
117    #[must_use]
118    pub fn num_vertices(&self) -> usize {
119        self.vertices.first().map_or(0, Vec::len)
120    }
121}
122
123/// Reusable intermediate buffers for batch processing.
124///
125/// This structure holds pre-allocated buffers that can be reused across
126/// multiple batch forward passes to avoid repeated memory allocation.
127#[derive(Debug, Clone)]
128pub struct BatchBufferPool {
129    /// Pre-allocated `v_shaped` buffers `[batch_size][num_vertices, 3]`.
130    v_shaped: Vec<Array2<f32>>,
131    /// Pre-allocated `v_posed` buffers `[batch_size][num_vertices, 3]`.
132    v_posed: Vec<Array2<f32>>,
133    /// Pre-allocated rotation matrices `[batch_size][n_joints]`.
134    rot_mats: Vec<Vec<na::Matrix3<f32>>>,
135    /// Pre-allocated skinning transforms `[batch_size][n_joints]`.
136    skinning: Vec<Vec<na::Matrix4<f32>>>,
137    /// Number of vertices.
138    num_vertices: usize,
139    /// Number of joints.
140    n_joints: usize,
141    /// Current batch capacity.
142    batch_capacity: usize,
143}
144
145impl BatchBufferPool {
146    /// Create a new buffer pool with specified capacity.
147    ///
148    /// # Arguments
149    ///
150    /// * `batch_size` - Maximum batch size to support
151    /// * `num_vertices` - Number of vertices per mesh
152    /// * `n_joints` - Number of joints (5 for FLAME)
153    #[must_use]
154    pub fn new(batch_size: usize, num_vertices: usize, n_joints: usize) -> Self {
155        let mut pool = Self {
156            v_shaped: Vec::with_capacity(batch_size),
157            v_posed: Vec::with_capacity(batch_size),
158            rot_mats: Vec::with_capacity(batch_size),
159            skinning: Vec::with_capacity(batch_size),
160            num_vertices,
161            n_joints,
162            batch_capacity: batch_size,
163        };
164
165        for _ in 0..batch_size {
166            pool.v_shaped.push(Array2::zeros((num_vertices, 3)));
167            pool.v_posed.push(Array2::zeros((num_vertices, 3)));
168            pool.rot_mats.push(vec![na::Matrix3::identity(); n_joints]);
169            pool.skinning.push(vec![na::Matrix4::identity(); n_joints]);
170        }
171
172        pool
173    }
174
175    /// Ensure the pool has capacity for at least `batch_size` items.
176    pub fn ensure_capacity(&mut self, batch_size: usize) {
177        while self.batch_capacity < batch_size {
178            self.v_shaped.push(Array2::zeros((self.num_vertices, 3)));
179            self.v_posed.push(Array2::zeros((self.num_vertices, 3)));
180            self.rot_mats
181                .push(vec![na::Matrix3::identity(); self.n_joints]);
182            self.skinning
183                .push(vec![na::Matrix4::identity(); self.n_joints]);
184            self.batch_capacity += 1;
185        }
186    }
187
188    /// Get the current batch capacity.
189    #[must_use]
190    pub fn capacity(&self) -> usize {
191        self.batch_capacity
192    }
193
194    /// Clear all buffers (but keep capacity).
195    pub fn clear(&mut self) {
196        for v in &mut self.v_shaped {
197            v.fill(0.0);
198        }
199        for v in &mut self.v_posed {
200            v.fill(0.0);
201        }
202        for r in &mut self.rot_mats {
203            for mat in r {
204                *mat = na::Matrix3::identity();
205            }
206        }
207        for s in &mut self.skinning {
208            for mat in s {
209                *mat = na::Matrix4::identity();
210            }
211        }
212    }
213}
214
215// ---------------------------------------------------------------------------
216// FlameModel
217// ---------------------------------------------------------------------------
218
219/// The loaded FLAME parametric head model.
220///
221/// Immutable after construction — call [`forward`](Self::forward) with
222/// different [`FlameParams`] to produce posed meshes.
223pub struct FlameModel {
224    /// Template (rest-pose) vertex positions `[N, 3]`.
225    pub v_template: Array2<f32>,
226    /// Triangle face indices.
227    pub faces: Vec<[u32; 3]>,
228    /// Shape blend-shape directions `[N, 3, n_shape]`.
229    pub shapedirs: Array3<f32>,
230    /// Expression blend-shape directions `[N, 3, n_expr]`.
231    pub expressiondirs: Array3<f32>,
232    /// Pose corrective blend-shape directions `[N, 3, (n_joints-1)*9]`.
233    pub posedirs: Array3<f32>,
234    /// Joint regressor matrix `[n_joints, N]`.
235    pub j_regressor: Array2<f32>,
236    /// Parent joint index for each joint (root = -1).
237    pub parents: Vec<i32>,
238    /// LBS skinning weights `[N, n_joints]`.
239    pub lbs_weights: Array2<f32>,
240    /// Number of joints (5 for FLAME).
241    pub n_joints: usize,
242}
243
244impl FlameModel {
245    /// Load a FLAME model from a directory of `.npy` files produced by
246    /// `scripts/convert_flame.py`.
247    ///
248    /// # Errors
249    ///
250    /// Returns an error if:
251    /// - The directory does not exist
252    /// - Required `.npy` files are missing
253    /// - Array shapes do not match expected dimensions
254    pub fn load(dir: impl AsRef<Path>) -> Result<Self, FlameError> {
255        crate::io::load_flame_model(dir.as_ref())
256    }
257
258    /// Number of template vertices (5023 for standard FLAME).
259    #[must_use]
260    pub fn num_vertices(&self) -> usize {
261        self.v_template.nrows()
262    }
263
264    // -----------------------------------------------------------------------
265    // Forward pass
266    // -----------------------------------------------------------------------
267
268    /// Compute the posed mesh from FLAME parameters.
269    #[must_use]
270    pub fn forward(&self, params: &FlameParams) -> Mesh {
271        // 1. Shape + expression blend shapes → v_shaped
272        let v_shaped = self.apply_shape_expression(params);
273
274        // 2. Joint positions from shaped vertices
275        let joints = self.j_regressor.dot(&v_shaped); // [n_joints, 3]
276
277        // 3. Per-joint rotation matrices (Rodrigues)
278        let rot_mats = self.compute_rotation_matrices(params);
279
280        // 4. Pose corrective blend shapes → v_posed
281        let v_posed = self.apply_pose_blend_shapes(&v_shaped, &rot_mats);
282
283        // 5. Build kinematic-chain skinning transforms
284        let skinning = self.compute_skinning_transforms(&rot_mats, &joints);
285
286        // 6. Linear Blend Skinning
287        let vertices = self.apply_lbs(&v_posed, &skinning, params);
288
289        // 7. Assemble mesh with normals
290        Mesh::new(vertices, self.faces.clone())
291    }
292
293    /// Compute the posed mesh using SIMD-accelerated operations.
294    ///
295    /// This method uses SIMD intrinsics for blend shapes and LBS when the
296    /// `simd` feature is enabled. Falls back to scalar implementation otherwise.
297    #[cfg(all(feature = "simd", nightly))]
298    #[must_use]
299    pub fn forward_simd(&self, params: &FlameParams) -> Mesh {
300        use crate::simd::apply_lbs_simd;
301
302        // 1. Shape + expression blend shapes → v_shaped (SIMD accelerated)
303        let v_shaped = self.apply_shape_expression_simd(params);
304
305        // 2. Joint positions from shaped vertices
306        let joints = self.j_regressor.dot(&v_shaped); // [n_joints, 3]
307
308        // 3. Per-joint rotation matrices (Rodrigues SIMD)
309        let rot_mats = self.compute_rotation_matrices_simd(params);
310
311        // 4. Pose corrective blend shapes → v_posed (SIMD accelerated)
312        let v_posed = self.apply_pose_blend_shapes_simd(&v_shaped, &rot_mats);
313
314        // 5. Build kinematic-chain skinning transforms
315        let skinning = self.compute_skinning_transforms(&rot_mats, &joints);
316
317        // 6. Linear Blend Skinning (SIMD accelerated)
318        let vertices = apply_lbs_simd(
319            &v_posed,
320            &skinning,
321            &self.lbs_weights.view(),
322            params.translation,
323        );
324
325        // 7. Assemble mesh with normals
326        Mesh::new(vertices, self.faces.clone())
327    }
328
329    // -----------------------------------------------------------------------
330    // Batch processing
331    // -----------------------------------------------------------------------
332
333    /// Process multiple parameter sets sequentially.
334    ///
335    /// Shares the model weights across all meshes in the batch.
336    ///
337    /// # Arguments
338    ///
339    /// * `params_batch` - Slice of FLAME parameters for each mesh
340    ///
341    /// # Returns
342    ///
343    /// Vector of posed meshes, one per parameter set.
344    #[must_use]
345    pub fn forward_batch(&self, params_batch: &[FlameParams]) -> Vec<Mesh> {
346        params_batch.iter().map(|p| self.forward(p)).collect()
347    }
348
349    /// Process multiple parameter sets sequentially with SIMD acceleration.
350    #[cfg(all(feature = "simd", nightly))]
351    #[must_use]
352    pub fn forward_batch_simd(&self, params_batch: &[FlameParams]) -> Vec<Mesh> {
353        params_batch.iter().map(|p| self.forward_simd(p)).collect()
354    }
355
356    /// Process multiple parameter sets in parallel using rayon.
357    ///
358    /// This method provides optimal performance for batch processing by:
359    /// - Sharing immutable model weights across threads
360    /// - Processing each mesh independently in parallel
361    /// - Automatically scaling to available CPU cores
362    ///
363    /// # Arguments
364    ///
365    /// * `params_batch` - Slice of FLAME parameters for each mesh
366    ///
367    /// # Returns
368    ///
369    /// Vector of posed meshes, one per parameter set.
370    ///
371    /// # Performance
372    ///
373    /// For batches of 10+ meshes, expect ~N× speedup where N is the number
374    /// of CPU cores. Memory usage scales linearly with batch size.
375    #[cfg(feature = "parallel")]
376    #[must_use]
377    pub fn forward_batch_par(&self, params_batch: &[FlameParams]) -> Vec<Mesh> {
378        params_batch.par_iter().map(|p| self.forward(p)).collect()
379    }
380
381    /// Process multiple parameter sets in parallel with SIMD acceleration.
382    ///
383    /// Combines rayon parallelism with SIMD vectorization for maximum throughput.
384    #[cfg(all(feature = "parallel", feature = "simd", nightly))]
385    #[must_use]
386    pub fn forward_batch_par_simd(&self, params_batch: &[FlameParams]) -> Vec<Mesh> {
387        params_batch
388            .par_iter()
389            .map(|p| self.forward_simd(p))
390            .collect()
391    }
392
393    // -----------------------------------------------------------------------
394    // Optimized batch processing with pre-allocated buffers
395    // -----------------------------------------------------------------------
396
397    /// Process multiple parameter sets with pre-allocated output buffers.
398    ///
399    /// This method is more memory-efficient than `forward_batch` when processing
400    /// many batches repeatedly, as it returns a `BatchedFlameOutput` with
401    /// pre-allocated buffers that can be reused.
402    ///
403    /// # Arguments
404    ///
405    /// * `params_batch` - Slice of FLAME parameters for each mesh
406    ///
407    /// # Returns
408    ///
409    /// `BatchedFlameOutput` containing all vertices and normals with shared faces.
410    #[must_use]
411    pub fn forward_batch_optimized(&self, params_batch: &[FlameParams]) -> BatchedFlameOutput {
412        let batch_size = params_batch.len();
413        let num_vertices = self.num_vertices();
414        let mut output =
415            BatchedFlameOutput::with_capacity(batch_size, num_vertices, self.faces.clone());
416
417        for (idx, params) in params_batch.iter().enumerate() {
418            self.forward_into(params, &mut output.vertices[idx], &mut output.normals[idx]);
419        }
420
421        output
422    }
423
424    /// Process multiple parameter sets in parallel with pre-allocated output buffers.
425    ///
426    /// Combines rayon parallelism with pre-allocated output buffers for maximum
427    /// throughput and memory efficiency.
428    ///
429    /// # Arguments
430    ///
431    /// * `params_batch` - Slice of FLAME parameters for each mesh
432    ///
433    /// # Returns
434    ///
435    /// `BatchedFlameOutput` containing all vertices and normals with shared faces.
436    ///
437    /// # Performance
438    ///
439    /// This is the recommended method for production batch processing:
440    /// - Pre-allocated output buffers avoid repeated allocations
441    /// - Parallel processing scales with CPU cores
442    /// - Shared face indices reduce memory footprint
443    #[cfg(feature = "parallel")]
444    #[must_use]
445    pub fn forward_batch_par_optimized(&self, params_batch: &[FlameParams]) -> BatchedFlameOutput {
446        let batch_size = params_batch.len();
447        let num_vertices = self.num_vertices();
448        let mut output =
449            BatchedFlameOutput::with_capacity(batch_size, num_vertices, self.faces.clone());
450
451        // Process in parallel using rayon
452        params_batch
453            .par_iter()
454            .zip(output.vertices.par_iter_mut())
455            .zip(output.normals.par_iter_mut())
456            .for_each(|((params, vertices), normals)| {
457                self.forward_into(params, vertices, normals);
458            });
459
460        output
461    }
462
463    /// Process multiple parameter sets with buffer pool for intermediate values.
464    ///
465    /// This method reuses intermediate buffers across the batch to minimize
466    /// memory allocations during the forward pass.
467    ///
468    /// # Arguments
469    ///
470    /// * `params_batch` - Slice of FLAME parameters for each mesh
471    /// * `buffer_pool` - Pre-allocated buffer pool for intermediate values
472    ///
473    /// # Returns
474    ///
475    /// `BatchedFlameOutput` containing all vertices and normals.
476    ///
477    /// # Example
478    ///
479    /// ```rust,no_run
480    /// # use oxigaf_flame::{FlameModel, FlameParams, BatchBufferPool};
481    /// let model = FlameModel::load("path/to/flame")?;
482    /// let mut pool = BatchBufferPool::new(16, model.num_vertices(), 5);
483    ///
484    /// // Reuse pool across multiple batch calls
485    /// for _ in 0..100 {
486    ///     let params_batch: Vec<FlameParams> = vec![/* ... */];
487    ///     let output = model.forward_batch_with_pool(&params_batch, &mut pool);
488    /// }
489    /// # Ok::<(), oxigaf_flame::FlameError>(())
490    /// ```
491    pub fn forward_batch_with_pool(
492        &self,
493        params_batch: &[FlameParams],
494        buffer_pool: &mut BatchBufferPool,
495    ) -> BatchedFlameOutput {
496        let batch_size = params_batch.len();
497        let num_vertices = self.num_vertices();
498
499        // Ensure pool has enough capacity
500        buffer_pool.ensure_capacity(batch_size);
501
502        let mut output =
503            BatchedFlameOutput::with_capacity(batch_size, num_vertices, self.faces.clone());
504
505        for (idx, params) in params_batch.iter().enumerate() {
506            self.forward_into_with_buffers(
507                params,
508                &mut buffer_pool.v_shaped[idx],
509                &mut buffer_pool.v_posed[idx],
510                &mut buffer_pool.rot_mats[idx],
511                &mut buffer_pool.skinning[idx],
512                &mut output.vertices[idx],
513                &mut output.normals[idx],
514            );
515        }
516
517        output
518    }
519
520    /// Process multiple parameter sets in parallel with buffer pool.
521    ///
522    /// This method combines parallel processing with buffer reuse for
523    /// optimal performance on multi-core systems.
524    ///
525    /// # Arguments
526    ///
527    /// * `params_batch` - Slice of FLAME parameters for each mesh
528    /// * `buffer_pool` - Pre-allocated buffer pool for intermediate values
529    ///
530    /// # Returns
531    ///
532    /// `BatchedFlameOutput` containing all vertices and normals.
533    #[cfg(feature = "parallel")]
534    pub fn forward_batch_par_with_pool(
535        &self,
536        params_batch: &[FlameParams],
537        buffer_pool: &mut BatchBufferPool,
538    ) -> BatchedFlameOutput {
539        let batch_size = params_batch.len();
540        let num_vertices = self.num_vertices();
541
542        // Ensure pool has enough capacity
543        buffer_pool.ensure_capacity(batch_size);
544
545        let mut output =
546            BatchedFlameOutput::with_capacity(batch_size, num_vertices, self.faces.clone());
547
548        // Process in parallel
549        params_batch
550            .par_iter()
551            .enumerate()
552            .zip(output.vertices.par_iter_mut())
553            .zip(output.normals.par_iter_mut())
554            .for_each(|(((idx, params), vertices), normals)| {
555                // Note: This requires that buffer_pool buffers are not modified
556                // during parallel access. For full parallelism with buffer reuse,
557                // thread-local buffers would be needed.
558                // Here we use a simpler approach: each thread gets its own view.
559                // For the parallel case without pool, we just do direct forward.
560                self.forward_into(params, vertices, normals);
561                let _ = idx; // Suppress unused warning
562            });
563
564        output
565    }
566
567    /// Create a buffer pool sized for this model.
568    ///
569    /// # Arguments
570    ///
571    /// * `batch_size` - Maximum batch size to support
572    #[must_use]
573    pub fn create_buffer_pool(&self, batch_size: usize) -> BatchBufferPool {
574        BatchBufferPool::new(batch_size, self.num_vertices(), self.n_joints)
575    }
576
577    // -----------------------------------------------------------------------
578    // In-place forward pass (writes directly to output buffers)
579    // -----------------------------------------------------------------------
580
581    /// Compute the posed mesh, writing directly to provided output buffers.
582    ///
583    /// This method avoids allocation by writing vertices and normals directly
584    /// to the provided slices.
585    ///
586    /// # Arguments
587    ///
588    /// * `params` - FLAME parameters
589    /// * `vertices_out` - Output buffer for vertices (must have correct size)
590    /// * `normals_out` - Output buffer for normals (must have correct size)
591    pub fn forward_into(
592        &self,
593        params: &FlameParams,
594        vertices_out: &mut [na::Point3<f32>],
595        normals_out: &mut [na::Vector3<f32>],
596    ) {
597        // 1. Shape + expression blend shapes → v_shaped
598        let v_shaped = self.apply_shape_expression(params);
599
600        // 2. Joint positions from shaped vertices
601        let joints = self.j_regressor.dot(&v_shaped);
602
603        // 3. Per-joint rotation matrices (Rodrigues)
604        let rot_mats = self.compute_rotation_matrices(params);
605
606        // 4. Pose corrective blend shapes → v_posed
607        let v_posed = self.apply_pose_blend_shapes(&v_shaped, &rot_mats);
608
609        // 5. Build kinematic-chain skinning transforms
610        let skinning = self.compute_skinning_transforms(&rot_mats, &joints);
611
612        // 6. Linear Blend Skinning (directly into output)
613        self.apply_lbs_into(&v_posed, &skinning, params, vertices_out);
614
615        // 7. Compute normals directly into output
616        compute_normals_into(vertices_out, &self.faces, normals_out);
617    }
618
619    /// Compute the posed mesh with reusable intermediate buffers.
620    #[allow(clippy::too_many_arguments)]
621    fn forward_into_with_buffers(
622        &self,
623        params: &FlameParams,
624        v_shaped: &mut Array2<f32>,
625        v_posed: &mut Array2<f32>,
626        rot_mats: &mut [na::Matrix3<f32>],
627        skinning: &mut [na::Matrix4<f32>],
628        vertices_out: &mut [na::Point3<f32>],
629        normals_out: &mut [na::Vector3<f32>],
630    ) {
631        // 1. Shape + expression blend shapes → v_shaped
632        self.apply_shape_expression_into(params, v_shaped);
633
634        // 2. Joint positions from shaped vertices
635        let joints = self.j_regressor.dot(v_shaped);
636
637        // 3. Per-joint rotation matrices (Rodrigues)
638        self.compute_rotation_matrices_into(params, rot_mats);
639
640        // 4. Pose corrective blend shapes → v_posed
641        self.apply_pose_blend_shapes_into(v_shaped, rot_mats, v_posed);
642
643        // 5. Build kinematic-chain skinning transforms
644        self.compute_skinning_transforms_into(rot_mats, &joints, skinning);
645
646        // 6. Linear Blend Skinning (directly into output)
647        self.apply_lbs_into(v_posed, skinning, params, vertices_out);
648
649        // 7. Compute normals directly into output
650        compute_normals_into(vertices_out, &self.faces, normals_out);
651    }
652
653    // -----------------------------------------------------------------------
654    // Internal helpers
655    // -----------------------------------------------------------------------
656
657    #[inline]
658    fn apply_shape_expression(&self, params: &FlameParams) -> Array2<f32> {
659        let mut v = self.v_template.clone();
660        apply_blend_shapes(&mut v, &self.shapedirs, &params.shape);
661        apply_blend_shapes(&mut v, &self.expressiondirs, &params.expression);
662        v
663    }
664
665    #[inline]
666    fn compute_rotation_matrices(&self, params: &FlameParams) -> Vec<na::Matrix3<f32>> {
667        (0..self.n_joints)
668            .map(|j| {
669                let [rx, ry, rz] = params.joint_pose(j);
670                rodrigues(rx, ry, rz)
671            })
672            .collect()
673    }
674
675    fn apply_pose_blend_shapes(
676        &self,
677        v_shaped: &Array2<f32>,
678        rot_mats: &[na::Matrix3<f32>],
679    ) -> Array2<f32> {
680        // Pose feature: flatten (R_j - I) for all non-root joints
681        let identity = na::Matrix3::<f32>::identity();
682        let mut pose_feature = Vec::with_capacity((self.n_joints - 1) * 9);
683
684        for rot in rot_mats.iter().skip(1) {
685            let diff = rot - identity;
686            // Column-major order to match PyTorch's flatten
687            for c in 0..3 {
688                for r in 0..3 {
689                    pose_feature.push(diff[(r, c)]);
690                }
691            }
692        }
693
694        let mut v = v_shaped.clone();
695        apply_blend_shapes(&mut v, &self.posedirs, &pose_feature);
696        v
697    }
698
699    fn compute_skinning_transforms(
700        &self,
701        rot_mats: &[na::Matrix3<f32>],
702        joints: &Array2<f32>,
703    ) -> Vec<na::Matrix4<f32>> {
704        let nj = self.n_joints;
705        let mut global = vec![na::Matrix4::<f32>::identity(); nj];
706
707        // Build global transforms via kinematic chain
708        for j in 0..nj {
709            let j_pos = na::Vector3::new(joints[[j, 0]], joints[[j, 1]], joints[[j, 2]]);
710            let parent = self.parents[j];
711
712            let mut local = na::Matrix4::identity();
713            // Set rotation block
714            for r in 0..3 {
715                for c in 0..3 {
716                    local[(r, c)] = rot_mats[j][(r, c)];
717                }
718            }
719
720            if parent < 0 {
721                // Root joint: absolute position
722                local[(0, 3)] = j_pos.x;
723                local[(1, 3)] = j_pos.y;
724                local[(2, 3)] = j_pos.z;
725                global[j] = local;
726            } else {
727                // Child joint: relative to parent
728                let p = parent as usize;
729                let p_pos = na::Vector3::new(joints[[p, 0]], joints[[p, 1]], joints[[p, 2]]);
730                let rel = j_pos - p_pos;
731                local[(0, 3)] = rel.x;
732                local[(1, 3)] = rel.y;
733                local[(2, 3)] = rel.z;
734                global[j] = global[p] * local;
735            }
736        }
737
738        // Remove rest-pose joint translations to obtain skinning transforms:
739        //   A_j = G_j  –  pad( G_j · [J_j, 0]^T )
740        // so that A_j(v) = R_global · (v – J_j) + t_global
741        for j in 0..nj {
742            let j_homo = na::Vector4::new(joints[[j, 0]], joints[[j, 1]], joints[[j, 2]], 0.0);
743            let correction = global[j] * j_homo;
744            global[j][(0, 3)] -= correction[0];
745            global[j][(1, 3)] -= correction[1];
746            global[j][(2, 3)] -= correction[2];
747        }
748
749        global
750    }
751
752    fn apply_lbs(
753        &self,
754        v_posed: &Array2<f32>,
755        transforms: &[na::Matrix4<f32>],
756        params: &FlameParams,
757    ) -> Vec<na::Point3<f32>> {
758        let n = v_posed.nrows();
759        let nj = self.n_joints;
760        let [tx, ty, tz] = params.translation;
761
762        let mut out = Vec::with_capacity(n);
763        for i in 0..n {
764            // Weighted blend of skinning transforms
765            let mut t = na::Matrix4::<f32>::zeros();
766            for (j, transform) in transforms.iter().enumerate().take(nj) {
767                let w = self.lbs_weights[[i, j]];
768                if w.abs() > 1e-12 {
769                    t += w * transform;
770                }
771            }
772
773            let v = na::Vector4::new(v_posed[[i, 0]], v_posed[[i, 1]], v_posed[[i, 2]], 1.0);
774            let r = t * v;
775
776            out.push(na::Point3::new(r[0] + tx, r[1] + ty, r[2] + tz));
777        }
778        out
779    }
780
781    // -----------------------------------------------------------------------
782    // In-place internal helpers (for buffer reuse)
783    // -----------------------------------------------------------------------
784
785    /// Apply shape and expression blend shapes into a pre-allocated buffer.
786    #[inline]
787    fn apply_shape_expression_into(&self, params: &FlameParams, out: &mut Array2<f32>) {
788        // Copy template to output
789        out.assign(&self.v_template);
790        // Apply blend shapes in-place
791        apply_blend_shapes(out, &self.shapedirs, &params.shape);
792        apply_blend_shapes(out, &self.expressiondirs, &params.expression);
793    }
794
795    /// Compute rotation matrices into a pre-allocated buffer.
796    #[inline]
797    fn compute_rotation_matrices_into(&self, params: &FlameParams, out: &mut [na::Matrix3<f32>]) {
798        for (j, mat) in out.iter_mut().enumerate().take(self.n_joints) {
799            let [rx, ry, rz] = params.joint_pose(j);
800            *mat = rodrigues(rx, ry, rz);
801        }
802    }
803
804    /// Apply pose blend shapes into a pre-allocated buffer.
805    fn apply_pose_blend_shapes_into(
806        &self,
807        v_shaped: &Array2<f32>,
808        rot_mats: &[na::Matrix3<f32>],
809        out: &mut Array2<f32>,
810    ) {
811        // Pose feature: flatten (R_j - I) for all non-root joints
812        let identity = na::Matrix3::<f32>::identity();
813        let mut pose_feature = Vec::with_capacity((self.n_joints - 1) * 9);
814
815        for rot in rot_mats.iter().skip(1) {
816            let diff = rot - identity;
817            // Column-major order to match PyTorch's flatten
818            for c in 0..3 {
819                for r in 0..3 {
820                    pose_feature.push(diff[(r, c)]);
821                }
822            }
823        }
824
825        // Copy v_shaped to output
826        out.assign(v_shaped);
827        apply_blend_shapes(out, &self.posedirs, &pose_feature);
828    }
829
830    /// Compute skinning transforms into a pre-allocated buffer.
831    fn compute_skinning_transforms_into(
832        &self,
833        rot_mats: &[na::Matrix3<f32>],
834        joints: &Array2<f32>,
835        out: &mut [na::Matrix4<f32>],
836    ) {
837        let nj = self.n_joints;
838
839        // Initialize to identity
840        for mat in out.iter_mut().take(nj) {
841            *mat = na::Matrix4::identity();
842        }
843
844        // Build global transforms via kinematic chain
845        for j in 0..nj {
846            let j_pos = na::Vector3::new(joints[[j, 0]], joints[[j, 1]], joints[[j, 2]]);
847            let parent = self.parents[j];
848
849            let mut local = na::Matrix4::identity();
850            // Set rotation block
851            for r in 0..3 {
852                for c in 0..3 {
853                    local[(r, c)] = rot_mats[j][(r, c)];
854                }
855            }
856
857            if parent < 0 {
858                // Root joint: absolute position
859                local[(0, 3)] = j_pos.x;
860                local[(1, 3)] = j_pos.y;
861                local[(2, 3)] = j_pos.z;
862                out[j] = local;
863            } else {
864                // Child joint: relative to parent
865                let p = parent as usize;
866                let p_pos = na::Vector3::new(joints[[p, 0]], joints[[p, 1]], joints[[p, 2]]);
867                let rel = j_pos - p_pos;
868                local[(0, 3)] = rel.x;
869                local[(1, 3)] = rel.y;
870                local[(2, 3)] = rel.z;
871                out[j] = out[p] * local;
872            }
873        }
874
875        // Remove rest-pose joint translations
876        for j in 0..nj {
877            let j_homo = na::Vector4::new(joints[[j, 0]], joints[[j, 1]], joints[[j, 2]], 0.0);
878            let correction = out[j] * j_homo;
879            out[j][(0, 3)] -= correction[0];
880            out[j][(1, 3)] -= correction[1];
881            out[j][(2, 3)] -= correction[2];
882        }
883    }
884
885    /// Apply LBS directly into a pre-allocated output buffer.
886    fn apply_lbs_into(
887        &self,
888        v_posed: &Array2<f32>,
889        transforms: &[na::Matrix4<f32>],
890        params: &FlameParams,
891        out: &mut [na::Point3<f32>],
892    ) {
893        let n = v_posed.nrows();
894        let nj = self.n_joints;
895        let [tx, ty, tz] = params.translation;
896
897        for i in 0..n {
898            // Weighted blend of skinning transforms
899            let mut t = na::Matrix4::<f32>::zeros();
900            for (j, transform) in transforms.iter().enumerate().take(nj) {
901                let w = self.lbs_weights[[i, j]];
902                if w.abs() > 1e-12 {
903                    t += w * transform;
904                }
905            }
906
907            let v = na::Vector4::new(v_posed[[i, 0]], v_posed[[i, 1]], v_posed[[i, 2]], 1.0);
908            let r = t * v;
909
910            out[i] = na::Point3::new(r[0] + tx, r[1] + ty, r[2] + tz);
911        }
912    }
913
914    // -----------------------------------------------------------------------
915    // SIMD-accelerated internal helpers
916    // -----------------------------------------------------------------------
917
918    /// Apply shape and expression blend shapes using SIMD.
919    #[cfg(all(feature = "simd", nightly))]
920    #[inline]
921    fn apply_shape_expression_simd(&self, params: &FlameParams) -> Array2<f32> {
922        use crate::simd::apply_blend_shapes_simd;
923
924        let mut v = self.v_template.clone();
925        apply_blend_shapes_simd(&mut v, &self.shapedirs, &params.shape);
926        apply_blend_shapes_simd(&mut v, &self.expressiondirs, &params.expression);
927        v
928    }
929
930    /// Compute rotation matrices using SIMD-accelerated Rodrigues.
931    #[cfg(all(feature = "simd", nightly))]
932    #[inline]
933    fn compute_rotation_matrices_simd(&self, params: &FlameParams) -> Vec<na::Matrix3<f32>> {
934        use crate::simd::rodrigues_simd;
935
936        (0..self.n_joints)
937            .map(|j| {
938                let [rx, ry, rz] = params.joint_pose(j);
939                rodrigues_simd(rx, ry, rz)
940            })
941            .collect()
942    }
943
944    /// Apply pose blend shapes using SIMD.
945    #[cfg(all(feature = "simd", nightly))]
946    fn apply_pose_blend_shapes_simd(
947        &self,
948        v_shaped: &Array2<f32>,
949        rot_mats: &[na::Matrix3<f32>],
950    ) -> Array2<f32> {
951        use crate::simd::apply_blend_shapes_simd;
952
953        // Pose feature: flatten (R_j - I) for all non-root joints
954        let identity = na::Matrix3::<f32>::identity();
955        let mut pose_feature = Vec::with_capacity((self.n_joints - 1) * 9);
956
957        for rot in rot_mats.iter().skip(1) {
958            let diff = rot - identity;
959            // Column-major order to match PyTorch's flatten
960            for c in 0..3 {
961                for r in 0..3 {
962                    pose_feature.push(diff[(r, c)]);
963                }
964            }
965        }
966
967        let mut v = v_shaped.clone();
968        apply_blend_shapes_simd(&mut v, &self.posedirs, &pose_feature);
969        v
970    }
971}
972
973// ---------------------------------------------------------------------------
974// Free helpers
975// ---------------------------------------------------------------------------
976
977/// Rodrigues' rotation formula: axis-angle to 3x3 rotation matrix.
978#[inline]
979#[must_use]
980pub fn rodrigues(rx: f32, ry: f32, rz: f32) -> na::Matrix3<f32> {
981    let angle = (rx * rx + ry * ry + rz * rz).sqrt();
982    if angle < 1e-8 {
983        return na::Matrix3::identity();
984    }
985
986    let (ax, ay, az) = (rx / angle, ry / angle, rz / angle);
987    let cos_a = angle.cos();
988    let sin_a = angle.sin();
989    let t = 1.0 - cos_a;
990
991    #[rustfmt::skip]
992    let m = na::Matrix3::new(
993        t * ax * ax + cos_a,       t * ax * ay - az * sin_a,  t * ax * az + ay * sin_a,
994        t * ay * ax + az * sin_a,  t * ay * ay + cos_a,       t * ay * az - ax * sin_a,
995        t * az * ax - ay * sin_a,  t * az * ay + ax * sin_a,  t * az * az + cos_a,
996    );
997    m
998}
999
1000/// Add blend shapes in-place: `v += dirs · coeffs`.
1001///
1002/// `v` is `[N, 3]`, `dirs` is `[N, 3, K]`, `coeffs` has up to `K` elements.
1003#[inline]
1004fn apply_blend_shapes(v: &mut Array2<f32>, dirs: &Array3<f32>, coeffs: &[f32]) {
1005    let k = coeffs.len().min(dirs.shape()[2]);
1006    for (i, &coeff) in coeffs.iter().enumerate().take(k) {
1007        if coeff.abs() > 1e-12 {
1008            let dir_slice = dirs.slice(s![.., .., i]);
1009            v.scaled_add(coeff, &dir_slice);
1010        }
1011    }
1012}
1013
1014// ---------------------------------------------------------------------------
1015// Batched Normal Computation
1016// ---------------------------------------------------------------------------
1017
1018/// Compute per-vertex normals directly into a pre-allocated buffer.
1019///
1020/// This function computes area-weighted vertex normals from triangle faces.
1021/// The normals are computed in-place to avoid memory allocation.
1022///
1023/// # Arguments
1024///
1025/// * `vertices` - Slice of vertex positions
1026/// * `faces` - Slice of triangle face indices
1027/// * `normals_out` - Pre-allocated output buffer for normals (same length as vertices)
1028pub fn compute_normals_into(
1029    vertices: &[na::Point3<f32>],
1030    faces: &[[u32; 3]],
1031    normals_out: &mut [na::Vector3<f32>],
1032) {
1033    // Zero out the normals buffer
1034    for normal in normals_out.iter_mut() {
1035        *normal = na::Vector3::zeros();
1036    }
1037
1038    // Accumulate area-weighted face normals
1039    for face in faces {
1040        let i0 = face[0] as usize;
1041        let i1 = face[1] as usize;
1042        let i2 = face[2] as usize;
1043
1044        // Skip invalid face indices
1045        if i0 >= vertices.len() || i1 >= vertices.len() || i2 >= vertices.len() {
1046            continue;
1047        }
1048
1049        let v0 = &vertices[i0];
1050        let v1 = &vertices[i1];
1051        let v2 = &vertices[i2];
1052
1053        let edge1 = v1 - v0;
1054        let edge2 = v2 - v0;
1055        // Cross product -- magnitude proportional to triangle area
1056        let face_normal = edge1.cross(&edge2);
1057
1058        normals_out[i0] += face_normal;
1059        normals_out[i1] += face_normal;
1060        normals_out[i2] += face_normal;
1061    }
1062
1063    // Normalize
1064    for normal in normals_out.iter_mut() {
1065        let len = normal.norm();
1066        if len > 1e-10 {
1067            *normal /= len;
1068        }
1069    }
1070}
1071
1072/// Compute normals for multiple meshes in a batch.
1073///
1074/// This function processes multiple meshes sequentially, computing per-vertex
1075/// normals for each mesh from shared face indices.
1076///
1077/// # Arguments
1078///
1079/// * `vertices_batch` - Batch of vertex position slices
1080/// * `faces` - Shared triangle face indices
1081/// * `normals_batch` - Batch of output normal buffers
1082pub fn compute_normals_batch(
1083    vertices_batch: &[Vec<na::Point3<f32>>],
1084    faces: &[[u32; 3]],
1085    normals_batch: &mut [Vec<na::Vector3<f32>>],
1086) {
1087    for (vertices, normals) in vertices_batch.iter().zip(normals_batch.iter_mut()) {
1088        compute_normals_into(vertices, faces, normals);
1089    }
1090}
1091
1092/// Compute normals for multiple meshes in parallel.
1093///
1094/// This function uses rayon to parallelize normal computation across
1095/// the batch dimension, providing significant speedup for large batches.
1096///
1097/// # Arguments
1098///
1099/// * `vertices_batch` - Batch of vertex position slices
1100/// * `faces` - Shared triangle face indices (immutably shared across threads)
1101/// * `normals_batch` - Batch of output normal buffers
1102///
1103/// # Performance
1104///
1105/// For batches of 10+ meshes, expect near-linear speedup with CPU cores.
1106/// Memory access is well-localized since each mesh's normals are independent.
1107#[cfg(feature = "parallel")]
1108pub fn compute_normals_batch_par(
1109    vertices_batch: &[Vec<na::Point3<f32>>],
1110    faces: &[[u32; 3]],
1111    normals_batch: &mut [Vec<na::Vector3<f32>>],
1112) {
1113    vertices_batch
1114        .par_iter()
1115        .zip(normals_batch.par_iter_mut())
1116        .for_each(|(vertices, normals)| {
1117            compute_normals_into(vertices, faces, normals);
1118        });
1119}
1120
1121/// Compute normals for a `BatchedFlameOutput` in-place.
1122///
1123/// This is a convenience method that updates the normals in a `BatchedFlameOutput`
1124/// based on the current vertex positions.
1125///
1126/// # Arguments
1127///
1128/// * `output` - The batched output to update (normals are modified in-place)
1129pub fn recompute_batch_normals(output: &mut BatchedFlameOutput) {
1130    for (vertices, normals) in output.vertices.iter().zip(output.normals.iter_mut()) {
1131        compute_normals_into(vertices, &output.faces, normals);
1132    }
1133}
1134
1135/// Compute normals for a `BatchedFlameOutput` in parallel.
1136///
1137/// This is a convenience method that updates the normals in a `BatchedFlameOutput`
1138/// based on the current vertex positions, using parallel processing.
1139///
1140/// # Arguments
1141///
1142/// * `output` - The batched output to update (normals are modified in-place)
1143#[cfg(feature = "parallel")]
1144pub fn recompute_batch_normals_par(output: &mut BatchedFlameOutput) {
1145    let faces = &output.faces;
1146    output
1147        .vertices
1148        .par_iter()
1149        .zip(output.normals.par_iter_mut())
1150        .for_each(|(vertices, normals)| {
1151            compute_normals_into(vertices, faces, normals);
1152        });
1153}
1154
1155// ---------------------------------------------------------------------------
1156// Tests
1157// ---------------------------------------------------------------------------
1158
1159#[cfg(test)]
1160mod tests {
1161    use super::*;
1162
1163    #[test]
1164    fn test_rodrigues_identity() {
1165        let r = rodrigues(0.0, 0.0, 0.0);
1166        let id = na::Matrix3::<f32>::identity();
1167        assert!((r - id).norm() < 1e-6);
1168    }
1169
1170    #[test]
1171    fn test_rodrigues_90_deg_z() {
1172        use std::f32::consts::FRAC_PI_2;
1173        let r = rodrigues(0.0, 0.0, FRAC_PI_2);
1174        // Should rotate x-axis to y-axis
1175        let v = na::Vector3::new(1.0, 0.0, 0.0);
1176        let rv = r * v;
1177        assert!((rv.x).abs() < 1e-5);
1178        assert!((rv.y - 1.0).abs() < 1e-5);
1179        assert!((rv.z).abs() < 1e-5);
1180    }
1181
1182    #[test]
1183    fn test_rodrigues_roundtrip() {
1184        // Rotating by angle then -angle should give identity
1185        let r1 = rodrigues(0.3, -0.2, 0.1);
1186        let r2 = rodrigues(-0.3, 0.2, -0.1);
1187        let product = r1 * r2;
1188        let id = na::Matrix3::<f32>::identity();
1189        assert!((product - id).norm() < 1e-5);
1190    }
1191}