use crate::ast::*;
use super::scope::InferredType;
use super::union::{contains_nil, simplify_union, without_nil};
fn dict_like(ty: &TypeExpr) -> bool {
matches!(ty, TypeExpr::Named(n) if n == "dict")
|| matches!(ty, TypeExpr::DictType(..))
|| matches!(ty, TypeExpr::Shape(_))
}
pub(super) fn infer_binary_op_type(
op: &str,
left: &InferredType,
right: &InferredType,
) -> InferredType {
match op {
"==" | "!=" | "<" | ">" | "<=" | ">=" | "&&" | "||" | "in" | "not_in" => {
Some(TypeExpr::Named("bool".into()))
}
"+" => match (left, right) {
(Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
match (l.as_str(), r.as_str()) {
("int", "int") => Some(TypeExpr::Named("int".into())),
("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
("string", "string") => Some(TypeExpr::Named("string".into())),
("list", "list") => Some(TypeExpr::Named("list".into())),
("dict", "dict") => Some(TypeExpr::Named("dict".into())),
_ => None,
}
}
(Some(l), Some(r)) if dict_like(l) && dict_like(r) => {
Some(TypeExpr::Named("dict".into()))
}
_ => None,
},
"-" | "/" | "%" => match (left, right) {
(Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
match (l.as_str(), r.as_str()) {
("int", "int") => Some(TypeExpr::Named("int".into())),
("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
_ => None,
}
}
_ => None,
},
"**" => match (left, right) {
(Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
match (l.as_str(), r.as_str()) {
("int", "int") => Some(TypeExpr::Named("int".into())),
("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
_ => None,
}
}
_ => None,
},
"*" => match (left, right) {
(Some(TypeExpr::Named(l)), Some(TypeExpr::Named(r))) => {
match (l.as_str(), r.as_str()) {
("string", "int") | ("int", "string") => Some(TypeExpr::Named("string".into())),
("int", "int") => Some(TypeExpr::Named("int".into())),
("float", _) | (_, "float") => Some(TypeExpr::Named("float".into())),
_ => None,
}
}
_ => None,
},
"??" => match (left, right) {
(Some(left), right) => match without_nil(left) {
None => right.clone(),
Some(non_nil_left) if !contains_nil(left) => Some(non_nil_left),
Some(non_nil_left) => match right {
Some(right) if &non_nil_left == right => Some(non_nil_left),
Some(right) => Some(simplify_union(vec![non_nil_left, right.clone()])),
None => Some(non_nil_left),
},
},
(None, _) => right.clone(),
},
"|>" => None,
_ => None,
}
}