use std::io::Write;
use std::io::{self};
use thiserror::Error;
use crate::{
chess::{position::Position, r#move::Move},
common::{
compressed_training_file_writer::CompressedTrainingDataFileWriter,
entry::PackedTrainingDataEntry, entry::TrainingDataEntry,
},
};
use super::move_score_list::PackedMoveScoreList;
const KI_B: usize = 1024;
const MI_B: usize = 1024 * KI_B;
const SUGGESTED_CHUNK_SIZE: usize = MI_B;
const MAX_MOVELIST_SIZE: usize = 10 * KI_B;
#[derive(Debug, Error)]
pub enum CompressedWriterError {
#[error("IO error: {0}")]
Io(#[from] io::Error),
#[error("Invalid data format: {0}")]
InvalidFormat(String),
#[error("End of file reached")]
EndOfFile,
}
type Result<T> = std::result::Result<T, CompressedWriterError>;
#[derive(Debug)]
pub struct CompressedTrainingDataEntryWriter<T: Write> {
output_file: Option<CompressedTrainingDataFileWriter<T>>,
last_entry: TrainingDataEntry,
movelist: PackedMoveScoreList,
packed_size: usize,
packed_entries: Vec<u8>,
is_first: bool,
}
impl<T: Write> CompressedTrainingDataEntryWriter<T> {
pub fn new(file: T) -> Result<Self> {
let writer = Self {
output_file: Some(CompressedTrainingDataFileWriter::new(file)?),
last_entry: TrainingDataEntry {
ply: 0xFFFF, result: 0x7FFF,
pos: Position::default(),
mv: Move::default(),
score: 0,
},
movelist: PackedMoveScoreList::new(),
packed_size: 0,
packed_entries: vec![0u8; SUGGESTED_CHUNK_SIZE + MAX_MOVELIST_SIZE],
is_first: true,
};
Ok(writer)
}
pub fn into_inner(&mut self) -> io::Result<T> {
self.output_file.take().unwrap().into_inner()
}
pub fn write_entry(&mut self, entry: &TrainingDataEntry) -> Result<()> {
let is_cont = self.last_entry.is_continuation(entry);
if is_cont {
self.movelist
.add_move_score(&entry.pos, entry.mv, entry.score);
} else {
if !self.is_first {
self.write_movelist();
}
if self.packed_size >= SUGGESTED_CHUNK_SIZE {
match self
.output_file
.as_mut()
.unwrap()
.append(&self.packed_entries[..self.packed_size])
{
Ok(_) => {}
Err(e) => {
return Err(CompressedWriterError::Io(e));
}
}
self.packed_size = 0;
}
let packed = PackedTrainingDataEntry::from_entry(entry);
let packed_bytes: [u8; size_of::<PackedTrainingDataEntry>()] = packed.data;
self.packed_entries
[self.packed_size..self.packed_size + PackedTrainingDataEntry::byte_size()]
.copy_from_slice(&packed_bytes);
self.packed_size += PackedTrainingDataEntry::byte_size();
self.movelist.clear(entry);
self.is_first = false;
}
self.last_entry = *entry;
Ok(())
}
pub fn flush_and_end(&mut self) {
let _ = self.flush_packed();
}
pub fn flush(&mut self) {
if let Some(file) = self.output_file.as_mut() {
let _ = file.flush();
}
}
fn flush_packed(&mut self) -> Result<()> {
if self.packed_size > 0 {
if !self.is_first {
self.write_movelist();
}
match self
.output_file
.as_mut()
.unwrap()
.append(&self.packed_entries[..self.packed_size])
{
Ok(_) => {}
Err(e) => {
return Err(CompressedWriterError::Io(e));
}
}
self.packed_size = 0;
}
if let Some(file) = self.output_file.as_mut() {
file.flush()?;
}
Ok(())
}
fn write_movelist(&mut self) {
self.packed_entries[self.packed_size] = (self.movelist.num_plies >> 8) as u8;
self.packed_entries[self.packed_size + 1] = self.movelist.num_plies as u8;
self.packed_size += 2;
if self.movelist.num_plies > 0 {
let movetext = self.movelist.movetext();
self.packed_entries[self.packed_size..self.packed_size + movetext.len()]
.copy_from_slice(movetext);
self.packed_size += movetext.len();
}
}
}
impl<T: Write> Drop for CompressedTrainingDataEntryWriter<T> {
fn drop(&mut self) {
if let Err(e) = self.flush_packed() {
eprintln!("Error flushing writer: {}", e);
}
}
}
#[cfg(test)]
mod tests {
use std::{
fs::{self, OpenOptions},
io::{Cursor, Read, Seek},
};
use super::*;
use crate::chess::{
coords::Square,
piece::Piece,
position::Position,
r#move::{Move, MoveType},
};
#[test]
fn test_compressed_writer() {
let entries = vec![
TrainingDataEntry {
pos: Position::from_fen("1q5b/1r5k/4p2p/1b2P1pN/3p4/6PP/1nP3B1/1Q2B1K1 w - - 0 35")
.unwrap(),
mv: Move::new(
Square::new(10),
Square::new(26),
MoveType::Normal,
Piece::none(),
),
score: -201,
ply: 68,
result: 0,
},
TrainingDataEntry {
pos: Position::from_fen("1q5b/1r5k/4p2p/1b2P1pN/2Pp4/6PP/1n4B1/1Q2B1K1 b - - 0 35")
.unwrap(),
mv: Move::new(
Square::new(27),
Square::new(19),
MoveType::Normal,
Piece::none(),
),
score: 254,
ply: 69,
result: 0,
},
TrainingDataEntry {
pos: Position::from_fen(
"1q5b/1r5k/4p2p/1b2P1pN/2P5/3p2PP/1n4B1/1Q2B1K1 w - - 0 36",
)
.unwrap(),
mv: Move::new(
Square::new(14),
Square::new(49),
MoveType::Normal,
Piece::none(),
),
score: -220,
ply: 70,
result: 0,
},
];
{
let file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.append(false)
.open("test/ep_new1.binpack")
.unwrap();
let mut writer = CompressedTrainingDataEntryWriter::new(file).unwrap();
for entry in entries.iter() {
writer.write_entry(entry).unwrap();
}
}
let file1_bytes = fs::read("test/ep_new1.binpack").unwrap();
let file2_bytes = fs::read("test/ep1.binpack").unwrap();
assert_eq!(file1_bytes, file2_bytes);
let _ = fs::remove_file("test/ep_new1.binpack");
}
#[test]
fn test_compressed_writer_in_memory_file() {
let entries = vec![
TrainingDataEntry {
pos: Position::from_fen("1q5b/1r5k/4p2p/1b2P1pN/3p4/6PP/1nP3B1/1Q2B1K1 w - - 0 35")
.unwrap(),
mv: Move::new(
Square::new(10),
Square::new(26),
MoveType::Normal,
Piece::none(),
),
score: -201,
ply: 68,
result: 0,
},
TrainingDataEntry {
pos: Position::from_fen("1q5b/1r5k/4p2p/1b2P1pN/2Pp4/6PP/1n4B1/1Q2B1K1 b - - 0 35")
.unwrap(),
mv: Move::new(
Square::new(27),
Square::new(19),
MoveType::Normal,
Piece::none(),
),
score: 254,
ply: 69,
result: 0,
},
TrainingDataEntry {
pos: Position::from_fen(
"1q5b/1r5k/4p2p/1b2P1pN/2P5/3p2PP/1n4B1/1Q2B1K1 w - - 0 36",
)
.unwrap(),
mv: Move::new(
Square::new(14),
Square::new(49),
MoveType::Normal,
Piece::none(),
),
score: -220,
ply: 70,
result: 0,
},
];
let cursor = Cursor::new(Vec::new());
let mut writer = CompressedTrainingDataEntryWriter::new(cursor).unwrap();
for entry in entries.iter() {
writer.write_entry(entry).unwrap();
}
writer.flush_and_end();
let mut cursor = writer.into_inner().unwrap();
cursor.seek(io::SeekFrom::Start(0)).unwrap();
let mut read_bytes = vec![];
cursor.read_to_end(&mut read_bytes).unwrap();
let expected_bytes = fs::read("test/ep1.binpack").unwrap();
assert_eq!(read_bytes, expected_bytes);
}
#[test]
fn test_compressed_writer_big_score_diff() {
let entries = vec![
TrainingDataEntry {
pos: Position::from_fen("1q5b/1r5k/4p2p/1b2P1pN/3p4/6PP/1nP3B1/1Q2B1K1 w - - 0 35")
.unwrap(),
mv: Move::new(
Square::new(10),
Square::new(26),
MoveType::Normal,
Piece::none(),
),
score: -31999,
ply: 68,
result: 0,
},
TrainingDataEntry {
pos: Position::from_fen("1q5b/1r5k/4p2p/1b2P1pN/2Pp4/6PP/1n4B1/1Q2B1K1 b - - 0 35")
.unwrap(),
mv: Move::new(
Square::new(27),
Square::new(19),
MoveType::Normal,
Piece::none(),
),
score: -1500,
ply: 69,
result: 0,
},
];
let cursor = Cursor::new(Vec::new());
let mut writer = CompressedTrainingDataEntryWriter::new(cursor).unwrap();
for entry in entries.iter() {
writer.write_entry(entry).unwrap();
}
writer.flush_and_end();
let mut cursor = writer.into_inner().unwrap();
cursor.seek(io::SeekFrom::Start(0)).unwrap();
let mut read_bytes = vec![];
cursor.read_to_end(&mut read_bytes).unwrap();
let expected_bytes = [
66, 73, 78, 80, 37, 0, 0, 0, 130, 130, 144, 210, 8, 192, 70, 82, 72, 58, 64, 0, 81, 16,
18, 113, 155, 5, 0, 0, 0, 0, 0, 0, 10, 104, 249, 253, 0, 68, 0, 0, 0, 1, 29, 83, 79,
];
assert_eq!(read_bytes, expected_bytes);
}
}