luac_parser/
lib.rs

1#![feature(box_patterns)]
2
3use bincode::{Decode, Encode};
4use std::{borrow::Cow, rc::Rc};
5
6#[allow(unused_imports)]
7use nom::{
8    branch::alt,
9    bytes::complete::{escaped, tag, take_till, take_until, take_while, take_while_m_n},
10    character::{
11        complete::{alphanumeric1, char as cchar, multispace0, multispace1, none_of, one_of},
12        is_alphabetic, is_newline, is_space,
13        streaming::space1,
14    },
15    combinator::{fail, map, map_res, opt, value},
16    number::complete::be_u8,
17    sequence::{delimited, tuple},
18};
19use nom::{
20    bytes::complete::take,
21    combinator::success,
22    error::{context, ErrorKind, ParseError},
23    multi::{length_count, length_data},
24    number::{complete, Endianness},
25    Parser,
26};
27use nom_supreme::{error::*, ParserExt};
28use serde::{Deserialize, Serialize};
29use serde_bytes::ByteBuf;
30
31pub mod lua51;
32pub mod lua52;
33pub mod lua53;
34pub mod lua54;
35pub mod luajit;
36pub mod luau;
37pub mod utils;
38
39pub type IResult<I, O, E = ErrorTree<I>> = Result<(I, O), nom::Err<E>>;
40
41#[derive(Clone, Default, Debug, PartialEq, Eq, Serialize, Deserialize, Encode, Decode)]
42pub struct LuaHeader {
43    pub lua_version: u8,
44    pub format_version: u8,
45    pub big_endian: bool,
46    pub int_size: u8,
47    pub size_t_size: u8,
48    pub instruction_size: u8,
49    pub number_size: u8,
50    pub number_integral: bool,
51    // for luajit
52    pub lj_flags: u8,
53}
54
55impl LuaHeader {
56    pub fn endian(&self) -> Endianness {
57        if self.big_endian {
58            Endianness::Big
59        } else {
60            Endianness::Little
61        }
62    }
63
64    pub fn version(&self) -> LuaVersion {
65        LuaVersion(self.lua_version)
66    }
67
68    pub fn test_luajit_flag(&self, flag: u8) -> bool {
69        self.lj_flags & flag != 0
70    }
71}
72
73#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Encode, Decode)]
74#[serde(untagged)]
75pub enum LuaNumber {
76    Integer(i64),
77    Float(f64),
78}
79
80impl std::fmt::Display for LuaNumber {
81    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82        match self {
83            Self::Float(n) => write!(f, "{n}"),
84            Self::Integer(i) => write!(f, "{i}"),
85        }
86    }
87}
88
89/// Constant table for luajit, notice that the index of the array part starts at 0
90#[derive(Clone, Debug, Default, Serialize, Deserialize, Encode, Decode)]
91pub struct ConstTable {
92    pub array: Vec<LuaConstant>,
93    pub hash: Vec<(LuaConstant, LuaConstant)>,
94}
95
96#[derive(Clone, Default, Deserialize, Encode, Decode)]
97#[serde(untagged)]
98pub enum LuaConstant {
99    #[default]
100    Null,
101    Bool(bool),
102    Number(LuaNumber),
103    String(#[bincode(with_serde)] Rc<ByteBuf>),
104    // for luajit
105    Proto(usize),
106    Table(Box<ConstTable>),
107    // // for luau
108    // Imp(u32),
109}
110
111impl Serialize for LuaConstant {
112    fn serialize<S>(&self, ser: S) -> Result<S::Ok, S::Error>
113    where
114        S: serde::Serializer,
115    {
116        match self {
117            LuaConstant::Null => ser.serialize_unit(),
118            LuaConstant::Bool(b) => ser.serialize_bool(*b),
119            LuaConstant::Number(n) => n.serialize(ser),
120            LuaConstant::String(s) => {
121                if let Ok(s) = core::str::from_utf8(s) {
122                    ser.serialize_str(s)
123                } else {
124                    ser.serialize_bytes(s)
125                }
126            }
127            LuaConstant::Proto(p) => p.serialize(ser),
128            LuaConstant::Table(t) => t.serialize(ser),
129        }
130    }
131}
132
133impl<T: Into<Vec<u8>>> From<T> for LuaConstant {
134    fn from(value: T) -> Self {
135        Self::String(Rc::new(ByteBuf::from(value)))
136    }
137}
138
139impl std::fmt::Debug for LuaConstant {
140    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141        match self {
142            Self::Null => write!(f, "Null"),
143            Self::Bool(arg0) => f.debug_tuple("Bool").field(arg0).finish(),
144            Self::Number(arg0) => match arg0 {
145                LuaNumber::Float(n) => f.debug_tuple("Number").field(n).finish(),
146                LuaNumber::Integer(n) => f.debug_tuple("Integer").field(n).finish(),
147            },
148            Self::String(arg0) => f
149                .debug_tuple("String")
150                .field(&String::from_utf8_lossy(arg0))
151                .finish(),
152            Self::Proto(i) => f.debug_tuple("Proto").field(i).finish(),
153            Self::Table(box ConstTable { array, hash }) => f
154                .debug_struct("Table")
155                .field("array", array)
156                .field("hash", hash)
157                .finish(),
158            // Self::Imp(imp) => f.debug_tuple("Imp").field(imp).finish(),
159        }
160    }
161}
162
163#[derive(Debug, Default, Serialize, Deserialize, Encode, Decode)]
164pub struct LuaLocal {
165    pub name: String,
166    pub start_pc: u64,
167    pub end_pc: u64,
168    pub reg: u8, // for luau
169}
170
171#[derive(Debug, Serialize, Deserialize, Encode, Decode)]
172pub struct LuaVarArgInfo {
173    pub has_arg: bool,
174    pub needs_arg: bool,
175}
176
177impl LuaVarArgInfo {
178    pub fn new() -> Self {
179        Self {
180            has_arg: true,
181            needs_arg: true,
182        }
183    }
184}
185
186#[derive(Debug, Default, Serialize, Deserialize, Encode, Decode)]
187pub struct UpVal {
188    pub on_stack: bool,
189    pub id: u8,
190    pub kind: u8,
191}
192
193#[derive(Default, Serialize, Deserialize, Encode, Decode)]
194pub struct LuaChunk {
195    pub name: Vec<u8>,
196    pub line_defined: u64,
197    pub last_line_defined: u64,
198    pub num_upvalues: u8,
199    pub num_params: u8,
200    /// Equivalent to framesize for luajit
201    pub max_stack: u8,
202    /// for luajit
203    pub flags: u8,
204    pub is_vararg: Option<LuaVarArgInfo>,
205    pub instructions: Vec<u32>,
206    pub constants: Vec<LuaConstant>,
207    /// for luajit
208    pub num_constants: Vec<LuaNumber>,
209    pub prototypes: Vec<Self>,
210    pub source_lines: Vec<(u32, u32)>,
211    pub locals: Vec<LuaLocal>,
212    /// for lua53
213    pub upvalue_infos: Vec<UpVal>,
214    pub upvalue_names: Vec<Vec<u8>>,
215}
216
217impl std::fmt::Debug for LuaChunk {
218    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
219        f.debug_struct("LuaChunk")
220            .field("name", &String::from_utf8_lossy(&self.name))
221            .field("line_defined", &self.line_defined)
222            .field("last_line_defined", &self.last_line_defined)
223            .field("is_vararg", &self.is_vararg.is_some())
224            .field("num_params", &self.num_params)
225            .field("num_upvalues", &self.num_upvalues)
226            .field("locals", &self.locals)
227            .field("constants", &self.constants)
228            .field("prototypes", &self.prototypes)
229            .field("upvalue_infos", &self.upvalue_infos)
230            .finish()
231    }
232}
233
234impl LuaChunk {
235    pub fn name(&self) -> Cow<'_, str> {
236        String::from_utf8_lossy(&self.name)
237    }
238
239    pub fn flags(&self) -> luajit::ProtoFlags {
240        luajit::ProtoFlags::from_bits(self.flags).unwrap()
241    }
242
243    pub fn is_empty(&self) -> bool {
244        self.instructions.is_empty()
245    }
246}
247
248#[derive(Debug, Serialize, Deserialize, Encode, Decode)]
249pub struct LuaBytecode {
250    pub header: LuaHeader,
251    pub main_chunk: LuaChunk,
252}
253
254fn lua_header(input: &[u8]) -> IResult<&[u8], LuaHeader, ErrorTree<&[u8]>> {
255    let (rest, (_, result)) = tuple((
256        tag(b"\x1BLua"),
257        alt((
258            map(
259                tuple((
260                    tag(b"\x51"),
261                    be_u8,
262                    be_u8,
263                    be_u8,
264                    be_u8,
265                    be_u8,
266                    be_u8,
267                    be_u8,
268                )),
269                |(
270                    _,
271                    format_version,
272                    big_endian,
273                    int_size,
274                    size_t_size,
275                    instruction_size,
276                    number_size,
277                    number_integral,
278                )| LuaHeader {
279                    lua_version: LUA51.0,
280                    format_version,
281                    big_endian: big_endian != 1,
282                    int_size,
283                    size_t_size,
284                    instruction_size,
285                    number_size,
286                    number_integral: number_integral != 0,
287                    ..Default::default()
288                },
289            ),
290            map(
291                tuple((
292                    tag(b"\x52"),
293                    be_u8,
294                    be_u8,
295                    be_u8,
296                    be_u8,
297                    be_u8,
298                    be_u8,
299                    be_u8,
300                    take(6usize), // LUAC_DATA
301                )),
302                |(
303                    _,
304                    format_version,
305                    big_endian,
306                    int_size,
307                    size_t_size,
308                    instruction_size,
309                    number_size,
310                    number_integral,
311                    _,
312                )| LuaHeader {
313                    lua_version: LUA52.0,
314                    format_version,
315                    big_endian: big_endian != 1,
316                    int_size,
317                    size_t_size,
318                    instruction_size,
319                    number_size,
320                    number_integral: number_integral != 0,
321                    ..Default::default()
322                },
323            ),
324            map(
325                tuple((
326                    tag(b"\x53"),
327                    be_u8,
328                    take(6usize), // LUAC_DATA
329                    be_u8,
330                    be_u8,
331                    be_u8,
332                    be_u8,
333                    be_u8,
334                    complete::le_i64,
335                    complete::le_f64,
336                    be_u8,
337                )),
338                |(
339                    _,
340                    format_version,
341                    _luac_data,
342                    int_size,
343                    size_t_size,
344                    instruction_size,
345                    _integer_size, // lua_Integer
346                    number_size,
347                    _,
348                    _,
349                    _,
350                )| LuaHeader {
351                    lua_version: LUA53.0,
352                    format_version,
353                    big_endian: cfg!(target_endian = "big"),
354                    int_size,
355                    size_t_size,
356                    instruction_size,
357                    number_size,
358                    number_integral: false,
359                    ..Default::default()
360                },
361            ),
362            map(
363                tuple((
364                    tag(b"\x54"),
365                    be_u8,
366                    take(6usize), // LUAC_DATA
367                    be_u8,
368                    be_u8,
369                    be_u8,
370                    complete::le_i64,
371                    complete::le_f64,
372                    be_u8,
373                )),
374                |(
375                    _,
376                    format_version,
377                    _luac_data,
378                    instruction_size,
379                    _integer_size, // lua_Integer
380                    number_size,
381                    _,
382                    _,
383                    _,
384                )| LuaHeader {
385                    lua_version: LUA54.0,
386                    format_version,
387                    big_endian: cfg!(target_endian = "big"),
388                    int_size: 4,
389                    size_t_size: 8,
390                    instruction_size,
391                    number_size,
392                    number_integral: false,
393                    ..Default::default()
394                },
395            ),
396        )),
397    ))(input)?;
398    Ok((rest, result))
399}
400
401fn must<I, O, E: ParseError<I>, P: Parser<I, O, E>>(
402    cond: bool,
403    mut parser: P,
404) -> impl FnMut(I) -> IResult<I, O, E> {
405    move |input| {
406        if cond {
407            parser.parse(input)
408        } else {
409            Err(nom::Err::Error(E::from_error_kind(
410                input,
411                ErrorKind::Switch,
412            )))
413        }
414    }
415}
416
417fn lua_int<'a>(header: &LuaHeader) -> impl Parser<&'a [u8], u64, ErrorTree<&'a [u8]>> {
418    let intsize = header.int_size;
419    alt((
420        must(
421            intsize == 8,
422            map(complete::u64(header.endian()), |v| v as u64),
423        ),
424        must(
425            intsize == 4,
426            map(complete::u32(header.endian()), |v| v as u64),
427        ),
428        must(
429            intsize == 2,
430            map(complete::u16(header.endian()), |v| v as u64),
431        ),
432        must(intsize == 1, map(be_u8, |v| v as u64)),
433    ))
434    .context("integer")
435}
436
437fn lua_size_t<'a>(header: &LuaHeader) -> impl Parser<&'a [u8], u64, ErrorTree<&'a [u8]>> {
438    let sizesize = header.size_t_size;
439    alt((
440        must(
441            sizesize == 8,
442            map(complete::u64(header.endian()), |v| v as u64),
443        ),
444        must(
445            sizesize == 4,
446            map(complete::u32(header.endian()), |v| v as u64),
447        ),
448        must(
449            sizesize == 2,
450            map(complete::u16(header.endian()), |v| v as u64),
451        ),
452        must(sizesize == 1, map(be_u8, |v| v as u64)),
453    ))
454    .context("size_t")
455}
456
457fn lua_number<'a>(header: &LuaHeader) -> impl Parser<&'a [u8], LuaNumber, ErrorTree<&'a [u8]>> {
458    let int = header.number_integral;
459    let size = header.number_size;
460    alt((
461        must(
462            int == true,
463            map(
464                alt((
465                    must(size == 1, map(complete::be_i8, |v| v as i64)),
466                    must(size == 2, map(complete::i16(header.endian()), |v| v as i64)),
467                    must(size == 4, map(complete::i32(header.endian()), |v| v as i64)),
468                    must(size == 8, map(complete::i64(header.endian()), |v| v as i64)),
469                )),
470                |v| LuaNumber::Integer(v),
471            ),
472        ),
473        must(
474            int == false,
475            map(
476                alt((
477                    must(size == 8, map(complete::f64(header.endian()), |v| v as f64)),
478                    must(size == 4, map(complete::f32(header.endian()), |v| v as f64)),
479                )),
480                |v| LuaNumber::Float(v),
481            ),
482        ),
483    ))
484    .context("number")
485}
486
487pub fn lua_bytecode(input: &[u8]) -> IResult<&[u8], LuaBytecode, ErrorTree<&[u8]>> {
488    let (input, header) = alt((lua_header, luajit::lj_header))(input)?;
489    log::trace!("header: {header:?}");
490    let (input, main_chunk) = match header.version() {
491        LUA51 => lua51::lua_chunk(&header).parse(input)?,
492        LUA52 => lua52::lua_chunk(&header).parse(input)?,
493        LUA53 => lua53::lua_chunk(&header).parse(input)?,
494        LUA54 => lua54::lua_chunk(&header).parse(input)?,
495        LUAJ1 | LUAJ2 => luajit::lj_chunk(&header).parse(input)?,
496        _ => context("unsupported lua version", fail)(input)?,
497    };
498    Ok((input, LuaBytecode { header, main_chunk }))
499}
500
501pub fn parse(input: &[u8]) -> Result<LuaBytecode, String> {
502    lua_bytecode(input).map(|x| x.1).map_err(|e| {
503        format!(
504            "{:#?}",
505            e.map(|e| e.map_locations(|p| unsafe { p.as_ptr().byte_offset_from(input.as_ptr()) }))
506        )
507    })
508}
509
510#[cfg(feature = "rmp-serde")]
511impl LuaBytecode {
512    pub fn from_msgpack(mp: &[u8]) -> Result<Self, rmp_serde::decode::Error> {
513        rmp_serde::from_slice(mp)
514    }
515
516    pub fn to_msgpack(&self) -> Result<Vec<u8>, rmp_serde::encode::Error> {
517        rmp_serde::to_vec(self)
518    }
519}
520
521#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
522pub struct LuaVersion(pub u8);
523
524impl std::fmt::Display for LuaVersion {
525    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
526        match *self {
527            LUAJ1 => write!(f, "luajit1"),
528            LUAJ2 => write!(f, "luajit2"),
529            v => write!(f, "lua{:x}", v.0),
530        }
531    }
532}
533
534impl LuaVersion {
535    pub fn is_luajit(self) -> bool {
536        matches!(self, LUAJ1 | LUAJ2)
537    }
538}
539
540pub const LUA51: LuaVersion = LuaVersion(0x51);
541pub const LUA52: LuaVersion = LuaVersion(0x52);
542pub const LUA53: LuaVersion = LuaVersion(0x53);
543pub const LUA54: LuaVersion = LuaVersion(0x54);
544pub const LUAJ1: LuaVersion = LuaVersion(0x11);
545pub const LUAJ2: LuaVersion = LuaVersion(0x12);