use std::io::{self};
use std::io::{Read, Seek};
use thiserror::Error;
use crate::common::{
binpack_error::BinpackError, compressed_training_file_reader::CompressedTrainingDataFileReader,
entry::PackedTrainingDataEntry, entry::TrainingDataEntry,
};
use super::move_score_list_reader::PackedMoveScoreListReader;
const SUGGESTED_CHUNK_SIZE: usize = 8192;
#[derive(Debug, Error)]
pub enum CompressedReaderError {
#[error("IO error: {0}")]
Io(#[from] io::Error),
#[error("Invalid data format: {0}")]
InvalidFormat(String),
#[error("End of file reached")]
EndOfFile,
#[error("Binpack error: {0}")]
BinpackError(#[from] BinpackError),
}
type Result<T> = std::result::Result<T, CompressedReaderError>;
#[derive(Debug)]
pub struct CompressedTrainingDataEntryReader<T: Read + Seek> {
chunk: Vec<u8>,
movelist_reader: Option<PackedMoveScoreListReader>,
input_file: Option<CompressedTrainingDataFileReader<T>>,
offset: usize,
is_end: bool,
}
impl<T: Read + Seek> CompressedTrainingDataEntryReader<T> {
pub fn new(file: T) -> Result<Self> {
let chunk = Vec::with_capacity(SUGGESTED_CHUNK_SIZE);
let mut reader = Self {
chunk,
movelist_reader: None,
input_file: Some(CompressedTrainingDataFileReader::new(file)?),
offset: 0,
is_end: false,
};
if !reader.input_file.as_mut().unwrap().has_next_chunk() {
reader.is_end = true;
return Err(CompressedReaderError::EndOfFile);
} else {
reader
.input_file
.as_mut()
.unwrap()
.read_next_chunk_into(&mut reader.chunk)?;
}
Ok(reader)
}
pub fn into_inner(&mut self) -> io::Result<T> {
self.input_file.take().unwrap().into_inner()
}
pub fn read_bytes(&self) -> u64 {
self.input_file.as_ref().unwrap().read_bytes()
}
pub fn has_next(&self) -> bool {
!self.is_end
}
pub fn is_next_entry_continuation(&self) -> bool {
if let Some(ref reader) = self.movelist_reader {
return reader.has_next();
}
false
}
#[allow(clippy::should_implement_trait)]
pub fn next(&mut self) -> TrainingDataEntry {
if let Some(ref mut reader) = self.movelist_reader {
let entry = reader.next_entry(&self.chunk[self.offset..]);
if !reader.has_next() {
self.offset += reader.num_read_bytes();
self.movelist_reader = None;
self.fetch_next_chunk_if_needed();
}
return entry;
}
let entry = self.read_entry();
let num_plies = self.read_plies();
if num_plies > 0 {
self.movelist_reader = Some(PackedMoveScoreListReader::new(entry, num_plies));
} else {
self.fetch_next_chunk_if_needed();
}
entry
}
fn read_entry(&mut self) -> TrainingDataEntry {
let size = PackedTrainingDataEntry::byte_size();
debug_assert!(self.offset + size <= self.chunk.len());
let packed =
PackedTrainingDataEntry::from_slice(&self.chunk[self.offset..self.offset + size]);
self.offset += size;
packed.unpack_entry()
}
fn read_plies(&mut self) -> u16 {
let ply = ((self.chunk[self.offset] as u16) << 8) | (self.chunk[self.offset + 1] as u16);
self.offset += 2;
ply
}
fn fetch_next_chunk_if_needed(&mut self) {
if self.offset + PackedTrainingDataEntry::byte_size() + 2 > self.chunk.len() {
if self.input_file.as_mut().unwrap().has_next_chunk() {
let chunk = self.input_file.as_mut().unwrap().read_next_chunk().unwrap();
self.chunk = chunk;
self.offset = 0;
} else {
self.is_end = true;
}
}
}
}
#[cfg(test)]
mod tests {
use std::{fs::OpenOptions, io::Cursor};
use crate::chess::{
coords::Square,
piece::Piece,
position::Position,
r#move::{Move, MoveType},
};
use super::*;
#[test]
fn test_reader_simple() {
let file = OpenOptions::new()
.read(true)
.write(true)
.create(false)
.append(false)
.open("./test/ep1.binpack")
.unwrap();
let mut reader = CompressedTrainingDataEntryReader::new(file).unwrap();
let mut entries: Vec<TrainingDataEntry> = Vec::new();
while reader.has_next() {
let entry = reader.next();
entries.push(entry);
}
let expected = vec![
TrainingDataEntry {
pos: Position::from_fen("1q5b/1r5k/4p2p/1b2P1pN/3p4/6PP/1nP3B1/1Q2B1K1 w - - 0 35")
.unwrap(),
mv: Move::new(
Square::new(10),
Square::new(26),
MoveType::Normal,
Piece::none(),
),
score: -201,
ply: 68,
result: 0,
},
TrainingDataEntry {
pos: Position::from_fen("1q5b/1r5k/4p2p/1b2P1pN/2Pp4/6PP/1n4B1/1Q2B1K1 b - - 0 35")
.unwrap(),
mv: Move::new(
Square::new(27),
Square::new(19),
MoveType::Normal,
Piece::none(),
),
score: 254,
ply: 69,
result: 0,
},
TrainingDataEntry {
pos: Position::from_fen(
"1q5b/1r5k/4p2p/1b2P1pN/2P5/3p2PP/1n4B1/1Q2B1K1 w - - 0 36",
)
.unwrap(),
mv: Move::new(
Square::new(14),
Square::new(49),
MoveType::Normal,
Piece::none(),
),
score: -220,
ply: 70,
result: 0,
},
];
assert_eq!(entries, expected);
}
#[test]
fn test_reader_big_score_diff() {
let cursor: Cursor<Vec<u8>> = Cursor::new(Vec::from([
66, 73, 78, 80, 37, 0, 0, 0, 130, 130, 144, 210, 8, 192, 70, 82, 72, 58, 64, 0, 81, 16,
18, 113, 155, 5, 0, 0, 0, 0, 0, 0, 10, 104, 249, 253, 0, 68, 0, 0, 0, 1, 29, 83, 79,
]));
let mut reader = CompressedTrainingDataEntryReader::new(cursor).unwrap();
let mut entries: Vec<TrainingDataEntry> = Vec::new();
while reader.has_next() {
let entry = reader.next();
entries.push(entry);
}
let expected = vec![
TrainingDataEntry {
pos: Position::from_fen("1q5b/1r5k/4p2p/1b2P1pN/3p4/6PP/1nP3B1/1Q2B1K1 w - - 0 35")
.unwrap(),
mv: Move::new(
Square::new(10),
Square::new(26),
MoveType::Normal,
Piece::none(),
),
score: -31999,
ply: 68,
result: 0,
},
TrainingDataEntry {
pos: Position::from_fen("1q5b/1r5k/4p2p/1b2P1pN/2Pp4/6PP/1n4B1/1Q2B1K1 b - - 0 35")
.unwrap(),
mv: Move::new(
Square::new(27),
Square::new(19),
MoveType::Normal,
Piece::none(),
),
score: -1500,
ply: 69,
result: 0,
},
];
assert_eq!(entries, expected);
}
}