1use std::path::Path;
4
5use convert_case::{Case, Casing};
6
7use crate::lexicon::{self, LexiconDef, LexiconDoc, LexiconProperty};
8
9#[derive(Debug, Clone)]
11pub struct GenOptions {
12 pub generate_stubs: bool,
14 pub generate_routes: bool,
16}
17
18impl Default for GenOptions {
19 fn default() -> Self {
20 Self {
21 generate_stubs: true,
22 generate_routes: true,
23 }
24 }
25}
26
27#[derive(Debug)]
29pub struct GenReport {
30 pub files_processed: usize,
32 pub types_generated: usize,
34 pub stubs_generated: usize,
36 pub output_files: Vec<String>,
38}
39
40pub fn generate(
45 input_dir: &Path,
46 output_dir: &Path,
47 opts: GenOptions,
48) -> anyhow::Result<GenReport> {
49 let mut report = GenReport {
50 files_processed: 0,
51 types_generated: 0,
52 stubs_generated: 0,
53 output_files: vec![],
54 };
55
56 let mut lexicons: Vec<LexiconDoc> = Vec::new();
58
59 for entry in walkdir::WalkDir::new(input_dir)
60 .into_iter()
61 .filter_map(|e| e.ok())
62 .filter(|e| e.path().extension().is_some_and(|ext| ext == "json"))
63 {
64 let content = std::fs::read_to_string(entry.path())?;
65 match lexicon::parse_lexicon(&content) {
66 Ok(doc) => {
67 tracing::debug!(id = %doc.id, path = %entry.path().display(), "parsed lexicon");
68 lexicons.push(doc);
69 report.files_processed += 1;
70 }
71 Err(e) => {
72 anyhow::bail!(
73 "Failed to parse lexicon at {}: {}",
74 entry.path().display(),
75 e
76 );
77 }
78 }
79 }
80
81 if lexicons.is_empty() {
82 tracing::warn!(dir = %input_dir.display(), "no lexicon JSON files found");
83 return Ok(report);
84 }
85
86 std::fs::create_dir_all(output_dir)?;
88
89 let mut all_types = String::new();
90 let mut all_routes = Vec::new();
91
92 for doc in &lexicons {
93 let (types_code, route_entries, type_count, stub_count) = generate_for_lexicon(doc, &opts)?;
94 all_types.push_str(&types_code);
95 all_types.push('\n');
96 all_routes.extend(route_entries);
97 report.types_generated += type_count;
98 report.stubs_generated += stub_count;
99 }
100
101 let types_path = output_dir.join("types.rs");
103 let types_content = format!(
104 "//! Generated types from AT Protocol lexicons.\n\
105 //!\n\
106 //! DO NOT EDIT — this file is generated by `atrg generate`.\n\n\
107 use serde::{{Deserialize, Serialize}};\n\n\
108 {all_types}"
109 );
110 let formatted = format_code(&types_content);
111 std::fs::write(&types_path, &formatted)?;
112 report.output_files.push(types_path.display().to_string());
113
114 if opts.generate_routes && !all_routes.is_empty() {
116 let routes_code = generate_routes_module(&all_routes);
117 let routes_path = output_dir.join("routes.rs");
118 let formatted = format_code(&routes_code);
119 std::fs::write(&routes_path, &formatted)?;
120 report.output_files.push(routes_path.display().to_string());
121 }
122
123 let mod_path = output_dir.join("mod.rs");
125 let mut mod_content = String::from(
126 "//! Generated code from AT Protocol lexicons.\n\
127 //!\n\
128 //! DO NOT EDIT — this file is generated by `atrg generate`.\n\n\
129 pub mod types;\n",
130 );
131 if opts.generate_routes && !all_routes.is_empty() {
132 mod_content.push_str("pub mod routes;\n");
133 }
134 std::fs::write(&mod_path, &mod_content)?;
135 report.output_files.push(mod_path.display().to_string());
136
137 tracing::info!(
138 files = report.files_processed,
139 types = report.types_generated,
140 stubs = report.stubs_generated,
141 "code generation complete"
142 );
143
144 Ok(report)
145}
146
147struct RouteEntry {
149 nsid: String,
150 method: &'static str, handler_name: String,
152}
153
154fn generate_for_lexicon(
155 doc: &LexiconDoc,
156 opts: &GenOptions,
157) -> anyhow::Result<(String, Vec<RouteEntry>, usize, usize)> {
158 let mut code = String::new();
159 let mut routes = Vec::new();
160 let mut type_count = 0;
161 let mut stub_count = 0;
162
163 let type_prefix = nsid_to_type_prefix(&doc.id);
164
165 for (def_name, def) in &doc.defs {
166 match def {
167 LexiconDef::Record {
168 description,
169 record: Some(obj),
170 ..
171 } => {
172 let struct_name = if def_name == "main" {
173 format!("{type_prefix}Record")
174 } else {
175 format!("{type_prefix}{}", def_name.to_case(Case::Pascal))
176 };
177 code.push_str(&generate_struct(&struct_name, description.as_deref(), obj));
178 type_count += 1;
179 }
180 LexiconDef::Object(obj) => {
181 let struct_name = if def_name == "main" {
182 type_prefix.clone()
183 } else {
184 format!("{type_prefix}{}", def_name.to_case(Case::Pascal))
185 };
186 code.push_str(&generate_struct(
187 &struct_name,
188 obj.description.as_deref(),
189 obj,
190 ));
191 type_count += 1;
192 }
193 LexiconDef::Query {
194 description: _,
195 parameters,
196 output,
197 } => {
198 if let Some(params) = parameters {
200 let name = format!("{type_prefix}Params");
201 code.push_str(&generate_struct(&name, None, params));
202 type_count += 1;
203 }
204 if let Some(out) = output {
206 if let Some(schema) = &out.schema {
207 let name = format!("{type_prefix}Output");
208 code.push_str(&generate_struct(&name, None, schema));
209 type_count += 1;
210 }
211 }
212 if opts.generate_stubs && def_name == "main" {
213 let handler = nsid_to_handler_name(&doc.id);
214 routes.push(RouteEntry {
215 nsid: doc.id.clone(),
216 method: "get",
217 handler_name: handler,
218 });
219 stub_count += 1;
220 }
221 }
222 LexiconDef::Procedure {
223 description: _,
224 input,
225 output,
226 } => {
227 if let Some(inp) = input {
229 if let Some(schema) = &inp.schema {
230 let name = format!("{type_prefix}Input");
231 code.push_str(&generate_struct(&name, None, schema));
232 type_count += 1;
233 }
234 }
235 if let Some(out) = output {
237 if let Some(schema) = &out.schema {
238 let name = format!("{type_prefix}Output");
239 code.push_str(&generate_struct(&name, None, schema));
240 type_count += 1;
241 }
242 }
243 if opts.generate_stubs && def_name == "main" {
244 let handler = nsid_to_handler_name(&doc.id);
245 routes.push(RouteEntry {
246 nsid: doc.id.clone(),
247 method: "post",
248 handler_name: handler,
249 });
250 stub_count += 1;
251 }
252 }
253 _ => {}
254 }
255 }
256
257 Ok((code, routes, type_count, stub_count))
258}
259
260fn generate_struct(name: &str, description: Option<&str>, obj: &lexicon::LexiconObject) -> String {
261 let mut s = String::new();
262
263 if let Some(desc) = description {
264 s.push_str(&format!("/// {desc}\n"));
265 }
266 s.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
267 s.push_str(&format!("pub struct {name} {{\n"));
268
269 let mut props: Vec<_> = obj.properties.iter().collect();
271 props.sort_by_key(|(k, _)| *k);
272
273 for (field_name, prop) in &props {
274 let rust_name = field_name.to_case(Case::Snake);
275 let rust_type = property_to_rust_type(prop, obj.required.contains(*field_name));
276
277 if let Some(desc) = &prop.description {
278 s.push_str(&format!(" /// {desc}\n"));
279 }
280
281 if rust_name != **field_name {
282 s.push_str(&format!(" #[serde(rename = \"{field_name}\")]\n"));
283 }
284
285 if !obj.required.contains(*field_name) {
286 s.push_str(" #[serde(default, skip_serializing_if = \"Option::is_none\")]\n");
287 }
288
289 s.push_str(&format!(" pub {rust_name}: {rust_type},\n"));
290 }
291
292 s.push_str("}\n\n");
293 s
294}
295
296fn property_to_rust_type(prop: &LexiconProperty, required: bool) -> String {
297 let base = match prop.prop_type.as_str() {
298 "string" => "String".to_string(),
299 "integer" => "i64".to_string(),
300 "boolean" => "bool".to_string(),
301 "blob" => "serde_json::Value".to_string(),
302 "unknown" => "serde_json::Value".to_string(),
303 "cid-link" => "String".to_string(),
304 "array" => {
305 if let Some(items) = &prop.items {
306 format!("Vec<{}>", property_to_rust_type(items, true))
307 } else {
308 "Vec<serde_json::Value>".to_string()
309 }
310 }
311 "ref" | "union" => "serde_json::Value".to_string(),
312 _ => "serde_json::Value".to_string(),
313 };
314
315 if required {
316 base
317 } else {
318 format!("Option<{base}>")
319 }
320}
321
322fn generate_routes_module(routes: &[RouteEntry]) -> String {
323 let mut s = String::from(
324 "//! Generated XRPC route wiring.\n\
325 //!\n\
326 //! DO NOT EDIT — this file is generated by `atrg generate`.\n\n\
327 use axum::{Router, routing::{get, post}, Json};\n\
328 use atrg_core::AppState;\n\
329 use atrg_xrpc::XrpcError;\n\n\
330 /// Mount all generated XRPC routes.\n\
331 pub fn xrpc_routes() -> Router<AppState> {\n\
332 \x20 atrg_xrpc::xrpc_router()\n",
333 );
334
335 for route in routes {
336 let method = route.method;
337 s.push_str(&format!(
338 " .route(\"/xrpc/{}\", {method}({}))\n",
339 route.nsid, route.handler_name
340 ));
341 }
342
343 s.push_str("}\n\n");
344
345 for route in routes {
347 s.push_str(&format!(
348 "/// Stub handler for `{}`.\n\
349 ///\n\
350 /// TODO: Implement this handler.\n\
351 async fn {}() -> Result<Json<serde_json::Value>, XrpcError> {{\n\
352 \x20 todo!(\"implement {}\")\n\
353 }}\n\n",
354 route.nsid, route.handler_name, route.nsid
355 ));
356 }
357
358 s
359}
360
361fn nsid_to_type_prefix(nsid: &str) -> String {
362 nsid.split('.')
363 .map(|s| s.to_case(Case::Pascal))
364 .collect::<Vec<_>>()
365 .join("")
366}
367
368fn nsid_to_handler_name(nsid: &str) -> String {
369 nsid.split('.')
370 .next_back()
371 .unwrap_or(nsid)
372 .to_case(Case::Snake)
373}
374
375fn format_code(code: &str) -> String {
376 match syn::parse_file(code) {
377 Ok(syntax_tree) => prettyplease::unparse(&syntax_tree),
378 Err(_) => {
379 tracing::warn!("generated code could not be parsed by syn; skipping formatting");
380 code.to_string()
381 }
382 }
383}
384
385#[cfg(test)]
386mod tests {
387 use super::*;
388 use std::fs;
389
390 fn setup_fixture(dir: &Path, files: &[(&str, &str)]) {
391 fs::create_dir_all(dir).unwrap();
392 for (name, content) in files {
393 fs::write(dir.join(name), content).unwrap();
394 }
395 }
396
397 #[test]
398 fn generate_from_query_lexicon() {
399 let input = tempfile::tempdir().unwrap();
400 let output = tempfile::tempdir().unwrap();
401
402 let lexicon = r#"{
403 "lexicon": 1,
404 "id": "com.atrg.test.ping",
405 "defs": {
406 "main": {
407 "type": "query",
408 "description": "Test ping",
409 "output": {
410 "encoding": "application/json",
411 "schema": {
412 "type": "object",
413 "required": ["pong"],
414 "properties": {
415 "pong": { "type": "boolean" },
416 "echo": { "type": "string" }
417 }
418 }
419 }
420 }
421 }
422 }"#;
423
424 setup_fixture(input.path(), &[("ping.json", lexicon)]);
425
426 let report = generate(input.path(), output.path(), GenOptions::default()).unwrap();
427 assert_eq!(report.files_processed, 1);
428 assert!(report.types_generated >= 1);
429 assert_eq!(report.stubs_generated, 1);
430
431 let types = fs::read_to_string(output.path().join("types.rs")).unwrap();
433 assert!(types.contains("ComAtrgTestPingOutput"));
434 assert!(types.contains("pub pong: bool"));
435 }
436
437 #[test]
438 fn generate_from_record_lexicon() {
439 let input = tempfile::tempdir().unwrap();
440 let output = tempfile::tempdir().unwrap();
441
442 let lexicon = r#"{
443 "lexicon": 1,
444 "id": "com.atrg.test.post",
445 "defs": {
446 "main": {
447 "type": "record",
448 "description": "A test post",
449 "key": "tid",
450 "record": {
451 "type": "object",
452 "required": ["text", "createdAt"],
453 "properties": {
454 "text": { "type": "string", "max_length": 3000 },
455 "createdAt": { "type": "string", "format": "datetime" }
456 }
457 }
458 }
459 }
460 }"#;
461
462 setup_fixture(input.path(), &[("post.json", lexicon)]);
463
464 let report = generate(input.path(), output.path(), GenOptions::default()).unwrap();
465 assert_eq!(report.files_processed, 1);
466 assert!(report.types_generated >= 1);
467
468 let types = fs::read_to_string(output.path().join("types.rs")).unwrap();
469 assert!(types.contains("ComAtrgTestPostRecord"));
470 assert!(types.contains("pub text: String"));
471 }
472
473 #[test]
474 fn malformed_lexicon_gives_error() {
475 let input = tempfile::tempdir().unwrap();
476 let output = tempfile::tempdir().unwrap();
477
478 setup_fixture(input.path(), &[("bad.json", "not valid json")]);
479
480 let result = generate(input.path(), output.path(), GenOptions::default());
481 assert!(result.is_err());
482 let err = result.unwrap_err().to_string();
483 assert!(
484 err.contains("bad.json"),
485 "error should mention the file: {err}"
486 );
487 }
488
489 #[test]
490 fn empty_dir_produces_empty_report() {
491 let input = tempfile::tempdir().unwrap();
492 let output = tempfile::tempdir().unwrap();
493
494 let report = generate(input.path(), output.path(), GenOptions::default()).unwrap();
495 assert_eq!(report.files_processed, 0);
496 }
497}