emmylua_code_analysis/db_index/declaration/
decl.rs

1use std::fmt;
2
3use crate::LuaSignatureId;
4use crate::{db_index::LuaType, FileId};
5use emmylua_parser::{LuaKind, LuaSyntaxId, LuaSyntaxKind};
6use rowan::{TextRange, TextSize};
7use serde::de::{self, Visitor};
8use serde::{Deserialize, Deserializer, Serialize, Serializer};
9use smol_str::SmolStr;
10
11#[derive(Eq, PartialEq, Hash, Debug, Clone)]
12pub struct LuaDecl {
13    name: SmolStr,
14    file_id: FileId,
15    range: TextRange,
16    expr_id: Option<LuaSyntaxId>,
17    pub extra: LuaDeclExtra,
18}
19
20#[derive(Eq, PartialEq, Hash, Debug, Clone)]
21pub enum LuaDeclExtra {
22    Local {
23        kind: LuaKind,
24        decl_type: Option<LuaType>,
25        attrib: Option<LocalAttribute>,
26    },
27    Param {
28        idx: usize,
29        signature_id: LuaSignatureId,
30    },
31    Global {
32        kind: LuaKind,
33        decl_type: Option<LuaType>,
34    },
35}
36
37impl LuaDecl {
38    pub fn new(
39        name: &str,
40        file_id: FileId,
41        range: TextRange,
42        extra: LuaDeclExtra,
43        expr_id: Option<LuaSyntaxId>,
44    ) -> Self {
45        Self {
46            name: SmolStr::new(name),
47            file_id,
48            range,
49            expr_id,
50            extra,
51        }
52    }
53
54    pub fn get_file_id(&self) -> FileId {
55        self.file_id
56    }
57
58    pub fn get_id(&self) -> LuaDeclId {
59        LuaDeclId::new(self.file_id, self.range.start())
60    }
61
62    pub fn get_name(&self) -> &str {
63        &self.name
64    }
65
66    pub fn get_position(&self) -> TextSize {
67        self.range.start()
68    }
69
70    pub fn get_range(&self) -> TextRange {
71        self.range
72    }
73
74    pub fn get_type(&self) -> Option<&LuaType> {
75        match &self.extra {
76            LuaDeclExtra::Local { decl_type, .. } => decl_type.as_ref(),
77            LuaDeclExtra::Global { decl_type, .. } => decl_type.as_ref(),
78            LuaDeclExtra::Param { .. } => None,
79        }
80    }
81
82    pub(crate) fn set_decl_type(&mut self, decl_type: LuaType) {
83        match &mut self.extra {
84            LuaDeclExtra::Local { decl_type: dt, .. } => *dt = Some(decl_type),
85            LuaDeclExtra::Global { decl_type: dt, .. } => *dt = Some(decl_type),
86            LuaDeclExtra::Param { .. } => {}
87        }
88    }
89
90    pub fn get_syntax_id(&self) -> LuaSyntaxId {
91        match self.extra {
92            LuaDeclExtra::Local { kind, .. } => LuaSyntaxId::new(kind, self.range),
93            LuaDeclExtra::Param { .. } => {
94                LuaSyntaxId::new(LuaSyntaxKind::ParamName.into(), self.range)
95            }
96            LuaDeclExtra::Global { kind, .. } => LuaSyntaxId::new(kind, self.range),
97        }
98    }
99
100    pub fn get_value_syntax_id(&self) -> Option<LuaSyntaxId> {
101        self.expr_id
102    }
103
104    pub fn is_local(&self) -> bool {
105        matches!(
106            &self.extra,
107            LuaDeclExtra::Local { .. } | LuaDeclExtra::Param { .. }
108        )
109    }
110
111    pub fn is_param(&self) -> bool {
112        matches!(&self.extra, LuaDeclExtra::Param { .. })
113    }
114
115    pub fn is_global(&self) -> bool {
116        matches!(&self.extra, LuaDeclExtra::Global { .. })
117    }
118}
119
120#[derive(Eq, PartialEq, Hash, Debug, Clone, Copy)]
121pub struct LuaDeclId {
122    pub file_id: FileId,
123    pub position: TextSize,
124}
125
126impl Serialize for LuaDeclId {
127    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
128    where
129        S: Serializer,
130    {
131        let value = format!("{}|{}", self.file_id.id, u32::from(self.position));
132        serializer.serialize_str(&value)
133    }
134}
135
136impl<'de> Deserialize<'de> for LuaDeclId {
137    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
138    where
139        D: Deserializer<'de>,
140    {
141        struct LuaDeclIdVisitor;
142
143        impl<'de> Visitor<'de> for LuaDeclIdVisitor {
144            type Value = LuaDeclId;
145
146            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
147                formatter.write_str("a string with format 'file_id:position'")
148            }
149
150            fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
151            where
152                E: de::Error,
153            {
154                let parts: Vec<&str> = value.split('|').collect();
155                if parts.len() != 2 {
156                    return Err(E::custom("expected format 'file_id:position'"));
157                }
158
159                let file_id = FileId {
160                    id: parts[0]
161                        .parse()
162                        .map_err(|e| E::custom(format!("invalid file_id: {}", e)))?,
163                };
164                let position = TextSize::new(
165                    parts[1]
166                        .parse()
167                        .map_err(|e| E::custom(format!("invalid position: {}", e)))?,
168                );
169
170                Ok(LuaDeclId { file_id, position })
171            }
172        }
173
174        deserializer.deserialize_str(LuaDeclIdVisitor)
175    }
176}
177
178impl LuaDeclId {
179    pub fn new(file_id: FileId, position: TextSize) -> Self {
180        Self { file_id, position }
181    }
182}
183
184#[derive(Eq, PartialEq, Hash, Debug, Clone)]
185pub enum LocalAttribute {
186    Const,
187    Close,
188    IterConst,
189}