use crate::ast::*;
use super::scope::{InferredType, TypeScope};
pub(super) fn simplify_union(members: Vec<TypeExpr>) -> TypeExpr {
let mut filtered: Vec<TypeExpr> = Vec::new();
for member in members {
collect_simplified_union_member(member, &mut filtered);
}
match filtered.len() {
0 => TypeExpr::Never,
1 => filtered.into_iter().next().unwrap(),
_ => TypeExpr::Union(filtered),
}
}
fn collect_simplified_union_member(member: TypeExpr, filtered: &mut Vec<TypeExpr>) {
match member {
TypeExpr::Never => {}
TypeExpr::Union(members) => {
for member in members {
collect_simplified_union_member(member, filtered);
}
}
member => {
if !filtered.contains(&member) {
filtered.push(member);
}
}
}
}
pub(super) fn without_nil(ty: &TypeExpr) -> InferredType {
match ty {
TypeExpr::Named(name) if name == "nil" => None,
TypeExpr::Union(members) => {
let non_nil: Vec<TypeExpr> = members.iter().filter_map(without_nil).collect();
(!non_nil.is_empty()).then(|| simplify_union(non_nil))
}
other => Some(other.clone()),
}
}
pub(super) fn contains_nil(ty: &TypeExpr) -> bool {
match ty {
TypeExpr::Named(name) if name == "nil" => true,
TypeExpr::Union(members) => members.iter().any(contains_nil),
_ => false,
}
}
fn member_matches_runtime_kind(member: &TypeExpr, target: &str) -> bool {
match (member, target) {
(TypeExpr::Named(n), t) => n == t,
(TypeExpr::List(_), "list") => true,
(TypeExpr::DictType(_, _), "dict") => true,
(TypeExpr::Shape(_), "dict") => true,
(TypeExpr::FnType { .. }, "closure") => true,
(TypeExpr::Iter(_), "iter") => true,
(TypeExpr::Generator(_), "generator") => true,
(TypeExpr::Stream(_), "stream") => true,
(TypeExpr::LitString(_), "string") => true,
(TypeExpr::LitInt(_), "int") => true,
_ => false,
}
}
pub(super) fn remove_from_union(members: &[TypeExpr], to_remove: &str) -> InferredType {
let remaining: Vec<TypeExpr> = members
.iter()
.filter(|m| !member_matches_runtime_kind(m, to_remove))
.cloned()
.collect();
match remaining.len() {
0 => Some(TypeExpr::Never),
1 => Some(remaining.into_iter().next().unwrap()),
_ => Some(TypeExpr::Union(remaining)),
}
}
pub(super) fn narrow_to_single(members: &[TypeExpr], target: &str) -> InferredType {
let matched: Vec<TypeExpr> = members
.iter()
.filter(|m| member_matches_runtime_kind(m, target))
.cloned()
.collect();
match matched.len() {
0 => None,
1 => Some(matched.into_iter().next().unwrap()),
_ => Some(TypeExpr::Union(matched)),
}
}
pub(super) fn extract_type_of_var(node: &SNode) -> Option<String> {
if let Node::FunctionCall { name, args, .. } = &node.node {
if name == "type_of" && args.len() == 1 {
if let Node::Identifier(var) = &args[0].node {
return Some(var.clone());
}
}
}
None
}
fn intersect_shapes(
current_fields: &[ShapeField],
schema_fields: &[ShapeField],
) -> Option<TypeExpr> {
let mut merged: Vec<ShapeField> = Vec::with_capacity(current_fields.len());
for field in current_fields {
if let Some(schema_field) = schema_fields.iter().find(|f| f.name == field.name) {
let intersected = intersect_types(&field.type_expr, &schema_field.type_expr)?;
merged.push(ShapeField {
name: field.name.clone(),
type_expr: intersected,
optional: field.optional && schema_field.optional,
});
} else {
merged.push(field.clone());
}
}
for schema_field in schema_fields {
if schema_field.optional {
continue;
}
if merged.iter().any(|f| f.name == schema_field.name) {
continue;
}
merged.push(schema_field.clone());
}
Some(TypeExpr::Shape(merged))
}
fn intersect_union_with(members: &[TypeExpr], other: &TypeExpr, flip: bool) -> Option<TypeExpr> {
let kept = members
.iter()
.filter_map(|member| {
if flip {
intersect_types(other, member)
} else {
intersect_types(member, other)
}
})
.collect::<Vec<_>>();
match kept.len() {
0 => None,
1 => kept.into_iter().next(),
_ => Some(TypeExpr::Union(kept)),
}
}
pub(super) fn intersect_types(current: &TypeExpr, schema_type: &TypeExpr) -> Option<TypeExpr> {
match (current, schema_type) {
(TypeExpr::Owned(c), TypeExpr::Owned(s)) => {
return intersect_types(c, s).map(|t| TypeExpr::Owned(Box::new(t)));
}
(TypeExpr::Owned(inner), other) | (other, TypeExpr::Owned(inner)) => {
return intersect_types(inner, other).map(|t| TypeExpr::Owned(Box::new(t)));
}
_ => {}
}
match (current, schema_type) {
(TypeExpr::LitString(a), TypeExpr::LitString(b)) if a == b => {
Some(TypeExpr::LitString(a.clone()))
}
(TypeExpr::LitInt(a), TypeExpr::LitInt(b)) if a == b => Some(TypeExpr::LitInt(*a)),
(TypeExpr::LitString(s), TypeExpr::Named(n))
| (TypeExpr::Named(n), TypeExpr::LitString(s))
if n == "string" =>
{
Some(TypeExpr::LitString(s.clone()))
}
(TypeExpr::LitInt(v), TypeExpr::Named(n)) | (TypeExpr::Named(n), TypeExpr::LitInt(v))
if n == "int" || n == "float" =>
{
Some(TypeExpr::LitInt(*v))
}
(TypeExpr::Union(members), other) => intersect_union_with(members, other, false),
(other, TypeExpr::Union(members)) => intersect_union_with(members, other, true),
(TypeExpr::Named(left), TypeExpr::Named(right)) if left == right => {
Some(TypeExpr::Named(left.clone()))
}
(TypeExpr::Named(name), TypeExpr::Shape(fields))
| (TypeExpr::Shape(fields), TypeExpr::Named(name))
if name == "dict" =>
{
Some(TypeExpr::Shape(fields.clone()))
}
(TypeExpr::Named(name), TypeExpr::DictType(key, value))
| (TypeExpr::DictType(key, value), TypeExpr::Named(name))
if name == "dict" =>
{
Some(TypeExpr::DictType(key.clone(), value.clone()))
}
(TypeExpr::Named(name), TypeExpr::List(inner))
| (TypeExpr::List(inner), TypeExpr::Named(name))
if name == "list" =>
{
Some(TypeExpr::List(inner.clone()))
}
(TypeExpr::Named(name), TypeExpr::Iter(inner))
| (TypeExpr::Iter(inner), TypeExpr::Named(name))
if name == "iter" =>
{
Some(TypeExpr::Iter(inner.clone()))
}
(TypeExpr::Named(name), TypeExpr::Generator(inner))
| (TypeExpr::Generator(inner), TypeExpr::Named(name))
if name == "generator" || name == "Generator" =>
{
Some(TypeExpr::Generator(inner.clone()))
}
(TypeExpr::Named(name), TypeExpr::Stream(inner))
| (TypeExpr::Stream(inner), TypeExpr::Named(name))
if name == "stream" || name == "Stream" =>
{
Some(TypeExpr::Stream(inner.clone()))
}
(TypeExpr::Shape(c), TypeExpr::Shape(s)) => intersect_shapes(c, s),
(TypeExpr::List(c), TypeExpr::List(s)) => {
intersect_types(c, s).map(|i| TypeExpr::List(Box::new(i)))
}
(TypeExpr::Iter(c), TypeExpr::Iter(s)) => {
intersect_types(c, s).map(|i| TypeExpr::Iter(Box::new(i)))
}
(TypeExpr::Generator(c), TypeExpr::Generator(s)) => {
intersect_types(c, s).map(|i| TypeExpr::Generator(Box::new(i)))
}
(TypeExpr::Stream(c), TypeExpr::Stream(s)) => {
intersect_types(c, s).map(|i| TypeExpr::Stream(Box::new(i)))
}
(TypeExpr::DictType(ck, cv), TypeExpr::DictType(sk, sv)) => {
let key = intersect_types(ck, sk)?;
let value = intersect_types(cv, sv)?;
Some(TypeExpr::DictType(Box::new(key), Box::new(value)))
}
(TypeExpr::Named(name), other) | (other, TypeExpr::Named(name))
if matches!(name.as_str(), "unknown" | "any") =>
{
Some(other.clone())
}
_ => None,
}
}
pub(super) fn subtract_type(current: &TypeExpr, schema_type: &TypeExpr) -> Option<TypeExpr> {
match current {
TypeExpr::Union(members) => {
let remaining = members
.iter()
.filter(|member| intersect_types(member, schema_type).is_none())
.cloned()
.collect::<Vec<_>>();
match remaining.len() {
0 => None,
1 => remaining.into_iter().next(),
_ => Some(TypeExpr::Union(remaining)),
}
}
TypeExpr::Named(name) if matches!(name.as_str(), "unknown" | "any") => {
Some(current.clone())
}
other if intersect_types(other, schema_type).is_some() => None,
other => Some(other.clone()),
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(super) enum DiscriminantValue {
Str(String),
Int(i64),
}
impl DiscriminantValue {
pub(super) fn from_type(t: &TypeExpr) -> Option<Self> {
match t {
TypeExpr::LitString(s) => Some(Self::Str(s.clone())),
TypeExpr::LitInt(v) => Some(Self::Int(*v)),
_ => None,
}
}
}
pub(super) fn discriminant_field(members: &[TypeExpr]) -> Option<String> {
if members.len() < 2 {
return None;
}
let shapes: Vec<&[ShapeField]> = members
.iter()
.map(|m| match m {
TypeExpr::Shape(fields) => Some(fields.as_slice()),
_ => None,
})
.collect::<Option<Vec<_>>>()?;
let first = shapes[0];
'fields: for candidate in first {
if candidate.optional {
continue;
}
let Some(first_value) = DiscriminantValue::from_type(&candidate.type_expr) else {
continue;
};
let mut seen: Vec<DiscriminantValue> = vec![first_value];
for fields in &shapes[1..] {
let Some(field) = fields.iter().find(|f| f.name == candidate.name) else {
continue 'fields;
};
if field.optional {
continue 'fields;
}
let Some(value) = DiscriminantValue::from_type(&field.type_expr) else {
continue 'fields;
};
if seen.contains(&value) {
continue 'fields;
}
seen.push(value);
}
return Some(candidate.name.clone());
}
None
}
pub(super) fn narrow_shape_union_by_tag(
members: &[TypeExpr],
tag_field: &str,
tag_value: &DiscriminantValue,
) -> Option<(TypeExpr, TypeExpr)> {
let mut matched: Option<TypeExpr> = None;
let mut residual: Vec<TypeExpr> = Vec::with_capacity(members.len());
for member in members {
let TypeExpr::Shape(fields) = member else {
return None;
};
let field = fields.iter().find(|f| f.name == tag_field)?;
let value = DiscriminantValue::from_type(&field.type_expr)?;
if &value == tag_value {
if matched.is_some() {
return None;
}
matched = Some(member.clone());
} else {
residual.push(member.clone());
}
}
let matched = matched?;
Some((matched, simplify_union(residual)))
}
pub(super) fn apply_refinements(scope: &mut TypeScope, refinements: &[(String, InferredType)]) {
for (var_name, narrowed_type) in refinements {
if !scope.narrowed_vars.contains_key(var_name) {
if let Some(original) = scope.get_var(var_name).cloned() {
scope.narrowed_vars.insert(var_name.clone(), original);
}
}
scope.define_var(var_name, narrowed_type.clone());
}
}