1use std::fmt;
2use std::fmt::{Debug, Display, Formatter};
3use std::ops::ControlFlow;
4
5use internal_iterator::InternalIterator;
6use itertools::Itertools;
7use rand::Rng;
8
9use crate::board::{
10 AllMovesIterator, Alternating, AvailableMovesIterator, Board, BoardDone, BoardMoves, BoardSymmetry, Outcome,
11 PlayError, Player,
12};
13use crate::symmetry::D4Symmetry;
14use crate::util::bits::{get_nth_set_bit, BitIter};
15use crate::util::iter::ClonableInternal;
16
17#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
18pub struct Coord(u8);
19
20#[derive(Clone, Eq, PartialEq, Hash)]
21pub struct STTTBoard {
22 grids: [u32; 9],
23 main_grid: u32,
24
25 last_move: Option<Coord>,
26 next_player: Player,
27 outcome: Option<Outcome>,
28
29 macro_mask: u32,
30 macro_open: u32,
31}
32
33impl Default for STTTBoard {
34 fn default() -> STTTBoard {
35 STTTBoard {
36 grids: [0; 9],
37 main_grid: 0,
38 last_move: None,
39 next_player: Player::A,
40 outcome: None,
41 macro_mask: STTTBoard::FULL_MASK,
42 macro_open: STTTBoard::FULL_MASK,
43 }
44 }
45}
46
47impl STTTBoard {
48 const FULL_MASK: u32 = 0b111_111_111;
49
50 pub fn tile(&self, coord: Coord) -> Option<Player> {
51 get_player(self.grids[coord.om() as usize], coord.os())
52 }
53
54 pub fn macr(&self, om: u8) -> Option<Player> {
55 debug_assert!(om < 9);
56 get_player(self.main_grid, om)
57 }
58
59 pub fn is_macro_open(&self, om: u8) -> bool {
60 debug_assert!(om < 9);
61 has_bit(self.macro_open, om)
62 }
63
64 pub fn count_tiles(&self) -> u32 {
66 self.grids.iter().map(|tile| tile.count_ones()).sum()
67 }
68
69 fn set_tile_and_update(&mut self, player: Player, coord: Coord) {
70 let om = coord.om();
71 let os = coord.os();
72 let p = 9 * player.index();
73
74 let new_grid = self.grids[om as usize] | (1 << (os + p));
76 self.grids[om as usize] = new_grid;
77
78 let grid_win = is_win_grid((new_grid >> p) & STTTBoard::FULL_MASK);
79 if grid_win {
80 let new_main_grid = self.main_grid | (1 << (om + p));
81 self.main_grid = new_main_grid;
82
83 if is_win_grid((new_main_grid >> p) & STTTBoard::FULL_MASK) {
84 self.outcome = Some(Outcome::WonBy(player));
85 }
86 }
87
88 if grid_win || new_grid.count_ones() == 9 {
90 self.macro_open &= !(1 << om);
91 if self.macro_open == 0 && self.outcome.is_none() {
92 self.outcome = Some(Outcome::Draw);
93 }
94 }
95 self.macro_mask = self.calc_macro_mask(os);
96 }
97
98 fn calc_macro_mask(&self, os: u8) -> u32 {
99 if has_bit(self.macro_open, os) {
100 1u32 << os
101 } else {
102 self.macro_open
103 }
104 }
105}
106
107impl Board for STTTBoard {
108 type Move = Coord;
109
110 fn next_player(&self) -> Player {
111 self.next_player
112 }
113
114 fn is_available_move(&self, mv: Self::Move) -> Result<bool, BoardDone> {
115 self.check_done()?;
116
117 let can_play_in_macro = has_bit(self.macro_mask, mv.om());
118 let micro_occupied = has_bit(compact_grid(self.grids[mv.om() as usize]), mv.os());
119
120 Ok(can_play_in_macro && !micro_occupied)
121 }
122
123 fn random_available_move(&self, rng: &mut impl Rng) -> Result<Self::Move, BoardDone> {
124 self.check_done()?;
128
129 let mut count = 0;
130 for om in BitIter::new(self.macro_mask) {
131 count += 9 - self.grids[om as usize].count_ones();
132 }
133
134 let mut index = rng.gen_range(0..count);
135
136 for om in BitIter::new(self.macro_mask) {
137 let grid = self.grids[om as usize];
138 let grid_count = 9 - grid.count_ones();
139
140 if index < grid_count {
141 let os = get_nth_set_bit(!compact_grid(grid), index);
142 return Ok(Coord::from_oo(om, os));
143 }
144
145 index -= grid_count;
146 }
147
148 unreachable!()
149 }
150
151 fn play(&mut self, mv: Self::Move) -> Result<(), PlayError> {
152 self.check_can_play(mv)?;
153
154 self.set_tile_and_update(self.next_player, mv);
156
157 self.last_move = Some(mv);
159 self.next_player = self.next_player.other();
160
161 Ok(())
162 }
163
164 fn outcome(&self) -> Option<Outcome> {
165 self.outcome
166 }
167
168 fn can_lose_after_move() -> bool {
169 false
170 }
171}
172
173impl Alternating for STTTBoard {}
174
175impl BoardSymmetry<STTTBoard> for STTTBoard {
176 type Symmetry = D4Symmetry;
177 type CanonicalKey = (u32, Option<Coord>, u32, u32);
178
179 fn map(&self, sym: D4Symmetry) -> STTTBoard {
180 let mut grids = [0; 9];
181 for oo in 0..9 {
182 grids[map_oo(sym, oo) as usize] = map_grid(sym, self.grids[oo as usize])
183 }
184
185 STTTBoard {
186 grids,
187 main_grid: map_grid(sym, self.main_grid),
188 last_move: self.last_move.map(|c| self.map_move(sym, c)),
189 next_player: self.next_player,
190 outcome: self.outcome,
191 macro_mask: map_grid(sym, self.macro_mask),
192 macro_open: map_grid(sym, self.macro_open),
193 }
194 }
195
196 fn map_move(&self, sym: D4Symmetry, mv: Coord) -> Coord {
197 Coord::from_oo(map_oo(sym, mv.om()), map_oo(sym, mv.os()))
198 }
199
200 fn canonical_key(&self) -> Self::CanonicalKey {
201 (self.main_grid, self.last_move, self.macro_mask, self.macro_open)
202 }
203}
204
205impl<'a> BoardMoves<'a, STTTBoard> for STTTBoard {
206 type AllMovesIterator = ClonableInternal<CoordIter>;
207 type AvailableMovesIterator = AvailableMovesIterator<'a, STTTBoard>;
208
209 fn all_possible_moves() -> Self::AllMovesIterator {
210 ClonableInternal::new(Coord::all())
211 }
212
213 fn available_moves(&'a self) -> Result<Self::AvailableMovesIterator, BoardDone> {
214 AvailableMovesIterator::new(self)
215 }
216}
217
218impl InternalIterator for AllMovesIterator<STTTBoard> {
219 type Item = Coord;
220
221 fn try_for_each<R, F>(self, f: F) -> ControlFlow<R>
222 where
223 F: FnMut(Self::Item) -> ControlFlow<R>,
224 {
225 Coord::all().try_for_each(f)
226 }
227}
228
229impl<'a> InternalIterator for AvailableMovesIterator<'a, STTTBoard> {
230 type Item = Coord;
231
232 fn try_for_each<R, F: FnMut(Self::Item) -> ControlFlow<R>>(self, mut f: F) -> ControlFlow<R> {
233 let board = self.board();
234 for om in BitIter::new(board.macro_mask) {
235 let free_grid = (!compact_grid(board.grids[om as usize])) & STTTBoard::FULL_MASK;
236 for os in BitIter::new(free_grid) {
237 f(Coord::from_oo(om, os))?;
238 }
239 }
240
241 ControlFlow::Continue(())
242 }
243}
244
245pub type CoordIter = std::iter::Map<std::ops::Range<u8>, fn(u8) -> Coord>;
246
247impl Coord {
248 pub fn all() -> CoordIter {
249 (0..81).map(Self::from_o)
250 }
251
252 pub fn all_yx() -> CoordIter {
253 (0..81).map(|i| Self::from_xy(i % 9, i / 9))
254 }
255
256 pub fn from_oo(om: u8, os: u8) -> Coord {
257 debug_assert!(om < 9);
258 debug_assert!(os < 9);
259 Coord(9 * om + os)
260 }
261
262 pub fn from_o(o: u8) -> Coord {
263 debug_assert!(o < 81);
264 Coord(o)
265 }
266
267 pub fn from_xy(x: u8, y: u8) -> Coord {
268 debug_assert!(x < 9 && y < 9);
269 Coord(((x / 3) + (y / 3) * 3) * 9 + ((x % 3) + (y % 3) * 3))
270 }
271
272 pub fn om(self) -> u8 {
273 self.0 / 9
274 }
275
276 pub fn os(self) -> u8 {
277 self.0 % 9
278 }
279
280 pub fn o(self) -> u8 {
281 9 * self.om() + self.os()
282 }
283
284 pub fn yx(self) -> u8 {
285 9 * self.y() + self.x()
286 }
287
288 pub fn x(self) -> u8 {
289 (self.om() % 3) * 3 + (self.os() % 3)
290 }
291
292 pub fn y(self) -> u8 {
293 (self.om() / 3) * 3 + (self.os() / 3)
294 }
295}
296
297impl Debug for Coord {
298 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
299 write!(f, "Coord({}, {})", self.om(), self.os())
300 }
301}
302
303impl Display for Coord {
304 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
305 write!(f, "({}, {})", self.om(), self.os())
306 }
307}
308
309fn map_oo(sym: D4Symmetry, oo: u8) -> u8 {
310 let (x, y) = sym.map_xy(oo % 3, oo / 3, 3);
311 x + y * 3
312}
313
314fn map_grid(sym: D4Symmetry, grid: u32) -> u32 {
315 let mut result = 0;
317 for oo_input in 0..9 {
318 let oo_result = map_oo(sym, oo_input);
319 let get = (grid >> oo_input) & 0b1_000_000_001;
320 result |= get << oo_result;
321 }
322 result
323}
324
325fn is_win_grid(grid: u32) -> bool {
326 debug_assert!(has_mask(STTTBoard::FULL_MASK, grid));
327
328 const WIN_GRIDS: [u32; 16] = [
329 2155905152, 4286611584, 4210076288, 4293962368, 3435954304, 4291592320, 4277971584, 4294748800, 2863300736,
330 4294635760, 4210731648, 4294638320, 4008607872, 4294897904, 4294967295, 4294967295,
331 ];
332 has_bit(WIN_GRIDS[(grid / 32) as usize], (grid % 32) as u8)
333}
334
335fn has_bit(x: u32, i: u8) -> bool {
336 ((x >> i) & 1) != 0
337}
338
339fn has_mask(x: u32, mask: u32) -> bool {
340 x & mask == mask
341}
342
343fn compact_grid(grid: u32) -> u32 {
344 (grid | grid >> 9) & STTTBoard::FULL_MASK
345}
346
347fn get_player(grid: u32, index: u8) -> Option<Player> {
348 if has_bit(grid, index) {
349 Some(Player::A)
350 } else if has_bit(grid, index + 9) {
351 Some(Player::B)
352 } else {
353 None
354 }
355}
356
357fn symbol_from_tile(board: &STTTBoard, coord: Coord) -> char {
358 let is_last = Some(coord) == board.last_move;
359 let is_available = board.is_available_move(coord).unwrap_or(false);
360 let player = board.tile(coord);
361 symbol_from_tuple(is_available, is_last, player)
362}
363
364fn symbol_from_tuple(is_available: bool, is_last: bool, player: Option<Player>) -> char {
365 let tuple = (is_available, is_last, player);
366 match tuple {
367 (false, false, Some(Player::A)) => 'x',
368 (false, true, Some(Player::A)) => 'X',
369 (false, false, Some(Player::B)) => 'o',
370 (false, true, Some(Player::B)) => 'O',
371 (true, false, None) => '.',
372 (false, false, None) => ' ',
373 _ => unreachable!("Invalid tile state {:?}", tuple),
374 }
375}
376
377fn symbol_to_tuple(c: char) -> (bool, bool, Option<Player>) {
378 match c {
379 'x' => (false, false, Some(Player::A)),
380 'X' => (false, true, Some(Player::A)),
381 'o' => (false, false, Some(Player::B)),
382 'O' => (false, true, Some(Player::B)),
383 ' ' => (false, false, None),
384 '.' => (true, false, None),
385 _ => panic!("unexpected character '{}'", c),
386 }
387}
388
389impl fmt::Debug for STTTBoard {
390 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
391 write!(f, "STTTBoard({:?})", board_to_compact_string(self))
392 }
393}
394
395impl fmt::Display for STTTBoard {
396 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
397 for y in 0..9 {
398 if y == 3 || y == 6 {
399 writeln!(f, "---+---+--- +---+")?;
400 }
401
402 for x in 0..9 {
403 if x == 3 || x == 6 {
404 write!(f, "|")?;
405 }
406 write!(f, "{}", symbol_from_tile(self, Coord::from_xy(x, y)))?;
407 }
408
409 if (3..6).contains(&y) {
410 write!(f, " |")?;
411 let ym = y - 3;
412 for xm in 0..3 {
413 let om = xm + 3 * ym;
414 write!(f, "{}", symbol_from_tuple(self.is_macro_open(om), false, self.macr(om)))?;
415 }
416 write!(f, "|")?;
417 }
418
419 writeln!(f)?;
420 }
421
422 Ok(())
423 }
424}
425
426pub fn board_to_compact_string(board: &STTTBoard) -> String {
427 Coord::all().map(|coord| symbol_from_tile(board, coord)).join("")
428}
429
430pub fn board_from_compact_string(s: &str) -> STTTBoard {
431 assert!(s.chars().count() == 81, "compact string should have length 81");
432
433 let mut board = STTTBoard::default();
434 let mut last_move = None;
435
436 for (o, c) in s.chars().enumerate() {
437 let coord = Coord::from_o(o as u8);
438 let (_, last, player) = symbol_to_tuple(c);
439
440 if last {
441 assert!(last_move.is_none(), "Compact string cannot contain multiple last moves");
442 let player = player.expect("Last move must have been played by a player");
443 last_move = Some((player, coord));
444 }
445
446 if let Some(player) = player {
447 board.set_tile_and_update(player, coord);
448 }
449 }
450
451 if let Some((last_player, last_coord)) = last_move {
452 board.set_tile_and_update(last_player, last_coord);
453 board.last_move = Some(last_coord);
454 board.next_player = last_player.other()
455 }
456
457 board
458}