use crate::flow::cfg::{BasicBlock, BlockId, CFG, Terminator};
use crate::flow::dataflow::{DataflowResult, Direction, TransferFunction, find_node_by_id};
use crate::flow::symbol_table::{SymbolTable, ValueOrigin};
use crate::semantics::LanguageSemantics;
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum InferredType {
String,
Number,
Boolean,
Null,
Undefined,
Array(Box<InferredType>),
Object,
Function,
Optional(Box<InferredType>),
Union(Vec<InferredType>),
Unknown,
}
impl InferredType {
pub fn is_nullable(&self) -> bool {
matches!(
self,
InferredType::Null | InferredType::Undefined | InferredType::Optional(_)
) || matches!(self, InferredType::Union(types) if types.iter().any(|t| t.is_nullable()))
}
pub fn is_primitive(&self) -> bool {
matches!(
self,
InferredType::String | InferredType::Number | InferredType::Boolean
)
}
pub fn simplify(self) -> Self {
match self {
InferredType::Union(types) => {
let mut flat: Vec<InferredType> = Vec::new();
for t in types {
match t.simplify() {
InferredType::Union(inner) => flat.extend(inner),
other => {
if !flat.contains(&other) {
flat.push(other);
}
}
}
}
match flat.len() {
0 => InferredType::Unknown,
1 => flat.into_iter().next().unwrap(),
_ => InferredType::Union(flat),
}
}
other => other,
}
}
pub fn union(self, other: InferredType) -> InferredType {
if self == other {
return self;
}
match (self, other) {
(InferredType::Unknown, other) | (other, InferredType::Unknown) => other,
(InferredType::Union(mut a), InferredType::Union(b)) => {
a.extend(b);
InferredType::Union(a).simplify()
}
(InferredType::Union(mut a), other) | (other, InferredType::Union(mut a)) => {
a.push(other);
InferredType::Union(a).simplify()
}
(a, b) => InferredType::Union(vec![a, b]).simplify(),
}
}
pub fn make_optional(self) -> InferredType {
if self.is_nullable() {
self
} else {
InferredType::Optional(Box::new(self))
}
}
pub fn unwrap_optional(&self) -> &InferredType {
match self {
InferredType::Optional(inner) => inner,
other => other,
}
}
}
impl std::fmt::Display for InferredType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
InferredType::String => write!(f, "String"),
InferredType::Number => write!(f, "Number"),
InferredType::Boolean => write!(f, "Boolean"),
InferredType::Null => write!(f, "null"),
InferredType::Undefined => write!(f, "undefined"),
InferredType::Array(elem) => write!(f, "Array<{}>", elem),
InferredType::Object => write!(f, "Object"),
InferredType::Function => write!(f, "Function"),
InferredType::Optional(inner) => write!(f, "{}?", inner),
InferredType::Union(types) => {
let type_strs: Vec<String> = types.iter().map(|t| t.to_string()).collect();
write!(f, "{}", type_strs.join(" | "))
}
InferredType::Unknown => write!(f, "unknown"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Nullability {
DefinitelyNonNull,
PossiblyNull,
DefinitelyNull,
Unknown,
}
impl Nullability {
pub fn merge(self, other: Nullability) -> Nullability {
use Nullability::*;
match (self, other) {
(a, b) if a == b => a,
(Unknown, _) | (_, Unknown) => Unknown,
(DefinitelyNull, DefinitelyNonNull) | (DefinitelyNonNull, DefinitelyNull) => {
PossiblyNull
}
(PossiblyNull, _) | (_, PossiblyNull) => PossiblyNull,
_ => Unknown,
}
}
pub fn could_be_null(&self) -> bool {
matches!(
self,
Nullability::PossiblyNull | Nullability::DefinitelyNull | Nullability::Unknown
)
}
pub fn is_definitely_non_null(&self) -> bool {
matches!(self, Nullability::DefinitelyNonNull)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct TypeInfo {
pub inferred_type: InferredType,
pub nullability: Nullability,
}
impl TypeInfo {
pub fn new(inferred_type: InferredType) -> Self {
let nullability = if inferred_type.is_nullable() {
Nullability::PossiblyNull
} else {
Nullability::DefinitelyNonNull
};
Self {
inferred_type,
nullability,
}
}
pub fn null() -> Self {
Self {
inferred_type: InferredType::Null,
nullability: Nullability::DefinitelyNull,
}
}
pub fn undefined() -> Self {
Self {
inferred_type: InferredType::Undefined,
nullability: Nullability::DefinitelyNull,
}
}
pub fn with_nullability(inferred_type: InferredType, nullability: Nullability) -> Self {
Self {
inferred_type,
nullability,
}
}
pub fn unknown() -> Self {
Self {
inferred_type: InferredType::Unknown,
nullability: Nullability::Unknown,
}
}
pub fn merge(self, other: TypeInfo) -> TypeInfo {
TypeInfo {
inferred_type: self.inferred_type.union(other.inferred_type),
nullability: self.nullability.merge(other.nullability),
}
}
}
impl Default for TypeInfo {
fn default() -> Self {
Self::unknown()
}
}
#[derive(Debug, Clone, Default)]
pub struct TypeTable {
types: HashMap<String, TypeInfo>,
}
impl TypeTable {
pub fn new() -> Self {
Self::default()
}
pub fn get(&self, name: &str) -> Option<&TypeInfo> {
self.types.get(name)
}
pub fn get_or_unknown(&self, name: &str) -> TypeInfo {
self.types
.get(name)
.cloned()
.unwrap_or_else(TypeInfo::unknown)
}
pub fn set(&mut self, name: String, info: TypeInfo) {
self.types.insert(name, info);
}
pub fn remove(&mut self, name: &str) -> Option<TypeInfo> {
self.types.remove(name)
}
pub fn contains(&self, name: &str) -> bool {
self.types.contains_key(name)
}
pub fn get_type(&self, name: &str) -> Option<&InferredType> {
self.types.get(name).map(|info| &info.inferred_type)
}
pub fn get_nullability(&self, name: &str) -> Nullability {
self.types
.get(name)
.map(|info| info.nullability)
.unwrap_or(Nullability::Unknown)
}
pub fn is_definitely_null(&self, name: &str) -> bool {
self.types
.get(name)
.map(|info| info.nullability == Nullability::DefinitelyNull)
.unwrap_or(false)
}
pub fn is_possibly_null(&self, name: &str) -> bool {
self.types
.get(name)
.map(|info| info.nullability.could_be_null())
.unwrap_or(true) }
pub fn is_definitely_non_null(&self, name: &str) -> bool {
self.types
.get(name)
.map(|info| info.nullability.is_definitely_non_null())
.unwrap_or(false)
}
pub fn merge(&mut self, other: &TypeTable) {
for (name, other_info) in &other.types {
if let Some(self_info) = self.types.get(name) {
self.types
.insert(name.clone(), self_info.clone().merge(other_info.clone()));
} else {
self.types.insert(name.clone(), other_info.clone());
}
}
}
pub fn iter(&self) -> impl Iterator<Item = (&String, &TypeInfo)> {
self.types.iter()
}
pub fn names(&self) -> impl Iterator<Item = &String> {
self.types.keys()
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct TypeFact {
pub var_name: String,
pub type_info: TypeInfo,
}
impl TypeFact {
pub fn new(var_name: impl Into<String>, type_info: TypeInfo) -> Self {
Self {
var_name: var_name.into(),
type_info,
}
}
}
pub struct TypeInferrer {
semantics: &'static LanguageSemantics,
nullable_functions: HashSet<&'static str>,
}
impl TypeInferrer {
pub fn new(semantics: &'static LanguageSemantics) -> Self {
let nullable_functions: HashSet<&'static str> = [
"find",
"findIndex",
"get",
"getAttribute",
"getElementById",
"querySelector",
"match",
"exec",
"pop",
"shift",
"get",
"first",
"last",
"find",
"ok",
"err",
"Get",
"get",
"find",
"get",
"find",
"findFirst",
"orElse",
]
.into_iter()
.collect();
Self {
semantics,
nullable_functions,
}
}
pub fn infer_type(&self, node: tree_sitter::Node, source: &[u8]) -> TypeInfo {
let kind = node.kind();
let sem = self.semantics;
if sem.is_string_literal(kind)
|| kind == "string"
|| kind == "template_string"
|| kind == "template_literal"
{
return TypeInfo::new(InferredType::String);
}
if sem.is_numeric_literal(kind)
|| kind == "number"
|| kind == "integer"
|| kind == "float"
|| kind == "integer_literal"
|| kind == "float_literal"
{
return TypeInfo::new(InferredType::Number);
}
if sem.is_boolean_literal(kind) || kind == "true" || kind == "false" {
return TypeInfo::new(InferredType::Boolean);
}
if sem.is_null_literal(kind) || kind == "null" || kind == "nil" || kind == "None" {
return TypeInfo::null();
}
if kind == "undefined" {
return TypeInfo::undefined();
}
if kind == "array" || kind == "array_expression" || kind == "list" {
let elem_type = node
.named_child(0)
.map(|child| self.infer_type(child, source).inferred_type)
.unwrap_or(InferredType::Unknown);
return TypeInfo::new(InferredType::Array(Box::new(elem_type)));
}
if kind == "object"
|| kind == "object_expression"
|| kind == "dictionary"
|| kind == "dict"
|| kind == "map_literal"
{
return TypeInfo::new(InferredType::Object);
}
if sem.is_function_def(kind)
|| kind == "arrow_function"
|| kind == "function_expression"
|| kind == "lambda"
|| kind == "closure_expression"
{
return TypeInfo::new(InferredType::Function);
}
if sem.is_call(kind) {
return self.infer_call_type(node, source);
}
if sem.is_binary_expression(kind) {
return self.infer_binary_type(node, source);
}
if sem.is_member_access(kind) {
return TypeInfo::with_nullability(InferredType::Unknown, Nullability::PossiblyNull);
}
if sem.is_identifier(kind) || kind == "identifier" {
return TypeInfo::unknown();
}
if kind == "parenthesized_expression"
&& let Some(inner) = node.named_child(0)
{
return self.infer_type(inner, source);
}
if kind == "await_expression"
&& let Some(inner) = node.named_child(0)
{
return self.infer_type(inner, source);
}
if kind == "ternary_expression" || kind == "conditional_expression" {
let consequence = node.child_by_field_name("consequence");
let alternative = node.child_by_field_name("alternative");
if let (Some(c), Some(a)) = (consequence, alternative) {
let c_type = self.infer_type(c, source);
let a_type = self.infer_type(a, source);
return c_type.merge(a_type);
}
}
TypeInfo::unknown()
}
fn infer_call_type(&self, node: tree_sitter::Node, source: &[u8]) -> TypeInfo {
let func_node = node
.child_by_field_name(self.semantics.function_field)
.or_else(|| node.named_child(0));
if let Some(func) = func_node {
let func_text = func.utf8_text(source).unwrap_or("");
let method_name = func_text.rsplit('.').next().unwrap_or(func_text);
if self.nullable_functions.contains(method_name) {
return TypeInfo::with_nullability(
InferredType::Unknown,
Nullability::PossiblyNull,
);
}
if func_text.starts_with("new ")
|| func_text.chars().next().is_some_and(|c| c.is_uppercase())
{
return TypeInfo::new(InferredType::Object);
}
if func_text == "parseInt" || func_text == "parseFloat" || func_text == "Number" {
return TypeInfo::new(InferredType::Number);
}
if func_text == "String" || func_text == "toString" || func_text.ends_with(".toString")
{
return TypeInfo::new(InferredType::String);
}
if func_text == "Boolean" {
return TypeInfo::new(InferredType::Boolean);
}
if func_text == "Array" || func_text.ends_with(".map") || func_text.ends_with(".filter")
{
return TypeInfo::new(InferredType::Array(Box::new(InferredType::Unknown)));
}
}
TypeInfo::unknown()
}
fn infer_binary_type(&self, node: tree_sitter::Node, source: &[u8]) -> TypeInfo {
let operator = node
.child_by_field_name(self.semantics.operator_field)
.or_else(|| {
let mut cursor = node.walk();
node.children(&mut cursor)
.find(|c| c.kind().contains("operator") || c.kind().len() <= 3)
});
let op_text = operator
.and_then(|op| op.utf8_text(source).ok())
.unwrap_or("");
let left = node.child_by_field_name(self.semantics.left_field);
let right = node.child_by_field_name(self.semantics.right_field);
match op_text {
"+" => {
if let Some(l) = left {
let l_type = self.infer_type(l, source);
if l_type.inferred_type == InferredType::String {
return TypeInfo::new(InferredType::String);
}
}
if let Some(r) = right {
let r_type = self.infer_type(r, source);
if r_type.inferred_type == InferredType::String {
return TypeInfo::new(InferredType::String);
}
}
TypeInfo::new(InferredType::Number)
}
"-" | "*" | "/" | "%" | "**" | "^" | "&" | "|" | "<<" | ">>" => {
TypeInfo::new(InferredType::Number)
}
"==" | "===" | "!=" | "!==" | "<" | ">" | "<=" | ">=" | "&&" | "||" | "!" => {
TypeInfo::new(InferredType::Boolean)
}
"??" => {
if let (Some(_l), Some(r)) = (left, right) {
let r_type = self.infer_type(r, source);
return TypeInfo::with_nullability(
r_type.inferred_type,
Nullability::DefinitelyNonNull,
);
}
TypeInfo::unknown()
}
_ => TypeInfo::unknown(),
}
}
pub fn type_from_origin(&self, origin: &ValueOrigin) -> TypeInfo {
match origin {
ValueOrigin::Literal(lit) => self.type_from_literal_text(lit),
ValueOrigin::Parameter(_) => TypeInfo::unknown(),
ValueOrigin::FunctionCall(func) => {
let method = func.rsplit('.').next().unwrap_or(func);
if self.nullable_functions.contains(method) {
TypeInfo::with_nullability(InferredType::Unknown, Nullability::PossiblyNull)
} else {
TypeInfo::unknown()
}
}
ValueOrigin::MemberAccess(_) => {
TypeInfo::with_nullability(InferredType::Unknown, Nullability::PossiblyNull)
}
ValueOrigin::BinaryExpression => TypeInfo::unknown(),
ValueOrigin::Variable(_) => TypeInfo::unknown(),
ValueOrigin::StringConcat(_) => TypeInfo::new(InferredType::String),
ValueOrigin::TemplateLiteral(_) => TypeInfo::new(InferredType::String),
ValueOrigin::MethodCall { method, .. } => {
let string_methods = [
"concat",
"join",
"trim",
"toLowerCase",
"toUpperCase",
"slice",
"substring",
"substr",
"replace",
"format",
];
if string_methods
.iter()
.any(|m| method.eq_ignore_ascii_case(m))
{
TypeInfo::new(InferredType::String)
} else {
TypeInfo::unknown()
}
}
ValueOrigin::Unknown => TypeInfo::unknown(),
}
}
fn type_from_literal_text(&self, text: &str) -> TypeInfo {
let trimmed = text.trim();
if (trimmed.starts_with('"') && trimmed.ends_with('"'))
|| (trimmed.starts_with('\'') && trimmed.ends_with('\''))
|| (trimmed.starts_with('`') && trimmed.ends_with('`'))
{
return TypeInfo::new(InferredType::String);
}
if trimmed == "true" || trimmed == "false" {
return TypeInfo::new(InferredType::Boolean);
}
if trimmed == "null" || trimmed == "nil" || trimmed == "None" {
return TypeInfo::null();
}
if trimmed == "undefined" {
return TypeInfo::undefined();
}
if trimmed.parse::<i64>().is_ok() || trimmed.parse::<f64>().is_ok() {
return TypeInfo::new(InferredType::Number);
}
TypeInfo::unknown()
}
}
pub struct TypeInferenceTransfer {
inferrer: TypeInferrer,
semantics: &'static LanguageSemantics,
}
impl TypeInferenceTransfer {
pub fn new(semantics: &'static LanguageSemantics) -> Self {
Self {
inferrer: TypeInferrer::new(semantics),
semantics,
}
}
}
impl TransferFunction<TypeFact> for TypeInferenceTransfer {
fn transfer(
&self,
block: &BasicBlock,
input: &HashSet<TypeFact>,
cfg: &CFG,
source: &[u8],
tree: &tree_sitter::Tree,
) -> HashSet<TypeFact> {
let mut state = input.clone();
for &stmt_node_id in &block.statements {
if let Some(node) = find_node_by_id(tree, stmt_node_id) {
self.process_statement(node, source, &mut state, cfg, block.id);
}
}
self.apply_branch_refinement(block, &mut state, source, tree);
state
}
}
impl TypeInferenceTransfer {
fn process_statement(
&self,
node: tree_sitter::Node,
source: &[u8],
state: &mut HashSet<TypeFact>,
_cfg: &CFG,
_block_id: BlockId,
) {
let kind = node.kind();
let sem = self.semantics;
if sem.is_variable_declaration(kind)
&& let Some((var_name, type_info)) = self.extract_declaration_type(node, source)
{
state.retain(|fact| fact.var_name != var_name);
state.insert(TypeFact::new(var_name, type_info));
}
if sem.is_assignment(kind)
&& let Some((var_name, type_info)) = self.extract_assignment_type(node, source)
{
state.retain(|fact| fact.var_name != var_name);
state.insert(TypeFact::new(var_name, type_info));
}
let mut cursor = node.walk();
for child in node.named_children(&mut cursor) {
if !sem.is_function_def(child.kind()) {
self.process_statement(child, source, state, _cfg, _block_id);
}
}
}
fn extract_declaration_type(
&self,
node: tree_sitter::Node,
source: &[u8],
) -> Option<(String, TypeInfo)> {
let sem = self.semantics;
let (name_node, value_node) = match node.kind() {
"variable_declarator" => (
node.child_by_field_name("name"),
node.child_by_field_name("value"),
),
"let_declaration" => (
node.child_by_field_name("pattern"),
node.child_by_field_name("value"),
),
"short_var_declaration" => {
let left = node.child_by_field_name("left");
let right = node.child_by_field_name("right");
if let (Some(l), Some(r)) = (left, right) {
let name = if l.kind() == "expression_list" {
l.named_child(0)
} else {
Some(l)
};
let value = if r.kind() == "expression_list" {
r.named_child(0)
} else {
Some(r)
};
(name, value)
} else {
(None, None)
}
}
"assignment" => (
node.child_by_field_name("left"),
node.child_by_field_name("right"),
),
"local_variable_declaration" => {
let mut cursor = node.walk();
let declarator = node
.named_children(&mut cursor)
.find(|c| c.kind() == "variable_declarator");
if let Some(d) = declarator {
(
d.child_by_field_name("name"),
d.child_by_field_name("value"),
)
} else {
(None, None)
}
}
_ => (
node.child_by_field_name(sem.name_field)
.or_else(|| node.child_by_field_name(sem.left_field)),
node.child_by_field_name(sem.value_field)
.or_else(|| node.child_by_field_name(sem.right_field)),
),
};
let name = name_node?;
if !sem.is_identifier(name.kind()) && name.kind() != "identifier" {
return None;
}
let name_str = name
.utf8_text(source)
.ok()?
.trim_start_matches("mut ")
.trim()
.to_string();
let type_info = if let Some(val) = value_node {
self.inferrer.infer_type(val, source)
} else {
TypeInfo::unknown()
};
Some((name_str, type_info))
}
fn extract_assignment_type(
&self,
node: tree_sitter::Node,
source: &[u8],
) -> Option<(String, TypeInfo)> {
let sem = self.semantics;
let left = node.child_by_field_name(sem.left_field)?;
let right = node.child_by_field_name(sem.right_field)?;
if !sem.is_identifier(left.kind()) && left.kind() != "identifier" {
return None;
}
let name = left.utf8_text(source).ok()?.to_string();
let type_info = self.inferrer.infer_type(right, source);
Some((name, type_info))
}
fn apply_branch_refinement(
&self,
_block: &BasicBlock,
_state: &mut HashSet<TypeFact>,
_source: &[u8],
_tree: &tree_sitter::Tree,
) {
}
fn extract_null_check(
&self,
node: tree_sitter::Node,
source: &[u8],
) -> Option<(String, bool, bool)> {
let kind = node.kind();
if self.semantics.is_binary_expression(kind) || kind == "binary_expression" {
let op = node
.child_by_field_name(self.semantics.operator_field)
.or_else(|| {
let mut cursor = node.walk();
node.children(&mut cursor)
.find(|c| !c.is_named() && c.kind().contains("="))
})?;
let op_text = op.utf8_text(source).ok()?;
let left = node.child_by_field_name(self.semantics.left_field)?;
let right = node.child_by_field_name(self.semantics.right_field)?;
let is_equality = op_text == "==" || op_text == "===";
let is_inequality = op_text == "!=" || op_text == "!==";
if !is_equality && !is_inequality {
return None;
}
let (var_node, null_node) = if self.is_null_or_undefined(right, source) {
(Some(left), Some(right))
} else if self.is_null_or_undefined(left, source) {
(Some(right), Some(left))
} else {
(None, None)
};
if let (Some(var), Some(_)) = (var_node, null_node)
&& (self.semantics.is_identifier(var.kind()) || var.kind() == "identifier")
{
let var_name = var.utf8_text(source).ok()?.to_string();
return Some((var_name, true, is_equality));
}
}
None
}
fn is_null_or_undefined(&self, node: tree_sitter::Node, source: &[u8]) -> bool {
let kind = node.kind();
if self.semantics.is_null_literal(kind) || kind == "null" || kind == "nil" || kind == "None"
{
return true;
}
if kind == "undefined" {
return true;
}
if kind == "identifier"
&& let Ok(text) = node.utf8_text(source)
{
return text == "null" || text == "undefined" || text == "nil" || text == "None";
}
false
}
}
#[derive(Debug, Clone, Default)]
pub struct NullabilityRefinements {
refinements: HashMap<BlockId, HashMap<String, Nullability>>,
}
impl NullabilityRefinements {
pub fn new() -> Self {
Self::default()
}
pub fn set(&mut self, block_id: BlockId, var_name: String, nullability: Nullability) {
self.refinements
.entry(block_id)
.or_default()
.insert(var_name, nullability);
}
pub fn get(&self, block_id: BlockId, var_name: &str) -> Option<Nullability> {
self.refinements
.get(&block_id)
.and_then(|m| m.get(var_name))
.copied()
}
pub fn has_refinement(&self, block_id: BlockId, var_name: &str) -> bool {
self.refinements
.get(&block_id)
.map(|m| m.contains_key(var_name))
.unwrap_or(false)
}
}
pub fn analyze_types(
cfg: &CFG,
tree: &tree_sitter::Tree,
source: &[u8],
semantics: &'static LanguageSemantics,
) -> DataflowResult<TypeFact> {
let transfer = TypeInferenceTransfer::new(semantics);
super::dataflow::solve(cfg, Direction::Forward, &transfer, source, tree)
}
pub fn infer_types_from_symbols(
symbols: &SymbolTable,
semantics: &'static LanguageSemantics,
) -> TypeTable {
let inferrer = TypeInferrer::new(semantics);
let mut type_table = TypeTable::new();
for (name, info) in symbols.iter() {
let type_info = inferrer.type_from_origin(&info.initializer);
type_table.set(name.clone(), type_info);
}
type_table
}
pub fn compute_nullability_refinements(
cfg: &CFG,
tree: &tree_sitter::Tree,
source: &[u8],
semantics: &'static LanguageSemantics,
) -> NullabilityRefinements {
let mut refinements = NullabilityRefinements::new();
let transfer = TypeInferenceTransfer::new(semantics);
for block in &cfg.blocks {
if !block.reachable {
continue;
}
if let Terminator::Branch {
condition_node,
true_block,
false_block,
} = &block.terminator
&& let Some(cond) = find_node_by_id(tree, *condition_node)
&& let Some((var_name, _is_null_check, is_equality)) =
transfer.extract_null_check(cond, source)
{
if is_equality {
refinements.set(*true_block, var_name.clone(), Nullability::DefinitelyNull);
refinements.set(*false_block, var_name, Nullability::DefinitelyNonNull);
} else {
refinements.set(
*true_block,
var_name.clone(),
Nullability::DefinitelyNonNull,
);
refinements.set(*false_block, var_name, Nullability::DefinitelyNull);
}
}
}
refinements
}
impl DataflowResult<TypeFact> {
pub fn type_at_entry(&self, block_id: BlockId, var_name: &str) -> Option<TypeInfo> {
self.block_entry.get(&block_id).and_then(|facts| {
facts
.iter()
.find(|f| f.var_name == var_name)
.map(|f| f.type_info.clone())
})
}
pub fn type_at_exit(&self, block_id: BlockId, var_name: &str) -> Option<TypeInfo> {
self.block_exit.get(&block_id).and_then(|facts| {
facts
.iter()
.find(|f| f.var_name == var_name)
.map(|f| f.type_info.clone())
})
}
pub fn inferred_type_at_entry(&self, block_id: BlockId, var_name: &str) -> InferredType {
self.type_at_entry(block_id, var_name)
.map(|info| info.inferred_type)
.unwrap_or(InferredType::Unknown)
}
pub fn nullability_at_entry(&self, block_id: BlockId, var_name: &str) -> Nullability {
self.type_at_entry(block_id, var_name)
.map(|info| info.nullability)
.unwrap_or(Nullability::Unknown)
}
pub fn is_possibly_null_at_entry(&self, block_id: BlockId, var_name: &str) -> bool {
self.nullability_at_entry(block_id, var_name)
.could_be_null()
}
pub fn is_definitely_non_null_at_entry(&self, block_id: BlockId, var_name: &str) -> bool {
self.nullability_at_entry(block_id, var_name)
.is_definitely_non_null()
}
pub fn type_table_at_entry(&self, block_id: BlockId) -> TypeTable {
let mut table = TypeTable::new();
if let Some(facts) = self.block_entry.get(&block_id) {
for fact in facts {
table.set(fact.var_name.clone(), fact.type_info.clone());
}
}
table
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::flow::cfg::CFG;
use rma_common::Language;
use rma_parser::ParserEngine;
use std::path::Path;
fn parse_js(code: &str) -> rma_parser::ParsedFile {
let config = rma_common::RmaConfig::default();
let parser = ParserEngine::new(config);
parser
.parse_file(Path::new("test.js"), code)
.expect("parse failed")
}
#[test]
fn test_infer_string_literal() {
let code = r#"const x = "hello";"#;
let parsed = parse_js(code);
let symbols = SymbolTable::build(&parsed, Language::JavaScript);
let semantics = LanguageSemantics::for_language(Language::JavaScript);
let type_table = infer_types_from_symbols(&symbols, semantics);
assert!(type_table.contains("x"));
let info = type_table.get("x").unwrap();
assert_eq!(info.inferred_type, InferredType::String);
assert_eq!(info.nullability, Nullability::DefinitelyNonNull);
}
#[test]
fn test_infer_number_literal() {
let code = "const x = 42;";
let parsed = parse_js(code);
let symbols = SymbolTable::build(&parsed, Language::JavaScript);
let semantics = LanguageSemantics::for_language(Language::JavaScript);
let type_table = infer_types_from_symbols(&symbols, semantics);
assert!(type_table.contains("x"));
let info = type_table.get("x").unwrap();
assert_eq!(info.inferred_type, InferredType::Number);
}
#[test]
fn test_infer_boolean_literal() {
let code = "const x = true;";
let parsed = parse_js(code);
let symbols = SymbolTable::build(&parsed, Language::JavaScript);
let semantics = LanguageSemantics::for_language(Language::JavaScript);
let type_table = infer_types_from_symbols(&symbols, semantics);
assert!(type_table.contains("x"));
let info = type_table.get("x").unwrap();
assert_eq!(info.inferred_type, InferredType::Boolean);
}
#[test]
fn test_infer_null_literal() {
let code = "const x = null;";
let parsed = parse_js(code);
let symbols = SymbolTable::build(&parsed, Language::JavaScript);
let semantics = LanguageSemantics::for_language(Language::JavaScript);
let type_table = infer_types_from_symbols(&symbols, semantics);
assert!(type_table.contains("x"));
let info = type_table.get("x").unwrap();
assert_eq!(info.inferred_type, InferredType::Null);
assert_eq!(info.nullability, Nullability::DefinitelyNull);
}
#[test]
fn test_assignment_propagation() {
let code = r#"const x = "hello";"#;
let parsed = parse_js(code);
let symbols = SymbolTable::build(&parsed, Language::JavaScript);
let semantics = LanguageSemantics::for_language(Language::JavaScript);
let type_table = infer_types_from_symbols(&symbols, semantics);
let x_info = type_table.get("x").expect("x should exist");
assert_eq!(x_info.inferred_type, InferredType::String);
}
#[test]
fn test_reassignment_type_change() {
let code = r#"let x = "hello";"#;
let parsed = parse_js(code);
let symbols = SymbolTable::build(&parsed, Language::JavaScript);
let semantics = LanguageSemantics::for_language(Language::JavaScript);
let type_table = infer_types_from_symbols(&symbols, semantics);
let x_info = type_table.get("x").expect("x should exist");
assert_eq!(x_info.inferred_type, InferredType::String);
}
#[test]
fn test_dataflow_type_propagation() {
let code = r#"
const x = "hello";
const y = 42;
"#;
let parsed = parse_js(code);
let cfg = CFG::build(&parsed, Language::JavaScript);
let semantics = LanguageSemantics::for_language(Language::JavaScript);
let result = analyze_types(&cfg, &parsed.tree, code.as_bytes(), semantics);
assert!(result.iterations > 0 || cfg.block_count() <= 1);
let any_types_inferred = result.block_exit.values().any(|facts| !facts.is_empty());
let _ = any_types_inferred;
}
#[test]
fn test_nullable_function_call() {
let code = "const x = array.find(item => item.id === 1);";
let parsed = parse_js(code);
let symbols = SymbolTable::build(&parsed, Language::JavaScript);
let semantics = LanguageSemantics::for_language(Language::JavaScript);
let type_table = infer_types_from_symbols(&symbols, semantics);
assert!(type_table.is_possibly_null("x"));
}
#[test]
fn test_null_check_refinement() {
let code = r#"
const x = getData();
if (x != null) {
console.log(x);
}
"#;
let parsed = parse_js(code);
let cfg = CFG::build(&parsed, Language::JavaScript);
let semantics = LanguageSemantics::for_language(Language::JavaScript);
let _refinements =
compute_nullability_refinements(&cfg, &parsed.tree, code.as_bytes(), semantics);
assert!(cfg.block_count() >= 3); }
#[test]
fn test_equality_null_check() {
let code = r#"
const x = getData();
if (x == null) {
return;
}
console.log(x);
"#;
let parsed = parse_js(code);
let cfg = CFG::build(&parsed, Language::JavaScript);
let semantics = LanguageSemantics::for_language(Language::JavaScript);
let _refinements =
compute_nullability_refinements(&cfg, &parsed.tree, code.as_bytes(), semantics);
assert!(cfg.block_count() >= 3);
}
#[test]
fn test_type_union() {
let a = InferredType::String;
let b = InferredType::Number;
let union = a.union(b);
match union {
InferredType::Union(types) => {
assert!(types.contains(&InferredType::String));
assert!(types.contains(&InferredType::Number));
}
_ => panic!("Expected Union type"),
}
}
#[test]
fn test_type_simplify() {
let union = InferredType::Union(vec![
InferredType::String,
InferredType::String,
InferredType::Number,
]);
let simplified = union.simplify();
match simplified {
InferredType::Union(types) => {
assert_eq!(types.len(), 2);
}
_ => panic!("Expected simplified Union"),
}
}
#[test]
fn test_single_type_union_simplifies() {
let union = InferredType::Union(vec![InferredType::String]);
let simplified = union.simplify();
assert_eq!(simplified, InferredType::String);
}
#[test]
fn test_nullability_merge_same() {
let a = Nullability::DefinitelyNonNull;
let b = Nullability::DefinitelyNonNull;
assert_eq!(a.merge(b), Nullability::DefinitelyNonNull);
}
#[test]
fn test_nullability_merge_conflict() {
let a = Nullability::DefinitelyNull;
let b = Nullability::DefinitelyNonNull;
assert_eq!(a.merge(b), Nullability::PossiblyNull);
}
#[test]
fn test_nullability_merge_with_possibly() {
let a = Nullability::DefinitelyNonNull;
let b = Nullability::PossiblyNull;
assert_eq!(a.merge(b), Nullability::PossiblyNull);
}
#[test]
fn test_type_table_operations() {
let mut table = TypeTable::new();
table.set("x".to_string(), TypeInfo::new(InferredType::String));
table.set("y".to_string(), TypeInfo::null());
assert!(table.contains("x"));
assert!(table.contains("y"));
assert!(!table.contains("z"));
assert!(table.is_definitely_non_null("x"));
assert!(table.is_definitely_null("y"));
assert!(table.is_possibly_null("z")); }
#[test]
fn test_type_table_merge() {
let mut table1 = TypeTable::new();
table1.set("x".to_string(), TypeInfo::new(InferredType::String));
let mut table2 = TypeTable::new();
table2.set("x".to_string(), TypeInfo::new(InferredType::Number));
table2.set("y".to_string(), TypeInfo::new(InferredType::Boolean));
table1.merge(&table2);
let x_type = table1.get("x").unwrap();
match &x_type.inferred_type {
InferredType::Union(_) => {}
_ => panic!("Expected Union type after merge"),
}
assert!(table1.contains("y"));
}
#[test]
fn test_type_inference_dataflow() {
let code = r#"
const x = "hello";
const y = 42;
const z = true;
"#;
let parsed = parse_js(code);
let cfg = CFG::build(&parsed, Language::JavaScript);
let semantics = LanguageSemantics::for_language(Language::JavaScript);
let result = analyze_types(&cfg, &parsed.tree, code.as_bytes(), semantics);
assert!(result.iterations > 0 || cfg.block_count() <= 1);
let _table = result.type_table_at_entry(cfg.exit);
}
#[test]
fn test_conditional_type_inference() {
let code = r#"
let x;
if (condition) {
x = "hello";
} else {
x = 42;
}
"#;
let parsed = parse_js(code);
let cfg = CFG::build(&parsed, Language::JavaScript);
let semantics = LanguageSemantics::for_language(Language::JavaScript);
let result = analyze_types(&cfg, &parsed.tree, code.as_bytes(), semantics);
assert!(result.iterations < cfg.block_count() * 25);
}
#[test]
fn test_type_display() {
assert_eq!(format!("{}", InferredType::String), "String");
assert_eq!(format!("{}", InferredType::Number), "Number");
assert_eq!(format!("{}", InferredType::Boolean), "Boolean");
assert_eq!(format!("{}", InferredType::Null), "null");
assert_eq!(format!("{}", InferredType::Undefined), "undefined");
assert_eq!(
format!("{}", InferredType::Array(Box::new(InferredType::Number))),
"Array<Number>"
);
assert_eq!(
format!("{}", InferredType::Optional(Box::new(InferredType::String))),
"String?"
);
assert_eq!(
format!(
"{}",
InferredType::Union(vec![InferredType::String, InferredType::Number])
),
"String | Number"
);
}
#[test]
fn test_type_info_merge() {
let a = TypeInfo::new(InferredType::String);
let b = TypeInfo::null();
let merged = a.merge(b);
assert!(matches!(
merged.inferred_type,
InferredType::Union(_) | InferredType::Optional(_)
));
assert_eq!(merged.nullability, Nullability::PossiblyNull);
}
#[test]
fn test_optional_type() {
let t = InferredType::String.make_optional();
assert!(t.is_nullable());
assert_eq!(*t.unwrap_optional(), InferredType::String);
}
}