use std::{ops::Deref, sync::Arc};
use emmylua_parser::{
LuaAssignStat, LuaAst, LuaAstNode, LuaCallArgList, LuaCallExpr, LuaExpr, LuaIndexMemberExpr,
LuaLiteralToken, LuaLocalStat, LuaReturnStat, LuaTableExpr, LuaTableField,
};
use crate::{
InFiled, InferGuard, LuaArrayType, LuaDeclId, LuaInferCache, LuaMemberId, LuaTupleStatus,
LuaTupleType, LuaUnionType, TypeOps, VariadicType, check_type_compact,
db_index::{DbIndex, LuaType},
infer_call_expr_func, infer_expr,
};
use super::{
InferFailReason, InferResult,
infer_index::{infer_member_by_member_key, infer_member_by_operator},
};
pub fn infer_table_expr(
db: &DbIndex,
cache: &mut LuaInferCache,
table: LuaTableExpr,
) -> InferResult {
if table.is_array() {
return infer_table_tuple_or_array(db, cache, table);
}
Ok(LuaType::TableConst(crate::InFiled {
file_id: cache.get_file_id(),
value: table.get_range(),
}))
}
fn infer_table_tuple_or_array(
db: &DbIndex,
cache: &mut LuaInferCache,
table: LuaTableExpr,
) -> InferResult {
let fields = table.get_fields().collect::<Vec<_>>();
if fields.len() > 50 {
let first_type = infer_expr(
db,
cache,
fields[0].get_value_expr().ok_or(InferFailReason::None)?,
)?;
return Ok(LuaType::Array(
LuaArrayType::from_base_type(first_type).into(),
));
}
if let Some(first_field) = fields.first() {
let first_value_expr = first_field.get_value_expr().ok_or(InferFailReason::None)?;
if is_dots_expr(&first_value_expr).unwrap_or(false) {
let first_expr_type = infer_expr(db, cache, first_value_expr)?;
match &first_expr_type {
LuaType::Variadic(multi) => match &multi.deref() {
VariadicType::Base(base) => {
return Ok(LuaType::Array(
LuaArrayType::from_base_type(base.clone()).into(),
));
}
VariadicType::Multi(tuple) => {
return Ok(LuaType::Tuple(
LuaTupleType::new(tuple.clone(), LuaTupleStatus::InferResolve).into(),
));
}
},
_ => {
return Ok(LuaType::Array(
LuaArrayType::from_base_type(first_expr_type).into(),
));
}
};
}
}
if let Some(last_field) = fields.last() {
let last_value_expr = last_field.get_value_expr().ok_or(InferFailReason::None)?;
let last_expr_type = infer_expr(db, cache, last_value_expr)?;
if let LuaType::Variadic(multi) = last_expr_type
&& let VariadicType::Base(base) = &multi.deref()
{
let non_nil_base = TypeOps::Remove.apply(db, base, &LuaType::Nil);
if fields.len() <= 1 {
return Ok(LuaType::Array(
LuaArrayType::from_base_type(non_nil_base).into(),
));
}
let len = fields.len() - 1;
let mut all_can_accept_base = true;
for i in 0..len {
let field = fields.get(i).ok_or(InferFailReason::None)?;
let value_expr = field.get_value_expr().ok_or(InferFailReason::None)?;
let typ = infer_expr(db, cache, value_expr)?;
if check_type_compact(db, &non_nil_base, &typ).is_err() {
all_can_accept_base = false;
break;
}
}
if all_can_accept_base {
return Ok(LuaType::Array(
LuaArrayType::from_base_type(non_nil_base).into(),
));
}
};
}
let mut types = Vec::new();
for field in fields {
let value_expr = field.get_value_expr().ok_or(InferFailReason::None)?;
let typ = infer_expr(db, cache, value_expr)?;
match typ {
LuaType::Variadic(multi) => flatten_multi_into_tuple(&mut types, &multi),
_ => {
types.push(typ);
}
}
}
Ok(LuaType::Tuple(
LuaTupleType::new(types, LuaTupleStatus::InferResolve).into(),
))
}
fn flatten_multi_into_tuple(tuple_list: &mut Vec<LuaType>, multi: &VariadicType) {
match multi {
VariadicType::Base(base) => {
tuple_list.push(LuaType::Variadic(VariadicType::Base(base.clone()).into()));
}
VariadicType::Multi(multi) => {
for typ in multi {
match typ {
LuaType::Variadic(multi) => {
flatten_multi_into_tuple(tuple_list, multi.deref());
}
_ => {
tuple_list.push(typ.clone());
}
}
}
}
}
}
fn is_dots_expr(expr: &LuaExpr) -> Option<bool> {
if let LuaExpr::LiteralExpr(literal) = expr
&& let LuaLiteralToken::Dots(_) = literal.get_literal()?
{
return Some(true);
}
Some(false)
}
pub fn infer_table_should_be(
db: &DbIndex,
cache: &mut LuaInferCache,
table: LuaTableExpr,
) -> InferResult {
match table.get_parent::<LuaAst>().ok_or(InferFailReason::None)? {
LuaAst::LuaCallArgList(call_arg_list) => {
infer_table_type_by_callee(db, cache, call_arg_list, table)
}
LuaAst::LuaTableField(field) => infer_table_field_type_by_parent(db, cache, field),
LuaAst::LuaLocalStat(local) => infer_table_type_by_local(db, cache, local, table),
LuaAst::LuaAssignStat(assign_stat) => {
infer_table_type_by_assign_stat(db, cache, assign_stat, table)
}
LuaAst::LuaReturnStat(return_stat) => {
infer_table_type_by_return_stat(db, cache, return_stat, table)
}
_ => Err(InferFailReason::None),
}
}
pub fn infer_table_field_value_should_be(
db: &DbIndex,
cache: &mut LuaInferCache,
table_field: LuaTableField,
) -> InferResult {
let parnet_table_expr = table_field
.get_parent::<LuaTableExpr>()
.ok_or(InferFailReason::None)?;
let parent_table_expr_type = infer_table_should_be(db, cache, parnet_table_expr)?;
let index = LuaIndexMemberExpr::TableField(table_field.clone());
let reason = match infer_member_by_member_key(
db,
cache,
&parent_table_expr_type,
index.clone(),
&InferGuard::new(),
) {
Ok(member_type) => return Ok(member_type),
Err(InferFailReason::FieldNotFound) => InferFailReason::FieldNotFound,
Err(err) => return Err(err),
};
match infer_member_by_operator(
db,
cache,
&parent_table_expr_type,
index,
&InferGuard::new(),
) {
Ok(member_type) => return Ok(member_type),
Err(InferFailReason::FieldNotFound) => {}
Err(err) => return Err(err),
}
let member_id = LuaMemberId::new(table_field.get_syntax_id(), cache.get_file_id());
if let Some(type_cache) = db.get_type_index().get_type_cache(&member_id.into()) {
return Ok(type_cache.as_type().clone());
};
Err(reason)
}
fn infer_table_type_by_callee(
db: &DbIndex,
cache: &mut LuaInferCache,
call_arg_list: LuaCallArgList,
table_expr: LuaTableExpr,
) -> InferResult {
let call_expr = call_arg_list
.get_parent::<LuaCallExpr>()
.ok_or(InferFailReason::None)?;
let prefix_expr = call_expr.get_prefix_expr().ok_or(InferFailReason::None)?;
let prefix_type = infer_expr(db, cache, prefix_expr)?;
let func_type = infer_call_expr_func(
db,
cache,
call_expr.clone(),
prefix_type,
&InferGuard::new(),
None,
)?;
let param_types = func_type.get_params();
let mut call_arg_number = call_arg_list
.children::<LuaAst>()
.enumerate()
.find(|(_, arg)| arg.get_position() == table_expr.get_position())
.ok_or(InferFailReason::None)?
.0;
match (func_type.is_colon_define(), call_expr.is_colon_call()) {
(true, true) | (false, false) => {}
(false, true) => {
call_arg_number += 1;
}
(true, false) => {
call_arg_number = call_arg_number.saturating_sub(1);
}
}
let typ = param_types
.get(call_arg_number)
.ok_or(InferFailReason::None)?
.1
.clone()
.unwrap_or(LuaType::Any);
match &typ {
LuaType::TableConst(_) => {}
LuaType::Union(union) => {
return Ok(union_remove_non_table_type(db, union));
}
_ => {}
}
Ok(typ)
}
fn union_remove_non_table_type(db: &DbIndex, union: &Arc<LuaUnionType>) -> LuaType {
let mut result = None;
for typ in union.into_set().into_iter() {
match typ {
LuaType::Signature(_) | LuaType::DocFunction(_) => {}
_ if typ.is_string() || typ.is_number() || typ.is_boolean() => {}
_ => {
result = Some(match result {
Some(result) => TypeOps::Union.apply(db, &result, &typ),
None => typ,
});
}
}
}
result.unwrap_or(LuaType::Unknown)
}
fn infer_table_field_type_by_parent(
db: &DbIndex,
cache: &mut LuaInferCache,
field: LuaTableField,
) -> InferResult {
let member_id = LuaMemberId::new(field.get_syntax_id(), cache.get_file_id());
if let Some(type_cache) = db.get_type_index().get_type_cache(&member_id.into()) {
if type_cache.is_doc() {
let typ = type_cache.as_type();
match typ {
LuaType::TableConst(_) => {}
LuaType::Tuple(tuple) => {
let types = tuple.get_types();
if tuple.is_infer_resolve() && types.len() == 1 && types[0].is_unknown() {
} else {
return Ok(typ.clone());
}
}
typ => return Ok(typ.clone()),
}
}
} else if field.is_value_field() {
return infer_table_field_value_should_be(db, cache, field);
} else {
return Err(InferFailReason::UnResolveMemberType(member_id));
}
let parnet_table_expr = field
.get_parent::<LuaTableExpr>()
.ok_or(InferFailReason::None)?;
let parent_table_expr_type = infer_table_should_be(db, cache, parnet_table_expr)?;
let index = LuaIndexMemberExpr::TableField(field);
let reason = match infer_member_by_member_key(
db,
cache,
&parent_table_expr_type,
index.clone(),
&InferGuard::new(),
) {
Ok(member_type) => return Ok(member_type),
Err(InferFailReason::FieldNotFound) => InferFailReason::FieldNotFound,
Err(err) => return Err(err),
};
match infer_member_by_operator(
db,
cache,
&parent_table_expr_type,
index,
&InferGuard::new(),
) {
Ok(member_type) => return Ok(member_type),
Err(InferFailReason::FieldNotFound) => {}
Err(err) => return Err(err),
}
Err(reason)
}
fn infer_table_type_by_local(
db: &DbIndex,
cache: &mut LuaInferCache,
local: LuaLocalStat,
table_expr: LuaTableExpr,
) -> InferResult {
let local_names = local.get_local_name_list().collect::<Vec<_>>();
let values = local.get_value_exprs().collect::<Vec<_>>();
let num = values
.iter()
.enumerate()
.find(|(_, value)| value.get_position() == table_expr.get_position())
.ok_or(InferFailReason::None)?
.0;
let local_name = local_names.get(num).ok_or(InferFailReason::None)?;
let decl_id = LuaDeclId::new(cache.get_file_id(), local_name.get_position());
match db.get_type_index().get_type_cache(&decl_id.into()) {
Some(type_cache) => match type_cache.as_type() {
LuaType::TableConst(_) => Err(InferFailReason::None),
typ => Ok(typ.clone()),
},
None => Err(InferFailReason::UnResolveDeclType(decl_id)),
}
}
fn infer_table_type_by_assign_stat(
db: &DbIndex,
cache: &mut LuaInferCache,
assign_stat: LuaAssignStat,
table_expr: LuaTableExpr,
) -> InferResult {
let (vars, exprs) = assign_stat.get_var_and_expr_list();
let num = exprs
.iter()
.enumerate()
.find(|(_, expr)| expr.get_position() == table_expr.get_position())
.ok_or(InferFailReason::None)?
.0;
let name = vars.get(num).ok_or(InferFailReason::None)?;
let decl_id = LuaDeclId::new(cache.get_file_id(), name.get_position());
if db.get_decl_index().get_decl(&decl_id).is_some() {
match db.get_type_index().get_type_cache(&decl_id.into()) {
Some(type_cache) => match type_cache.as_type() {
LuaType::TableConst(_) => Err(InferFailReason::None),
typ => Ok(typ.clone()),
},
None => Err(InferFailReason::UnResolveDeclType(decl_id)),
}
} else {
infer_expr(
db,
cache,
LuaExpr::cast(name.syntax().clone()).ok_or(InferFailReason::None)?,
)
}
}
fn infer_table_type_by_return_stat(
db: &DbIndex,
cache: &mut LuaInferCache,
return_stat: LuaReturnStat,
table_expr: LuaTableExpr,
) -> InferResult {
let in_file_syntax_id = InFiled::new(cache.get_file_id(), return_stat.get_syntax_id());
let cache_type = match db
.get_type_index()
.get_type_cache(&in_file_syntax_id.into())
{
Some(cache) => cache,
None => {
let in_file_range = InFiled::new(cache.get_file_id(), table_expr.get_range());
return Ok(LuaType::TableConst(in_file_range));
}
};
Ok(cache_type.as_type().clone())
}