Skip to main content

astrelis_input/
lib.rs

1//! Input state management for tracking keyboard, mouse, and gamepad state.
2//!
3//! This module provides a unified input system that tracks the current state
4//! of input devices, making it easy to query whether keys or buttons are
5//! pressed, just pressed, or just released.
6//!
7//! # Example
8//!
9//! ```ignore
10//! use astrelis_input::InputState;
11//!
12//! let mut input = InputState::new();
13//!
14//! // In your event loop:
15//! input.handle_events(&mut events);
16//!
17//! // Query input state:
18//! if input.is_key_pressed(KeyCode::Space) {
19//!     player.jump();
20//! }
21//!
22//! if input.is_key_just_pressed(KeyCode::Escape) {
23//!     game.pause();
24//! }
25//!
26//! let mouse_delta = input.mouse_delta();
27//!
28//! // At the end of each frame:
29//! input.end_frame();
30//! ```
31
32use astrelis_core::alloc::HashSet;
33use astrelis_core::math::Vec2;
34use astrelis_core::profiling::profile_function;
35use astrelis_winit::event::{
36    ElementState, Event, EventBatch, HandleStatus, KeyCode, MouseButton as WinitMouseButton,
37    MouseScrollDelta, PhysicalKey,
38};
39
40// Re-export KeyCode for convenience
41pub use astrelis_winit::event::KeyCode as Key;
42
43/// Mouse button identifiers.
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
45pub enum MouseButton {
46    Left,
47    Right,
48    Middle,
49    Back,
50    Forward,
51    Other(u16),
52}
53
54impl From<WinitMouseButton> for MouseButton {
55    fn from(button: WinitMouseButton) -> Self {
56        match button {
57            WinitMouseButton::Left => MouseButton::Left,
58            WinitMouseButton::Right => MouseButton::Right,
59            WinitMouseButton::Middle => MouseButton::Middle,
60            WinitMouseButton::Back => MouseButton::Back,
61            WinitMouseButton::Forward => MouseButton::Forward,
62            WinitMouseButton::Other(id) => MouseButton::Other(id),
63        }
64    }
65}
66
67/// Modifier key state.
68#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
69pub struct Modifiers {
70    pub shift: bool,
71    pub ctrl: bool,
72    pub alt: bool,
73    pub meta: bool, // Command on macOS, Windows key on Windows
74}
75
76impl Modifiers {
77    /// Create new modifiers with all keys released.
78    pub fn new() -> Self {
79        Self::default()
80    }
81
82    /// Check if any modifier is pressed.
83    pub fn any(&self) -> bool {
84        self.shift || self.ctrl || self.alt || self.meta
85    }
86
87    /// Check if no modifiers are pressed.
88    pub fn none(&self) -> bool {
89        !self.any()
90    }
91}
92
93/// Input state tracker for keyboard, mouse, and other input devices.
94///
95/// This struct tracks the current state of all input devices and provides
96/// convenient methods for querying input state. It distinguishes between
97/// keys/buttons that are currently held, just pressed this frame, or just
98/// released this frame.
99#[derive(Debug)]
100pub struct InputState {
101    // Keyboard state
102    keys_pressed: HashSet<KeyCode>,
103    keys_just_pressed: HashSet<KeyCode>,
104    keys_just_released: HashSet<KeyCode>,
105    modifiers: Modifiers,
106
107    // Mouse state
108    mouse_buttons_pressed: HashSet<MouseButton>,
109    mouse_buttons_just_pressed: HashSet<MouseButton>,
110    mouse_buttons_just_released: HashSet<MouseButton>,
111    mouse_position: Vec2,
112    mouse_position_prev: Vec2,
113    mouse_delta: Vec2,
114    scroll_delta: Vec2,
115    mouse_in_window: bool,
116
117    // Text input
118    text_input: String,
119}
120
121impl InputState {
122    /// Create a new input state tracker.
123    pub fn new() -> Self {
124        Self {
125            keys_pressed: HashSet::new(),
126            keys_just_pressed: HashSet::new(),
127            keys_just_released: HashSet::new(),
128            modifiers: Modifiers::new(),
129
130            mouse_buttons_pressed: HashSet::new(),
131            mouse_buttons_just_pressed: HashSet::new(),
132            mouse_buttons_just_released: HashSet::new(),
133            mouse_position: Vec2::ZERO,
134            mouse_position_prev: Vec2::ZERO,
135            mouse_delta: Vec2::ZERO,
136            scroll_delta: Vec2::ZERO,
137            mouse_in_window: false,
138
139            text_input: String::new(),
140        }
141    }
142
143    /// Process events from the event batch.
144    ///
145    /// This should be called each frame before querying input state.
146    pub fn handle_events(&mut self, events: &mut EventBatch) {
147        profile_function!();
148        events.dispatch(|event| {
149            match event {
150                Event::KeyInput(key_event) => {
151                    if let PhysicalKey::Code(key_code) = key_event.physical_key {
152                        match key_event.state {
153                            ElementState::Pressed => {
154                                if !key_event.repeat {
155                                    self.keys_just_pressed.insert(key_code);
156                                }
157                                self.keys_pressed.insert(key_code);
158                                self.update_modifiers(key_code, true);
159                            }
160                            ElementState::Released => {
161                                self.keys_just_released.insert(key_code);
162                                self.keys_pressed.remove(&key_code);
163                                self.update_modifiers(key_code, false);
164                            }
165                        }
166                    }
167
168                    // Collect text input
169                    if key_event.state == ElementState::Pressed
170                        && let Some(ref text) = key_event.text {
171                            self.text_input.push_str(text);
172                        }
173
174                    HandleStatus::handled()
175                }
176                Event::MouseButtonDown(button) => {
177                    let button = MouseButton::from(*button);
178                    self.mouse_buttons_just_pressed.insert(button);
179                    self.mouse_buttons_pressed.insert(button);
180                    HandleStatus::handled()
181                }
182                Event::MouseButtonUp(button) => {
183                    let button = MouseButton::from(*button);
184                    self.mouse_buttons_just_released.insert(button);
185                    self.mouse_buttons_pressed.remove(&button);
186                    HandleStatus::handled()
187                }
188                Event::MouseMoved(pos) => {
189                    self.mouse_position = Vec2::new(pos.x as f32, pos.y as f32);
190                    HandleStatus::handled()
191                }
192                Event::MouseScrolled(delta) => {
193                    let (dx, dy) = match delta {
194                        MouseScrollDelta::LineDelta(x, y) => (*x, *y),
195                        MouseScrollDelta::PixelDelta(pos) => (pos.x as f32, pos.y as f32),
196                    };
197                    self.scroll_delta = Vec2::new(dx, dy);
198                    HandleStatus::handled()
199                }
200                Event::MouseEntered => {
201                    self.mouse_in_window = true;
202                    HandleStatus::handled()
203                }
204                Event::MouseLeft => {
205                    self.mouse_in_window = false;
206                    HandleStatus::handled()
207                }
208                _ => HandleStatus::ignored(),
209            }
210        });
211
212        // Calculate mouse delta
213        self.mouse_delta = self.mouse_position - self.mouse_position_prev;
214    }
215
216    /// Clear per-frame state. Call this at the end of each frame.
217    pub fn end_frame(&mut self) {
218        profile_function!();
219        self.keys_just_pressed.clear();
220        self.keys_just_released.clear();
221        self.mouse_buttons_just_pressed.clear();
222        self.mouse_buttons_just_released.clear();
223        self.mouse_position_prev = self.mouse_position;
224        self.mouse_delta = Vec2::ZERO;
225        self.scroll_delta = Vec2::ZERO;
226        self.text_input.clear();
227    }
228
229    // ==================== Keyboard Queries ====================
230
231    /// Check if a key is currently pressed (held down).
232    pub fn is_key_pressed(&self, key: KeyCode) -> bool {
233        self.keys_pressed.contains(&key)
234    }
235
236    /// Check if a key was just pressed this frame.
237    pub fn is_key_just_pressed(&self, key: KeyCode) -> bool {
238        self.keys_just_pressed.contains(&key)
239    }
240
241    /// Check if a key was just released this frame.
242    pub fn is_key_just_released(&self, key: KeyCode) -> bool {
243        self.keys_just_released.contains(&key)
244    }
245
246    /// Check if any of the given keys are pressed.
247    pub fn is_any_key_pressed(&self, keys: &[KeyCode]) -> bool {
248        keys.iter().any(|k| self.is_key_pressed(*k))
249    }
250
251    /// Check if all of the given keys are pressed.
252    pub fn are_all_keys_pressed(&self, keys: &[KeyCode]) -> bool {
253        keys.iter().all(|k| self.is_key_pressed(*k))
254    }
255
256    /// Get the current modifier key state.
257    pub fn modifiers(&self) -> Modifiers {
258        self.modifiers
259    }
260
261    /// Check if Shift is held.
262    pub fn is_shift_pressed(&self) -> bool {
263        self.modifiers.shift
264    }
265
266    /// Check if Ctrl (or Cmd on macOS) is held.
267    pub fn is_ctrl_pressed(&self) -> bool {
268        self.modifiers.ctrl
269    }
270
271    /// Check if Alt (or Option on macOS) is held.
272    pub fn is_alt_pressed(&self) -> bool {
273        self.modifiers.alt
274    }
275
276    /// Check if Meta (Windows key or Cmd on macOS) is held.
277    pub fn is_meta_pressed(&self) -> bool {
278        self.modifiers.meta
279    }
280
281    /// Get text input received this frame.
282    pub fn text_input(&self) -> &str {
283        &self.text_input
284    }
285
286    /// Get all keys currently pressed.
287    pub fn pressed_keys(&self) -> impl Iterator<Item = &KeyCode> {
288        self.keys_pressed.iter()
289    }
290
291    // ==================== Mouse Queries ====================
292
293    /// Check if a mouse button is currently pressed.
294    pub fn is_mouse_button_pressed(&self, button: MouseButton) -> bool {
295        self.mouse_buttons_pressed.contains(&button)
296    }
297
298    /// Check if a mouse button was just pressed this frame.
299    pub fn is_mouse_button_just_pressed(&self, button: MouseButton) -> bool {
300        self.mouse_buttons_just_pressed.contains(&button)
301    }
302
303    /// Check if a mouse button was just released this frame.
304    pub fn is_mouse_button_just_released(&self, button: MouseButton) -> bool {
305        self.mouse_buttons_just_released.contains(&button)
306    }
307
308    /// Check if left mouse button is pressed.
309    pub fn is_left_mouse_pressed(&self) -> bool {
310        self.is_mouse_button_pressed(MouseButton::Left)
311    }
312
313    /// Check if left mouse button was just pressed.
314    pub fn is_left_mouse_just_pressed(&self) -> bool {
315        self.is_mouse_button_just_pressed(MouseButton::Left)
316    }
317
318    /// Check if right mouse button is pressed.
319    pub fn is_right_mouse_pressed(&self) -> bool {
320        self.is_mouse_button_pressed(MouseButton::Right)
321    }
322
323    /// Check if right mouse button was just pressed.
324    pub fn is_right_mouse_just_pressed(&self) -> bool {
325        self.is_mouse_button_just_pressed(MouseButton::Right)
326    }
327
328    /// Check if middle mouse button is pressed.
329    pub fn is_middle_mouse_pressed(&self) -> bool {
330        self.is_mouse_button_pressed(MouseButton::Middle)
331    }
332
333    /// Get the current mouse position in window coordinates.
334    pub fn mouse_position(&self) -> Vec2 {
335        self.mouse_position
336    }
337
338    /// Get the mouse movement delta since last frame.
339    pub fn mouse_delta(&self) -> Vec2 {
340        self.mouse_delta
341    }
342
343    /// Get the scroll wheel delta since last frame.
344    ///
345    /// Positive Y = scroll up, Negative Y = scroll down.
346    pub fn scroll_delta(&self) -> Vec2 {
347        self.scroll_delta
348    }
349
350    /// Check if the mouse cursor is inside the window.
351    pub fn is_mouse_in_window(&self) -> bool {
352        self.mouse_in_window
353    }
354
355    // ==================== Helper Methods ====================
356
357    /// Get horizontal input axis (-1, 0, or 1) from arrow keys or WASD.
358    pub fn horizontal_axis(&self) -> f32 {
359        let mut axis = 0.0;
360        if self.is_key_pressed(KeyCode::ArrowLeft) || self.is_key_pressed(KeyCode::KeyA) {
361            axis -= 1.0;
362        }
363        if self.is_key_pressed(KeyCode::ArrowRight) || self.is_key_pressed(KeyCode::KeyD) {
364            axis += 1.0;
365        }
366        axis
367    }
368
369    /// Get vertical input axis (-1, 0, or 1) from arrow keys or WASD.
370    pub fn vertical_axis(&self) -> f32 {
371        let mut axis = 0.0;
372        if self.is_key_pressed(KeyCode::ArrowUp) || self.is_key_pressed(KeyCode::KeyW) {
373            axis -= 1.0;
374        }
375        if self.is_key_pressed(KeyCode::ArrowDown) || self.is_key_pressed(KeyCode::KeyS) {
376            axis += 1.0;
377        }
378        axis
379    }
380
381    /// Get movement direction as a normalized vector.
382    pub fn movement_direction(&self) -> Vec2 {
383        let dir = Vec2::new(self.horizontal_axis(), self.vertical_axis());
384        if dir.length_squared() > 0.0 {
385            dir.normalize()
386        } else {
387            dir
388        }
389    }
390
391    /// Reset all input state.
392    pub fn reset(&mut self) {
393        self.keys_pressed.clear();
394        self.keys_just_pressed.clear();
395        self.keys_just_released.clear();
396        self.modifiers = Modifiers::new();
397        self.mouse_buttons_pressed.clear();
398        self.mouse_buttons_just_pressed.clear();
399        self.mouse_buttons_just_released.clear();
400        self.mouse_delta = Vec2::ZERO;
401        self.scroll_delta = Vec2::ZERO;
402        self.text_input.clear();
403    }
404
405    // ==================== Internal Methods ====================
406
407    fn update_modifiers(&mut self, key: KeyCode, pressed: bool) {
408        match key {
409            KeyCode::ShiftLeft | KeyCode::ShiftRight => self.modifiers.shift = pressed,
410            KeyCode::ControlLeft | KeyCode::ControlRight => self.modifiers.ctrl = pressed,
411            KeyCode::AltLeft | KeyCode::AltRight => self.modifiers.alt = pressed,
412            KeyCode::SuperLeft | KeyCode::SuperRight | KeyCode::Meta => self.modifiers.meta = pressed,
413            _ => {}
414        }
415    }
416}
417
418impl Default for InputState {
419    fn default() -> Self {
420        Self::new()
421    }
422}
423
424/// An input system that wraps InputState and provides additional functionality.
425pub struct InputSystem {
426    state: InputState,
427}
428
429impl InputSystem {
430    /// Create a new input system.
431    pub fn new() -> Self {
432        Self {
433            state: InputState::new(),
434        }
435    }
436
437    /// Get the input state.
438    pub fn state(&self) -> &InputState {
439        &self.state
440    }
441
442    /// Get mutable access to the input state.
443    pub fn state_mut(&mut self) -> &mut InputState {
444        &mut self.state
445    }
446
447    /// Process events from the event batch.
448    pub fn handle_events(&mut self, events: &mut EventBatch) {
449        profile_function!();
450        self.state.handle_events(events);
451    }
452
453    /// Clear per-frame state.
454    pub fn end_frame(&mut self) {
455        self.state.end_frame();
456    }
457}
458
459impl Default for InputSystem {
460    fn default() -> Self {
461        Self::new()
462    }
463}
464
465impl std::ops::Deref for InputSystem {
466    type Target = InputState;
467
468    fn deref(&self) -> &Self::Target {
469        &self.state
470    }
471}
472
473impl std::ops::DerefMut for InputSystem {
474    fn deref_mut(&mut self) -> &mut Self::Target {
475        &mut self.state
476    }
477}
478
479#[cfg(test)]
480mod tests {
481    use super::*;
482
483    #[test]
484    fn test_modifiers_default() {
485        let mods = Modifiers::new();
486        assert!(!mods.any());
487        assert!(mods.none());
488    }
489
490    #[test]
491    fn test_modifiers_any() {
492        let mut mods = Modifiers::new();
493        mods.shift = true;
494        assert!(mods.any());
495        assert!(!mods.none());
496    }
497
498    #[test]
499    fn test_input_state_new() {
500        let state = InputState::new();
501        assert!(!state.is_key_pressed(KeyCode::Space));
502        assert!(!state.is_left_mouse_pressed());
503        assert_eq!(state.mouse_position(), Vec2::ZERO);
504    }
505
506    #[test]
507    fn test_movement_direction_normalized() {
508        let mut state = InputState::new();
509        state.keys_pressed.insert(KeyCode::KeyW);
510        state.keys_pressed.insert(KeyCode::KeyD);
511
512        let dir = state.movement_direction();
513        let len = dir.length();
514        assert!((len - 1.0).abs() < 0.001, "Direction should be normalized");
515    }
516}