Skip to main content

diplomacy/
unit.rs

1use crate::parser::{Error, ErrorKind};
2use crate::{Command, Nation, Order, ShortName, geo::Location, geo::RegionKey};
3use std::borrow::{Borrow, Cow};
4use std::collections::HashMap;
5use std::fmt;
6use std::hash::{BuildHasher, Hash};
7use std::str::FromStr;
8
9/// The type of a military unit. Armies are convoyable land-based units; fleets
10/// are sea-going units which are able to convoy armies.
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
12#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
13pub enum UnitType {
14    /// A convoyable land-based unit which can traverse inland and coastal terrain.
15    #[cfg_attr(feature = "serde", serde(rename = "A"))]
16    Army,
17
18    /// A sea-based unit which can traverse sea and coastal terrain.
19    #[cfg_attr(feature = "serde", serde(rename = "F"))]
20    Fleet,
21}
22
23impl FromStr for UnitType {
24    type Err = Error;
25
26    fn from_str(s: &str) -> Result<Self, Self::Err> {
27        match &s.to_lowercase()[..] {
28            "a" | "army" => Ok(UnitType::Army),
29            "f" | "fleet" => Ok(UnitType::Fleet),
30            _ => Err(Error::new(ErrorKind::InvalidUnitType, s)),
31        }
32    }
33}
34
35impl ShortName for UnitType {
36    fn short_name(&self) -> std::borrow::Cow<'_, str> {
37        Cow::Borrowed(match *self {
38            UnitType::Army => "A",
39            UnitType::Fleet => "F",
40        })
41    }
42}
43
44/// A specific unit that belongs to a nation.
45///
46/// Diplomacy doesn't invest much in unit identity across turns, so there's no difference
47/// between one Austrian fleet and another.
48#[derive(Debug, Clone, PartialEq, Eq, Hash)]
49#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
50pub struct Unit<'a> {
51    nation: Cow<'a, Nation>,
52    unit_type: UnitType,
53}
54
55impl<'a> Unit<'a> {
56    pub fn new(nation: impl Into<Cow<'a, Nation>>, unit_type: UnitType) -> Self {
57        Self {
58            nation: nation.into(),
59            unit_type,
60        }
61    }
62
63    pub fn nation(&self) -> &Nation {
64        self.nation.as_ref()
65    }
66
67    pub fn unit_type(&self) -> UnitType {
68        self.unit_type
69    }
70}
71
72/// A unit's instantaneous position in a region.
73#[derive(Debug, Clone, PartialEq, Eq, Hash)]
74#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
75pub struct UnitPosition<'a, L = &'a RegionKey> {
76    pub unit: Unit<'a>,
77    /// The unit's current location.
78    pub region: L,
79}
80
81impl<'a, L> UnitPosition<'a, L> {
82    /// Create a new unit at a given position.
83    pub fn new(unit: Unit<'a>, region: L) -> Self {
84        Self { unit, region }
85    }
86
87    pub fn nation(&self) -> &Nation {
88        self.unit.nation()
89    }
90
91    /// Create a view of the unit position that has a reference to its region.
92    pub fn as_region_ref(&self) -> UnitPosition<'a, &L> {
93        UnitPosition {
94            unit: self.unit.clone(),
95            region: &self.region,
96        }
97    }
98}
99
100impl<'a, L: Clone> UnitPosition<'a, &L> {
101    /// Returns a [`UnitPosition`] that converts the region to an owned value by cloning.
102    pub fn with_cloned_region(&self) -> UnitPosition<'a, L> {
103        UnitPosition {
104            unit: self.unit.clone(),
105            region: (*self.region).clone(),
106        }
107    }
108}
109
110impl<'a> FromStr for UnitPosition<'a, RegionKey> {
111    type Err = Error;
112
113    fn from_str(s: &str) -> Result<Self, Self::Err> {
114        let mut words = s.split_ascii_whitespace();
115        let nation = if let Some(first_word) = words.next() {
116            Nation::from(first_word.trim_end_matches(':'))
117        } else {
118            return Err(Error::new(ErrorKind::TooFewWords(3), s));
119        };
120
121        let unit_type = if let Some(word) = words.next() {
122            UnitType::from_str(word)?
123        } else {
124            return Err(Error::new(ErrorKind::TooFewWords(3), s));
125        };
126
127        let region = if let Some(word) = words.next() {
128            RegionKey::from_str(word)?
129        } else {
130            return Err(Error::new(ErrorKind::TooFewWords(3), s));
131        };
132
133        Ok(UnitPosition::new(
134            Unit::new(Cow::Owned(nation), unit_type),
135            region,
136        ))
137    }
138}
139
140impl fmt::Display for UnitPosition<'_, RegionKey> {
141    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
142        write!(
143            f,
144            "{}: {} {}",
145            self.unit.nation().short_name(),
146            self.unit.unit_type().short_name(),
147            self.region.short_name()
148        )
149    }
150}
151
152/// Knowledge of unit positions at a point in time.
153pub trait UnitPositions<L: Location> {
154    /// The current unit positions. The order is unspecified, but every unit should be
155    /// returned exactly once.
156    fn unit_positions(&self) -> Vec<UnitPosition<'_, &L>>;
157
158    /// Get the unit currently occupying a province.
159    ///
160    /// This function returns the region of the occupier as well.
161    fn find_province_occupier(&self, province: &L::Province) -> Option<UnitPosition<'_, &L>>;
162
163    /// Get the unit currently occupying a specific region.
164    fn find_region_occupier(&self, region: &L) -> Option<Unit<'_>>;
165}
166
167impl<'a, L: Location> UnitPositions<L> for Vec<UnitPosition<'a, L>> {
168    fn unit_positions(&self) -> Vec<UnitPosition<'_, &L>> {
169        self.iter()
170            .map(|pos| UnitPosition::new(pos.unit.clone(), &pos.region))
171            .collect()
172    }
173
174    fn find_province_occupier(&self, province: &L::Province) -> Option<UnitPosition<'_, &L>> {
175        self.iter()
176            .find(|pos| pos.region.province() == province)
177            .map(|pos| UnitPosition::new(pos.unit.clone(), &pos.region))
178    }
179
180    fn find_region_occupier(&self, region: &L) -> Option<Unit<'_>> {
181        self.iter()
182            .find(|pos| pos.region == *region)
183            .map(|pos| pos.unit.clone())
184    }
185}
186
187/// Infer unit positions from a collection of orders. This assumes orders are trustworthy
188/// and complete:
189///
190/// 1. There is an order for every unit.
191/// 2. Orders are only issued to units that exist.
192/// 3. There is at most one order per province.
193impl<L: Location, C: Command<L>> UnitPositions<L> for Vec<Order<L, C>> {
194    fn unit_positions(&self) -> Vec<UnitPosition<'_, &L>> {
195        self.iter().map(UnitPosition::from).collect()
196    }
197
198    fn find_province_occupier(&self, province: &L::Province) -> Option<UnitPosition<'_, &L>> {
199        self.iter()
200            .find(|ord| ord.region.province() == province)
201            .map(UnitPosition::from)
202    }
203
204    fn find_region_occupier(&self, region: &L) -> Option<Unit<'_>> {
205        self.iter()
206            .find(|ord| ord.region == *region)
207            .map(Unit::from)
208    }
209}
210
211/// Infer unit positions from a collection of orders. This assumes orders are trustworthy
212/// and complete:
213///
214/// 1. There is an order for every unit.
215/// 2. Orders are only issued to units that exist.
216/// 3. There is at most one order per province.
217impl<L: Location, C: Command<L>> UnitPositions<L> for Vec<&'_ Order<L, C>> {
218    fn unit_positions(&self) -> Vec<UnitPosition<'_, &L>> {
219        self.iter().copied().map(UnitPosition::from).collect()
220    }
221
222    fn find_province_occupier(&self, province: &L::Province) -> Option<UnitPosition<'_, &L>> {
223        self.iter()
224            .copied()
225            .find(|ord| ord.region.province() == province)
226            .map(UnitPosition::from)
227    }
228
229    fn find_region_occupier(&self, region: &L) -> Option<Unit<'_>> {
230        self.iter()
231            .copied()
232            .find(|ord| ord.region == *region)
233            .map(Unit::from)
234    }
235}
236
237impl<L, K, H> UnitPositions<L> for HashMap<K, UnitPosition<'_, &L>, H>
238where
239    L: Location + Eq,
240    L::Province: Eq + Hash,
241    K: Borrow<L::Province> + Eq + Hash,
242    H: BuildHasher,
243{
244    fn unit_positions(&self) -> Vec<UnitPosition<'_, &L>> {
245        self.values().cloned().collect()
246    }
247
248    fn find_province_occupier(&self, province: &L::Province) -> Option<UnitPosition<'_, &L>> {
249        self.get(province).cloned()
250    }
251
252    fn find_region_occupier(&self, region: &L) -> Option<Unit<'_>> {
253        let up = self.get(region.province())?;
254        if up.region == region {
255            Some(up.unit.clone())
256        } else {
257            None
258        }
259    }
260}
261
262#[cfg(test)]
263mod test {
264    use super::{UnitPosition, UnitType};
265    use crate::{Nation, geo::RegionKey};
266
267    #[test]
268    fn parse_unit_type() {
269        assert_eq!(Ok(UnitType::Army), "Army".parse());
270        assert_eq!(Ok(UnitType::Fleet), "Fleet".parse());
271        assert_eq!(Ok(UnitType::Army), "A".parse());
272        assert_eq!(Ok(UnitType::Fleet), "F".parse());
273        assert_eq!(Ok(UnitType::Army), "a".parse());
274        assert_eq!(Ok(UnitType::Fleet), "f".parse());
275    }
276
277    #[test]
278    fn parse_unit_position() {
279        let pos: UnitPosition<'_, RegionKey> = "FRA: F bre".parse().unwrap();
280        assert_eq!(pos.nation(), &Nation::from("FRA"));
281    }
282}