use tokio_postgres::{NoTls, Error};
use convert_case::{Case, Casing};
use chrono::NaiveDate;
use std::collections::HashMap;
async fn get_tables(client: &tokio_postgres::Client) -> Result<Vec<String>, Error> {
let rows = client
.query("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'", &[])
.await?;
Ok(rows.iter().map(|row| row.get(0)).collect())
}
async fn get_columns(client: &tokio_postgres::Client, table_name: &str) -> Result<HashMap<String, String>, Error> {
let query = format!(
"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = $1"
);
let rows = client.query(&query, &[&table_name]).await?;
Ok(rows.iter().map(|row| (row.get(0), row.get(1))).collect())
}
pub async fn generate_structs(database_url: &str) -> Result<(), Error> {
let (client, connection) = tokio_postgres::connect(database_url, NoTls).await?;
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("connection error: {}", e);
}
});
let tables = get_tables(&client).await?;
for table_name in tables {
let columns = get_columns(&client, &table_name).await?;
let struct_output = generate_struct(&table_name, columns, "Tom Blanchard", "https://github.com/tomblanchard312/rust_orm_gen", NaiveDate::from_ymd_opt(2024, 7, 24).unwrap());
println!("{}", struct_output);
}
Ok(())
}
pub fn generate_struct(table_name: &str, columns: HashMap<String, String>, author: &str, github_link: &str, date: NaiveDate) -> String {
let header = format!(
"/*\n * This code was generated by rust_orm_gen.\n * GitHub: {}\n * Date: {}\n * Author: {}\n */\n\n",
github_link, date.format("%Y-%m-%d"), author
);
let struct_name = table_name.to_case(Case::Pascal);
let mut struct_def = format!("{}#[derive(Debug, Serialize, Deserialize)]\npub struct {} {{\n", header, struct_name);
let mut sorted_columns: Vec<_> = columns.into_iter().collect();
sorted_columns.sort_by(|a, b| a.0.cmp(&b.0));
for (col_name, data_type) in sorted_columns {
let rust_field_name = col_name.replace(" ", "_");
let rust_type = map_data_type(&data_type);
struct_def.push_str(&format!(
" #[serde(rename = \"{}\")] pub {}: {},\n",
col_name, rust_field_name, rust_type
));
}
struct_def.push_str("}\n");
struct_def
}
fn map_data_type(data_type: &str) -> &str {
match data_type {
"integer" | "serial" => "i32",
"bigint" | "bigserial" => "i64",
"smallint" => "i16",
"boolean" => "bool",
"text" | "varchar" | "char" => "String",
"date" => "chrono::NaiveDate",
"timestamp" => "chrono::NaiveDateTime",
"timestamptz" | "timetz" => "chrono::DateTime<chrono::Utc>",
"time" => "chrono::NaiveTime",
"float4" => "f32",
"float8" => "f64",
"numeric" => "bigdecimal::BigDecimal",
"uuid" => "uuid::Uuid",
"json" | "jsonb" => "serde_json::Value",
"bytea" => "Vec<u8>",
_ => "String", }
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
use chrono::NaiveDate;
#[test]
fn test_generate_struct() {
let mut columns = HashMap::new();
columns.insert("id".to_string(), "integer".to_string());
columns.insert("name".to_string(), "text".to_string());
columns.insert("zip code".to_string(), "text".to_string());
let date = NaiveDate::from_ymd_opt(2024, 7, 24).unwrap();
let result = generate_struct("users", columns, "Tom Blanchard", "https://github.com/tomblanchard312/rust_orm_gen", date);
assert!(result.contains("pub id: i32,"), "Type conversion for 'id' is incorrect or missing");
assert!(result.contains("pub name: String,"), "Type conversion for 'name' is incorrect or missing");
assert!(result.contains("pub zip_code: String,"), "Type conversion for 'zip code' is incorrect or missing");
}
}