rust2vec/
metadata.rs

1//! Metadata
2
3use std::io::{Read, Seek, Write};
4
5use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
6use failure::{ensure, err_msg, Error};
7use toml::Value;
8
9use crate::io::{
10    private::{ChunkIdentifier, Header, ReadChunk, WriteChunk},
11    ReadMetadata,
12};
13
14/// Embeddings metadata.
15///
16/// finalfusion metadata in TOML format.
17#[derive(Clone, Debug, PartialEq)]
18pub struct Metadata(pub Value);
19
20impl ReadChunk for Metadata {
21    fn read_chunk<R>(read: &mut R) -> Result<Self, Error>
22    where
23        R: Read + Seek,
24    {
25        let chunk_id = ChunkIdentifier::try_from(read.read_u32::<LittleEndian>()?)
26            .ok_or_else(|| err_msg("Unknown chunk identifier"))?;
27        ensure!(
28            chunk_id == ChunkIdentifier::Metadata,
29            "Cannot read chunk {:?} as Metadata",
30            chunk_id
31        );
32
33        // Read chunk length.
34        let chunk_len = read.read_u64::<LittleEndian>()? as usize;
35
36        // Read TOML data.
37        let mut buf = vec![0; chunk_len];
38        read.read_exact(&mut buf)?;
39        let buf_str = String::from_utf8(buf)?;
40
41        Ok(Metadata(buf_str.parse::<Value>()?))
42    }
43}
44
45impl WriteChunk for Metadata {
46    fn chunk_identifier(&self) -> ChunkIdentifier {
47        ChunkIdentifier::Metadata
48    }
49
50    fn write_chunk<W>(&self, write: &mut W) -> Result<(), Error>
51    where
52        W: Write + Seek,
53    {
54        let metadata_str = self.0.to_string();
55
56        write.write_u32::<LittleEndian>(self.chunk_identifier() as u32)?;
57        write.write_u64::<LittleEndian>(metadata_str.len() as u64)?;
58        write.write_all(metadata_str.as_bytes())?;
59
60        Ok(())
61    }
62}
63
64impl ReadMetadata for Option<Metadata> {
65    fn read_metadata<R>(read: &mut R) -> Result<Self, Error>
66    where
67        R: Read + Seek,
68    {
69        let header = Header::read_chunk(read)?;
70        let chunks = header.chunk_identifiers();
71        ensure!(!chunks.is_empty(), "Embedding file without chunks.");
72
73        if header.chunk_identifiers()[0] == ChunkIdentifier::Metadata {
74            Ok(Some(Metadata::read_chunk(read)?))
75        } else {
76            Ok(None)
77        }
78    }
79}
80
81#[cfg(test)]
82mod tests {
83    use std::io::{Cursor, Read, Seek, SeekFrom};
84
85    use byteorder::{LittleEndian, ReadBytesExt};
86    use toml::{toml, toml_internal};
87
88    use super::Metadata;
89    use crate::io::private::{ReadChunk, WriteChunk};
90
91    fn read_chunk_size(read: &mut impl Read) -> u64 {
92        // Skip identifier.
93        read.read_u32::<LittleEndian>().unwrap();
94
95        // Return chunk length.
96        read.read_u64::<LittleEndian>().unwrap()
97    }
98
99    fn test_metadata() -> Metadata {
100        Metadata(toml! {
101            [hyperparameters]
102            dims = 300
103            ns = 5
104
105            [description]
106            description = "Test model"
107            language = "de"
108        })
109    }
110
111    #[test]
112    fn metadata_correct_chunk_size() {
113        let check_metadata = test_metadata();
114        let mut cursor = Cursor::new(Vec::new());
115        check_metadata.write_chunk(&mut cursor).unwrap();
116        cursor.seek(SeekFrom::Start(0)).unwrap();
117
118        let chunk_size = read_chunk_size(&mut cursor);
119        assert_eq!(
120            cursor.read_to_end(&mut Vec::new()).unwrap(),
121            chunk_size as usize
122        );
123    }
124
125    #[test]
126    fn metadata_write_read_roundtrip() {
127        let check_metadata = test_metadata();
128        let mut cursor = Cursor::new(Vec::new());
129        check_metadata.write_chunk(&mut cursor).unwrap();
130        cursor.seek(SeekFrom::Start(0)).unwrap();
131        let metadata = Metadata::read_chunk(&mut cursor).unwrap();
132        assert_eq!(metadata, check_metadata);
133    }
134}