use crate::ast::{Expr, Type};
use anyhow::{anyhow, Result};
use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct TypeParameter {
pub name: String,
pub constraints: Vec<TypeConstraint>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum TypeConstraint {
Interface(String),
Numeric,
Comparable,
Reference,
}
#[derive(Debug, Clone)]
pub struct GenericType {
pub base_type: String,
pub type_params: Vec<TypeParameter>,
pub instantiations: Vec<(Vec<String>, Type)>,
}
impl GenericType {
pub fn new(base_type: String, type_params: Vec<TypeParameter>) -> Self {
Self {
base_type,
type_params,
instantiations: Vec::new(),
}
}
pub fn instantiate(&mut self, concrete_types: Vec<Type>) -> Result<Type> {
if concrete_types.len() != self.type_params.len() {
return Err(anyhow!(
"Expected {} type parameters, got {}",
self.type_params.len(),
concrete_types.len()
));
}
for (param, concrete) in self.type_params.iter().zip(concrete_types.iter()) {
self.check_constraints(param, concrete)?;
}
let type_key: Vec<String> = concrete_types.iter().map(|t| format!("{:?}", t)).collect();
if let Some((_, instantiated)) =
self.instantiations.iter().find(|(key, _)| key == &type_key)
{
return Ok(instantiated.clone());
}
let instantiated = Type::GenericInstance {
base_type: self.base_type.clone(),
type_arguments: concrete_types,
};
self.instantiations.push((type_key, instantiated.clone()));
Ok(instantiated)
}
fn check_constraints(&self, param: &TypeParameter, concrete: &Type) -> Result<()> {
for constraint in ¶m.constraints {
match constraint {
TypeConstraint::Numeric => {
if !matches!(concrete, Type::Integer | Type::Real) {
return Err(anyhow!("Type must be numeric"));
}
}
TypeConstraint::Comparable => {
}
TypeConstraint::Reference => {
if !matches!(concrete, Type::Pointer(_) | Type::String) {
return Err(anyhow!("Type must be a reference type"));
}
}
TypeConstraint::Interface(_) => {
}
}
}
Ok(())
}
}
pub struct TypeInference {
type_vars: HashMap<String, Type>,
constraints: Vec<TypeConstraint>,
next_type_var: u32,
}
impl TypeInference {
pub fn new() -> Self {
Self {
type_vars: HashMap::new(),
constraints: Vec::new(),
next_type_var: 0,
}
}
pub fn fresh_type_var(&mut self) -> Type {
let var_name = format!("T{}", self.next_type_var);
self.next_type_var += 1;
Type::Generic {
name: var_name,
constraints: vec![],
}
}
pub fn infer_expr(&mut self, expr: &Expr) -> Result<Type> {
match expr {
Expr::Literal(lit) => Ok(match lit {
crate::ast::Literal::Integer(_) => Type::Integer,
crate::ast::Literal::Real(_) => Type::Real,
crate::ast::Literal::Boolean(_) => Type::Boolean,
crate::ast::Literal::Char(_) => Type::Char,
crate::ast::Literal::String(_) => Type::String,
_ => Type::Integer,
}),
Expr::Variable(name) => {
if let Some(typ) = self.type_vars.get(name) {
Ok(typ.clone())
} else {
let typ = self.fresh_type_var();
self.type_vars.insert(name.clone(), typ.clone());
Ok(typ)
}
}
Expr::BinaryOp {
operator,
left,
right,
} => {
let left_type = self.infer_expr(left)?;
let right_type = self.infer_expr(right)?;
self.unify(&left_type, &right_type)?;
match operator.as_str() {
"+" | "-" | "*" | "/" | "div" | "mod" => Ok(left_type),
"=" | "<>" | "<" | "<=" | ">" | ">=" => Ok(Type::Boolean),
_ => Ok(left_type),
}
}
_ => Ok(Type::Integer),
}
}
fn unify(&mut self, t1: &Type, t2: &Type) -> Result<()> {
match (t1, t2) {
(Type::Integer, Type::Integer) => Ok(()),
(Type::Real, Type::Real) => Ok(()),
(Type::Boolean, Type::Boolean) => Ok(()),
(Type::Integer, Type::Real) | (Type::Real, Type::Integer) => {
Ok(())
}
_ => Err(anyhow!("Cannot unify types {:?} and {:?}", t1, t2)),
}
}
pub fn resolve(&self) -> HashMap<String, Type> {
self.type_vars.clone()
}
pub fn infer_from_expr(expr: &Expr) -> Type {
match expr {
Expr::Literal(crate::ast::Literal::Integer(_)) => Type::Integer,
Expr::Literal(crate::ast::Literal::Real(_)) => Type::Real,
Expr::Literal(crate::ast::Literal::Boolean(_)) => Type::Boolean,
Expr::Literal(crate::ast::Literal::Char(_)) => Type::Char,
Expr::Literal(crate::ast::Literal::String(_)) => Type::String,
Expr::BinaryOp { operator, left, right } => {
let lt = Self::infer_from_expr(left);
let rt = Self::infer_from_expr(right);
match operator.as_str() {
"=" | "<>" | "<" | "<=" | ">" | ">=" | "and" | "or" | "xor" => Type::Boolean,
_ => {
if matches!(lt, Type::Real) || matches!(rt, Type::Real) {
Type::Real
} else {
lt
}
}
}
}
Expr::UnaryOp { operator, operand } => {
let t = Self::infer_from_expr(operand);
if operator == "not" {
Type::Boolean
} else {
t
}
}
_ => Type::Integer,
}
}
}
pub fn infer_block_variable_types(block: &mut crate::ast::Block) {
for var_decl in &mut block.vars {
if let Some(ref expr) = var_decl.initial_value {
var_decl.variable_type = TypeInference::infer_from_expr(expr);
}
}
}
impl Default for TypeInference {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct OperatorOverload {
pub operator: OverloadableOperator,
pub left_type: Type,
pub right_type: Option<Type>,
pub return_type: Type,
pub implementation: String,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum OverloadableOperator {
Add,
Subtract,
Multiply,
Divide,
Equal,
NotEqual,
Less,
Greater,
Index,
Call,
}
impl OverloadableOperator {
pub fn from_operator_str(op: &str) -> Option<Self> {
match op {
"+" => Some(Self::Add),
"-" => Some(Self::Subtract),
"*" => Some(Self::Multiply),
"/" | "div" => Some(Self::Divide),
"=" => Some(Self::Equal),
"<>" => Some(Self::NotEqual),
"<" => Some(Self::Less),
">" => Some(Self::Greater),
_ => None,
}
}
}
pub struct OperatorRegistry {
overloads: Vec<OperatorOverload>,
}
impl OperatorRegistry {
pub fn new() -> Self {
Self {
overloads: Vec::new(),
}
}
pub fn register(&mut self, overload: OperatorOverload) {
self.overloads.push(overload);
}
pub fn lookup(
&self,
op: &OverloadableOperator,
left: &Type,
right: Option<&Type>,
) -> Option<&OperatorOverload> {
self.overloads.iter().find(|o| {
o.operator == *op
&& format!("{:?}", o.left_type) == format!("{:?}", left)
&& o.right_type.as_ref().map(|t| format!("{:?}", t))
== right.map(|t| format!("{:?}", t))
})
}
pub fn is_overloaded(
&self,
op: &OverloadableOperator,
left: &Type,
right: Option<&Type>,
) -> bool {
self.lookup(op, left, right).is_some()
}
}
impl Default for OperatorRegistry {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct TypeClass {
pub name: String,
pub methods: Vec<TypeClassMethod>,
pub instances: Vec<TypeClassInstance>,
}
#[derive(Debug, Clone)]
pub struct TypeClassMethod {
pub name: String,
pub signature: Type,
}
#[derive(Debug, Clone)]
pub struct TypeClassInstance {
pub typ: Type,
pub implementations: HashMap<String, String>,
}
impl TypeClass {
pub fn new(name: String) -> Self {
Self {
name,
methods: Vec::new(),
instances: Vec::new(),
}
}
pub fn add_method(&mut self, name: String, signature: Type) {
self.methods.push(TypeClassMethod { name, signature });
}
pub fn add_instance(&mut self, typ: Type, implementations: HashMap<String, String>) {
self.instances.push(TypeClassInstance {
typ,
implementations,
});
}
pub fn has_instance(&self, typ: &Type) -> bool {
self.instances
.iter()
.any(|inst| format!("{:?}", inst.typ) == format!("{:?}", typ))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generic_type() {
let mut generic = GenericType::new(
"Array".to_string(),
vec![TypeParameter {
name: "T".to_string(),
constraints: vec![],
}],
);
let int_array = generic.instantiate(vec![Type::Integer]).unwrap();
assert!(matches!(int_array, Type::GenericInstance { .. }));
}
#[test]
fn test_type_inference() {
let mut inference = TypeInference::new();
let expr = Expr::Literal(crate::ast::Literal::Integer(42));
let typ = inference.infer_expr(&expr).unwrap();
assert_eq!(typ, Type::Integer);
}
#[test]
fn test_operator_overload() {
let mut registry = OperatorRegistry::new();
registry.register(OperatorOverload {
operator: OverloadableOperator::Add,
left_type: Type::String,
right_type: Some(Type::String),
return_type: Type::String,
implementation: "string_concat".to_string(),
});
assert!(registry.is_overloaded(
&OverloadableOperator::Add,
&Type::String,
Some(&Type::String)
));
}
}