1use crate::extractor::TypeInfo;
2use crate::parser::ParsedFile;
3use log::{debug, warn};
4use std::collections::{HashMap, HashSet};
5
6pub struct TypeResolver {
8 parsed_files: Vec<ParsedFile>,
10 type_cache: HashMap<String, ResolvedType>,
12 resolving_stack: HashSet<String>,
14}
15
16#[derive(Debug, Clone)]
18pub struct ResolvedType {
19 pub name: String,
21 pub kind: TypeKind,
23}
24
25#[derive(Debug, Clone)]
27pub enum TypeKind {
28 Struct(StructDef),
30 Enum(EnumDef),
32 Primitive(PrimitiveType),
34 Generic(String),
36}
37
38#[derive(Debug, Clone)]
40pub struct StructDef {
41 pub fields: Vec<FieldDef>,
43}
44
45#[derive(Debug, Clone)]
47pub struct FieldDef {
48 pub name: String,
50 pub type_info: TypeInfo,
52 pub optional: bool,
54 pub serde_attrs: SerdeAttributes,
56}
57
58#[derive(Debug, Clone)]
60pub struct EnumDef {
61 pub variants: Vec<String>,
63}
64
65#[derive(Debug, Clone, PartialEq, Eq)]
67pub enum PrimitiveType {
68 String,
69 I8,
70 I16,
71 I32,
72 I64,
73 I128,
74 U8,
75 U16,
76 U32,
77 U64,
78 U128,
79 F32,
80 F64,
81 Bool,
82 Char,
83}
84
85#[derive(Debug, Clone, Default)]
87pub struct SerdeAttributes {
88 pub rename: Option<String>,
90 pub skip: bool,
92 pub flatten: bool,
94}
95
96impl TypeResolver {
97 pub fn new(parsed_files: Vec<ParsedFile>) -> Self {
99 debug!("Initializing TypeResolver with {} files", parsed_files.len());
100 Self {
101 parsed_files,
102 type_cache: HashMap::new(),
103 resolving_stack: HashSet::new(),
104 }
105 }
106
107 pub fn find_struct_definition(&self, name: &str) -> Option<&syn::ItemStruct> {
109 debug!("Searching for struct definition: {}", name);
110
111 for parsed_file in &self.parsed_files {
112 for item in &parsed_file.syntax_tree.items {
113 if let syn::Item::Struct(item_struct) = item {
114 if item_struct.ident == name {
115 debug!("Found struct {} in {}", name, parsed_file.path.display());
116 return Some(item_struct);
117 }
118 }
119 }
120 }
121
122 debug!("Struct {} not found", name);
123 None
124 }
125
126 pub fn find_enum_definition(&self, name: &str) -> Option<&syn::ItemEnum> {
128 debug!("Searching for enum definition: {}", name);
129
130 for parsed_file in &self.parsed_files {
131 for item in &parsed_file.syntax_tree.items {
132 if let syn::Item::Enum(item_enum) = item {
133 if item_enum.ident == name {
134 debug!("Found enum {} in {}", name, parsed_file.path.display());
135 return Some(item_enum);
136 }
137 }
138 }
139 }
140
141 debug!("Enum {} not found", name);
142 None
143 }
144
145 pub fn resolve_type(&mut self, type_name: &str) -> Option<ResolvedType> {
147 debug!("Resolving type: {}", type_name);
148
149 if let Some(cached) = self.type_cache.get(type_name) {
151 debug!("Type {} found in cache", type_name);
152 return Some(cached.clone());
153 }
154
155 if self.resolving_stack.contains(type_name) {
157 warn!("Circular reference detected for type: {}", type_name);
158 let placeholder = ResolvedType {
160 name: type_name.to_string(),
161 kind: TypeKind::Generic(format!("CircularRef<{}>", type_name)),
162 };
163 return Some(placeholder);
164 }
165
166 self.resolving_stack.insert(type_name.to_string());
168
169 if let Some(primitive) = Self::parse_primitive_type(type_name) {
171 let resolved = ResolvedType {
172 name: type_name.to_string(),
173 kind: TypeKind::Primitive(primitive),
174 };
175 self.type_cache.insert(type_name.to_string(), resolved.clone());
176 self.resolving_stack.remove(type_name);
177 return Some(resolved);
178 }
179
180 let result = if let Some(struct_def) = self.find_struct_definition(type_name) {
182 let resolved = self.parse_struct_definition(struct_def);
183 self.type_cache.insert(type_name.to_string(), resolved.clone());
184 Some(resolved)
185 } else if let Some(enum_def) = self.find_enum_definition(type_name) {
186 let resolved = self.parse_enum_definition(enum_def);
188 self.type_cache.insert(type_name.to_string(), resolved.clone());
189 Some(resolved)
190 } else {
191 warn!("Could not resolve type: {}", type_name);
192 None
193 };
194
195 self.resolving_stack.remove(type_name);
197
198 result
199 }
200
201 pub fn resolve_nested_types(&mut self, type_info: &TypeInfo) {
203 debug!("Resolving nested types for: {}", type_info.name);
204
205 if Self::parse_primitive_type(&type_info.name).is_none() {
207 self.resolve_type(&type_info.name);
208 }
209
210 for generic_arg in &type_info.generic_args {
212 self.resolve_nested_types(generic_arg);
213 }
214 }
215
216 fn parse_struct_definition(&self, item_struct: &syn::ItemStruct) -> ResolvedType {
218 let struct_name = item_struct.ident.to_string();
219 debug!("Parsing struct definition: {}", struct_name);
220
221 let fields = self.parse_struct_fields(item_struct);
222
223 ResolvedType {
224 name: struct_name,
225 kind: TypeKind::Struct(StructDef { fields }),
226 }
227 }
228
229 fn parse_enum_definition(&self, item_enum: &syn::ItemEnum) -> ResolvedType {
231 let enum_name = item_enum.ident.to_string();
232 debug!("Parsing enum definition: {}", enum_name);
233
234 let variants: Vec<String> = item_enum
235 .variants
236 .iter()
237 .map(|v| v.ident.to_string())
238 .collect();
239
240 debug!("Parsed {} variants", variants.len());
241
242 ResolvedType {
243 name: enum_name,
244 kind: TypeKind::Enum(EnumDef { variants }),
245 }
246 }
247
248 fn parse_struct_fields(&self, item_struct: &syn::ItemStruct) -> Vec<FieldDef> {
250 let mut fields = Vec::new();
251
252 if let syn::Fields::Named(named_fields) = &item_struct.fields {
253 for field in &named_fields.named {
254 if let Some(field_def) = self.parse_field(field) {
255 fields.push(field_def);
256 }
257 }
258 }
259
260 debug!("Parsed {} fields", fields.len());
261 fields
262 }
263
264 fn parse_field(&self, field: &syn::Field) -> Option<FieldDef> {
266 let field_name = field.ident.as_ref()?.to_string();
267 debug!("Parsing field: {}", field_name);
268
269 let type_info = Self::extract_type_info(&field.ty);
270 let optional = type_info.is_option;
271 let serde_attrs = Self::parse_serde_attributes(&field.attrs);
272
273 Some(FieldDef {
274 name: field_name,
275 type_info,
276 optional,
277 serde_attrs,
278 })
279 }
280
281 fn parse_serde_attributes(attrs: &[syn::Attribute]) -> SerdeAttributes {
283 let mut serde_attrs = SerdeAttributes::default();
284
285 for attr in attrs {
286 if !attr.path().is_ident("serde") {
288 continue;
289 }
290
291 if let Ok(meta_list) = attr.meta.require_list() {
293 let tokens_str = meta_list.tokens.to_string();
295
296 if let Some(value) = Self::extract_rename_value(&tokens_str) {
298 debug!("Found serde rename: {}", value);
299 serde_attrs.rename = Some(value);
300 }
301
302 if tokens_str.contains("skip") && !tokens_str.contains("skip_serializing_if") {
304 debug!("Found serde skip");
305 serde_attrs.skip = true;
306 }
307
308 if tokens_str.contains("flatten") {
310 debug!("Found serde flatten");
311 serde_attrs.flatten = true;
312 }
313 }
314 }
315
316 serde_attrs
317 }
318
319 fn extract_rename_value(tokens_str: &str) -> Option<String> {
321 if let Some(rename_pos) = tokens_str.find("rename") {
323 let after_rename = &tokens_str[rename_pos..];
324 if let Some(eq_pos) = after_rename.find('=') {
325 let after_eq = &after_rename[eq_pos + 1..];
326 if let Some(start_quote) = after_eq.find('"') {
328 let after_start = &after_eq[start_quote + 1..];
329 if let Some(end_quote) = after_start.find('"') {
330 let value = &after_start[..end_quote];
331 return Some(value.to_string());
332 }
333 }
334 }
335 }
336 None
337 }
338
339 fn extract_type_info(ty: &syn::Type) -> TypeInfo {
341 match ty {
342 syn::Type::Path(type_path) => {
343 Self::extract_type_info_from_path(&type_path.path)
344 }
345 _ => {
346 TypeInfo::new("Unknown".to_string())
348 }
349 }
350 }
351
352 fn extract_type_info_from_path(path: &syn::Path) -> TypeInfo {
354 if let Some(segment) = path.segments.last() {
355 let type_name = segment.ident.to_string();
356
357 if type_name == "Option" {
359 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
360 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
361 let inner_type_info = Self::extract_type_info(inner_ty);
362 return TypeInfo::option(inner_type_info);
363 }
364 }
365 }
366
367 if type_name == "Vec" {
369 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
370 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
371 let inner_type_info = Self::extract_type_info(inner_ty);
372 return TypeInfo::vec(inner_type_info);
373 }
374 }
375 }
376
377 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
379 let mut generic_args = Vec::new();
380 for arg in &args.args {
381 if let syn::GenericArgument::Type(inner_ty) = arg {
382 generic_args.push(Self::extract_type_info(inner_ty));
383 }
384 }
385
386 return TypeInfo {
387 name: type_name,
388 is_generic: !generic_args.is_empty(),
389 generic_args,
390 is_option: false,
391 is_vec: false,
392 };
393 }
394
395 TypeInfo::new(type_name)
397 } else {
398 TypeInfo::new("Unknown".to_string())
399 }
400 }
401
402 fn parse_primitive_type(type_name: &str) -> Option<PrimitiveType> {
404 match type_name {
405 "String" | "str" => Some(PrimitiveType::String),
406 "i8" => Some(PrimitiveType::I8),
407 "i16" => Some(PrimitiveType::I16),
408 "i32" => Some(PrimitiveType::I32),
409 "i64" => Some(PrimitiveType::I64),
410 "i128" => Some(PrimitiveType::I128),
411 "u8" => Some(PrimitiveType::U8),
412 "u16" => Some(PrimitiveType::U16),
413 "u32" => Some(PrimitiveType::U32),
414 "u64" => Some(PrimitiveType::U64),
415 "u128" => Some(PrimitiveType::U128),
416 "f32" => Some(PrimitiveType::F32),
417 "f64" => Some(PrimitiveType::F64),
418 "bool" => Some(PrimitiveType::Bool),
419 "char" => Some(PrimitiveType::Char),
420 _ => None,
421 }
422 }
423}
424
425#[cfg(test)]
426mod tests {
427 use super::*;
428 use crate::parser::AstParser;
429 use std::fs;
430 use std::io::Write;
431 use tempfile::TempDir;
432
433 fn create_temp_file(dir: &TempDir, name: &str, content: &str) -> std::path::PathBuf {
435 let file_path = dir.path().join(name);
436 let mut file = fs::File::create(&file_path).unwrap();
437 file.write_all(content.as_bytes()).unwrap();
438 file_path
439 }
440
441 fn create_resolver_from_code(code: &str) -> TypeResolver {
443 let temp_dir = TempDir::new().unwrap();
444 let file_path = create_temp_file(&temp_dir, "test.rs", code);
445 let parsed = AstParser::parse_file(&file_path).unwrap();
446 TypeResolver::new(vec![parsed])
447 }
448
449 #[test]
450 fn test_resolve_primitive_types() {
451 let resolver = create_resolver_from_code("");
452
453 let mut resolver = resolver;
454
455 let primitives = vec![
457 ("String", PrimitiveType::String),
458 ("i32", PrimitiveType::I32),
459 ("u64", PrimitiveType::U64),
460 ("f32", PrimitiveType::F32),
461 ("bool", PrimitiveType::Bool),
462 ];
463
464 for (type_name, expected_primitive) in primitives {
465 let resolved = resolver.resolve_type(type_name);
466 assert!(resolved.is_some());
467
468 let resolved = resolved.unwrap();
469 assert_eq!(resolved.name, type_name);
470
471 if let TypeKind::Primitive(prim) = resolved.kind {
472 assert_eq!(prim, expected_primitive);
473 } else {
474 panic!("Expected primitive type for {}", type_name);
475 }
476 }
477 }
478
479 #[test]
480 fn test_resolve_simple_struct() {
481 let code = r#"
482 pub struct User {
483 pub id: u32,
484 pub name: String,
485 pub active: bool,
486 }
487 "#;
488
489 let mut resolver = create_resolver_from_code(code);
490 let resolved = resolver.resolve_type("User");
491
492 assert!(resolved.is_some());
493 let resolved = resolved.unwrap();
494 assert_eq!(resolved.name, "User");
495
496 if let TypeKind::Struct(struct_def) = resolved.kind {
497 assert_eq!(struct_def.fields.len(), 3);
498
499 assert_eq!(struct_def.fields[0].name, "id");
501 assert_eq!(struct_def.fields[1].name, "name");
502 assert_eq!(struct_def.fields[2].name, "active");
503
504 assert_eq!(struct_def.fields[0].type_info.name, "u32");
506 assert_eq!(struct_def.fields[1].type_info.name, "String");
507 assert_eq!(struct_def.fields[2].type_info.name, "bool");
508 } else {
509 panic!("Expected struct type");
510 }
511 }
512
513 #[test]
514 fn test_resolve_struct_with_option() {
515 let code = r#"
516 pub struct User {
517 pub id: u32,
518 pub email: Option<String>,
519 }
520 "#;
521
522 let mut resolver = create_resolver_from_code(code);
523 let resolved = resolver.resolve_type("User");
524
525 assert!(resolved.is_some());
526 let resolved = resolved.unwrap();
527
528 if let TypeKind::Struct(struct_def) = resolved.kind {
529 assert_eq!(struct_def.fields.len(), 2);
530
531 let email_field = &struct_def.fields[1];
533 assert_eq!(email_field.name, "email");
534 assert!(email_field.type_info.is_option);
535 assert!(email_field.optional);
536 assert_eq!(email_field.type_info.name, "String");
537 } else {
538 panic!("Expected struct type");
539 }
540 }
541
542 #[test]
543 fn test_resolve_struct_with_vec() {
544 let code = r#"
545 pub struct Post {
546 pub id: u32,
547 pub tags: Vec<String>,
548 }
549 "#;
550
551 let mut resolver = create_resolver_from_code(code);
552 let resolved = resolver.resolve_type("Post");
553
554 assert!(resolved.is_some());
555 let resolved = resolved.unwrap();
556
557 if let TypeKind::Struct(struct_def) = resolved.kind {
558 assert_eq!(struct_def.fields.len(), 2);
559
560 let tags_field = &struct_def.fields[1];
562 assert_eq!(tags_field.name, "tags");
563 assert!(tags_field.type_info.is_vec);
564 assert_eq!(tags_field.type_info.name, "String");
565 } else {
566 panic!("Expected struct type");
567 }
568 }
569
570 #[test]
571 fn test_parse_serde_rename() {
572 let code = r#"
573 use serde::{Deserialize, Serialize};
574
575 #[derive(Serialize, Deserialize)]
576 pub struct User {
577 pub id: u32,
578 #[serde(rename = "userName")]
579 pub name: String,
580 }
581 "#;
582
583 let mut resolver = create_resolver_from_code(code);
584 let resolved = resolver.resolve_type("User");
585
586 assert!(resolved.is_some());
587 let resolved = resolved.unwrap();
588
589 if let TypeKind::Struct(struct_def) = resolved.kind {
590 let name_field = &struct_def.fields[1];
591 assert_eq!(name_field.name, "name");
592 assert_eq!(name_field.serde_attrs.rename, Some("userName".to_string()));
593 } else {
594 panic!("Expected struct type");
595 }
596 }
597
598 #[test]
599 fn test_parse_serde_skip() {
600 let code = r#"
601 use serde::{Deserialize, Serialize};
602
603 #[derive(Serialize, Deserialize)]
604 pub struct User {
605 pub id: u32,
606 #[serde(skip)]
607 pub password: String,
608 }
609 "#;
610
611 let mut resolver = create_resolver_from_code(code);
612 let resolved = resolver.resolve_type("User");
613
614 assert!(resolved.is_some());
615 let resolved = resolved.unwrap();
616
617 if let TypeKind::Struct(struct_def) = resolved.kind {
618 let password_field = &struct_def.fields[1];
619 assert_eq!(password_field.name, "password");
620 assert!(password_field.serde_attrs.skip);
621 } else {
622 panic!("Expected struct type");
623 }
624 }
625
626 #[test]
627 fn test_parse_serde_flatten() {
628 let code = r#"
629 use serde::{Deserialize, Serialize};
630
631 #[derive(Serialize, Deserialize)]
632 pub struct User {
633 pub id: u32,
634 #[serde(flatten)]
635 pub metadata: Metadata,
636 }
637
638 pub struct Metadata {
639 pub created_at: String,
640 }
641 "#;
642
643 let mut resolver = create_resolver_from_code(code);
644 let resolved = resolver.resolve_type("User");
645
646 assert!(resolved.is_some());
647 let resolved = resolved.unwrap();
648
649 if let TypeKind::Struct(struct_def) = resolved.kind {
650 let metadata_field = &struct_def.fields[1];
651 assert_eq!(metadata_field.name, "metadata");
652 assert!(metadata_field.serde_attrs.flatten);
653 } else {
654 panic!("Expected struct type");
655 }
656 }
657
658 #[test]
659 fn test_resolve_nested_struct() {
660 let code = r#"
661 pub struct User {
662 pub id: u32,
663 pub profile: Profile,
664 }
665
666 pub struct Profile {
667 pub bio: String,
668 pub avatar: String,
669 }
670 "#;
671
672 let mut resolver = create_resolver_from_code(code);
673
674 let user_resolved = resolver.resolve_type("User");
676 assert!(user_resolved.is_some());
677
678 let profile_resolved = resolver.resolve_type("Profile");
680 assert!(profile_resolved.is_some());
681
682 let profile_resolved = profile_resolved.unwrap();
683 if let TypeKind::Struct(struct_def) = profile_resolved.kind {
684 assert_eq!(struct_def.fields.len(), 2);
685 assert_eq!(struct_def.fields[0].name, "bio");
686 assert_eq!(struct_def.fields[1].name, "avatar");
687 } else {
688 panic!("Expected struct type");
689 }
690 }
691
692 #[test]
693 fn test_resolve_enum() {
694 let code = r#"
695 pub enum Status {
696 Active,
697 Inactive,
698 Pending,
699 }
700 "#;
701
702 let mut resolver = create_resolver_from_code(code);
703 let resolved = resolver.resolve_type("Status");
704
705 assert!(resolved.is_some());
706 let resolved = resolved.unwrap();
707 assert_eq!(resolved.name, "Status");
708
709 if let TypeKind::Enum(enum_def) = resolved.kind {
710 assert_eq!(enum_def.variants.len(), 3);
711 assert_eq!(enum_def.variants[0], "Active");
712 assert_eq!(enum_def.variants[1], "Inactive");
713 assert_eq!(enum_def.variants[2], "Pending");
714 } else {
715 panic!("Expected enum type");
716 }
717 }
718
719 #[test]
720 fn test_type_caching() {
721 let code = r#"
722 pub struct User {
723 pub id: u32,
724 pub name: String,
725 }
726 "#;
727
728 let mut resolver = create_resolver_from_code(code);
729
730 let resolved1 = resolver.resolve_type("User");
732 let resolved2 = resolver.resolve_type("User");
733
734 assert!(resolved1.is_some());
735 assert!(resolved2.is_some());
736
737 let r1 = resolved1.unwrap();
739 let r2 = resolved2.unwrap();
740 assert_eq!(r1.name, r2.name);
741 }
742
743 #[test]
744 fn test_circular_reference_detection() {
745 let code = r#"
746 pub struct Node {
747 pub value: i32,
748 pub next: Option<Box<Node>>,
749 }
750 "#;
751
752 let mut resolver = create_resolver_from_code(code);
753
754 let resolved = resolver.resolve_type("Node");
756 assert!(resolved.is_some());
757 }
758
759 #[test]
760 fn test_resolve_nonexistent_type() {
761 let code = r#"
762 pub struct User {
763 pub id: u32,
764 }
765 "#;
766
767 let mut resolver = create_resolver_from_code(code);
768 let resolved = resolver.resolve_type("NonExistent");
769
770 assert!(resolved.is_none());
771 }
772
773 #[test]
774 fn test_resolve_nested_types_recursively() {
775 let code = r#"
776 pub struct User {
777 pub id: u32,
778 pub posts: Vec<Post>,
779 }
780
781 pub struct Post {
782 pub id: u32,
783 pub title: String,
784 }
785 "#;
786
787 let mut resolver = create_resolver_from_code(code);
788
789 let user_resolved = resolver.resolve_type("User");
791 assert!(user_resolved.is_some());
792
793 if let Some(user) = user_resolved {
795 if let TypeKind::Struct(struct_def) = user.kind {
796 let posts_field = &struct_def.fields[1];
797
798 resolver.resolve_nested_types(&posts_field.type_info);
800
801 let post_resolved = resolver.resolve_type("Post");
803 assert!(post_resolved.is_some());
804 }
805 }
806 }
807
808 #[test]
809 fn test_complex_generic_types() {
810 let code = r#"
811 pub struct Response {
812 pub data: Option<Vec<String>>,
813 }
814 "#;
815
816 let mut resolver = create_resolver_from_code(code);
817 let resolved = resolver.resolve_type("Response");
818
819 assert!(resolved.is_some());
820 let resolved = resolved.unwrap();
821
822 if let TypeKind::Struct(struct_def) = resolved.kind {
823 let data_field = &struct_def.fields[0];
824 assert_eq!(data_field.name, "data");
825 assert!(data_field.type_info.is_option);
826
827 if let Some(inner) = data_field.type_info.generic_args.first() {
829 assert!(inner.is_vec);
830 assert_eq!(inner.name, "String");
831 } else {
832 panic!("Expected generic args for Option");
833 }
834 } else {
835 panic!("Expected struct type");
836 }
837 }
838}