rend3_anim/
lib.rs

1//! Utility library to play gltf animations.
2//!
3//! This library is meant to be used together with `rend3` and `rend3-gltf` and
4//! allows posing meshes according to the animation data stored in a gltf file.
5//!
6//! In order to play animations, you need to:
7//! - Create an [`AnimationData`] once when spawning your scene and store it.
8//! - Each simulation frame, use [`pose_animation_frame`] to set the mesh's
9//!   joints to a specific animation at a specific time.
10//!
11//! For now, this library aims to be a simple utility abstraction. Updating the
12//! current state of the animation by changing the currently played animation or
13//! increasing the playback time should be handled in user code.
14
15use std::collections::HashMap;
16
17use itertools::Itertools;
18use rend3::{
19    types::{
20        glam::{Mat4, Quat, Vec3},
21        SkeletonHandle,
22    },
23    util::typedefs::{FastHashMap, FastHashSet},
24    Renderer,
25};
26use rend3_gltf::{AnimationChannel, GltfSceneInstance, LoadedGltfScene};
27
28#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
29pub struct AnimationIndex(pub usize);
30#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
31pub struct SkinIndex(pub usize);
32#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
33pub struct NodeIndex(pub usize);
34#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
35pub struct JointIndex(pub usize);
36
37/// Cached data structures per each of the Skins in a gltf model. This struct is
38/// part of [`AnimationData`]
39pub struct PerSkinData {
40    /// Translates node indices to joint indices for this particular skin. This
41    /// translation is necessary because an animation may be animating a scene
42    /// node which is present in multiple skins.
43    pub node_to_joint_idx: FastHashMap<NodeIndex, JointIndex>,
44    /// Stores the node indices for this skin's joints in topological order.
45    /// This is used to avoid iterating all the scene hierarchy when
46    /// computing global positions for a node.
47    pub joint_nodes_topological_order: Vec<NodeIndex>,
48    /// The list of skeletons deformed by this animation. There's one skeleton
49    /// for each of the mesh primitives. Every skeleton in this list is
50    /// shares the same bone structure.
51    pub skeletons: Vec<SkeletonHandle>,
52}
53
54/// Caches animation data necessary to run [`pose_animation_frame`].
55pub struct AnimationData {
56    /// For each skin, stores several cached data structures that speed up the
57    /// animation loop at runtime.
58    pub skin_data: FastHashMap<SkinIndex, PerSkinData>,
59    /// For each animation, stores the list of skins it affects. An animation
60    /// affects a skin if it deforms any of its joints. This is used to avoid
61    /// iterating unaffected skins when playing an animation.
62    pub animation_skin_usage: FastHashMap<AnimationIndex, Vec<SkinIndex>>,
63}
64
65impl AnimationData {
66    /// Creates an [`AnimationData`] from a loaded gltf scene and instance.
67    ///
68    /// Note that the instance is necessary, as one `AnimationData` must exist
69    /// per each instance of the same scene.
70    ///
71    /// ## Parameters
72    /// - scene: The loaded scene, as returned by
73    ///   [`load_gltf`](rend3_gltf::load_gltf) or
74    ///   [`load_gltf_data`](rend3_gltf::load_gltf_data)
75    /// - instance: An instance of `scene`, as returned by
76    ///   [`load_gltf`](rend3_gltf::load_gltf) or
77    ///   [`instance_loaded_scene`](rend3_gltf::instance_loaded_scene)
78    pub fn from_gltf_scene(scene: &LoadedGltfScene, instance: &GltfSceneInstance) -> Self {
79        // The set of joints that each animation affects, stored as node indices
80        // NOTE: Uses a std HashMap because `GroupingMap::collect()` is
81        // hardcoded to return that.
82        let animation_to_joint_nodes: HashMap<AnimationIndex, FastHashSet<NodeIndex>> = scene
83            .animations
84            .iter()
85            .enumerate()
86            .flat_map(|(anim_idx, anim)| {
87                anim.inner
88                    .channels
89                    .keys()
90                    .map(move |node_idx| (AnimationIndex(anim_idx), NodeIndex(*node_idx)))
91            })
92            .into_grouping_map()
93            .collect::<FastHashSet<_>>();
94
95        let mut animation_skin_usage = FastHashMap::<AnimationIndex, Vec<SkinIndex>>::default();
96        for animation_idx in 0..scene.animations.len() {
97            let animation_idx = AnimationIndex(animation_idx);
98            for (skin_index, skin) in scene.skins.iter().enumerate() {
99                let skin_index = SkinIndex(skin_index);
100
101                let anim_affected_nodes = &animation_to_joint_nodes[&animation_idx];
102                if skin
103                    .inner
104                    .joints
105                    .iter()
106                    .any(|j| anim_affected_nodes.contains(&NodeIndex(j.inner.node_idx)))
107                {
108                    let entry = animation_skin_usage
109                        .entry(animation_idx)
110                        .or_insert_with(Default::default);
111                    entry.push(skin_index);
112                }
113            }
114        }
115
116        let mut skin_data = FastHashMap::default();
117        for (skin_index, skin) in scene.skins.iter().enumerate() {
118            let skin_index = SkinIndex(skin_index);
119
120            let node_to_joint_idx = skin
121                .inner
122                .joints
123                .iter()
124                .enumerate()
125                .map(|(idx, joint)| (NodeIndex(joint.inner.node_idx), JointIndex(idx)))
126                .collect();
127
128            // Nodes affected by this skin (i.e. joints)
129            let skin_nodes: Vec<NodeIndex> = skin.inner.joints.iter().map(|j| NodeIndex(j.inner.node_idx)).collect();
130
131            let joint_nodes_topological_order: Vec<NodeIndex> = instance
132                .topological_order
133                .iter()
134                .map(|node_idx| NodeIndex(*node_idx))
135                .filter(|node_idx| skin_nodes.contains(node_idx))
136                .collect();
137
138            let skeletons: Vec<SkeletonHandle> = instance
139                .nodes
140                .iter()
141                .flat_map(|node| &node.inner.object)
142                .flat_map(|object| &object.inner.armature)
143                .filter(|armature| armature.skin_index == skin_index.0)
144                .flat_map(|armature| &armature.skeletons)
145                .cloned()
146                .collect();
147
148            skin_data.insert(
149                skin_index,
150                PerSkinData {
151                    node_to_joint_idx,
152                    joint_nodes_topological_order,
153                    skeletons,
154                },
155            );
156        }
157
158        AnimationData {
159            skin_data,
160            animation_skin_usage,
161        }
162    }
163}
164
165/// Helper trait that exposes a generic `lerp` function for various `glam` types
166pub trait Lerp {
167    fn lerp(self, other: Self, t: f32) -> Self;
168}
169impl Lerp for Vec3 {
170    fn lerp(self, other: Self, t: f32) -> Self {
171        self.lerp(other, t)
172    }
173}
174impl Lerp for Quat {
175    fn lerp(self, other: Self, t: f32) -> Self {
176        // Uses Normalized Linear Interpolation (a.k.a. nlerp) as slerp replacement
177        // See: *"Understanding Slerp, Then Not Using It"*
178        // http://number-none.com/product/Understanding%20Slerp,%20Then%20Not%20Using%20It/
179        self.lerp(other, t).normalize()
180    }
181}
182
183/// Samples the data value for an animation channel at a given time. Will
184/// interpolate between the two closest keyframes.
185fn sample_at_time<T: Lerp + Copy>(channel: &AnimationChannel<T>, current_time: f32) -> T {
186    let next_idx = channel
187        .times
188        .iter()
189        .position(|time| *time > current_time)
190        .unwrap_or(channel.times.len() - 1);
191    let prev_idx = next_idx.saturating_sub(1);
192
193    let interp_factor = f32::clamp(
194        (current_time - channel.times[prev_idx]) / (channel.times[next_idx] - channel.times[prev_idx]),
195        0.0,
196        1.0,
197    );
198
199    channel.values[prev_idx].lerp(channel.values[next_idx], interp_factor)
200}
201
202/// Sets the pose of the meshes at the given scene by using the animation at
203/// index `animation_index` at a given `time`. The provided time gets clamped to
204/// the valid range of times for the selected animation.
205pub fn pose_animation_frame(
206    renderer: &Renderer,
207    scene: &LoadedGltfScene,
208    instance: &GltfSceneInstance,
209    animation_data: &AnimationData,
210    animation_index: usize,
211    time: f32,
212) {
213    let animation = &scene.animations[animation_index];
214    let time = time.clamp(0.0, animation.inner.duration);
215
216    for (skin_index, per_skin_data) in &animation_data.skin_data {
217        let skin = &scene.skins[skin_index.0];
218        let inv_bind_mats = &skin.inner.inverse_bind_matrices;
219
220        // The local position of each joint, relative to its parent
221        let mut joint_local_matrices = vec![Mat4::IDENTITY; inv_bind_mats.len()];
222
223        let node_to_joint_idx = &per_skin_data.node_to_joint_idx;
224
225        // Compute each bone's local transformation
226        for (&node_idx, channels) in &animation.inner.channels {
227            // NOTE: If a channel's property is not present, we need to set the
228            // joint at its bind pose for that individual property
229            let local_transform = instance.nodes[node_idx].inner.local_transform;
230            let (bind_scale, bind_rotation, bind_translation) = local_transform.to_scale_rotation_translation();
231
232            let translation = channels
233                .translation
234                .as_ref()
235                .map(|tra| sample_at_time(tra, time))
236                .unwrap_or(bind_translation);
237            let rotation = channels
238                .rotation
239                .as_ref()
240                .map(|rot| sample_at_time(rot, time))
241                .unwrap_or(bind_rotation);
242            let scale = channels
243                .scale
244                .as_ref()
245                .map(|sca| sample_at_time(sca, time))
246                .unwrap_or(bind_scale);
247
248            let matrix = Mat4::from_scale_rotation_translation(scale, rotation, translation);
249            let joint_idx = node_to_joint_idx[&NodeIndex(node_idx)];
250            joint_local_matrices[joint_idx.0] = matrix;
251        }
252
253        let mut global_joint_transforms = vec![Mat4::IDENTITY; inv_bind_mats.len()];
254
255        // Compute bone global transformations
256        for node_idx in &per_skin_data.joint_nodes_topological_order {
257            let node = &instance.nodes[node_idx.0].inner;
258            let joint_idx = node_to_joint_idx[node_idx];
259            if let Some(parent_joint_idx) = node.parent.map(|pi| node_to_joint_idx.get(&NodeIndex(pi))) {
260                // This is guaranteed to be computed because we're iterating
261                // the hierarchy nodes in topological order
262                let parent_transform = parent_joint_idx
263                    .map(|p| global_joint_transforms[p.0])
264                    .unwrap_or(Mat4::IDENTITY);
265                let current_transform = joint_local_matrices[joint_idx.0];
266
267                global_joint_transforms[joint_idx.0] = parent_transform * current_transform;
268            } else {
269                global_joint_transforms[joint_idx.0] = joint_local_matrices[joint_idx.0];
270            }
271        }
272
273        // Set the joint positions in rend3
274        for skeleton in &per_skin_data.skeletons {
275            renderer.set_skeleton_joint_transforms(skeleton, &global_joint_transforms, inv_bind_mats);
276        }
277    }
278}