#![warn(missing_docs)]
#![doc = include_str!("../readme.md")]
#[cfg(feature = "bevy_assets")]
use bevy::asset::VisitAssetDependencies;
use bevy::prelude::*;
use rand::{Rng, seq::SliceRandom as _};
use std::{
fmt::Debug,
hash::{Hash, Hasher},
};
#[derive(Component, Resource, Debug, Reflect)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
#[reflect(Component, Resource)]
#[non_exhaustive]
pub struct ShuffleBag<T> {
pub full_collection: Vec<T>,
pub current_draft: Vec<usize>,
pub last_pick: Option<usize>,
}
impl<T> ShuffleBag<T> {
pub fn try_from_iter(iter: impl IntoIterator<Item = T>, rng: &mut impl Rng) -> Option<Self> {
let full_collection: Vec<_> = iter.into_iter().collect();
Self::try_new(full_collection, rng)
}
pub fn try_new(full_collection: impl Into<Vec<T>>, rng: &mut impl Rng) -> Option<Self> {
let full_collection = full_collection.into();
if full_collection.is_empty() {
return None;
}
let mut bag = Self {
full_collection,
current_draft: vec![],
last_pick: None,
};
bag.shuffle_new_draft(rng);
Some(bag)
}
pub fn shuffle_new_draft(&mut self, rng: &mut impl Rng) {
self.current_draft = (0..self.full_collection.len()).collect();
self.current_draft.shuffle(rng);
if self.current_draft.len() <= 1 {
return;
}
let Some(last_pick) = &self.last_pick else {
return;
};
if self.current_draft.last() != Some(last_pick) {
return;
}
let max_index = self.current_draft.len() - 2;
let index = rng.random_range(0..=max_index);
let new_next_pick = self.current_draft.swap_remove(index);
self.current_draft.push(new_next_pick);
}
pub fn reset(&mut self, rng: &mut impl Rng) {
self.current_draft = vec![];
self.last_pick = None;
self.shuffle_new_draft(rng);
}
pub fn pick(&mut self, rng: &mut impl Rng) -> &T {
let pick = self.current_draft.pop().unwrap();
self.last_pick = Some(pick);
if self.current_draft.is_empty() {
self.shuffle_new_draft(rng);
}
&self.full_collection[pick]
}
pub fn peek(&self) -> &T {
&self.full_collection[*self.current_draft.last().unwrap()]
}
}
impl<T: Clone> Clone for ShuffleBag<T> {
fn clone(&self) -> Self {
Self {
full_collection: self.full_collection.clone(),
current_draft: self.current_draft.clone(),
last_pick: self.last_pick,
}
}
}
impl<T: PartialEq> PartialEq for ShuffleBag<T> {
fn eq(&self, other: &Self) -> bool {
self.full_collection == other.full_collection
&& self.current_draft == other.current_draft
&& self.last_pick == other.last_pick
}
}
impl<T: Eq> Eq for ShuffleBag<T> {}
impl<T: Hash + Clone> Hash for ShuffleBag<T> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.full_collection.hash(state);
self.current_draft.hash(state);
self.last_pick.hash(state);
}
}
impl<T: VisitAssetDependencies> VisitAssetDependencies for ShuffleBag<T> {
fn visit_dependencies(&self, visit: &mut impl FnMut(bevy::asset::UntypedAssetId)) {
for item in &self.full_collection {
item.visit_dependencies(visit);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use bevy::reflect::Typed;
use paste::paste;
#[cfg(feature = "serialize")]
use serde::{Deserialize, Serialize};
#[test]
fn fails_to_create_empty_shuffle_bag() {
let mut rng = rand::rng();
let bag = ShuffleBag::<usize>::try_new(vec![], &mut rng);
assert!(bag.is_none());
}
#[test]
fn picks_same_item_from_singular_bag() {
let mut rng = rand::rng();
let mut bag = ShuffleBag::<usize>::try_new(vec![1], &mut rng).unwrap();
for _ in 0..100 {
assert_eq!(*bag.pick(&mut rng), 1);
}
}
#[test]
fn picks_all_items_from_bag() {
let mut rng = rand::rng();
let mut bag = ShuffleBag::<usize>::try_new(vec![1, 2, 3], &mut rng).unwrap();
let mut picked = Vec::new();
for _ in 0..99 {
let item = bag.pick(&mut rng);
picked.push(*item);
}
assert_eq!(picked.len(), 99, "expected 99 items, got {}", picked.len());
let ones = picked.iter().filter(|&&item| item == 1).count();
let twos = picked.iter().filter(|&&item| item == 2).count();
let threes = picked.iter().filter(|&&item| item == 3).count();
assert!(ones == 33, "ones: {} (expected 33)", ones);
assert!(twos == 33, "twos: {} (expected 33)", twos);
assert!(threes == 33, "threes: {} (expected 33)", threes);
}
#[test]
fn never_picks_the_same_item_twice() {
let mut rng = rand::rng();
let mut bag = ShuffleBag::<usize>::try_new(vec![1, 2, 3], &mut rng).unwrap();
let mut last_pick = None;
for _ in 0..1000 {
let pick = *bag.pick(&mut rng);
assert_ne!(Some(pick), last_pick);
last_pick = Some(pick);
}
}
assert_implements_type!(
Eq,
Hash,
Clone,
Debug,
PartialEq,
PartialReflect,
Reflect,
Struct,
TypePath,
Typed,
Component,
Resource
);
#[cfg(feature = "serialize")]
assert_implements_type!(Serialize);
#[cfg(feature = "serialize")]
#[test]
fn is_deserialize() {
fn accept_type<T: for<'a> Deserialize<'a>>(_: T) {}
let mut rng = rand::rng();
let bag = ShuffleBag::try_new(vec![1, 2, 3], &mut rng).unwrap();
accept_type(bag);
}
#[derive(Asset, TypePath)]
struct _TestAsset {
#[dependency]
shuffle_bag: ShuffleBag<Handle<()>>,
}
macro_rules! assert_implements_type {
($($name:ty),*) => {
$(
paste! {
#[test]
#[allow(non_snake_case)]
fn [<is_ $name>]() {
fn accept_type<T: $name>(_: T) {}
let mut rng = rand::rng();
let bag = ShuffleBag::try_new(vec![1, 2, 3], &mut rng).unwrap();
accept_type(bag);
}
}
)*
};
}
use assert_implements_type;
}