1use std::path::{Path, PathBuf};
7
8use forge_core::schema::{
9 EnumDef, EnumVariant, FieldDef, FunctionArg, FunctionDef, FunctionKind, RustType,
10 SchemaRegistry, TableDef,
11};
12use forge_core::util::to_snake_case;
13use quote::ToTokens;
14use syn::{Attribute, Expr, Fields, FnArg, Lit, Meta, Pat, ReturnType};
15
16use crate::Error;
17
18fn collect_rs_files(dir: &Path, out: &mut Vec<PathBuf>) {
19 let entries = match std::fs::read_dir(dir) {
20 Ok(e) => e,
21 Err(_) => return,
22 };
23 for entry in entries.flatten() {
24 let path = entry.path();
25 if path.is_dir() {
26 collect_rs_files(&path, out);
27 } else if path.extension().is_some_and(|ext| ext == "rs") {
28 out.push(path);
29 }
30 }
31}
32
33pub fn parse_project(src_dir: &Path) -> Result<SchemaRegistry, Error> {
35 let registry = SchemaRegistry::new();
36
37 let mut files = Vec::new();
38 collect_rs_files(src_dir, &mut files);
39 files.sort();
40
41 for path in &files {
42 let content = std::fs::read_to_string(path)?;
43 if let Err(e) = parse_file(&content, ®istry) {
44 tracing::debug!(file = ?path, error = %e, "Failed to parse file");
45 }
46 }
47
48 Ok(registry)
49}
50
51fn parse_file(content: &str, registry: &SchemaRegistry) -> Result<(), Error> {
53 let file = syn::parse_file(content).map_err(|e| Error::Template(e.to_string()))?;
54
55 for item in file.items {
56 match item {
57 syn::Item::Struct(item_struct) => {
58 if has_forge_model_attr(&item_struct.attrs) {
59 if let Some(table) = parse_model(&item_struct) {
60 registry.register_table(table);
61 }
62 } else if has_serde_derive(&item_struct.attrs) {
63 if let Some(table) = parse_dto_struct(&item_struct) {
65 registry.register_table(table);
66 }
67 }
68 }
69 syn::Item::Enum(item_enum) => {
70 if has_forge_enum_attr(&item_enum.attrs) {
71 if let Some(enum_def) = parse_enum(&item_enum) {
72 registry.register_enum(enum_def);
73 }
74 } else if has_serde_derive(&item_enum.attrs) {
75 if let Some(enum_def) = parse_enum(&item_enum) {
77 registry.register_enum(enum_def);
78 }
79 }
80 }
81 syn::Item::Fn(item_fn) => {
82 if let Some(func) = parse_function(&item_fn) {
83 registry.register_function(func);
84 }
85 }
86 _ => {}
87 }
88 }
89
90 Ok(())
91}
92
93fn has_forge_model_attr(attrs: &[Attribute]) -> bool {
95 attrs.iter().any(|attr| {
96 let path = attr.path();
97 path.is_ident("model")
98 || path.segments.len() == 2
99 && path.segments[0].ident == "forge"
100 && path.segments[1].ident == "model"
101 })
102}
103
104fn has_forge_enum_attr(attrs: &[Attribute]) -> bool {
106 attrs.iter().any(|attr| {
107 let path = attr.path();
108 path.is_ident("forge_enum")
109 || path.is_ident("enum_type")
110 || path.segments.len() == 2
111 && path.segments[0].ident == "forge"
112 && (path.segments[1].ident == "enum_type" || path.segments[1].ident == "forge_enum")
113 })
114}
115
116fn has_serde_derive(attrs: &[Attribute]) -> bool {
118 attrs.iter().any(|attr| {
119 if !attr.path().is_ident("derive") {
120 return false;
121 }
122 let tokens = attr.meta.to_token_stream().to_string();
123 tokens.contains("Serialize") || tokens.contains("Deserialize")
124 })
125}
126
127fn parse_dto_struct(item: &syn::ItemStruct) -> Option<TableDef> {
129 let struct_name = item.ident.to_string();
130
131 let mut table = TableDef::new(&struct_name, &struct_name);
133
134 table.is_dto = true;
136
137 table.doc = get_doc_comment(&item.attrs);
139
140 if let Fields::Named(fields) = &item.fields {
142 for field in &fields.named {
143 if let Some(field_name) = &field.ident {
144 let field_def = parse_field(field_name.to_string(), &field.ty, &field.attrs);
145 table.fields.push(field_def);
146 }
147 }
148 }
149
150 Some(table)
151}
152
153fn parse_model(item: &syn::ItemStruct) -> Option<TableDef> {
155 let struct_name = item.ident.to_string();
156 let table_name = get_table_name_from_attrs(&item.attrs).unwrap_or_else(|| {
157 let snake = to_snake_case(&struct_name);
158 pluralize(&snake)
159 });
160
161 let mut table = TableDef::new(&table_name, &struct_name);
162
163 table.doc = get_doc_comment(&item.attrs);
165
166 if let Fields::Named(fields) = &item.fields {
168 for field in &fields.named {
169 if let Some(field_name) = &field.ident {
170 let field_def = parse_field(field_name.to_string(), &field.ty, &field.attrs);
171 table.fields.push(field_def);
172 }
173 }
174 }
175
176 Some(table)
177}
178
179fn parse_field(name: String, ty: &syn::Type, attrs: &[Attribute]) -> FieldDef {
181 let rust_type = type_to_rust_type(ty);
182 let mut field = FieldDef::new(&name, rust_type);
183 field.column_name = to_snake_case(&name);
184 field.doc = get_doc_comment(attrs);
185 field
186}
187
188fn parse_enum(item: &syn::ItemEnum) -> Option<EnumDef> {
190 let enum_name = item.ident.to_string();
191 let mut enum_def = EnumDef::new(&enum_name);
192 enum_def.doc = get_doc_comment(&item.attrs);
193
194 for variant in &item.variants {
195 let variant_name = variant.ident.to_string();
196 let mut enum_variant = EnumVariant::new(&variant_name);
197 enum_variant.doc = get_doc_comment(&variant.attrs);
198
199 if let Some((_, Expr::Lit(lit))) = &variant.discriminant
201 && let Lit::Int(int_lit) = &lit.lit
202 && let Ok(value) = int_lit.base10_parse::<i32>()
203 {
204 enum_variant.int_value = Some(value);
205 }
206
207 enum_def.variants.push(enum_variant);
208 }
209
210 Some(enum_def)
211}
212
213fn parse_function(item: &syn::ItemFn) -> Option<FunctionDef> {
215 let kind = get_function_kind(&item.attrs)?;
216 let func_name = item.sig.ident.to_string();
217
218 let return_type = match &item.sig.output {
220 ReturnType::Default => RustType::Custom("()".to_string()),
221 ReturnType::Type(_, ty) => extract_result_type(ty),
222 };
223
224 let mut func = FunctionDef::new(&func_name, kind, return_type);
225 func.doc = get_doc_comment(&item.attrs);
226 func.is_async = item.sig.asyncness.is_some();
227
228 let mut skip_first = true;
230 for arg in &item.sig.inputs {
231 if let FnArg::Typed(pat_type) = arg {
232 if skip_first {
234 skip_first = false;
235 let type_str = quote::quote!(#pat_type.ty).to_string();
237 if type_str.contains("Context")
238 || type_str.contains("QueryContext")
239 || type_str.contains("MutationContext")
240 {
241 continue;
242 }
243 }
244
245 if let Pat::Ident(pat_ident) = &*pat_type.pat {
247 let arg_name = pat_ident.ident.to_string();
248 let arg_type = type_to_rust_type(&pat_type.ty);
249 func.args.push(FunctionArg::new(arg_name, arg_type));
250 }
251 }
252 }
253
254 Some(func)
255}
256
257fn get_function_kind(attrs: &[Attribute]) -> Option<FunctionKind> {
259 for attr in attrs {
260 let path = attr.path();
261 let segments: Vec<_> = path.segments.iter().map(|s| s.ident.to_string()).collect();
262
263 let kind_str = if segments.len() == 2 && segments[0] == "forge" {
265 Some(segments[1].as_str())
266 } else if segments.len() == 1 {
267 Some(segments[0].as_str())
268 } else {
269 None
270 };
271
272 if let Some(kind) = kind_str {
273 match kind {
274 "query" => return Some(FunctionKind::Query),
275 "mutation" => return Some(FunctionKind::Mutation),
276 "job" => return Some(FunctionKind::Job),
277 "cron" => return Some(FunctionKind::Cron),
278 "workflow" => return Some(FunctionKind::Workflow),
279 _ => {}
280 }
281 }
282 }
283 None
284}
285
286fn extract_result_type(ty: &syn::Type) -> RustType {
288 let type_str = quote::quote!(#ty).to_string().replace(' ', "");
289
290 if let Some(rest) = type_str.strip_prefix("Result<") {
292 let mut depth = 0;
294 let mut end_idx = 0;
295 for (i, c) in rest.chars().enumerate() {
296 match c {
297 '<' => depth += 1,
298 '>' => {
299 if depth == 0 {
300 end_idx = i;
301 break;
302 }
303 depth -= 1;
304 }
305 ',' if depth == 0 => {
306 end_idx = i;
307 break;
308 }
309 _ => {}
310 }
311 }
312 let inner = &rest[..end_idx];
313 return type_to_rust_type(
314 &syn::parse_str(inner)
315 .unwrap_or_else(|_| syn::parse_str::<syn::Type>("String").unwrap()),
316 );
317 }
318
319 type_to_rust_type(ty)
320}
321
322fn type_to_rust_type(ty: &syn::Type) -> RustType {
324 let type_str = quote::quote!(#ty).to_string().replace(' ', "");
325
326 match type_str.as_str() {
328 "String" | "&str" => RustType::String,
329 "i32" => RustType::I32,
330 "i64" => RustType::I64,
331 "f32" => RustType::F32,
332 "f64" => RustType::F64,
333 "bool" => RustType::Bool,
334 "Uuid" | "uuid::Uuid" => RustType::Uuid,
335 "DateTime<Utc>" | "chrono::DateTime<Utc>" | "chrono::DateTime<chrono::Utc>" => {
336 RustType::DateTime
337 }
338 "NaiveDate" | "chrono::NaiveDate" => RustType::Date,
339 "NaiveTime" | "chrono::NaiveTime" => RustType::Custom("NaiveTime".to_string()),
340 "serde_json::Value" | "Value" => RustType::Json,
341 "Vec<u8>" => RustType::Bytes,
342 _ => {
343 if let Some(inner) = type_str
345 .strip_prefix("Option<")
346 .and_then(|s| s.strip_suffix('>'))
347 {
348 let inner_ty: syn::Type =
349 syn::parse_str(inner).unwrap_or_else(|_| syn::parse_str("String").unwrap());
350 return RustType::Option(Box::new(type_to_rust_type(&inner_ty)));
351 }
352
353 if let Some(inner) = type_str
355 .strip_prefix("Vec<")
356 .and_then(|s| s.strip_suffix('>'))
357 {
358 if inner == "u8" {
359 return RustType::Bytes;
360 }
361 let inner_ty: syn::Type =
362 syn::parse_str(inner).unwrap_or_else(|_| syn::parse_str("String").unwrap());
363 return RustType::Vec(Box::new(type_to_rust_type(&inner_ty)));
364 }
365
366 RustType::Custom(type_str)
368 }
369 }
370}
371
372fn get_table_name_from_attrs(attrs: &[Attribute]) -> Option<String> {
374 for attr in attrs {
375 if attr.path().is_ident("table")
376 && let Meta::List(list) = &attr.meta
377 {
378 let tokens = list.tokens.to_string();
379 if let Some(value) = extract_name_value(&tokens) {
380 return Some(value);
381 }
382 }
383 }
384 None
385}
386
387fn get_attribute_string_value(attr: &Attribute) -> Option<String> {
389 if let Meta::NameValue(nv) = &attr.meta
390 && let Expr::Lit(lit) = &nv.value
391 && let Lit::Str(s) = &lit.lit
392 {
393 return Some(s.value());
394 }
395 None
396}
397
398fn get_doc_comment(attrs: &[Attribute]) -> Option<String> {
400 let docs: Vec<String> = attrs
401 .iter()
402 .filter_map(|attr| {
403 if attr.path().is_ident("doc") {
404 get_attribute_string_value(attr)
405 } else {
406 None
407 }
408 })
409 .collect();
410
411 if docs.is_empty() {
412 None
413 } else {
414 Some(
415 docs.into_iter()
416 .map(|s| s.trim().to_string())
417 .collect::<Vec<_>>()
418 .join("\n"),
419 )
420 }
421}
422
423fn extract_name_value(s: &str) -> Option<String> {
425 let parts: Vec<&str> = s.splitn(2, '=').collect();
426 if parts.len() == 2 {
427 let value = parts[1].trim();
428 if let Some(stripped) = value.strip_prefix('"').and_then(|s| s.strip_suffix('"')) {
429 return Some(stripped.to_string());
430 }
431 }
432 None
433}
434
435fn pluralize(s: &str) -> String {
437 if s.ends_with('s')
438 || s.ends_with("sh")
439 || s.ends_with("ch")
440 || s.ends_with('x')
441 || s.ends_with('z')
442 {
443 format!("{}es", s)
444 } else if let Some(stem) = s.strip_suffix('y') {
445 if !s.ends_with("ay") && !s.ends_with("ey") && !s.ends_with("oy") && !s.ends_with("uy") {
446 format!("{}ies", stem)
447 } else {
448 format!("{}s", s)
449 }
450 } else {
451 format!("{}s", s)
452 }
453}
454
455#[cfg(test)]
456mod tests {
457 use super::*;
458
459 #[test]
460 fn test_parse_model_source() {
461 let source = r#"
462 #[model]
463 struct User {
464 #[id]
465 id: Uuid,
466 email: String,
467 name: Option<String>,
468 #[indexed]
469 created_at: DateTime<Utc>,
470 }
471 "#;
472
473 let registry = SchemaRegistry::new();
474 parse_file(source, ®istry).unwrap();
475
476 let table = registry.get_table("users").unwrap();
477 assert_eq!(table.struct_name, "User");
478 assert_eq!(table.fields.len(), 4);
479 }
480
481 #[test]
482 fn test_parse_enum_source() {
483 let source = r#"
484 #[forge_enum]
485 enum ProjectStatus {
486 Draft,
487 Active,
488 Completed,
489 }
490 "#;
491
492 let registry = SchemaRegistry::new();
493 parse_file(source, ®istry).unwrap();
494
495 let enum_def = registry.get_enum("ProjectStatus").unwrap();
496 assert_eq!(enum_def.variants.len(), 3);
497 }
498
499 #[test]
500 fn test_to_snake_case() {
501 assert_eq!(to_snake_case("UserProfile"), "user_profile");
502 assert_eq!(to_snake_case("ID"), "i_d");
503 assert_eq!(to_snake_case("createdAt"), "created_at");
504 }
505
506 #[test]
507 fn test_pluralize() {
508 assert_eq!(pluralize("user"), "users");
509 assert_eq!(pluralize("category"), "categories");
510 assert_eq!(pluralize("box"), "boxes");
511 assert_eq!(pluralize("address"), "addresses");
512 }
513
514 #[test]
515 fn test_parse_query_function() {
516 let source = r#"
517 #[query]
518 async fn get_user(ctx: QueryContext, id: Uuid) -> Result<User> {
519 todo!()
520 }
521 "#;
522
523 let registry = SchemaRegistry::new();
524 parse_file(source, ®istry).unwrap();
525
526 let func = registry.get_function("get_user").unwrap();
527 assert_eq!(func.name, "get_user");
528 assert_eq!(func.kind, FunctionKind::Query);
529 assert!(func.is_async);
530 }
531
532 #[test]
533 fn test_parse_mutation_function() {
534 let source = r#"
535 #[mutation]
536 async fn create_user(ctx: MutationContext, name: String, email: String) -> Result<User> {
537 todo!()
538 }
539 "#;
540
541 let registry = SchemaRegistry::new();
542 parse_file(source, ®istry).unwrap();
543
544 let func = registry.get_function("create_user").unwrap();
545 assert_eq!(func.name, "create_user");
546 assert_eq!(func.kind, FunctionKind::Mutation);
547 assert_eq!(func.args.len(), 2);
548 assert_eq!(func.args[0].name, "name");
549 assert_eq!(func.args[1].name, "email");
550 }
551}