1use std::{
9 any::type_name,
10 fmt,
11 hash::{BuildHasher, BuildHasherDefault, Hash, Hasher},
12 marker::PhantomData,
13};
14
15use la_arena::{Arena, Idx, RawIdx};
16use rustc_hash::FxHasher;
17use syntax::{AstNode, AstPtr, SyntaxNode, SyntaxNodePtr, ast};
18
19#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
22pub struct ErasedFileAstId(u32);
23
24impl ErasedFileAstId {
25 pub const fn into_raw(self) -> u32 {
26 self.0
27 }
28 pub const fn from_raw(u32: u32) -> Self {
29 Self(u32)
30 }
31}
32
33impl fmt::Display for ErasedFileAstId {
34 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35 self.0.fmt(f)
36 }
37}
38impl fmt::Debug for ErasedFileAstId {
39 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40 self.0.fmt(f)
41 }
42}
43
44pub struct FileAstId<N: AstIdNode> {
46 raw: ErasedFileAstId,
47 covariant: PhantomData<fn() -> N>,
48}
49
50impl<N: AstIdNode> Clone for FileAstId<N> {
51 fn clone(&self) -> FileAstId<N> {
52 *self
53 }
54}
55impl<N: AstIdNode> Copy for FileAstId<N> {}
56
57impl<N: AstIdNode> PartialEq for FileAstId<N> {
58 fn eq(&self, other: &Self) -> bool {
59 self.raw == other.raw
60 }
61}
62impl<N: AstIdNode> Eq for FileAstId<N> {}
63impl<N: AstIdNode> Hash for FileAstId<N> {
64 fn hash<H: Hasher>(&self, hasher: &mut H) {
65 self.raw.hash(hasher);
66 }
67}
68
69impl<N: AstIdNode> fmt::Debug for FileAstId<N> {
70 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71 write!(f, "FileAstId::<{}>({})", type_name::<N>(), self.raw)
72 }
73}
74
75impl<N: AstIdNode> FileAstId<N> {
76 pub fn upcast<M: AstIdNode>(self) -> FileAstId<M>
78 where
79 N: Into<M>,
80 {
81 FileAstId { raw: self.raw, covariant: PhantomData }
82 }
83
84 pub fn erase(self) -> ErasedFileAstId {
85 self.raw
86 }
87}
88
89pub trait AstIdNode: AstNode {}
90macro_rules! register_ast_id_node {
91 (impl AstIdNode for $($ident:ident),+ ) => {
92 $(
93 impl AstIdNode for ast::$ident {}
94 )+
95 fn should_alloc_id(kind: syntax::SyntaxKind) -> bool {
96 $(
97 ast::$ident::can_cast(kind)
98 )||+
99 }
100 };
101}
102register_ast_id_node! {
103 impl AstIdNode for
104 Item, AnyHasGenericParams,
105 Adt,
106 Enum,
107 Variant,
108 Struct,
109 Union,
110 AssocItem,
111 Const,
112 Fn,
113 MacroCall,
114 TypeAlias,
115 ExternBlock,
116 ExternCrate,
117 Impl,
118 Macro,
119 MacroDef,
120 MacroRules,
121 Module,
122 Static,
123 Trait,
124 TraitAlias,
125 Use,
126 BlockExpr, ConstArg
127}
128
129#[derive(Default)]
131pub struct AstIdMap {
132 arena: Arena<SyntaxNodePtr>,
134 map: hashbrown::HashMap<Idx<SyntaxNodePtr>, (), ()>,
136}
137
138impl fmt::Debug for AstIdMap {
139 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
140 f.debug_struct("AstIdMap").field("arena", &self.arena).finish()
141 }
142}
143
144impl PartialEq for AstIdMap {
145 fn eq(&self, other: &Self) -> bool {
146 self.arena == other.arena
147 }
148}
149impl Eq for AstIdMap {}
150
151impl AstIdMap {
152 pub fn from_source(node: &SyntaxNode) -> AstIdMap {
153 assert!(node.parent().is_none());
154 let mut res = AstIdMap::default();
155
156 if !should_alloc_id(node.kind()) {
158 res.alloc(node);
159 }
160 bdfs(node, |it| {
165 if should_alloc_id(it.kind()) {
166 res.alloc(&it);
167 TreeOrder::BreadthFirst
168 } else {
169 TreeOrder::DepthFirst
170 }
171 });
172 res.map = hashbrown::HashMap::with_capacity_and_hasher(res.arena.len(), ());
173 for (idx, ptr) in res.arena.iter() {
174 let hash = hash_ptr(ptr);
175 match res.map.raw_entry_mut().from_hash(hash, |idx2| *idx2 == idx) {
176 hashbrown::hash_map::RawEntryMut::Occupied(_) => unreachable!(),
177 hashbrown::hash_map::RawEntryMut::Vacant(entry) => {
178 entry.insert_with_hasher(hash, idx, (), |&idx| hash_ptr(&res.arena[idx]));
179 }
180 }
181 }
182 res.arena.shrink_to_fit();
183 res
184 }
185
186 pub fn root(&self) -> SyntaxNodePtr {
188 self.arena[Idx::from_raw(RawIdx::from_u32(0))]
189 }
190
191 pub fn ast_id<N: AstIdNode>(&self, item: &N) -> FileAstId<N> {
192 let raw = self.erased_ast_id(item.syntax());
193 FileAstId { raw, covariant: PhantomData }
194 }
195
196 pub fn ast_id_for_ptr<N: AstIdNode>(&self, ptr: AstPtr<N>) -> FileAstId<N> {
197 let ptr = ptr.syntax_node_ptr();
198 let hash = hash_ptr(&ptr);
199 match self.map.raw_entry().from_hash(hash, |&idx| self.arena[idx] == ptr) {
200 Some((&raw, &())) => FileAstId {
201 raw: ErasedFileAstId(raw.into_raw().into_u32()),
202 covariant: PhantomData,
203 },
204 None => panic!(
205 "Can't find {:?} in AstIdMap:\n{:?}",
206 ptr,
207 self.arena.iter().map(|(_id, i)| i).collect::<Vec<_>>(),
208 ),
209 }
210 }
211
212 pub fn get<N: AstIdNode>(&self, id: FileAstId<N>) -> AstPtr<N> {
213 AstPtr::try_from_raw(self.arena[Idx::from_raw(RawIdx::from_u32(id.raw.into_raw()))])
214 .unwrap()
215 }
216
217 pub fn get_erased(&self, id: ErasedFileAstId) -> SyntaxNodePtr {
218 self.arena[Idx::from_raw(RawIdx::from_u32(id.into_raw()))]
219 }
220
221 fn erased_ast_id(&self, item: &SyntaxNode) -> ErasedFileAstId {
222 let ptr = SyntaxNodePtr::new(item);
223 let hash = hash_ptr(&ptr);
224 match self.map.raw_entry().from_hash(hash, |&idx| self.arena[idx] == ptr) {
225 Some((&idx, &())) => ErasedFileAstId(idx.into_raw().into_u32()),
226 None => panic!(
227 "Can't find {:?} in AstIdMap:\n{:?}\n source text: {}",
228 item,
229 self.arena.iter().map(|(_id, i)| i).collect::<Vec<_>>(),
230 item
231 ),
232 }
233 }
234
235 fn alloc(&mut self, item: &SyntaxNode) -> ErasedFileAstId {
236 ErasedFileAstId(self.arena.alloc(SyntaxNodePtr::new(item)).into_raw().into_u32())
237 }
238}
239
240fn hash_ptr(ptr: &SyntaxNodePtr) -> u64 {
241 BuildHasherDefault::<FxHasher>::default().hash_one(ptr)
242}
243
244#[derive(Copy, Clone, PartialEq, Eq)]
245enum TreeOrder {
246 BreadthFirst,
247 DepthFirst,
248}
249
250fn bdfs(node: &SyntaxNode, mut f: impl FnMut(SyntaxNode) -> TreeOrder) {
258 let mut curr_layer = vec![node.clone()];
259 let mut next_layer = vec![];
260 while !curr_layer.is_empty() {
261 curr_layer.drain(..).for_each(|node| {
262 let mut preorder = node.preorder();
263 while let Some(event) = preorder.next() {
264 match event {
265 syntax::WalkEvent::Enter(node) => {
266 if f(node.clone()) == TreeOrder::BreadthFirst {
267 next_layer.extend(node.children());
268 preorder.skip_subtree();
269 }
270 }
271 syntax::WalkEvent::Leave(_) => {}
272 }
273 }
274 });
275 std::mem::swap(&mut curr_layer, &mut next_layer);
276 }
277}