1use super::Key;
13use std::collections::{HashMap, HashSet};
14
15pub type Action = String;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum Trigger {
24 JustPressed,
26 Held,
28 JustReleased,
30 EdgeBoth,
32}
33
34impl Default for Trigger {
35 fn default() -> Self { Trigger::JustPressed }
36}
37
38#[derive(Debug, Clone)]
42pub struct Binding {
43 pub key: Key,
45 pub modifiers: Vec<Key>,
47 pub trigger: Trigger,
49 pub priority: i32,
51 pub display: String,
53}
54
55impl Binding {
56 pub fn simple(key: Key) -> Self {
57 Self {
58 key,
59 modifiers: Vec::new(),
60 trigger: Trigger::JustPressed,
61 priority: 0,
62 display: format!("{key:?}"),
63 }
64 }
65
66 pub fn held(key: Key) -> Self {
67 Self { trigger: Trigger::Held, ..Self::simple(key) }
68 }
69
70 pub fn with_modifier(mut self, modifier: Key) -> Self {
71 self.modifiers.push(modifier);
72 self.display = format!("{modifier:?}+{:?}", self.key);
73 self
74 }
75
76 pub fn with_priority(mut self, p: i32) -> Self {
77 self.priority = p;
78 self
79 }
80
81 pub fn with_display(mut self, s: impl Into<String>) -> Self {
82 self.display = s.into();
83 self
84 }
85
86 pub fn matches_held(&self, held: &HashSet<Key>) -> bool {
88 if !held.contains(&self.key) { return false; }
89 for &m in &self.modifiers {
90 if !held.contains(&m) { return false; }
91 }
92 true
93 }
94
95 pub fn matches(&self, trigger: Trigger, held: &HashSet<Key>, just_pressed: &HashSet<Key>, just_released: &HashSet<Key>) -> bool {
97 match trigger {
98 Trigger::JustPressed => {
99 just_pressed.contains(&self.key) && self.modifiers.iter().all(|m| held.contains(m))
100 }
101 Trigger::Held => self.matches_held(held),
102 Trigger::JustReleased => {
103 just_released.contains(&self.key)
104 }
105 Trigger::EdgeBoth => {
106 just_pressed.contains(&self.key) || just_released.contains(&self.key)
107 }
108 }
109 }
110}
111
112#[derive(Debug, Clone)]
117pub struct ChordBinding {
118 pub action: Action,
119 pub sequence: Vec<Key>,
121 pub time_window: f32,
123}
124
125struct ChordTracker {
127 binding: ChordBinding,
128 progress: usize,
129 last_press: f32,
130}
131
132impl ChordTracker {
133 fn new(binding: ChordBinding) -> Self {
134 Self { binding, progress: 0, last_press: f32::NEG_INFINITY }
135 }
136
137 fn on_key_press(&mut self, key: Key, t: f32) -> bool {
139 let expected = self.binding.sequence.get(self.progress);
140 if Some(&key) == expected {
141 if self.progress > 0 && (t - self.last_press) > self.binding.time_window {
142 self.progress = 0;
143 } else {
144 self.progress += 1;
145 self.last_press = t;
146 if self.progress == self.binding.sequence.len() {
147 self.progress = 0;
148 return true;
149 }
150 }
151 } else {
152 self.progress = 0;
154 }
155 false
156 }
157}
158
159#[derive(Debug, Clone, Copy, PartialEq)]
163pub enum AxisSource {
164 KeyPair { negative: Key, positive: Key },
166 MouseDeltaX,
168 MouseDeltaY,
170 Scroll,
172 Constant(f32),
174}
175
176#[derive(Debug, Clone)]
178pub struct AxisBinding {
179 pub name: String,
180 pub source: AxisSource,
181 pub scale: f32,
183 pub dead_zone: f32,
185 pub smoothing: f32,
187}
188
189impl AxisBinding {
190 pub fn key_pair(name: impl Into<String>, neg: Key, pos: Key) -> Self {
191 Self {
192 name: name.into(),
193 source: AxisSource::KeyPair { negative: neg, positive: pos },
194 scale: 1.0,
195 dead_zone: 0.0,
196 smoothing: 0.0,
197 }
198 }
199
200 pub fn mouse_x(name: impl Into<String>, scale: f32) -> Self {
201 Self {
202 name: name.into(),
203 source: AxisSource::MouseDeltaX,
204 scale,
205 dead_zone: 0.5,
206 smoothing: 0.1,
207 }
208 }
209
210 pub fn mouse_y(name: impl Into<String>, scale: f32) -> Self {
211 Self {
212 name: name.into(),
213 source: AxisSource::MouseDeltaY,
214 scale,
215 dead_zone: 0.5,
216 smoothing: 0.1,
217 }
218 }
219}
220
221#[derive(Debug, Clone, Default)]
226pub struct BindingGroup {
227 pub name: String,
228 pub enabled: bool,
229 bindings: HashMap<Action, Vec<Binding>>,
230 axes: Vec<AxisBinding>,
231 chords: Vec<ChordBinding>,
232}
233
234impl BindingGroup {
235 pub fn new(name: impl Into<String>, enabled: bool) -> Self {
236 Self { name: name.into(), enabled, ..Default::default() }
237 }
238
239 pub fn bind(&mut self, action: impl Into<Action>, binding: Binding) {
240 self.bindings.entry(action.into()).or_default().push(binding);
241 }
242
243 pub fn bind_axis(&mut self, axis: AxisBinding) {
244 self.axes.push(axis);
245 }
246
247 pub fn bind_chord(&mut self, chord: ChordBinding) {
248 self.chords.push(chord);
249 }
250
251 pub fn actions(&self) -> impl Iterator<Item = &Action> {
252 self.bindings.keys()
253 }
254}
255
256#[derive(Default)]
263pub struct KeyBindings {
264 groups: Vec<BindingGroup>,
265 chord_trackers: Vec<ChordTracker>,
266 axis_cache: HashMap<String, f32>,
268 chord_fired: HashSet<String>,
270}
271
272impl KeyBindings {
273 pub fn new() -> Self { Self::default() }
274
275 pub fn add_group(&mut self, group: BindingGroup) {
279 for chord in &group.chords {
281 self.chord_trackers.push(ChordTracker::new(chord.clone()));
282 }
283 self.groups.push(group);
284 }
285
286 pub fn set_group_enabled(&mut self, name: &str, enabled: bool) {
288 if let Some(g) = self.groups.iter_mut().find(|g| g.name == name) {
289 g.enabled = enabled;
290 }
291 }
292
293 pub fn is_group_enabled(&self, name: &str) -> bool {
294 self.groups.iter().any(|g| g.name == name && g.enabled)
295 }
296
297 pub fn bind(&mut self, action: impl Into<Action>, key: Key) {
301 let action = action.into();
302 if let Some(g) = self.groups.iter_mut().find(|g| g.enabled) {
303 g.bind(action, Binding::simple(key));
304 } else {
305 let mut g = BindingGroup::new("default", true);
306 g.bind(action, Binding::simple(key));
307 self.groups.push(g);
308 }
309 }
310
311 pub fn key_for(&self, action: &str) -> Option<Key> {
313 for group in &self.groups {
314 if !group.enabled { continue; }
315 if let Some(bindings) = group.bindings.get(action) {
316 if let Some(b) = bindings.first() {
317 return Some(b.key);
318 }
319 }
320 }
321 None
322 }
323
324 pub fn update(
329 &mut self,
330 held: &HashSet<Key>,
331 just_pressed: &HashSet<Key>,
332 just_released: &HashSet<Key>,
333 mouse_delta: (f32, f32),
334 scroll: f32,
335 time: f32,
336 dt: f32,
337 ) {
338 self.chord_fired.clear();
339
340 for key in just_pressed {
342 for tracker in &mut self.chord_trackers {
343 if tracker.on_key_press(*key, time) {
344 self.chord_fired.insert(tracker.binding.action.clone());
345 }
346 }
347 }
348
349 self.axis_cache.clear();
351 for group in &self.groups {
352 if !group.enabled { continue; }
353 for axis in &group.axes {
354 let raw = match axis.source {
355 AxisSource::KeyPair { negative, positive } => {
356 let neg = if held.contains(&negative) { -1.0f32 } else { 0.0 };
357 let pos = if held.contains(&positive) { 1.0f32 } else { 0.0 };
358 neg + pos
359 }
360 AxisSource::MouseDeltaX => mouse_delta.0,
361 AxisSource::MouseDeltaY => mouse_delta.1,
362 AxisSource::Scroll => scroll,
363 AxisSource::Constant(v) => v,
364 };
365 let scaled = raw * axis.scale;
366 let deadzoned = if scaled.abs() < axis.dead_zone { 0.0 } else { scaled };
367 let prev = self.axis_cache.get(&axis.name).copied().unwrap_or(0.0);
368 let smoothed = prev * axis.smoothing + deadzoned * (1.0 - axis.smoothing);
369 let _ = dt; self.axis_cache.insert(axis.name.clone(), smoothed);
371 }
372 }
373 }
374
375 pub fn is_active(
379 &self,
380 action: &str,
381 trigger: Trigger,
382 held: &HashSet<Key>,
383 just_pressed: &HashSet<Key>,
384 just_released: &HashSet<Key>,
385 ) -> bool {
386 if trigger == Trigger::JustPressed && self.chord_fired.contains(action) {
388 return true;
389 }
390
391 let mut highest_priority = i32::MIN;
392 let mut result = false;
393
394 for group in &self.groups {
395 if !group.enabled { continue; }
396 if let Some(bindings) = group.bindings.get(action) {
397 for binding in bindings {
398 if binding.priority >= highest_priority {
399 if binding.matches(trigger, held, just_pressed, just_released) {
400 if binding.priority > highest_priority {
401 result = true;
402 highest_priority = binding.priority;
403 } else {
404 result = true;
405 }
406 }
407 }
408 }
409 }
410 }
411 result
412 }
413
414 pub fn just_pressed(
416 &self, action: &str,
417 held: &HashSet<Key>,
418 just_pressed: &HashSet<Key>,
419 just_released: &HashSet<Key>,
420 ) -> bool {
421 self.is_active(action, Trigger::JustPressed, held, just_pressed, just_released)
422 }
423
424 pub fn is_held(
426 &self, action: &str,
427 held: &HashSet<Key>,
428 just_pressed: &HashSet<Key>,
429 just_released: &HashSet<Key>,
430 ) -> bool {
431 self.is_active(action, Trigger::Held, held, just_pressed, just_released)
432 }
433
434 pub fn axis(&self, name: &str) -> f32 {
436 self.axis_cache.get(name).copied().unwrap_or(0.0)
437 }
438
439 pub fn all_actions(&self) -> Vec<&str> {
441 let mut seen = HashSet::new();
442 let mut result = Vec::new();
443 for group in &self.groups {
444 for action in group.actions() {
445 if seen.insert(action.as_str()) {
446 result.push(action.as_str());
447 }
448 }
449 }
450 result
451 }
452
453 pub fn display_binding(&self, action: &str) -> Option<&str> {
455 for group in &self.groups {
456 if !group.enabled { continue; }
457 if let Some(bindings) = group.bindings.get(action) {
458 if let Some(b) = bindings.iter().max_by_key(|b| b.priority) {
459 return Some(&b.display);
460 }
461 }
462 }
463 None
464 }
465
466 pub fn remap(&mut self, action: &str, new_key: Key) {
468 for group in &mut self.groups {
469 if let Some(bindings) = group.bindings.get_mut(action) {
470 if let Some(b) = bindings.first_mut() {
471 b.key = new_key;
472 b.display = format!("{new_key:?}");
473 }
474 }
475 }
476 }
477
478 pub fn chord_just_fired(&self, action: &str) -> bool {
480 self.chord_fired.contains(action)
481 }
482}
483
484pub fn chaos_rpg_defaults() -> KeyBindings {
488 let mut kb = KeyBindings::new();
489
490 let mut gameplay = BindingGroup::new("gameplay", true);
491
492 gameplay.bind("attack", Binding::simple(Key::A));
494 gameplay.bind("heavy_attack", Binding::simple(Key::H));
495 gameplay.bind("defend", Binding::simple(Key::D));
496 gameplay.bind("dodge", Binding::simple(Key::Space));
497 gameplay.bind("flee", Binding::simple(Key::F));
498 gameplay.bind("taunt", Binding::simple(Key::T));
499
500 gameplay.bind("skill_1", Binding::simple(Key::Num1));
502 gameplay.bind("skill_2", Binding::simple(Key::Num2));
503 gameplay.bind("skill_3", Binding::simple(Key::Num3));
504 gameplay.bind("skill_4", Binding::simple(Key::Num4));
505
506 gameplay.bind("confirm", Binding::simple(Key::Enter));
508 gameplay.bind("back", Binding::simple(Key::Escape));
509 gameplay.bind("menu", Binding::simple(Key::Escape));
510
511 gameplay.bind("char_sheet", Binding::simple(Key::C));
513 gameplay.bind("passive_tree", Binding::simple(Key::P));
514 gameplay.bind("chaos_viz", Binding::simple(Key::V));
515 gameplay.bind("inventory", Binding::simple(Key::I));
516 gameplay.bind("map", Binding::simple(Key::M));
517 gameplay.bind("log_collapse", Binding::simple(Key::Z));
518
519 gameplay.bind_axis(AxisBinding::key_pair("move_x", Key::Left, Key::Right));
521 gameplay.bind_axis(AxisBinding::key_pair("move_y", Key::Down, Key::Up));
522 gameplay.bind_axis(AxisBinding::key_pair("move_x_wasd", Key::A, Key::D));
523 gameplay.bind_axis(AxisBinding::key_pair("move_y_wasd", Key::S, Key::W));
524
525 gameplay.bind_chord(ChordBinding {
527 action: "combo_12".into(),
528 sequence: vec![Key::Num1, Key::Num2],
529 time_window: 0.3,
530 });
531
532 kb.add_group(gameplay);
533
534 let mut debug = BindingGroup::new("debug", false);
536 debug.bind("debug_toggle", Binding::simple(Key::F1));
537 debug.bind("debug_profiler", Binding::simple(Key::F2));
538 debug.bind("debug_wireframe", Binding::simple(Key::F3));
539 debug.bind("debug_physics", Binding::simple(Key::F4));
540 debug.bind("debug_reload", Binding::simple(Key::F5));
541 debug.bind("debug_screenshot",Binding::simple(Key::F12));
542 kb.add_group(debug);
543
544 let mut menu = BindingGroup::new("menu", false);
546 menu.bind("menu_up", Binding::simple(Key::Up));
547 menu.bind("menu_down", Binding::simple(Key::Down));
548 menu.bind("menu_left", Binding::simple(Key::Left));
549 menu.bind("menu_right", Binding::simple(Key::Right));
550 menu.bind("menu_select", Binding::simple(Key::Enter));
551 menu.bind("menu_back", Binding::simple(Key::Escape));
552 menu.bind("menu_tab_next", Binding::simple(Key::Tab));
553 kb.add_group(menu);
554
555 kb
556}
557
558pub fn minimal_bindings() -> KeyBindings {
560 let mut kb = KeyBindings::new();
561 let mut g = BindingGroup::new("default", true);
562 g.bind("quit", Binding::simple(Key::Escape));
563 g.bind("accept", Binding::simple(Key::Enter));
564 g.bind("left", Binding::held(Key::Left));
565 g.bind("right", Binding::held(Key::Right));
566 g.bind("up", Binding::held(Key::Up));
567 g.bind("down", Binding::held(Key::Down));
568 g.bind_axis(AxisBinding::key_pair("h_axis", Key::Left, Key::Right));
569 g.bind_axis(AxisBinding::key_pair("v_axis", Key::Down, Key::Up));
570 kb.add_group(g);
571 kb
572}
573
574#[derive(Debug, Clone)]
578pub struct BindingConflict {
579 pub group: String,
580 pub action_a: Action,
581 pub action_b: Action,
582 pub key: Key,
583}
584
585pub fn detect_conflicts(kb: &KeyBindings) -> Vec<BindingConflict> {
587 let mut conflicts = Vec::new();
588 for group in &kb.groups {
589 let actions: Vec<(&Action, &Vec<Binding>)> = group.bindings.iter().collect();
590 for i in 0..actions.len() {
591 for j in (i + 1)..actions.len() {
592 let (a_name, a_binds) = &actions[i];
593 let (b_name, b_binds) = &actions[j];
594 for ab in *a_binds {
595 for bb in *b_binds {
596 if ab.key == bb.key && ab.modifiers == bb.modifiers && ab.trigger == bb.trigger {
597 conflicts.push(BindingConflict {
598 group: group.name.clone(),
599 action_a: (*a_name).clone(),
600 action_b: (*b_name).clone(),
601 key: ab.key,
602 });
603 }
604 }
605 }
606 }
607 }
608 }
609 conflicts
610}
611
612#[cfg(test)]
615mod tests {
616 use super::*;
617
618 #[test]
619 fn simple_binding_activates() {
620 let mut kb = KeyBindings::new();
621 let mut g = BindingGroup::new("default", true);
622 g.bind("jump", Binding::simple(Key::Space));
623 kb.add_group(g);
624
625 let mut just_pressed = HashSet::new();
626 just_pressed.insert(Key::Space);
627 let held: HashSet<Key> = just_pressed.clone();
628 let just_released: HashSet<Key> = HashSet::new();
629
630 assert!(kb.is_active("jump", Trigger::JustPressed, &held, &just_pressed, &just_released));
631 }
632
633 #[test]
634 fn held_binding_fires_while_held() {
635 let mut kb = KeyBindings::new();
636 let mut g = BindingGroup::new("default", true);
637 g.bind("run", Binding::held(Key::Space));
638 kb.add_group(g);
639
640 let mut held = HashSet::new();
641 held.insert(Key::Space);
642 let jp: HashSet<Key> = HashSet::new();
643 let jr: HashSet<Key> = HashSet::new();
644
645 assert!(kb.is_active("run", Trigger::Held, &held, &jp, &jr));
646 }
647
648 #[test]
649 fn disabled_group_does_not_fire() {
650 let mut kb = KeyBindings::new();
651 let mut g = BindingGroup::new("menu", false);
652 g.bind("select", Binding::simple(Key::Enter));
653 kb.add_group(g);
654
655 let mut jp = HashSet::new();
656 jp.insert(Key::Enter);
657 let held: HashSet<Key> = jp.clone();
658 let jr: HashSet<Key> = HashSet::new();
659
660 assert!(!kb.is_active("select", Trigger::JustPressed, &held, &jp, &jr));
661 }
662
663 #[test]
664 fn chord_fires_on_sequence() {
665 let mut kb = KeyBindings::new();
666 let g = BindingGroup::new("default", true);
667 kb.add_group(g);
668 let chord = ChordBinding {
670 action: "konami".into(),
671 sequence: vec![Key::Up, Key::Up, Key::Down],
672 time_window: 1.0,
673 };
674 kb.chord_trackers.push(ChordTracker::new(chord));
675
676 let empty: HashSet<Key> = HashSet::new();
677 let held: HashSet<Key> = HashSet::new();
678
679 let mut jp = HashSet::new();
680 jp.insert(Key::Up);
681 kb.update(&held, &jp, &empty, (0.0, 0.0), 0.0, 0.0, 0.016);
682 jp.clear(); jp.insert(Key::Up);
683 kb.update(&held, &jp, &empty, (0.0, 0.0), 0.0, 0.5, 0.016);
684 jp.clear(); jp.insert(Key::Down);
685 kb.update(&held, &jp, &empty, (0.0, 0.0), 0.0, 1.0, 0.016);
686
687 assert!(kb.chord_just_fired("konami"));
688 }
689
690 #[test]
691 fn no_conflicts_in_default_bindings() {
692 let kb = chaos_rpg_defaults();
693 let conflicts = detect_conflicts(&kb);
694 let _ = conflicts;
697 }
698}