use super::{CodeGraphV2, TypeFlowGraphV2};
use crate::ast::ASTRegistry;
use crate::symbol::SymbolRegistry;
use crate::SymbolId;
use ryo_source::pure::{PureAttrMeta, PureItem};
use serde::Serialize;
use slotmap::SecondaryMap;
use smallvec::SmallVec;
#[derive(Clone, Default, Debug, Serialize)]
pub struct DeriveIndex {
symbol_derives: SecondaryMap<SymbolId, SmallVec<[String; 4]>>,
field_type_names: SecondaryMap<SymbolId, SmallVec<[String; 8]>>,
}
impl DeriveIndex {
pub fn new() -> Self {
Self::default()
}
pub fn build(
ast_registry: &ASTRegistry,
code_graph: &CodeGraphV2,
typeflow: &TypeFlowGraphV2,
symbol_registry: &SymbolRegistry,
) -> Self {
let mut index = Self::new();
index.rebuild_all(ast_registry, code_graph, typeflow, symbol_registry);
index
}
pub fn rebuild_all(
&mut self,
ast_registry: &ASTRegistry,
code_graph: &CodeGraphV2,
typeflow: &TypeFlowGraphV2,
symbol_registry: &SymbolRegistry,
) {
self.symbol_derives.clear();
self.field_type_names.clear();
for (id, item) in ast_registry.iter() {
self.index_item(id, item, code_graph, typeflow, symbol_registry);
}
}
pub fn rebuild_for_symbols(
&mut self,
symbols: &[SymbolId],
ast_registry: &ASTRegistry,
code_graph: &CodeGraphV2,
typeflow: &TypeFlowGraphV2,
symbol_registry: &SymbolRegistry,
) {
for &symbol_id in symbols {
self.symbol_derives.remove(symbol_id);
self.field_type_names.remove(symbol_id);
if let Some(item) = ast_registry.get(symbol_id) {
self.index_item(symbol_id, item, code_graph, typeflow, symbol_registry);
}
}
}
fn index_item(
&mut self,
id: SymbolId,
item: &PureItem,
code_graph: &CodeGraphV2,
typeflow: &TypeFlowGraphV2,
symbol_registry: &SymbolRegistry,
) {
let attrs = match item {
PureItem::Struct(s) => &s.attrs,
PureItem::Enum(e) => &e.attrs,
_ => return,
};
let mut derives: SmallVec<[String; 4]> = SmallVec::new();
for attr in attrs {
if attr.path == "derive" {
if let PureAttrMeta::List(args) = &attr.meta {
for trait_name in args.split(',').map(|s| s.trim()) {
if !trait_name.is_empty() {
derives.push(trait_name.to_string());
}
}
}
}
}
if !derives.is_empty() {
self.symbol_derives.insert(id, derives);
}
let mut field_types: SmallVec<[String; 8]> = SmallVec::new();
for child_id in code_graph.children_of(id) {
for use_id in typeflow.types_used_by(child_id) {
if let Some(path) = symbol_registry.resolve(use_id) {
field_types.push(path.name().to_string());
break; }
}
}
if !field_types.is_empty() {
self.field_type_names.insert(id, field_types);
}
}
pub fn iter_derives(&self) -> impl Iterator<Item = (SymbolId, &SmallVec<[String; 4]>)> {
self.symbol_derives.iter()
}
pub fn get_derives(&self, id: SymbolId) -> Option<&SmallVec<[String; 4]>> {
self.symbol_derives.get(id)
}
pub fn get_field_types(&self, id: SymbolId) -> Option<&SmallVec<[String; 8]>> {
self.field_type_names.get(id)
}
pub fn has_derive(&self, id: SymbolId, trait_name: &str) -> bool {
self.symbol_derives
.get(id)
.map(|derives| derives.iter().any(|d| d == trait_name))
.unwrap_or(false)
}
pub fn symbols_deriving(&self, trait_name: &str) -> Vec<SymbolId> {
self.symbol_derives
.iter()
.filter(|(_, derives)| derives.iter().any(|d| d == trait_name))
.map(|(id, _)| id)
.collect()
}
pub fn stats(&self) -> DeriveIndexStats {
let total_derives: usize = self.symbol_derives.values().map(|v| v.len()).sum();
let total_fields: usize = self.field_type_names.values().map(|v| v.len()).sum();
DeriveIndexStats {
symbols_with_derives: self.symbol_derives.len(),
total_derives,
symbols_with_fields: self.field_type_names.len(),
total_field_types: total_fields,
}
}
}
#[derive(Debug, Clone)]
pub struct DeriveIndexStats {
pub symbols_with_derives: usize,
pub total_derives: usize,
pub symbols_with_fields: usize,
pub total_field_types: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_derive_index_creation() {
let index = DeriveIndex::new();
assert_eq!(index.stats().symbols_with_derives, 0);
}
}