rten_model_file/
header.rs

1use std::error::Error;
2use std::fmt::{Display, Formatter};
3
4use rten_base::num::LeBytes;
5
6/// Read little-endian encoded primitive values from a byte buffer.
7struct ValueReader<'a> {
8    pos: usize,
9    buf: &'a [u8],
10}
11
12impl<'a> ValueReader<'a> {
13    fn new(buf: &'a [u8]) -> Self {
14        Self { pos: 0, buf }
15    }
16
17    /// Return the next N bytes from the buffer, or None if there aren't enough.
18    fn read_n<const N: usize>(&mut self) -> Option<[u8; N]> {
19        let chunk = self.buf.get(self.pos..self.pos + N)?;
20        self.pos += N;
21        Some(chunk.try_into().unwrap())
22    }
23
24    /// Read a little-endian encoded value.
25    ///
26    /// Returns None if there are not enough bytes left in the buffer.
27    fn read<T: LeBytes>(&mut self) -> Option<T> {
28        let chunk = self
29            .buf
30            .get(self.pos..self.pos + std::mem::size_of::<T>())?;
31        self.pos += chunk.len();
32
33        let chunk_array = chunk.try_into().unwrap();
34        Some(T::from_le_bytes(chunk_array))
35    }
36}
37
38/// Errors produced when reading the header for an RTen model file.
39#[derive(Clone, Debug, PartialEq)]
40pub enum HeaderError {
41    /// The header is incomplete
42    TooShort,
43
44    /// The file format version specified in the header is unsupported.
45    UnsupportedVersion,
46
47    /// The header doesn't start with the magic bytes "RTEN".
48    InvalidMagic,
49
50    /// A segment offset in the header is invalid.
51    InvalidOffset,
52
53    /// A segment length in the header is invalid.
54    InvalidLength,
55}
56
57/// Header for an RTen model file.
58///
59/// This specifies the file version and offset of the model data and tensor
60/// data within the file.
61#[derive(Clone, Debug, PartialEq)]
62pub struct Header {
63    /// Major version of the file format. Currently 2.
64    pub version: u32,
65
66    /// Offset of the FlatBuffers data describing the model.
67    pub model_offset: u64,
68
69    /// Length of the FlatBuffers data describing the model.
70    pub model_len: u64,
71
72    /// Offset of tensor data stored outside the model.
73    pub tensor_data_offset: u64,
74}
75
76impl Header {
77    /// Size of the serialized header in bytes.
78    pub const LEN: usize = 32;
79
80    /// Read the file header from a byte buffer.
81    ///
82    /// `buf` is expected to be a slice that contains the entire file, as its
83    /// length is used to validate offsets in the header.
84    pub fn from_buf(buf: &[u8]) -> Result<Header, HeaderError> {
85        let too_short = Err(HeaderError::TooShort);
86
87        // This could be passed in separately if we wanted to avoid needing to
88        // read or mmap the entire file just to read the header.
89        let file_size = buf.len() as u64;
90
91        let mut reader = ValueReader::new(buf);
92
93        let Some(magic) = reader.read_n::<4>() else {
94            return too_short;
95        };
96        if &magic != b"RTEN" {
97            return Err(HeaderError::InvalidMagic);
98        }
99
100        let Some(version) = reader.read() else {
101            return too_short;
102        };
103        if version != 2 {
104            return Err(HeaderError::UnsupportedVersion);
105        }
106
107        let Some(model_offset) = reader.read::<u64>() else {
108            return too_short;
109        };
110        if model_offset < Self::LEN as u64 || model_offset > file_size {
111            return Err(HeaderError::InvalidOffset);
112        }
113        let Some(model_len) = reader.read() else {
114            return too_short;
115        };
116        if model_offset.saturating_add(model_len) > file_size {
117            return Err(HeaderError::InvalidLength);
118        }
119
120        let Some(tensor_data_offset) = reader.read() else {
121            return too_short;
122        };
123        if tensor_data_offset < Self::LEN as u64 || tensor_data_offset > file_size {
124            return Err(HeaderError::InvalidOffset);
125        }
126
127        Ok(Header {
128            version,
129            model_offset,
130            model_len,
131            tensor_data_offset,
132        })
133    }
134
135    /// Serialize this header to a byte buffer.
136    pub fn to_buf(&self) -> Vec<u8> {
137        let mut buffer = Vec::new();
138
139        buffer.extend(b"RTEN");
140        buffer.extend(self.version.to_le_bytes());
141        buffer.extend(self.model_offset.to_le_bytes());
142        buffer.extend(self.model_len.to_le_bytes());
143        buffer.extend(self.tensor_data_offset.to_le_bytes());
144
145        buffer
146    }
147}
148
149impl Display for HeaderError {
150    fn fmt(&self, fmt: &mut Formatter<'_>) -> std::fmt::Result {
151        match self {
152            HeaderError::TooShort => write!(fmt, "header is too short"),
153            HeaderError::UnsupportedVersion => write!(fmt, "unsupported file version"),
154            HeaderError::InvalidMagic => write!(fmt, "incorrect file magic"),
155            HeaderError::InvalidOffset => write!(fmt, "segment offset is invalid"),
156            HeaderError::InvalidLength => write!(fmt, "segment length is invalid"),
157        }
158    }
159}
160
161impl Error for HeaderError {}
162
163#[cfg(test)]
164mod tests {
165    use rten_testing::TestCases;
166
167    use super::{Header, HeaderError};
168
169    #[test]
170    fn test_read_header() {
171        let expected_header = Header {
172            version: 2,
173            // nb. Values must be >= header size and <= length of buffer.
174            model_offset: 32,
175            model_len: 32,
176            tensor_data_offset: 64,
177        };
178
179        let mut header_buf = expected_header.to_buf();
180        header_buf.extend([0; 32]);
181        let header = Header::from_buf(&header_buf).unwrap();
182
183        assert_eq!(header, expected_header);
184    }
185
186    #[test]
187    fn test_invalid_header() {
188        #[derive(Debug)]
189        struct Case {
190            buf: Vec<u8>,
191            expected: HeaderError,
192        }
193
194        let cases = [
195            Case {
196                buf: Vec::new(),
197                expected: HeaderError::TooShort,
198            },
199            Case {
200                buf: b"This is some random ASCII text and not a valid header".to_vec(),
201                expected: HeaderError::InvalidMagic,
202            },
203            Case {
204                buf: Header {
205                    version: 10,
206                    model_offset: 0,
207                    model_len: 0,
208                    tensor_data_offset: 0,
209                }
210                .to_buf(),
211                expected: HeaderError::UnsupportedVersion,
212            },
213            // Offsets too small.
214            Case {
215                buf: Header {
216                    version: 2,
217                    model_offset: 0,
218                    model_len: 0,
219                    tensor_data_offset: 0,
220                }
221                .to_buf(),
222                expected: HeaderError::InvalidOffset,
223            },
224            // Offsets exceed buffer size.
225            Case {
226                buf: Header {
227                    version: 2,
228                    model_offset: 500,
229                    model_len: 0,
230                    tensor_data_offset: 500,
231                }
232                .to_buf(),
233                expected: HeaderError::InvalidOffset,
234            },
235            // Offset + length exceeds buffer size
236            Case {
237                buf: Header {
238                    version: 2,
239                    model_offset: 32,
240                    model_len: 1024,
241                    tensor_data_offset: 0,
242                }
243                .to_buf(),
244                expected: HeaderError::InvalidLength,
245            },
246        ];
247
248        cases.test_each(|Case { buf, expected }| {
249            let result = Header::from_buf(buf);
250            assert_eq!(result.as_ref(), Err(expected));
251        })
252    }
253}