use std::collections::HashMap;
use std::path::PathBuf;
use syn::{Field, Fields, Item, ItemStruct, Path, PathSegment, Type, TypePath, TypeReference};
#[derive(Debug, Clone)]
pub struct GlobalTypeRegistry {
pub types: HashMap<String, TypeDefinition>,
pub module_exports: HashMap<Vec<String>, Vec<String>>,
pub type_aliases: HashMap<String, String>,
pub imports: HashMap<PathBuf, ImportScope>,
}
#[derive(Debug, Clone)]
pub struct TypeDefinition {
pub name: String,
pub kind: TypeKind,
pub fields: Option<FieldRegistry>,
pub methods: Vec<MethodSignature>,
pub generic_params: Vec<String>,
pub module_path: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct FieldRegistry {
pub named_fields: HashMap<String, ResolvedFieldType>,
pub tuple_fields: Vec<ResolvedFieldType>,
}
#[derive(Debug, Clone)]
pub struct ResolvedFieldType {
pub type_name: String,
pub is_reference: bool,
pub is_mutable: bool,
pub generic_args: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct MethodSignature {
pub name: String,
pub self_param: Option<SelfParam>,
pub return_type: Option<String>,
pub param_types: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct SelfParam {
pub is_reference: bool,
pub is_mutable: bool,
}
#[derive(Debug, Clone, PartialEq)]
pub enum TypeKind {
Struct,
Enum,
Trait,
TypeAlias,
TupleStruct,
UnitStruct,
}
#[derive(Debug, Clone)]
pub struct ImportScope {
pub imports: HashMap<String, String>, pub wildcard_imports: Vec<Vec<String>>,
}
impl Default for GlobalTypeRegistry {
fn default() -> Self {
Self::new()
}
}
impl GlobalTypeRegistry {
pub fn new() -> Self {
Self {
types: HashMap::new(),
module_exports: HashMap::new(),
type_aliases: HashMap::new(),
imports: HashMap::new(),
}
}
pub fn register_struct(&mut self, module_path: Vec<String>, item: &ItemStruct) {
let name = item.ident.to_string();
let full_name = if module_path.is_empty() {
name.clone()
} else {
format!("{}::{}", module_path.join("::"), name)
};
let fields = self.extract_fields(&item.fields);
let generic_params = item
.generics
.params
.iter()
.filter_map(|param| match param {
syn::GenericParam::Type(type_param) => Some(type_param.ident.to_string()),
_ => None,
})
.collect();
let kind = match &item.fields {
Fields::Named(_) => TypeKind::Struct,
Fields::Unnamed(_) => TypeKind::TupleStruct,
Fields::Unit => TypeKind::UnitStruct,
};
let type_def = TypeDefinition {
name: full_name.clone(),
kind,
fields: Some(fields),
methods: Vec::new(),
generic_params,
module_path: module_path.clone(),
};
self.types.insert(full_name, type_def);
self.module_exports
.entry(module_path)
.or_default()
.push(name);
}
fn extract_fields(&self, fields: &Fields) -> FieldRegistry {
match fields {
Fields::Named(named_fields) => {
let mut named = HashMap::new();
for field in &named_fields.named {
if let Some(ident) = &field.ident {
let field_type = self.extract_field_type(field);
named.insert(ident.to_string(), field_type);
}
}
FieldRegistry {
named_fields: named,
tuple_fields: Vec::new(),
}
}
Fields::Unnamed(unnamed_fields) => {
let tuple_fields = unnamed_fields
.unnamed
.iter()
.map(|field| self.extract_field_type(field))
.collect();
FieldRegistry {
named_fields: HashMap::new(),
tuple_fields,
}
}
Fields::Unit => FieldRegistry {
named_fields: HashMap::new(),
tuple_fields: Vec::new(),
},
}
}
#[allow(dead_code)]
fn extract_type_name_from_path(path: &syn::Path) -> String {
path.segments
.iter()
.map(|seg| seg.ident.to_string())
.collect::<Vec<_>>()
.join("::")
}
fn extract_field_type(&self, field: &Field) -> ResolvedFieldType {
match &field.ty {
Type::Path(type_path) => Self::extract_path_type(type_path),
Type::Reference(type_ref) => Self::extract_reference_type(type_ref),
_ => Self::unknown_field_type(),
}
}
fn extract_path_type(type_path: &TypePath) -> ResolvedFieldType {
let type_name = Self::build_type_name(&type_path.path);
let generic_args = Self::extract_generic_args(&type_path.path);
ResolvedFieldType {
type_name,
is_reference: false,
is_mutable: false,
generic_args,
}
}
fn extract_reference_type(type_ref: &TypeReference) -> ResolvedFieldType {
let base_type = match &*type_ref.elem {
Type::Path(type_path) => ResolvedFieldType {
type_name: Self::build_type_name(&type_path.path),
is_reference: true,
is_mutable: type_ref.mutability.is_some(),
generic_args: Self::extract_generic_args(&type_path.path),
},
_ => Self::unknown_field_type(),
};
ResolvedFieldType {
is_reference: true,
is_mutable: type_ref.mutability.is_some(),
..base_type
}
}
fn build_type_name(path: &Path) -> String {
path.segments
.iter()
.map(|seg| seg.ident.to_string())
.collect::<Vec<_>>()
.join("::")
}
fn extract_generic_args(path: &Path) -> Vec<String> {
path.segments
.last()
.and_then(Self::extract_args_from_segment)
.unwrap_or_default()
}
fn extract_args_from_segment(segment: &PathSegment) -> Option<Vec<String>> {
match &segment.arguments {
syn::PathArguments::AngleBracketed(args) => Some(
args.args
.iter()
.filter_map(Self::extract_type_name_from_arg)
.collect(),
),
_ => None,
}
}
fn extract_type_name_from_arg(arg: &syn::GenericArgument) -> Option<String> {
match arg {
syn::GenericArgument::Type(Type::Path(type_path)) => type_path
.path
.segments
.last()
.map(|seg| seg.ident.to_string()),
_ => None,
}
}
fn unknown_field_type() -> ResolvedFieldType {
ResolvedFieldType {
type_name: "Unknown".to_string(),
is_reference: false,
is_mutable: false,
generic_args: Vec::new(),
}
}
pub fn get_type(&self, name: &str) -> Option<&TypeDefinition> {
self.types.get(name)
}
pub fn resolve_field(&self, type_name: &str, field_name: &str) -> Option<ResolvedFieldType> {
let type_def = self.get_type(type_name)?;
let fields = type_def.fields.as_ref()?;
fields.named_fields.get(field_name).cloned()
}
pub fn resolve_tuple_field(&self, type_name: &str, index: usize) -> Option<ResolvedFieldType> {
let type_def = self.get_type(type_name)?;
let fields = type_def.fields.as_ref()?;
fields.tuple_fields.get(index).cloned()
}
pub fn add_method(&mut self, type_name: &str, method: MethodSignature) {
if let Some(type_def) = self.types.get_mut(type_name) {
type_def.methods.push(method);
}
}
pub fn register_type_alias(&mut self, alias: String, target: String) {
self.type_aliases.insert(alias, target);
}
pub fn resolve_type_alias(&self, alias: &str) -> Option<&String> {
self.type_aliases.get(alias)
}
pub fn register_imports(&mut self, file: PathBuf, imports: ImportScope) {
self.imports.insert(file, imports);
}
pub fn get_imports(&self, file: &PathBuf) -> Option<&ImportScope> {
self.imports.get(file)
}
pub fn resolve_type_with_imports(&self, file: &PathBuf, name: &str) -> Option<String> {
if self.types.contains_key(name) {
return Some(name.to_string());
}
if let Some(import_scope) = self.get_imports(file) {
if let Some(full_name) = import_scope.imports.get(name) {
return Some(full_name.clone());
}
for module_path in &import_scope.wildcard_imports {
let potential_name = format!("{}::{}", module_path.join("::"), name);
if self.types.contains_key(&potential_name) {
return Some(potential_name);
}
}
}
if let Some(target) = self.resolve_type_alias(name) {
return Some(target.clone());
}
None
}
}
pub fn extract_type_definitions(
file: &syn::File,
module_path: Vec<String>,
registry: &mut GlobalTypeRegistry,
) {
for item in &file.items {
match item {
Item::Struct(item_struct) => {
registry.register_struct(module_path.clone(), item_struct);
}
Item::Mod(item_mod) => {
if let Some((_, items)) = &item_mod.content {
let mut nested_path = module_path.clone();
nested_path.push(item_mod.ident.to_string());
for item in items {
if let Item::Struct(item_struct) = item {
registry.register_struct(nested_path.clone(), item_struct);
}
}
}
}
_ => {}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use syn::{parse_quote, Field};
#[test]
fn test_extract_type_name_from_simple_path() {
let path: syn::Path = parse_quote!(String);
assert_eq!(
GlobalTypeRegistry::extract_type_name_from_path(&path),
"String"
);
}
#[test]
fn test_extract_type_name_from_qualified_path() {
let path: syn::Path = parse_quote!(std::collections::HashMap);
assert_eq!(
GlobalTypeRegistry::extract_type_name_from_path(&path),
"std::collections::HashMap"
);
}
#[test]
fn test_extract_generic_args_none() {
let path: syn::Path = parse_quote!(String);
assert_eq!(
GlobalTypeRegistry::extract_generic_args(&path),
Vec::<String>::new()
);
}
#[test]
fn test_extract_generic_args_single() {
let path: syn::Path = parse_quote!(Option<String>);
assert_eq!(
GlobalTypeRegistry::extract_generic_args(&path),
vec!["String"]
);
}
#[test]
fn test_extract_generic_args_multiple() {
let path: syn::Path = parse_quote!(HashMap<String, Value>);
let args = GlobalTypeRegistry::extract_generic_args(&path);
assert_eq!(args.len(), 2);
assert!(args.contains(&"String".to_string()));
assert!(args.contains(&"Value".to_string()));
}
#[test]
fn test_extract_field_type_simple() {
let registry = GlobalTypeRegistry::new();
let field: Field = parse_quote!(pub name: String);
let field_type = registry.extract_field_type(&field);
assert_eq!(field_type.type_name, "String");
assert!(!field_type.is_reference);
assert!(!field_type.is_mutable);
assert!(field_type.generic_args.is_empty());
}
#[test]
fn test_extract_field_type_reference() {
let registry = GlobalTypeRegistry::new();
let field: Field = parse_quote!(pub name: &str);
let field_type = registry.extract_field_type(&field);
assert_eq!(field_type.type_name, "str");
assert!(field_type.is_reference);
assert!(!field_type.is_mutable);
}
#[test]
fn test_extract_field_type_mutable_reference() {
let registry = GlobalTypeRegistry::new();
let field: Field = parse_quote!(pub name: &mut String);
let field_type = registry.extract_field_type(&field);
assert_eq!(field_type.type_name, "String");
assert!(field_type.is_reference);
assert!(field_type.is_mutable);
}
#[test]
fn test_extract_field_type_with_generics() {
let registry = GlobalTypeRegistry::new();
let field: Field = parse_quote!(pub items: Vec<Item>);
let field_type = registry.extract_field_type(&field);
assert_eq!(field_type.type_name, "Vec");
assert_eq!(field_type.generic_args, vec!["Item"]);
}
#[test]
fn test_extract_field_type_unknown() {
let registry = GlobalTypeRegistry::new();
let field: Field = parse_quote!(pub callback: fn());
let field_type = registry.extract_field_type(&field);
assert_eq!(field_type.type_name, "Unknown");
}
}