leafwing_input_manager/
action_diff.rs

1//! Serialization-friendly representation of changes to [`ActionState`].
2//!
3//! These are predominantly intended for use in networked games,
4//! where the server needs to know what the players are doing.
5//! They would like a compact, semantically meaningful representation of the changes to the game state without needing to know
6//! about things like keybindings or input devices.
7
8use bevy::{
9    ecs::{
10        entity::{Entity, MapEntities},
11        event::Event,
12        query::QueryFilter,
13    },
14    math::{Vec2, Vec3},
15    platform::collections::{HashMap, HashSet},
16    prelude::{EntityMapper, EventWriter, Query, Res},
17};
18use serde::{Deserialize, Serialize};
19
20use crate::buttonlike::ButtonValue;
21use crate::{action_state::ActionKindData, prelude::ActionState, Actionlike};
22
23/// Stores presses and releases of buttons without timing information
24///
25/// These are typically accessed using the `Events<ActionDiffEvent>` resource.
26/// Uses a minimal storage format to facilitate transport over the network.
27///
28/// An `ActionState` can be fully reconstructed from a stream of `ActionDiff`.
29#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
30pub enum ActionDiff<A: Actionlike> {
31    /// The action was pressed
32    Pressed {
33        /// The value of the action
34        action: A,
35        /// The new value of the action
36        value: f32,
37    },
38    /// The action was released
39    Released {
40        /// The value of the action
41        action: A,
42    },
43    /// The value of the action changed
44    AxisChanged {
45        /// The value of the action
46        action: A,
47        /// The new value of the action
48        value: f32,
49    },
50    /// The axis pair of the action changed
51    DualAxisChanged {
52        /// The value of the action
53        action: A,
54        /// The new value of the axes
55        axis_pair: Vec2,
56    },
57    /// The axis pair of the action changed
58    TripleAxisChanged {
59        /// The value of the action
60        action: A,
61        /// The new value of the axes
62        axis_triple: Vec3,
63    },
64}
65
66/// Will store an `ActionDiff` as well as what generated it (either an Entity, or nothing if the
67/// input actions are represented by a `Resource`)
68///
69/// These are typically accessed using the `Events<ActionDiffEvent>` resource.
70#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Event)]
71pub struct ActionDiffEvent<A: Actionlike> {
72    /// If some: the entity that has the `ActionState<A>` component
73    /// If none: `ActionState<A>` is a Resource, not a component
74    pub owner: Option<Entity>,
75    /// The `ActionDiff` that was generated
76    pub action_diffs: Vec<ActionDiff<A>>,
77}
78
79/// Implements entity mapping for `ActionDiffEvent`.
80///
81/// This allows the owner entity to be remapped when transferring event diffs
82/// between different ECS worlds (e.g. client and server).
83impl<A: Actionlike> MapEntities for ActionDiffEvent<A> {
84    fn map_entities<M: EntityMapper>(&mut self, entity_mapper: &mut M) {
85        self.owner = self.owner.map(|entity| entity_mapper.get_mapped(entity));
86    }
87}
88
89/// Stores the state of all actions in the current frame.
90///
91/// Inside of the hashmap, [`Entity::PLACEHOLDER`] represents the global / resource state of the action.
92#[derive(Debug, PartialEq, Clone)]
93pub struct SummarizedActionState<A: Actionlike> {
94    button_state_map: HashMap<Entity, HashMap<A, ButtonValue>>,
95    axis_state_map: HashMap<Entity, HashMap<A, f32>>,
96    dual_axis_state_map: HashMap<Entity, HashMap<A, Vec2>>,
97    triple_axis_state_map: HashMap<Entity, HashMap<A, Vec3>>,
98}
99
100impl<A: Actionlike> SummarizedActionState<A> {
101    /// Returns a list of all entities that are contained within this data structure.
102    ///
103    /// This includes the global / resource state, using [`Entity::PLACEHOLDER`].
104    pub fn all_entities(&self) -> HashSet<Entity> {
105        let mut entities = HashSet::default();
106        let button_entities = self.button_state_map.keys();
107        let axis_entities = self.axis_state_map.keys();
108        let dual_axis_entities = self.dual_axis_state_map.keys();
109        let triple_axis_entities = self.triple_axis_state_map.keys();
110
111        entities.extend(button_entities);
112        entities.extend(axis_entities);
113        entities.extend(dual_axis_entities);
114        entities.extend(triple_axis_entities);
115
116        entities
117    }
118
119    /// Captures the raw values for each action in the current frame, for all entities with `ActionState<A>`.
120    pub fn summarize(
121        global_action_state: Option<Res<ActionState<A>>>,
122        action_state_query: Query<(Entity, &ActionState<A>)>,
123    ) -> Self {
124        Self::summarize_filtered(global_action_state, action_state_query)
125    }
126
127    /// Captures the raw values for each action in the current frame, for entities with `ActionState<A>`
128    /// matching the query filter.
129    pub fn summarize_filtered<F: QueryFilter>(
130        global_action_state: Option<Res<ActionState<A>>>,
131        action_state_query: Query<(Entity, &ActionState<A>), F>,
132    ) -> Self {
133        let mut button_state_map = HashMap::default();
134        let mut axis_state_map = HashMap::default();
135        let mut dual_axis_state_map = HashMap::default();
136        let mut triple_axis_state_map = HashMap::default();
137
138        if let Some(global_action_state) = global_action_state {
139            let mut per_entity_button_state = HashMap::default();
140            let mut per_entity_axis_state = HashMap::default();
141            let mut per_entity_dual_axis_state = HashMap::default();
142            let mut per_entity_triple_axis_state = HashMap::default();
143
144            for (action, action_data) in global_action_state.all_action_data() {
145                match &action_data.kind_data {
146                    ActionKindData::Button(button_data) => {
147                        per_entity_button_state
148                            .insert(action.clone(), button_data.to_button_value());
149                    }
150                    ActionKindData::Axis(axis_data) => {
151                        per_entity_axis_state.insert(action.clone(), axis_data.value);
152                    }
153                    ActionKindData::DualAxis(dual_axis_data) => {
154                        per_entity_dual_axis_state.insert(action.clone(), dual_axis_data.pair);
155                    }
156                    ActionKindData::TripleAxis(triple_axis_data) => {
157                        per_entity_triple_axis_state
158                            .insert(action.clone(), triple_axis_data.triple);
159                    }
160                }
161            }
162
163            button_state_map.insert(Entity::PLACEHOLDER, per_entity_button_state);
164            axis_state_map.insert(Entity::PLACEHOLDER, per_entity_axis_state);
165            dual_axis_state_map.insert(Entity::PLACEHOLDER, per_entity_dual_axis_state);
166            triple_axis_state_map.insert(Entity::PLACEHOLDER, per_entity_triple_axis_state);
167        }
168
169        for (entity, action_state) in action_state_query.iter() {
170            let mut per_entity_button_state = HashMap::default();
171            let mut per_entity_axis_state = HashMap::default();
172            let mut per_entity_dual_axis_state = HashMap::default();
173            let mut per_entity_triple_axis_state = HashMap::default();
174
175            for (action, action_data) in action_state.all_action_data() {
176                match &action_data.kind_data {
177                    ActionKindData::Button(button_data) => {
178                        per_entity_button_state
179                            .insert(action.clone(), button_data.to_button_value());
180                    }
181                    ActionKindData::Axis(axis_data) => {
182                        per_entity_axis_state.insert(action.clone(), axis_data.value);
183                    }
184                    ActionKindData::DualAxis(dual_axis_data) => {
185                        per_entity_dual_axis_state.insert(action.clone(), dual_axis_data.pair);
186                    }
187                    ActionKindData::TripleAxis(triple_axis_data) => {
188                        per_entity_triple_axis_state
189                            .insert(action.clone(), triple_axis_data.triple);
190                    }
191                }
192            }
193
194            button_state_map.insert(entity, per_entity_button_state);
195            axis_state_map.insert(entity, per_entity_axis_state);
196            dual_axis_state_map.insert(entity, per_entity_dual_axis_state);
197            triple_axis_state_map.insert(entity, per_entity_triple_axis_state);
198        }
199
200        Self {
201            button_state_map,
202            axis_state_map,
203            dual_axis_state_map,
204            triple_axis_state_map,
205        }
206    }
207
208    /// Generates an [`ActionDiff`] for button data,
209    /// if the button has changed state.
210    ///
211    ///
212    /// Previous values will be treated as default if they were not present.
213    pub fn button_diff(
214        action: A,
215        previous_button: Option<ButtonValue>,
216        current_button: Option<ButtonValue>,
217    ) -> Option<ActionDiff<A>> {
218        let previous_button = previous_button.unwrap_or_default();
219        let current_button = current_button?;
220
221        (previous_button != current_button).then(|| {
222            if current_button.pressed {
223                ActionDiff::Pressed {
224                    action,
225                    value: current_button.value,
226                }
227            } else {
228                ActionDiff::Released { action }
229            }
230        })
231    }
232
233    /// Generates an [`ActionDiff`] for axis data,
234    /// if the axis has changed state.
235    ///
236    /// Previous values will be treated as default if they were not present.
237    pub fn axis_diff(
238        action: A,
239        previous_axis: Option<f32>,
240        current_axis: Option<f32>,
241    ) -> Option<ActionDiff<A>> {
242        let previous_axis = previous_axis.unwrap_or_default();
243        let current_axis = current_axis?;
244
245        (previous_axis != current_axis).then(|| ActionDiff::AxisChanged {
246            action,
247            value: current_axis,
248        })
249    }
250
251    /// Generates an [`ActionDiff`] for dual axis data,
252    /// if the dual axis has changed state.
253    pub fn dual_axis_diff(
254        action: A,
255        previous_dual_axis: Option<Vec2>,
256        current_dual_axis: Option<Vec2>,
257    ) -> Option<ActionDiff<A>> {
258        let previous_dual_axis = previous_dual_axis.unwrap_or_default();
259        let current_dual_axis = current_dual_axis?;
260
261        (previous_dual_axis != current_dual_axis).then(|| ActionDiff::DualAxisChanged {
262            action,
263            axis_pair: current_dual_axis,
264        })
265    }
266
267    /// Generates an [`ActionDiff`] for triple axis data,
268    /// if the triple axis has changed state.
269    pub fn triple_axis_diff(
270        action: A,
271        previous_triple_axis: Option<Vec3>,
272        current_triple_axis: Option<Vec3>,
273    ) -> Option<ActionDiff<A>> {
274        let previous_triple_axis = previous_triple_axis.unwrap_or_default();
275        let current_triple_axis = current_triple_axis?;
276
277        (previous_triple_axis != current_triple_axis).then(|| ActionDiff::TripleAxisChanged {
278            action,
279            axis_triple: current_triple_axis,
280        })
281    }
282
283    /// Generates all [`ActionDiff`]s for a single entity.
284    pub fn entity_diffs(&self, entity: &Entity, previous: &Self) -> Vec<ActionDiff<A>> {
285        let mut action_diffs = Vec::new();
286
287        if let Some(current_button_state) = self.button_state_map.get(entity) {
288            let previous_button_state = previous.button_state_map.get(entity);
289            for (action, current_button) in current_button_state {
290                let previous_button = previous_button_state
291                    .and_then(|previous_button_state| previous_button_state.get(action))
292                    .copied();
293
294                if let Some(diff) =
295                    Self::button_diff(action.clone(), previous_button, Some(*current_button))
296                {
297                    action_diffs.push(diff);
298                }
299            }
300        }
301
302        if let Some(current_axis_state) = self.axis_state_map.get(entity) {
303            let previous_axis_state = previous.axis_state_map.get(entity);
304            for (action, current_axis) in current_axis_state {
305                let previous_axis = previous_axis_state
306                    .and_then(|previous_axis_state| previous_axis_state.get(action))
307                    .copied();
308
309                if let Some(diff) =
310                    Self::axis_diff(action.clone(), previous_axis, Some(*current_axis))
311                {
312                    action_diffs.push(diff);
313                }
314            }
315        }
316
317        if let Some(current_dual_axis_state) = self.dual_axis_state_map.get(entity) {
318            let previous_dual_axis_state = previous.dual_axis_state_map.get(entity);
319            for (action, current_dual_axis) in current_dual_axis_state {
320                let previous_dual_axis = previous_dual_axis_state
321                    .and_then(|previous_dual_axis_state| previous_dual_axis_state.get(action))
322                    .copied();
323
324                if let Some(diff) = Self::dual_axis_diff(
325                    action.clone(),
326                    previous_dual_axis,
327                    Some(*current_dual_axis),
328                ) {
329                    action_diffs.push(diff);
330                }
331            }
332        }
333
334        if let Some(current_triple_axis_state) = self.triple_axis_state_map.get(entity) {
335            let previous_triple_axis_state = previous.triple_axis_state_map.get(entity);
336            for (action, current_triple_axis) in current_triple_axis_state {
337                let previous_triple_axis = previous_triple_axis_state
338                    .and_then(|previous_triple_axis_state| previous_triple_axis_state.get(action))
339                    .copied();
340
341                if let Some(diff) = Self::triple_axis_diff(
342                    action.clone(),
343                    previous_triple_axis,
344                    Some(*current_triple_axis),
345                ) {
346                    action_diffs.push(diff);
347                }
348            }
349        }
350
351        action_diffs
352    }
353
354    /// Compares the current frame to the previous frame, generates [`ActionDiff`]s and then sends them as batched [`ActionDiffEvent`]s.
355    pub fn send_diffs(&self, previous: &Self, writer: &mut EventWriter<ActionDiffEvent<A>>) {
356        for entity in self.all_entities() {
357            let owner = (entity != Entity::PLACEHOLDER).then_some(entity);
358
359            let action_diffs = self.entity_diffs(&entity, previous);
360
361            if !action_diffs.is_empty() {
362                writer.write(ActionDiffEvent {
363                    owner,
364                    action_diffs,
365                });
366            }
367        }
368    }
369}
370
371// Manual impl due to A not being bounded by Default messing with the derive
372impl<A: Actionlike> Default for SummarizedActionState<A> {
373    fn default() -> Self {
374        Self {
375            button_state_map: Default::default(),
376            axis_state_map: Default::default(),
377            dual_axis_state_map: Default::default(),
378            triple_axis_state_map: Default::default(),
379        }
380    }
381}
382
383#[cfg(test)]
384mod tests {
385    use crate as leafwing_input_manager;
386
387    use super::*;
388    use crate::buttonlike::ButtonValue;
389    use bevy::{ecs::system::SystemState, prelude::*};
390
391    #[derive(Actionlike, Debug, Clone, Copy, PartialEq, Eq, Hash, Reflect)]
392    enum TestAction {
393        Button,
394        #[actionlike(Axis)]
395        Axis,
396        #[actionlike(DualAxis)]
397        DualAxis,
398        #[actionlike(TripleAxis)]
399        TripleAxis,
400    }
401
402    fn test_action_state() -> ActionState<TestAction> {
403        let mut action_state = ActionState::default();
404        action_state.press(&TestAction::Button);
405        action_state.set_value(&TestAction::Axis, 0.3);
406        action_state.set_axis_pair(&TestAction::DualAxis, Vec2::new(0.5, 0.7));
407        action_state.set_axis_triple(&TestAction::TripleAxis, Vec3::new(0.5, 0.7, 0.9));
408        action_state
409    }
410
411    #[derive(Component)]
412    struct NotSummarized;
413
414    fn expected_summary(entity: Entity) -> SummarizedActionState<TestAction> {
415        let mut button_state_map = HashMap::default();
416        let mut axis_state_map = HashMap::default();
417        let mut dual_axis_state_map = HashMap::default();
418        let mut triple_axis_state_map = HashMap::default();
419
420        let mut global_button_state = HashMap::default();
421        global_button_state.insert(TestAction::Button, ButtonValue::from_pressed(true));
422        button_state_map.insert(entity, global_button_state);
423
424        let mut global_axis_state = HashMap::default();
425        global_axis_state.insert(TestAction::Axis, 0.3);
426        axis_state_map.insert(entity, global_axis_state);
427
428        let mut global_dual_axis_state = HashMap::default();
429        global_dual_axis_state.insert(TestAction::DualAxis, Vec2::new(0.5, 0.7));
430        dual_axis_state_map.insert(entity, global_dual_axis_state);
431
432        let mut global_triple_axis_state = HashMap::default();
433        global_triple_axis_state.insert(TestAction::TripleAxis, Vec3::new(0.5, 0.7, 0.9));
434        triple_axis_state_map.insert(entity, global_triple_axis_state);
435
436        SummarizedActionState {
437            button_state_map,
438            axis_state_map,
439            dual_axis_state_map,
440            triple_axis_state_map,
441        }
442    }
443
444    #[test]
445    fn summarize_from_resource() {
446        let mut world = World::new();
447        world.insert_resource(test_action_state());
448        let mut system_state: SystemState<(
449            Option<Res<ActionState<TestAction>>>,
450            Query<(Entity, &ActionState<TestAction>)>,
451        )> = SystemState::new(&mut world);
452        let (global_action_state, action_state_query) = system_state.get(&world);
453        let summarized = SummarizedActionState::summarize(global_action_state, action_state_query);
454
455        // Resources use the placeholder entity
456        assert_eq!(summarized, expected_summary(Entity::PLACEHOLDER));
457    }
458
459    #[test]
460    fn summarize_from_component() {
461        let mut world = World::new();
462        let entity = world.spawn(test_action_state()).id();
463        let mut system_state: SystemState<(
464            Option<Res<ActionState<TestAction>>>,
465            Query<(Entity, &ActionState<TestAction>)>,
466        )> = SystemState::new(&mut world);
467        let (global_action_state, action_state_query) = system_state.get(&world);
468        let summarized = SummarizedActionState::summarize(global_action_state, action_state_query);
469
470        // Components use the entity
471        assert_eq!(summarized, expected_summary(entity));
472    }
473
474    #[test]
475    fn summarize_filtered_entities_from_component() {
476        // Spawn two entities, one to be summarized and one to be filtered out
477        let mut world = World::new();
478        let entity = world.spawn(test_action_state()).id();
479        world.spawn((test_action_state(), NotSummarized));
480
481        let mut system_state: SystemState<(
482            Option<Res<ActionState<TestAction>>>,
483            Query<(Entity, &ActionState<TestAction>), Without<NotSummarized>>,
484        )> = SystemState::new(&mut world);
485        let (global_action_state, action_state_query) = system_state.get(&world);
486        let summarized =
487            SummarizedActionState::summarize_filtered(global_action_state, action_state_query);
488
489        // Check that only the entity without NotSummarized was summarized
490        assert_eq!(summarized, expected_summary(entity));
491    }
492
493    #[test]
494    fn diffs_are_sent() {
495        let mut world = World::new();
496        world.init_resource::<Events<ActionDiffEvent<TestAction>>>();
497
498        let entity = world.spawn(test_action_state()).id();
499        let mut system_state: SystemState<(
500            Option<Res<ActionState<TestAction>>>,
501            Query<(Entity, &ActionState<TestAction>)>,
502            EventWriter<ActionDiffEvent<TestAction>>,
503        )> = SystemState::new(&mut world);
504        let (global_action_state, action_state_query, mut action_diff_writer) =
505            system_state.get_mut(&mut world);
506        let summarized = SummarizedActionState::summarize(global_action_state, action_state_query);
507
508        let previous = SummarizedActionState::default();
509        summarized.send_diffs(&previous, &mut action_diff_writer);
510
511        let mut system_state: SystemState<EventReader<ActionDiffEvent<TestAction>>> =
512            SystemState::new(&mut world);
513        let mut event_reader = system_state.get_mut(&mut world);
514        let action_diff_events = event_reader.read().collect::<Vec<_>>();
515
516        dbg!(&action_diff_events);
517        assert_eq!(action_diff_events.len(), 1);
518        let action_diff_event = action_diff_events[0];
519        assert_eq!(action_diff_event.owner, Some(entity));
520        assert_eq!(action_diff_event.action_diffs.len(), 4);
521    }
522}