use fxhash::FxHashMap;
use std::{fs, sync::OnceLock};
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct Opening {
pub eco: String,
pub name: String,
pub moves: String,
}
impl Opening {
pub fn iter_moves(&self) -> impl Iterator<Item = &str> {
self.moves.split(' ').filter(|s| !s.contains('.'))
}
}
pub fn default_book() -> &'static Vec<Opening> {
static DEFAULT_BOOK: OnceLock<Vec<Opening>> = OnceLock::new();
DEFAULT_BOOK.get_or_init(|| {
fs::read_to_string("src/data/book.csv")
.unwrap()
.split('\n')
.filter_map(|line| {
let mut iter = line.split(';');
Some(Opening {
eco: iter.next()?.into(),
name: iter.next()?.into(),
moves: iter.next()?.into(),
})
})
.collect()
})
}
pub fn default_tree() -> &'static OpeningTree {
static DEFAULT_TREE: OnceLock<OpeningTree> = OnceLock::new();
DEFAULT_TREE.get_or_init(|| OpeningTree::from(default_book().as_slice()))
}
#[derive(Default, Debug, Clone, PartialEq, Eq)]
pub struct OpeningTree {
pub element: Option<Opening>,
pub children: Option<FxHashMap<String, OpeningTree>>,
}
impl OpeningTree {
#[must_use]
pub fn get_opening<'a>(
&self,
sans: impl Iterator<Item = &'a str>,
) -> Option<&Opening> {
let mut cursor = self;
let mut longest_match: Option<&Opening> = None;
for san in sans {
let Some(ref children) = cursor.children else {
break;
};
match children.get(san) {
Some(tree) => {
cursor = tree;
if let Some(opening) = &cursor.element {
longest_match = Some(opening);
}
}
None => break,
}
}
longest_match
}
}
impl From<&[Opening]> for OpeningTree {
fn from(value: &[Opening]) -> Self {
let mut root = Self::default();
for opening in value {
let mut cursor = &mut root;
for san in opening.iter_moves() {
let children = cursor.children.get_or_insert_with(FxHashMap::default);
cursor = children.entry(san.into()).or_default();
}
cursor.element = Some(opening.clone());
}
root
}
}
impl From<OpeningTree> for Vec<Opening> {
fn from(value: OpeningTree) -> Self {
let mut output = Self::new();
if let Some(opening) = value.element {
output.push(opening);
}
if let Some(children) = value.children {
for child in children.into_values() {
output.append(&mut child.into());
}
}
output
}
}
#[cfg(test)]
mod tests {
use super::*;
use rayon::prelude::*;
#[test]
fn test_default_book() {
let book: Vec<Opening> = default_book().clone();
assert_eq!(3455, book.len());
}
#[test]
fn test_default_tree() {
let tree = default_tree().clone();
let moves = "d4 Nf6 c4 e6 Nc3 Bb4 Qb3 Bxc3+ Qxc3 O-O Bg5 c5 dxc5 Nc6 Nf3 Qa5 Bxf6 gxf6 Qxa5 Nxa5 e3 Rd8 Rd1 Kg7 Be2 b6 Rd4 bxc5 Rg4+ Kh6 Bd3 f5 Rh4+ Kg6 g4 Ba6 gxf5+ exf5 Ne5+ Kf6 Rh6+ Kxe5 f4#".split(' ');
let opening = tree.get_opening(moves).unwrap();
assert_eq!(opening.eco, "E22");
}
#[test]
fn test_tree_perf() {
const MOVE_SEQUENCES_PATH: &str = "test_data/bare_move_sequences.json";
use std::time;
let start_time = std::time::SystemTime::now()
.duration_since(time::UNIX_EPOCH)
.unwrap();
let tree: OpeningTree = default_tree().clone();
let end_time = std::time::SystemTime::now()
.duration_since(time::UNIX_EPOCH)
.unwrap();
println!("File -> Tree: {:?} elapsed.", end_time - start_time);
let json = fs::read_to_string(MOVE_SEQUENCES_PATH).unwrap();
let games: Vec<String> = serde_json::from_str(&json).unwrap();
let games: Vec<Vec<&str>> = games
.par_iter()
.map(|v: &String| v.split(' ').collect())
.collect();
let start_time = std::time::SystemTime::now()
.duration_since(time::UNIX_EPOCH)
.unwrap();
games.into_par_iter().for_each(|game| {
let _ = tree.get_opening(game.into_iter());
});
let end_time = std::time::SystemTime::now()
.duration_since(time::UNIX_EPOCH)
.unwrap();
println!("65k games: {:?} elapsed.", end_time - start_time);
}
}