mod declaration;
mod expression;
mod statement;
mod types;
use std::collections::{HashMap, HashSet};
use ty_ree::{
declaration::{
Callable, Constant, Declaration, ExternFunction, ImplBlock, Module, Newtype, Parameter,
ParameterList,
},
expression::operator::{
ArithmeticOperator, BinaryOperator, BitwiseOperator, ComparisonOperator, LogicalOperator,
UnaryOperator,
},
expression::{
Binding, Block, Expression, ExpressionKind, Literal, MatchArm, MatchPattern, Statement,
},
types::{
ConstExpression, FloatWidth, FunctionSignature, GenericArgument, GenericParameter,
GenericParameterKind, IntWidth, Lifetime, Mutability, Signedness, Type,
},
};
use crate::lexer::SpannedToken;
use crate::token::Token;
use crate::{Error, Span};
type Result<T> = crate::Result<T>;
struct TypeContext {
return_types: HashMap<String, Type>,
function_params: HashMap<String, Vec<Type>>,
newtypes: HashMap<String, Type>,
constants: HashMap<String, Type>,
generic_functions: HashSet<String>,
generic_parameter_kinds: HashMap<String, Vec<GenericParameterKind>>,
active_generic_params: HashSet<String>,
scopes: Vec<HashMap<String, Type>>,
auto_ref_variables: Vec<HashSet<String>>,
}
impl TypeContext {
fn new() -> Self {
Self {
return_types: HashMap::new(),
function_params: HashMap::new(),
newtypes: HashMap::new(),
constants: HashMap::new(),
generic_functions: HashSet::new(),
generic_parameter_kinds: HashMap::new(),
active_generic_params: HashSet::new(),
scopes: Vec::new(),
auto_ref_variables: Vec::new(),
}
}
fn push_scope(&mut self) {
self.scopes.push(HashMap::new());
self.auto_ref_variables.push(HashSet::new());
}
fn pop_scope(&mut self) {
self.scopes.pop();
self.auto_ref_variables.pop();
}
fn mark_auto_ref(&mut self, name: &str) {
if let Some(scope) = self.auto_ref_variables.last_mut() {
scope.insert(name.into());
}
}
fn is_auto_ref(&self, name: &str) -> bool {
self.auto_ref_variables
.iter()
.rev()
.any(|scope| scope.contains(name))
}
fn define_variable(&mut self, name: String, variable_type: Type) {
if let Some(scope) = self.scopes.last_mut() {
scope.insert(name, variable_type);
}
}
fn lookup_variable(&self, name: &str) -> Option<&Type> {
for scope in self.scopes.iter().rev() {
if let Some(variable_type) = scope.get(name) {
return Some(variable_type);
}
}
self.constants.get(name)
}
fn lookup_function(&self, name: &str) -> Option<&Type> {
self.return_types.get(name)
}
fn lookup_field(&self, type_name: &str, field_name: &str) -> Option<Type> {
let inner = self.newtypes.get(type_name)?;
if let Type::Tuple(fields) = inner {
fields.iter().find_map(|field| {
if let Type::Named(name, _) = field
&& name == field_name
{
Some(field.clone())
} else {
None
}
})
} else {
None
}
}
fn resolve_named(&self, resolved_type: &Type) -> Type {
if let Type::Named(name, _) = resolved_type
&& let Some(inner) = self.newtypes.get(name)
{
return self.resolve_named(inner);
}
resolved_type.clone()
}
}
fn anonymous_compound_name(compound: &Type) -> String {
fn type_component(resolved_type: &Type) -> String {
match resolved_type {
Type::Named(name, _) => name.clone(),
Type::Tuple(_) | Type::Enum(_) => anonymous_compound_name(resolved_type),
other => format!("{other:?}"),
}
}
match compound {
Type::Tuple(fields) => {
let parts: Vec<_> = fields.iter().map(type_component).collect();
format!("_t_{}", parts.join("_"))
}
Type::Enum(variants) => {
let parts: Vec<_> = variants.iter().map(type_component).collect();
format!("_e_{}", parts.join("_"))
}
_ => unreachable!(),
}
}
fn is_auto_ref_chain(expression: &Expression, context: &TypeContext) -> bool {
match &expression.kind {
ExpressionKind::Variable(name) => context.is_auto_ref(name),
ExpressionKind::Field(inner, _) | ExpressionKind::Index(inner, _) => {
is_auto_ref_chain(inner, context)
}
_ => false,
}
}
fn extract_variable_name(expression: &Expression) -> Option<&str> {
match &expression.kind {
ExpressionKind::Variable(name) => Some(name),
ExpressionKind::Dereference(inner) => {
if let ExpressionKind::Variable(name) = &inner.kind {
Some(name)
} else {
None
}
}
_ => None,
}
}
fn auto_deref(expression: Expression, context: &TypeContext) -> Expression {
if !is_auto_ref_chain(&expression, context) {
return expression;
}
let resolved = context.resolve_named(&expression.resolved_type);
if let Type::Pointer(_, _, inner) = resolved {
Expression::new(ExpressionKind::Dereference(Box::new(expression)), *inner)
} else {
expression
}
}
fn wrap_binding_type(binding: Binding, inner: Type) -> Type {
match binding {
Binding::Variable => {
Type::Pointer(Mutability::Mutable, Lifetime::Inferred, Box::new(inner))
}
Binding::Reference => {
Type::Pointer(Mutability::Shared, Lifetime::Inferred, Box::new(inner))
}
Binding::Value => inner,
}
}
fn block_result_type(block: &Block) -> Type {
block
.result
.as_ref()
.map_or_else(Type::unit, |expression| expression.resolved_type.clone())
}
fn is_block_expression(expression: &Expression) -> bool {
matches!(
expression.kind,
ExpressionKind::If { .. } | ExpressionKind::Match { .. } | ExpressionKind::Block(_)
)
}
pub struct Parser {
tokens: Vec<SpannedToken>,
position: usize,
context: TypeContext,
anonymous_types: Vec<Declaration>,
}
impl Parser {
pub fn new(tokens: Vec<SpannedToken>) -> Self {
Self {
tokens,
position: 0,
context: TypeContext::new(),
anonymous_types: Vec::new(),
}
}
fn peek(&self) -> &Token {
&self.tokens[self.position].token
}
fn peek_at(&self, offset: usize) -> &Token {
self.tokens
.get(self.position + offset)
.map_or(&Token::End, |token| &token.token)
}
fn span(&self) -> Span {
self.tokens[self.position].span
}
fn advance(&mut self) -> Token {
let token = self.tokens[self.position].token.clone();
if self.position < self.tokens.len() - 1 {
self.position += 1;
}
token
}
fn at(&self, token: &Token) -> bool {
std::mem::discriminant(self.peek()) == std::mem::discriminant(token)
}
fn expect(&mut self, expected: Token) -> Result<()> {
if self.at(&expected) {
self.advance();
Ok(())
} else {
Err(Error {
message: format!("expected {expected}, got {}", self.peek()),
span: Some(self.span()),
})
}
}
fn expect_identifier(&mut self) -> Result<String> {
if let Token::Identifier(name) = self.peek().clone() {
self.advance();
Ok(name)
} else {
Err(Error {
message: format!("expected identifier, got {}", self.peek()),
span: Some(self.span()),
})
}
}
fn is_generic_param(&self, name: &str) -> bool {
self.context.active_generic_params.contains(name)
}
pub fn parse_module(&mut self) -> Result<Module> {
let mut declarations = Vec::new();
let mut top_level_statements = Vec::new();
while !self.at(&Token::End) {
if self.is_declaration_start() {
declarations.extend(self.parse_declaration()?);
} else {
if top_level_statements.is_empty() {
self.context.push_scope();
}
top_level_statements.extend(self.parse_top_level_statement()?);
}
}
if !top_level_statements.is_empty() {
let has_main = declarations.iter().any(|declaration| {
if let Declaration::Impl(impl_block) = declaration {
matches!(&impl_block.target_type, Type::Named(name, _) if name == "main")
} else {
false
}
});
if has_main {
return Err(Error {
message: "top-level statements not allowed when fn main exists".into(),
span: None,
});
}
self.context.pop_scope();
declarations.push(Declaration::Type(Newtype {
name: "main".into(),
generic_parameters: vec![],
inner_type: Type::unit(),
public: false,
}));
declarations.push(Declaration::Impl(ImplBlock {
generic_parameters: vec![],
target_type: Type::Named("main".into(), vec![]),
callable: Callable {
return_type: Type::unit(),
body: Block {
statements: top_level_statements,
result: None,
},
parameters: Some(vec![]),
},
const_fn: false,
}));
}
let anonymous = std::mem::take(&mut self.anonymous_types);
if !anonymous.is_empty() {
let insert_position = declarations
.iter()
.rposition(|declaration| matches!(declaration, Declaration::Type(_)))
.map_or(0, |index| index + 1);
for (offset, declaration) in anonymous.into_iter().enumerate() {
declarations.insert(insert_position + offset, declaration);
}
}
Ok(declarations)
}
fn is_declaration_start(&self) -> bool {
matches!(
self.peek(),
Token::Fn
| Token::Extern
| Token::Type
| Token::Const
| Token::Import
| Token::Pub
| Token::Impl
)
}
fn parse_declaration(&mut self) -> Result<Vec<Declaration>> {
match self.peek().clone() {
Token::Pub => self.parse_pub_declaration(),
Token::Fn => self.parse_function(false),
Token::Extern => self.parse_extern(),
Token::Type => self.parse_type_declarations(false),
Token::Const => self.parse_constant(false),
Token::Import => Ok(vec![self.parse_import()?]),
Token::Impl => self.parse_impl(),
_ => unreachable!(),
}
}
fn parse_pub_declaration(&mut self) -> Result<Vec<Declaration>> {
self.advance();
match self.peek().clone() {
Token::Fn => self.parse_function(true),
Token::Const => self.parse_constant(true),
Token::Type => self.parse_type_declarations(true),
_ => Err(Error {
message: "pub can only be used with fn, const, or type".into(),
span: Some(self.span()),
}),
}
}
fn parse_top_level_statement(&mut self) -> Result<Vec<Statement>> {
if self.is_statement_start() {
return self.parse_statement();
}
let expression = self.parse_expression()?;
if self.at(&Token::Semicolon) {
self.advance();
}
Ok(vec![Statement::Expression(expression)])
}
}
fn postfix_binding_power(token: &Token) -> Option<u8> {
match token {
Token::Dot | Token::LeftParen | Token::LeftBracket | Token::As | Token::Bang => Some(23),
_ => None,
}
}
fn infix_binding_power(token: &Token) -> Option<(u8, u8)> {
match token {
Token::PipePipe => Some((1, 2)),
Token::AmpAmp => Some((3, 4)),
Token::EqualEqual
| Token::BangEqual
| Token::Less
| Token::LessEqual
| Token::Greater
| Token::GreaterEqual => Some((5, 6)),
Token::Pipe => Some((7, 8)),
Token::Caret => Some((9, 10)),
Token::Amp => Some((11, 12)),
Token::LessLess | Token::GreaterGreater => Some((13, 14)),
Token::Plus | Token::Minus => Some((15, 16)),
Token::Star | Token::Slash | Token::Percent => Some((17, 18)),
_ => None,
}
}
fn token_to_binary_operator(token: &Token) -> Option<BinaryOperator> {
match token {
Token::Plus => Some(BinaryOperator::Arithmetic(ArithmeticOperator::Add)),
Token::Minus => Some(BinaryOperator::Arithmetic(ArithmeticOperator::Subtract)),
Token::Star => Some(BinaryOperator::Arithmetic(ArithmeticOperator::Multiply)),
Token::Slash => Some(BinaryOperator::Arithmetic(ArithmeticOperator::Divide)),
Token::Percent => Some(BinaryOperator::Arithmetic(ArithmeticOperator::Remainder)),
Token::EqualEqual => Some(BinaryOperator::Comparison(ComparisonOperator::Equal)),
Token::BangEqual => Some(BinaryOperator::Comparison(ComparisonOperator::NotEqual)),
Token::Less => Some(BinaryOperator::Comparison(ComparisonOperator::Less)),
Token::LessEqual => Some(BinaryOperator::Comparison(ComparisonOperator::LessEqual)),
Token::Greater => Some(BinaryOperator::Comparison(ComparisonOperator::Greater)),
Token::GreaterEqual => Some(BinaryOperator::Comparison(ComparisonOperator::GreaterEqual)),
Token::AmpAmp => Some(BinaryOperator::Logical(LogicalOperator::And)),
Token::PipePipe => Some(BinaryOperator::Logical(LogicalOperator::Or)),
Token::Amp => Some(BinaryOperator::Bitwise(BitwiseOperator::And)),
Token::Pipe => Some(BinaryOperator::Bitwise(BitwiseOperator::Or)),
Token::Caret => Some(BinaryOperator::Bitwise(BitwiseOperator::Xor)),
Token::LessLess => Some(BinaryOperator::Bitwise(BitwiseOperator::ShiftLeft)),
Token::GreaterGreater => Some(BinaryOperator::Bitwise(BitwiseOperator::ShiftRight)),
_ => None,
}
}
fn operator_result_type(operator: &BinaryOperator, operand_type: &Type) -> Type {
match operator {
BinaryOperator::Comparison(_) | BinaryOperator::Logical(_) => Type::Bool,
BinaryOperator::Arithmetic(_) | BinaryOperator::Bitwise(_) => operand_type.clone(),
}
}
fn parse_type_name(name: &str) -> Type {
match name {
"bool" => Type::Bool,
"never" => Type::Never,
"void" => Type::unit(),
"i8" => Type::Int(IntWidth::W8, Signedness::Signed),
"i16" => Type::Int(IntWidth::W16, Signedness::Signed),
"i32" => Type::Int(IntWidth::W32, Signedness::Signed),
"i64" => Type::Int(IntWidth::W64, Signedness::Signed),
"i128" => Type::Int(IntWidth::W128, Signedness::Signed),
"u8" => Type::Int(IntWidth::W8, Signedness::Unsigned),
"u16" => Type::Int(IntWidth::W16, Signedness::Unsigned),
"u32" => Type::Int(IntWidth::W32, Signedness::Unsigned),
"u64" => Type::Int(IntWidth::W64, Signedness::Unsigned),
"u128" => Type::Int(IntWidth::W128, Signedness::Unsigned),
"f32" => Type::Float(FloatWidth::W32),
"f64" => Type::Float(FloatWidth::W64),
_ => Type::Named(name.into(), vec![]),
}
}