use std::{
any::type_name,
fmt,
hash::{BuildHasher, BuildHasherDefault, Hash, Hasher},
marker::PhantomData,
};
use la_arena::{Arena, Idx, RawIdx};
use rustc_hash::FxHasher;
use syntax::{ast, AstNode, AstPtr, SyntaxNode, SyntaxNodePtr};
pub type ErasedFileAstId = la_arena::Idx<syntax::SyntaxNodePtr>;
pub struct FileAstId<N: AstIdNode> {
raw: ErasedFileAstId,
covariant: PhantomData<fn() -> N>,
}
impl<N: AstIdNode> Clone for FileAstId<N> {
fn clone(&self) -> FileAstId<N> {
*self
}
}
impl<N: AstIdNode> Copy for FileAstId<N> {}
impl<N: AstIdNode> PartialEq for FileAstId<N> {
fn eq(&self, other: &Self) -> bool {
self.raw == other.raw
}
}
impl<N: AstIdNode> Eq for FileAstId<N> {}
impl<N: AstIdNode> Hash for FileAstId<N> {
fn hash<H: Hasher>(&self, hasher: &mut H) {
self.raw.hash(hasher);
}
}
impl<N: AstIdNode> fmt::Debug for FileAstId<N> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "FileAstId::<{}>({})", type_name::<N>(), self.raw.into_raw())
}
}
impl<N: AstIdNode> FileAstId<N> {
pub fn upcast<M: AstIdNode>(self) -> FileAstId<M>
where
N: Into<M>,
{
FileAstId { raw: self.raw, covariant: PhantomData }
}
pub fn erase(self) -> ErasedFileAstId {
self.raw
}
}
pub trait AstIdNode: AstNode {}
macro_rules! register_ast_id_node {
(impl AstIdNode for $($ident:ident),+ ) => {
$(
impl AstIdNode for ast::$ident {}
)+
fn should_alloc_id(kind: syntax::SyntaxKind) -> bool {
$(
ast::$ident::can_cast(kind)
)||+
}
};
}
register_ast_id_node! {
impl AstIdNode for
Item,
Adt,
Enum,
Variant,
Struct,
RecordField,
TupleField,
Union,
AssocItem,
Const,
Fn,
MacroCall,
TypeAlias,
ExternBlock,
ExternCrate,
Impl,
Macro,
MacroDef,
MacroRules,
Module,
Static,
Trait,
TraitAlias,
Use,
BlockExpr, ConstArg, Param, SelfParam
}
#[derive(Default)]
pub struct AstIdMap {
arena: Arena<SyntaxNodePtr>,
map: hashbrown::HashMap<Idx<SyntaxNodePtr>, (), ()>,
}
impl fmt::Debug for AstIdMap {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AstIdMap").field("arena", &self.arena).finish()
}
}
impl PartialEq for AstIdMap {
fn eq(&self, other: &Self) -> bool {
self.arena == other.arena
}
}
impl Eq for AstIdMap {}
impl AstIdMap {
pub fn from_source(node: &SyntaxNode) -> AstIdMap {
assert!(node.parent().is_none());
let mut res = AstIdMap::default();
if !should_alloc_id(node.kind()) {
res.alloc(node);
}
bdfs(node, |it| {
if should_alloc_id(it.kind()) {
res.alloc(&it);
TreeOrder::BreadthFirst
} else {
TreeOrder::DepthFirst
}
});
res.map = hashbrown::HashMap::with_capacity_and_hasher(res.arena.len(), ());
for (idx, ptr) in res.arena.iter() {
let hash = hash_ptr(ptr);
match res.map.raw_entry_mut().from_hash(hash, |idx2| *idx2 == idx) {
hashbrown::hash_map::RawEntryMut::Occupied(_) => unreachable!(),
hashbrown::hash_map::RawEntryMut::Vacant(entry) => {
entry.insert_with_hasher(hash, idx, (), |&idx| hash_ptr(&res.arena[idx]));
}
}
}
res.arena.shrink_to_fit();
res
}
pub fn root(&self) -> SyntaxNodePtr {
self.arena[Idx::from_raw(RawIdx::from_u32(0))]
}
pub fn ast_id<N: AstIdNode>(&self, item: &N) -> FileAstId<N> {
let raw = self.erased_ast_id(item.syntax());
FileAstId { raw, covariant: PhantomData }
}
pub fn ast_id_for_ptr<N: AstIdNode>(&self, ptr: AstPtr<N>) -> FileAstId<N> {
let ptr = ptr.syntax_node_ptr();
let hash = hash_ptr(&ptr);
match self.map.raw_entry().from_hash(hash, |&idx| self.arena[idx] == ptr) {
Some((&raw, &())) => FileAstId { raw, covariant: PhantomData },
None => panic!(
"Can't find {:?} in AstIdMap:\n{:?}",
ptr,
self.arena.iter().map(|(_id, i)| i).collect::<Vec<_>>(),
),
}
}
pub fn get<N: AstIdNode>(&self, id: FileAstId<N>) -> AstPtr<N> {
AstPtr::try_from_raw(self.arena[id.raw]).unwrap()
}
pub fn get_erased(&self, id: ErasedFileAstId) -> SyntaxNodePtr {
self.arena[id]
}
fn erased_ast_id(&self, item: &SyntaxNode) -> ErasedFileAstId {
let ptr = SyntaxNodePtr::new(item);
let hash = hash_ptr(&ptr);
match self.map.raw_entry().from_hash(hash, |&idx| self.arena[idx] == ptr) {
Some((&idx, &())) => idx,
None => panic!(
"Can't find {:?} in AstIdMap:\n{:?}",
item,
self.arena.iter().map(|(_id, i)| i).collect::<Vec<_>>(),
),
}
}
fn alloc(&mut self, item: &SyntaxNode) -> ErasedFileAstId {
self.arena.alloc(SyntaxNodePtr::new(item))
}
}
fn hash_ptr(ptr: &SyntaxNodePtr) -> u64 {
BuildHasherDefault::<FxHasher>::default().hash_one(ptr)
}
#[derive(Copy, Clone, PartialEq, Eq)]
enum TreeOrder {
BreadthFirst,
DepthFirst,
}
fn bdfs(node: &SyntaxNode, mut f: impl FnMut(SyntaxNode) -> TreeOrder) {
let mut curr_layer = vec![node.clone()];
let mut next_layer = vec![];
while !curr_layer.is_empty() {
curr_layer.drain(..).for_each(|node| {
let mut preorder = node.preorder();
while let Some(event) = preorder.next() {
match event {
syntax::WalkEvent::Enter(node) => {
if f(node.clone()) == TreeOrder::BreadthFirst {
next_layer.extend(node.children());
preorder.skip_subtree();
}
}
syntax::WalkEvent::Leave(_) => {}
}
}
});
std::mem::swap(&mut curr_layer, &mut next_layer);
}
}