#![allow(missing_docs)]
use super::{Type, TypeVar, Kind, TypeScheme};
use crate::diagnostics::{Error, Result, Span};
use crate::eval::value::Value;
use std::collections::{HashMap, HashSet};
use std::fmt;
#[derive(Debug, Clone, PartialEq)]
pub struct AlgebraicDataType {
pub name: String,
pub type_params: Vec<TypeVar>,
pub constructors: Vec<DataConstructor>,
pub kind: Kind,
pub variant_type: AlgebraicVariant,
pub span: Option<Span>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum AlgebraicVariant {
Sum,
Product,
GADT,
}
#[derive(Debug, Clone, PartialEq)]
pub struct DataConstructor {
pub name: String,
pub param_types: Vec<Type>,
pub return_type: Option<Type>,
pub tag: usize,
pub span: Option<Span>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum Pattern {
Wildcard,
Variable(String),
Literal(crate::ast::Literal),
Constructor {
name: String,
patterns: Vec<Pattern>,
},
Tuple(Vec<Pattern>),
Record {
fields: HashMap<String, Pattern>,
rest: Option<Box<Pattern>>,
},
Or(Vec<Pattern>),
Guard {
pattern: Box<Pattern>,
guard: String, },
}
#[derive(Debug, Clone)]
pub struct MatchClause {
pub pattern: Pattern,
pub guard: Option<String>,
pub body: String, pub span: Option<Span>,
}
#[derive(Debug, Clone)]
pub struct MatchExpression {
pub scrutinee: String, pub clauses: Vec<MatchClause>,
pub span: Option<Span>,
}
pub struct PatternMatcher {
types: HashMap<String, AlgebraicDataType>,
cache: HashMap<String, CompiledPattern>,
}
#[derive(Debug, Clone)]
pub struct CompiledPattern {
pub tree: DecisionTree,
pub bindings: Vec<String>,
}
#[derive(Debug, Clone)]
pub enum DecisionTree {
Success {
bindings: HashMap<String, Value>,
action: String,
},
Failure,
Test {
test: PatternTest,
success: Box<DecisionTree>,
failure: Box<DecisionTree>,
},
Switch {
scrutinee: String,
branches: HashMap<String, DecisionTree>,
default: Option<Box<DecisionTree>>,
},
}
#[derive(Debug, Clone)]
pub enum PatternTest {
Constructor {
name: String,
arity: usize,
},
Literal(crate::ast::Literal),
Type(Type),
Guard(String),
}
impl AlgebraicDataType {
pub fn new(
name: String,
type_params: Vec<TypeVar>,
variant_type: AlgebraicVariant,
span: Option<Span>,
) -> Self {
Self {
name,
type_params,
constructors: Vec::new(),
kind: Kind::Type, variant_type,
span,
}
}
pub fn add_constructor(&mut self, mut constructor: DataConstructor) {
constructor.tag = self.constructors.len();
self.constructors.push(constructor);
}
pub fn compute_kind(&self) -> Kind {
self.type_params.iter().fold(Kind::Type, |acc, _| {
Kind::arrow(Kind::Type, acc)
})
}
pub fn apply(&self, args: Vec<Type>) -> Result<Type> {
if args.len() != self.type_params.len() {
return Err(Box::new(Error::type_error(
format!(
"Type {} expects {} arguments, got {}",
self.name,
self.type_params.len(),
args.len()
),
self.span.unwrap_or_default(),
)));
}
let mut substitution = HashMap::new();
for (param, arg) in self.type_params.iter().zip(args.iter()) {
substitution.insert(param.clone(), arg.clone());
}
Ok(Type::Constructor {
name: self.name.clone(),
kind: self.compute_kind(),
})
}
pub fn constructors(&self) -> &[DataConstructor] {
&self.constructors
}
pub fn get_constructor(&self, name: &str) -> Option<&DataConstructor> {
self.constructors.iter().find(|c| c.name == name)
}
pub fn is_recursive(&self) -> bool {
self.constructors.iter().any(|c| {
c.param_types.iter().any(|t| self.contains_self_reference(t))
})
}
fn contains_self_reference(&self, ty: &Type) -> bool {
match ty {
Type::Constructor { name, .. } => name == &self.name,
Type::Application { constructor, argument } => {
self.contains_self_reference(constructor) || self.contains_self_reference(argument)
}
Type::Function { params, return_type } => {
params.iter().any(|p| self.contains_self_reference(p))
|| self.contains_self_reference(return_type)
}
_ => false,
}
}
}
impl DataConstructor {
pub fn new(name: String, param_types: Vec<Type>, span: Option<Span>) -> Self {
Self {
name,
param_types,
return_type: None,
tag: 0, span,
}
}
pub fn gadt(
name: String,
param_types: Vec<Type>,
return_type: Type,
span: Option<Span>,
) -> Self {
Self {
name,
param_types,
return_type: Some(return_type),
tag: 0,
span,
}
}
pub fn arity(&self) -> usize {
self.param_types.len()
}
pub fn is_nullary(&self) -> bool {
self.param_types.is_empty()
}
pub fn type_scheme(&self, result_type: &Type) -> TypeScheme {
if self.param_types.is_empty() {
TypeScheme::monomorphic(result_type.clone())
} else {
let func_type = Type::function(self.param_types.clone(), result_type.clone());
TypeScheme::monomorphic(func_type)
}
}
}
impl Pattern {
pub fn is_irrefutable(&self) -> bool {
match self {
Pattern::Wildcard | Pattern::Variable(_) => true,
Pattern::Tuple(patterns) => patterns.iter().all(|p| p.is_irrefutable()),
Pattern::Record { fields, rest } => {
fields.values().all(|p| p.is_irrefutable()) && rest.is_none()
}
_ => false,
}
}
pub fn bound_variables(&self) -> HashSet<String> {
let mut vars = HashSet::new();
self.collect_variables(&mut vars);
vars
}
fn collect_variables(&self, vars: &mut HashSet<String>) {
match self {
Pattern::Variable(name) => {
vars.insert(name.clone());
}
Pattern::Constructor { patterns, .. } => {
for pattern in patterns {
pattern.collect_variables(vars);
}
}
Pattern::Tuple(patterns) => {
for pattern in patterns {
pattern.collect_variables(vars);
}
}
Pattern::Record { fields, rest } => {
for pattern in fields.values() {
pattern.collect_variables(vars);
}
if let Some(rest_pattern) = rest {
rest_pattern.collect_variables(vars);
}
}
Pattern::Or(patterns) => {
for pattern in patterns {
pattern.collect_variables(vars);
}
}
Pattern::Guard { pattern, .. } => {
pattern.collect_variables(vars);
}
_ => {}
}
}
pub fn type_check(&self, expected_type: &Type) -> Result<HashMap<String, Type>> {
let mut bindings = HashMap::new();
self.type_check_impl(expected_type, &mut bindings)?;
Ok(bindings)
}
fn type_check_impl(
&self,
expected_type: &Type,
bindings: &mut HashMap<String, Type>,
) -> Result<()> {
match self {
Pattern::Wildcard => Ok(()),
Pattern::Variable(name) => {
bindings.insert(name.clone(), expected_type.clone());
Ok(())
}
Pattern::Literal(lit) => {
let lit_type = literal_to_type(lit);
if types_compatible(&lit_type, expected_type) {
Ok(())
} else {
Err(Box::new(Error::type_error(
format!(
"Pattern literal type {lit_type} doesn't match expected type {expected_type}"
),
Span::default(),
)))
}
}
Pattern::Constructor { name: _, patterns } => {
for pattern in patterns {
pattern.type_check_impl(&Type::Dynamic, bindings)?;
}
Ok(())
}
Pattern::Tuple(patterns) => {
for pattern in patterns {
pattern.type_check_impl(&Type::Dynamic, bindings)?;
}
Ok(())
}
Pattern::Record { fields, rest: _ } => {
for pattern in fields.values() {
pattern.type_check_impl(&Type::Dynamic, bindings)?;
}
Ok(())
}
Pattern::Or(patterns) => {
let mut first_bindings = HashMap::new();
if let Some(first) = patterns.first() {
first.type_check_impl(expected_type, &mut first_bindings)?;
}
for pattern in patterns.iter().skip(1) {
let mut pattern_bindings = HashMap::new();
pattern.type_check_impl(expected_type, &mut pattern_bindings)?;
if first_bindings != pattern_bindings {
return Err(Box::new(Error::type_error(
"All branches in or-pattern must bind the same variables with the same types".to_string(),
Span::default(),
)));
}
}
bindings.extend(first_bindings);
Ok(())
}
Pattern::Guard { pattern, .. } => {
pattern.type_check_impl(expected_type, bindings)
}
}
}
}
impl PatternMatcher {
pub fn new() -> Self {
Self {
types: HashMap::new(),
cache: HashMap::new(),
}
}
pub fn register_type(&mut self, adt: AlgebraicDataType) {
self.types.insert(adt.name.clone(), adt);
}
pub fn compile_match(&mut self, match_expr: &MatchExpression) -> Result<CompiledPattern> {
let cache_key = format!("{match_expr:?}");
if let Some(cached) = self.cache.get(&cache_key) {
return Ok(cached.clone());
}
let tree = self.compile_clauses(&match_expr.clauses)?;
let bindings = self.extract_bindings(&match_expr.clauses);
let compiled = CompiledPattern { tree, bindings };
self.cache.insert(cache_key, compiled.clone());
Ok(compiled)
}
fn compile_clauses(&self, clauses: &[MatchClause]) -> Result<DecisionTree> {
if clauses.is_empty() {
return Ok(DecisionTree::Failure);
}
let first_clause = &clauses[0];
if first_clause.pattern.is_irrefutable() {
Ok(DecisionTree::Success {
bindings: HashMap::new(), action: first_clause.body.clone(),
})
} else {
let test = self.pattern_to_test(&first_clause.pattern)?;
let success = Box::new(DecisionTree::Success {
bindings: HashMap::new(),
action: first_clause.body.clone(),
});
let failure = Box::new(self.compile_clauses(&clauses[1..])?);
Ok(DecisionTree::Test {
test,
success,
failure,
})
}
}
fn pattern_to_test(&self, pattern: &Pattern) -> Result<PatternTest> {
match pattern {
Pattern::Literal(lit) => Ok(PatternTest::Literal(lit.clone())),
Pattern::Constructor { name, patterns } => Ok(PatternTest::Constructor {
name: name.clone(),
arity: patterns.len(),
}),
Pattern::Guard { guard, .. } => Ok(PatternTest::Guard(guard.clone())),
_ => Err(Box::new(Error::type_error(
"Cannot convert pattern to test".to_string(),
Span::default(),
)))
}
}
fn extract_bindings(&self, clauses: &[MatchClause]) -> Vec<String> {
clauses
.iter()
.flat_map(|clause| clause.pattern.bound_variables())
.collect()
}
pub fn check_exhaustiveness(&self, patterns: &[Pattern], _ty: &Type) -> Result<bool> {
if patterns.iter().any(|p| p.is_irrefutable()) {
return Ok(true);
}
Ok(false)
}
pub fn check_redundancy(&self, _patterns: &[Pattern]) -> Vec<usize> {
Vec::new()
}
}
fn literal_to_type(lit: &crate::ast::Literal) -> Type {
match lit {
crate::ast::Literal::ExactInteger(_) => Type::Number,
crate::ast::Literal::InexactReal(_) => Type::Number,
crate::ast::Literal::Number(_) => Type::Number,
crate::ast::Literal::Rational { .. } => Type::Number,
crate::ast::Literal::Complex { .. } => Type::Number,
crate::ast::Literal::String(_) => Type::String,
crate::ast::Literal::Boolean(_) => Type::Boolean,
crate::ast::Literal::Character(_) => Type::Char,
crate::ast::Literal::Bytevector(_) => Type::Bytevector,
crate::ast::Literal::Nil => Type::Unit,
crate::ast::Literal::Unspecified => Type::Unit,
}
}
fn types_compatible(t1: &Type, t2: &Type) -> bool {
match (t1, t2) {
(Type::Dynamic, _) | (_, Type::Dynamic) => true,
_ => t1 == t2,
}
}
impl fmt::Display for AlgebraicDataType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "data {}", self.name)?;
if !self.type_params.is_empty() {
write!(f, " (")?;
for (i, param) in self.type_params.iter().enumerate() {
if i > 0 { write!(f, " ")?; }
write!(f, "{param}")?;
}
write!(f, ")")?;
}
match self.variant_type {
AlgebraicVariant::Sum => {
for (i, constructor) in self.constructors.iter().enumerate() {
if i == 0 {
write!(f, " = ")?;
} else {
write!(f, " | ")?;
}
write!(f, "{constructor}")?;
}
}
AlgebraicVariant::Product => {
write!(f, " {{")?;
for (i, constructor) in self.constructors.iter().enumerate() {
if i > 0 { write!(f, ", ")?; }
write!(f, "{constructor}")?;
}
write!(f, "}}")?;
}
AlgebraicVariant::GADT => {
write!(f, " where")?;
for constructor in &self.constructors {
write!(f, "\n {constructor}")?;
}
}
}
Ok(())
}
}
impl fmt::Display for DataConstructor {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.name)?;
if !self.param_types.is_empty() {
write!(f, " (")?;
for (i, param_type) in self.param_types.iter().enumerate() {
if i > 0 { write!(f, " ")?; }
write!(f, "{param_type}")?;
}
write!(f, ")")?;
}
if let Some(return_type) = &self.return_type {
write!(f, " : {return_type}")?;
}
Ok(())
}
}
impl fmt::Display for Pattern {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Pattern::Wildcard => write!(f, "_"),
Pattern::Variable(name) => write!(f, "{name}"),
Pattern::Literal(lit) => write!(f, "{lit}"),
Pattern::Constructor { name, patterns } => {
write!(f, "{name}")?;
if !patterns.is_empty() {
write!(f, " (")?;
for (i, pattern) in patterns.iter().enumerate() {
if i > 0 { write!(f, " ")?; }
write!(f, "{pattern}")?;
}
write!(f, ")")?;
}
Ok(())
}
Pattern::Tuple(patterns) => {
write!(f, "(")?;
for (i, pattern) in patterns.iter().enumerate() {
if i > 0 { write!(f, ", ")?; }
write!(f, "{pattern}")?;
}
write!(f, ")")
}
Pattern::Record { fields, rest } => {
write!(f, "{{")?;
for (i, (name, pattern)) in fields.iter().enumerate() {
if i > 0 { write!(f, ", ")?; }
write!(f, "{name} = {pattern}")?;
}
if let Some(rest_pattern) = rest {
if !fields.is_empty() { write!(f, ", ")?; }
write!(f, "..{rest_pattern}")?;
}
write!(f, "}}")
}
Pattern::Or(patterns) => {
for (i, pattern) in patterns.iter().enumerate() {
if i > 0 { write!(f, " | ")?; }
write!(f, "{pattern}")?;
}
Ok(())
}
Pattern::Guard { pattern, guard } => {
write!(f, "{pattern} if {guard}")
}
}
}
}
impl Default for PatternMatcher {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_algebraic_data_type_creation() {
let mut maybe_type = AlgebraicDataType::new(
"Maybe".to_string(),
vec![TypeVar::with_name("a")],
AlgebraicVariant::Sum,
None,
);
let none_constructor = DataConstructor::new("None".to_string(), vec![], None);
let some_constructor = DataConstructor::new(
"Some".to_string(),
vec![Type::named_var("a")],
None,
);
maybe_type.add_constructor(none_constructor);
maybe_type.add_constructor(some_constructor);
assert_eq!(maybe_type.constructors.len(), 2);
assert_eq!(maybe_type.constructors[0].name, "None");
assert_eq!(maybe_type.constructors[1].name, "Some");
assert_eq!(maybe_type.constructors[1].arity(), 1);
}
#[test]
fn test_pattern_variables() {
let pattern = Pattern::Constructor {
name: "Some".to_string(),
patterns: vec![Pattern::Variable("x".to_string())],
};
let vars = pattern.bound_variables();
assert!(vars.contains("x"));
assert_eq!(vars.len(), 1);
}
#[test]
fn test_pattern_irrefutability() {
assert!(Pattern::Wildcard.is_irrefutable());
assert!(Pattern::Variable("x".to_string()).is_irrefutable());
assert!(!Pattern::Literal(crate::ast::Literal::Boolean(true)).is_irrefutable());
let tuple_pattern = Pattern::Tuple(vec![
Pattern::Variable("x".to_string()),
Pattern::Wildcard,
]);
assert!(tuple_pattern.is_irrefutable());
}
#[test]
fn test_constructor_type_scheme() {
let constructor = DataConstructor::new(
"Cons".to_string(),
vec![Type::named_var("a"), Type::list(Type::named_var("a"))],
None,
);
let result_type = Type::list(Type::named_var("a"));
let scheme = constructor.type_scheme(&result_type);
match &scheme.type_ {
Type::Function { params, return_type } => {
assert_eq!(params.len(), 2);
assert_eq!(**return_type, result_type);
}
_ => panic!("Expected function type"),
}
}
}