use std::collections::BTreeSet;
use std::fmt;
use crate::hir::{HirLabelId, HirProtoRef, LocalId, ParamId, TempId, UpvalueId};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Ord, PartialOrd, Hash)]
pub struct AstSyntheticLocalId(pub TempId);
impl AstSyntheticLocalId {
pub const fn index(self) -> usize {
self.0.index()
}
}
#[derive(Debug, Clone, PartialEq, Default)]
pub struct AstModule {
pub entry_function: HirProtoRef,
pub body: AstBlock,
}
#[derive(Debug, Clone, PartialEq, Default)]
pub struct AstBlock {
pub stmts: Vec<AstStmt>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum AstStmt {
LocalDecl(Box<AstLocalDecl>),
GlobalDecl(Box<AstGlobalDecl>),
Assign(Box<AstAssign>),
CallStmt(Box<AstCallStmt>),
Return(Box<AstReturn>),
If(Box<AstIf>),
While(Box<AstWhile>),
Repeat(Box<AstRepeat>),
NumericFor(Box<AstNumericFor>),
GenericFor(Box<AstGenericFor>),
Break,
Continue,
Goto(Box<AstGoto>),
Label(Box<AstLabel>),
DoBlock(Box<AstBlock>),
FunctionDecl(Box<AstFunctionDecl>),
LocalFunctionDecl(Box<AstLocalFunctionDecl>),
Error(String),
}
#[derive(Debug, Clone, PartialEq)]
pub enum AstExpr {
Nil,
Boolean(bool),
Integer(i64),
Number(f64),
String(String),
Int64(i64),
UInt64(u64),
Complex { real: f64, imag: f64 },
Var(AstNameRef),
FieldAccess(Box<AstFieldAccess>),
IndexAccess(Box<AstIndexAccess>),
Unary(Box<AstUnaryExpr>),
Binary(Box<AstBinaryExpr>),
LogicalAnd(Box<AstLogicalExpr>),
LogicalOr(Box<AstLogicalExpr>),
Call(Box<AstCallExpr>),
MethodCall(Box<AstMethodCallExpr>),
SingleValue(Box<AstExpr>),
VarArg,
TableConstructor(Box<AstTableConstructor>),
FunctionExpr(Box<AstFunctionExpr>),
Error(String),
}
#[derive(Debug, Clone, PartialEq)]
pub struct AstAssign {
pub targets: Vec<AstLValue>,
pub values: Vec<AstExpr>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum AstLValue {
Name(AstNameRef),
FieldAccess(Box<AstFieldAccess>),
IndexAccess(Box<AstIndexAccess>),
}
#[derive(Debug, Clone, PartialEq, Eq, Ord, PartialOrd, Hash)]
pub enum AstNameRef {
Param(ParamId),
Local(LocalId),
Temp(TempId),
SyntheticLocal(AstSyntheticLocalId),
Upvalue(UpvalueId),
Global(AstGlobalName),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Ord, PartialOrd, Hash)]
pub enum AstBindingRef {
Local(LocalId),
Temp(TempId),
SyntheticLocal(AstSyntheticLocalId),
}
#[derive(Debug, Clone, PartialEq, Eq, Ord, PartialOrd, Hash)]
pub struct AstGlobalName {
pub text: String,
}
#[derive(Debug, Clone, PartialEq)]
pub struct AstReturn {
pub values: Vec<AstExpr>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct AstFunctionExpr {
pub function: HirProtoRef,
pub params: Vec<ParamId>,
pub is_vararg: bool,
pub named_vararg: Option<AstBindingRef>,
pub body: AstBlock,
pub captured_bindings: BTreeSet<AstBindingRef>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct AstFunctionDecl {
pub target: AstFunctionName,
pub func: AstFunctionExpr,
}
#[derive(Debug, Clone, PartialEq)]
pub struct AstLocalFunctionDecl {
pub name: AstBindingRef,
pub func: AstFunctionExpr,
}
#[derive(Debug, Clone, PartialEq)]
pub enum AstFunctionName {
Plain(AstNamePath),
Method(AstNamePath, String),
}
#[derive(Debug, Clone, PartialEq)]
pub struct AstNamePath {
pub root: AstNameRef,
pub fields: Vec<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct AstTargetDialect {
pub version: AstDialectVersion,
pub caps: AstDialectCaps,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct AstDialectCaps {
pub goto_label: bool,
pub continue_stmt: bool,
pub local_const: bool,
pub local_close: bool,
pub global_decl: bool,
pub global_const: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Ord, PartialOrd, Hash)]
pub enum AstFeature {
GotoLabel,
ContinueStmt,
LocalConst,
LocalClose,
GlobalDecl,
GlobalConst,
}
impl AstFeature {
pub const fn as_str(self) -> &'static str {
match self {
Self::GotoLabel => "goto",
Self::ContinueStmt => "continue",
Self::LocalConst => "local<const>",
Self::LocalClose => "local<close>",
Self::GlobalDecl => "global",
Self::GlobalConst => "global<const>",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AstDialectVersion {
Lua51,
Lua52,
Lua53,
Lua54,
Lua55,
LuaJit,
Luau,
}
impl AstDialectVersion {
pub const fn as_str(self) -> &'static str {
match self {
Self::Lua51 => "lua5.1",
Self::Lua52 => "lua5.2",
Self::Lua53 => "lua5.3",
Self::Lua54 => "lua5.4",
Self::Lua55 => "lua5.5",
Self::LuaJit => "luajit",
Self::Luau => "luau",
}
}
}
impl fmt::Display for AstDialectVersion {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
impl AstTargetDialect {
pub const fn new(version: AstDialectVersion) -> Self {
let caps = match version {
AstDialectVersion::Lua51 => AstDialectCaps {
goto_label: false,
continue_stmt: false,
local_const: false,
local_close: false,
global_decl: false,
global_const: false,
},
AstDialectVersion::Lua52 | AstDialectVersion::Lua53 => AstDialectCaps {
goto_label: true,
continue_stmt: false,
local_const: false,
local_close: false,
global_decl: false,
global_const: false,
},
AstDialectVersion::Lua54 => AstDialectCaps {
goto_label: true,
continue_stmt: false,
local_const: true,
local_close: true,
global_decl: false,
global_const: false,
},
AstDialectVersion::Lua55 => AstDialectCaps {
goto_label: true,
continue_stmt: false,
local_const: true,
local_close: true,
global_decl: true,
global_const: true,
},
AstDialectVersion::LuaJit => AstDialectCaps {
goto_label: true,
continue_stmt: false,
local_const: false,
local_close: false,
global_decl: false,
global_const: false,
},
AstDialectVersion::Luau => AstDialectCaps {
goto_label: false,
continue_stmt: true,
local_const: false,
local_close: false,
global_decl: false,
global_const: false,
},
};
Self { version, caps }
}
pub const fn relaxed_for_lowering(version: AstDialectVersion) -> Self {
let mut caps = Self::new(version).caps;
caps.goto_label = true;
caps.local_const = true;
caps.local_close = true;
caps.global_decl = true;
caps.global_const = true;
Self { version, caps }
}
pub const fn supports_feature(self, feature: AstFeature) -> bool {
self.caps.supports(feature)
}
}
impl AstDialectCaps {
pub const fn supports(self, feature: AstFeature) -> bool {
match feature {
AstFeature::GotoLabel => self.goto_label,
AstFeature::ContinueStmt => self.continue_stmt,
AstFeature::LocalConst => self.local_const,
AstFeature::LocalClose => self.local_close,
AstFeature::GlobalDecl => self.global_decl,
AstFeature::GlobalConst => self.global_decl && self.global_const,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct AstLocalDecl {
pub bindings: Vec<AstLocalBinding>,
pub values: Vec<AstExpr>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct AstGlobalDecl {
pub bindings: Vec<AstGlobalBinding>,
pub values: Vec<AstExpr>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct AstLocalBinding {
pub id: AstBindingRef,
pub attr: AstLocalAttr,
pub origin: AstLocalOrigin,
}
#[derive(Debug, Clone, PartialEq)]
pub struct AstGlobalBinding {
pub target: AstGlobalBindingTarget,
pub attr: AstGlobalAttr,
}
#[derive(Debug, Clone, PartialEq, Eq, Ord, PartialOrd, Hash)]
pub enum AstGlobalBindingTarget {
Name(AstGlobalName),
Wildcard,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AstLocalAttr {
None,
Const,
Close,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AstLocalOrigin {
Recovered,
DebugHinted,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AstGlobalAttr {
None,
Const,
}
#[derive(Debug, Clone, PartialEq)]
pub struct AstFieldAccess {
pub base: AstExpr,
pub field: String,
}
#[derive(Debug, Clone, PartialEq)]
pub struct AstIndexAccess {
pub base: AstExpr,
pub index: AstExpr,
}
#[derive(Debug, Clone, PartialEq)]
pub struct AstUnaryExpr {
pub op: AstUnaryOpKind,
pub expr: AstExpr,
}
#[derive(Debug, Clone, PartialEq)]
pub struct AstBinaryExpr {
pub op: AstBinaryOpKind,
pub lhs: AstExpr,
pub rhs: AstExpr,
}
#[derive(Debug, Clone, PartialEq)]
pub struct AstLogicalExpr {
pub lhs: AstExpr,
pub rhs: AstExpr,
}
#[derive(Debug, Clone, PartialEq)]
pub struct AstCallExpr {
pub callee: AstExpr,
pub args: Vec<AstExpr>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct AstMethodCallExpr {
pub receiver: AstExpr,
pub method: String,
pub args: Vec<AstExpr>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct AstCallStmt {
pub call: AstCallKind,
}
#[derive(Debug, Clone, PartialEq)]
pub enum AstCallKind {
Call(Box<AstCallExpr>),
MethodCall(Box<AstMethodCallExpr>),
}
#[derive(Debug, Clone, PartialEq)]
pub struct AstTableConstructor {
pub fields: Vec<AstTableField>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum AstTableField {
Array(AstExpr),
Record(AstRecordField),
}
#[derive(Debug, Clone, PartialEq)]
pub struct AstRecordField {
pub key: AstTableKey,
pub value: AstExpr,
}
#[derive(Debug, Clone, PartialEq)]
pub enum AstTableKey {
Name(String),
Expr(AstExpr),
}
#[derive(Debug, Clone, PartialEq)]
pub struct AstIf {
pub cond: AstExpr,
pub then_block: AstBlock,
pub else_block: Option<AstBlock>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct AstWhile {
pub cond: AstExpr,
pub body: AstBlock,
}
#[derive(Debug, Clone, PartialEq)]
pub struct AstRepeat {
pub body: AstBlock,
pub cond: AstExpr,
}
#[derive(Debug, Clone, PartialEq)]
pub struct AstNumericFor {
pub binding: AstBindingRef,
pub start: AstExpr,
pub limit: AstExpr,
pub step: AstExpr,
pub body: AstBlock,
}
#[derive(Debug, Clone, PartialEq)]
pub struct AstGenericFor {
pub bindings: Vec<AstBindingRef>,
pub iterator: Vec<AstExpr>,
pub body: AstBlock,
}
#[derive(Debug, Clone, PartialEq)]
pub struct AstGoto {
pub target: AstLabelId,
}
#[derive(Debug, Clone, PartialEq)]
pub struct AstLabel {
pub id: AstLabelId,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Ord, PartialOrd, Hash)]
pub struct AstLabelId(pub usize);
impl AstLabelId {
pub const fn index(self) -> usize {
self.0
}
}
impl From<HirLabelId> for AstLabelId {
fn from(value: HirLabelId) -> Self {
Self(value.index())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AstUnaryOpKind {
Not,
Neg,
BitNot,
Length,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AstBinaryOpKind {
Add,
Sub,
Mul,
Div,
FloorDiv,
Mod,
Pow,
BitAnd,
BitOr,
BitXor,
Shl,
Shr,
Concat,
Eq,
Lt,
Le,
}