use super::stream::{BinaryOp, Expression, LiteralValue, UnaryOp};
use crate::core::{PrimitiveType, SharedUniverse, TypeId, TypeKind};
use im::HashMap as ImHashMap;
pub struct TypeInference {
universe: SharedUniverse,
bindings: ImHashMap<String, TypeId>,
}
#[derive(Debug, Clone, Default)]
pub struct LookaheadContext {
pub expected_return: Option<TypeId>,
pub assignment_target: Option<TypeId>,
pub scope_bindings: ImHashMap<String, TypeId>,
pub call_params: Vec<TypeId>,
pub param_index: usize,
pub previous_type: Option<TypeId>,
pub sibling_types: Vec<TypeId>,
}
impl LookaheadContext {
pub fn new() -> Self {
Self::default()
}
pub fn expecting(expected: TypeId) -> Self {
Self {
assignment_target: Some(expected),
..Default::default()
}
}
pub fn with_binding(mut self, name: impl Into<String>, typ: TypeId) -> Self {
self.scope_bindings = self.scope_bindings.update(name.into(), typ);
self
}
pub fn in_call(mut self, params: Vec<TypeId>, index: usize) -> Self {
self.call_params = params;
self.param_index = index;
self
}
}
#[derive(Debug, Clone)]
pub struct InferenceResult {
pub type_id: TypeId,
pub confidence: f32,
pub alternatives: Vec<(TypeId, f32)>,
}
impl TypeInference {
pub fn new(universe: SharedUniverse) -> Self {
Self {
universe,
bindings: ImHashMap::new(),
}
}
pub fn infer(&self, expr: &Expression, ctx: &LookaheadContext) -> Option<TypeId> {
match expr {
Expression::Identifier(name) => self.infer_identifier(name, ctx),
Expression::Literal(lit) => self.infer_literal(lit),
Expression::Binary { op, left, right } => self.infer_binary_op(op, left, right, ctx),
Expression::Unary { op, operand } => self.infer_unary_op(op, operand, ctx),
Expression::Call { func, args } => self.infer_call(func, args, ctx),
Expression::Selector { base, field } => self.infer_selector(base, field, ctx),
Expression::TypeAssertion { expr: _, typ } => Some(*typ),
Expression::Composite { typ, elements: _ } => Some(*typ),
_ => ctx.assignment_target,
}
}
fn infer_identifier(&self, name: &str, ctx: &LookaheadContext) -> Option<TypeId> {
if let Some(&typ) = ctx.scope_bindings.get(name) {
return Some(typ);
}
if let Some(&typ) = self.bindings.get(name) {
return Some(typ);
}
let symbol = self.universe.symbols().lookup(None, name)?;
let _entity = self.universe.lookup_by_symbol(symbol);
ctx.assignment_target
}
fn infer_literal(&self, lit: &LiteralValue) -> Option<TypeId> {
let _prim = match lit {
LiteralValue::Int(_) => PrimitiveType::UntypedInt,
LiteralValue::Float(_) => PrimitiveType::UntypedFloat,
LiteralValue::String(_) => PrimitiveType::UntypedString,
LiteralValue::Bool(_) => PrimitiveType::UntypedBool,
LiteralValue::Nil => PrimitiveType::UntypedNil,
};
self.universe.get_type(TypeId(2)).map(|_| TypeId(2))
}
fn infer_binary_op(
&self,
op: &BinaryOp,
left: &Expression,
right: &Expression,
ctx: &LookaheadContext,
) -> Option<TypeId> {
let left_type = self.infer(left, ctx)?;
let right_type = self.infer(right, ctx)?;
match op {
BinaryOp::Add | BinaryOp::Sub | BinaryOp::Mul | BinaryOp::Div => {
Some(self.unify_numeric(left_type, right_type)?)
}
BinaryOp::And | BinaryOp::Or => {
Some(left_type) }
BinaryOp::Eq
| BinaryOp::Ne
| BinaryOp::Lt
| BinaryOp::Le
| BinaryOp::Gt
| BinaryOp::Ge => {
self.universe.get_type(TypeId(1)).map(|_| TypeId(1)) }
_ => Some(left_type),
}
}
fn infer_unary_op(
&self,
op: &UnaryOp,
operand: &Expression,
ctx: &LookaheadContext,
) -> Option<TypeId> {
let operand_type = self.infer(operand, ctx)?;
match op {
UnaryOp::Not => {
self.universe.get_type(TypeId(1)).map(|_| TypeId(1))
}
UnaryOp::Neg | UnaryOp::Pos => {
Some(operand_type)
}
UnaryOp::Addr => {
Some(operand_type)
}
_ => Some(operand_type),
}
}
fn infer_call(
&self,
func: &Expression,
_args: &[Expression],
ctx: &LookaheadContext,
) -> Option<TypeId> {
let func_type = self.infer(func, ctx)?;
if let Some(typ) = self.universe.get_type(func_type) {
match &typ.kind {
TypeKind::Func { results, .. } => {
results.first().map(|r| r.typ)
}
_ => None,
}
} else {
None
}
}
fn infer_selector(
&self,
base: &Expression,
_field: &str,
ctx: &LookaheadContext,
) -> Option<TypeId> {
let base_type = self.infer(base, ctx)?;
if let Some(typ) = self.universe.get_type(base_type) {
match &typ.kind {
TypeKind::Struct { fields } => {
fields
.iter()
.find(|f| f.name.as_ref() == _field)
.map(|f| f.typ)
}
_ => None,
}
} else {
None
}
}
fn unify_numeric(&self, left: TypeId, right: TypeId) -> Option<TypeId> {
if left == right {
Some(left)
} else {
Some(left)
}
}
pub fn lookahead_predict(&self, _partial: &str, ctx: &LookaheadContext) -> Vec<(TypeId, f32)> {
let mut predictions = Vec::new();
if let Some(expected) = ctx.assignment_target {
predictions.push((expected, 0.9));
}
predictions
}
pub fn infer_param_type(
&self,
func_expr: &Expression,
param_index: usize,
ctx: &LookaheadContext,
) -> Option<TypeId> {
let func_type = self.infer(func_expr, ctx)?;
if let Some(typ) = self.universe.get_type(func_type) {
match &typ.kind {
TypeKind::Func { params, .. } => params.get(param_index).map(|p| p.typ),
_ => None,
}
} else {
None
}
}
pub fn bind(&mut self, name: impl Into<String>, typ: TypeId) {
self.bindings = self.bindings.update(name.into(), typ);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::TypeUniverse;
use std::sync::Arc;
fn setup_inference() -> TypeInference {
let universe = Arc::new(TypeUniverse::new());
TypeInference::new(universe)
}
#[test]
fn test_literal_inference() {
let inference = setup_inference();
let lit = Expression::Literal(LiteralValue::Int(42));
let typ = inference.infer(&lit, &LookaheadContext::default());
assert!(typ.is_some());
}
#[test]
fn test_lookahead_context() {
let ctx = LookaheadContext::new()
.with_binding("x", TypeId(1))
.in_call(vec![TypeId(2), TypeId(3)], 0);
assert_eq!(ctx.scope_bindings.get("x"), Some(&TypeId(1)));
assert_eq!(ctx.param_index, 0);
}
}