use std::io::{Read, Seek, Write};
use std::mem;
use std::ops::{Deref, DerefMut};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use toml::Table;
use crate::chunks::io::{ChunkIdentifier, Header, ReadChunk, WriteChunk};
use crate::error::{Error, Result};
use crate::io::ReadMetadata;
#[derive(Clone, Debug, PartialEq)]
pub struct Metadata {
inner: Table,
}
impl Metadata {
pub fn new(inner: Table) -> Self {
Metadata { inner }
}
}
impl Deref for Metadata {
type Target = Table;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl DerefMut for Metadata {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
impl From<Table> for Metadata {
fn from(inner: Table) -> Self {
Metadata { inner }
}
}
impl ReadChunk for Metadata {
fn read_chunk<R>(read: &mut R) -> Result<Self>
where
R: Read + Seek,
{
ChunkIdentifier::ensure_chunk_type(read, ChunkIdentifier::Metadata)?;
let chunk_len = read
.read_u64::<LittleEndian>()
.map_err(|e| Error::read_error("Cannot read chunk length", e))?
as usize;
let mut buf = vec![0; chunk_len];
read.read_exact(&mut buf)
.map_err(|e| Error::read_error("Cannot read TOML metadata", e))?;
let buf_str = String::from_utf8(buf)
.map_err(|e| Error::Format(format!("TOML metadata contains invalid UTF-8: {}", e)))
.map_err(Error::from)?;
Ok(Metadata::new(
buf_str
.parse::<Table>()
.map_err(|e| Error::Format(format!("Cannot deserialize TOML metadata: {}", e)))
.map_err(Error::from)?,
))
}
}
impl WriteChunk for Metadata {
fn chunk_identifier(&self) -> ChunkIdentifier {
ChunkIdentifier::Metadata
}
fn chunk_len(&self, _offset: u64) -> u64 {
(mem::size_of::<u32>() + mem::size_of::<u64>() + self.to_string().len()) as u64
}
fn write_chunk<W>(&self, write: &mut W) -> Result<()>
where
W: Write + Seek,
{
let metadata_str = self.to_string();
write
.write_u32::<LittleEndian>(self.chunk_identifier() as u32)
.map_err(|e| Error::write_error("Cannot write metadata chunk identifier", e))?;
write
.write_u64::<LittleEndian>(metadata_str.len() as u64)
.map_err(|e| Error::write_error("Cannot write metadata length", e))?;
write
.write_all(metadata_str.as_bytes())
.map_err(|e| Error::write_error("Cannot write metadata", e))?;
Ok(())
}
}
impl ReadMetadata for Option<Metadata> {
fn read_metadata<R>(read: &mut R) -> Result<Self>
where
R: Read + Seek,
{
let header = Header::read_chunk(read)?;
let chunks = header.chunk_identifiers();
if chunks.is_empty() {
return Err(Error::Format(String::from(
"Embedding file does not contain chunks",
)));
}
if header.chunk_identifiers()[0] == ChunkIdentifier::Metadata {
Ok(Some(Metadata::read_chunk(read)?))
} else {
Ok(None)
}
}
}
#[cfg(test)]
mod tests {
use std::io::{Cursor, Read, Seek, SeekFrom};
use byteorder::{LittleEndian, ReadBytesExt};
use toml::toml;
use super::Metadata;
use crate::chunks::io::{ReadChunk, WriteChunk};
fn read_chunk_size(read: &mut impl Read) -> u64 {
read.read_u32::<LittleEndian>().unwrap();
read.read_u64::<LittleEndian>().unwrap()
}
fn test_metadata() -> Metadata {
Metadata::new(toml! {
[hyperparameters]
dims = 300
ns = 5
[description]
description = "Test model"
language = "de"
})
}
#[test]
fn metadata_correct_chunk_size() {
for offset in 0..16u64 {
let check_metadata = test_metadata();
let mut cursor = Cursor::new(Vec::new());
cursor.seek(SeekFrom::Start(offset)).unwrap();
check_metadata.write_chunk(&mut cursor).unwrap();
cursor.seek(SeekFrom::Start(offset)).unwrap();
let chunk_size = read_chunk_size(&mut cursor);
assert_eq!(
cursor.read_to_end(&mut Vec::new()).unwrap() as u64,
chunk_size
);
let data = cursor.into_inner();
assert_eq!(data.len() as u64 - offset, check_metadata.chunk_len(offset));
}
}
#[test]
fn metadata_write_read_roundtrip() {
let check_metadata = test_metadata();
let mut cursor = Cursor::new(Vec::new());
check_metadata.write_chunk(&mut cursor).unwrap();
cursor.seek(SeekFrom::Start(0)).unwrap();
let metadata = Metadata::read_chunk(&mut cursor).unwrap();
assert_eq!(metadata, check_metadata);
}
}