1use std::path::Path;
7
8use forge_core::schema::{
9 EnumDef, EnumVariant, FieldDef, FunctionArg, FunctionDef, FunctionKind, RustType,
10 SchemaRegistry, TableDef,
11};
12use quote::ToTokens;
13use syn::{Attribute, Expr, Fields, FnArg, Lit, Meta, Pat, ReturnType};
14use walkdir::WalkDir;
15
16use crate::Error;
17
18pub fn parse_project(src_dir: &Path) -> Result<SchemaRegistry, Error> {
20 let registry = SchemaRegistry::new();
21
22 let mut files: Vec<_> = WalkDir::new(src_dir)
23 .into_iter()
24 .filter_map(|e| e.ok())
25 .filter(|e| e.path().extension().map(|ext| ext == "rs").unwrap_or(false))
26 .collect();
27 files.sort_by(|a, b| a.path().cmp(b.path()));
28
29 for entry in files {
30 let content = std::fs::read_to_string(entry.path())?;
31 if let Err(e) = parse_file(&content, ®istry) {
32 tracing::debug!(file = ?entry.path(), error = %e, "Failed to parse file");
33 }
34 }
35
36 Ok(registry)
37}
38
39fn parse_file(content: &str, registry: &SchemaRegistry) -> Result<(), Error> {
41 let file = syn::parse_file(content).map_err(|e| Error::Template(e.to_string()))?;
42
43 for item in file.items {
44 match item {
45 syn::Item::Struct(item_struct) => {
46 if has_forge_model_attr(&item_struct.attrs) {
47 if let Some(table) = parse_model(&item_struct) {
48 registry.register_table(table);
49 }
50 } else if has_serde_derive(&item_struct.attrs) {
51 if let Some(table) = parse_dto_struct(&item_struct) {
53 registry.register_table(table);
54 }
55 }
56 }
57 syn::Item::Enum(item_enum) => {
58 if has_forge_enum_attr(&item_enum.attrs) {
59 if let Some(enum_def) = parse_enum(&item_enum) {
60 registry.register_enum(enum_def);
61 }
62 } else if has_serde_derive(&item_enum.attrs) {
63 if let Some(enum_def) = parse_enum(&item_enum) {
65 registry.register_enum(enum_def);
66 }
67 }
68 }
69 syn::Item::Fn(item_fn) => {
70 if let Some(func) = parse_function(&item_fn) {
71 registry.register_function(func);
72 }
73 }
74 _ => {}
75 }
76 }
77
78 Ok(())
79}
80
81fn has_forge_model_attr(attrs: &[Attribute]) -> bool {
83 attrs.iter().any(|attr| {
84 let path = attr.path();
85 path.is_ident("model")
86 || path.segments.len() == 2
87 && path.segments[0].ident == "forge"
88 && path.segments[1].ident == "model"
89 })
90}
91
92fn has_forge_enum_attr(attrs: &[Attribute]) -> bool {
94 attrs.iter().any(|attr| {
95 let path = attr.path();
96 path.is_ident("forge_enum")
97 || path.is_ident("enum_type")
98 || path.segments.len() == 2
99 && path.segments[0].ident == "forge"
100 && path.segments[1].ident == "enum_type"
101 })
102}
103
104fn has_serde_derive(attrs: &[Attribute]) -> bool {
106 attrs.iter().any(|attr| {
107 if !attr.path().is_ident("derive") {
108 return false;
109 }
110 let tokens = attr.meta.to_token_stream().to_string();
111 tokens.contains("Serialize") || tokens.contains("Deserialize")
112 })
113}
114
115fn parse_dto_struct(item: &syn::ItemStruct) -> Option<TableDef> {
117 let struct_name = item.ident.to_string();
118
119 let mut table = TableDef::new(&struct_name, &struct_name);
121
122 table.is_dto = true;
124
125 table.doc = get_doc_comment(&item.attrs);
127
128 if let Fields::Named(fields) = &item.fields {
130 for field in &fields.named {
131 if let Some(field_name) = &field.ident {
132 let field_def = parse_field(field_name.to_string(), &field.ty, &field.attrs);
133 table.fields.push(field_def);
134 }
135 }
136 }
137
138 Some(table)
139}
140
141fn parse_model(item: &syn::ItemStruct) -> Option<TableDef> {
143 let struct_name = item.ident.to_string();
144 let table_name = get_table_name_from_attrs(&item.attrs).unwrap_or_else(|| {
145 let snake = to_snake_case(&struct_name);
146 pluralize(&snake)
147 });
148
149 let mut table = TableDef::new(&table_name, &struct_name);
150
151 table.doc = get_doc_comment(&item.attrs);
153
154 if let Fields::Named(fields) = &item.fields {
156 for field in &fields.named {
157 if let Some(field_name) = &field.ident {
158 let field_def = parse_field(field_name.to_string(), &field.ty, &field.attrs);
159 table.fields.push(field_def);
160 }
161 }
162 }
163
164 Some(table)
165}
166
167fn parse_field(name: String, ty: &syn::Type, attrs: &[Attribute]) -> FieldDef {
169 let rust_type = type_to_rust_type(ty);
170 let mut field = FieldDef::new(&name, rust_type);
171 field.column_name = to_snake_case(&name);
172 field.doc = get_doc_comment(attrs);
173 field
174}
175
176fn parse_enum(item: &syn::ItemEnum) -> Option<EnumDef> {
178 let enum_name = item.ident.to_string();
179 let mut enum_def = EnumDef::new(&enum_name);
180 enum_def.doc = get_doc_comment(&item.attrs);
181
182 for variant in &item.variants {
183 let variant_name = variant.ident.to_string();
184 let mut enum_variant = EnumVariant::new(&variant_name);
185 enum_variant.doc = get_doc_comment(&variant.attrs);
186
187 if let Some((_, Expr::Lit(lit))) = &variant.discriminant {
189 if let Lit::Int(int_lit) = &lit.lit {
190 if let Ok(value) = int_lit.base10_parse::<i32>() {
191 enum_variant.int_value = Some(value);
192 }
193 }
194 }
195
196 enum_def.variants.push(enum_variant);
197 }
198
199 Some(enum_def)
200}
201
202fn parse_function(item: &syn::ItemFn) -> Option<FunctionDef> {
204 let kind = get_function_kind(&item.attrs)?;
205 let func_name = item.sig.ident.to_string();
206
207 let return_type = match &item.sig.output {
209 ReturnType::Default => RustType::Custom("()".to_string()),
210 ReturnType::Type(_, ty) => extract_result_type(ty),
211 };
212
213 let mut func = FunctionDef::new(&func_name, kind, return_type);
214 func.doc = get_doc_comment(&item.attrs);
215 func.is_async = item.sig.asyncness.is_some();
216
217 let mut skip_first = true;
219 for arg in &item.sig.inputs {
220 if let FnArg::Typed(pat_type) = arg {
221 if skip_first {
223 skip_first = false;
224 let type_str = quote::quote!(#pat_type.ty).to_string();
226 if type_str.contains("Context")
227 || type_str.contains("QueryContext")
228 || type_str.contains("MutationContext")
229 {
230 continue;
231 }
232 }
233
234 if let Pat::Ident(pat_ident) = &*pat_type.pat {
236 let arg_name = pat_ident.ident.to_string();
237 let arg_type = type_to_rust_type(&pat_type.ty);
238 func.args.push(FunctionArg::new(arg_name, arg_type));
239 }
240 }
241 }
242
243 Some(func)
244}
245
246fn get_function_kind(attrs: &[Attribute]) -> Option<FunctionKind> {
248 for attr in attrs {
249 let path = attr.path();
250 let segments: Vec<_> = path.segments.iter().map(|s| s.ident.to_string()).collect();
251
252 let kind_str = if segments.len() == 2 && segments[0] == "forge" {
254 Some(segments[1].as_str())
255 } else if segments.len() == 1 {
256 Some(segments[0].as_str())
257 } else {
258 None
259 };
260
261 if let Some(kind) = kind_str {
262 match kind {
263 "query" => return Some(FunctionKind::Query),
264 "mutation" => return Some(FunctionKind::Mutation),
265 "job" => return Some(FunctionKind::Job),
266 "cron" => return Some(FunctionKind::Cron),
267 "workflow" => return Some(FunctionKind::Workflow),
268 _ => {}
269 }
270 }
271 }
272 None
273}
274
275fn extract_result_type(ty: &syn::Type) -> RustType {
277 let type_str = quote::quote!(#ty).to_string().replace(' ', "");
278
279 if let Some(rest) = type_str.strip_prefix("Result<") {
281 let mut depth = 0;
283 let mut end_idx = 0;
284 for (i, c) in rest.chars().enumerate() {
285 match c {
286 '<' => depth += 1,
287 '>' => {
288 if depth == 0 {
289 end_idx = i;
290 break;
291 }
292 depth -= 1;
293 }
294 ',' if depth == 0 => {
295 end_idx = i;
296 break;
297 }
298 _ => {}
299 }
300 }
301 let inner = &rest[..end_idx];
302 return type_to_rust_type(
303 &syn::parse_str(inner)
304 .unwrap_or_else(|_| syn::parse_str::<syn::Type>("String").unwrap()),
305 );
306 }
307
308 type_to_rust_type(ty)
309}
310
311fn type_to_rust_type(ty: &syn::Type) -> RustType {
313 let type_str = quote::quote!(#ty).to_string().replace(' ', "");
314
315 match type_str.as_str() {
317 "String" | "&str" => RustType::String,
318 "i32" => RustType::I32,
319 "i64" => RustType::I64,
320 "f32" => RustType::F32,
321 "f64" => RustType::F64,
322 "bool" => RustType::Bool,
323 "Uuid" | "uuid::Uuid" => RustType::Uuid,
324 "DateTime<Utc>" | "chrono::DateTime<Utc>" | "chrono::DateTime<chrono::Utc>" => {
325 RustType::DateTime
326 }
327 "NaiveDate" | "chrono::NaiveDate" => RustType::Date,
328 "NaiveTime" | "chrono::NaiveTime" => RustType::Custom("NaiveTime".to_string()),
329 "serde_json::Value" | "Value" => RustType::Json,
330 "Vec<u8>" => RustType::Bytes,
331 _ => {
332 if let Some(inner) = type_str
334 .strip_prefix("Option<")
335 .and_then(|s| s.strip_suffix('>'))
336 {
337 let inner_type = match inner {
338 "String" => RustType::String,
339 "i32" => RustType::I32,
340 "i64" => RustType::I64,
341 "f64" => RustType::F64,
342 "bool" => RustType::Bool,
343 "Uuid" => RustType::Uuid,
344 _ => RustType::Custom(inner.to_string()),
345 };
346 return RustType::Option(Box::new(inner_type));
347 }
348
349 if let Some(inner) = type_str
351 .strip_prefix("Vec<")
352 .and_then(|s| s.strip_suffix('>'))
353 {
354 let inner_type = match inner {
355 "String" => RustType::String,
356 "i32" => RustType::I32,
357 "u8" => return RustType::Bytes,
358 _ => RustType::Custom(inner.to_string()),
359 };
360 return RustType::Vec(Box::new(inner_type));
361 }
362
363 RustType::Custom(type_str)
365 }
366 }
367}
368
369fn get_table_name_from_attrs(attrs: &[Attribute]) -> Option<String> {
371 for attr in attrs {
372 if attr.path().is_ident("table") {
373 if let Meta::List(list) = &attr.meta {
374 let tokens = list.tokens.to_string();
375 if let Some(value) = extract_name_value(&tokens) {
376 return Some(value);
377 }
378 }
379 }
380 }
381 None
382}
383
384fn get_attribute_string_value(attr: &Attribute) -> Option<String> {
386 if let Meta::NameValue(nv) = &attr.meta {
387 if let Expr::Lit(lit) = &nv.value {
388 if let Lit::Str(s) = &lit.lit {
389 return Some(s.value());
390 }
391 }
392 }
393 None
394}
395
396fn get_doc_comment(attrs: &[Attribute]) -> Option<String> {
398 let docs: Vec<String> = attrs
399 .iter()
400 .filter_map(|attr| {
401 if attr.path().is_ident("doc") {
402 get_attribute_string_value(attr)
403 } else {
404 None
405 }
406 })
407 .collect();
408
409 if docs.is_empty() {
410 None
411 } else {
412 Some(
413 docs.into_iter()
414 .map(|s| s.trim().to_string())
415 .collect::<Vec<_>>()
416 .join("\n"),
417 )
418 }
419}
420
421fn extract_name_value(s: &str) -> Option<String> {
423 let parts: Vec<&str> = s.splitn(2, '=').collect();
424 if parts.len() == 2 {
425 let value = parts[1].trim();
426 if let Some(stripped) = value.strip_prefix('"').and_then(|s| s.strip_suffix('"')) {
427 return Some(stripped.to_string());
428 }
429 }
430 None
431}
432
433fn to_snake_case(s: &str) -> String {
435 let mut result = String::new();
436 for (i, c) in s.chars().enumerate() {
437 if c.is_uppercase() {
438 if i > 0 {
439 result.push('_');
440 }
441 result.push(c.to_lowercase().next().unwrap());
442 } else {
443 result.push(c);
444 }
445 }
446 result
447}
448
449fn pluralize(s: &str) -> String {
451 if s.ends_with('s')
452 || s.ends_with("sh")
453 || s.ends_with("ch")
454 || s.ends_with('x')
455 || s.ends_with('z')
456 {
457 format!("{}es", s)
458 } else if let Some(stem) = s.strip_suffix('y') {
459 if !s.ends_with("ay") && !s.ends_with("ey") && !s.ends_with("oy") && !s.ends_with("uy") {
460 format!("{}ies", stem)
461 } else {
462 format!("{}s", s)
463 }
464 } else {
465 format!("{}s", s)
466 }
467}
468
469#[cfg(test)]
470mod tests {
471 use super::*;
472
473 #[test]
474 fn test_parse_model_source() {
475 let source = r#"
476 #[model]
477 struct User {
478 #[id]
479 id: Uuid,
480 email: String,
481 name: Option<String>,
482 #[indexed]
483 created_at: DateTime<Utc>,
484 }
485 "#;
486
487 let registry = SchemaRegistry::new();
488 parse_file(source, ®istry).unwrap();
489
490 let table = registry.get_table("users").unwrap();
491 assert_eq!(table.struct_name, "User");
492 assert_eq!(table.fields.len(), 4);
493 }
494
495 #[test]
496 fn test_parse_enum_source() {
497 let source = r#"
498 #[forge_enum]
499 enum ProjectStatus {
500 Draft,
501 Active,
502 Completed,
503 }
504 "#;
505
506 let registry = SchemaRegistry::new();
507 parse_file(source, ®istry).unwrap();
508
509 let enum_def = registry.get_enum("ProjectStatus").unwrap();
510 assert_eq!(enum_def.variants.len(), 3);
511 }
512
513 #[test]
514 fn test_to_snake_case() {
515 assert_eq!(to_snake_case("UserProfile"), "user_profile");
516 assert_eq!(to_snake_case("ID"), "i_d");
517 assert_eq!(to_snake_case("createdAt"), "created_at");
518 }
519
520 #[test]
521 fn test_pluralize() {
522 assert_eq!(pluralize("user"), "users");
523 assert_eq!(pluralize("category"), "categories");
524 assert_eq!(pluralize("box"), "boxes");
525 assert_eq!(pluralize("address"), "addresses");
526 }
527
528 #[test]
529 fn test_parse_query_function() {
530 let source = r#"
531 #[query]
532 async fn get_user(ctx: QueryContext, id: Uuid) -> Result<User> {
533 todo!()
534 }
535 "#;
536
537 let registry = SchemaRegistry::new();
538 parse_file(source, ®istry).unwrap();
539
540 let func = registry.get_function("get_user").unwrap();
541 assert_eq!(func.name, "get_user");
542 assert_eq!(func.kind, FunctionKind::Query);
543 assert!(func.is_async);
544 }
545
546 #[test]
547 fn test_parse_mutation_function() {
548 let source = r#"
549 #[mutation]
550 async fn create_user(ctx: MutationContext, name: String, email: String) -> Result<User> {
551 todo!()
552 }
553 "#;
554
555 let registry = SchemaRegistry::new();
556 parse_file(source, ®istry).unwrap();
557
558 let func = registry.get_function("create_user").unwrap();
559 assert_eq!(func.name, "create_user");
560 assert_eq!(func.kind, FunctionKind::Mutation);
561 assert_eq!(func.args.len(), 2);
562 assert_eq!(func.args[0].name, "name");
563 assert_eq!(func.args[1].name, "email");
564 }
565}