emmylua_parser/syntax/
mod.rs

1mod node;
2mod traits;
3mod tree;
4
5use serde::de::{self, Visitor};
6use serde::{Deserialize, Deserializer, Serialize, Serializer};
7use std::fmt;
8use std::iter::successors;
9use std::marker::PhantomData;
10
11use rowan::{Language, TextRange, TextSize};
12
13use crate::kind::{LuaKind, LuaSyntaxKind, LuaTokenKind};
14pub use node::*;
15pub use traits::*;
16pub use tree::{LuaSyntaxTree, LuaTreeBuilder};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
19pub struct LuaLanguage;
20
21impl Language for LuaLanguage {
22    type Kind = LuaKind;
23
24    fn kind_from_raw(raw: rowan::SyntaxKind) -> Self::Kind {
25        LuaKind::from_raw(raw.0)
26    }
27
28    fn kind_to_raw(kind: Self::Kind) -> rowan::SyntaxKind {
29        rowan::SyntaxKind(kind.get_raw())
30    }
31}
32
33pub type LuaSyntaxNode = rowan::SyntaxNode<LuaLanguage>;
34pub type LuaSyntaxToken = rowan::SyntaxToken<LuaLanguage>;
35pub type LuaSyntaxElement = rowan::NodeOrToken<LuaSyntaxNode, LuaSyntaxToken>;
36pub type LuaSyntaxElementChildren = rowan::SyntaxElementChildren<LuaLanguage>;
37pub type LuaSyntaxNodeChildren = rowan::SyntaxNodeChildren<LuaLanguage>;
38pub type LuaSyntaxNodePtr = rowan::ast::SyntaxNodePtr<LuaLanguage>;
39
40impl From<LuaSyntaxKind> for rowan::SyntaxKind {
41    fn from(kind: LuaSyntaxKind) -> Self {
42        let lua_kind = LuaKind::from(kind);
43        rowan::SyntaxKind(lua_kind.get_raw())
44    }
45}
46
47impl From<rowan::SyntaxKind> for LuaSyntaxKind {
48    fn from(kind: rowan::SyntaxKind) -> Self {
49        LuaKind::from_raw(kind.0).into()
50    }
51}
52
53impl From<LuaTokenKind> for rowan::SyntaxKind {
54    fn from(kind: LuaTokenKind) -> Self {
55        let lua_kind = LuaKind::from(kind);
56        rowan::SyntaxKind(lua_kind.get_raw())
57    }
58}
59
60impl From<rowan::SyntaxKind> for LuaTokenKind {
61    fn from(kind: rowan::SyntaxKind) -> Self {
62        LuaKind::from_raw(kind.0).into()
63    }
64}
65
66#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
67pub struct LuaSyntaxId {
68    kind: LuaKind,
69    range: TextRange,
70}
71
72impl LuaSyntaxId {
73    pub fn new(kind: LuaKind, range: TextRange) -> Self {
74        LuaSyntaxId { kind, range }
75    }
76
77    pub fn from_ptr(ptr: LuaSyntaxNodePtr) -> Self {
78        LuaSyntaxId {
79            kind: ptr.kind(),
80            range: ptr.text_range(),
81        }
82    }
83
84    pub fn from_node(node: &LuaSyntaxNode) -> Self {
85        LuaSyntaxId {
86            kind: node.kind(),
87            range: node.text_range(),
88        }
89    }
90
91    pub fn from_token(token: &LuaSyntaxToken) -> Self {
92        LuaSyntaxId {
93            kind: token.kind(),
94            range: token.text_range(),
95        }
96    }
97
98    pub fn get_kind(&self) -> LuaSyntaxKind {
99        self.kind.into()
100    }
101
102    pub fn get_token_kind(&self) -> LuaTokenKind {
103        self.kind.into()
104    }
105
106    pub fn is_token(&self) -> bool {
107        self.kind.is_token()
108    }
109
110    pub fn is_node(&self) -> bool {
111        self.kind.is_syntax()
112    }
113
114    pub fn get_range(&self) -> TextRange {
115        self.range
116    }
117
118    pub fn to_node(&self, tree: &LuaSyntaxTree) -> Option<LuaSyntaxNode> {
119        let root = tree.get_red_root();
120        if root.parent().is_some() {
121            return None;
122        }
123        self.to_node_from_root(&root)
124    }
125
126    pub fn to_node_from_root(&self, root: &LuaSyntaxNode) -> Option<LuaSyntaxNode> {
127        successors(Some(root.clone()), |node| {
128            node.child_or_token_at_range(self.range)?.into_node()
129        })
130        .find(|it| it.text_range() == self.range && it.kind() == self.kind)
131    }
132
133    pub fn to_token(&self, tree: &LuaSyntaxTree) -> Option<LuaSyntaxToken> {
134        let root = tree.get_red_root();
135        if root.parent().is_some() {
136            return None;
137        }
138        self.to_token_from_root(&root)
139    }
140
141    pub fn to_token_from_root(&self, root: &LuaSyntaxNode) -> Option<LuaSyntaxToken> {
142        let mut current_node = Some(root.clone());
143        while let Some(node) = current_node {
144            let node_or_token = node.child_or_token_at_range(self.range)?;
145            match node_or_token {
146                rowan::NodeOrToken::Node(node) => {
147                    current_node = Some(node);
148                }
149                rowan::NodeOrToken::Token(token) => {
150                    if token.text_range() == self.range && token.kind() == self.kind {
151                        return Some(token);
152                    }
153                    return None;
154                }
155            }
156        }
157        None
158    }
159
160    pub fn to_node_at_range(root: &LuaSyntaxNode, range: TextRange) -> Option<LuaSyntaxNode> {
161        successors(Some(root.clone()), |node| {
162            node.child_or_token_at_range(range)?.into_node()
163        })
164        .find(|it| it.text_range() == range)
165    }
166}
167
168impl Serialize for LuaSyntaxId {
169    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
170    where
171        S: Serializer,
172    {
173        let kind_raw = self.kind.get_raw();
174        let start = u32::from(self.range.start());
175        let end = u32::from(self.range.end());
176        let range_combined = ((start as u64) << 32) | (end as u64);
177        let value = format!("{:x}:{:x}", kind_raw, range_combined);
178        serializer.serialize_str(&value)
179    }
180}
181
182impl<'de> Deserialize<'de> for LuaSyntaxId {
183    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
184    where
185        D: Deserializer<'de>,
186    {
187        struct LuaSyntaxIdVisitor;
188
189        impl Visitor<'_> for LuaSyntaxIdVisitor {
190            type Value = LuaSyntaxId;
191
192            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
193                formatter.write_str("a string with format 'kind:range'")
194            }
195
196            fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
197            where
198                E: de::Error,
199            {
200                let parts: Vec<&str> = value.split(':').collect();
201                if parts.len() != 2 {
202                    return Err(E::custom("expected format 'kind:range'"));
203                }
204
205                let kind_raw = u16::from_str_radix(parts[0], 16)
206                    .map_err(|e| E::custom(format!("invalid kind: {}", e)))?;
207                let range_combined = u64::from_str_radix(parts[1], 16)
208                    .map_err(|e| E::custom(format!("invalid range: {}", e)))?;
209
210                let start = TextSize::new(((range_combined >> 32) & 0xFFFFFFFF) as u32);
211                let end = TextSize::new((range_combined & 0xFFFFFFFF) as u32);
212
213                Ok(LuaSyntaxId {
214                    kind: LuaKind::from_raw(kind_raw),
215                    range: TextRange::new(start, end),
216                })
217            }
218        }
219
220        deserializer.deserialize_str(LuaSyntaxIdVisitor)
221    }
222}
223
224#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
225pub struct LuaAstPtr<T: LuaAstNode> {
226    pub syntax_id: LuaSyntaxId,
227    _phantom: PhantomData<T>,
228}
229
230impl<T: LuaAstNode> LuaAstPtr<T> {
231    pub fn new(node: &T) -> Self {
232        LuaAstPtr {
233            syntax_id: node.get_syntax_id(),
234            _phantom: PhantomData,
235        }
236    }
237
238    pub fn get_syntax_id(&self) -> LuaSyntaxId {
239        self.syntax_id
240    }
241
242    pub fn to_node(&self, root: &LuaChunk) -> Option<T> {
243        let syntax_node = self.syntax_id.to_node_from_root(root.syntax());
244        if let Some(node) = syntax_node {
245            T::cast(node)
246        } else {
247            None
248        }
249    }
250}
251
252unsafe impl<T: LuaAstNode> Send for LuaAstPtr<T> {}
253unsafe impl<T: LuaAstNode> Sync for LuaAstPtr<T> {}