use glam::Vec2;
use std::collections::HashSet;
use winit::event::{ElementState, MouseButton as WinitMouseButton, WindowEvent};
use winit::keyboard::{KeyCode as WinitKeyCode, PhysicalKey};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum MouseButton {
Left,
Right,
Middle,
}
impl From<WinitMouseButton> for MouseButton {
fn from(btn: WinitMouseButton) -> Self {
match btn {
WinitMouseButton::Left => MouseButton::Left,
WinitMouseButton::Right => MouseButton::Right,
WinitMouseButton::Middle => MouseButton::Middle,
_ => MouseButton::Left, }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum KeyCode {
A, B, C, D, E, F, G, H, I, J, K, L, M,
N, O, P, Q, R, S, T, U, V, W, X, Y, Z,
Key0, Key1, Key2, Key3, Key4, Key5, Key6, Key7, Key8, Key9,
F1, F2, F3, F4, F5, F6, F7, F8, F9, F10, F11, F12,
Up, Down, Left, Right,
Space, Enter, Escape, Tab, Backspace, Delete,
Shift, Control, Alt,
Other(u32),
}
impl From<WinitKeyCode> for KeyCode {
fn from(key: WinitKeyCode) -> Self {
match key {
WinitKeyCode::KeyA => KeyCode::A,
WinitKeyCode::KeyB => KeyCode::B,
WinitKeyCode::KeyC => KeyCode::C,
WinitKeyCode::KeyD => KeyCode::D,
WinitKeyCode::KeyE => KeyCode::E,
WinitKeyCode::KeyF => KeyCode::F,
WinitKeyCode::KeyG => KeyCode::G,
WinitKeyCode::KeyH => KeyCode::H,
WinitKeyCode::KeyI => KeyCode::I,
WinitKeyCode::KeyJ => KeyCode::J,
WinitKeyCode::KeyK => KeyCode::K,
WinitKeyCode::KeyL => KeyCode::L,
WinitKeyCode::KeyM => KeyCode::M,
WinitKeyCode::KeyN => KeyCode::N,
WinitKeyCode::KeyO => KeyCode::O,
WinitKeyCode::KeyP => KeyCode::P,
WinitKeyCode::KeyQ => KeyCode::Q,
WinitKeyCode::KeyR => KeyCode::R,
WinitKeyCode::KeyS => KeyCode::S,
WinitKeyCode::KeyT => KeyCode::T,
WinitKeyCode::KeyU => KeyCode::U,
WinitKeyCode::KeyV => KeyCode::V,
WinitKeyCode::KeyW => KeyCode::W,
WinitKeyCode::KeyX => KeyCode::X,
WinitKeyCode::KeyY => KeyCode::Y,
WinitKeyCode::KeyZ => KeyCode::Z,
WinitKeyCode::Digit0 => KeyCode::Key0,
WinitKeyCode::Digit1 => KeyCode::Key1,
WinitKeyCode::Digit2 => KeyCode::Key2,
WinitKeyCode::Digit3 => KeyCode::Key3,
WinitKeyCode::Digit4 => KeyCode::Key4,
WinitKeyCode::Digit5 => KeyCode::Key5,
WinitKeyCode::Digit6 => KeyCode::Key6,
WinitKeyCode::Digit7 => KeyCode::Key7,
WinitKeyCode::Digit8 => KeyCode::Key8,
WinitKeyCode::Digit9 => KeyCode::Key9,
WinitKeyCode::F1 => KeyCode::F1,
WinitKeyCode::F2 => KeyCode::F2,
WinitKeyCode::F3 => KeyCode::F3,
WinitKeyCode::F4 => KeyCode::F4,
WinitKeyCode::F5 => KeyCode::F5,
WinitKeyCode::F6 => KeyCode::F6,
WinitKeyCode::F7 => KeyCode::F7,
WinitKeyCode::F8 => KeyCode::F8,
WinitKeyCode::F9 => KeyCode::F9,
WinitKeyCode::F10 => KeyCode::F10,
WinitKeyCode::F11 => KeyCode::F11,
WinitKeyCode::F12 => KeyCode::F12,
WinitKeyCode::ArrowUp => KeyCode::Up,
WinitKeyCode::ArrowDown => KeyCode::Down,
WinitKeyCode::ArrowLeft => KeyCode::Left,
WinitKeyCode::ArrowRight => KeyCode::Right,
WinitKeyCode::Space => KeyCode::Space,
WinitKeyCode::Enter => KeyCode::Enter,
WinitKeyCode::Escape => KeyCode::Escape,
WinitKeyCode::Tab => KeyCode::Tab,
WinitKeyCode::Backspace => KeyCode::Backspace,
WinitKeyCode::Delete => KeyCode::Delete,
WinitKeyCode::ShiftLeft | WinitKeyCode::ShiftRight => KeyCode::Shift,
WinitKeyCode::ControlLeft | WinitKeyCode::ControlRight => KeyCode::Control,
WinitKeyCode::AltLeft | WinitKeyCode::AltRight => KeyCode::Alt,
_ => KeyCode::Other(key as u32),
}
}
}
#[derive(Debug, Default)]
pub struct Input {
keys_held: HashSet<KeyCode>,
keys_pressed: HashSet<KeyCode>,
keys_released: HashSet<KeyCode>,
mouse_held: HashSet<MouseButton>,
mouse_pressed: HashSet<MouseButton>,
mouse_released: HashSet<MouseButton>,
mouse_position: Vec2,
mouse_ndc: Vec2,
mouse_delta: Vec2,
last_mouse_position: Vec2,
scroll_delta: f32,
window_size: (u32, u32),
}
impl Input {
pub fn new() -> Self {
Self {
window_size: (800, 600),
..Default::default()
}
}
pub fn key_pressed(&self, key: KeyCode) -> bool {
self.keys_pressed.contains(&key)
}
pub fn key_held(&self, key: KeyCode) -> bool {
self.keys_held.contains(&key)
}
pub fn key_released(&self, key: KeyCode) -> bool {
self.keys_released.contains(&key)
}
pub fn mouse_pressed(&self, button: MouseButton) -> bool {
self.mouse_pressed.contains(&button)
}
pub fn mouse_held(&self, button: MouseButton) -> bool {
self.mouse_held.contains(&button)
}
pub fn mouse_released(&self, button: MouseButton) -> bool {
self.mouse_released.contains(&button)
}
pub fn mouse_position(&self) -> Vec2 {
self.mouse_position
}
pub fn mouse_ndc(&self) -> Vec2 {
self.mouse_ndc
}
pub fn mouse_delta(&self) -> Vec2 {
self.mouse_delta
}
pub fn scroll_delta(&self) -> f32 {
self.scroll_delta
}
pub fn window_size(&self) -> (u32, u32) {
self.window_size
}
pub fn aspect_ratio(&self) -> f32 {
let (w, h) = self.window_size;
if h > 0 {
w as f32 / h as f32
} else {
1.0
}
}
pub(crate) fn begin_frame(&mut self) {
self.keys_pressed.clear();
self.keys_released.clear();
self.mouse_pressed.clear();
self.mouse_released.clear();
self.mouse_delta = Vec2::ZERO;
self.scroll_delta = 0.0;
}
pub(crate) fn set_window_size(&mut self, width: u32, height: u32) {
self.window_size = (width, height);
}
pub(crate) fn handle_event(&mut self, event: &WindowEvent) {
match event {
WindowEvent::KeyboardInput { event, .. } => {
if let PhysicalKey::Code(keycode) = event.physical_key {
let key = KeyCode::from(keycode);
match event.state {
ElementState::Pressed => {
if !self.keys_held.contains(&key) {
self.keys_pressed.insert(key);
}
self.keys_held.insert(key);
}
ElementState::Released => {
self.keys_held.remove(&key);
self.keys_released.insert(key);
}
}
}
}
WindowEvent::MouseInput { state, button, .. } => {
let btn = MouseButton::from(*button);
match state {
ElementState::Pressed => {
self.mouse_pressed.insert(btn);
self.mouse_held.insert(btn);
}
ElementState::Released => {
self.mouse_held.remove(&btn);
self.mouse_released.insert(btn);
}
}
}
WindowEvent::CursorMoved { position, .. } => {
let new_pos = Vec2::new(position.x as f32, position.y as f32);
self.mouse_delta = new_pos - self.last_mouse_position;
self.last_mouse_position = self.mouse_position;
self.mouse_position = new_pos;
let (w, h) = self.window_size;
if w > 0 && h > 0 {
self.mouse_ndc = Vec2::new(
(position.x as f32 / w as f32) * 2.0 - 1.0,
1.0 - (position.y as f32 / h as f32) * 2.0, );
}
}
WindowEvent::MouseWheel { delta, .. } => {
self.scroll_delta = match delta {
winit::event::MouseScrollDelta::LineDelta(_, y) => *y,
winit::event::MouseScrollDelta::PixelDelta(pos) => pos.y as f32 / 100.0,
};
}
_ => {}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_key_state() {
let mut input = Input::new();
assert!(!input.key_held(KeyCode::Space));
assert!(!input.key_pressed(KeyCode::Space));
input.keys_pressed.insert(KeyCode::Space);
input.keys_held.insert(KeyCode::Space);
assert!(input.key_held(KeyCode::Space));
assert!(input.key_pressed(KeyCode::Space));
input.begin_frame();
assert!(input.key_held(KeyCode::Space));
assert!(!input.key_pressed(KeyCode::Space));
}
#[test]
fn test_mouse_ndc() {
let mut input = Input::new();
input.set_window_size(800, 600);
input.mouse_position = Vec2::new(400.0, 300.0);
input.mouse_ndc = Vec2::new(
(400.0 / 800.0) * 2.0 - 1.0,
1.0 - (300.0 / 600.0) * 2.0,
);
assert!((input.mouse_ndc().x).abs() < 0.01);
assert!((input.mouse_ndc().y).abs() < 0.01);
}
}