use crate::ast::*;
use super::scope::InferredType;
use super::union::{contains_nil, simplify_union, without_nil};
fn dict_like(ty: &TypeExpr) -> bool {
matches!(ty, TypeExpr::Named(n) if n == "dict")
|| matches!(ty, TypeExpr::DictType(..))
|| matches!(ty, TypeExpr::Shape(_))
}
pub(in crate::typechecker) fn merge_shape_fields(
left: &[ShapeField],
right: &[ShapeField],
) -> Vec<ShapeField> {
let mut out: Vec<ShapeField> = left.to_vec();
for rf in right {
if let Some(slot) = out.iter_mut().find(|f| f.name == rf.name) {
slot.optional = slot.optional && rf.optional;
slot.type_expr = if rf.optional {
simplify_union(vec![slot.type_expr.clone(), rf.type_expr.clone()])
} else {
rf.type_expr.clone()
};
} else {
out.push(rf.clone());
}
}
out
}
pub(in crate::typechecker) fn fold_open_shape(
fields: Vec<ShapeField>,
rests: Vec<TypeExpr>,
) -> TypeExpr {
let mut acc = fields;
let mut leftover: Vec<TypeExpr> = Vec::new();
let mut work: Vec<TypeExpr> = rests;
let mut i = 0;
while i < work.len() {
match std::mem::replace(&mut work[i], TypeExpr::Never) {
TypeExpr::Shape(rf) => acc = merge_shape_fields(&acc, &rf),
TypeExpr::OpenShape {
fields: rf,
rests: rr,
} => {
acc = merge_shape_fields(&acc, &rf);
let tail = work.split_off(i + 1);
work.truncate(i + 1);
work.extend(rr);
work.extend(tail);
}
other => leftover.push(other),
}
i += 1;
}
if leftover.is_empty() {
TypeExpr::Shape(acc)
} else {
TypeExpr::OpenShape {
fields: acc,
rests: leftover,
}
}
}
pub(super) fn infer_binary_op_type(
op: &str,
left: &InferredType,
right: &InferredType,
) -> InferredType {
match op {
"==" | "!=" | "<" | ">" | "<=" | ">=" | "&&" | "||" | "in" | "not_in" => {
Some(TypeExpr::Named("bool".into()))
}
"+" => match (left, right) {
(Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
match (l.as_str(), r.as_str()) {
("int", "int") => Some(TypeExpr::Named("int".into())),
("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
("string", "string") => Some(TypeExpr::Named("string".into())),
("list", "list") => Some(TypeExpr::Named("list".into())),
("dict", "dict") => Some(TypeExpr::Named("dict".into())),
_ => None,
}
}
(Some(TypeExpr::Shape(l)), Some(TypeExpr::Shape(r))) => {
Some(TypeExpr::Shape(merge_shape_fields(l, r)))
}
(Some(l), Some(r)) if dict_like(l) && dict_like(r) => {
Some(TypeExpr::Named("dict".into()))
}
_ => None,
},
"-" | "/" | "%" => match (left, right) {
(Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
match (l.as_str(), r.as_str()) {
("int", "int") => Some(TypeExpr::Named("int".into())),
("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
_ => None,
}
}
_ => None,
},
"**" => match (left, right) {
(Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
match (l.as_str(), r.as_str()) {
("int", "int") => Some(TypeExpr::Named("int".into())),
("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
_ => None,
}
}
_ => None,
},
"*" => match (left, right) {
(Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
match (l.as_str(), r.as_str()) {
("string", "int") | ("int", "string") => Some(TypeExpr::Named("string".into())),
("int", "int") => Some(TypeExpr::Named("int".into())),
("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
_ => None,
}
}
_ => None,
},
"??" => match (left, right) {
(Some(left), right) => match without_nil(left) {
None => right.clone(),
Some(non_nil_left) if !contains_nil(left) => Some(non_nil_left),
Some(non_nil_left) => match right {
Some(right) if &non_nil_left == right => Some(non_nil_left),
Some(right) => Some(simplify_union(vec![non_nil_left, right.clone()])),
None => Some(non_nil_left),
},
},
(None, _) => right.clone(),
},
"|>" => None,
_ => None,
}
}