use std::fs::File;
use std::io::{BufReader, Read, Seek, Write};
use failure::Error;
pub trait ReadEmbeddings
where
Self: Sized,
{
fn read_embeddings<R>(read: &mut R) -> Result<Self, Error>
where
R: Read + Seek;
}
pub trait ReadMetadata
where
Self: Sized,
{
fn read_metadata<R>(read: &mut R) -> Result<Self, Error>
where
R: Read + Seek;
}
pub trait MmapEmbeddings
where
Self: Sized,
{
fn mmap_embeddings(read: &mut BufReader<File>) -> Result<Self, Error>;
}
pub trait WriteEmbeddings {
fn write_embeddings<W>(&self, write: &mut W) -> Result<(), Error>
where
W: Write + Seek;
}
pub(crate) mod private {
use std::fs::File;
use std::io::{BufReader, Read, Seek, Write};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use failure::{ensure, format_err, Error, ResultExt};
const MODEL_VERSION: u32 = 0;
const MAGIC: [u8; 4] = [b'F', b'i', b'F', b'u'];
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[repr(u32)]
pub enum ChunkIdentifier {
Header = 0,
SimpleVocab = 1,
NdArray = 2,
SubwordVocab = 3,
QuantizedArray = 4,
Metadata = 5,
}
impl ChunkIdentifier {
pub fn try_from(identifier: u32) -> Option<Self> {
use ChunkIdentifier::*;
match identifier {
1 => Some(SimpleVocab),
2 => Some(NdArray),
3 => Some(SubwordVocab),
4 => Some(QuantizedArray),
5 => Some(Metadata),
_ => None,
}
}
}
pub trait TypeId {
fn type_id() -> u32;
}
macro_rules! typeid_impl {
($type:ty, $id:expr) => {
impl TypeId for $type {
fn type_id() -> u32 {
$id
}
}
};
}
typeid_impl!(f32, 10);
typeid_impl!(u8, 1);
pub trait ReadChunk
where
Self: Sized,
{
fn read_chunk<R>(read: &mut R) -> Result<Self, Error>
where
R: Read + Seek;
}
pub trait MmapChunk
where
Self: Sized,
{
fn mmap_chunk(read: &mut BufReader<File>) -> Result<Self, Error>;
}
pub trait WriteChunk {
fn chunk_identifier(&self) -> ChunkIdentifier;
fn write_chunk<W>(&self, write: &mut W) -> Result<(), Error>
where
W: Write + Seek;
}
#[derive(Debug, Eq, PartialEq)]
pub(crate) struct Header {
chunk_identifiers: Vec<ChunkIdentifier>,
}
impl Header {
pub fn new(chunk_identifiers: impl Into<Vec<ChunkIdentifier>>) -> Self {
Header {
chunk_identifiers: chunk_identifiers.into(),
}
}
pub fn chunk_identifiers(&self) -> &[ChunkIdentifier] {
&self.chunk_identifiers
}
}
impl WriteChunk for Header {
fn chunk_identifier(&self) -> ChunkIdentifier {
ChunkIdentifier::Header
}
fn write_chunk<W>(&self, write: &mut W) -> Result<(), Error>
where
W: Write + Seek,
{
write.write_all(&MAGIC)?;
write.write_u32::<LittleEndian>(MODEL_VERSION)?;
write.write_u32::<LittleEndian>(self.chunk_identifiers.len() as u32)?;
for &identifier in &self.chunk_identifiers {
write.write_u32::<LittleEndian>(identifier as u32)?
}
Ok(())
}
}
impl ReadChunk for Header {
fn read_chunk<R>(read: &mut R) -> Result<Self, Error>
where
R: Read + Seek,
{
let mut magic = [0u8; 4];
read.read_exact(&mut magic)?;
ensure!(
magic == MAGIC,
"File does not have finalfusion magic, expected: {}, was: {}",
String::from_utf8_lossy(&MAGIC),
String::from_utf8_lossy(&magic)
);
let version = read.read_u32::<LittleEndian>()?;
ensure!(
version == MODEL_VERSION,
"Unknown model version, expected: {}, was: {}",
MODEL_VERSION,
version
);
let chunk_identifiers_len = read.read_u32::<LittleEndian>()? as usize;
let mut chunk_identifiers = Vec::with_capacity(chunk_identifiers_len);
for _ in 0..chunk_identifiers_len {
let identifier = read
.read_u32::<LittleEndian>()
.with_context(|e| format!("Cannot read chunk identifier: {}", e))?;
let chunk_identifier = ChunkIdentifier::try_from(identifier)
.ok_or_else(|| format_err!("Unknown chunk identifier: {}", identifier))?;
chunk_identifiers.push(chunk_identifier);
}
Ok(Header { chunk_identifiers })
}
}
}
#[cfg(test)]
mod tests {
use std::io::{Cursor, Seek, SeekFrom};
use crate::io::private::{ChunkIdentifier, Header, ReadChunk, WriteChunk};
#[test]
fn header_write_read_roundtrip() {
let check_header =
Header::new(vec![ChunkIdentifier::SimpleVocab, ChunkIdentifier::NdArray]);
let mut cursor = Cursor::new(Vec::new());
check_header.write_chunk(&mut cursor).unwrap();
cursor.seek(SeekFrom::Start(0)).unwrap();
let header = Header::read_chunk(&mut cursor).unwrap();
assert_eq!(header, check_header);
}
}