emmylua_code_analysis/semantic/decl/
mod.rs

1use std::collections::HashSet;
2
3use emmylua_parser::{LuaAstNode, LuaCallExpr, LuaIndexExpr, LuaSyntaxKind};
4use rowan::NodeOrToken;
5
6use crate::{
7    DbIndex, LuaDecl, LuaDeclId, LuaInferCache, LuaSemanticDeclId, LuaType, ModuleInfo,
8    SemanticDeclLevel, SemanticModel, infer_node_semantic_decl,
9    semantic::semantic_info::infer_token_semantic_decl,
10};
11
12pub fn enum_variable_is_param(
13    db: &DbIndex,
14    cache: &mut LuaInferCache,
15    index_expr: &LuaIndexExpr,
16    prefix_typ: &LuaType,
17) -> Option<()> {
18    let LuaType::Ref(id) = prefix_typ else {
19        return None;
20    };
21
22    let type_decl = db.get_type_index().get_type_decl(id)?;
23    if !type_decl.is_enum() {
24        return None;
25    }
26
27    let prefix_expr = index_expr.get_prefix_expr()?;
28    let prefix_decl = infer_node_semantic_decl(
29        db,
30        cache,
31        prefix_expr.syntax().clone(),
32        SemanticDeclLevel::default(),
33    )?;
34
35    let LuaSemanticDeclId::LuaDecl(decl_id) = prefix_decl else {
36        return None;
37    };
38
39    let mut decl_guard = DeclGuard::new();
40    let origin_decl_id = find_enum_origin(db, cache, decl_id, &mut decl_guard).unwrap_or(decl_id);
41    let decl = db.get_decl_index().get_decl(&origin_decl_id)?;
42
43    if decl.is_param() { Some(()) } else { None }
44}
45
46#[derive(Debug, Clone, PartialEq, Eq)]
47pub struct DeclGuard {
48    decl_set: HashSet<LuaDeclId>,
49}
50
51impl DeclGuard {
52    pub fn new() -> Self {
53        Self {
54            decl_set: HashSet::new(),
55        }
56    }
57
58    pub fn check(&mut self, decl_id: LuaDeclId) -> Option<()> {
59        if self.decl_set.contains(&decl_id) {
60            None
61        } else {
62            self.decl_set.insert(decl_id);
63            Some(())
64        }
65    }
66}
67
68fn find_enum_origin(
69    db: &DbIndex,
70    cache: &mut LuaInferCache,
71    decl_id: LuaDeclId,
72    decl_guard: &mut DeclGuard,
73) -> Option<LuaDeclId> {
74    decl_guard.check(decl_id)?;
75    let syntax_tree = db.get_vfs().get_syntax_tree(&decl_id.file_id)?;
76    let root = syntax_tree.get_red_root();
77
78    let node = db
79        .get_decl_index()
80        .get_decl(&decl_id)?
81        .get_value_syntax_id()?
82        .to_node_from_root(&root)?;
83
84    let semantic_decl = match node.into() {
85        NodeOrToken::Node(node) => {
86            infer_node_semantic_decl(db, cache, node, SemanticDeclLevel::NoTrace)
87        }
88        NodeOrToken::Token(token) => {
89            infer_token_semantic_decl(db, cache, token, SemanticDeclLevel::NoTrace)
90        }
91    };
92
93    match semantic_decl {
94        Some(LuaSemanticDeclId::Member(_)) => None,
95        Some(LuaSemanticDeclId::LuaDecl(new_decl_id)) => {
96            let decl = db.get_decl_index().get_decl(&new_decl_id)?;
97            if decl.get_value_syntax_id().is_some() {
98                Some(find_enum_origin(db, cache, new_decl_id, decl_guard).unwrap_or(new_decl_id))
99            } else {
100                Some(new_decl_id)
101            }
102        }
103        _ => None,
104    }
105}
106
107/// 解析 require 调用表达式并获取模块信息
108pub fn parse_require_module_info<'a>(
109    semantic_model: &'a SemanticModel,
110    decl: &LuaDecl,
111) -> Option<&'a ModuleInfo> {
112    let value_syntax_id = decl.get_value_syntax_id()?;
113    if value_syntax_id.get_kind() != LuaSyntaxKind::RequireCallExpr {
114        return None;
115    }
116
117    let node = semantic_model
118        .get_db()
119        .get_vfs()
120        .get_syntax_tree(&decl.get_file_id())
121        .and_then(|tree| {
122            let root = tree.get_red_root();
123            semantic_model
124                .get_db()
125                .get_decl_index()
126                .get_decl(&decl.get_id())
127                .and_then(|decl| decl.get_value_syntax_id())
128                .and_then(|syntax_id| syntax_id.to_node_from_root(&root))
129        })?;
130
131    let call_expr = LuaCallExpr::cast(node)?;
132    let arg_list = call_expr.get_args_list()?;
133    let first_arg = arg_list.get_args().next()?;
134    let require_path_type = semantic_model.infer_expr(first_arg.clone()).ok()?;
135    let module_path: String = match &require_path_type {
136        LuaType::StringConst(module_path) => module_path.as_ref().to_string(),
137        _ => {
138            return None;
139        }
140    };
141
142    semantic_model
143        .get_db()
144        .get_module_index()
145        .find_module(&module_path)
146}