use crate::{
flavor::ImperatorFlavor,
models::{HeaderBorrowed, HeaderOwned, Save},
tokens::TokenLookup,
FailedResolveStrategy, ImperatorError, ImperatorErrorKind,
};
use jomini::{BinaryDeserializer, TextDeserializer, TextTape};
use serde::de::{Deserialize, DeserializeOwned};
use std::io::{Cursor, Read, Seek, SeekFrom};
use zip::{result::ZipError, ZipArchive};
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Encoding {
Debug,
Standard,
}
#[derive(Debug, Clone, Copy)]
pub enum Extraction {
InMemory,
#[cfg(feature = "mmap")]
MmapTemporaries,
}
#[derive(Debug, Clone)]
pub struct ImperatorExtractorBuilder {
extraction: Extraction,
on_failed_resolve: FailedResolveStrategy,
}
impl Default for ImperatorExtractorBuilder {
fn default() -> Self {
ImperatorExtractorBuilder::new()
}
}
impl ImperatorExtractorBuilder {
pub fn new() -> Self {
ImperatorExtractorBuilder {
extraction: Extraction::InMemory,
on_failed_resolve: FailedResolveStrategy::Ignore,
}
}
pub fn with_extraction(mut self, extraction: Extraction) -> Self {
self.extraction = extraction;
self
}
pub fn with_on_failed_resolve(mut self, strategy: FailedResolveStrategy) -> Self {
self.on_failed_resolve = strategy;
self
}
pub fn extract_header_owned(
&self,
data: &[u8],
) -> Result<(HeaderOwned, Encoding), ImperatorError> {
self.extract_header_as(data)
}
pub fn extract_header_borrowed<'a>(
&self,
data: &'a [u8],
) -> Result<(HeaderBorrowed<'a>, Encoding), ImperatorError> {
self.extract_header_as(data)
}
pub fn extract_header_as<'de, T>(
&self,
data: &'de [u8],
) -> Result<(T, Encoding), ImperatorError>
where
T: Deserialize<'de>,
{
let data = skip_save_prefix(&data);
let mut cursor = Cursor::new(data);
let offset = match detect_encoding(&mut cursor)? {
BodyEncoding::Plain => data.len(),
BodyEncoding::Zip(zip) => zip.offset() as usize,
};
let data = &data[..offset];
if sniff_is_binary(data) {
let res = BinaryDeserializer::builder_flavor(ImperatorFlavor)
.on_failed_resolve(self.on_failed_resolve)
.from_slice(data, &TokenLookup)?;
Ok((res, Encoding::Standard))
} else {
let res = TextDeserializer::from_utf8_slice(data)?;
Ok((res, Encoding::Debug))
}
}
pub fn extract_save<R>(&self, reader: R) -> Result<(Save, Encoding), ImperatorError>
where
R: Read + Seek,
{
self.extract_save_as(reader)
}
fn extract_save_as<R>(&self, mut reader: R) -> Result<(Save, Encoding), ImperatorError>
where
R: Read + Seek,
{
let mut buffer = Vec::new();
match detect_encoding(&mut reader)? {
BodyEncoding::Plain => {
reader.seek(SeekFrom::Start(0))?;
let len = reader.seek(SeekFrom::End(0))?;
reader.seek(SeekFrom::Start(0))?;
buffer.reserve(len as usize);
reader.read_to_end(&mut buffer)?;
let data = skip_save_prefix(&buffer);
let tape = TextTape::from_slice(data)?;
let header = TextDeserializer::from_utf8_tape(&tape)?;
let gamestate = TextDeserializer::from_utf8_tape(&tape)?;
Ok((Save { header, gamestate }, Encoding::Debug))
}
BodyEncoding::Zip(mut zip) => {
let res = match self.extraction {
Extraction::InMemory => {
melt_in_memory(&mut buffer, "gamestate", &mut zip, self.on_failed_resolve)
}
#[cfg(feature = "mmap")]
Extraction::MmapTemporaries => {
melt_with_temporary("gamestate", &mut zip, self.on_failed_resolve)
}
}?;
Ok((res, Encoding::Standard))
}
}
}
}
#[derive(Debug, Clone)]
pub struct ImperatorExtractor {}
impl ImperatorExtractor {
pub fn builder() -> ImperatorExtractorBuilder {
ImperatorExtractorBuilder::new()
}
pub fn extract_header(data: &[u8]) -> Result<(HeaderOwned, Encoding), ImperatorError> {
Self::builder().extract_header_owned(data)
}
pub fn extract_save<R>(reader: R) -> Result<(Save, Encoding), ImperatorError>
where
R: Read + Seek,
{
Self::builder().extract_save(reader)
}
}
fn melt_in_memory<T, R>(
mut buffer: &mut Vec<u8>,
name: &'static str,
zip: &mut zip::ZipArchive<R>,
on_failed_resolve: FailedResolveStrategy,
) -> Result<T, ImperatorError>
where
R: Read + Seek,
T: DeserializeOwned,
{
buffer.clear();
let mut zip_file = zip
.by_name(name)
.map_err(|e| ImperatorErrorKind::ZipMissingEntry(name, e))?;
if zip_file.size() > 1024 * 1024 * 200 {
return Err(ImperatorErrorKind::ZipSize(name).into());
}
buffer.reserve(zip_file.size() as usize);
zip_file
.read_to_end(&mut buffer)
.map_err(|e| ImperatorErrorKind::ZipExtraction(name, e))?;
let res = BinaryDeserializer::builder_flavor(ImperatorFlavor)
.on_failed_resolve(on_failed_resolve)
.from_slice(buffer, &TokenLookup)
.map_err(|e| ImperatorErrorKind::Deserialize {
part: Some(name.to_string()),
err: e,
})?;
Ok(res)
}
#[cfg(feature = "mmap")]
fn melt_with_temporary<T, R>(
name: &'static str,
zip: &mut zip::ZipArchive<R>,
on_failed_resolve: FailedResolveStrategy,
) -> Result<T, ImperatorError>
where
R: Read + Seek,
T: DeserializeOwned,
{
let mut zip_file = zip
.by_name(name)
.map_err(|e| ImperatorErrorKind::ZipMissingEntry(name, e))?;
if zip_file.size() > 1024 * 1024 * 200 {
return Err(ImperatorErrorKind::ZipSize(name).into());
}
let mut mmap = memmap::MmapMut::map_anon(zip_file.size() as usize)?;
std::io::copy(&mut zip_file, &mut mmap.as_mut())
.map_err(|e| ImperatorErrorKind::ZipExtraction(name, e))?;
let buffer = &mmap[..];
let res = BinaryDeserializer::builder_flavor(ImperatorFlavor)
.on_failed_resolve(on_failed_resolve)
.from_slice(buffer, &TokenLookup)
.map_err(|e| ImperatorErrorKind::Deserialize {
part: Some(name.to_string()),
err: e,
})?;
Ok(res)
}
fn skip_save_prefix(data: &[u8]) -> &[u8] {
let id_line_idx = data
.iter()
.position(|&x| x == b'\n')
.map(|x| x + 1)
.unwrap_or(0);
&data[id_line_idx..]
}
fn sniff_is_binary(data: &[u8]) -> bool {
data.get(2..4).map_or(false, |x| x == [0x01, 0x00])
}
pub(crate) enum BodyEncoding<'a, R>
where
R: Read + Seek,
{
Zip(ZipArchive<&'a mut R>),
Plain,
}
pub(crate) fn detect_encoding<R>(reader: &mut R) -> Result<BodyEncoding<R>, ImperatorError>
where
R: Read + Seek,
{
let zip_attempt = zip::ZipArchive::new(reader);
match zip_attempt {
Ok(x) => Ok(BodyEncoding::Zip(x)),
Err(ZipError::InvalidArchive(_)) => Ok(BodyEncoding::Plain),
Err(e) => Err(ImperatorError::new(
ImperatorErrorKind::ZipCentralDirectory(e),
)),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_skip_save_prefix() {
let data = b"abc\n123";
let result = skip_save_prefix(&data[..]);
assert_eq!(result, b"123");
}
}