Skip to main content

atrg_codegen/
generator.rs

1//! Code generation from lexicon definitions.
2
3use std::path::Path;
4
5use convert_case::{Case, Casing};
6
7use crate::lexicon::{self, LexiconDef, LexiconDoc, LexiconProperty};
8
9/// Options for code generation.
10#[derive(Debug, Clone)]
11pub struct GenOptions {
12    /// Whether to generate handler stubs (default: true).
13    pub generate_stubs: bool,
14    /// Whether to generate route wiring (default: true).
15    pub generate_routes: bool,
16}
17
18impl Default for GenOptions {
19    fn default() -> Self {
20        Self {
21            generate_stubs: true,
22            generate_routes: true,
23        }
24    }
25}
26
27/// Report from a code generation run.
28#[derive(Debug)]
29pub struct GenReport {
30    /// Number of lexicon files processed.
31    pub files_processed: usize,
32    /// Number of types generated.
33    pub types_generated: usize,
34    /// Number of handler stubs generated.
35    pub stubs_generated: usize,
36    /// Paths of generated files.
37    pub output_files: Vec<String>,
38}
39
40/// Generate Rust code from lexicon JSON files.
41///
42/// Walks `input_dir` for `*.json` files, parses each as a lexicon,
43/// and writes generated Rust code to `output_dir`.
44pub fn generate(
45    input_dir: &Path,
46    output_dir: &Path,
47    opts: GenOptions,
48) -> anyhow::Result<GenReport> {
49    let mut report = GenReport {
50        files_processed: 0,
51        types_generated: 0,
52        stubs_generated: 0,
53        output_files: vec![],
54    };
55
56    // Collect all lexicon files
57    let mut lexicons: Vec<LexiconDoc> = Vec::new();
58
59    for entry in walkdir::WalkDir::new(input_dir)
60        .into_iter()
61        .filter_map(|e| e.ok())
62        .filter(|e| e.path().extension().is_some_and(|ext| ext == "json"))
63    {
64        let content = std::fs::read_to_string(entry.path())?;
65        match lexicon::parse_lexicon(&content) {
66            Ok(doc) => {
67                tracing::debug!(id = %doc.id, path = %entry.path().display(), "parsed lexicon");
68                lexicons.push(doc);
69                report.files_processed += 1;
70            }
71            Err(e) => {
72                anyhow::bail!(
73                    "Failed to parse lexicon at {}: {}",
74                    entry.path().display(),
75                    e
76                );
77            }
78        }
79    }
80
81    if lexicons.is_empty() {
82        tracing::warn!(dir = %input_dir.display(), "no lexicon JSON files found");
83        return Ok(report);
84    }
85
86    // Generate code for each lexicon
87    std::fs::create_dir_all(output_dir)?;
88
89    let mut all_types = String::new();
90    let mut all_routes = Vec::new();
91
92    for doc in &lexicons {
93        let (types_code, route_entries, type_count, stub_count) = generate_for_lexicon(doc, &opts)?;
94        all_types.push_str(&types_code);
95        all_types.push('\n');
96        all_routes.extend(route_entries);
97        report.types_generated += type_count;
98        report.stubs_generated += stub_count;
99    }
100
101    // Write types module
102    let types_path = output_dir.join("types.rs");
103    let types_content = format!(
104        "//! Generated types from AT Protocol lexicons.\n\
105         //!\n\
106         //! DO NOT EDIT — this file is generated by `atrg generate`.\n\n\
107         use serde::{{Deserialize, Serialize}};\n\n\
108         {all_types}"
109    );
110    let formatted = format_code(&types_content);
111    std::fs::write(&types_path, &formatted)?;
112    report.output_files.push(types_path.display().to_string());
113
114    // Write routes module (if enabled)
115    if opts.generate_routes && !all_routes.is_empty() {
116        let routes_code = generate_routes_module(&all_routes);
117        let routes_path = output_dir.join("routes.rs");
118        let formatted = format_code(&routes_code);
119        std::fs::write(&routes_path, &formatted)?;
120        report.output_files.push(routes_path.display().to_string());
121    }
122
123    // Write mod.rs
124    let mod_path = output_dir.join("mod.rs");
125    let mut mod_content = String::from(
126        "//! Generated code from AT Protocol lexicons.\n\
127         //!\n\
128         //! DO NOT EDIT — this file is generated by `atrg generate`.\n\n\
129         pub mod types;\n",
130    );
131    if opts.generate_routes && !all_routes.is_empty() {
132        mod_content.push_str("pub mod routes;\n");
133    }
134    std::fs::write(&mod_path, &mod_content)?;
135    report.output_files.push(mod_path.display().to_string());
136
137    tracing::info!(
138        files = report.files_processed,
139        types = report.types_generated,
140        stubs = report.stubs_generated,
141        "code generation complete"
142    );
143
144    Ok(report)
145}
146
147/// A route entry for the generated routes module.
148struct RouteEntry {
149    nsid: String,
150    method: &'static str, // "get" or "post"
151    handler_name: String,
152}
153
154fn generate_for_lexicon(
155    doc: &LexiconDoc,
156    opts: &GenOptions,
157) -> anyhow::Result<(String, Vec<RouteEntry>, usize, usize)> {
158    let mut code = String::new();
159    let mut routes = Vec::new();
160    let mut type_count = 0;
161    let mut stub_count = 0;
162
163    let type_prefix = nsid_to_type_prefix(&doc.id);
164
165    for (def_name, def) in &doc.defs {
166        match def {
167            LexiconDef::Record {
168                description,
169                record: Some(obj),
170                ..
171            } => {
172                let struct_name = if def_name == "main" {
173                    format!("{type_prefix}Record")
174                } else {
175                    format!("{type_prefix}{}", def_name.to_case(Case::Pascal))
176                };
177                code.push_str(&generate_struct(&struct_name, description.as_deref(), obj));
178                type_count += 1;
179            }
180            LexiconDef::Object(obj) => {
181                let struct_name = if def_name == "main" {
182                    type_prefix.clone()
183                } else {
184                    format!("{type_prefix}{}", def_name.to_case(Case::Pascal))
185                };
186                code.push_str(&generate_struct(
187                    &struct_name,
188                    obj.description.as_deref(),
189                    obj,
190                ));
191                type_count += 1;
192            }
193            LexiconDef::Query {
194                description: _,
195                parameters,
196                output,
197            } => {
198                // Generate params struct
199                if let Some(params) = parameters {
200                    let name = format!("{type_prefix}Params");
201                    code.push_str(&generate_struct(&name, None, params));
202                    type_count += 1;
203                }
204                // Generate output struct
205                if let Some(out) = output {
206                    if let Some(schema) = &out.schema {
207                        let name = format!("{type_prefix}Output");
208                        code.push_str(&generate_struct(&name, None, schema));
209                        type_count += 1;
210                    }
211                }
212                if opts.generate_stubs && def_name == "main" {
213                    let handler = nsid_to_handler_name(&doc.id);
214                    routes.push(RouteEntry {
215                        nsid: doc.id.clone(),
216                        method: "get",
217                        handler_name: handler,
218                    });
219                    stub_count += 1;
220                }
221            }
222            LexiconDef::Procedure {
223                description: _,
224                input,
225                output,
226            } => {
227                // Generate input struct
228                if let Some(inp) = input {
229                    if let Some(schema) = &inp.schema {
230                        let name = format!("{type_prefix}Input");
231                        code.push_str(&generate_struct(&name, None, schema));
232                        type_count += 1;
233                    }
234                }
235                // Generate output struct
236                if let Some(out) = output {
237                    if let Some(schema) = &out.schema {
238                        let name = format!("{type_prefix}Output");
239                        code.push_str(&generate_struct(&name, None, schema));
240                        type_count += 1;
241                    }
242                }
243                if opts.generate_stubs && def_name == "main" {
244                    let handler = nsid_to_handler_name(&doc.id);
245                    routes.push(RouteEntry {
246                        nsid: doc.id.clone(),
247                        method: "post",
248                        handler_name: handler,
249                    });
250                    stub_count += 1;
251                }
252            }
253            _ => {}
254        }
255    }
256
257    Ok((code, routes, type_count, stub_count))
258}
259
260fn generate_struct(name: &str, description: Option<&str>, obj: &lexicon::LexiconObject) -> String {
261    let mut s = String::new();
262
263    if let Some(desc) = description {
264        s.push_str(&format!("/// {desc}\n"));
265    }
266    s.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
267    s.push_str(&format!("pub struct {name} {{\n"));
268
269    // Sort properties for deterministic output
270    let mut props: Vec<_> = obj.properties.iter().collect();
271    props.sort_by_key(|(k, _)| *k);
272
273    for (field_name, prop) in &props {
274        let rust_name = field_name.to_case(Case::Snake);
275        let rust_type = property_to_rust_type(prop, obj.required.contains(*field_name));
276
277        if let Some(desc) = &prop.description {
278            s.push_str(&format!("    /// {desc}\n"));
279        }
280
281        if rust_name != **field_name {
282            s.push_str(&format!("    #[serde(rename = \"{field_name}\")]\n"));
283        }
284
285        if !obj.required.contains(*field_name) {
286            s.push_str("    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n");
287        }
288
289        s.push_str(&format!("    pub {rust_name}: {rust_type},\n"));
290    }
291
292    s.push_str("}\n\n");
293    s
294}
295
296fn property_to_rust_type(prop: &LexiconProperty, required: bool) -> String {
297    let base = match prop.prop_type.as_str() {
298        "string" => "String".to_string(),
299        "integer" => "i64".to_string(),
300        "boolean" => "bool".to_string(),
301        "blob" => "serde_json::Value".to_string(),
302        "unknown" => "serde_json::Value".to_string(),
303        "cid-link" => "String".to_string(),
304        "array" => {
305            if let Some(items) = &prop.items {
306                format!("Vec<{}>", property_to_rust_type(items, true))
307            } else {
308                "Vec<serde_json::Value>".to_string()
309            }
310        }
311        "ref" | "union" => "serde_json::Value".to_string(),
312        _ => "serde_json::Value".to_string(),
313    };
314
315    if required {
316        base
317    } else {
318        format!("Option<{base}>")
319    }
320}
321
322fn generate_routes_module(routes: &[RouteEntry]) -> String {
323    let mut s = String::from(
324        "//! Generated XRPC route wiring.\n\
325         //!\n\
326         //! DO NOT EDIT — this file is generated by `atrg generate`.\n\n\
327         use axum::{Router, routing::{get, post}, Json};\n\
328         use atrg_core::AppState;\n\
329         use atrg_xrpc::XrpcError;\n\n\
330         /// Mount all generated XRPC routes.\n\
331         pub fn xrpc_routes() -> Router<AppState> {\n\
332         \x20   atrg_xrpc::xrpc_router()\n",
333    );
334
335    for route in routes {
336        let method = route.method;
337        s.push_str(&format!(
338            "        .route(\"/xrpc/{}\", {method}({}))\n",
339            route.nsid, route.handler_name
340        ));
341    }
342
343    s.push_str("}\n\n");
344
345    // Generate stub handlers
346    for route in routes {
347        s.push_str(&format!(
348            "/// Stub handler for `{}`.\n\
349             ///\n\
350             /// TODO: Implement this handler.\n\
351             async fn {}() -> Result<Json<serde_json::Value>, XrpcError> {{\n\
352             \x20   todo!(\"implement {}\")\n\
353             }}\n\n",
354            route.nsid, route.handler_name, route.nsid
355        ));
356    }
357
358    s
359}
360
361fn nsid_to_type_prefix(nsid: &str) -> String {
362    nsid.split('.')
363        .map(|s| s.to_case(Case::Pascal))
364        .collect::<Vec<_>>()
365        .join("")
366}
367
368fn nsid_to_handler_name(nsid: &str) -> String {
369    nsid.split('.')
370        .next_back()
371        .unwrap_or(nsid)
372        .to_case(Case::Snake)
373}
374
375fn format_code(code: &str) -> String {
376    match syn::parse_file(code) {
377        Ok(syntax_tree) => prettyplease::unparse(&syntax_tree),
378        Err(_) => {
379            tracing::warn!("generated code could not be parsed by syn; skipping formatting");
380            code.to_string()
381        }
382    }
383}
384
385#[cfg(test)]
386mod tests {
387    use super::*;
388    use std::fs;
389
390    fn setup_fixture(dir: &Path, files: &[(&str, &str)]) {
391        fs::create_dir_all(dir).unwrap();
392        for (name, content) in files {
393            fs::write(dir.join(name), content).unwrap();
394        }
395    }
396
397    #[test]
398    fn generate_from_query_lexicon() {
399        let input = tempfile::tempdir().unwrap();
400        let output = tempfile::tempdir().unwrap();
401
402        let lexicon = r#"{
403            "lexicon": 1,
404            "id": "com.atrg.test.ping",
405            "defs": {
406                "main": {
407                    "type": "query",
408                    "description": "Test ping",
409                    "output": {
410                        "encoding": "application/json",
411                        "schema": {
412                            "type": "object",
413                            "required": ["pong"],
414                            "properties": {
415                                "pong": { "type": "boolean" },
416                                "echo": { "type": "string" }
417                            }
418                        }
419                    }
420                }
421            }
422        }"#;
423
424        setup_fixture(input.path(), &[("ping.json", lexicon)]);
425
426        let report = generate(input.path(), output.path(), GenOptions::default()).unwrap();
427        assert_eq!(report.files_processed, 1);
428        assert!(report.types_generated >= 1);
429        assert_eq!(report.stubs_generated, 1);
430
431        // Verify types.rs was generated
432        let types = fs::read_to_string(output.path().join("types.rs")).unwrap();
433        assert!(types.contains("ComAtrgTestPingOutput"));
434        assert!(types.contains("pub pong: bool"));
435    }
436
437    #[test]
438    fn generate_from_record_lexicon() {
439        let input = tempfile::tempdir().unwrap();
440        let output = tempfile::tempdir().unwrap();
441
442        let lexicon = r#"{
443            "lexicon": 1,
444            "id": "com.atrg.test.post",
445            "defs": {
446                "main": {
447                    "type": "record",
448                    "description": "A test post",
449                    "key": "tid",
450                    "record": {
451                        "type": "object",
452                        "required": ["text", "createdAt"],
453                        "properties": {
454                            "text": { "type": "string", "max_length": 3000 },
455                            "createdAt": { "type": "string", "format": "datetime" }
456                        }
457                    }
458                }
459            }
460        }"#;
461
462        setup_fixture(input.path(), &[("post.json", lexicon)]);
463
464        let report = generate(input.path(), output.path(), GenOptions::default()).unwrap();
465        assert_eq!(report.files_processed, 1);
466        assert!(report.types_generated >= 1);
467
468        let types = fs::read_to_string(output.path().join("types.rs")).unwrap();
469        assert!(types.contains("ComAtrgTestPostRecord"));
470        assert!(types.contains("pub text: String"));
471    }
472
473    #[test]
474    fn malformed_lexicon_gives_error() {
475        let input = tempfile::tempdir().unwrap();
476        let output = tempfile::tempdir().unwrap();
477
478        setup_fixture(input.path(), &[("bad.json", "not valid json")]);
479
480        let result = generate(input.path(), output.path(), GenOptions::default());
481        assert!(result.is_err());
482        let err = result.unwrap_err().to_string();
483        assert!(
484            err.contains("bad.json"),
485            "error should mention the file: {err}"
486        );
487    }
488
489    #[test]
490    fn empty_dir_produces_empty_report() {
491        let input = tempfile::tempdir().unwrap();
492        let output = tempfile::tempdir().unwrap();
493
494        let report = generate(input.path(), output.path(), GenOptions::default()).unwrap();
495        assert_eq!(report.files_processed, 0);
496    }
497}