1use std::io;
7use std::ops::Deref;
8use std::path::Path;
9
10use super::header::Header;
11use super::ids::{StringId, TypeId};
12use super::instructions::{Call, Match, Opcode, Return, Trampoline};
13use super::sections::{FieldSymbol, NodeSymbol, TriviaEntry};
14use super::type_meta::{TypeData, TypeDef, TypeKind, TypeMember, TypeMetaHeader, TypeName};
15use super::{Entrypoint, SECTION_ALIGN, STEP_SIZE, VERSION};
16
17#[inline]
19fn read_u16_le(bytes: &[u8], offset: usize) -> u16 {
20 u16::from_le_bytes([bytes[offset], bytes[offset + 1]])
21}
22
23#[inline]
25fn read_u32_le(bytes: &[u8], offset: usize) -> u32 {
26 u32::from_le_bytes([
27 bytes[offset],
28 bytes[offset + 1],
29 bytes[offset + 2],
30 bytes[offset + 3],
31 ])
32}
33
34#[derive(Debug)]
36pub struct ByteStorage(Vec<u8>);
37
38impl Deref for ByteStorage {
39 type Target = [u8];
40
41 fn deref(&self) -> &Self::Target {
42 &self.0
43 }
44}
45
46impl ByteStorage {
47 pub fn from_vec(bytes: Vec<u8>) -> Self {
49 Self(bytes)
50 }
51
52 pub fn from_file(path: impl AsRef<Path>) -> io::Result<Self> {
54 let bytes = std::fs::read(path)?;
55 Ok(Self(bytes))
56 }
57}
58
59#[derive(Clone, Copy, Debug)]
61pub enum Instruction<'a> {
62 Match(Match<'a>),
63 Call(Call),
64 Return(Return),
65 Trampoline(Trampoline),
66}
67
68impl<'a> Instruction<'a> {
69 #[inline]
71 pub fn from_bytes(bytes: &'a [u8]) -> Self {
72 debug_assert!(bytes.len() >= 8, "instruction too short");
73
74 let opcode = Opcode::from_u8(bytes[0] & 0xF);
75 match opcode {
76 Opcode::Call => {
77 let arr: [u8; 8] = bytes[..8].try_into().unwrap();
78 Self::Call(Call::from_bytes(arr))
79 }
80 Opcode::Return => {
81 let arr: [u8; 8] = bytes[..8].try_into().unwrap();
82 Self::Return(Return::from_bytes(arr))
83 }
84 Opcode::Trampoline => {
85 let arr: [u8; 8] = bytes[..8].try_into().unwrap();
86 Self::Trampoline(Trampoline::from_bytes(arr))
87 }
88 _ => Self::Match(Match::from_bytes(bytes)),
89 }
90 }
91}
92
93#[derive(Debug, thiserror::Error)]
95pub enum ModuleError {
96 #[error("invalid magic: expected PTKQ")]
97 InvalidMagic,
98 #[error("unsupported version: {0} (expected {VERSION})")]
99 UnsupportedVersion(u32),
100 #[error("file too small: {0} bytes (minimum 64)")]
101 FileTooSmall(usize),
102 #[error("size mismatch: header says {header} bytes, got {actual}")]
103 SizeMismatch { header: u32, actual: usize },
104 #[error("io error: {0}")]
105 Io(#[from] io::Error),
106}
107
108#[derive(Debug)]
113pub struct Module {
114 storage: ByteStorage,
115 header: Header,
116}
117
118impl Module {
119 pub fn from_bytes(bytes: Vec<u8>) -> Result<Self, ModuleError> {
121 Self::from_storage(ByteStorage::from_vec(bytes))
122 }
123
124 pub fn from_path(path: impl AsRef<Path>) -> Result<Self, ModuleError> {
126 let storage = ByteStorage::from_file(&path)?;
127 Self::from_storage(storage)
128 }
129
130 fn from_storage(storage: ByteStorage) -> Result<Self, ModuleError> {
132 if storage.len() < 64 {
133 return Err(ModuleError::FileTooSmall(storage.len()));
134 }
135
136 let header = Header::from_bytes(&storage[..64]);
137
138 if !header.validate_magic() {
139 return Err(ModuleError::InvalidMagic);
140 }
141 if !header.validate_version() {
142 return Err(ModuleError::UnsupportedVersion(header.version));
143 }
144 if header.total_size as usize != storage.len() {
145 return Err(ModuleError::SizeMismatch {
146 header: header.total_size,
147 actual: storage.len(),
148 });
149 }
150
151 Ok(Self { storage, header })
152 }
153
154 pub fn header(&self) -> &Header {
156 &self.header
157 }
158
159 pub fn bytes(&self) -> &[u8] {
161 &self.storage
162 }
163
164 #[inline]
166 pub fn decode_step(&self, step: u16) -> Instruction<'_> {
167 let offset = self.header.transitions_offset as usize + (step as usize) * STEP_SIZE;
168 Instruction::from_bytes(&self.storage[offset..])
169 }
170
171 pub fn strings(&self) -> StringsView<'_> {
173 StringsView {
174 blob: &self.storage[self.header.str_blob_offset as usize..],
175 table: self.string_table_slice(),
176 }
177 }
178
179 pub fn node_types(&self) -> SymbolsView<'_, NodeSymbol> {
181 let offset = self.header.node_types_offset as usize;
182 let count = self.header.node_types_count as usize;
183 SymbolsView {
184 bytes: &self.storage[offset..offset + count * 4],
185 count,
186 _marker: std::marker::PhantomData,
187 }
188 }
189
190 pub fn node_fields(&self) -> SymbolsView<'_, FieldSymbol> {
192 let offset = self.header.node_fields_offset as usize;
193 let count = self.header.node_fields_count as usize;
194 SymbolsView {
195 bytes: &self.storage[offset..offset + count * 4],
196 count,
197 _marker: std::marker::PhantomData,
198 }
199 }
200
201 pub fn trivia(&self) -> TriviaView<'_> {
203 let offset = self.header.trivia_offset as usize;
204 let count = self.header.trivia_count as usize;
205 TriviaView {
206 bytes: &self.storage[offset..offset + count * 2],
207 count,
208 }
209 }
210
211 pub fn types(&self) -> TypesView<'_> {
213 let meta_offset = self.header.type_meta_offset as usize;
214 let meta_header = TypeMetaHeader::from_bytes(&self.storage[meta_offset..]);
215
216 let defs_offset = align64(meta_offset + 8);
218 let defs_count = meta_header.type_defs_count as usize;
219 let members_offset = align64(defs_offset + defs_count * 4);
220 let members_count = meta_header.type_members_count as usize;
221 let names_offset = align64(members_offset + members_count * 4);
222 let names_count = meta_header.type_names_count as usize;
223
224 TypesView {
225 defs_bytes: &self.storage[defs_offset..defs_offset + defs_count * 4],
226 members_bytes: &self.storage[members_offset..members_offset + members_count * 4],
227 names_bytes: &self.storage[names_offset..names_offset + names_count * 4],
228 defs_count,
229 members_count,
230 names_count,
231 }
232 }
233
234 pub fn entrypoints(&self) -> EntrypointsView<'_> {
236 let offset = self.header.entrypoints_offset as usize;
237 let count = self.header.entrypoints_count as usize;
238 EntrypointsView {
239 bytes: &self.storage[offset..offset + count * 8],
240 count,
241 }
242 }
243
244 fn string_table_slice(&self) -> &[u8] {
247 let offset = self.header.str_table_offset as usize;
248 let count = self.header.str_table_count as usize;
249 &self.storage[offset..offset + (count + 1) * 4]
250 }
251}
252
253fn align64(offset: usize) -> usize {
255 let rem = offset % SECTION_ALIGN;
256 if rem == 0 {
257 offset
258 } else {
259 offset + SECTION_ALIGN - rem
260 }
261}
262
263pub struct StringsView<'a> {
265 blob: &'a [u8],
266 table: &'a [u8],
267}
268
269impl<'a> StringsView<'a> {
270 pub fn get(&self, id: StringId) -> &'a str {
272 self.get_by_index(id.get() as usize)
273 }
274
275 pub fn get_by_index(&self, idx: usize) -> &'a str {
280 let start = read_u32_le(self.table, idx * 4) as usize;
281 let end = read_u32_le(self.table, (idx + 1) * 4) as usize;
282 std::str::from_utf8(&self.blob[start..end]).expect("invalid UTF-8 in string table")
283 }
284}
285
286pub struct SymbolsView<'a, T> {
288 bytes: &'a [u8],
289 count: usize,
290 _marker: std::marker::PhantomData<T>,
291}
292
293impl<'a> SymbolsView<'a, NodeSymbol> {
294 pub fn get(&self, idx: usize) -> NodeSymbol {
296 assert!(idx < self.count, "node symbol index out of bounds");
297 let offset = idx * 4;
298 NodeSymbol::new(
299 read_u16_le(self.bytes, offset),
300 StringId::new(read_u16_le(self.bytes, offset + 2)),
301 )
302 }
303
304 pub fn len(&self) -> usize {
306 self.count
307 }
308
309 pub fn is_empty(&self) -> bool {
311 self.count == 0
312 }
313}
314
315impl<'a> SymbolsView<'a, FieldSymbol> {
316 pub fn get(&self, idx: usize) -> FieldSymbol {
318 assert!(idx < self.count, "field symbol index out of bounds");
319 let offset = idx * 4;
320 FieldSymbol::new(
321 read_u16_le(self.bytes, offset),
322 StringId::new(read_u16_le(self.bytes, offset + 2)),
323 )
324 }
325
326 pub fn len(&self) -> usize {
328 self.count
329 }
330
331 pub fn is_empty(&self) -> bool {
333 self.count == 0
334 }
335}
336
337pub struct TriviaView<'a> {
339 bytes: &'a [u8],
340 count: usize,
341}
342
343impl<'a> TriviaView<'a> {
344 pub fn get(&self, idx: usize) -> TriviaEntry {
346 assert!(idx < self.count, "trivia index out of bounds");
347 TriviaEntry::new(read_u16_le(self.bytes, idx * 2))
348 }
349
350 pub fn len(&self) -> usize {
352 self.count
353 }
354
355 pub fn is_empty(&self) -> bool {
357 self.count == 0
358 }
359
360 pub fn contains(&self, node_type: u16) -> bool {
362 (0..self.count).any(|i| self.get(i).node_type() == node_type)
363 }
364}
365
366pub struct TypesView<'a> {
373 defs_bytes: &'a [u8],
374 members_bytes: &'a [u8],
375 names_bytes: &'a [u8],
376 defs_count: usize,
377 members_count: usize,
378 names_count: usize,
379}
380
381impl<'a> TypesView<'a> {
382 pub fn get_def(&self, idx: usize) -> TypeDef {
384 assert!(idx < self.defs_count, "type def index out of bounds");
385 let offset = idx * 4;
386 TypeDef::from_bytes(&self.defs_bytes[offset..])
387 }
388
389 pub fn get(&self, id: TypeId) -> Option<TypeDef> {
391 let idx = id.0 as usize;
392 if idx < self.defs_count {
393 Some(self.get_def(idx))
394 } else {
395 None
396 }
397 }
398
399 pub fn get_member(&self, idx: usize) -> TypeMember {
401 assert!(idx < self.members_count, "type member index out of bounds");
402 let offset = idx * 4;
403 TypeMember::new(
404 StringId::new(read_u16_le(self.members_bytes, offset)),
405 TypeId(read_u16_le(self.members_bytes, offset + 2)),
406 )
407 }
408
409 pub fn get_name(&self, idx: usize) -> TypeName {
411 assert!(idx < self.names_count, "type name index out of bounds");
412 let offset = idx * 4;
413 TypeName::new(
414 StringId::new(read_u16_le(self.names_bytes, offset)),
415 TypeId(read_u16_le(self.names_bytes, offset + 2)),
416 )
417 }
418
419 pub fn defs_count(&self) -> usize {
421 self.defs_count
422 }
423
424 pub fn members_count(&self) -> usize {
426 self.members_count
427 }
428
429 pub fn names_count(&self) -> usize {
431 self.names_count
432 }
433
434 pub fn members_of(&self, def: &TypeDef) -> impl Iterator<Item = TypeMember> + '_ {
436 let (start, count) = match def.classify() {
437 TypeData::Composite {
438 member_start,
439 member_count,
440 ..
441 } => (member_start as usize, member_count as usize),
442 _ => (0, 0),
443 };
444 (0..count).map(move |i| self.get_member(start + i))
445 }
446
447 pub fn unwrap_optional(&self, type_id: TypeId) -> (TypeId, bool) {
450 let Some(type_def) = self.get(type_id) else {
451 return (type_id, false);
452 };
453 match type_def.classify() {
454 TypeData::Wrapper {
455 kind: TypeKind::Optional,
456 inner,
457 } => (inner, true),
458 _ => (type_id, false),
459 }
460 }
461}
462
463pub struct EntrypointsView<'a> {
465 bytes: &'a [u8],
466 count: usize,
467}
468
469impl<'a> EntrypointsView<'a> {
470 pub fn get(&self, idx: usize) -> Entrypoint {
472 assert!(idx < self.count, "entrypoint index out of bounds");
473 let offset = idx * 8;
474 Entrypoint::from_bytes(&self.bytes[offset..])
475 }
476
477 pub fn len(&self) -> usize {
479 self.count
480 }
481
482 pub fn is_empty(&self) -> bool {
484 self.count == 0
485 }
486
487 pub fn find_by_name(&self, name: &str, strings: &StringsView<'_>) -> Option<Entrypoint> {
489 (0..self.count)
490 .map(|i| self.get(i))
491 .find(|e| strings.get(e.name()) == name)
492 }
493}