use std::collections::HashMap;
use crate::ast::ExprId;
use crate::intern::InternedStr;
use crate::type_repr::TypeRepr;
#[derive(Debug, Clone)]
pub struct TypeConstraint {
pub expr_id: ExprId,
pub ty: TypeRepr,
pub context: String,
}
impl TypeConstraint {
pub fn new(expr_id: ExprId, ty: TypeRepr, context: impl Into<String>) -> Self {
Self {
expr_id,
ty,
context: context.into(),
}
}
pub fn source_display(&self) -> &'static str {
self.ty.source_display()
}
}
#[derive(Debug, Clone)]
pub struct ParamLink {
pub expr_id: ExprId,
pub param_name: InternedStr,
pub context: String,
}
#[derive(Debug, Clone, Default)]
pub struct TypeEnv {
pub param_constraints: HashMap<InternedStr, Vec<TypeConstraint>>,
pub expr_constraints: HashMap<ExprId, Vec<TypeConstraint>>,
pub return_constraints: Vec<TypeConstraint>,
pub expr_to_param: Vec<ParamLink>,
pub param_to_exprs: HashMap<InternedStr, Vec<ExprId>>,
}
impl TypeEnv {
pub fn new() -> Self {
Self::default()
}
pub fn add_param_constraint(&mut self, param: InternedStr, constraint: TypeConstraint) {
self.param_constraints
.entry(param)
.or_default()
.push(constraint);
}
pub fn add_expr_constraint(&mut self, constraint: TypeConstraint) {
self.expr_constraints
.entry(constraint.expr_id)
.or_default()
.push(constraint);
}
pub fn add_constraint(&mut self, constraint: TypeConstraint) {
self.add_expr_constraint(constraint);
}
pub fn add_return_constraint(&mut self, constraint: TypeConstraint) {
self.return_constraints.push(constraint);
}
pub fn link_expr_to_param(&mut self, expr_id: ExprId, param_name: InternedStr, context: impl Into<String>) {
self.expr_to_param.push(ParamLink {
expr_id,
param_name,
context: context.into(),
});
self.param_to_exprs
.entry(param_name)
.or_default()
.push(expr_id);
}
pub fn get_param_constraints(&self, param: InternedStr) -> Option<&Vec<TypeConstraint>> {
self.param_constraints.get(¶m)
}
pub fn get_expr_constraints(&self, expr_id: ExprId) -> Option<&Vec<TypeConstraint>> {
self.expr_constraints.get(&expr_id)
}
pub fn get_linked_param(&self, expr_id: ExprId) -> Option<InternedStr> {
self.expr_to_param
.iter()
.find(|link| link.expr_id == expr_id)
.map(|link| link.param_name)
}
pub fn param_constraint_count(&self) -> usize {
self.param_constraints.values().map(|v| v.len()).sum()
}
pub fn expr_constraint_count(&self) -> usize {
self.expr_constraints.values().map(|v| v.len()).sum()
}
pub fn return_constraint_count(&self) -> usize {
self.return_constraints.len()
}
pub fn get_return_type(&self) -> Option<&TypeRepr> {
self.return_constraints.first().map(|c| &c.ty)
}
pub fn total_constraint_count(&self) -> usize {
self.param_constraint_count() + self.expr_constraint_count() + self.return_constraint_count()
}
pub fn is_empty(&self) -> bool {
self.param_constraints.is_empty()
&& self.expr_constraints.is_empty()
&& self.return_constraints.is_empty()
}
pub fn merge(&mut self, other: TypeEnv) {
for (param, constraints) in other.param_constraints {
self.param_constraints
.entry(param)
.or_default()
.extend(constraints);
}
for (expr_id, constraints) in other.expr_constraints {
self.expr_constraints
.entry(expr_id)
.or_default()
.extend(constraints);
}
self.return_constraints.extend(other.return_constraints);
self.expr_to_param.extend(other.expr_to_param);
for (param, expr_ids) in other.param_to_exprs {
self.param_to_exprs
.entry(param)
.or_default()
.extend(expr_ids);
}
}
pub fn summary(&self) -> String {
format!(
"TypeEnv {{ params: {}, exprs: {}, returns: {}, links: {} }}",
self.param_constraints.len(),
self.expr_constraints.len(),
self.return_constraints.len(),
self.expr_to_param.len(),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::intern::StringInterner;
use crate::type_repr::{CTypeSource, CTypeSpecs, InferredType, IntSize, RustTypeRepr, RustTypeSource};
fn c_int_type() -> TypeRepr {
TypeRepr::CType {
specs: CTypeSpecs::Int { signed: true, size: IntSize::Int },
derived: vec![],
source: CTypeSource::Header,
}
}
fn rust_c_int_type() -> TypeRepr {
TypeRepr::RustType {
repr: RustTypeRepr::from_type_string("c_int"),
source: RustTypeSource::FnParam { func_name: "test".to_string(), param_index: 0 },
}
}
fn apidoc_sv_ptr_type() -> TypeRepr {
let interner = StringInterner::new();
TypeRepr::from_apidoc_string("SV *", &interner)
}
#[test]
fn test_type_env_new() {
let env = TypeEnv::new();
assert!(env.is_empty());
assert_eq!(env.total_constraint_count(), 0);
}
#[test]
fn test_add_expr_constraint() {
let mut env = TypeEnv::new();
let expr_id = ExprId::next();
let constraint = TypeConstraint::new(
expr_id,
c_int_type(),
"test context",
);
env.add_expr_constraint(constraint);
assert!(!env.is_empty());
assert_eq!(env.expr_constraint_count(), 1);
assert_eq!(env.get_expr_constraints(expr_id).unwrap().len(), 1);
}
#[test]
fn test_add_multiple_constraints() {
let mut env = TypeEnv::new();
let expr_id = ExprId::next();
env.add_constraint(TypeConstraint::new(
expr_id,
c_int_type(),
"from C header",
));
env.add_constraint(TypeConstraint::new(
expr_id,
rust_c_int_type(),
"from bindings",
));
let constraints = env.get_expr_constraints(expr_id).unwrap();
assert_eq!(constraints.len(), 2);
assert_eq!(constraints[0].source_display(), "c-header");
assert_eq!(constraints[1].source_display(), "rust-bindings");
}
#[test]
fn test_link_expr_to_param() {
let mut env = TypeEnv::new();
let expr_id = ExprId::next();
let mut interner = StringInterner::new();
let param_name = interner.intern("x");
env.link_expr_to_param(expr_id, param_name, "parameter reference");
assert_eq!(env.get_linked_param(expr_id), Some(param_name));
}
#[test]
fn test_merge() {
let mut env1 = TypeEnv::new();
let mut env2 = TypeEnv::new();
let expr1 = ExprId::next();
let expr2 = ExprId::next();
env1.add_constraint(TypeConstraint::new(
expr1,
c_int_type(),
"env1",
));
env2.add_constraint(TypeConstraint::new(
expr2,
TypeRepr::CType {
specs: CTypeSpecs::Char { signed: None },
derived: vec![],
source: CTypeSource::Apidoc { raw: "char".to_string() },
},
"env2",
));
env1.merge(env2);
assert_eq!(env1.expr_constraint_count(), 2);
assert!(env1.get_expr_constraints(expr1).is_some());
assert!(env1.get_expr_constraints(expr2).is_some());
}
#[test]
fn test_return_constraints() {
let mut env = TypeEnv::new();
let expr_id = ExprId::next();
env.add_return_constraint(TypeConstraint::new(
expr_id,
apidoc_sv_ptr_type(),
"return type from apidoc",
));
assert_eq!(env.return_constraint_count(), 1);
assert_eq!(env.return_constraints[0].source_display(), "apidoc");
}
#[test]
fn test_type_repr_source_display() {
assert_eq!(c_int_type().source_display(), "c-header");
assert_eq!(apidoc_sv_ptr_type().source_display(), "apidoc");
assert_eq!(rust_c_int_type().source_display(), "rust-bindings");
let inferred = TypeRepr::Inferred(InferredType::IntLiteral);
assert_eq!(inferred.source_display(), "inferred");
}
}