igs/transposition_table/
protected.rs

1use crate::dbs::{NimbersProvider, NimbersStorer, HasLen};
2use crate::game::{Game, SerializableGame};
3use std::collections::HashMap;
4use std::hash::Hash;
5use std::io;
6use std::path::Path;
7use std::fs::{File, OpenOptions};
8use std::io::{Read, Seek, SeekFrom, BufWriter, Write};
9
10/// Transposition Table that protects nimbers of some positions
11/// (usually the positions that are close to a root of a search tree).
12///
13/// Nimbers of protected positions are stored in `protected_part`,
14/// which never overwrites them and is saved in `backup` (which usually is a file).
15/// Nimbers of the rest positions are stored in `unprotected_part`.
16/// The predicate `should_be_protected` points which positions are protected.
17pub struct ProtectedTT<'g, G: Game, UnprotectedTT, ProtectPred: Fn(&G, &G::Position) -> bool, F> {
18    game: &'g G,
19    /// Stores nimbers of unprotected positions.
20    unprotected_part: UnprotectedTT,
21    /// Stores nimbers of protected positions.
22    protected_part: HashMap<G::Position, u8>,
23    /// The predicate that points if given position is protected.
24    should_be_protected: ProtectPred,
25    /// A copy of the protected part; usually the file.
26    backup: F
27}
28
29impl<'g, G, UnprotectedTT, ProtectPred> ProtectedTT<'g, G, UnprotectedTT, ProtectPred, BufWriter<File>>
30where G: Game + SerializableGame,
31      <G as Game>::Position: Eq + Hash,
32      UnprotectedTT: NimbersStorer<G::Position>,
33      ProtectPred: Fn(&G, &G::Position) -> bool
34{
35    pub fn new<P: AsRef<Path>>(game: &'g G, backup_file_name: P, should_be_protected: ProtectPred, mut unprotected_part: UnprotectedTT) -> Self {
36        let mut protected_part = HashMap::<G::Position, u8>::new();
37        let mut backup_position = 0;
38        let mut backup = OpenOptions::new()
39            .read(true)
40            .write(true)
41            .create(true)
42            //.truncate(false)
43            .open(backup_file_name)
44            .unwrap();
45        let mut backup_has_extra_positions = false;
46        while let Ok(position) = game.read_position(&mut backup) {
47            let mut nimber = 0u8;
48            if backup.read_exact(std::slice::from_mut(&mut nimber)).is_ok() {
49                if should_be_protected(game, &position) {
50                    protected_part.store_nimber(position, nimber);
51                    backup_position = backup.stream_position().unwrap();
52                } else {    // file has been created with different predicate and some position are not protected now:
53                    unprotected_part.store_nimber(position, nimber);
54                    backup_has_extra_positions = true;
55                }
56            } else {
57                break;
58            }
59        }
60        backup.seek(SeekFrom::Start(backup_position)).unwrap();
61        let mut backup = BufWriter::with_capacity(game.position_size_bytes() + 1, backup);
62        if backup_has_extra_positions {
63            backup.rewind().unwrap();
64            for (p, n) in &protected_part {
65                game.write_position(&mut backup, p).expect("ProtectedTT cannot write the position to the backup");
66                backup.write_all(&n.to_ne_bytes()).expect("ProtectedTT cannot write the nimber to the backup");
67            }
68            backup.flush().expect("ProtectedTT cannot flush the backup");
69            let current_size = backup.stream_position().expect("ProtectedTT cannot shrink the file");
70            backup.get_mut().set_len(current_size).expect("ProtectedTT cannot shrink the file");
71        }
72        Self {
73            game,
74            unprotected_part,
75            protected_part,
76            should_be_protected,
77            backup
78        }
79    }
80}
81
82impl<'g, G, UnprotectedTT, ProtectPred, F> NimbersProvider<G::Position> for ProtectedTT<'g, G, UnprotectedTT, ProtectPred, F>
83where G: Game,
84      <G as Game>::Position: Eq + Hash,
85      UnprotectedTT: NimbersProvider<G::Position>,
86    ProtectPred: Fn(&G, &G::Position) -> bool
87{
88    #[inline(always)]
89    fn get_nimber(&self, position: &G::Position) -> Option<u8> {
90        if (self.should_be_protected)(self.game, position) {
91            self.protected_part.get_nimber(position)
92        } else {
93            self.unprotected_part.get_nimber(position)
94        }
95    }
96
97    #[inline(always)]
98    fn get_nimber_and_self_organize(&mut self, position: &G::Position) -> Option<u8> {
99        if (self.should_be_protected)(self.game, position) {
100            self.protected_part.get_nimber_and_self_organize(position)
101        } else {
102            self.unprotected_part.get_nimber_and_self_organize(position)
103        }
104    }
105}
106
107impl<'g, G, UnprotectedTT, ProtectPred, F> NimbersStorer<G::Position> for ProtectedTT<'g, G, UnprotectedTT, ProtectPred, F>
108    where G: Game + SerializableGame,
109          <G as Game>::Position: Eq + Hash,
110          UnprotectedTT: NimbersStorer<G::Position> + NimbersProvider<G::Position>,
111          ProtectPred: Fn(&G, &G::Position) -> bool,
112        F: io::Write
113{
114    fn store_nimber(&mut self, position: G::Position, nimber: u8) {
115        if (self.should_be_protected)(self.game, &position) {
116            /*let mut buff = Vec::<u8>::new();
117            self.game.write_position(&mut buff, &position).expect("ProtectedTT cannot write the position to the backup");
118            buff.push(nimber);
119            self.backup.write_all(&buff);*/
120            self.game.write_position(&mut self.backup, &position).expect("ProtectedTT cannot write the position to the backup");
121            self.backup.write_all(&nimber.to_ne_bytes()).expect("ProtectedTT cannot write the nimber to the backup");
122            self.backup.flush().expect("ProtectedTT cannot flush the backup");
123            self.protected_part.store_nimber(position, nimber)
124        } else {
125            self.unprotected_part.store_nimber(position, nimber)
126        }
127    }
128}
129
130impl<'g, G, UnprotectedTT, ProtectPred, F> HasLen for ProtectedTT<'g, G, UnprotectedTT, ProtectPred, F>
131where G: Game, ProtectPred: Fn(&G, &G::Position) -> bool, UnprotectedTT: HasLen
132{
133    #[inline] fn len(&self) -> usize {
134        self.protected_part.len() + self.unprotected_part.len()
135    }
136}