Skip to main content

astrelis_input/
lib.rs

1//! Input state management for tracking keyboard, mouse, and gamepad state.
2//!
3//! This module provides a unified input system that tracks the current state
4//! of input devices, making it easy to query whether keys or buttons are
5//! pressed, just pressed, or just released.
6//!
7//! # Example
8//!
9//! ```ignore
10//! use astrelis_input::InputState;
11//!
12//! let mut input = InputState::new();
13//!
14//! // In your event loop:
15//! input.handle_events(&mut events);
16//!
17//! // Query input state:
18//! if input.is_key_pressed(KeyCode::Space) {
19//!     player.jump();
20//! }
21//!
22//! if input.is_key_just_pressed(KeyCode::Escape) {
23//!     game.pause();
24//! }
25//!
26//! let mouse_delta = input.mouse_delta();
27//!
28//! // At the end of each frame:
29//! input.end_frame();
30//! ```
31
32use astrelis_core::alloc::HashSet;
33use astrelis_core::math::Vec2;
34use astrelis_core::profiling::profile_function;
35use astrelis_winit::event::{
36    ElementState, Event, EventBatch, HandleStatus, KeyCode, MouseButton as WinitMouseButton,
37    MouseScrollDelta, PhysicalKey,
38};
39
40// Re-export KeyCode for convenience
41pub use astrelis_winit::event::KeyCode as Key;
42
43/// Mouse button identifiers.
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
45pub enum MouseButton {
46    Left,
47    Right,
48    Middle,
49    Back,
50    Forward,
51    Other(u16),
52}
53
54impl From<WinitMouseButton> for MouseButton {
55    fn from(button: WinitMouseButton) -> Self {
56        match button {
57            WinitMouseButton::Left => MouseButton::Left,
58            WinitMouseButton::Right => MouseButton::Right,
59            WinitMouseButton::Middle => MouseButton::Middle,
60            WinitMouseButton::Back => MouseButton::Back,
61            WinitMouseButton::Forward => MouseButton::Forward,
62            WinitMouseButton::Other(id) => MouseButton::Other(id),
63        }
64    }
65}
66
67/// Modifier key state.
68#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
69pub struct Modifiers {
70    pub shift: bool,
71    pub ctrl: bool,
72    pub alt: bool,
73    pub meta: bool, // Command on macOS, Windows key on Windows
74}
75
76impl Modifiers {
77    /// Create new modifiers with all keys released.
78    pub fn new() -> Self {
79        Self::default()
80    }
81
82    /// Check if any modifier is pressed.
83    pub fn any(&self) -> bool {
84        self.shift || self.ctrl || self.alt || self.meta
85    }
86
87    /// Check if no modifiers are pressed.
88    pub fn none(&self) -> bool {
89        !self.any()
90    }
91}
92
93/// Input state tracker for keyboard, mouse, and other input devices.
94///
95/// This struct tracks the current state of all input devices and provides
96/// convenient methods for querying input state. It distinguishes between
97/// keys/buttons that are currently held, just pressed this frame, or just
98/// released this frame.
99#[derive(Debug)]
100pub struct InputState {
101    // Keyboard state
102    keys_pressed: HashSet<KeyCode>,
103    keys_just_pressed: HashSet<KeyCode>,
104    keys_just_released: HashSet<KeyCode>,
105    modifiers: Modifiers,
106
107    // Mouse state
108    mouse_buttons_pressed: HashSet<MouseButton>,
109    mouse_buttons_just_pressed: HashSet<MouseButton>,
110    mouse_buttons_just_released: HashSet<MouseButton>,
111    mouse_position: Vec2,
112    mouse_position_prev: Vec2,
113    mouse_delta: Vec2,
114    scroll_delta: Vec2,
115    mouse_in_window: bool,
116
117    // Text input
118    text_input: String,
119}
120
121impl InputState {
122    /// Create a new input state tracker.
123    pub fn new() -> Self {
124        Self {
125            keys_pressed: HashSet::new(),
126            keys_just_pressed: HashSet::new(),
127            keys_just_released: HashSet::new(),
128            modifiers: Modifiers::new(),
129
130            mouse_buttons_pressed: HashSet::new(),
131            mouse_buttons_just_pressed: HashSet::new(),
132            mouse_buttons_just_released: HashSet::new(),
133            mouse_position: Vec2::ZERO,
134            mouse_position_prev: Vec2::ZERO,
135            mouse_delta: Vec2::ZERO,
136            scroll_delta: Vec2::ZERO,
137            mouse_in_window: false,
138
139            text_input: String::new(),
140        }
141    }
142
143    /// Process events from the event batch.
144    ///
145    /// This should be called each frame before querying input state.
146    pub fn handle_events(&mut self, events: &mut EventBatch) {
147        profile_function!();
148        events.dispatch(|event| {
149            match event {
150                Event::KeyInput(key_event) => {
151                    if let PhysicalKey::Code(key_code) = key_event.physical_key {
152                        match key_event.state {
153                            ElementState::Pressed => {
154                                if !key_event.repeat {
155                                    self.keys_just_pressed.insert(key_code);
156                                }
157                                self.keys_pressed.insert(key_code);
158                                self.update_modifiers(key_code, true);
159                            }
160                            ElementState::Released => {
161                                self.keys_just_released.insert(key_code);
162                                self.keys_pressed.remove(&key_code);
163                                self.update_modifiers(key_code, false);
164                            }
165                        }
166                    }
167
168                    // Collect text input
169                    if key_event.state == ElementState::Pressed
170                        && let Some(ref text) = key_event.text
171                    {
172                        self.text_input.push_str(text);
173                    }
174
175                    HandleStatus::handled()
176                }
177                Event::MouseButtonDown(button) => {
178                    let button = MouseButton::from(*button);
179                    self.mouse_buttons_just_pressed.insert(button);
180                    self.mouse_buttons_pressed.insert(button);
181                    HandleStatus::handled()
182                }
183                Event::MouseButtonUp(button) => {
184                    let button = MouseButton::from(*button);
185                    self.mouse_buttons_just_released.insert(button);
186                    self.mouse_buttons_pressed.remove(&button);
187                    HandleStatus::handled()
188                }
189                Event::MouseMoved(pos) => {
190                    self.mouse_position = Vec2::new(pos.x as f32, pos.y as f32);
191                    HandleStatus::handled()
192                }
193                Event::MouseScrolled(delta) => {
194                    let (dx, dy) = match delta {
195                        MouseScrollDelta::LineDelta(x, y) => (*x, *y),
196                        MouseScrollDelta::PixelDelta(pos) => (pos.x as f32, pos.y as f32),
197                    };
198                    self.scroll_delta = Vec2::new(dx, dy);
199                    HandleStatus::handled()
200                }
201                Event::MouseEntered => {
202                    self.mouse_in_window = true;
203                    HandleStatus::handled()
204                }
205                Event::MouseLeft => {
206                    self.mouse_in_window = false;
207                    HandleStatus::handled()
208                }
209                _ => HandleStatus::ignored(),
210            }
211        });
212
213        // Calculate mouse delta
214        self.mouse_delta = self.mouse_position - self.mouse_position_prev;
215    }
216
217    /// Clear per-frame state. Call this at the end of each frame.
218    pub fn end_frame(&mut self) {
219        profile_function!();
220        self.keys_just_pressed.clear();
221        self.keys_just_released.clear();
222        self.mouse_buttons_just_pressed.clear();
223        self.mouse_buttons_just_released.clear();
224        self.mouse_position_prev = self.mouse_position;
225        self.mouse_delta = Vec2::ZERO;
226        self.scroll_delta = Vec2::ZERO;
227        self.text_input.clear();
228    }
229
230    // ==================== Keyboard Queries ====================
231
232    /// Check if a key is currently pressed (held down).
233    pub fn is_key_pressed(&self, key: KeyCode) -> bool {
234        self.keys_pressed.contains(&key)
235    }
236
237    /// Check if a key was just pressed this frame.
238    pub fn is_key_just_pressed(&self, key: KeyCode) -> bool {
239        self.keys_just_pressed.contains(&key)
240    }
241
242    /// Check if a key was just released this frame.
243    pub fn is_key_just_released(&self, key: KeyCode) -> bool {
244        self.keys_just_released.contains(&key)
245    }
246
247    /// Check if any of the given keys are pressed.
248    pub fn is_any_key_pressed(&self, keys: &[KeyCode]) -> bool {
249        keys.iter().any(|k| self.is_key_pressed(*k))
250    }
251
252    /// Check if all of the given keys are pressed.
253    pub fn are_all_keys_pressed(&self, keys: &[KeyCode]) -> bool {
254        keys.iter().all(|k| self.is_key_pressed(*k))
255    }
256
257    /// Get the current modifier key state.
258    pub fn modifiers(&self) -> Modifiers {
259        self.modifiers
260    }
261
262    /// Check if Shift is held.
263    pub fn is_shift_pressed(&self) -> bool {
264        self.modifiers.shift
265    }
266
267    /// Check if Ctrl (or Cmd on macOS) is held.
268    pub fn is_ctrl_pressed(&self) -> bool {
269        self.modifiers.ctrl
270    }
271
272    /// Check if Alt (or Option on macOS) is held.
273    pub fn is_alt_pressed(&self) -> bool {
274        self.modifiers.alt
275    }
276
277    /// Check if Meta (Windows key or Cmd on macOS) is held.
278    pub fn is_meta_pressed(&self) -> bool {
279        self.modifiers.meta
280    }
281
282    /// Get text input received this frame.
283    pub fn text_input(&self) -> &str {
284        &self.text_input
285    }
286
287    /// Get all keys currently pressed.
288    pub fn pressed_keys(&self) -> impl Iterator<Item = &KeyCode> {
289        self.keys_pressed.iter()
290    }
291
292    // ==================== Mouse Queries ====================
293
294    /// Check if a mouse button is currently pressed.
295    pub fn is_mouse_button_pressed(&self, button: MouseButton) -> bool {
296        self.mouse_buttons_pressed.contains(&button)
297    }
298
299    /// Check if a mouse button was just pressed this frame.
300    pub fn is_mouse_button_just_pressed(&self, button: MouseButton) -> bool {
301        self.mouse_buttons_just_pressed.contains(&button)
302    }
303
304    /// Check if a mouse button was just released this frame.
305    pub fn is_mouse_button_just_released(&self, button: MouseButton) -> bool {
306        self.mouse_buttons_just_released.contains(&button)
307    }
308
309    /// Check if left mouse button is pressed.
310    pub fn is_left_mouse_pressed(&self) -> bool {
311        self.is_mouse_button_pressed(MouseButton::Left)
312    }
313
314    /// Check if left mouse button was just pressed.
315    pub fn is_left_mouse_just_pressed(&self) -> bool {
316        self.is_mouse_button_just_pressed(MouseButton::Left)
317    }
318
319    /// Check if right mouse button is pressed.
320    pub fn is_right_mouse_pressed(&self) -> bool {
321        self.is_mouse_button_pressed(MouseButton::Right)
322    }
323
324    /// Check if right mouse button was just pressed.
325    pub fn is_right_mouse_just_pressed(&self) -> bool {
326        self.is_mouse_button_just_pressed(MouseButton::Right)
327    }
328
329    /// Check if middle mouse button is pressed.
330    pub fn is_middle_mouse_pressed(&self) -> bool {
331        self.is_mouse_button_pressed(MouseButton::Middle)
332    }
333
334    /// Get the current mouse position in window coordinates.
335    pub fn mouse_position(&self) -> Vec2 {
336        self.mouse_position
337    }
338
339    /// Get the mouse movement delta since last frame.
340    pub fn mouse_delta(&self) -> Vec2 {
341        self.mouse_delta
342    }
343
344    /// Get the scroll wheel delta since last frame.
345    ///
346    /// Positive Y = scroll up, Negative Y = scroll down.
347    pub fn scroll_delta(&self) -> Vec2 {
348        self.scroll_delta
349    }
350
351    /// Check if the mouse cursor is inside the window.
352    pub fn is_mouse_in_window(&self) -> bool {
353        self.mouse_in_window
354    }
355
356    // ==================== Helper Methods ====================
357
358    /// Get horizontal input axis (-1, 0, or 1) from arrow keys or WASD.
359    pub fn horizontal_axis(&self) -> f32 {
360        let mut axis = 0.0;
361        if self.is_key_pressed(KeyCode::ArrowLeft) || self.is_key_pressed(KeyCode::KeyA) {
362            axis -= 1.0;
363        }
364        if self.is_key_pressed(KeyCode::ArrowRight) || self.is_key_pressed(KeyCode::KeyD) {
365            axis += 1.0;
366        }
367        axis
368    }
369
370    /// Get vertical input axis (-1, 0, or 1) from arrow keys or WASD.
371    pub fn vertical_axis(&self) -> f32 {
372        let mut axis = 0.0;
373        if self.is_key_pressed(KeyCode::ArrowUp) || self.is_key_pressed(KeyCode::KeyW) {
374            axis -= 1.0;
375        }
376        if self.is_key_pressed(KeyCode::ArrowDown) || self.is_key_pressed(KeyCode::KeyS) {
377            axis += 1.0;
378        }
379        axis
380    }
381
382    /// Get movement direction as a normalized vector.
383    pub fn movement_direction(&self) -> Vec2 {
384        let dir = Vec2::new(self.horizontal_axis(), self.vertical_axis());
385        if dir.length_squared() > 0.0 {
386            dir.normalize()
387        } else {
388            dir
389        }
390    }
391
392    /// Reset all input state.
393    pub fn reset(&mut self) {
394        self.keys_pressed.clear();
395        self.keys_just_pressed.clear();
396        self.keys_just_released.clear();
397        self.modifiers = Modifiers::new();
398        self.mouse_buttons_pressed.clear();
399        self.mouse_buttons_just_pressed.clear();
400        self.mouse_buttons_just_released.clear();
401        self.mouse_delta = Vec2::ZERO;
402        self.scroll_delta = Vec2::ZERO;
403        self.text_input.clear();
404    }
405
406    // ==================== Internal Methods ====================
407
408    fn update_modifiers(&mut self, key: KeyCode, pressed: bool) {
409        match key {
410            KeyCode::ShiftLeft | KeyCode::ShiftRight => self.modifiers.shift = pressed,
411            KeyCode::ControlLeft | KeyCode::ControlRight => self.modifiers.ctrl = pressed,
412            KeyCode::AltLeft | KeyCode::AltRight => self.modifiers.alt = pressed,
413            KeyCode::SuperLeft | KeyCode::SuperRight | KeyCode::Meta => {
414                self.modifiers.meta = pressed
415            }
416            _ => {}
417        }
418    }
419}
420
421impl Default for InputState {
422    fn default() -> Self {
423        Self::new()
424    }
425}
426
427/// An input system that wraps InputState and provides additional functionality.
428pub struct InputSystem {
429    state: InputState,
430}
431
432impl InputSystem {
433    /// Create a new input system.
434    pub fn new() -> Self {
435        Self {
436            state: InputState::new(),
437        }
438    }
439
440    /// Get the input state.
441    pub fn state(&self) -> &InputState {
442        &self.state
443    }
444
445    /// Get mutable access to the input state.
446    pub fn state_mut(&mut self) -> &mut InputState {
447        &mut self.state
448    }
449
450    /// Process events from the event batch.
451    pub fn handle_events(&mut self, events: &mut EventBatch) {
452        profile_function!();
453        self.state.handle_events(events);
454    }
455
456    /// Clear per-frame state.
457    pub fn end_frame(&mut self) {
458        self.state.end_frame();
459    }
460}
461
462impl Default for InputSystem {
463    fn default() -> Self {
464        Self::new()
465    }
466}
467
468impl std::ops::Deref for InputSystem {
469    type Target = InputState;
470
471    fn deref(&self) -> &Self::Target {
472        &self.state
473    }
474}
475
476impl std::ops::DerefMut for InputSystem {
477    fn deref_mut(&mut self) -> &mut Self::Target {
478        &mut self.state
479    }
480}
481
482#[cfg(test)]
483mod tests {
484    use super::*;
485
486    #[test]
487    fn test_modifiers_default() {
488        let mods = Modifiers::new();
489        assert!(!mods.any());
490        assert!(mods.none());
491    }
492
493    #[test]
494    fn test_modifiers_any() {
495        let mut mods = Modifiers::new();
496        mods.shift = true;
497        assert!(mods.any());
498        assert!(!mods.none());
499    }
500
501    #[test]
502    fn test_input_state_new() {
503        let state = InputState::new();
504        assert!(!state.is_key_pressed(KeyCode::Space));
505        assert!(!state.is_left_mouse_pressed());
506        assert_eq!(state.mouse_position(), Vec2::ZERO);
507    }
508
509    #[test]
510    fn test_movement_direction_normalized() {
511        let mut state = InputState::new();
512        state.keys_pressed.insert(KeyCode::KeyW);
513        state.keys_pressed.insert(KeyCode::KeyD);
514
515        let dir = state.movement_direction();
516        let len = dir.length();
517        assert!((len - 1.0).abs() < 0.001, "Direction should be normalized");
518    }
519}