Skip to main content

battler_data/moves/
multihit_type.rs

1use core::{
2    fmt,
3    fmt::Display,
4};
5
6use anyhow::Error;
7use serde::{
8    Deserialize,
9    Serialize,
10    Serializer,
11    de::Visitor,
12    ser::SerializeSeq,
13};
14
15/// The number of hits done by a multihit move.
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum MultihitType {
18    /// A static number of hits.
19    Static(u8),
20    /// A range of numbers to choose from.
21    Range(u8, u8),
22}
23
24impl MultihitType {
25    /// The minimum number of times the move can hit.
26    pub fn min(&self) -> u8 {
27        match self {
28            Self::Static(val) => *val,
29            Self::Range(min, _) => *min,
30        }
31    }
32
33    /// The maximum number of times the move can hit.
34    pub fn max(&self) -> u8 {
35        match self {
36            Self::Static(val) => *val,
37            Self::Range(_, max) => *max,
38        }
39    }
40}
41
42impl Display for MultihitType {
43    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44        match self {
45            Self::Static(n) => write!(f, "{n}"),
46            Self::Range(begin, end) => write!(f, "[{begin},{end}]"),
47        }
48    }
49}
50
51impl From<u8> for MultihitType {
52    fn from(value: u8) -> Self {
53        Self::Static(value)
54    }
55}
56
57impl TryFrom<&[u8]> for MultihitType {
58    type Error = Error;
59    fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
60        if value.len() != 2 {
61            return Err(Error::msg("multihit range must contain exactly 2 elements"));
62        }
63        Ok(Self::Range(value[0], value[1]))
64    }
65}
66
67impl Serialize for MultihitType {
68    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
69    where
70        S: Serializer,
71    {
72        match self {
73            Self::Static(n) => serializer.serialize_u8(*n),
74            Self::Range(begin, end) => {
75                let mut seq = serializer.serialize_seq(Some(2))?;
76                seq.serialize_element(begin)?;
77                seq.serialize_element(end)?;
78                seq.end()
79            }
80        }
81    }
82}
83
84struct MultihitTypeVisitor;
85
86impl<'de> Visitor<'de> for MultihitTypeVisitor {
87    type Value = MultihitType;
88
89    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
90        write!(formatter, "an integer or an array of 2 integers")
91    }
92
93    fn visit_u8<E>(self, v: u8) -> Result<Self::Value, E>
94    where
95        E: serde::de::Error,
96    {
97        Ok(Self::Value::from(v))
98    }
99
100    fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
101    where
102        E: serde::de::Error,
103    {
104        Ok(Self::Value::from(v as u8))
105    }
106
107    fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
108    where
109        A: serde::de::SeqAccess<'de>,
110    {
111        let begin = match seq.next_element()? {
112            Some(v) => v,
113            None => return Err(serde::de::Error::invalid_length(0, &self)),
114        };
115        let end = match seq.next_element()? {
116            Some(v) => v,
117            None => return Err(serde::de::Error::invalid_length(1, &self)),
118        };
119        if seq.next_element::<u8>()?.is_some() {
120            return Err(serde::de::Error::invalid_length(3, &self));
121        }
122        Ok(Self::Value::Range(begin, end))
123    }
124}
125
126impl<'de> Deserialize<'de> for MultihitType {
127    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
128    where
129        D: serde::Deserializer<'de>,
130    {
131        deserializer.deserialize_any(MultihitTypeVisitor)
132    }
133}
134
135#[cfg(test)]
136mod multihit_type_test {
137    use crate::{
138        MultihitType,
139        test_util::test_serialization,
140    };
141
142    #[test]
143    fn serializes_to_string() {
144        test_serialization(MultihitType::Static(2), 2);
145        test_serialization(MultihitType::Range(1, 5), "[1,5]");
146    }
147}