use std::path::Path;
use convert_case::{Case, Casing};
use crate::lexicon::{self, LexiconDef, LexiconDoc, LexiconProperty};
#[derive(Debug, Clone)]
pub struct GenOptions {
pub generate_stubs: bool,
pub generate_routes: bool,
}
impl Default for GenOptions {
fn default() -> Self {
Self {
generate_stubs: true,
generate_routes: true,
}
}
}
#[derive(Debug)]
pub struct GenReport {
pub files_processed: usize,
pub types_generated: usize,
pub stubs_generated: usize,
pub output_files: Vec<String>,
}
pub fn generate(
input_dir: &Path,
output_dir: &Path,
opts: GenOptions,
) -> anyhow::Result<GenReport> {
let mut report = GenReport {
files_processed: 0,
types_generated: 0,
stubs_generated: 0,
output_files: vec![],
};
let mut lexicons: Vec<LexiconDoc> = Vec::new();
for entry in walkdir::WalkDir::new(input_dir)
.into_iter()
.filter_map(|e| e.ok())
.filter(|e| e.path().extension().is_some_and(|ext| ext == "json"))
{
let content = std::fs::read_to_string(entry.path())?;
match lexicon::parse_lexicon(&content) {
Ok(doc) => {
tracing::debug!(id = %doc.id, path = %entry.path().display(), "parsed lexicon");
lexicons.push(doc);
report.files_processed += 1;
}
Err(e) => {
anyhow::bail!(
"Failed to parse lexicon at {}: {}",
entry.path().display(),
e
);
}
}
}
if lexicons.is_empty() {
tracing::warn!(dir = %input_dir.display(), "no lexicon JSON files found");
return Ok(report);
}
std::fs::create_dir_all(output_dir)?;
let mut all_types = String::new();
let mut all_routes = Vec::new();
for doc in &lexicons {
let (types_code, route_entries, type_count, stub_count) = generate_for_lexicon(doc, &opts)?;
all_types.push_str(&types_code);
all_types.push('\n');
all_routes.extend(route_entries);
report.types_generated += type_count;
report.stubs_generated += stub_count;
}
let types_path = output_dir.join("types.rs");
let types_content = format!(
"//! Generated types from AT Protocol lexicons.\n\
//!\n\
//! DO NOT EDIT — this file is generated by `atrg generate`.\n\n\
use serde::{{Deserialize, Serialize}};\n\n\
{all_types}"
);
let formatted = format_code(&types_content);
std::fs::write(&types_path, &formatted)?;
report.output_files.push(types_path.display().to_string());
if opts.generate_routes && !all_routes.is_empty() {
let routes_code = generate_routes_module(&all_routes);
let routes_path = output_dir.join("routes.rs");
let formatted = format_code(&routes_code);
std::fs::write(&routes_path, &formatted)?;
report.output_files.push(routes_path.display().to_string());
}
let mod_path = output_dir.join("mod.rs");
let mut mod_content = String::from(
"//! Generated code from AT Protocol lexicons.\n\
//!\n\
//! DO NOT EDIT — this file is generated by `atrg generate`.\n\n\
pub mod types;\n",
);
if opts.generate_routes && !all_routes.is_empty() {
mod_content.push_str("pub mod routes;\n");
}
std::fs::write(&mod_path, &mod_content)?;
report.output_files.push(mod_path.display().to_string());
tracing::info!(
files = report.files_processed,
types = report.types_generated,
stubs = report.stubs_generated,
"code generation complete"
);
Ok(report)
}
struct RouteEntry {
nsid: String,
method: &'static str, handler_name: String,
}
fn generate_for_lexicon(
doc: &LexiconDoc,
opts: &GenOptions,
) -> anyhow::Result<(String, Vec<RouteEntry>, usize, usize)> {
let mut code = String::new();
let mut routes = Vec::new();
let mut type_count = 0;
let mut stub_count = 0;
let type_prefix = nsid_to_type_prefix(&doc.id);
for (def_name, def) in &doc.defs {
match def {
LexiconDef::Record {
description,
record: Some(obj),
..
} => {
let struct_name = if def_name == "main" {
format!("{type_prefix}Record")
} else {
format!("{type_prefix}{}", def_name.to_case(Case::Pascal))
};
code.push_str(&generate_struct(&struct_name, description.as_deref(), obj));
type_count += 1;
}
LexiconDef::Object(obj) => {
let struct_name = if def_name == "main" {
type_prefix.clone()
} else {
format!("{type_prefix}{}", def_name.to_case(Case::Pascal))
};
code.push_str(&generate_struct(
&struct_name,
obj.description.as_deref(),
obj,
));
type_count += 1;
}
LexiconDef::Query {
description: _,
parameters,
output,
} => {
if let Some(params) = parameters {
let name = format!("{type_prefix}Params");
code.push_str(&generate_struct(&name, None, params));
type_count += 1;
}
if let Some(out) = output {
if let Some(schema) = &out.schema {
let name = format!("{type_prefix}Output");
code.push_str(&generate_struct(&name, None, schema));
type_count += 1;
}
}
if opts.generate_stubs && def_name == "main" {
let handler = nsid_to_handler_name(&doc.id);
routes.push(RouteEntry {
nsid: doc.id.clone(),
method: "get",
handler_name: handler,
});
stub_count += 1;
}
}
LexiconDef::Procedure {
description: _,
input,
output,
} => {
if let Some(inp) = input {
if let Some(schema) = &inp.schema {
let name = format!("{type_prefix}Input");
code.push_str(&generate_struct(&name, None, schema));
type_count += 1;
}
}
if let Some(out) = output {
if let Some(schema) = &out.schema {
let name = format!("{type_prefix}Output");
code.push_str(&generate_struct(&name, None, schema));
type_count += 1;
}
}
if opts.generate_stubs && def_name == "main" {
let handler = nsid_to_handler_name(&doc.id);
routes.push(RouteEntry {
nsid: doc.id.clone(),
method: "post",
handler_name: handler,
});
stub_count += 1;
}
}
_ => {}
}
}
Ok((code, routes, type_count, stub_count))
}
fn generate_struct(name: &str, description: Option<&str>, obj: &lexicon::LexiconObject) -> String {
let mut s = String::new();
if let Some(desc) = description {
s.push_str(&format!("/// {desc}\n"));
}
s.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
s.push_str(&format!("pub struct {name} {{\n"));
let mut props: Vec<_> = obj.properties.iter().collect();
props.sort_by_key(|(k, _)| *k);
for (field_name, prop) in &props {
let rust_name = field_name.to_case(Case::Snake);
let rust_type = property_to_rust_type(prop, obj.required.contains(*field_name));
if let Some(desc) = &prop.description {
s.push_str(&format!(" /// {desc}\n"));
}
if rust_name != **field_name {
s.push_str(&format!(" #[serde(rename = \"{field_name}\")]\n"));
}
if !obj.required.contains(*field_name) {
s.push_str(" #[serde(default, skip_serializing_if = \"Option::is_none\")]\n");
}
s.push_str(&format!(" pub {rust_name}: {rust_type},\n"));
}
s.push_str("}\n\n");
s
}
fn property_to_rust_type(prop: &LexiconProperty, required: bool) -> String {
let base = match prop.prop_type.as_str() {
"string" => "String".to_string(),
"integer" => "i64".to_string(),
"boolean" => "bool".to_string(),
"blob" => "serde_json::Value".to_string(),
"unknown" => "serde_json::Value".to_string(),
"cid-link" => "String".to_string(),
"array" => {
if let Some(items) = &prop.items {
format!("Vec<{}>", property_to_rust_type(items, true))
} else {
"Vec<serde_json::Value>".to_string()
}
}
"ref" | "union" => "serde_json::Value".to_string(),
_ => "serde_json::Value".to_string(),
};
if required {
base
} else {
format!("Option<{base}>")
}
}
fn generate_routes_module(routes: &[RouteEntry]) -> String {
let mut s = String::from(
"//! Generated XRPC route wiring.\n\
//!\n\
//! DO NOT EDIT — this file is generated by `atrg generate`.\n\n\
use axum::{Router, routing::{get, post}, Json};\n\
use atrg_core::AppState;\n\
use atrg_xrpc::XrpcError;\n\n\
/// Mount all generated XRPC routes.\n\
pub fn xrpc_routes() -> Router<AppState> {\n\
\x20 atrg_xrpc::xrpc_router()\n",
);
for route in routes {
let method = route.method;
s.push_str(&format!(
" .route(\"/xrpc/{}\", {method}({}))\n",
route.nsid, route.handler_name
));
}
s.push_str("}\n\n");
for route in routes {
s.push_str(&format!(
"/// Stub handler for `{}`.\n\
///\n\
/// TODO: Implement this handler.\n\
async fn {}() -> Result<Json<serde_json::Value>, XrpcError> {{\n\
\x20 todo!(\"implement {}\")\n\
}}\n\n",
route.nsid, route.handler_name, route.nsid
));
}
s
}
fn nsid_to_type_prefix(nsid: &str) -> String {
nsid.split('.')
.map(|s| s.to_case(Case::Pascal))
.collect::<Vec<_>>()
.join("")
}
fn nsid_to_handler_name(nsid: &str) -> String {
nsid.split('.')
.next_back()
.unwrap_or(nsid)
.to_case(Case::Snake)
}
fn format_code(code: &str) -> String {
match syn::parse_file(code) {
Ok(syntax_tree) => prettyplease::unparse(&syntax_tree),
Err(_) => {
tracing::warn!("generated code could not be parsed by syn; skipping formatting");
code.to_string()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
fn setup_fixture(dir: &Path, files: &[(&str, &str)]) {
fs::create_dir_all(dir).unwrap();
for (name, content) in files {
fs::write(dir.join(name), content).unwrap();
}
}
#[test]
fn generate_from_query_lexicon() {
let input = tempfile::tempdir().unwrap();
let output = tempfile::tempdir().unwrap();
let lexicon = r#"{
"lexicon": 1,
"id": "com.atrg.test.ping",
"defs": {
"main": {
"type": "query",
"description": "Test ping",
"output": {
"encoding": "application/json",
"schema": {
"type": "object",
"required": ["pong"],
"properties": {
"pong": { "type": "boolean" },
"echo": { "type": "string" }
}
}
}
}
}
}"#;
setup_fixture(input.path(), &[("ping.json", lexicon)]);
let report = generate(input.path(), output.path(), GenOptions::default()).unwrap();
assert_eq!(report.files_processed, 1);
assert!(report.types_generated >= 1);
assert_eq!(report.stubs_generated, 1);
let types = fs::read_to_string(output.path().join("types.rs")).unwrap();
assert!(types.contains("ComAtrgTestPingOutput"));
assert!(types.contains("pub pong: bool"));
}
#[test]
fn generate_from_record_lexicon() {
let input = tempfile::tempdir().unwrap();
let output = tempfile::tempdir().unwrap();
let lexicon = r#"{
"lexicon": 1,
"id": "com.atrg.test.post",
"defs": {
"main": {
"type": "record",
"description": "A test post",
"key": "tid",
"record": {
"type": "object",
"required": ["text", "createdAt"],
"properties": {
"text": { "type": "string", "max_length": 3000 },
"createdAt": { "type": "string", "format": "datetime" }
}
}
}
}
}"#;
setup_fixture(input.path(), &[("post.json", lexicon)]);
let report = generate(input.path(), output.path(), GenOptions::default()).unwrap();
assert_eq!(report.files_processed, 1);
assert!(report.types_generated >= 1);
let types = fs::read_to_string(output.path().join("types.rs")).unwrap();
assert!(types.contains("ComAtrgTestPostRecord"));
assert!(types.contains("pub text: String"));
}
#[test]
fn malformed_lexicon_gives_error() {
let input = tempfile::tempdir().unwrap();
let output = tempfile::tempdir().unwrap();
setup_fixture(input.path(), &[("bad.json", "not valid json")]);
let result = generate(input.path(), output.path(), GenOptions::default());
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("bad.json"),
"error should mention the file: {err}"
);
}
#[test]
fn empty_dir_produces_empty_report() {
let input = tempfile::tempdir().unwrap();
let output = tempfile::tempdir().unwrap();
let report = generate(input.path(), output.path(), GenOptions::default()).unwrap();
assert_eq!(report.files_processed, 0);
}
}