1use std::path::Path;
7
8use forge_core::schema::{
9 EnumDef, EnumVariant, FieldDef, FunctionArg, FunctionDef, FunctionKind, RustType,
10 SchemaRegistry, TableDef,
11};
12use quote::ToTokens;
13use syn::{Attribute, Expr, Fields, FnArg, Lit, Meta, Pat, ReturnType};
14use walkdir::WalkDir;
15
16use crate::Error;
17
18pub fn parse_project(src_dir: &Path) -> Result<SchemaRegistry, Error> {
20 let registry = SchemaRegistry::new();
21
22 for entry in WalkDir::new(src_dir)
23 .into_iter()
24 .filter_map(|e| e.ok())
25 .filter(|e| e.path().extension().map(|ext| ext == "rs").unwrap_or(false))
26 {
27 let content = std::fs::read_to_string(entry.path())?;
28 if let Err(e) = parse_file(&content, ®istry) {
29 tracing::debug!(file = ?entry.path(), error = %e, "Failed to parse file");
30 }
31 }
32
33 Ok(registry)
34}
35
36fn parse_file(content: &str, registry: &SchemaRegistry) -> Result<(), Error> {
38 let file = syn::parse_file(content).map_err(|e| Error::Template(e.to_string()))?;
39
40 for item in file.items {
41 match item {
42 syn::Item::Struct(item_struct) => {
43 if has_forge_model_attr(&item_struct.attrs) {
44 if let Some(table) = parse_model(&item_struct) {
45 registry.register_table(table);
46 }
47 } else if has_serde_derive(&item_struct.attrs) {
48 if let Some(table) = parse_dto_struct(&item_struct) {
50 registry.register_table(table);
51 }
52 }
53 }
54 syn::Item::Enum(item_enum) => {
55 if has_forge_enum_attr(&item_enum.attrs) {
56 if let Some(enum_def) = parse_enum(&item_enum) {
57 registry.register_enum(enum_def);
58 }
59 } else if has_serde_derive(&item_enum.attrs) {
60 if let Some(enum_def) = parse_enum(&item_enum) {
62 registry.register_enum(enum_def);
63 }
64 }
65 }
66 syn::Item::Fn(item_fn) => {
67 if let Some(func) = parse_function(&item_fn) {
68 registry.register_function(func);
69 }
70 }
71 _ => {}
72 }
73 }
74
75 Ok(())
76}
77
78fn has_forge_model_attr(attrs: &[Attribute]) -> bool {
80 attrs.iter().any(|attr| {
81 let path = attr.path();
82 path.is_ident("model")
83 || path.segments.len() == 2
84 && path.segments[0].ident == "forge"
85 && path.segments[1].ident == "model"
86 })
87}
88
89fn has_forge_enum_attr(attrs: &[Attribute]) -> bool {
91 attrs.iter().any(|attr| {
92 let path = attr.path();
93 path.is_ident("forge_enum")
94 || path.is_ident("enum_type")
95 || path.segments.len() == 2
96 && path.segments[0].ident == "forge"
97 && path.segments[1].ident == "enum_type"
98 })
99}
100
101fn has_serde_derive(attrs: &[Attribute]) -> bool {
103 attrs.iter().any(|attr| {
104 if !attr.path().is_ident("derive") {
105 return false;
106 }
107 let tokens = attr.meta.to_token_stream().to_string();
108 tokens.contains("Serialize") || tokens.contains("Deserialize")
109 })
110}
111
112fn parse_dto_struct(item: &syn::ItemStruct) -> Option<TableDef> {
114 let struct_name = item.ident.to_string();
115
116 let mut table = TableDef::new(&struct_name, &struct_name);
118
119 table.is_dto = true;
121
122 table.doc = get_doc_comment(&item.attrs);
124
125 if let Fields::Named(fields) = &item.fields {
127 for field in &fields.named {
128 if let Some(field_name) = &field.ident {
129 let field_def = parse_field(field_name.to_string(), &field.ty, &field.attrs);
130 table.fields.push(field_def);
131 }
132 }
133 }
134
135 Some(table)
136}
137
138fn parse_model(item: &syn::ItemStruct) -> Option<TableDef> {
140 let struct_name = item.ident.to_string();
141 let table_name = get_table_name_from_attrs(&item.attrs).unwrap_or_else(|| {
142 let snake = to_snake_case(&struct_name);
143 pluralize(&snake)
144 });
145
146 let mut table = TableDef::new(&table_name, &struct_name);
147
148 table.doc = get_doc_comment(&item.attrs);
150
151 if let Fields::Named(fields) = &item.fields {
153 for field in &fields.named {
154 if let Some(field_name) = &field.ident {
155 let field_def = parse_field(field_name.to_string(), &field.ty, &field.attrs);
156 table.fields.push(field_def);
157 }
158 }
159 }
160
161 Some(table)
162}
163
164fn parse_field(name: String, ty: &syn::Type, attrs: &[Attribute]) -> FieldDef {
166 let rust_type = type_to_rust_type(ty);
167 let mut field = FieldDef::new(&name, rust_type);
168 field.column_name = to_snake_case(&name);
169 field.doc = get_doc_comment(attrs);
170 field
171}
172
173fn parse_enum(item: &syn::ItemEnum) -> Option<EnumDef> {
175 let enum_name = item.ident.to_string();
176 let mut enum_def = EnumDef::new(&enum_name);
177 enum_def.doc = get_doc_comment(&item.attrs);
178
179 for variant in &item.variants {
180 let variant_name = variant.ident.to_string();
181 let mut enum_variant = EnumVariant::new(&variant_name);
182 enum_variant.doc = get_doc_comment(&variant.attrs);
183
184 if let Some((_, Expr::Lit(lit))) = &variant.discriminant {
186 if let Lit::Int(int_lit) = &lit.lit {
187 if let Ok(value) = int_lit.base10_parse::<i32>() {
188 enum_variant.int_value = Some(value);
189 }
190 }
191 }
192
193 enum_def.variants.push(enum_variant);
194 }
195
196 Some(enum_def)
197}
198
199fn parse_function(item: &syn::ItemFn) -> Option<FunctionDef> {
201 let kind = get_function_kind(&item.attrs)?;
202 let func_name = item.sig.ident.to_string();
203
204 let return_type = match &item.sig.output {
206 ReturnType::Default => RustType::Custom("()".to_string()),
207 ReturnType::Type(_, ty) => extract_result_type(ty),
208 };
209
210 let mut func = FunctionDef::new(&func_name, kind, return_type);
211 func.doc = get_doc_comment(&item.attrs);
212 func.is_async = item.sig.asyncness.is_some();
213
214 let mut skip_first = true;
216 for arg in &item.sig.inputs {
217 if let FnArg::Typed(pat_type) = arg {
218 if skip_first {
220 skip_first = false;
221 let type_str = quote::quote!(#pat_type.ty).to_string();
223 if type_str.contains("Context")
224 || type_str.contains("QueryContext")
225 || type_str.contains("MutationContext")
226 || type_str.contains("ActionContext")
227 {
228 continue;
229 }
230 }
231
232 if let Pat::Ident(pat_ident) = &*pat_type.pat {
234 let arg_name = pat_ident.ident.to_string();
235 let arg_type = type_to_rust_type(&pat_type.ty);
236 func.args.push(FunctionArg::new(arg_name, arg_type));
237 }
238 }
239 }
240
241 Some(func)
242}
243
244fn get_function_kind(attrs: &[Attribute]) -> Option<FunctionKind> {
246 for attr in attrs {
247 let path = attr.path();
248 let segments: Vec<_> = path.segments.iter().map(|s| s.ident.to_string()).collect();
249
250 let kind_str = if segments.len() == 2 && segments[0] == "forge" {
252 Some(segments[1].as_str())
253 } else if segments.len() == 1 {
254 Some(segments[0].as_str())
255 } else {
256 None
257 };
258
259 if let Some(kind) = kind_str {
260 match kind {
261 "query" => return Some(FunctionKind::Query),
262 "mutation" => return Some(FunctionKind::Mutation),
263 "action" => return Some(FunctionKind::Action),
264 "job" => return Some(FunctionKind::Job),
265 "cron" => return Some(FunctionKind::Cron),
266 "workflow" => return Some(FunctionKind::Workflow),
267 _ => {}
268 }
269 }
270 }
271 None
272}
273
274fn extract_result_type(ty: &syn::Type) -> RustType {
276 let type_str = quote::quote!(#ty).to_string().replace(' ', "");
277
278 if let Some(rest) = type_str.strip_prefix("Result<") {
280 let mut depth = 0;
282 let mut end_idx = 0;
283 for (i, c) in rest.chars().enumerate() {
284 match c {
285 '<' => depth += 1,
286 '>' => {
287 if depth == 0 {
288 end_idx = i;
289 break;
290 }
291 depth -= 1;
292 }
293 ',' if depth == 0 => {
294 end_idx = i;
295 break;
296 }
297 _ => {}
298 }
299 }
300 let inner = &rest[..end_idx];
301 return type_to_rust_type(
302 &syn::parse_str(inner)
303 .unwrap_or_else(|_| syn::parse_str::<syn::Type>("String").unwrap()),
304 );
305 }
306
307 type_to_rust_type(ty)
308}
309
310fn type_to_rust_type(ty: &syn::Type) -> RustType {
312 let type_str = quote::quote!(#ty).to_string().replace(' ', "");
313
314 match type_str.as_str() {
316 "String" | "&str" => RustType::String,
317 "i32" => RustType::I32,
318 "i64" => RustType::I64,
319 "f32" => RustType::F32,
320 "f64" => RustType::F64,
321 "bool" => RustType::Bool,
322 "Uuid" | "uuid::Uuid" => RustType::Uuid,
323 "DateTime<Utc>" | "chrono::DateTime<Utc>" | "chrono::DateTime<chrono::Utc>" => {
324 RustType::DateTime
325 }
326 "NaiveDate" | "chrono::NaiveDate" => RustType::Date,
327 "NaiveTime" | "chrono::NaiveTime" => RustType::Custom("NaiveTime".to_string()),
328 "serde_json::Value" | "Value" => RustType::Json,
329 "Vec<u8>" => RustType::Bytes,
330 _ => {
331 if let Some(inner) = type_str
333 .strip_prefix("Option<")
334 .and_then(|s| s.strip_suffix('>'))
335 {
336 let inner_type = match inner {
337 "String" => RustType::String,
338 "i32" => RustType::I32,
339 "i64" => RustType::I64,
340 "f64" => RustType::F64,
341 "bool" => RustType::Bool,
342 "Uuid" => RustType::Uuid,
343 _ => RustType::String, };
345 return RustType::Option(Box::new(inner_type));
346 }
347
348 if let Some(inner) = type_str
350 .strip_prefix("Vec<")
351 .and_then(|s| s.strip_suffix('>'))
352 {
353 let inner_type = match inner {
354 "String" => RustType::String,
355 "i32" => RustType::I32,
356 "u8" => return RustType::Bytes,
357 _ => RustType::String,
358 };
359 return RustType::Vec(Box::new(inner_type));
360 }
361
362 RustType::Custom(type_str)
364 }
365 }
366}
367
368fn get_table_name_from_attrs(attrs: &[Attribute]) -> Option<String> {
370 for attr in attrs {
371 if attr.path().is_ident("table") {
372 if let Meta::List(list) = &attr.meta {
373 let tokens = list.tokens.to_string();
374 if let Some(value) = extract_name_value(&tokens) {
375 return Some(value);
376 }
377 }
378 }
379 }
380 None
381}
382
383fn get_attribute_string_value(attr: &Attribute) -> Option<String> {
385 if let Meta::NameValue(nv) = &attr.meta {
386 if let Expr::Lit(lit) = &nv.value {
387 if let Lit::Str(s) = &lit.lit {
388 return Some(s.value());
389 }
390 }
391 }
392 None
393}
394
395fn get_doc_comment(attrs: &[Attribute]) -> Option<String> {
397 let docs: Vec<String> = attrs
398 .iter()
399 .filter_map(|attr| {
400 if attr.path().is_ident("doc") {
401 get_attribute_string_value(attr)
402 } else {
403 None
404 }
405 })
406 .collect();
407
408 if docs.is_empty() {
409 None
410 } else {
411 Some(
412 docs.into_iter()
413 .map(|s| s.trim().to_string())
414 .collect::<Vec<_>>()
415 .join("\n"),
416 )
417 }
418}
419
420fn extract_name_value(s: &str) -> Option<String> {
422 let parts: Vec<&str> = s.splitn(2, '=').collect();
423 if parts.len() == 2 {
424 let value = parts[1].trim();
425 if let Some(stripped) = value.strip_prefix('"').and_then(|s| s.strip_suffix('"')) {
426 return Some(stripped.to_string());
427 }
428 }
429 None
430}
431
432fn to_snake_case(s: &str) -> String {
434 let mut result = String::new();
435 for (i, c) in s.chars().enumerate() {
436 if c.is_uppercase() {
437 if i > 0 {
438 result.push('_');
439 }
440 result.push(c.to_lowercase().next().unwrap());
441 } else {
442 result.push(c);
443 }
444 }
445 result
446}
447
448fn pluralize(s: &str) -> String {
450 if s.ends_with('s')
451 || s.ends_with("sh")
452 || s.ends_with("ch")
453 || s.ends_with('x')
454 || s.ends_with('z')
455 {
456 format!("{}es", s)
457 } else if let Some(stem) = s.strip_suffix('y') {
458 if !s.ends_with("ay") && !s.ends_with("ey") && !s.ends_with("oy") && !s.ends_with("uy") {
459 format!("{}ies", stem)
460 } else {
461 format!("{}s", s)
462 }
463 } else {
464 format!("{}s", s)
465 }
466}
467
468#[cfg(test)]
469mod tests {
470 use super::*;
471
472 #[test]
473 fn test_parse_model_source() {
474 let source = r#"
475 #[model]
476 struct User {
477 #[id]
478 id: Uuid,
479 email: String,
480 name: Option<String>,
481 #[indexed]
482 created_at: DateTime<Utc>,
483 }
484 "#;
485
486 let registry = SchemaRegistry::new();
487 parse_file(source, ®istry).unwrap();
488
489 let table = registry.get_table("users").unwrap();
490 assert_eq!(table.struct_name, "User");
491 assert_eq!(table.fields.len(), 4);
492 }
493
494 #[test]
495 fn test_parse_enum_source() {
496 let source = r#"
497 #[forge_enum]
498 enum ProjectStatus {
499 Draft,
500 Active,
501 Completed,
502 }
503 "#;
504
505 let registry = SchemaRegistry::new();
506 parse_file(source, ®istry).unwrap();
507
508 let enum_def = registry.get_enum("ProjectStatus").unwrap();
509 assert_eq!(enum_def.variants.len(), 3);
510 }
511
512 #[test]
513 fn test_to_snake_case() {
514 assert_eq!(to_snake_case("UserProfile"), "user_profile");
515 assert_eq!(to_snake_case("ID"), "i_d");
516 assert_eq!(to_snake_case("createdAt"), "created_at");
517 }
518
519 #[test]
520 fn test_pluralize() {
521 assert_eq!(pluralize("user"), "users");
522 assert_eq!(pluralize("category"), "categories");
523 assert_eq!(pluralize("box"), "boxes");
524 assert_eq!(pluralize("address"), "addresses");
525 }
526
527 #[test]
528 fn test_parse_query_function() {
529 let source = r#"
530 #[query]
531 async fn get_user(ctx: QueryContext, id: Uuid) -> Result<User> {
532 todo!()
533 }
534 "#;
535
536 let registry = SchemaRegistry::new();
537 parse_file(source, ®istry).unwrap();
538
539 let func = registry.get_function("get_user").unwrap();
540 assert_eq!(func.name, "get_user");
541 assert_eq!(func.kind, FunctionKind::Query);
542 assert!(func.is_async);
543 }
544
545 #[test]
546 fn test_parse_mutation_function() {
547 let source = r#"
548 #[mutation]
549 async fn create_user(ctx: MutationContext, name: String, email: String) -> Result<User> {
550 todo!()
551 }
552 "#;
553
554 let registry = SchemaRegistry::new();
555 parse_file(source, ®istry).unwrap();
556
557 let func = registry.get_function("create_user").unwrap();
558 assert_eq!(func.name, "create_user");
559 assert_eq!(func.kind, FunctionKind::Mutation);
560 assert_eq!(func.args.len(), 2);
561 assert_eq!(func.args[0].name, "name");
562 assert_eq!(func.args[1].name, "email");
563 }
564
565 #[test]
566 fn test_parse_action_function() {
567 let source = r#"
568 #[action]
569 async fn send_notification(ctx: ActionContext, message: String) -> Result<()> {
570 todo!()
571 }
572 "#;
573
574 let registry = SchemaRegistry::new();
575 parse_file(source, ®istry).unwrap();
576
577 let func = registry.get_function("send_notification").unwrap();
578 assert_eq!(func.name, "send_notification");
579 assert_eq!(func.kind, FunctionKind::Action);
580 }
581}