1#![feature(box_patterns)]
2
3use bincode::{Decode, Encode};
4use std::{borrow::Cow, rc::Rc};
5
6#[allow(unused_imports)]
7use nom::{
8 branch::alt,
9 bytes::complete::{escaped, tag, take_till, take_until, take_while, take_while_m_n},
10 character::{
11 complete::{alphanumeric1, char as cchar, multispace0, multispace1, none_of, one_of},
12 is_alphabetic, is_newline, is_space,
13 streaming::space1,
14 },
15 combinator::{fail, map, map_res, opt, value},
16 number::complete::be_u8,
17 sequence::{delimited, tuple},
18};
19use nom::{
20 bytes::complete::take,
21 combinator::success,
22 error::{context, ErrorKind, ParseError},
23 multi::{length_count, length_data},
24 number::{complete, Endianness},
25 Parser,
26};
27use nom_supreme::{error::*, ParserExt};
28use serde::{Deserialize, Serialize};
29use serde_bytes::ByteBuf;
30
31pub mod lua51;
32pub mod lua52;
33pub mod lua53;
34pub mod lua54;
35pub mod luajit;
36pub mod luau;
37pub mod utils;
38
39pub type IResult<I, O, E = ErrorTree<I>> = Result<(I, O), nom::Err<E>>;
40
41#[derive(Clone, Default, Debug, PartialEq, Eq, Serialize, Deserialize, Encode, Decode)]
42pub struct LuaHeader {
43 pub lua_version: u8,
44 pub format_version: u8,
45 pub big_endian: bool,
46 pub int_size: u8,
47 pub size_t_size: u8,
48 pub instruction_size: u8,
49 pub number_size: u8,
50 pub number_integral: bool,
51 pub lj_flags: u8,
53}
54
55impl LuaHeader {
56 pub fn endian(&self) -> Endianness {
57 if self.big_endian {
58 Endianness::Big
59 } else {
60 Endianness::Little
61 }
62 }
63
64 pub fn version(&self) -> LuaVersion {
65 LuaVersion(self.lua_version)
66 }
67
68 pub fn test_luajit_flag(&self, flag: u8) -> bool {
69 self.lj_flags & flag != 0
70 }
71}
72
73#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Encode, Decode)]
74#[serde(untagged)]
75pub enum LuaNumber {
76 Integer(i64),
77 Float(f64),
78}
79
80impl std::fmt::Display for LuaNumber {
81 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82 match self {
83 Self::Float(n) => write!(f, "{n}"),
84 Self::Integer(i) => write!(f, "{i}"),
85 }
86 }
87}
88
89#[derive(Clone, Debug, Default, Serialize, Deserialize, Encode, Decode)]
91pub struct ConstTable {
92 pub array: Vec<LuaConstant>,
93 pub hash: Vec<(LuaConstant, LuaConstant)>,
94}
95
96#[derive(Clone, Default, Deserialize, Encode, Decode)]
97#[serde(untagged)]
98pub enum LuaConstant {
99 #[default]
100 Null,
101 Bool(bool),
102 Number(LuaNumber),
103 String(#[bincode(with_serde)] Rc<ByteBuf>),
104 Proto(usize),
106 Table(Box<ConstTable>),
107 }
110
111impl Serialize for LuaConstant {
112 fn serialize<S>(&self, ser: S) -> Result<S::Ok, S::Error>
113 where
114 S: serde::Serializer,
115 {
116 match self {
117 LuaConstant::Null => ser.serialize_unit(),
118 LuaConstant::Bool(b) => ser.serialize_bool(*b),
119 LuaConstant::Number(n) => n.serialize(ser),
120 LuaConstant::String(s) => {
121 if let Ok(s) = core::str::from_utf8(s) {
122 ser.serialize_str(s)
123 } else {
124 ser.serialize_bytes(s)
125 }
126 }
127 LuaConstant::Proto(p) => p.serialize(ser),
128 LuaConstant::Table(t) => t.serialize(ser),
129 }
130 }
131}
132
133impl<T: Into<Vec<u8>>> From<T> for LuaConstant {
134 fn from(value: T) -> Self {
135 Self::String(Rc::new(ByteBuf::from(value)))
136 }
137}
138
139impl std::fmt::Debug for LuaConstant {
140 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141 match self {
142 Self::Null => write!(f, "Null"),
143 Self::Bool(arg0) => f.debug_tuple("Bool").field(arg0).finish(),
144 Self::Number(arg0) => match arg0 {
145 LuaNumber::Float(n) => f.debug_tuple("Number").field(n).finish(),
146 LuaNumber::Integer(n) => f.debug_tuple("Integer").field(n).finish(),
147 },
148 Self::String(arg0) => f
149 .debug_tuple("String")
150 .field(&String::from_utf8_lossy(arg0))
151 .finish(),
152 Self::Proto(i) => f.debug_tuple("Proto").field(i).finish(),
153 Self::Table(box ConstTable { array, hash }) => f
154 .debug_struct("Table")
155 .field("array", array)
156 .field("hash", hash)
157 .finish(),
158 }
160 }
161}
162
163#[derive(Debug, Default, Serialize, Deserialize, Encode, Decode)]
164pub struct LuaLocal {
165 pub name: String,
166 pub start_pc: u64,
167 pub end_pc: u64,
168 pub reg: u8, }
170
171#[derive(Debug, Serialize, Deserialize, Encode, Decode)]
172pub struct LuaVarArgInfo {
173 pub has_arg: bool,
174 pub needs_arg: bool,
175}
176
177impl LuaVarArgInfo {
178 pub fn new() -> Self {
179 Self {
180 has_arg: true,
181 needs_arg: true,
182 }
183 }
184}
185
186#[derive(Debug, Default, Serialize, Deserialize, Encode, Decode)]
187pub struct UpVal {
188 pub on_stack: bool,
189 pub id: u8,
190 pub kind: u8,
191}
192
193#[derive(Default, Serialize, Deserialize, Encode, Decode)]
194pub struct LuaChunk {
195 pub name: Vec<u8>,
196 pub line_defined: u64,
197 pub last_line_defined: u64,
198 pub num_upvalues: u8,
199 pub num_params: u8,
200 pub max_stack: u8,
202 pub flags: u8,
204 pub is_vararg: Option<LuaVarArgInfo>,
205 pub instructions: Vec<u32>,
206 pub constants: Vec<LuaConstant>,
207 pub num_constants: Vec<LuaNumber>,
209 pub prototypes: Vec<Self>,
210 pub source_lines: Vec<(u32, u32)>,
211 pub locals: Vec<LuaLocal>,
212 pub upvalue_infos: Vec<UpVal>,
214 pub upvalue_names: Vec<Vec<u8>>,
215}
216
217impl std::fmt::Debug for LuaChunk {
218 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
219 f.debug_struct("LuaChunk")
220 .field("name", &String::from_utf8_lossy(&self.name))
221 .field("line_defined", &self.line_defined)
222 .field("last_line_defined", &self.last_line_defined)
223 .field("is_vararg", &self.is_vararg.is_some())
224 .field("num_params", &self.num_params)
225 .field("num_upvalues", &self.num_upvalues)
226 .field("locals", &self.locals)
227 .field("constants", &self.constants)
228 .field("prototypes", &self.prototypes)
229 .field("upvalue_infos", &self.upvalue_infos)
230 .finish()
231 }
232}
233
234impl LuaChunk {
235 pub fn name(&self) -> Cow<'_, str> {
236 String::from_utf8_lossy(&self.name)
237 }
238
239 pub fn flags(&self) -> luajit::ProtoFlags {
240 luajit::ProtoFlags::from_bits(self.flags).unwrap()
241 }
242
243 pub fn is_empty(&self) -> bool {
244 self.instructions.is_empty()
245 }
246}
247
248#[derive(Debug, Serialize, Deserialize, Encode, Decode)]
249pub struct LuaBytecode {
250 pub header: LuaHeader,
251 pub main_chunk: LuaChunk,
252}
253
254fn lua_header(input: &[u8]) -> IResult<&[u8], LuaHeader, ErrorTree<&[u8]>> {
255 let (rest, (_, result)) = tuple((
256 tag(b"\x1BLua"),
257 alt((
258 map(
259 tuple((
260 tag(b"\x51"),
261 be_u8,
262 be_u8,
263 be_u8,
264 be_u8,
265 be_u8,
266 be_u8,
267 be_u8,
268 )),
269 |(
270 _,
271 format_version,
272 big_endian,
273 int_size,
274 size_t_size,
275 instruction_size,
276 number_size,
277 number_integral,
278 )| LuaHeader {
279 lua_version: LUA51.0,
280 format_version,
281 big_endian: big_endian != 1,
282 int_size,
283 size_t_size,
284 instruction_size,
285 number_size,
286 number_integral: number_integral != 0,
287 ..Default::default()
288 },
289 ),
290 map(
291 tuple((
292 tag(b"\x52"),
293 be_u8,
294 be_u8,
295 be_u8,
296 be_u8,
297 be_u8,
298 be_u8,
299 be_u8,
300 take(6usize), )),
302 |(
303 _,
304 format_version,
305 big_endian,
306 int_size,
307 size_t_size,
308 instruction_size,
309 number_size,
310 number_integral,
311 _,
312 )| LuaHeader {
313 lua_version: LUA52.0,
314 format_version,
315 big_endian: big_endian != 1,
316 int_size,
317 size_t_size,
318 instruction_size,
319 number_size,
320 number_integral: number_integral != 0,
321 ..Default::default()
322 },
323 ),
324 map(
325 tuple((
326 tag(b"\x53"),
327 be_u8,
328 take(6usize), be_u8,
330 be_u8,
331 be_u8,
332 be_u8,
333 be_u8,
334 complete::le_i64,
335 complete::le_f64,
336 be_u8,
337 )),
338 |(
339 _,
340 format_version,
341 _luac_data,
342 int_size,
343 size_t_size,
344 instruction_size,
345 _integer_size, number_size,
347 _,
348 _,
349 _,
350 )| LuaHeader {
351 lua_version: LUA53.0,
352 format_version,
353 big_endian: cfg!(target_endian = "big"),
354 int_size,
355 size_t_size,
356 instruction_size,
357 number_size,
358 number_integral: false,
359 ..Default::default()
360 },
361 ),
362 map(
363 tuple((
364 tag(b"\x54"),
365 be_u8,
366 take(6usize), be_u8,
368 be_u8,
369 be_u8,
370 complete::le_i64,
371 complete::le_f64,
372 be_u8,
373 )),
374 |(
375 _,
376 format_version,
377 _luac_data,
378 instruction_size,
379 _integer_size, number_size,
381 _,
382 _,
383 _,
384 )| LuaHeader {
385 lua_version: LUA54.0,
386 format_version,
387 big_endian: cfg!(target_endian = "big"),
388 int_size: 4,
389 size_t_size: 8,
390 instruction_size,
391 number_size,
392 number_integral: false,
393 ..Default::default()
394 },
395 ),
396 )),
397 ))(input)?;
398 Ok((rest, result))
399}
400
401fn must<I, O, E: ParseError<I>, P: Parser<I, O, E>>(
402 cond: bool,
403 mut parser: P,
404) -> impl FnMut(I) -> IResult<I, O, E> {
405 move |input| {
406 if cond {
407 parser.parse(input)
408 } else {
409 Err(nom::Err::Error(E::from_error_kind(
410 input,
411 ErrorKind::Switch,
412 )))
413 }
414 }
415}
416
417fn lua_int<'a>(header: &LuaHeader) -> impl Parser<&'a [u8], u64, ErrorTree<&'a [u8]>> {
418 let intsize = header.int_size;
419 alt((
420 must(
421 intsize == 8,
422 map(complete::u64(header.endian()), |v| v as u64),
423 ),
424 must(
425 intsize == 4,
426 map(complete::u32(header.endian()), |v| v as u64),
427 ),
428 must(
429 intsize == 2,
430 map(complete::u16(header.endian()), |v| v as u64),
431 ),
432 must(intsize == 1, map(be_u8, |v| v as u64)),
433 ))
434 .context("integer")
435}
436
437fn lua_size_t<'a>(header: &LuaHeader) -> impl Parser<&'a [u8], u64, ErrorTree<&'a [u8]>> {
438 let sizesize = header.size_t_size;
439 alt((
440 must(
441 sizesize == 8,
442 map(complete::u64(header.endian()), |v| v as u64),
443 ),
444 must(
445 sizesize == 4,
446 map(complete::u32(header.endian()), |v| v as u64),
447 ),
448 must(
449 sizesize == 2,
450 map(complete::u16(header.endian()), |v| v as u64),
451 ),
452 must(sizesize == 1, map(be_u8, |v| v as u64)),
453 ))
454 .context("size_t")
455}
456
457fn lua_number<'a>(header: &LuaHeader) -> impl Parser<&'a [u8], LuaNumber, ErrorTree<&'a [u8]>> {
458 let int = header.number_integral;
459 let size = header.number_size;
460 alt((
461 must(
462 int == true,
463 map(
464 alt((
465 must(size == 1, map(complete::be_i8, |v| v as i64)),
466 must(size == 2, map(complete::i16(header.endian()), |v| v as i64)),
467 must(size == 4, map(complete::i32(header.endian()), |v| v as i64)),
468 must(size == 8, map(complete::i64(header.endian()), |v| v as i64)),
469 )),
470 |v| LuaNumber::Integer(v),
471 ),
472 ),
473 must(
474 int == false,
475 map(
476 alt((
477 must(size == 8, map(complete::f64(header.endian()), |v| v as f64)),
478 must(size == 4, map(complete::f32(header.endian()), |v| v as f64)),
479 )),
480 |v| LuaNumber::Float(v),
481 ),
482 ),
483 ))
484 .context("number")
485}
486
487pub fn lua_bytecode(input: &[u8]) -> IResult<&[u8], LuaBytecode, ErrorTree<&[u8]>> {
488 let (input, header) = alt((lua_header, luajit::lj_header))(input)?;
489 log::trace!("header: {header:?}");
490 let (input, main_chunk) = match header.version() {
491 LUA51 => lua51::lua_chunk(&header).parse(input)?,
492 LUA52 => lua52::lua_chunk(&header).parse(input)?,
493 LUA53 => lua53::lua_chunk(&header).parse(input)?,
494 LUA54 => lua54::lua_chunk(&header).parse(input)?,
495 LUAJ1 | LUAJ2 => luajit::lj_chunk(&header).parse(input)?,
496 _ => context("unsupported lua version", fail)(input)?,
497 };
498 Ok((input, LuaBytecode { header, main_chunk }))
499}
500
501pub fn parse(input: &[u8]) -> Result<LuaBytecode, String> {
502 lua_bytecode(input).map(|x| x.1).map_err(|e| {
503 format!(
504 "{:#?}",
505 e.map(|e| e.map_locations(|p| unsafe { p.as_ptr().byte_offset_from(input.as_ptr()) }))
506 )
507 })
508}
509
510#[cfg(feature = "rmp-serde")]
511impl LuaBytecode {
512 pub fn from_msgpack(mp: &[u8]) -> Result<Self, rmp_serde::decode::Error> {
513 rmp_serde::from_slice(mp)
514 }
515
516 pub fn to_msgpack(&self) -> Result<Vec<u8>, rmp_serde::encode::Error> {
517 rmp_serde::to_vec(self)
518 }
519}
520
521#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
522pub struct LuaVersion(pub u8);
523
524impl std::fmt::Display for LuaVersion {
525 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
526 match *self {
527 LUAJ1 => write!(f, "luajit1"),
528 LUAJ2 => write!(f, "luajit2"),
529 v => write!(f, "lua{:x}", v.0),
530 }
531 }
532}
533
534impl LuaVersion {
535 pub fn is_luajit(self) -> bool {
536 matches!(self, LUAJ1 | LUAJ2)
537 }
538}
539
540pub const LUA51: LuaVersion = LuaVersion(0x51);
541pub const LUA52: LuaVersion = LuaVersion(0x52);
542pub const LUA53: LuaVersion = LuaVersion(0x53);
543pub const LUA54: LuaVersion = LuaVersion(0x54);
544pub const LUAJ1: LuaVersion = LuaVersion(0x11);
545pub const LUAJ2: LuaVersion = LuaVersion(0x12);