1#[cfg(not(target_arch = "wasm32"))]
2use glutin::event::{ElementState, MouseButton, MouseScrollDelta, VirtualKeyCode, WindowEvent};
3use std::{
4 borrow::Cow,
5 cmp::Ordering,
6 collections::HashMap,
7 sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard},
8};
9use typid::ID;
10#[cfg(target_arch = "wasm32")]
11use winit::event::{ElementState, MouseButton, MouseScrollDelta, VirtualKeyCode, WindowEvent};
12
13#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
14pub enum InputConsume {
15 #[default]
16 None,
17 Hit,
18 All,
19}
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
22pub enum VirtualAction {
23 KeyButton(VirtualKeyCode),
24 MouseButton(MouseButton),
25 Axis(u32),
26}
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
29pub enum VirtualAxis {
30 KeyButton(VirtualKeyCode),
31 MousePositionX,
32 MousePositionY,
33 MouseWheelX,
34 MouseWheelY,
35 MouseButton(MouseButton),
36 Axis(u32),
37}
38
39#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
40pub enum InputAction {
41 #[default]
42 Idle,
43 Pressed,
44 Hold,
45 Released,
46}
47
48impl InputAction {
49 pub fn change(self, hold: bool) -> Self {
50 match (self, hold) {
51 (Self::Idle, true) | (Self::Released, true) => Self::Pressed,
52 (Self::Pressed, true) => Self::Hold,
53 (Self::Pressed, false) | (Self::Hold, false) => Self::Released,
54 (Self::Released, false) => Self::Idle,
55 _ => self,
56 }
57 }
58
59 pub fn update(self) -> Self {
60 match self {
61 Self::Pressed => Self::Hold,
62 Self::Released => Self::Idle,
63 _ => self,
64 }
65 }
66
67 pub fn is_idle(self) -> bool {
68 matches!(self, Self::Idle)
69 }
70
71 pub fn is_pressed(self) -> bool {
72 matches!(self, Self::Pressed)
73 }
74
75 pub fn is_hold(self) -> bool {
76 matches!(self, Self::Hold)
77 }
78
79 pub fn is_released(self) -> bool {
80 matches!(self, Self::Released)
81 }
82
83 pub fn is_up(self) -> bool {
84 matches!(self, Self::Idle | Self::Released)
85 }
86
87 pub fn is_down(self) -> bool {
88 matches!(self, Self::Pressed | Self::Hold)
89 }
90
91 pub fn is_changing(self) -> bool {
92 matches!(self, Self::Pressed | Self::Released)
93 }
94
95 pub fn is_continuing(self) -> bool {
96 matches!(self, Self::Idle | Self::Hold)
97 }
98
99 pub fn to_scalar(self, falsy: f32, truthy: f32) -> f32 {
100 if self.is_down() { truthy } else { falsy }
101 }
102}
103
104#[derive(Debug, Default, Clone, Copy, PartialEq)]
105pub struct InputAxis(pub f32);
106
107impl InputAxis {
108 pub fn threshold(self, value: f32) -> bool {
109 self.0 >= value
110 }
111}
112
113#[derive(Debug, Default, Clone)]
114pub struct InputRef<T: Default + Clone>(Arc<RwLock<T>>);
115
116impl<T: Default + Clone> InputRef<T> {
117 pub fn new(data: T) -> Self {
118 Self(Arc::new(RwLock::new(data)))
119 }
120
121 pub fn read(&self) -> Option<RwLockReadGuard<T>> {
122 self.0.read().ok()
123 }
124
125 pub fn write(&self) -> Option<RwLockWriteGuard<T>> {
126 self.0.write().ok()
127 }
128
129 pub fn get(&self) -> T {
130 self.read().map(|value| value.clone()).unwrap_or_default()
131 }
132
133 pub fn set(&self, value: T) {
134 if let Some(mut data) = self.write() {
135 *data = value;
136 }
137 }
138}
139
140pub type InputActionRef = InputRef<InputAction>;
141pub type InputAxisRef = InputRef<InputAxis>;
142pub type InputCharactersRef = InputRef<InputCharacters>;
143pub type InputMappingRef = InputRef<InputMapping>;
144
145#[derive(Debug, Default, Clone)]
146pub enum InputActionOrAxisRef {
147 #[default]
148 None,
149 Action(InputActionRef),
150 Axis(InputAxisRef),
151}
152
153impl InputActionOrAxisRef {
154 pub fn is_none(&self) -> bool {
155 matches!(self, Self::None)
156 }
157
158 pub fn is_some(&self) -> bool {
159 !self.is_none()
160 }
161
162 pub fn get_scalar(&self, falsy: f32, truthy: f32) -> f32 {
163 match self {
164 Self::None => falsy,
165 Self::Action(action) => action.get().to_scalar(falsy, truthy),
166 Self::Axis(axis) => axis.get().0,
167 }
168 }
169
170 pub fn threshold(&self, value: f32) -> bool {
171 match self {
172 Self::None => false,
173 Self::Action(action) => action.get().is_down(),
174 Self::Axis(axis) => axis.get().threshold(value),
175 }
176 }
177}
178
179impl From<InputActionRef> for InputActionOrAxisRef {
180 fn from(value: InputActionRef) -> Self {
181 Self::Action(value)
182 }
183}
184
185impl From<InputAxisRef> for InputActionOrAxisRef {
186 fn from(value: InputAxisRef) -> Self {
187 Self::Axis(value)
188 }
189}
190
191pub struct InputCombinator<T> {
192 mapper: Box<dyn Fn() -> T>,
193}
194
195impl<T: Default> Default for InputCombinator<T> {
196 fn default() -> Self {
197 Self::new(|| T::default())
198 }
199}
200
201impl<T> InputCombinator<T> {
202 pub fn new(mapper: impl Fn() -> T + 'static) -> Self {
203 Self {
204 mapper: Box::new(mapper),
205 }
206 }
207
208 pub fn get(&self) -> T {
209 (self.mapper)()
210 }
211}
212
213#[derive(Default)]
214pub struct CardinalInputCombinator(InputCombinator<[f32; 2]>);
215
216impl CardinalInputCombinator {
217 pub fn new(
218 left: impl Into<InputActionOrAxisRef>,
219 right: impl Into<InputActionOrAxisRef>,
220 up: impl Into<InputActionOrAxisRef>,
221 down: impl Into<InputActionOrAxisRef>,
222 ) -> Self {
223 let left = left.into();
224 let right = right.into();
225 let up = up.into();
226 let down = down.into();
227 Self(InputCombinator::new(move || {
228 let left = left.get_scalar(0.0, -1.0);
229 let right = right.get_scalar(0.0, 1.0);
230 let up = up.get_scalar(0.0, -1.0);
231 let down = down.get_scalar(0.0, 1.0);
232 [left + right, up + down]
233 }))
234 }
235
236 pub fn get(&self) -> [f32; 2] {
237 self.0.get()
238 }
239}
240
241#[derive(Default)]
242pub struct DualInputCombinator(InputCombinator<f32>);
243
244impl DualInputCombinator {
245 pub fn new(
246 negative: impl Into<InputActionOrAxisRef>,
247 positive: impl Into<InputActionOrAxisRef>,
248 ) -> Self {
249 let negative = negative.into();
250 let positive = positive.into();
251 Self(InputCombinator::new(move || {
252 let negative = negative.get_scalar(0.0, -1.0);
253 let positive = positive.get_scalar(0.0, 1.0);
254 negative + positive
255 }))
256 }
257
258 pub fn get(&self) -> f32 {
259 self.0.get()
260 }
261}
262
263pub struct ArrayInputCombinator<const N: usize>(InputCombinator<[f32; N]>);
264
265impl<const N: usize> Default for ArrayInputCombinator<N> {
266 fn default() -> Self {
267 Self(InputCombinator::new(|| {
268 std::array::from_fn(|_| Default::default())
269 }))
270 }
271}
272
273impl<const N: usize> ArrayInputCombinator<N> {
274 pub fn new(inputs: [impl Into<InputActionOrAxisRef>; N]) -> Self {
275 let mut items = std::array::from_fn::<InputActionOrAxisRef, N, _>(|_| Default::default());
276 for (index, input) in inputs.into_iter().enumerate() {
277 items[index] = input.into();
278 }
279 Self(InputCombinator::new(move || {
280 std::array::from_fn(|index| items[index].get_scalar(0.0, 1.0))
281 }))
282 }
283
284 pub fn get(&self) -> [f32; N] {
285 self.0.get()
286 }
287}
288
289#[derive(Debug, Default, Clone)]
290pub struct InputCharacters {
291 characters: String,
292}
293
294impl InputCharacters {
295 pub fn read(&self) -> &str {
296 &self.characters
297 }
298
299 pub fn write(&mut self) -> &mut String {
300 &mut self.characters
301 }
302
303 pub fn take(&mut self) -> String {
304 std::mem::take(&mut self.characters)
305 }
306}
307
308#[derive(Debug, Default, Clone)]
309pub struct InputMapping {
310 pub actions: HashMap<VirtualAction, InputActionRef>,
311 pub axes: HashMap<VirtualAxis, InputAxisRef>,
312 pub consume: InputConsume,
313 pub layer: isize,
314 pub name: Cow<'static, str>,
315}
316
317impl InputMapping {
318 pub fn action(mut self, id: VirtualAction, action: InputActionRef) -> Self {
319 self.actions.insert(id, action);
320 self
321 }
322
323 pub fn axis(mut self, id: VirtualAxis, axis: InputAxisRef) -> Self {
324 self.axes.insert(id, axis);
325 self
326 }
327
328 pub fn consume(mut self, consume: InputConsume) -> Self {
329 self.consume = consume;
330 self
331 }
332
333 pub fn layer(mut self, value: isize) -> Self {
334 self.layer = value;
335 self
336 }
337
338 pub fn name(mut self, value: impl Into<Cow<'static, str>>) -> Self {
339 self.name = value.into();
340 self
341 }
342}
343
344impl From<InputMapping> for InputMappingRef {
345 fn from(value: InputMapping) -> Self {
346 Self::new(value)
347 }
348}
349
350#[derive(Debug, Clone)]
351pub struct InputContext {
352 pub mouse_wheel_line_scale: f32,
353 mappings_stack: Vec<(ID<InputMapping>, InputMappingRef)>,
355 characters: InputCharactersRef,
356}
357
358impl Default for InputContext {
359 fn default() -> Self {
360 Self {
361 mouse_wheel_line_scale: Self::default_mouse_wheel_line_scale(),
362 mappings_stack: Default::default(),
363 characters: Default::default(),
364 }
365 }
366}
367
368impl InputContext {
369 fn default_mouse_wheel_line_scale() -> f32 {
370 10.0
371 }
372
373 pub fn push_mapping(&mut self, mapping: impl Into<InputMappingRef>) -> ID<InputMapping> {
374 let mapping = mapping.into();
375 let id = ID::default();
376 let layer = mapping.read().unwrap().layer;
377 let index = self
378 .mappings_stack
379 .binary_search_by(|(_, mapping)| {
380 mapping
381 .read()
382 .unwrap()
383 .layer
384 .cmp(&layer)
385 .then(Ordering::Less)
386 })
387 .unwrap_or_else(|index| index);
388 self.mappings_stack.insert(index, (id, mapping));
389 id
390 }
391
392 pub fn pop_mapping(&mut self) -> Option<InputMappingRef> {
393 self.mappings_stack.pop().map(|(_, mapping)| mapping)
394 }
395
396 pub fn top_mapping(&self) -> Option<&InputMappingRef> {
397 self.mappings_stack.last().map(|(_, mapping)| mapping)
398 }
399
400 pub fn remove_mapping(&mut self, id: ID<InputMapping>) -> Option<InputMappingRef> {
401 self.mappings_stack
402 .iter()
403 .position(|(mid, _)| mid == &id)
404 .map(|index| self.mappings_stack.remove(index).1)
405 }
406
407 pub fn mapping(&self, id: ID<InputMapping>) -> Option<RwLockReadGuard<InputMapping>> {
408 self.mappings_stack
409 .iter()
410 .find(|(mid, _)| mid == &id)
411 .and_then(|(_, mapping)| mapping.read())
412 }
413
414 pub fn stack(&self) -> impl Iterator<Item = &InputMappingRef> {
415 self.mappings_stack.iter().map(|(_, mapping)| mapping)
416 }
417
418 pub fn characters(&self) -> InputCharactersRef {
419 self.characters.clone()
420 }
421
422 pub fn maintain(&mut self) {
423 for (_, mapping) in &mut self.mappings_stack {
424 if let Some(mut mapping) = mapping.write() {
425 for action in mapping.actions.values_mut() {
426 if let Some(mut action) = action.write() {
427 *action = action.update();
428 }
429 }
430 for (id, axis) in &mut mapping.axes {
431 if let VirtualAxis::MouseWheelX | VirtualAxis::MouseWheelY = id {
432 if let Some(mut axis) = axis.write() {
433 axis.0 = 0.0;
434 }
435 }
436 }
437 }
438 }
439 }
440
441 pub fn on_event(&mut self, event: &WindowEvent) {
442 match event {
443 WindowEvent::ReceivedCharacter(character) => {
444 if let Some(mut characters) = self.characters.write() {
445 characters.characters.push(*character);
446 }
447 }
448 WindowEvent::KeyboardInput { input, .. } => {
449 if let Some(key) = input.virtual_keycode {
450 for (_, mapping) in self.mappings_stack.iter().rev() {
451 if let Some(mapping) = mapping.read() {
452 let mut consume = mapping.consume == InputConsume::All;
453 for (id, data) in &mapping.actions {
454 if let VirtualAction::KeyButton(button) = id {
455 if *button == key {
456 if let Some(mut data) = data.write() {
457 *data =
458 data.change(input.state == ElementState::Pressed);
459 if mapping.consume == InputConsume::Hit {
460 consume = true;
461 }
462 }
463 }
464 }
465 }
466 for (id, data) in &mapping.axes {
467 if let VirtualAxis::KeyButton(button) = id {
468 if *button == key {
469 if let Some(mut data) = data.write() {
470 data.0 = if input.state == ElementState::Pressed {
471 1.0
472 } else {
473 0.0
474 };
475 if mapping.consume == InputConsume::Hit {
476 consume = true;
477 }
478 }
479 }
480 }
481 }
482 if consume {
483 break;
484 }
485 }
486 }
487 }
488 }
489 WindowEvent::CursorMoved { position, .. } => {
490 for (_, mapping) in self.mappings_stack.iter().rev() {
491 if let Some(mapping) = mapping.read() {
492 let mut consume = mapping.consume == InputConsume::All;
493 for (id, data) in &mapping.axes {
494 match id {
495 VirtualAxis::MousePositionX => {
496 if let Some(mut data) = data.write() {
497 data.0 = position.x as _;
498 if mapping.consume == InputConsume::Hit {
499 consume = true;
500 }
501 }
502 }
503 VirtualAxis::MousePositionY => {
504 if let Some(mut data) = data.write() {
505 data.0 = position.y as _;
506 if mapping.consume == InputConsume::Hit {
507 consume = true;
508 }
509 }
510 }
511 _ => {}
512 }
513 }
514 if consume {
515 break;
516 }
517 }
518 }
519 }
520 WindowEvent::MouseWheel { delta, .. } => {
521 for (_, mapping) in self.mappings_stack.iter().rev() {
522 if let Some(mapping) = mapping.read() {
523 let mut consume = mapping.consume == InputConsume::All;
524 for (id, data) in &mapping.axes {
525 match id {
526 VirtualAxis::MouseWheelX => {
527 if let Some(mut data) = data.write() {
528 data.0 = match delta {
529 MouseScrollDelta::LineDelta(x, _) => *x,
530 MouseScrollDelta::PixelDelta(pos) => pos.x as _,
531 };
532 if mapping.consume == InputConsume::Hit {
533 consume = true;
534 }
535 }
536 }
537 VirtualAxis::MouseWheelY => {
538 if let Some(mut data) = data.write() {
539 data.0 = match delta {
540 MouseScrollDelta::LineDelta(_, y) => *y,
541 MouseScrollDelta::PixelDelta(pos) => pos.y as _,
542 };
543 if mapping.consume == InputConsume::Hit {
544 consume = true;
545 }
546 }
547 }
548 _ => {}
549 }
550 }
551 if consume {
552 break;
553 }
554 }
555 }
556 }
557 WindowEvent::MouseInput { state, button, .. } => {
558 for (_, mapping) in self.mappings_stack.iter().rev() {
559 if let Some(mapping) = mapping.read() {
560 let mut consume = mapping.consume == InputConsume::All;
561 for (id, data) in &mapping.actions {
562 if let VirtualAction::MouseButton(btn) = id {
563 if button == btn {
564 if let Some(mut data) = data.write() {
565 *data = data.change(*state == ElementState::Pressed);
566 if mapping.consume == InputConsume::Hit {
567 consume = true;
568 }
569 }
570 }
571 }
572 }
573 for (id, data) in &mapping.axes {
574 if let VirtualAxis::MouseButton(btn) = id {
575 if button == btn {
576 if let Some(mut data) = data.write() {
577 data.0 = if *state == ElementState::Pressed {
578 1.0
579 } else {
580 0.0
581 };
582 if mapping.consume == InputConsume::Hit {
583 consume = true;
584 }
585 }
586 }
587 }
588 }
589 if consume {
590 break;
591 }
592 }
593 }
594 }
595 WindowEvent::AxisMotion { axis, value, .. } => {
596 for (_, mapping) in self.mappings_stack.iter().rev() {
597 if let Some(mapping) = mapping.read() {
598 let mut consume = mapping.consume == InputConsume::All;
599 for (id, data) in &mapping.actions {
600 if let VirtualAction::Axis(index) = id {
601 if axis == index {
602 if let Some(mut data) = data.write() {
603 *data = data.change(value.abs() > 0.5);
604 if mapping.consume == InputConsume::Hit {
605 consume = true;
606 }
607 }
608 }
609 }
610 }
611 for (id, data) in &mapping.axes {
612 if let VirtualAxis::Axis(index) = id {
613 if axis == index {
614 if let Some(mut data) = data.write() {
615 data.0 = *value as _;
616 if mapping.consume == InputConsume::Hit {
617 consume = true;
618 }
619 }
620 }
621 }
622 }
623 if consume {
624 break;
625 }
626 }
627 }
628 }
629 _ => {}
630 }
631 }
632}
633
634#[cfg(test)]
635mod tests {
636 use crate::{InputContext, InputMapping};
637
638 #[test]
639 fn test_stack() {
640 let mut context = InputContext::default();
641 context.push_mapping(InputMapping::default().name("a").layer(0));
642 context.push_mapping(InputMapping::default().name("b").layer(0));
643 context.push_mapping(InputMapping::default().name("c").layer(0));
644 context.push_mapping(InputMapping::default().name("d").layer(-1));
645 context.push_mapping(InputMapping::default().name("e").layer(1));
646 context.push_mapping(InputMapping::default().name("f").layer(-1));
647 context.push_mapping(InputMapping::default().name("g").layer(1));
648 context.push_mapping(InputMapping::default().name("h").layer(-2));
649 context.push_mapping(InputMapping::default().name("i").layer(-2));
650 context.push_mapping(InputMapping::default().name("j").layer(2));
651 context.push_mapping(InputMapping::default().name("k").layer(2));
652
653 let provided = context
654 .stack()
655 .map(|mapping| {
656 let mapping = mapping.read().unwrap();
657 (mapping.name.as_ref().to_owned(), mapping.layer)
658 })
659 .collect::<Vec<_>>();
660 assert_eq!(
661 provided,
662 vec![
663 ("h".to_owned(), -2),
664 ("i".to_owned(), -2),
665 ("d".to_owned(), -1),
666 ("f".to_owned(), -1),
667 ("a".to_owned(), 0),
668 ("b".to_owned(), 0),
669 ("c".to_owned(), 0),
670 ("e".to_owned(), 1),
671 ("g".to_owned(), 1),
672 ("j".to_owned(), 2),
673 ("k".to_owned(), 2),
674 ]
675 );
676 }
677}