1use std::{
2 marker::PhantomData,
3 ops::Add,
4};
5
6use anyhow::{
7 Error,
8 Result,
9};
10use serde::{
11 Deserialize,
12 Serialize,
13};
14use serde_string_enum::{
15 DeserializeLabeledStringEnum,
16 SerializeLabeledStringEnum,
17};
18
19use crate::Stat;
20
21#[derive(
23 Debug,
24 Clone,
25 Copy,
26 PartialEq,
27 Eq,
28 Hash,
29 SerializeLabeledStringEnum,
30 DeserializeLabeledStringEnum,
31)]
32pub enum Boost {
33 #[string = "atk"]
34 #[alias = "Attack"]
35 Atk,
36 #[string = "def"]
37 #[alias = "Defense"]
38 Def,
39 #[string = "spa"]
40 #[alias = "spatk"]
41 #[alias = "Sp.Atk"]
42 #[alias = "Special Attack"]
43 SpAtk,
44 #[string = "spd"]
45 #[alias = "spdef"]
46 #[alias = "Sp.Def"]
47 #[alias = "Special Defense"]
48 SpDef,
49 #[string = "spe"]
50 #[alias = "Speed"]
51 Spe,
52 #[string = "acc"]
53 #[alias = "Accuracy"]
54 Accuracy,
55 #[string = "eva"]
56 #[alias = "Evasion"]
57 Evasion,
58}
59
60impl TryFrom<Stat> for Boost {
61 type Error = Error;
62 fn try_from(value: Stat) -> Result<Self, Self::Error> {
63 match value {
64 Stat::HP => Err(Error::msg("HP cannot be boosted")),
65 Stat::Atk => Ok(Self::Atk),
66 Stat::Def => Ok(Self::Def),
67 Stat::SpAtk => Ok(Self::SpAtk),
68 Stat::SpDef => Ok(Self::SpDef),
69 Stat::Spe => Ok(Self::Spe),
70 }
71 }
72}
73
74pub trait ContainsOptionalBoosts<T> {
76 fn get_boost(&self, boost: Boost) -> Option<(Boost, T)>;
77}
78
79#[derive(Debug, Default, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
81pub struct BoostTable {
82 #[serde(default)]
83 pub atk: i8,
84 #[serde(default)]
85 pub def: i8,
86 #[serde(default)]
87 pub spa: i8,
88 #[serde(default)]
89 pub spd: i8,
90 #[serde(default)]
91 pub spe: i8,
92 #[serde(default)]
93 pub acc: i8,
94 #[serde(default)]
95 pub eva: i8,
96}
97
98impl BoostTable {
99 pub fn new() -> Self {
101 Self::default()
102 }
103
104 pub fn get(&self, boost: Boost) -> i8 {
106 match boost {
107 Boost::Atk => self.atk,
108 Boost::Def => self.def,
109 Boost::SpAtk => self.spa,
110 Boost::SpDef => self.spd,
111 Boost::Spe => self.spe,
112 Boost::Accuracy => self.acc,
113 Boost::Evasion => self.eva,
114 }
115 }
116
117 fn get_mut(&mut self, boost: Boost) -> &mut i8 {
119 match boost {
120 Boost::Atk => &mut self.atk,
121 Boost::Def => &mut self.def,
122 Boost::SpAtk => &mut self.spa,
123 Boost::SpDef => &mut self.spd,
124 Boost::Spe => &mut self.spe,
125 Boost::Accuracy => &mut self.acc,
126 Boost::Evasion => &mut self.eva,
127 }
128 }
129
130 pub fn set(&mut self, boost: Boost, value: i8) {
132 *self.get_mut(boost) = value;
133 }
134
135 pub fn iter<'a>(&'a self) -> impl Iterator<Item = (Boost, i8)> + 'a {
137 BoostTableEntries::new(self)
138 }
139
140 pub fn non_zero_iter<'a>(&'a self) -> impl Iterator<Item = (Boost, i8)> + 'a {
142 self.iter().filter(|(_, val)| *val != 0)
143 }
144
145 pub fn values<'a>(&'a self) -> impl Iterator<Item = i8> + 'a {
147 self.iter().map(|(_, val)| val)
148 }
149}
150
151impl FromIterator<(Boost, i8)> for BoostTable {
152 fn from_iter<T: IntoIterator<Item = (Boost, i8)>>(iter: T) -> Self {
153 let mut table = Self::new();
154 for (boost, value) in iter {
155 *table.get_mut(boost) = value;
156 }
157 table
158 }
159}
160
161impl Add for &BoostTable {
162 type Output = BoostTable;
163 fn add(self, rhs: Self) -> Self::Output {
164 BoostTable {
165 atk: self.atk + rhs.atk,
166 def: self.def + rhs.def,
167 spa: self.spa + rhs.spa,
168 spd: self.spd + rhs.spd,
169 spe: self.spe + rhs.spe,
170 acc: self.acc + rhs.acc,
171 eva: self.eva + rhs.eva,
172 }
173 }
174}
175
176impl ContainsOptionalBoosts<i8> for BoostTable {
177 fn get_boost(&self, boost: Boost) -> Option<(Boost, i8)> {
178 Some((boost, self.get(boost)))
179 }
180}
181
182pub struct BoostOrderIterator {
184 next: Option<Boost>,
185}
186
187impl BoostOrderIterator {
188 pub fn new() -> Self {
190 Self {
191 next: Some(Boost::Atk),
192 }
193 }
194
195 fn next_internal(&mut self) -> Option<Boost> {
196 let out = self.next;
197 self.next = match self.next {
198 Some(Boost::Atk) => Some(Boost::Def),
199 Some(Boost::Def) => Some(Boost::SpAtk),
200 Some(Boost::SpAtk) => Some(Boost::SpDef),
201 Some(Boost::SpDef) => Some(Boost::Spe),
202 Some(Boost::Spe) => Some(Boost::Accuracy),
203 Some(Boost::Accuracy) => Some(Boost::Evasion),
204 None | Some(Boost::Evasion) => None,
205 };
206 out
207 }
208}
209
210impl Iterator for BoostOrderIterator {
211 type Item = Boost;
212 fn next(&mut self) -> Option<Self::Item> {
213 self.next_internal()
214 }
215}
216
217pub struct BoostTableEntries<'m, B, T>
220where
221 B: ContainsOptionalBoosts<T>,
222 T: Copy,
223{
224 table: &'m B,
225 boost_iter: BoostOrderIterator,
226 _phantom: PhantomData<T>,
227}
228
229impl<'m, B, T> BoostTableEntries<'m, B, T>
230where
231 B: ContainsOptionalBoosts<T>,
232 T: Copy,
233{
234 pub fn new(table: &'m B) -> Self {
236 Self {
237 table,
238 boost_iter: BoostOrderIterator::new(),
239 _phantom: PhantomData,
240 }
241 }
242
243 fn next_non_zero_entry(&mut self) -> Option<(Boost, T)> {
244 while let Some(boost) = self.boost_iter.next() {
245 let entry = self.table.get_boost(boost);
246 if entry.is_some() {
247 return entry;
248 }
249 }
250 None
251 }
252}
253
254impl<'m, B, T> Iterator for BoostTableEntries<'m, B, T>
255where
256 B: ContainsOptionalBoosts<T>,
257 T: Copy,
258{
259 type Item = (Boost, T);
260 fn next(&mut self) -> Option<Self::Item> {
261 self.next_non_zero_entry()
262 }
263}
264
265#[cfg(test)]
266mod boost_test {
267 use crate::{
268 Boost,
269 test_util::{
270 test_string_deserialization,
271 test_string_serialization,
272 },
273 };
274
275 #[test]
276 fn serializes_to_string() {
277 test_string_serialization(Boost::Atk, "atk");
278 test_string_serialization(Boost::Def, "def");
279 test_string_serialization(Boost::SpAtk, "spa");
280 test_string_serialization(Boost::SpDef, "spd");
281 test_string_serialization(Boost::Spe, "spe");
282 test_string_serialization(Boost::Accuracy, "acc");
283 test_string_serialization(Boost::Evasion, "eva");
284 }
285
286 #[test]
287 fn deserializes_capitalized() {
288 test_string_deserialization("Atk", Boost::Atk);
289 test_string_deserialization("Def", Boost::Def);
290 test_string_deserialization("SpAtk", Boost::SpAtk);
291 test_string_deserialization("SpDef", Boost::SpDef);
292 test_string_deserialization("Spe", Boost::Spe);
293 test_string_deserialization("Acc", Boost::Accuracy);
294 test_string_deserialization("Eva", Boost::Evasion);
295 }
296
297 #[test]
298 fn deserializes_full_names() {
299 test_string_deserialization("Attack", Boost::Atk);
300 test_string_deserialization("Defense", Boost::Def);
301 test_string_deserialization("Special Attack", Boost::SpAtk);
302 test_string_deserialization("Sp.Atk", Boost::SpAtk);
303 test_string_deserialization("Special Defense", Boost::SpDef);
304 test_string_deserialization("Sp.Def", Boost::SpDef);
305 test_string_deserialization("Speed", Boost::Spe);
306 test_string_deserialization("Accuracy", Boost::Accuracy);
307 test_string_deserialization("Evasion", Boost::Evasion);
308 }
309}
310
311#[cfg(test)]
312mod boost_table_test {
313
314 use crate::{
315 Boost,
316 BoostTable,
317 };
318
319 #[test]
320 fn gets_associated_value() {
321 let bt = BoostTable {
322 atk: 1,
323 def: 2,
324 spa: 3,
325 spd: 4,
326 spe: 5,
327 acc: 6,
328 eva: 7,
329 };
330 assert_eq!(bt.get(Boost::Atk), 1);
331 assert_eq!(bt.get(Boost::Def), 2);
332 assert_eq!(bt.get(Boost::SpAtk), 3);
333 assert_eq!(bt.get(Boost::SpDef), 4);
334 assert_eq!(bt.get(Boost::Spe), 5);
335 assert_eq!(bt.get(Boost::Accuracy), 6);
336 assert_eq!(bt.get(Boost::Evasion), 7);
337 }
338
339 #[test]
340 fn iterates_entries_in_order() {
341 let mut table = BoostTable::new();
342 assert_eq!(
343 table.non_zero_iter().collect::<Vec<(Boost, i8)>>(),
344 Vec::<(Boost, i8)>::new(),
345 );
346
347 *table.get_mut(Boost::SpAtk) = 1;
348 assert_eq!(
349 table.non_zero_iter().collect::<Vec<(Boost, i8)>>(),
350 vec![(Boost::SpAtk, 1)],
351 );
352
353 *table.get_mut(Boost::Atk) = 2;
354 assert_eq!(
355 table.non_zero_iter().collect::<Vec<(Boost, i8)>>(),
356 vec![(Boost::Atk, 2), (Boost::SpAtk, 1)],
357 );
358
359 *table.get_mut(Boost::Accuracy) = -1;
360 assert_eq!(
361 table.non_zero_iter().collect::<Vec<(Boost, i8)>>(),
362 vec![(Boost::Atk, 2), (Boost::SpAtk, 1), (Boost::Accuracy, -1)],
363 );
364
365 let table = BoostTable::from_iter([
366 (Boost::Atk, 1),
367 (Boost::Def, 1),
368 (Boost::SpAtk, 1),
369 (Boost::SpDef, 1),
370 (Boost::Spe, 1),
371 (Boost::Accuracy, 1),
372 (Boost::Evasion, 1),
373 ]);
374 assert_eq!(
375 table.iter().collect::<Vec<(Boost, i8)>>(),
376 vec![
377 (Boost::Atk, 1),
378 (Boost::Def, 1),
379 (Boost::SpAtk, 1),
380 (Boost::SpDef, 1),
381 (Boost::Spe, 1),
382 (Boost::Accuracy, 1),
383 (Boost::Evasion, 1),
384 ],
385 );
386 }
387}