use std::collections::HashSet;
use std::hash::Hash;
pub struct ButtonState<T: Eq + Hash + Copy> {
pressed: HashSet<T>,
just_pressed: HashSet<T>,
just_released: HashSet<T>,
}
impl<T: Eq + Hash + Copy> ButtonState<T> {
pub fn new() -> Self {
Self {
pressed: HashSet::new(),
just_pressed: HashSet::new(),
just_released: HashSet::new(),
}
}
pub fn press(&mut self, button: T) {
if self.pressed.insert(button) {
self.just_pressed.insert(button);
}
}
pub fn release(&mut self, button: T) {
if self.pressed.remove(&button) {
self.just_released.insert(button);
}
}
pub fn pressed(&self, button: T) -> bool {
self.pressed.contains(&button)
}
pub fn just_pressed(&self, button: T) -> bool {
self.just_pressed.contains(&button)
}
pub fn just_released(&self, button: T) -> bool {
self.just_released.contains(&button)
}
pub fn clear_just(&mut self) {
self.just_pressed.clear();
self.just_released.clear();
}
pub fn reset(&mut self) {
self.pressed.clear();
self.just_pressed.clear();
self.just_released.clear();
}
pub fn get_pressed(&self) -> impl Iterator<Item = &T> {
self.pressed.iter()
}
pub fn get_just_pressed(&self) -> impl Iterator<Item = &T> {
self.just_pressed.iter()
}
pub fn get_just_released(&self) -> impl Iterator<Item = &T> {
self.just_released.iter()
}
pub fn pressed_count(&self) -> usize {
self.pressed.len()
}
pub fn any_pressed(&self) -> bool {
!self.pressed.is_empty()
}
pub fn any_just_pressed(&self) -> bool {
!self.just_pressed.is_empty()
}
}
impl<T: Eq + Hash + Copy> Default for ButtonState<T> {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
enum TestButton {
A,
B,
}
#[test]
fn press_and_query() {
let mut state = ButtonState::new();
state.press(TestButton::A);
assert!(state.pressed(TestButton::A));
assert!(state.just_pressed(TestButton::A));
assert!(!state.just_released(TestButton::A));
assert!(!state.pressed(TestButton::B));
}
#[test]
fn release_and_query() {
let mut state = ButtonState::new();
state.press(TestButton::A);
state.clear_just();
state.release(TestButton::A);
assert!(!state.pressed(TestButton::A));
assert!(state.just_released(TestButton::A));
assert!(!state.just_pressed(TestButton::A));
}
#[test]
fn clear_just_resets_transient() {
let mut state = ButtonState::new();
state.press(TestButton::A);
assert!(state.just_pressed(TestButton::A));
state.clear_just();
assert!(state.pressed(TestButton::A));
assert!(!state.just_pressed(TestButton::A));
}
#[test]
fn double_press_no_duplicate_just() {
let mut state = ButtonState::new();
state.press(TestButton::A);
state.press(TestButton::A); assert!(state.pressed(TestButton::A));
assert_eq!(state.pressed_count(), 1);
}
#[test]
fn release_without_press_noop() {
let mut state = ButtonState::new();
state.release(TestButton::A);
assert!(!state.just_released(TestButton::A));
}
#[test]
fn reset_clears_everything() {
let mut state = ButtonState::new();
state.press(TestButton::A);
state.press(TestButton::B);
state.reset();
assert!(!state.pressed(TestButton::A));
assert!(!state.pressed(TestButton::B));
assert!(!state.any_pressed());
assert!(!state.any_just_pressed());
}
#[test]
fn multiple_buttons() {
let mut state = ButtonState::new();
state.press(TestButton::A);
state.press(TestButton::B);
assert_eq!(state.pressed_count(), 2);
assert!(state.any_pressed());
state.release(TestButton::A);
assert_eq!(state.pressed_count(), 1);
assert!(state.pressed(TestButton::B));
}
#[test]
fn frame_lifecycle() {
let mut state = ButtonState::new();
state.press(TestButton::A);
assert!(state.just_pressed(TestButton::A));
state.clear_just();
assert!(state.pressed(TestButton::A));
assert!(!state.just_pressed(TestButton::A));
state.press(TestButton::B);
assert!(state.just_pressed(TestButton::B));
state.clear_just();
state.release(TestButton::A);
assert!(state.just_released(TestButton::A));
assert!(!state.pressed(TestButton::A));
assert!(state.pressed(TestButton::B));
}
}