1use bevy::{
9 ecs::{
10 entity::{Entity, MapEntities},
11 event::Event,
12 query::QueryFilter,
13 },
14 math::{Vec2, Vec3},
15 platform::collections::{HashMap, HashSet},
16 prelude::{EntityMapper, EventWriter, Query, Res},
17};
18use serde::{Deserialize, Serialize};
19
20use crate::buttonlike::ButtonValue;
21use crate::{action_state::ActionKindData, prelude::ActionState, Actionlike};
22
23#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
30pub enum ActionDiff<A: Actionlike> {
31 Pressed {
33 action: A,
35 value: f32,
37 },
38 Released {
40 action: A,
42 },
43 AxisChanged {
45 action: A,
47 value: f32,
49 },
50 DualAxisChanged {
52 action: A,
54 axis_pair: Vec2,
56 },
57 TripleAxisChanged {
59 action: A,
61 axis_triple: Vec3,
63 },
64}
65
66#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Event)]
71pub struct ActionDiffEvent<A: Actionlike> {
72 pub owner: Option<Entity>,
75 pub action_diffs: Vec<ActionDiff<A>>,
77}
78
79impl<A: Actionlike> MapEntities for ActionDiffEvent<A> {
84 fn map_entities<M: EntityMapper>(&mut self, entity_mapper: &mut M) {
85 self.owner = self.owner.map(|entity| entity_mapper.get_mapped(entity));
86 }
87}
88
89#[derive(Debug, PartialEq, Clone)]
93pub struct SummarizedActionState<A: Actionlike> {
94 button_state_map: HashMap<Entity, HashMap<A, ButtonValue>>,
95 axis_state_map: HashMap<Entity, HashMap<A, f32>>,
96 dual_axis_state_map: HashMap<Entity, HashMap<A, Vec2>>,
97 triple_axis_state_map: HashMap<Entity, HashMap<A, Vec3>>,
98}
99
100impl<A: Actionlike> SummarizedActionState<A> {
101 pub fn all_entities(&self) -> HashSet<Entity> {
105 let mut entities = HashSet::default();
106 let button_entities = self.button_state_map.keys();
107 let axis_entities = self.axis_state_map.keys();
108 let dual_axis_entities = self.dual_axis_state_map.keys();
109 let triple_axis_entities = self.triple_axis_state_map.keys();
110
111 entities.extend(button_entities);
112 entities.extend(axis_entities);
113 entities.extend(dual_axis_entities);
114 entities.extend(triple_axis_entities);
115
116 entities
117 }
118
119 pub fn summarize(
121 global_action_state: Option<Res<ActionState<A>>>,
122 action_state_query: Query<(Entity, &ActionState<A>)>,
123 ) -> Self {
124 Self::summarize_filtered(global_action_state, action_state_query)
125 }
126
127 pub fn summarize_filtered<F: QueryFilter>(
130 global_action_state: Option<Res<ActionState<A>>>,
131 action_state_query: Query<(Entity, &ActionState<A>), F>,
132 ) -> Self {
133 let mut button_state_map = HashMap::default();
134 let mut axis_state_map = HashMap::default();
135 let mut dual_axis_state_map = HashMap::default();
136 let mut triple_axis_state_map = HashMap::default();
137
138 if let Some(global_action_state) = global_action_state {
139 let mut per_entity_button_state = HashMap::default();
140 let mut per_entity_axis_state = HashMap::default();
141 let mut per_entity_dual_axis_state = HashMap::default();
142 let mut per_entity_triple_axis_state = HashMap::default();
143
144 for (action, action_data) in global_action_state.all_action_data() {
145 match &action_data.kind_data {
146 ActionKindData::Button(button_data) => {
147 per_entity_button_state
148 .insert(action.clone(), button_data.to_button_value());
149 }
150 ActionKindData::Axis(axis_data) => {
151 per_entity_axis_state.insert(action.clone(), axis_data.value);
152 }
153 ActionKindData::DualAxis(dual_axis_data) => {
154 per_entity_dual_axis_state.insert(action.clone(), dual_axis_data.pair);
155 }
156 ActionKindData::TripleAxis(triple_axis_data) => {
157 per_entity_triple_axis_state
158 .insert(action.clone(), triple_axis_data.triple);
159 }
160 }
161 }
162
163 button_state_map.insert(Entity::PLACEHOLDER, per_entity_button_state);
164 axis_state_map.insert(Entity::PLACEHOLDER, per_entity_axis_state);
165 dual_axis_state_map.insert(Entity::PLACEHOLDER, per_entity_dual_axis_state);
166 triple_axis_state_map.insert(Entity::PLACEHOLDER, per_entity_triple_axis_state);
167 }
168
169 for (entity, action_state) in action_state_query.iter() {
170 let mut per_entity_button_state = HashMap::default();
171 let mut per_entity_axis_state = HashMap::default();
172 let mut per_entity_dual_axis_state = HashMap::default();
173 let mut per_entity_triple_axis_state = HashMap::default();
174
175 for (action, action_data) in action_state.all_action_data() {
176 match &action_data.kind_data {
177 ActionKindData::Button(button_data) => {
178 per_entity_button_state
179 .insert(action.clone(), button_data.to_button_value());
180 }
181 ActionKindData::Axis(axis_data) => {
182 per_entity_axis_state.insert(action.clone(), axis_data.value);
183 }
184 ActionKindData::DualAxis(dual_axis_data) => {
185 per_entity_dual_axis_state.insert(action.clone(), dual_axis_data.pair);
186 }
187 ActionKindData::TripleAxis(triple_axis_data) => {
188 per_entity_triple_axis_state
189 .insert(action.clone(), triple_axis_data.triple);
190 }
191 }
192 }
193
194 button_state_map.insert(entity, per_entity_button_state);
195 axis_state_map.insert(entity, per_entity_axis_state);
196 dual_axis_state_map.insert(entity, per_entity_dual_axis_state);
197 triple_axis_state_map.insert(entity, per_entity_triple_axis_state);
198 }
199
200 Self {
201 button_state_map,
202 axis_state_map,
203 dual_axis_state_map,
204 triple_axis_state_map,
205 }
206 }
207
208 pub fn button_diff(
214 action: A,
215 previous_button: Option<ButtonValue>,
216 current_button: Option<ButtonValue>,
217 ) -> Option<ActionDiff<A>> {
218 let previous_button = previous_button.unwrap_or_default();
219 let current_button = current_button?;
220
221 (previous_button != current_button).then(|| {
222 if current_button.pressed {
223 ActionDiff::Pressed {
224 action,
225 value: current_button.value,
226 }
227 } else {
228 ActionDiff::Released { action }
229 }
230 })
231 }
232
233 pub fn axis_diff(
238 action: A,
239 previous_axis: Option<f32>,
240 current_axis: Option<f32>,
241 ) -> Option<ActionDiff<A>> {
242 let previous_axis = previous_axis.unwrap_or_default();
243 let current_axis = current_axis?;
244
245 (previous_axis != current_axis).then(|| ActionDiff::AxisChanged {
246 action,
247 value: current_axis,
248 })
249 }
250
251 pub fn dual_axis_diff(
254 action: A,
255 previous_dual_axis: Option<Vec2>,
256 current_dual_axis: Option<Vec2>,
257 ) -> Option<ActionDiff<A>> {
258 let previous_dual_axis = previous_dual_axis.unwrap_or_default();
259 let current_dual_axis = current_dual_axis?;
260
261 (previous_dual_axis != current_dual_axis).then(|| ActionDiff::DualAxisChanged {
262 action,
263 axis_pair: current_dual_axis,
264 })
265 }
266
267 pub fn triple_axis_diff(
270 action: A,
271 previous_triple_axis: Option<Vec3>,
272 current_triple_axis: Option<Vec3>,
273 ) -> Option<ActionDiff<A>> {
274 let previous_triple_axis = previous_triple_axis.unwrap_or_default();
275 let current_triple_axis = current_triple_axis?;
276
277 (previous_triple_axis != current_triple_axis).then(|| ActionDiff::TripleAxisChanged {
278 action,
279 axis_triple: current_triple_axis,
280 })
281 }
282
283 pub fn entity_diffs(&self, entity: &Entity, previous: &Self) -> Vec<ActionDiff<A>> {
285 let mut action_diffs = Vec::new();
286
287 if let Some(current_button_state) = self.button_state_map.get(entity) {
288 let previous_button_state = previous.button_state_map.get(entity);
289 for (action, current_button) in current_button_state {
290 let previous_button = previous_button_state
291 .and_then(|previous_button_state| previous_button_state.get(action))
292 .copied();
293
294 if let Some(diff) =
295 Self::button_diff(action.clone(), previous_button, Some(*current_button))
296 {
297 action_diffs.push(diff);
298 }
299 }
300 }
301
302 if let Some(current_axis_state) = self.axis_state_map.get(entity) {
303 let previous_axis_state = previous.axis_state_map.get(entity);
304 for (action, current_axis) in current_axis_state {
305 let previous_axis = previous_axis_state
306 .and_then(|previous_axis_state| previous_axis_state.get(action))
307 .copied();
308
309 if let Some(diff) =
310 Self::axis_diff(action.clone(), previous_axis, Some(*current_axis))
311 {
312 action_diffs.push(diff);
313 }
314 }
315 }
316
317 if let Some(current_dual_axis_state) = self.dual_axis_state_map.get(entity) {
318 let previous_dual_axis_state = previous.dual_axis_state_map.get(entity);
319 for (action, current_dual_axis) in current_dual_axis_state {
320 let previous_dual_axis = previous_dual_axis_state
321 .and_then(|previous_dual_axis_state| previous_dual_axis_state.get(action))
322 .copied();
323
324 if let Some(diff) = Self::dual_axis_diff(
325 action.clone(),
326 previous_dual_axis,
327 Some(*current_dual_axis),
328 ) {
329 action_diffs.push(diff);
330 }
331 }
332 }
333
334 if let Some(current_triple_axis_state) = self.triple_axis_state_map.get(entity) {
335 let previous_triple_axis_state = previous.triple_axis_state_map.get(entity);
336 for (action, current_triple_axis) in current_triple_axis_state {
337 let previous_triple_axis = previous_triple_axis_state
338 .and_then(|previous_triple_axis_state| previous_triple_axis_state.get(action))
339 .copied();
340
341 if let Some(diff) = Self::triple_axis_diff(
342 action.clone(),
343 previous_triple_axis,
344 Some(*current_triple_axis),
345 ) {
346 action_diffs.push(diff);
347 }
348 }
349 }
350
351 action_diffs
352 }
353
354 pub fn send_diffs(&self, previous: &Self, writer: &mut EventWriter<ActionDiffEvent<A>>) {
356 for entity in self.all_entities() {
357 let owner = (entity != Entity::PLACEHOLDER).then_some(entity);
358
359 let action_diffs = self.entity_diffs(&entity, previous);
360
361 if !action_diffs.is_empty() {
362 writer.write(ActionDiffEvent {
363 owner,
364 action_diffs,
365 });
366 }
367 }
368 }
369}
370
371impl<A: Actionlike> Default for SummarizedActionState<A> {
373 fn default() -> Self {
374 Self {
375 button_state_map: Default::default(),
376 axis_state_map: Default::default(),
377 dual_axis_state_map: Default::default(),
378 triple_axis_state_map: Default::default(),
379 }
380 }
381}
382
383#[cfg(test)]
384mod tests {
385 use crate as leafwing_input_manager;
386
387 use super::*;
388 use crate::buttonlike::ButtonValue;
389 use bevy::{ecs::system::SystemState, prelude::*};
390
391 #[derive(Actionlike, Debug, Clone, Copy, PartialEq, Eq, Hash, Reflect)]
392 enum TestAction {
393 Button,
394 #[actionlike(Axis)]
395 Axis,
396 #[actionlike(DualAxis)]
397 DualAxis,
398 #[actionlike(TripleAxis)]
399 TripleAxis,
400 }
401
402 fn test_action_state() -> ActionState<TestAction> {
403 let mut action_state = ActionState::default();
404 action_state.press(&TestAction::Button);
405 action_state.set_value(&TestAction::Axis, 0.3);
406 action_state.set_axis_pair(&TestAction::DualAxis, Vec2::new(0.5, 0.7));
407 action_state.set_axis_triple(&TestAction::TripleAxis, Vec3::new(0.5, 0.7, 0.9));
408 action_state
409 }
410
411 #[derive(Component)]
412 struct NotSummarized;
413
414 fn expected_summary(entity: Entity) -> SummarizedActionState<TestAction> {
415 let mut button_state_map = HashMap::default();
416 let mut axis_state_map = HashMap::default();
417 let mut dual_axis_state_map = HashMap::default();
418 let mut triple_axis_state_map = HashMap::default();
419
420 let mut global_button_state = HashMap::default();
421 global_button_state.insert(TestAction::Button, ButtonValue::from_pressed(true));
422 button_state_map.insert(entity, global_button_state);
423
424 let mut global_axis_state = HashMap::default();
425 global_axis_state.insert(TestAction::Axis, 0.3);
426 axis_state_map.insert(entity, global_axis_state);
427
428 let mut global_dual_axis_state = HashMap::default();
429 global_dual_axis_state.insert(TestAction::DualAxis, Vec2::new(0.5, 0.7));
430 dual_axis_state_map.insert(entity, global_dual_axis_state);
431
432 let mut global_triple_axis_state = HashMap::default();
433 global_triple_axis_state.insert(TestAction::TripleAxis, Vec3::new(0.5, 0.7, 0.9));
434 triple_axis_state_map.insert(entity, global_triple_axis_state);
435
436 SummarizedActionState {
437 button_state_map,
438 axis_state_map,
439 dual_axis_state_map,
440 triple_axis_state_map,
441 }
442 }
443
444 #[test]
445 fn summarize_from_resource() {
446 let mut world = World::new();
447 world.insert_resource(test_action_state());
448 let mut system_state: SystemState<(
449 Option<Res<ActionState<TestAction>>>,
450 Query<(Entity, &ActionState<TestAction>)>,
451 )> = SystemState::new(&mut world);
452 let (global_action_state, action_state_query) = system_state.get(&world);
453 let summarized = SummarizedActionState::summarize(global_action_state, action_state_query);
454
455 assert_eq!(summarized, expected_summary(Entity::PLACEHOLDER));
457 }
458
459 #[test]
460 fn summarize_from_component() {
461 let mut world = World::new();
462 let entity = world.spawn(test_action_state()).id();
463 let mut system_state: SystemState<(
464 Option<Res<ActionState<TestAction>>>,
465 Query<(Entity, &ActionState<TestAction>)>,
466 )> = SystemState::new(&mut world);
467 let (global_action_state, action_state_query) = system_state.get(&world);
468 let summarized = SummarizedActionState::summarize(global_action_state, action_state_query);
469
470 assert_eq!(summarized, expected_summary(entity));
472 }
473
474 #[test]
475 fn summarize_filtered_entities_from_component() {
476 let mut world = World::new();
478 let entity = world.spawn(test_action_state()).id();
479 world.spawn((test_action_state(), NotSummarized));
480
481 let mut system_state: SystemState<(
482 Option<Res<ActionState<TestAction>>>,
483 Query<(Entity, &ActionState<TestAction>), Without<NotSummarized>>,
484 )> = SystemState::new(&mut world);
485 let (global_action_state, action_state_query) = system_state.get(&world);
486 let summarized =
487 SummarizedActionState::summarize_filtered(global_action_state, action_state_query);
488
489 assert_eq!(summarized, expected_summary(entity));
491 }
492
493 #[test]
494 fn diffs_are_sent() {
495 let mut world = World::new();
496 world.init_resource::<Events<ActionDiffEvent<TestAction>>>();
497
498 let entity = world.spawn(test_action_state()).id();
499 let mut system_state: SystemState<(
500 Option<Res<ActionState<TestAction>>>,
501 Query<(Entity, &ActionState<TestAction>)>,
502 EventWriter<ActionDiffEvent<TestAction>>,
503 )> = SystemState::new(&mut world);
504 let (global_action_state, action_state_query, mut action_diff_writer) =
505 system_state.get_mut(&mut world);
506 let summarized = SummarizedActionState::summarize(global_action_state, action_state_query);
507
508 let previous = SummarizedActionState::default();
509 summarized.send_diffs(&previous, &mut action_diff_writer);
510
511 let mut system_state: SystemState<EventReader<ActionDiffEvent<TestAction>>> =
512 SystemState::new(&mut world);
513 let mut event_reader = system_state.get_mut(&mut world);
514 let action_diff_events = event_reader.read().collect::<Vec<_>>();
515
516 dbg!(&action_diff_events);
517 assert_eq!(action_diff_events.len(), 1);
518 let action_diff_event = action_diff_events[0];
519 assert_eq!(action_diff_event.owner, Some(entity));
520 assert_eq!(action_diff_event.action_diffs.len(), 4);
521 }
522}