use std::io;
use std::ops::Deref;
use std::path::Path;
use super::aligned_vec::AlignedVec;
use super::header::{Header, SectionOffsets};
use super::ids::{StringId, TypeId};
use super::instructions::{Call, Match, Opcode, Return, Trampoline};
use super::sections::{FieldSymbol, NodeSymbol, TriviaEntry};
use super::type_meta::{TypeData, TypeDef, TypeKind, TypeMember, TypeName};
use super::{Entrypoint, STEP_SIZE, VERSION};
#[inline]
fn read_u16_le(bytes: &[u8], offset: usize) -> u16 {
u16::from_le_bytes([bytes[offset], bytes[offset + 1]])
}
#[inline]
fn read_u32_le(bytes: &[u8], offset: usize) -> u32 {
u32::from_le_bytes([
bytes[offset],
bytes[offset + 1],
bytes[offset + 2],
bytes[offset + 3],
])
}
pub enum ByteStorage {
Static(&'static [u8]),
Aligned(AlignedVec),
}
impl Deref for ByteStorage {
type Target = [u8];
fn deref(&self) -> &Self::Target {
match self {
ByteStorage::Static(s) => s,
ByteStorage::Aligned(v) => v,
}
}
}
impl std::fmt::Debug for ByteStorage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ByteStorage::Static(s) => f.debug_tuple("Static").field(&s.len()).finish(),
ByteStorage::Aligned(v) => f.debug_tuple("Aligned").field(&v.len()).finish(),
}
}
}
impl ByteStorage {
pub fn from_static(bytes: &'static [u8]) -> Self {
assert!(
(bytes.as_ptr() as usize).is_multiple_of(64),
"static bytes must be 64-byte aligned; use include_query_aligned! macro"
);
Self::Static(bytes)
}
pub fn from_aligned(vec: AlignedVec) -> Self {
Self::Aligned(vec)
}
pub fn copy_from_slice(bytes: &[u8]) -> Self {
Self::Aligned(AlignedVec::copy_from_slice(bytes))
}
pub fn from_file(path: impl AsRef<Path>) -> io::Result<Self> {
Ok(Self::Aligned(AlignedVec::from_file(path)?))
}
}
#[derive(Clone, Copy, Debug)]
pub enum Instruction<'a> {
Match(Match<'a>),
Call(Call),
Return(Return),
Trampoline(Trampoline),
}
impl<'a> Instruction<'a> {
#[inline]
pub fn from_bytes(bytes: &'a [u8]) -> Self {
debug_assert!(bytes.len() >= 8, "instruction too short");
let opcode = Opcode::from_u8(bytes[0] & 0xF);
match opcode {
Opcode::Call => {
let arr: [u8; 8] = bytes[..8].try_into().unwrap();
Self::Call(Call::from_bytes(arr))
}
Opcode::Return => {
let arr: [u8; 8] = bytes[..8].try_into().unwrap();
Self::Return(Return::from_bytes(arr))
}
Opcode::Trampoline => {
let arr: [u8; 8] = bytes[..8].try_into().unwrap();
Self::Trampoline(Trampoline::from_bytes(arr))
}
_ => Self::Match(Match::from_bytes(bytes)),
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum ModuleError {
#[error("invalid magic: expected PTKQ")]
InvalidMagic,
#[error("unsupported version: {0} (expected {VERSION})")]
UnsupportedVersion(u32),
#[error("file too small: {0} bytes (minimum 64)")]
FileTooSmall(usize),
#[error("size mismatch: header says {header} bytes, got {actual}")]
SizeMismatch { header: u32, actual: usize },
#[error("io error: {0}")]
Io(#[from] io::Error),
}
#[derive(Debug)]
pub struct Module {
storage: ByteStorage,
header: Header,
offsets: SectionOffsets,
}
impl Module {
pub fn from_aligned(vec: AlignedVec) -> Result<Self, ModuleError> {
Self::from_storage(ByteStorage::from_aligned(vec))
}
pub fn from_static(bytes: &'static [u8]) -> Result<Self, ModuleError> {
Self::from_storage(ByteStorage::from_static(bytes))
}
pub fn from_path(path: impl AsRef<Path>) -> Result<Self, ModuleError> {
Self::from_storage(ByteStorage::from_file(&path)?)
}
pub fn load(bytes: &[u8]) -> Result<Self, ModuleError> {
Self::from_storage(ByteStorage::copy_from_slice(bytes))
}
#[deprecated(
since = "0.1.0",
note = "use `Module::from_aligned` for AlignedVec or `Module::load` for copying"
)]
pub fn from_bytes(bytes: Vec<u8>) -> Result<Self, ModuleError> {
Self::load(&bytes)
}
fn from_storage(storage: ByteStorage) -> Result<Self, ModuleError> {
if storage.len() < 64 {
return Err(ModuleError::FileTooSmall(storage.len()));
}
let header = Header::from_bytes(&storage[..64]);
if !header.validate_magic() {
return Err(ModuleError::InvalidMagic);
}
if !header.validate_version() {
return Err(ModuleError::UnsupportedVersion(header.version));
}
if header.total_size as usize != storage.len() {
return Err(ModuleError::SizeMismatch {
header: header.total_size,
actual: storage.len(),
});
}
let offsets = header.compute_offsets();
Ok(Self {
storage,
header,
offsets,
})
}
pub fn header(&self) -> &Header {
&self.header
}
pub fn offsets(&self) -> &SectionOffsets {
&self.offsets
}
pub fn bytes(&self) -> &[u8] {
&self.storage
}
#[inline]
pub fn decode_step(&self, step: u16) -> Instruction<'_> {
let offset = self.offsets.transitions as usize + (step as usize) * STEP_SIZE;
Instruction::from_bytes(&self.storage[offset..])
}
pub fn strings(&self) -> StringsView<'_> {
StringsView {
blob: &self.storage[self.offsets.str_blob as usize..],
table: self.string_table_slice(),
}
}
pub fn node_types(&self) -> SymbolsView<'_, NodeSymbol> {
let offset = self.offsets.node_types as usize;
let count = self.header.node_types_count as usize;
SymbolsView {
bytes: &self.storage[offset..offset + count * 4],
count,
_marker: std::marker::PhantomData,
}
}
pub fn node_fields(&self) -> SymbolsView<'_, FieldSymbol> {
let offset = self.offsets.node_fields as usize;
let count = self.header.node_fields_count as usize;
SymbolsView {
bytes: &self.storage[offset..offset + count * 4],
count,
_marker: std::marker::PhantomData,
}
}
pub fn trivia(&self) -> TriviaView<'_> {
let offset = self.offsets.trivia as usize;
let count = self.header.trivia_count as usize;
TriviaView {
bytes: &self.storage[offset..offset + count * 2],
count,
}
}
pub fn regexes(&self) -> RegexView<'_> {
RegexView {
blob: &self.storage[self.offsets.regex_blob as usize..],
table: self.regex_table_slice(),
}
}
pub fn types(&self) -> TypesView<'_> {
let defs_offset = self.offsets.type_defs as usize;
let defs_count = self.header.type_defs_count as usize;
let members_offset = self.offsets.type_members as usize;
let members_count = self.header.type_members_count as usize;
let names_offset = self.offsets.type_names as usize;
let names_count = self.header.type_names_count as usize;
TypesView {
defs_bytes: &self.storage[defs_offset..defs_offset + defs_count * 4],
members_bytes: &self.storage[members_offset..members_offset + members_count * 4],
names_bytes: &self.storage[names_offset..names_offset + names_count * 4],
defs_count,
members_count,
names_count,
}
}
pub fn entrypoints(&self) -> EntrypointsView<'_> {
let offset = self.offsets.entrypoints as usize;
let count = self.header.entrypoints_count as usize;
EntrypointsView {
bytes: &self.storage[offset..offset + count * 8],
count,
}
}
fn string_table_slice(&self) -> &[u8] {
let offset = self.offsets.str_table as usize;
let count = self.header.str_table_count as usize;
&self.storage[offset..offset + (count + 1) * 4]
}
fn regex_table_slice(&self) -> &[u8] {
let offset = self.offsets.regex_table as usize;
let count = self.header.regex_table_count as usize;
&self.storage[offset..offset + (count + 1) * 8]
}
}
pub struct StringsView<'a> {
blob: &'a [u8],
table: &'a [u8],
}
impl<'a> StringsView<'a> {
pub fn get(&self, id: StringId) -> &'a str {
self.get_by_index(id.get() as usize)
}
pub fn get_by_index(&self, idx: usize) -> &'a str {
let start = read_u32_le(self.table, idx * 4) as usize;
let end = read_u32_le(self.table, (idx + 1) * 4) as usize;
std::str::from_utf8(&self.blob[start..end]).expect("invalid UTF-8 in string table")
}
}
pub struct SymbolsView<'a, T> {
bytes: &'a [u8],
count: usize,
_marker: std::marker::PhantomData<T>,
}
impl<'a> SymbolsView<'a, NodeSymbol> {
pub fn get(&self, idx: usize) -> NodeSymbol {
assert!(idx < self.count, "node symbol index out of bounds");
let offset = idx * 4;
NodeSymbol::new(
read_u16_le(self.bytes, offset),
StringId::new(read_u16_le(self.bytes, offset + 2)),
)
}
pub fn len(&self) -> usize {
self.count
}
pub fn is_empty(&self) -> bool {
self.count == 0
}
}
impl<'a> SymbolsView<'a, FieldSymbol> {
pub fn get(&self, idx: usize) -> FieldSymbol {
assert!(idx < self.count, "field symbol index out of bounds");
let offset = idx * 4;
FieldSymbol::new(
read_u16_le(self.bytes, offset),
StringId::new(read_u16_le(self.bytes, offset + 2)),
)
}
pub fn len(&self) -> usize {
self.count
}
pub fn is_empty(&self) -> bool {
self.count == 0
}
}
pub struct TriviaView<'a> {
bytes: &'a [u8],
count: usize,
}
impl<'a> TriviaView<'a> {
pub fn get(&self, idx: usize) -> TriviaEntry {
assert!(idx < self.count, "trivia index out of bounds");
TriviaEntry::new(read_u16_le(self.bytes, idx * 2))
}
pub fn len(&self) -> usize {
self.count
}
pub fn is_empty(&self) -> bool {
self.count == 0
}
pub fn contains(&self, node_type: u16) -> bool {
(0..self.count).any(|i| self.get(i).node_type == node_type)
}
}
pub struct RegexView<'a> {
blob: &'a [u8],
table: &'a [u8],
}
impl<'a> RegexView<'a> {
const ENTRY_SIZE: usize = 8;
pub fn get_by_index(&self, idx: usize) -> &'a [u8] {
let entry_offset = idx * Self::ENTRY_SIZE;
let next_entry_offset = (idx + 1) * Self::ENTRY_SIZE;
let start = read_u32_le(self.table, entry_offset + 4) as usize;
let end = read_u32_le(self.table, next_entry_offset + 4) as usize;
&self.blob[start..end]
}
pub fn get_string_id(&self, idx: usize) -> super::StringId {
let entry_offset = idx * Self::ENTRY_SIZE;
let string_id = read_u16_le(self.table, entry_offset);
super::StringId::new(string_id)
}
}
pub struct TypesView<'a> {
defs_bytes: &'a [u8],
members_bytes: &'a [u8],
names_bytes: &'a [u8],
defs_count: usize,
members_count: usize,
names_count: usize,
}
impl<'a> TypesView<'a> {
pub fn get_def(&self, idx: usize) -> TypeDef {
assert!(idx < self.defs_count, "type def index out of bounds");
let offset = idx * 4;
TypeDef::from_bytes(&self.defs_bytes[offset..])
}
pub fn get(&self, id: TypeId) -> Option<TypeDef> {
let idx = id.0 as usize;
if idx < self.defs_count {
Some(self.get_def(idx))
} else {
None
}
}
pub fn get_member(&self, idx: usize) -> TypeMember {
assert!(idx < self.members_count, "type member index out of bounds");
let offset = idx * 4;
TypeMember::new(
StringId::new(read_u16_le(self.members_bytes, offset)),
TypeId(read_u16_le(self.members_bytes, offset + 2)),
)
}
pub fn get_name(&self, idx: usize) -> TypeName {
assert!(idx < self.names_count, "type name index out of bounds");
let offset = idx * 4;
TypeName::new(
StringId::new(read_u16_le(self.names_bytes, offset)),
TypeId(read_u16_le(self.names_bytes, offset + 2)),
)
}
pub fn defs_count(&self) -> usize {
self.defs_count
}
pub fn members_count(&self) -> usize {
self.members_count
}
pub fn names_count(&self) -> usize {
self.names_count
}
pub fn members_of(&self, def: &TypeDef) -> impl Iterator<Item = TypeMember> + '_ {
let (start, count) = match def.classify() {
TypeData::Composite {
member_start,
member_count,
..
} => (member_start as usize, member_count as usize),
_ => (0, 0),
};
(0..count).map(move |i| self.get_member(start + i))
}
pub fn unwrap_optional(&self, type_id: TypeId) -> (TypeId, bool) {
let Some(type_def) = self.get(type_id) else {
return (type_id, false);
};
match type_def.classify() {
TypeData::Wrapper {
kind: TypeKind::Optional,
inner,
} => (inner, true),
_ => (type_id, false),
}
}
}
pub struct EntrypointsView<'a> {
bytes: &'a [u8],
count: usize,
}
impl<'a> EntrypointsView<'a> {
pub fn get(&self, idx: usize) -> Entrypoint {
assert!(idx < self.count, "entrypoint index out of bounds");
let offset = idx * 8;
Entrypoint::from_bytes(&self.bytes[offset..])
}
pub fn len(&self) -> usize {
self.count
}
pub fn is_empty(&self) -> bool {
self.count == 0
}
pub fn find_by_name(&self, name: &str, strings: &StringsView<'_>) -> Option<Entrypoint> {
(0..self.count)
.map(|i| self.get(i))
.find(|e| strings.get(e.name()) == name)
}
}