#![allow(dead_code)]
use crate::compiler::error::Span;
use std::collections::HashMap;
pub use crate::compiler::mangler::MANGLER;
pub fn mangle_wrap_args(base: &str, args: &[String]) -> String {
MANGLER.wrap_args(base, args)
}
pub fn mangle_base_name(mangled: &str) -> &str {
MANGLER.base_name(mangled)
}
pub fn mangle_has_args(mangled: &str) -> bool {
MANGLER.has_args(mangled)
}
pub fn mangle_extract_args(mangled: &str) -> Vec<&str> {
MANGLER.extract_args(mangled)
}
pub fn parse_mangled_type_str(s: &str) -> Type {
MANGLER.parse_type_str(s)
}
pub fn parse_mangled_type_strs(strs: &[&str]) -> Vec<Type> {
MANGLER.parse_type_strs(strs)
}
#[derive(Debug, Clone)]
pub struct Spanned<T> {
pub inner: T,
pub span: Span,
}
impl<T: PartialEq> PartialEq for Spanned<T> {
fn eq(&self, other: &Self) -> bool {
self.inner == other.inner
}
}
impl<T: Eq> Eq for Spanned<T> {}
impl<T> Spanned<T> {
pub fn new(inner: T, span: Span) -> Self {
Spanned { inner, span }
}
pub fn dummy(inner: T) -> Self {
Spanned {
inner,
span: Span::default(),
}
}
}
impl<T> std::ops::Deref for Spanned<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl<T: Default> Default for Spanned<T> {
fn default() -> Self {
Spanned::dummy(T::default())
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum Dim {
Constant(usize), Var(u32), Symbolic(String), }
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[allow(dead_code)]
pub enum Type {
F32,
F64,
F16,
BF16,
Bool,
String(String),
Char(String),
I8,
I16,
I32,
I64,
U8,
U16,
U32,
U64,
Usize,
Entity,
Tensor(Box<Type>, usize),
TensorShaped(Box<Type>, Vec<Dim>),
Ptr(Box<Type>),
TypeVar(u32),
Struct(String, Vec<Type>),
Enum(String, Vec<Type>),
Tuple(Vec<Type>),
Path(Vec<String>, Vec<Type>),
Array(Box<Type>, usize),
UnifiedType {
base_name: String, type_args: Vec<Type>, mangled_name: String, is_enum: bool, },
Void, Never, Undefined(u64), }
impl Type {
pub fn get_base_name(&self) -> String {
match self {
Type::F32 => "F32".to_string(),
Type::F64 => "F64".to_string(),
Type::F16 => "F16".to_string(),
Type::BF16 => "BF16".to_string(),
Type::Bool => "Bool".to_string(),
Type::String(_) => "String".to_string(),
Type::Char(_) => "Char".to_string(),
Type::I8 => "I8".to_string(),
Type::I16 => "I16".to_string(),
Type::I32 => "I32".to_string(),
Type::I64 => "I64".to_string(),
Type::U8 => "U8".to_string(),
Type::U16 => "U16".to_string(),
Type::U32 => "U32".to_string(),
Type::U64 => "U64".to_string(),
Type::Usize => "Usize".to_string(),
Type::Entity => "Entity".to_string(),
Type::Tensor(_, _) => "Tensor".to_string(),
Type::TensorShaped(_, _) => "Tensor".to_string(),
Type::Struct(n, _) => n.clone(),
Type::Enum(n, _) => n.clone(),
Type::UnifiedType { base_name, .. } => base_name.clone(),
Type::Path(p, _) => p.last().cloned().unwrap_or_default(),
Type::Tuple(_) => "Tuple".to_string(), Type::Void => "Void".to_string(),
Type::Never => "Never".to_string(),
Type::Undefined(_) => "Undefined".to_string(),
Type::TypeVar(_) => "TypeVar".to_string(),
Type::Ptr(_inner) => "Ptr".to_string(),
Type::Array(inner, _) => format!("Array_{}", inner.get_base_name()),
}
}
pub fn as_struct_like(&self) -> Option<(&str, &[Type])> {
match self {
Type::Struct(name, args) => Some((name, args)),
Type::UnifiedType { base_name, type_args, is_enum: false, .. } => Some((base_name, type_args)),
_ => None,
}
}
pub fn as_enum_like(&self) -> Option<(&str, &[Type])> {
match self {
Type::Enum(name, args) => Some((name, args)),
Type::UnifiedType { base_name, type_args, is_enum: true, .. } => Some((base_name, type_args)),
_ => None,
}
}
pub fn as_named_type(&self) -> Option<(&str, &[Type])> {
match self {
Type::Struct(name, args) | Type::Enum(name, args) => Some((name, args)),
Type::UnifiedType { base_name, type_args, .. } => Some((base_name, type_args)),
_ => None,
}
}
pub fn mangled_name_or_name(&self) -> Option<&str> {
match self {
Type::Struct(name, _) | Type::Enum(name, _) => Some(name),
Type::UnifiedType { mangled_name, .. } => Some(mangled_name),
_ => None,
}
}
pub fn codegen_name(&self) -> Option<String> {
match self {
Type::Struct(name, args) if !args.is_empty() => {
Some(name.clone())
}
Type::Struct(name, _) | Type::Enum(name, _) => Some(name.clone()),
Type::UnifiedType { mangled_name, .. } => Some(mangled_name.clone()),
_ => None,
}
}
pub fn is_enum_type(&self) -> bool {
matches!(self, Type::Enum(_, _) | Type::UnifiedType { is_enum: true, .. })
}
pub fn is_struct_type(&self) -> bool {
matches!(self, Type::Struct(_, _) | Type::UnifiedType { is_enum: false, .. })
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct FunctionDef {
pub name: String,
pub args: Vec<(String, Type)>,
pub return_type: Type,
pub body: Vec<Stmt>,
pub generics: Vec<String>, pub is_extern: bool,
}
#[derive(Debug, Clone, PartialEq)]
pub struct StructDef {
pub name: String,
pub fields: Vec<(String, Type)>,
pub generics: Vec<String>, }
#[derive(Debug, Clone, PartialEq)]
pub struct ImplBlock {
pub target_type: Type, pub generics: Vec<String>,
pub methods: Vec<FunctionDef>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct VariantDef {
pub name: String,
pub kind: VariantKind,
}
#[derive(Debug, Clone, PartialEq)]
pub enum VariantKind {
Unit,
Tuple(Vec<Type>),
Struct(Vec<(String, Type)>),
Array(Type, usize), }
#[derive(Debug, Clone, PartialEq)]
pub struct EnumDef {
pub name: String,
pub variants: Vec<VariantDef>,
pub generics: Vec<String>,
}
pub type Stmt = Spanned<StmtKind>;
#[derive(Debug, Clone, PartialEq)]
pub enum LValue {
Variable(String),
FieldAccess(Box<LValue>, String),
IndexAccess(Box<LValue>, Vec<Expr>),
}
#[derive(Debug, Clone, PartialEq)]
pub enum StmtKind {
TensorDecl {
name: String,
type_annotation: Type,
init: Option<Expr>,
},
Let {
name: String,
type_annotation: Option<Type>,
value: Expr,
mutable: bool,
},
Assign {
lhs: LValue,
op: AssignOp, value: Expr,
},
Expr(Expr),
Return(Option<Expr>),
For {
loop_var: String,
iterator: Expr, body: Vec<Stmt>,
},
While {
cond: Expr,
body: Vec<Stmt>,
},
Use {
path: Vec<String>,
alias: Option<String>,
items: Vec<String>,
},
Break,
Continue,
Loop {
body: Vec<Stmt>,
},
}
#[derive(Debug, Clone, PartialEq)]
pub enum AssignOp {
Assign,
AddAssign,
SubAssign,
MulAssign,
DivAssign,
ModAssign,
MaxAssign,
AvgAssign,
}
#[derive(Debug, Clone, PartialEq)]
pub enum ComprehensionClause {
Generator { name: String, range: Expr }, Condition(Expr), }
pub type Expr = Spanned<ExprKind>;
#[derive(Debug, Clone, PartialEq)]
pub enum ExprKind {
Float(f64),
Int(i64),
Bool(bool),
StringLiteral(String),
CharLiteral(char),
Tuple(Vec<Expr>), Range(Box<Expr>, Box<Expr>), TensorComprehension {
indices: Vec<String>,
clauses: Vec<ComprehensionClause>,
body: Option<Box<Expr>>,
},
TensorLiteral(Vec<Expr>), TensorConstLiteral(Vec<Expr>), Symbol(String), LogicVar(String), Wildcard,
Variable(String),
IndexAccess(Box<Expr>, Vec<Expr>), TupleAccess(Box<Expr>, usize), FieldAccess(Box<Expr>, String),
BinOp(Box<Expr>, BinOp, Box<Expr>),
UnOp(UnOp, Box<Expr>),
Try(Box<Expr>),
FnCall(String, Vec<Expr>),
MethodCall(Box<Expr>, String, Vec<Expr>),
StaticMethodCall(Type, String, Vec<Expr>),
As(Box<Expr>, Type),
IfExpr(Box<Expr>, Vec<Stmt>, Option<Vec<Stmt>>),
IfLet {
pattern: Pattern,
expr: Box<Expr>,
then_block: Vec<Stmt>,
else_block: Option<Vec<Stmt>>,
},
Block(Vec<Stmt>),
StructInit(Type, Vec<(String, Expr)>),
EnumInit {
enum_name: String,
variant_name: String,
generics: Vec<Type>, payload: EnumVariantInit,
},
Match {
expr: Box<Expr>,
arms: Vec<(Pattern, Expr)>, },
}
#[derive(Debug, Clone, PartialEq)]
pub enum Pattern {
EnumPattern {
enum_name: String,
variant_name: String,
bindings: EnumPatternBindings,
},
Wildcard,
Literal(Box<Expr>),
}
#[derive(Debug, Clone, PartialEq)]
pub enum BinOp {
Add,
Sub,
Mul,
Div,
Mod, Eq,
Neq,
Lt,
Gt,
Le,
Ge,
And,
Or, }
#[derive(Debug, Clone, PartialEq)]
pub enum UnOp {
Neg,
Not,
Query,
Ref,
}
#[derive(Debug, Clone, PartialEq)]
pub struct Atom {
pub predicate: String,
pub args: Vec<Expr>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct RelationDecl {
pub name: String,
pub args: Vec<(String, Type)>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct Rule {
pub head: Atom,
pub body: Vec<LogicLiteral>,
pub weight: Option<f64>, }
#[derive(Debug, Clone, PartialEq)]
pub enum LogicLiteral {
Pos(Atom),
Neg(Atom),
}
#[derive(Debug, Clone, PartialEq)]
pub struct Module {
pub structs: Vec<StructDef>,
pub enums: Vec<EnumDef>,
pub impls: Vec<ImplBlock>,
pub functions: Vec<FunctionDef>,
pub tensor_decls: Vec<Stmt>, pub relations: Vec<RelationDecl>,
pub rules: Vec<Rule>,
pub queries: Vec<Expr>,
pub imports: Vec<String>,
pub submodules: HashMap<String, Module>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum EnumVariantInit {
Unit,
Tuple(Vec<Expr>),
Struct(Vec<(String, Expr)>),
}
#[derive(Debug, Clone, PartialEq)]
pub enum EnumPatternBindings {
Unit,
Tuple(Vec<String>), Struct(Vec<(String, String)>), }