rten_model_file/
header.rs1use std::error::Error;
2use std::fmt::{Display, Formatter};
3
4use rten_base::num::LeBytes;
5
6struct ValueReader<'a> {
8 pos: usize,
9 buf: &'a [u8],
10}
11
12impl<'a> ValueReader<'a> {
13 fn new(buf: &'a [u8]) -> Self {
14 Self { pos: 0, buf }
15 }
16
17 fn read_n<const N: usize>(&mut self) -> Option<[u8; N]> {
19 let chunk = self.buf.get(self.pos..self.pos + N)?;
20 self.pos += N;
21 Some(chunk.try_into().unwrap())
22 }
23
24 fn read<T: LeBytes>(&mut self) -> Option<T> {
28 let chunk = self
29 .buf
30 .get(self.pos..self.pos + std::mem::size_of::<T>())?;
31 self.pos += chunk.len();
32
33 let chunk_array = chunk.try_into().unwrap();
34 Some(T::from_le_bytes(chunk_array))
35 }
36}
37
38#[derive(Clone, Debug, PartialEq)]
40pub enum HeaderError {
41 TooShort,
43
44 UnsupportedVersion,
46
47 InvalidMagic,
49
50 InvalidOffset,
52
53 InvalidLength,
55}
56
57#[derive(Clone, Debug, PartialEq)]
62pub struct Header {
63 pub version: u32,
65
66 pub model_offset: u64,
68
69 pub model_len: u64,
71
72 pub tensor_data_offset: u64,
74}
75
76impl Header {
77 pub const LEN: usize = 32;
79
80 pub fn from_buf(buf: &[u8]) -> Result<Header, HeaderError> {
85 let too_short = Err(HeaderError::TooShort);
86
87 let file_size = buf.len() as u64;
90
91 let mut reader = ValueReader::new(buf);
92
93 let Some(magic) = reader.read_n::<4>() else {
94 return too_short;
95 };
96 if &magic != b"RTEN" {
97 return Err(HeaderError::InvalidMagic);
98 }
99
100 let Some(version) = reader.read() else {
101 return too_short;
102 };
103 if version != 2 {
104 return Err(HeaderError::UnsupportedVersion);
105 }
106
107 let Some(model_offset) = reader.read::<u64>() else {
108 return too_short;
109 };
110 if model_offset < Self::LEN as u64 || model_offset > file_size {
111 return Err(HeaderError::InvalidOffset);
112 }
113 let Some(model_len) = reader.read() else {
114 return too_short;
115 };
116 if model_offset.saturating_add(model_len) > file_size {
117 return Err(HeaderError::InvalidLength);
118 }
119
120 let Some(tensor_data_offset) = reader.read() else {
121 return too_short;
122 };
123 if tensor_data_offset < Self::LEN as u64 || tensor_data_offset > file_size {
124 return Err(HeaderError::InvalidOffset);
125 }
126
127 Ok(Header {
128 version,
129 model_offset,
130 model_len,
131 tensor_data_offset,
132 })
133 }
134
135 pub fn to_buf(&self) -> Vec<u8> {
137 let mut buffer = Vec::new();
138
139 buffer.extend(b"RTEN");
140 buffer.extend(self.version.to_le_bytes());
141 buffer.extend(self.model_offset.to_le_bytes());
142 buffer.extend(self.model_len.to_le_bytes());
143 buffer.extend(self.tensor_data_offset.to_le_bytes());
144
145 buffer
146 }
147}
148
149impl Display for HeaderError {
150 fn fmt(&self, fmt: &mut Formatter<'_>) -> std::fmt::Result {
151 match self {
152 HeaderError::TooShort => write!(fmt, "header is too short"),
153 HeaderError::UnsupportedVersion => write!(fmt, "unsupported file version"),
154 HeaderError::InvalidMagic => write!(fmt, "incorrect file magic"),
155 HeaderError::InvalidOffset => write!(fmt, "segment offset is invalid"),
156 HeaderError::InvalidLength => write!(fmt, "segment length is invalid"),
157 }
158 }
159}
160
161impl Error for HeaderError {}
162
163#[cfg(test)]
164mod tests {
165 use rten_testing::TestCases;
166
167 use super::{Header, HeaderError};
168
169 #[test]
170 fn test_read_header() {
171 let expected_header = Header {
172 version: 2,
173 model_offset: 32,
175 model_len: 32,
176 tensor_data_offset: 64,
177 };
178
179 let mut header_buf = expected_header.to_buf();
180 header_buf.extend([0; 32]);
181 let header = Header::from_buf(&header_buf).unwrap();
182
183 assert_eq!(header, expected_header);
184 }
185
186 #[test]
187 fn test_invalid_header() {
188 #[derive(Debug)]
189 struct Case {
190 buf: Vec<u8>,
191 expected: HeaderError,
192 }
193
194 let cases = [
195 Case {
196 buf: Vec::new(),
197 expected: HeaderError::TooShort,
198 },
199 Case {
200 buf: b"This is some random ASCII text and not a valid header".to_vec(),
201 expected: HeaderError::InvalidMagic,
202 },
203 Case {
204 buf: Header {
205 version: 10,
206 model_offset: 0,
207 model_len: 0,
208 tensor_data_offset: 0,
209 }
210 .to_buf(),
211 expected: HeaderError::UnsupportedVersion,
212 },
213 Case {
215 buf: Header {
216 version: 2,
217 model_offset: 0,
218 model_len: 0,
219 tensor_data_offset: 0,
220 }
221 .to_buf(),
222 expected: HeaderError::InvalidOffset,
223 },
224 Case {
226 buf: Header {
227 version: 2,
228 model_offset: 500,
229 model_len: 0,
230 tensor_data_offset: 500,
231 }
232 .to_buf(),
233 expected: HeaderError::InvalidOffset,
234 },
235 Case {
237 buf: Header {
238 version: 2,
239 model_offset: 32,
240 model_len: 1024,
241 tensor_data_offset: 0,
242 }
243 .to_buf(),
244 expected: HeaderError::InvalidLength,
245 },
246 ];
247
248 cases.test_each(|Case { buf, expected }| {
249 let result = Header::from_buf(buf);
250 assert_eq!(result.as_ref(), Err(expected));
251 })
252 }
253}