use std::fmt::{self, Display};
use std::fs::File;
use std::io::{BufReader, Read, Seek, Write};
use std::mem;
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use crate::error::{Error, Result};
const MODEL_VERSION: u32 = 0;
const MAGIC: [u8; 4] = [b'F', b'i', b'F', b'u'];
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[repr(u32)]
pub enum ChunkIdentifier {
Header = 0,
SimpleVocab = 1,
NdArray = 2,
BucketSubwordVocab = 3,
QuantizedArray = 4,
Metadata = 5,
NdNorms = 6,
FastTextSubwordVocab = 7,
ExplicitSubwordVocab = 8,
FloretSubwordVocab = 9,
}
impl ChunkIdentifier {
pub fn ensure_chunk_type<R>(read: &mut R, identifier: ChunkIdentifier) -> Result<()>
where
R: Read,
{
let chunk_id = read
.read_u32::<LittleEndian>()
.map_err(|e| Error::read_error("Cannot read chunk identifier", e))?;
let chunk_id = ChunkIdentifier::try_from(chunk_id)?;
if chunk_id != identifier {
return Err(Error::Format(format!(
"Invalid chunk identifier, expected: {}, got: {}",
identifier, chunk_id
)));
}
Ok(())
}
}
impl Display for ChunkIdentifier {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use self::ChunkIdentifier::*;
match self {
Header => write!(f, "Header"),
SimpleVocab => write!(f, "SimpleVocab"),
NdArray => write!(f, "NdArray"),
FastTextSubwordVocab => write!(f, "FastTextSubwordVocab"),
ExplicitSubwordVocab => write!(f, "ExplicitSubwordVocab"),
FloretSubwordVocab => write!(f, "FloretSubwordVocab"),
BucketSubwordVocab => write!(f, "BucketSubwordVocab"),
QuantizedArray => write!(f, "QuantizedArray"),
Metadata => write!(f, "Metadata"),
NdNorms => write!(f, "NdNorms"),
}
}
}
impl TryFrom<u32> for ChunkIdentifier {
type Error = Error;
fn try_from(identifier: u32) -> Result<Self> {
use self::ChunkIdentifier::*;
match identifier {
1 => Ok(SimpleVocab),
2 => Ok(NdArray),
3 => Ok(BucketSubwordVocab),
4 => Ok(QuantizedArray),
5 => Ok(Metadata),
6 => Ok(NdNorms),
7 => Ok(FastTextSubwordVocab),
8 => Ok(ExplicitSubwordVocab),
9 => Ok(FloretSubwordVocab),
unknown => Err(Error::UnknownChunkIdentifier(unknown)),
}
}
}
pub trait TypeId {
fn ensure_data_type<R>(read: &mut R) -> Result<()>
where
R: Read;
fn type_id() -> u32;
}
macro_rules! typeid_impl {
($type:ty, $id:expr) => {
impl TypeId for $type {
fn ensure_data_type<R>(read: &mut R) -> Result<()>
where
R: Read,
{
let type_id = read
.read_u32::<LittleEndian>()
.map_err(|e| Error::read_error("Cannot read type identifier", e))?;
if type_id != Self::type_id() {
return Err(Error::Format(format!(
"Invalid type, expected: {}, got: {}",
Self::type_id(),
type_id
)));
}
Ok(())
}
fn type_id() -> u32 {
$id
}
}
};
}
typeid_impl!(f32, 10);
typeid_impl!(u8, 1);
pub trait ReadChunk
where
Self: Sized,
{
fn read_chunk<R>(read: &mut R) -> Result<Self>
where
R: Read + Seek;
}
pub trait MmapChunk
where
Self: Sized,
{
fn mmap_chunk(read: &mut BufReader<File>) -> Result<Self>;
}
pub trait WriteChunk {
fn chunk_identifier(&self) -> ChunkIdentifier;
fn chunk_len(&self, offset: u64) -> u64;
fn write_chunk<W>(&self, write: &mut W) -> Result<()>
where
W: Write + Seek;
}
#[derive(Debug, Eq, PartialEq)]
pub(crate) struct Header {
chunk_identifiers: Vec<ChunkIdentifier>,
}
impl Header {
pub fn new(chunk_identifiers: impl Into<Vec<ChunkIdentifier>>) -> Self {
Header {
chunk_identifiers: chunk_identifiers.into(),
}
}
pub fn chunk_identifiers(&self) -> &[ChunkIdentifier] {
&self.chunk_identifiers
}
}
impl WriteChunk for Header {
fn chunk_identifier(&self) -> ChunkIdentifier {
ChunkIdentifier::Header
}
fn chunk_len(&self, _offset: u64) -> u64 {
(MAGIC.len()
+ mem::size_of_val(&MODEL_VERSION)
+ mem::size_of::<u32>()
+ self.chunk_identifiers.len() * mem::size_of::<u32>()) as u64
}
fn write_chunk<W>(&self, write: &mut W) -> Result<()>
where
W: Write + Seek,
{
write
.write_all(&MAGIC)
.map_err(|e| Error::write_error("Cannot write magic", e))?;
write
.write_u32::<LittleEndian>(MODEL_VERSION)
.map_err(|e| Error::write_error("Cannot write model version", e))?;
write
.write_u32::<LittleEndian>(self.chunk_identifiers.len() as u32)
.map_err(|e| Error::write_error("Cannot write chunk identifiers length", e))?;
for &identifier in &self.chunk_identifiers {
write
.write_u32::<LittleEndian>(identifier as u32)
.map_err(|e| Error::write_error("Cannot write chunk identifier", e))?;
}
Ok(())
}
}
impl ReadChunk for Header {
fn read_chunk<R>(read: &mut R) -> Result<Self>
where
R: Read + Seek,
{
let mut magic = [0u8; 4];
read.read_exact(&mut magic)
.map_err(|e| Error::read_error("Cannot read magic", e))?;
if magic != MAGIC {
return Err(Error::Format(format!(
"Expected 'FiFu' as magic, got: {}",
String::from_utf8_lossy(&magic).into_owned()
)));
}
let version = read
.read_u32::<LittleEndian>()
.map_err(|e| Error::read_error("Cannot read model version", e))?;
if version != MODEL_VERSION {
return Err(Error::Format(format!(
"Unknown finalfusion version: {}",
version
)));
}
let chunk_identifiers_len = read
.read_u32::<LittleEndian>()
.map_err(|e| Error::read_error("Cannot read chunk identifiers length", e))?
as usize;
let mut chunk_identifiers = Vec::with_capacity(chunk_identifiers_len);
for _ in 0..chunk_identifiers_len {
let identifier = read
.read_u32::<LittleEndian>()
.map_err(|e| Error::read_error("Cannot read chunk identifier", e))?;
let chunk_identifier = ChunkIdentifier::try_from(identifier)?;
chunk_identifiers.push(chunk_identifier);
}
Ok(Header { chunk_identifiers })
}
}
#[cfg(test)]
mod tests {
use std::io::{Cursor, Seek, SeekFrom};
use super::{ChunkIdentifier, Header, ReadChunk, WriteChunk};
#[test]
fn header_chunk_len_is_correct() {
let check_header =
Header::new(vec![ChunkIdentifier::SimpleVocab, ChunkIdentifier::NdArray]);
let mut cursor = Cursor::new(Vec::new());
check_header.write_chunk(&mut cursor).unwrap();
let data = cursor.into_inner();
assert_eq!(data.len() as u64, check_header.chunk_len(0));
}
#[test]
fn header_write_read_roundtrip() {
let check_header =
Header::new(vec![ChunkIdentifier::SimpleVocab, ChunkIdentifier::NdArray]);
let mut cursor = Cursor::new(Vec::new());
check_header.write_chunk(&mut cursor).unwrap();
cursor.seek(SeekFrom::Start(0)).unwrap();
let header = Header::read_chunk(&mut cursor).unwrap();
assert_eq!(header, check_header);
}
}