Skip to main content

leafwing_input_manager/
action_diff.rs

1//! Serialization-friendly representation of changes to [`ActionState`].
2//!
3//! These are predominantly intended for use in networked games,
4//! where the server needs to know what the players are doing.
5//! They would like a compact, semantically meaningful representation of the changes to the game state without needing to know
6//! about things like keybindings or input devices.
7
8use bevy::{
9    ecs::{
10        entity::{Entity, MapEntities},
11        message::Message,
12        query::QueryFilter,
13    },
14    math::{Vec2, Vec3},
15    platform::collections::{HashMap, HashSet},
16    prelude::{EntityMapper, MessageWriter, Query},
17};
18use serde::{Deserialize, Serialize};
19
20use crate::buttonlike::ButtonValue;
21use crate::{Actionlike, action_state::ActionKindData, prelude::ActionState};
22
23/// Stores presses and releases of buttons without timing information
24///
25/// These are typically accessed using the `Messages<ActionDiffMessage>` resource.
26/// Uses a minimal storage format to facilitate transport over the network.
27///
28/// An `ActionState` can be fully reconstructed from a stream of `ActionDiff`.
29#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
30pub enum ActionDiff<A: Actionlike> {
31    /// The action was pressed
32    Pressed {
33        /// The value of the action
34        action: A,
35        /// The new value of the action
36        value: f32,
37    },
38    /// The action was released
39    Released {
40        /// The value of the action
41        action: A,
42    },
43    /// The value of the action changed
44    AxisChanged {
45        /// The value of the action
46        action: A,
47        /// The new value of the action
48        value: f32,
49    },
50    /// The axis pair of the action changed
51    DualAxisChanged {
52        /// The value of the action
53        action: A,
54        /// The new value of the axes
55        axis_pair: Vec2,
56    },
57    /// The axis pair of the action changed
58    TripleAxisChanged {
59        /// The value of the action
60        action: A,
61        /// The new value of the axes
62        axis_triple: Vec3,
63    },
64}
65
66/// Will store an `ActionDiff` as well as the entity that generated it.
67///
68/// These are typically accessed using the `Messages<ActionDiffMessage>` resource.
69#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Message)]
70pub struct ActionDiffMessage<A: Actionlike> {
71    /// The entity that has the `ActionState<A>` component
72    pub owner: Entity,
73    /// The `ActionDiff` that was generated
74    pub action_diffs: Vec<ActionDiff<A>>,
75}
76
77/// Implements entity mapping for [`ActionDiffMessage`].
78///
79/// This allows the owner entity to be remapped when transferring message diffs
80/// between different ECS worlds (e.g. client and server).
81impl<A: Actionlike> MapEntities for ActionDiffMessage<A> {
82    fn map_entities<M: EntityMapper>(&mut self, entity_mapper: &mut M) {
83        self.owner = entity_mapper.get_mapped(self.owner);
84    }
85}
86
87/// Stores the state of all actions in the current frame, keyed by the entity that owns each
88/// [`ActionState`].
89#[derive(Debug, PartialEq, Clone)]
90pub struct SummarizedActionState<A: Actionlike> {
91    button_state_map: HashMap<Entity, HashMap<A, ButtonValue>>,
92    axis_state_map: HashMap<Entity, HashMap<A, f32>>,
93    dual_axis_state_map: HashMap<Entity, HashMap<A, Vec2>>,
94    triple_axis_state_map: HashMap<Entity, HashMap<A, Vec3>>,
95}
96
97impl<A: Actionlike> SummarizedActionState<A> {
98    /// Returns a list of all entities that are contained within this data structure.
99    pub fn all_entities(&self) -> HashSet<Entity> {
100        let mut entities = HashSet::default();
101        let button_entities = self.button_state_map.keys();
102        let axis_entities = self.axis_state_map.keys();
103        let dual_axis_entities = self.dual_axis_state_map.keys();
104        let triple_axis_entities = self.triple_axis_state_map.keys();
105
106        entities.extend(button_entities);
107        entities.extend(axis_entities);
108        entities.extend(dual_axis_entities);
109        entities.extend(triple_axis_entities);
110
111        entities
112    }
113
114    /// Captures the raw values for each action in the current frame, for all entities with `ActionState<A>`.
115    pub fn summarize(action_state_query: Query<(Entity, &ActionState<A>)>) -> Self {
116        Self::summarize_filtered(action_state_query)
117    }
118
119    /// Captures the raw values for each action in the current frame, for entities with `ActionState<A>`
120    /// matching the query filter.
121    pub fn summarize_filtered<F: QueryFilter>(
122        action_state_query: Query<(Entity, &ActionState<A>), F>,
123    ) -> Self {
124        let mut button_state_map = HashMap::default();
125        let mut axis_state_map = HashMap::default();
126        let mut dual_axis_state_map = HashMap::default();
127        let mut triple_axis_state_map = HashMap::default();
128
129        for (entity, action_state) in action_state_query
130            .iter()
131            .filter(|(_, action_state)| !action_state.disabled())
132        {
133            let mut per_entity_button_state = HashMap::default();
134            let mut per_entity_axis_state = HashMap::default();
135            let mut per_entity_dual_axis_state = HashMap::default();
136            let mut per_entity_triple_axis_state = HashMap::default();
137
138            for (action, action_data) in action_state
139                .all_action_data()
140                .iter()
141                .filter(|(_, action_data)| !action_data.disabled)
142            {
143                match &action_data.kind_data {
144                    ActionKindData::Button(button_data) => {
145                        per_entity_button_state
146                            .insert(action.clone(), button_data.to_button_value());
147                    }
148                    ActionKindData::Axis(axis_data) => {
149                        per_entity_axis_state.insert(action.clone(), axis_data.value);
150                    }
151                    ActionKindData::DualAxis(dual_axis_data) => {
152                        per_entity_dual_axis_state.insert(action.clone(), dual_axis_data.pair);
153                    }
154                    ActionKindData::TripleAxis(triple_axis_data) => {
155                        per_entity_triple_axis_state
156                            .insert(action.clone(), triple_axis_data.triple);
157                    }
158                }
159            }
160
161            button_state_map.insert(entity, per_entity_button_state);
162            axis_state_map.insert(entity, per_entity_axis_state);
163            dual_axis_state_map.insert(entity, per_entity_dual_axis_state);
164            triple_axis_state_map.insert(entity, per_entity_triple_axis_state);
165        }
166
167        Self {
168            button_state_map,
169            axis_state_map,
170            dual_axis_state_map,
171            triple_axis_state_map,
172        }
173    }
174
175    /// Generates an [`ActionDiff`] for button data,
176    /// if the button has changed state.
177    ///
178    ///
179    /// Previous values will be treated as default if they were not present.
180    pub fn button_diff(
181        action: A,
182        previous_button: Option<ButtonValue>,
183        current_button: Option<ButtonValue>,
184    ) -> Option<ActionDiff<A>> {
185        let previous_button = previous_button.unwrap_or_default();
186        let current_button = current_button?;
187
188        (previous_button != current_button).then(|| {
189            if current_button.pressed {
190                ActionDiff::Pressed {
191                    action,
192                    value: current_button.value,
193                }
194            } else {
195                ActionDiff::Released { action }
196            }
197        })
198    }
199
200    /// Generates an [`ActionDiff`] for axis data,
201    /// if the axis has changed state.
202    ///
203    /// Previous values will be treated as default if they were not present.
204    pub fn axis_diff(
205        action: A,
206        previous_axis: Option<f32>,
207        current_axis: Option<f32>,
208    ) -> Option<ActionDiff<A>> {
209        let previous_axis = previous_axis.unwrap_or_default();
210        let current_axis = current_axis?;
211
212        (previous_axis != current_axis).then(|| ActionDiff::AxisChanged {
213            action,
214            value: current_axis,
215        })
216    }
217
218    /// Generates an [`ActionDiff`] for dual axis data,
219    /// if the dual axis has changed state.
220    pub fn dual_axis_diff(
221        action: A,
222        previous_dual_axis: Option<Vec2>,
223        current_dual_axis: Option<Vec2>,
224    ) -> Option<ActionDiff<A>> {
225        let previous_dual_axis = previous_dual_axis.unwrap_or_default();
226        let current_dual_axis = current_dual_axis?;
227
228        (previous_dual_axis != current_dual_axis).then(|| ActionDiff::DualAxisChanged {
229            action,
230            axis_pair: current_dual_axis,
231        })
232    }
233
234    /// Generates an [`ActionDiff`] for triple axis data,
235    /// if the triple axis has changed state.
236    pub fn triple_axis_diff(
237        action: A,
238        previous_triple_axis: Option<Vec3>,
239        current_triple_axis: Option<Vec3>,
240    ) -> Option<ActionDiff<A>> {
241        let previous_triple_axis = previous_triple_axis.unwrap_or_default();
242        let current_triple_axis = current_triple_axis?;
243
244        (previous_triple_axis != current_triple_axis).then(|| ActionDiff::TripleAxisChanged {
245            action,
246            axis_triple: current_triple_axis,
247        })
248    }
249
250    /// Generates all [`ActionDiff`]s for a single entity.
251    pub fn entity_diffs(&self, entity: &Entity, previous: &Self) -> Vec<ActionDiff<A>> {
252        let mut action_diffs = Vec::new();
253
254        if let Some(current_button_state) = self.button_state_map.get(entity) {
255            let previous_button_state = previous.button_state_map.get(entity);
256            for (action, current_button) in current_button_state {
257                let previous_button = previous_button_state
258                    .and_then(|previous_button_state| previous_button_state.get(action))
259                    .copied();
260
261                if let Some(diff) =
262                    Self::button_diff(action.clone(), previous_button, Some(*current_button))
263                {
264                    action_diffs.push(diff);
265                }
266            }
267        }
268
269        if let Some(current_axis_state) = self.axis_state_map.get(entity) {
270            let previous_axis_state = previous.axis_state_map.get(entity);
271            for (action, current_axis) in current_axis_state {
272                let previous_axis = previous_axis_state
273                    .and_then(|previous_axis_state| previous_axis_state.get(action))
274                    .copied();
275
276                if let Some(diff) =
277                    Self::axis_diff(action.clone(), previous_axis, Some(*current_axis))
278                {
279                    action_diffs.push(diff);
280                }
281            }
282        }
283
284        if let Some(current_dual_axis_state) = self.dual_axis_state_map.get(entity) {
285            let previous_dual_axis_state = previous.dual_axis_state_map.get(entity);
286            for (action, current_dual_axis) in current_dual_axis_state {
287                let previous_dual_axis = previous_dual_axis_state
288                    .and_then(|previous_dual_axis_state| previous_dual_axis_state.get(action))
289                    .copied();
290
291                if let Some(diff) = Self::dual_axis_diff(
292                    action.clone(),
293                    previous_dual_axis,
294                    Some(*current_dual_axis),
295                ) {
296                    action_diffs.push(diff);
297                }
298            }
299        }
300
301        if let Some(current_triple_axis_state) = self.triple_axis_state_map.get(entity) {
302            let previous_triple_axis_state = previous.triple_axis_state_map.get(entity);
303            for (action, current_triple_axis) in current_triple_axis_state {
304                let previous_triple_axis = previous_triple_axis_state
305                    .and_then(|previous_triple_axis_state| previous_triple_axis_state.get(action))
306                    .copied();
307
308                if let Some(diff) = Self::triple_axis_diff(
309                    action.clone(),
310                    previous_triple_axis,
311                    Some(*current_triple_axis),
312                ) {
313                    action_diffs.push(diff);
314                }
315            }
316        }
317
318        action_diffs
319    }
320
321    /// Compares the current frame to the previous frame, generates [`ActionDiff`]s and then sends them as batched [`ActionDiffMessage`]s.
322    pub fn send_diffs(&self, previous: &Self, writer: &mut MessageWriter<ActionDiffMessage<A>>) {
323        for entity in self.all_entities() {
324            let action_diffs = self.entity_diffs(&entity, previous);
325
326            if !action_diffs.is_empty() {
327                writer.write(ActionDiffMessage {
328                    owner: entity,
329                    action_diffs,
330                });
331            }
332        }
333    }
334}
335
336// Manual impl due to A not being bounded by Default messing with the derive
337impl<A: Actionlike> Default for SummarizedActionState<A> {
338    fn default() -> Self {
339        Self {
340            button_state_map: Default::default(),
341            axis_state_map: Default::default(),
342            dual_axis_state_map: Default::default(),
343            triple_axis_state_map: Default::default(),
344        }
345    }
346}
347
348#[cfg(test)]
349mod tests {
350    use crate as leafwing_input_manager;
351
352    use super::*;
353    use crate::buttonlike::ButtonValue;
354    use bevy::{ecs::system::SystemState, prelude::*};
355
356    #[derive(Actionlike, Debug, Clone, Copy, PartialEq, Eq, Hash, Reflect)]
357    enum TestAction {
358        Button,
359        #[actionlike(Axis)]
360        Axis,
361        #[actionlike(DualAxis)]
362        DualAxis,
363        #[actionlike(TripleAxis)]
364        TripleAxis,
365    }
366
367    fn test_action_state() -> ActionState<TestAction> {
368        let mut action_state = ActionState::default();
369        action_state.press(&TestAction::Button);
370        action_state.set_value(&TestAction::Axis, 0.3);
371        action_state.set_axis_pair(&TestAction::DualAxis, Vec2::new(0.5, 0.7));
372        action_state.set_axis_triple(&TestAction::TripleAxis, Vec3::new(0.5, 0.7, 0.9));
373        action_state
374    }
375
376    #[derive(Component)]
377    struct NotSummarized;
378
379    fn expected_summary(entity: Entity) -> SummarizedActionState<TestAction> {
380        let mut button_state_map = HashMap::default();
381        let mut axis_state_map = HashMap::default();
382        let mut dual_axis_state_map = HashMap::default();
383        let mut triple_axis_state_map = HashMap::default();
384
385        let mut global_button_state = HashMap::default();
386        global_button_state.insert(TestAction::Button, ButtonValue::from_pressed(true));
387        button_state_map.insert(entity, global_button_state);
388
389        let mut global_axis_state = HashMap::default();
390        global_axis_state.insert(TestAction::Axis, 0.3);
391        axis_state_map.insert(entity, global_axis_state);
392
393        let mut global_dual_axis_state = HashMap::default();
394        global_dual_axis_state.insert(TestAction::DualAxis, Vec2::new(0.5, 0.7));
395        dual_axis_state_map.insert(entity, global_dual_axis_state);
396
397        let mut global_triple_axis_state = HashMap::default();
398        global_triple_axis_state.insert(TestAction::TripleAxis, Vec3::new(0.5, 0.7, 0.9));
399        triple_axis_state_map.insert(entity, global_triple_axis_state);
400
401        SummarizedActionState {
402            button_state_map,
403            axis_state_map,
404            dual_axis_state_map,
405            triple_axis_state_map,
406        }
407    }
408
409    #[test]
410    fn summarize_from_component() {
411        let mut world = World::new();
412        let entity = world.spawn(test_action_state()).id();
413        let mut system_state: SystemState<Query<(Entity, &ActionState<TestAction>)>> =
414            SystemState::new(&mut world);
415        let action_state_query = system_state.get(&world).unwrap();
416        let summarized = SummarizedActionState::summarize(action_state_query);
417
418        // Components use the entity
419        assert_eq!(summarized, expected_summary(entity));
420    }
421
422    #[test]
423    fn summarize_filtered_entities_from_component() {
424        // Spawn two entities, one to be summarized and one to be filtered out
425        let mut world = World::new();
426        let entity = world.spawn(test_action_state()).id();
427        world.spawn((test_action_state(), NotSummarized));
428
429        let mut system_state: SystemState<
430            Query<(Entity, &ActionState<TestAction>), Without<NotSummarized>>,
431        > = SystemState::new(&mut world);
432        let action_state_query = system_state.get(&world).unwrap();
433        let summarized = SummarizedActionState::summarize_filtered(action_state_query);
434
435        // Check that only the entity without NotSummarized was summarized
436        assert_eq!(summarized, expected_summary(entity));
437    }
438
439    #[test]
440    fn diffs_are_sent() {
441        let mut world = World::new();
442        world.init_resource::<Messages<ActionDiffMessage<TestAction>>>();
443
444        let entity = world.spawn(test_action_state()).id();
445        let mut system_state: SystemState<(
446            Query<(Entity, &ActionState<TestAction>)>,
447            MessageWriter<ActionDiffMessage<TestAction>>,
448        )> = SystemState::new(&mut world);
449        let (action_state_query, mut action_diff_writer) =
450            system_state.get_mut(&mut world).unwrap();
451        let summarized = SummarizedActionState::summarize(action_state_query);
452
453        let previous = SummarizedActionState::default();
454        summarized.send_diffs(&previous, &mut action_diff_writer);
455
456        let mut system_state: SystemState<MessageReader<ActionDiffMessage<TestAction>>> =
457            SystemState::new(&mut world);
458        let mut message_reader = system_state.get_mut(&mut world).unwrap();
459        let action_diff_messages = message_reader.read().collect::<Vec<_>>();
460
461        dbg!(&action_diff_messages);
462        assert_eq!(action_diff_messages.len(), 1);
463        let action_diff_message = action_diff_messages[0];
464        assert_eq!(action_diff_message.owner, entity);
465        assert_eq!(action_diff_message.action_diffs.len(), 4);
466    }
467
468    fn test_action_state_disabled() -> ActionState<TestAction> {
469        let mut action_state = ActionState::default();
470        action_state.press(&TestAction::Button);
471        action_state.set_value(&TestAction::Axis, 0.3);
472        action_state.set_axis_pair(&TestAction::DualAxis, Vec2::new(0.5, 0.7));
473        action_state.set_axis_triple(&TestAction::TripleAxis, Vec3::new(0.5, 0.7, 0.9));
474        action_state.disable();
475        action_state
476    }
477
478    fn expected_summary_when_disabled() -> SummarizedActionState<TestAction> {
479        let button_state_map = HashMap::default();
480        let axis_state_map = HashMap::default();
481        let dual_axis_state_map = HashMap::default();
482        let triple_axis_state_map = HashMap::default();
483
484        SummarizedActionState {
485            button_state_map,
486            axis_state_map,
487            dual_axis_state_map,
488            triple_axis_state_map,
489        }
490    }
491
492    #[test]
493    fn summarize_filtered_from_disabled_component() {
494        let mut world = World::new();
495        world.spawn((test_action_state_disabled(), NotSummarized));
496
497        let mut system_state: SystemState<
498            Query<(Entity, &ActionState<TestAction>), Without<NotSummarized>>,
499        > = SystemState::new(&mut world);
500        let action_state_query = system_state.get(&world).unwrap();
501        let summarized = SummarizedActionState::summarize_filtered(action_state_query);
502
503        // Check that only the entity without NotSummarized was summarized
504        assert_eq!(summarized, expected_summary_when_disabled());
505    }
506
507    fn test_action_state_disabled_action() -> ActionState<TestAction> {
508        let mut action_state = ActionState::default();
509        action_state.press(&TestAction::Button);
510        action_state.set_value(&TestAction::Axis, 0.3);
511        action_state.set_axis_pair(&TestAction::DualAxis, Vec2::new(0.5, 0.7));
512        action_state.set_axis_triple(&TestAction::TripleAxis, Vec3::new(0.5, 0.7, 0.9));
513        action_state.disable_action(&TestAction::Button);
514        action_state
515    }
516
517    fn expected_summary_with_disabled_action(entity: Entity) -> SummarizedActionState<TestAction> {
518        let mut button_state_map = HashMap::default();
519        let mut axis_state_map = HashMap::default();
520        let mut dual_axis_state_map = HashMap::default();
521        let mut triple_axis_state_map = HashMap::default();
522
523        let global_button_state = HashMap::default();
524        button_state_map.insert(entity, global_button_state);
525
526        let mut global_axis_state = HashMap::default();
527        global_axis_state.insert(TestAction::Axis, 0.3);
528        axis_state_map.insert(entity, global_axis_state);
529
530        let mut global_dual_axis_state = HashMap::default();
531        global_dual_axis_state.insert(TestAction::DualAxis, Vec2::new(0.5, 0.7));
532        dual_axis_state_map.insert(entity, global_dual_axis_state);
533
534        let mut global_triple_axis_state = HashMap::default();
535        global_triple_axis_state.insert(TestAction::TripleAxis, Vec3::new(0.5, 0.7, 0.9));
536        triple_axis_state_map.insert(entity, global_triple_axis_state);
537
538        SummarizedActionState {
539            button_state_map,
540            axis_state_map,
541            dual_axis_state_map,
542            triple_axis_state_map,
543        }
544    }
545
546    #[test]
547    fn summarize_filtered_entites_from_component_disabled_action() {
548        let mut world = World::new();
549        let entity = world.spawn(test_action_state_disabled_action()).id();
550
551        let mut system_state: SystemState<
552            Query<(Entity, &ActionState<TestAction>), Without<NotSummarized>>,
553        > = SystemState::new(&mut world);
554        let action_state_query = system_state.get(&world).unwrap();
555        let summarized = SummarizedActionState::summarize_filtered(action_state_query);
556
557        // Check that only the entity without NotSummarized was summarized
558        assert_eq!(summarized, expected_summary_with_disabled_action(entity));
559    }
560}