use crate::ast::*;
use super::scope::{InferredType, TypeScope};
pub(super) fn simplify_union(members: Vec<TypeExpr>) -> TypeExpr {
let filtered: Vec<TypeExpr> = members
.into_iter()
.filter(|m| !matches!(m, TypeExpr::Never))
.collect();
match filtered.len() {
0 => TypeExpr::Never,
1 => filtered.into_iter().next().unwrap(),
_ => TypeExpr::Union(filtered),
}
}
pub(super) fn remove_from_union(members: &[TypeExpr], to_remove: &str) -> InferredType {
let remaining: Vec<TypeExpr> = members
.iter()
.filter(|m| !matches!(m, TypeExpr::Named(n) if n == to_remove))
.cloned()
.collect();
match remaining.len() {
0 => Some(TypeExpr::Never),
1 => Some(remaining.into_iter().next().unwrap()),
_ => Some(TypeExpr::Union(remaining)),
}
}
pub(super) fn narrow_to_single(members: &[TypeExpr], target: &str) -> InferredType {
if members
.iter()
.any(|m| matches!(m, TypeExpr::Named(n) if n == target))
{
Some(TypeExpr::Named(target.to_string()))
} else {
None
}
}
pub(super) fn extract_type_of_var(node: &SNode) -> Option<String> {
if let Node::FunctionCall { name, args } = &node.node {
if name == "type_of" && args.len() == 1 {
if let Node::Identifier(var) = &args[0].node {
return Some(var.clone());
}
}
}
None
}
pub(super) fn intersect_types(current: &TypeExpr, schema_type: &TypeExpr) -> Option<TypeExpr> {
match (current, schema_type) {
(TypeExpr::LitString(a), TypeExpr::LitString(b)) if a == b => {
Some(TypeExpr::LitString(a.clone()))
}
(TypeExpr::LitInt(a), TypeExpr::LitInt(b)) if a == b => Some(TypeExpr::LitInt(*a)),
(TypeExpr::LitString(s), TypeExpr::Named(n))
| (TypeExpr::Named(n), TypeExpr::LitString(s))
if n == "string" =>
{
Some(TypeExpr::LitString(s.clone()))
}
(TypeExpr::LitInt(v), TypeExpr::Named(n)) | (TypeExpr::Named(n), TypeExpr::LitInt(v))
if n == "int" || n == "float" =>
{
Some(TypeExpr::LitInt(*v))
}
(TypeExpr::Union(members), other) => {
let kept = members
.iter()
.filter_map(|member| intersect_types(member, other))
.collect::<Vec<_>>();
match kept.len() {
0 => None,
1 => kept.into_iter().next(),
_ => Some(TypeExpr::Union(kept)),
}
}
(other, TypeExpr::Union(members)) => {
let kept = members
.iter()
.filter_map(|member| intersect_types(other, member))
.collect::<Vec<_>>();
match kept.len() {
0 => None,
1 => kept.into_iter().next(),
_ => Some(TypeExpr::Union(kept)),
}
}
(TypeExpr::Named(left), TypeExpr::Named(right)) if left == right => {
Some(TypeExpr::Named(left.clone()))
}
(TypeExpr::Named(name), TypeExpr::Shape(fields)) if name == "dict" => {
Some(TypeExpr::Shape(fields.clone()))
}
(TypeExpr::Shape(fields), TypeExpr::Named(name)) if name == "dict" => {
Some(TypeExpr::Shape(fields.clone()))
}
(TypeExpr::Named(name), TypeExpr::List(inner)) if name == "list" => {
Some(TypeExpr::List(inner.clone()))
}
(TypeExpr::List(inner), TypeExpr::Named(name)) if name == "list" => {
Some(TypeExpr::List(inner.clone()))
}
(TypeExpr::Named(name), TypeExpr::DictType(key, value)) if name == "dict" => {
Some(TypeExpr::DictType(key.clone(), value.clone()))
}
(TypeExpr::DictType(key, value), TypeExpr::Named(name)) if name == "dict" => {
Some(TypeExpr::DictType(key.clone(), value.clone()))
}
(TypeExpr::Shape(_), TypeExpr::Shape(fields)) => Some(TypeExpr::Shape(fields.clone())),
(TypeExpr::List(current_inner), TypeExpr::List(schema_inner)) => {
intersect_types(current_inner, schema_inner)
.map(|inner| TypeExpr::List(Box::new(inner)))
}
(
TypeExpr::DictType(current_key, current_value),
TypeExpr::DictType(schema_key, schema_value),
) => {
let key = intersect_types(current_key, schema_key)?;
let value = intersect_types(current_value, schema_value)?;
Some(TypeExpr::DictType(Box::new(key), Box::new(value)))
}
_ => None,
}
}
pub(super) fn subtract_type(current: &TypeExpr, schema_type: &TypeExpr) -> Option<TypeExpr> {
match current {
TypeExpr::Union(members) => {
let remaining = members
.iter()
.filter(|member| intersect_types(member, schema_type).is_none())
.cloned()
.collect::<Vec<_>>();
match remaining.len() {
0 => None,
1 => remaining.into_iter().next(),
_ => Some(TypeExpr::Union(remaining)),
}
}
other if intersect_types(other, schema_type).is_some() => None,
other => Some(other.clone()),
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(super) enum DiscriminantValue {
Str(String),
Int(i64),
}
impl DiscriminantValue {
pub(super) fn from_type(t: &TypeExpr) -> Option<Self> {
match t {
TypeExpr::LitString(s) => Some(Self::Str(s.clone())),
TypeExpr::LitInt(v) => Some(Self::Int(*v)),
_ => None,
}
}
}
pub(super) fn discriminant_field(members: &[TypeExpr]) -> Option<String> {
if members.len() < 2 {
return None;
}
let shapes: Vec<&[ShapeField]> = members
.iter()
.map(|m| match m {
TypeExpr::Shape(fields) => Some(fields.as_slice()),
_ => None,
})
.collect::<Option<Vec<_>>>()?;
let first = shapes[0];
'fields: for candidate in first {
if candidate.optional {
continue;
}
let Some(first_value) = DiscriminantValue::from_type(&candidate.type_expr) else {
continue;
};
let mut seen: Vec<DiscriminantValue> = vec![first_value];
for fields in &shapes[1..] {
let Some(field) = fields.iter().find(|f| f.name == candidate.name) else {
continue 'fields;
};
if field.optional {
continue 'fields;
}
let Some(value) = DiscriminantValue::from_type(&field.type_expr) else {
continue 'fields;
};
if seen.contains(&value) {
continue 'fields;
}
seen.push(value);
}
return Some(candidate.name.clone());
}
None
}
pub(super) fn narrow_shape_union_by_tag(
members: &[TypeExpr],
tag_field: &str,
tag_value: &DiscriminantValue,
) -> Option<(TypeExpr, TypeExpr)> {
let mut matched: Option<TypeExpr> = None;
let mut residual: Vec<TypeExpr> = Vec::with_capacity(members.len());
for member in members {
let TypeExpr::Shape(fields) = member else {
return None;
};
let field = fields.iter().find(|f| f.name == tag_field)?;
let value = DiscriminantValue::from_type(&field.type_expr)?;
if &value == tag_value {
if matched.is_some() {
return None;
}
matched = Some(member.clone());
} else {
residual.push(member.clone());
}
}
let matched = matched?;
Some((matched, simplify_union(residual)))
}
pub(super) fn apply_refinements(scope: &mut TypeScope, refinements: &[(String, InferredType)]) {
for (var_name, narrowed_type) in refinements {
if !scope.narrowed_vars.contains_key(var_name) {
if let Some(original) = scope.get_var(var_name).cloned() {
scope.narrowed_vars.insert(var_name.clone(), original);
}
}
scope.define_var(var_name, narrowed_type.clone());
}
}