1use bevy::{
9 ecs::{
10 entity::{Entity, MapEntities},
11 message::Message,
12 query::QueryFilter,
13 },
14 math::{Vec2, Vec3},
15 platform::collections::{HashMap, HashSet},
16 prelude::{EntityMapper, MessageWriter, Query},
17};
18use serde::{Deserialize, Serialize};
19
20use crate::buttonlike::ButtonValue;
21use crate::{Actionlike, action_state::ActionKindData, prelude::ActionState};
22
23#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
30pub enum ActionDiff<A: Actionlike> {
31 Pressed {
33 action: A,
35 value: f32,
37 },
38 Released {
40 action: A,
42 },
43 AxisChanged {
45 action: A,
47 value: f32,
49 },
50 DualAxisChanged {
52 action: A,
54 axis_pair: Vec2,
56 },
57 TripleAxisChanged {
59 action: A,
61 axis_triple: Vec3,
63 },
64}
65
66#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Message)]
70pub struct ActionDiffMessage<A: Actionlike> {
71 pub owner: Entity,
73 pub action_diffs: Vec<ActionDiff<A>>,
75}
76
77impl<A: Actionlike> MapEntities for ActionDiffMessage<A> {
82 fn map_entities<M: EntityMapper>(&mut self, entity_mapper: &mut M) {
83 self.owner = entity_mapper.get_mapped(self.owner);
84 }
85}
86
87#[derive(Debug, PartialEq, Clone)]
90pub struct SummarizedActionState<A: Actionlike> {
91 button_state_map: HashMap<Entity, HashMap<A, ButtonValue>>,
92 axis_state_map: HashMap<Entity, HashMap<A, f32>>,
93 dual_axis_state_map: HashMap<Entity, HashMap<A, Vec2>>,
94 triple_axis_state_map: HashMap<Entity, HashMap<A, Vec3>>,
95}
96
97impl<A: Actionlike> SummarizedActionState<A> {
98 pub fn all_entities(&self) -> HashSet<Entity> {
100 let mut entities = HashSet::default();
101 let button_entities = self.button_state_map.keys();
102 let axis_entities = self.axis_state_map.keys();
103 let dual_axis_entities = self.dual_axis_state_map.keys();
104 let triple_axis_entities = self.triple_axis_state_map.keys();
105
106 entities.extend(button_entities);
107 entities.extend(axis_entities);
108 entities.extend(dual_axis_entities);
109 entities.extend(triple_axis_entities);
110
111 entities
112 }
113
114 pub fn summarize(action_state_query: Query<(Entity, &ActionState<A>)>) -> Self {
116 Self::summarize_filtered(action_state_query)
117 }
118
119 pub fn summarize_filtered<F: QueryFilter>(
122 action_state_query: Query<(Entity, &ActionState<A>), F>,
123 ) -> Self {
124 let mut button_state_map = HashMap::default();
125 let mut axis_state_map = HashMap::default();
126 let mut dual_axis_state_map = HashMap::default();
127 let mut triple_axis_state_map = HashMap::default();
128
129 for (entity, action_state) in action_state_query
130 .iter()
131 .filter(|(_, action_state)| !action_state.disabled())
132 {
133 let mut per_entity_button_state = HashMap::default();
134 let mut per_entity_axis_state = HashMap::default();
135 let mut per_entity_dual_axis_state = HashMap::default();
136 let mut per_entity_triple_axis_state = HashMap::default();
137
138 for (action, action_data) in action_state
139 .all_action_data()
140 .iter()
141 .filter(|(_, action_data)| !action_data.disabled)
142 {
143 match &action_data.kind_data {
144 ActionKindData::Button(button_data) => {
145 per_entity_button_state
146 .insert(action.clone(), button_data.to_button_value());
147 }
148 ActionKindData::Axis(axis_data) => {
149 per_entity_axis_state.insert(action.clone(), axis_data.value);
150 }
151 ActionKindData::DualAxis(dual_axis_data) => {
152 per_entity_dual_axis_state.insert(action.clone(), dual_axis_data.pair);
153 }
154 ActionKindData::TripleAxis(triple_axis_data) => {
155 per_entity_triple_axis_state
156 .insert(action.clone(), triple_axis_data.triple);
157 }
158 }
159 }
160
161 button_state_map.insert(entity, per_entity_button_state);
162 axis_state_map.insert(entity, per_entity_axis_state);
163 dual_axis_state_map.insert(entity, per_entity_dual_axis_state);
164 triple_axis_state_map.insert(entity, per_entity_triple_axis_state);
165 }
166
167 Self {
168 button_state_map,
169 axis_state_map,
170 dual_axis_state_map,
171 triple_axis_state_map,
172 }
173 }
174
175 pub fn button_diff(
181 action: A,
182 previous_button: Option<ButtonValue>,
183 current_button: Option<ButtonValue>,
184 ) -> Option<ActionDiff<A>> {
185 let previous_button = previous_button.unwrap_or_default();
186 let current_button = current_button?;
187
188 (previous_button != current_button).then(|| {
189 if current_button.pressed {
190 ActionDiff::Pressed {
191 action,
192 value: current_button.value,
193 }
194 } else {
195 ActionDiff::Released { action }
196 }
197 })
198 }
199
200 pub fn axis_diff(
205 action: A,
206 previous_axis: Option<f32>,
207 current_axis: Option<f32>,
208 ) -> Option<ActionDiff<A>> {
209 let previous_axis = previous_axis.unwrap_or_default();
210 let current_axis = current_axis?;
211
212 (previous_axis != current_axis).then(|| ActionDiff::AxisChanged {
213 action,
214 value: current_axis,
215 })
216 }
217
218 pub fn dual_axis_diff(
221 action: A,
222 previous_dual_axis: Option<Vec2>,
223 current_dual_axis: Option<Vec2>,
224 ) -> Option<ActionDiff<A>> {
225 let previous_dual_axis = previous_dual_axis.unwrap_or_default();
226 let current_dual_axis = current_dual_axis?;
227
228 (previous_dual_axis != current_dual_axis).then(|| ActionDiff::DualAxisChanged {
229 action,
230 axis_pair: current_dual_axis,
231 })
232 }
233
234 pub fn triple_axis_diff(
237 action: A,
238 previous_triple_axis: Option<Vec3>,
239 current_triple_axis: Option<Vec3>,
240 ) -> Option<ActionDiff<A>> {
241 let previous_triple_axis = previous_triple_axis.unwrap_or_default();
242 let current_triple_axis = current_triple_axis?;
243
244 (previous_triple_axis != current_triple_axis).then(|| ActionDiff::TripleAxisChanged {
245 action,
246 axis_triple: current_triple_axis,
247 })
248 }
249
250 pub fn entity_diffs(&self, entity: &Entity, previous: &Self) -> Vec<ActionDiff<A>> {
252 let mut action_diffs = Vec::new();
253
254 if let Some(current_button_state) = self.button_state_map.get(entity) {
255 let previous_button_state = previous.button_state_map.get(entity);
256 for (action, current_button) in current_button_state {
257 let previous_button = previous_button_state
258 .and_then(|previous_button_state| previous_button_state.get(action))
259 .copied();
260
261 if let Some(diff) =
262 Self::button_diff(action.clone(), previous_button, Some(*current_button))
263 {
264 action_diffs.push(diff);
265 }
266 }
267 }
268
269 if let Some(current_axis_state) = self.axis_state_map.get(entity) {
270 let previous_axis_state = previous.axis_state_map.get(entity);
271 for (action, current_axis) in current_axis_state {
272 let previous_axis = previous_axis_state
273 .and_then(|previous_axis_state| previous_axis_state.get(action))
274 .copied();
275
276 if let Some(diff) =
277 Self::axis_diff(action.clone(), previous_axis, Some(*current_axis))
278 {
279 action_diffs.push(diff);
280 }
281 }
282 }
283
284 if let Some(current_dual_axis_state) = self.dual_axis_state_map.get(entity) {
285 let previous_dual_axis_state = previous.dual_axis_state_map.get(entity);
286 for (action, current_dual_axis) in current_dual_axis_state {
287 let previous_dual_axis = previous_dual_axis_state
288 .and_then(|previous_dual_axis_state| previous_dual_axis_state.get(action))
289 .copied();
290
291 if let Some(diff) = Self::dual_axis_diff(
292 action.clone(),
293 previous_dual_axis,
294 Some(*current_dual_axis),
295 ) {
296 action_diffs.push(diff);
297 }
298 }
299 }
300
301 if let Some(current_triple_axis_state) = self.triple_axis_state_map.get(entity) {
302 let previous_triple_axis_state = previous.triple_axis_state_map.get(entity);
303 for (action, current_triple_axis) in current_triple_axis_state {
304 let previous_triple_axis = previous_triple_axis_state
305 .and_then(|previous_triple_axis_state| previous_triple_axis_state.get(action))
306 .copied();
307
308 if let Some(diff) = Self::triple_axis_diff(
309 action.clone(),
310 previous_triple_axis,
311 Some(*current_triple_axis),
312 ) {
313 action_diffs.push(diff);
314 }
315 }
316 }
317
318 action_diffs
319 }
320
321 pub fn send_diffs(&self, previous: &Self, writer: &mut MessageWriter<ActionDiffMessage<A>>) {
323 for entity in self.all_entities() {
324 let action_diffs = self.entity_diffs(&entity, previous);
325
326 if !action_diffs.is_empty() {
327 writer.write(ActionDiffMessage {
328 owner: entity,
329 action_diffs,
330 });
331 }
332 }
333 }
334}
335
336impl<A: Actionlike> Default for SummarizedActionState<A> {
338 fn default() -> Self {
339 Self {
340 button_state_map: Default::default(),
341 axis_state_map: Default::default(),
342 dual_axis_state_map: Default::default(),
343 triple_axis_state_map: Default::default(),
344 }
345 }
346}
347
348#[cfg(test)]
349mod tests {
350 use crate as leafwing_input_manager;
351
352 use super::*;
353 use crate::buttonlike::ButtonValue;
354 use bevy::{ecs::system::SystemState, prelude::*};
355
356 #[derive(Actionlike, Debug, Clone, Copy, PartialEq, Eq, Hash, Reflect)]
357 enum TestAction {
358 Button,
359 #[actionlike(Axis)]
360 Axis,
361 #[actionlike(DualAxis)]
362 DualAxis,
363 #[actionlike(TripleAxis)]
364 TripleAxis,
365 }
366
367 fn test_action_state() -> ActionState<TestAction> {
368 let mut action_state = ActionState::default();
369 action_state.press(&TestAction::Button);
370 action_state.set_value(&TestAction::Axis, 0.3);
371 action_state.set_axis_pair(&TestAction::DualAxis, Vec2::new(0.5, 0.7));
372 action_state.set_axis_triple(&TestAction::TripleAxis, Vec3::new(0.5, 0.7, 0.9));
373 action_state
374 }
375
376 #[derive(Component)]
377 struct NotSummarized;
378
379 fn expected_summary(entity: Entity) -> SummarizedActionState<TestAction> {
380 let mut button_state_map = HashMap::default();
381 let mut axis_state_map = HashMap::default();
382 let mut dual_axis_state_map = HashMap::default();
383 let mut triple_axis_state_map = HashMap::default();
384
385 let mut global_button_state = HashMap::default();
386 global_button_state.insert(TestAction::Button, ButtonValue::from_pressed(true));
387 button_state_map.insert(entity, global_button_state);
388
389 let mut global_axis_state = HashMap::default();
390 global_axis_state.insert(TestAction::Axis, 0.3);
391 axis_state_map.insert(entity, global_axis_state);
392
393 let mut global_dual_axis_state = HashMap::default();
394 global_dual_axis_state.insert(TestAction::DualAxis, Vec2::new(0.5, 0.7));
395 dual_axis_state_map.insert(entity, global_dual_axis_state);
396
397 let mut global_triple_axis_state = HashMap::default();
398 global_triple_axis_state.insert(TestAction::TripleAxis, Vec3::new(0.5, 0.7, 0.9));
399 triple_axis_state_map.insert(entity, global_triple_axis_state);
400
401 SummarizedActionState {
402 button_state_map,
403 axis_state_map,
404 dual_axis_state_map,
405 triple_axis_state_map,
406 }
407 }
408
409 #[test]
410 fn summarize_from_component() {
411 let mut world = World::new();
412 let entity = world.spawn(test_action_state()).id();
413 let mut system_state: SystemState<Query<(Entity, &ActionState<TestAction>)>> =
414 SystemState::new(&mut world);
415 let action_state_query = system_state.get(&world).unwrap();
416 let summarized = SummarizedActionState::summarize(action_state_query);
417
418 assert_eq!(summarized, expected_summary(entity));
420 }
421
422 #[test]
423 fn summarize_filtered_entities_from_component() {
424 let mut world = World::new();
426 let entity = world.spawn(test_action_state()).id();
427 world.spawn((test_action_state(), NotSummarized));
428
429 let mut system_state: SystemState<
430 Query<(Entity, &ActionState<TestAction>), Without<NotSummarized>>,
431 > = SystemState::new(&mut world);
432 let action_state_query = system_state.get(&world).unwrap();
433 let summarized = SummarizedActionState::summarize_filtered(action_state_query);
434
435 assert_eq!(summarized, expected_summary(entity));
437 }
438
439 #[test]
440 fn diffs_are_sent() {
441 let mut world = World::new();
442 world.init_resource::<Messages<ActionDiffMessage<TestAction>>>();
443
444 let entity = world.spawn(test_action_state()).id();
445 let mut system_state: SystemState<(
446 Query<(Entity, &ActionState<TestAction>)>,
447 MessageWriter<ActionDiffMessage<TestAction>>,
448 )> = SystemState::new(&mut world);
449 let (action_state_query, mut action_diff_writer) =
450 system_state.get_mut(&mut world).unwrap();
451 let summarized = SummarizedActionState::summarize(action_state_query);
452
453 let previous = SummarizedActionState::default();
454 summarized.send_diffs(&previous, &mut action_diff_writer);
455
456 let mut system_state: SystemState<MessageReader<ActionDiffMessage<TestAction>>> =
457 SystemState::new(&mut world);
458 let mut message_reader = system_state.get_mut(&mut world).unwrap();
459 let action_diff_messages = message_reader.read().collect::<Vec<_>>();
460
461 dbg!(&action_diff_messages);
462 assert_eq!(action_diff_messages.len(), 1);
463 let action_diff_message = action_diff_messages[0];
464 assert_eq!(action_diff_message.owner, entity);
465 assert_eq!(action_diff_message.action_diffs.len(), 4);
466 }
467
468 fn test_action_state_disabled() -> ActionState<TestAction> {
469 let mut action_state = ActionState::default();
470 action_state.press(&TestAction::Button);
471 action_state.set_value(&TestAction::Axis, 0.3);
472 action_state.set_axis_pair(&TestAction::DualAxis, Vec2::new(0.5, 0.7));
473 action_state.set_axis_triple(&TestAction::TripleAxis, Vec3::new(0.5, 0.7, 0.9));
474 action_state.disable();
475 action_state
476 }
477
478 fn expected_summary_when_disabled() -> SummarizedActionState<TestAction> {
479 let button_state_map = HashMap::default();
480 let axis_state_map = HashMap::default();
481 let dual_axis_state_map = HashMap::default();
482 let triple_axis_state_map = HashMap::default();
483
484 SummarizedActionState {
485 button_state_map,
486 axis_state_map,
487 dual_axis_state_map,
488 triple_axis_state_map,
489 }
490 }
491
492 #[test]
493 fn summarize_filtered_from_disabled_component() {
494 let mut world = World::new();
495 world.spawn((test_action_state_disabled(), NotSummarized));
496
497 let mut system_state: SystemState<
498 Query<(Entity, &ActionState<TestAction>), Without<NotSummarized>>,
499 > = SystemState::new(&mut world);
500 let action_state_query = system_state.get(&world).unwrap();
501 let summarized = SummarizedActionState::summarize_filtered(action_state_query);
502
503 assert_eq!(summarized, expected_summary_when_disabled());
505 }
506
507 fn test_action_state_disabled_action() -> ActionState<TestAction> {
508 let mut action_state = ActionState::default();
509 action_state.press(&TestAction::Button);
510 action_state.set_value(&TestAction::Axis, 0.3);
511 action_state.set_axis_pair(&TestAction::DualAxis, Vec2::new(0.5, 0.7));
512 action_state.set_axis_triple(&TestAction::TripleAxis, Vec3::new(0.5, 0.7, 0.9));
513 action_state.disable_action(&TestAction::Button);
514 action_state
515 }
516
517 fn expected_summary_with_disabled_action(entity: Entity) -> SummarizedActionState<TestAction> {
518 let mut button_state_map = HashMap::default();
519 let mut axis_state_map = HashMap::default();
520 let mut dual_axis_state_map = HashMap::default();
521 let mut triple_axis_state_map = HashMap::default();
522
523 let global_button_state = HashMap::default();
524 button_state_map.insert(entity, global_button_state);
525
526 let mut global_axis_state = HashMap::default();
527 global_axis_state.insert(TestAction::Axis, 0.3);
528 axis_state_map.insert(entity, global_axis_state);
529
530 let mut global_dual_axis_state = HashMap::default();
531 global_dual_axis_state.insert(TestAction::DualAxis, Vec2::new(0.5, 0.7));
532 dual_axis_state_map.insert(entity, global_dual_axis_state);
533
534 let mut global_triple_axis_state = HashMap::default();
535 global_triple_axis_state.insert(TestAction::TripleAxis, Vec3::new(0.5, 0.7, 0.9));
536 triple_axis_state_map.insert(entity, global_triple_axis_state);
537
538 SummarizedActionState {
539 button_state_map,
540 axis_state_map,
541 dual_axis_state_map,
542 triple_axis_state_map,
543 }
544 }
545
546 #[test]
547 fn summarize_filtered_entites_from_component_disabled_action() {
548 let mut world = World::new();
549 let entity = world.spawn(test_action_state_disabled_action()).id();
550
551 let mut system_state: SystemState<
552 Query<(Entity, &ActionState<TestAction>), Without<NotSummarized>>,
553 > = SystemState::new(&mut world);
554 let action_state_query = system_state.get(&world).unwrap();
555 let summarized = SummarizedActionState::summarize_filtered(action_state_query);
556
557 assert_eq!(summarized, expected_summary_with_disabled_action(entity));
559 }
560}