use heck::{ToPascalCase, ToSnakeCase};
use reqwest::blocking::get;
use sqlparser::ast::{ColumnOption, DataType, Statement};
use sqlparser::dialect::GenericDialect;
use sqlparser::parser::Parser;
use std::fs::File;
use std::io::{BufWriter, Write};
use std::path::Path;
const NAHPU_TABLES_URL: &str =
"https://raw.githubusercontent.com/nahpu/nahpu/main/lib/services/database/tables.drift";
const GENERATED_FILE: &str = "nahpu_sqlite.rs";
fn sql_type_to_rust(sql_type: &DataType, is_nullable: bool) -> String {
let base_type = match sql_type {
DataType::Text | DataType::Varchar(_) | DataType::String(_) => "String".to_string(),
DataType::Int(_) | DataType::Integer(_) => "i32".to_string(),
DataType::Real | DataType::Float(_) | DataType::Double(_) => "f64".to_string(),
DataType::Boolean => "bool".to_string(),
_ => {
eprintln!(
"Warning: Unknown SQL type {:?}, defaulting to String.",
sql_type
);
"String".to_string()
}
};
if is_nullable {
format!("Option<{}>", base_type)
} else {
base_type
}
}
fn clean_drift_content(content: &str) -> String {
content
.split(';') .filter(|stmt| stmt.trim().to_uppercase().starts_with("CREATE TABLE"))
.map(|stmt| format!("{};", stmt.trim())) .collect()
}
use std::env;
fn write_rust_file(content: &str) {
let out_dir = env::var("OUT_DIR").expect("OUT_DIR environment variable is not set");
let dest_path = Path::new(&out_dir).join(GENERATED_FILE);
let file = File::create(&dest_path).expect("Unable to create file");
let mut writer = BufWriter::new(file);
let header = format!(
r#"// This file is auto-generated by build.rs. Do not edit directly.
// It is derived from the NAHPU Drift schema (CREATE TABLE statements only).
// Source: {}
// Regenerate by running `cargo build`."#,
NAHPU_TABLES_URL
);
writeln!(writer, "{}\n\n{}", header, content).expect("Unable to write data to file");
println!("Generated Rust code written to {:?}", dest_path);
}
fn create_rust_code(drift_content: &str) -> String {
let dialect = GenericDialect {}; let ast =
Parser::parse_sql(&dialect, drift_content).expect("Failed to parse drift file as SQL");
let mut rust_code = String::new();
rust_code.push_str("use serde::{Deserialize, Serialize};\n\n");
for statement in ast {
if let Statement::CreateTable(create_table) = statement {
let table_name = match create_table.name.0.last().unwrap() {
sqlparser::ast::ObjectNamePart::Identifier(ident) => ident.value.clone(),
_ => panic!("Expected identifier"),
};
let struct_name = table_name.to_pascal_case();
rust_code.push_str("#[derive(Serialize, Deserialize, Debug, Clone)]\n");
rust_code.push_str("#[serde(rename_all = \"camelCase\")]\n");
rust_code.push_str(&format!("pub struct {} {{\n", struct_name));
for col in create_table.columns {
let col_name = col.name.value.clone();
let field_name_snake = col_name.to_snake_case();
let is_nullable = !col
.options
.iter()
.any(|opt| matches!(opt.option, ColumnOption::NotNull));
let rust_type = sql_type_to_rust(&col.data_type, is_nullable);
if field_name_snake == "type" {
rust_code.push_str(&format!(" #[serde(rename = \"type\")]\n"));
rust_code.push_str(&format!(" pub type_: {},\n", rust_type));
} else {
rust_code.push_str(&format!(" pub {}: {},\n", field_name_snake, rust_type));
}
}
rust_code.push_str("}\n\n");
}
}
rust_code
}
fn main() {
let drift_file_url =
"https://raw.githubusercontent.com/nahpu/nahpu/main/lib/services/database/tables.drift";
println!("cargo:rerun-if-changed=build.rs"); println!("Fetching drift file from: {}", drift_file_url);
let drift_content = match get(drift_file_url) {
Ok(response) => response.text().expect("Failed to read response text"),
Err(e) => {
println!(
"cargo:warning=Failed to fetch drift file: {}. Skipping generation.",
e
);
return;
}
};
let cleaned_content = clean_drift_content(&drift_content);
let rust_code = create_rust_code(&cleaned_content);
write_rust_file(&rust_code);
}