use std::path::{Path, PathBuf};
use sqlx::{PgPool, Row, SqlitePool};
use umbral_casing::{pascal_case_from_table, to_snake_case};
use crate::migrate::{self, Column, MigrationFile, ModelMeta, Operation, Snapshot};
use crate::orm::SqlType;
pub const INSPECTED_PLUGIN_NAME: &str = migrate::APP_PLUGIN_NAME;
pub const INITIAL_MIGRATION_ID: &str = "0001_initial";
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct IntrospectedSchema {
pub tables: Vec<IntrospectedTable>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct IntrospectedTable {
pub table: String,
pub name: String,
pub columns: Vec<IntrospectedColumn>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct IntrospectedColumn {
pub name: String,
pub ty: SqlType,
pub primary_key: bool,
pub nullable: bool,
}
#[derive(Debug)]
pub enum InspectError {
Io(std::io::Error),
Json(serde_json::Error),
Sqlx(sqlx::Error),
NoTables,
UnsupportedColumnType {
table: String,
column: String,
sql_type: String,
},
Migrate(migrate::MigrateError),
}
impl std::fmt::Display for InspectError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
InspectError::Io(e) => write!(f, "umbral inspectdb: io: {e}"),
InspectError::Json(e) => write!(f, "umbral inspectdb: json: {e}"),
InspectError::Sqlx(e) => write!(f, "umbral inspectdb: sqlx: {e}"),
InspectError::NoTables => write!(
f,
"umbral inspectdb: no tables found in the database (nothing to import)"
),
InspectError::UnsupportedColumnType {
table,
column,
sql_type,
} => write!(
f,
"umbral inspectdb: column `{table}.{column}` has unsupported SQL type `{sql_type}`; \
add a matching SqlType variant or edit the generated model by hand"
),
InspectError::Migrate(e) => write!(f, "umbral inspectdb: migrate: {e}"),
}
}
}
impl std::error::Error for InspectError {}
impl From<std::io::Error> for InspectError {
fn from(e: std::io::Error) -> Self {
Self::Io(e)
}
}
impl From<sqlx::Error> for InspectError {
fn from(e: sqlx::Error) -> Self {
Self::Sqlx(e)
}
}
impl From<serde_json::Error> for InspectError {
fn from(e: serde_json::Error) -> Self {
Self::Json(e)
}
}
impl From<migrate::MigrateError> for InspectError {
fn from(e: migrate::MigrateError) -> Self {
Self::Migrate(e)
}
}
#[derive(Debug, Clone)]
pub struct InspectOptions {
pub output: PathBuf,
pub mark_applied: bool,
}
#[derive(Debug, Clone, Default)]
pub struct InspectReport {
pub tables: usize,
pub columns: usize,
pub models_path: PathBuf,
pub migration_path: PathBuf,
}
pub async fn inspectdb(opts: InspectOptions) -> Result<InspectReport, InspectError> {
let schema = match crate::db::pool_dispatched() {
crate::db::DbPool::Sqlite(pool) => introspect_pool(pool).await?,
crate::db::DbPool::Postgres(pool) => introspect_pool_pg(pool).await?,
};
if schema.tables.is_empty() {
return Err(InspectError::NoTables);
}
let models_src = render_models(&schema);
let migration = render_initial_migration(&schema);
let report = write_outputs(&opts.output, &models_src, &migration).await?;
if opts.mark_applied {
let hash = migration.snapshot_after.hash();
migrate::record_applied(&migration.plugin, &migration.id, &hash).await?;
}
Ok(report)
}
pub async fn introspect_pool(pool: &SqlitePool) -> Result<IntrospectedSchema, InspectError> {
let table_rows = sqlx::query(
"SELECT name FROM sqlite_master \
WHERE type = 'table' \
AND name NOT LIKE 'sqlite_%' \
AND name <> 'umbral_migrations' \
ORDER BY name",
)
.fetch_all(pool)
.await?;
let mut tables: Vec<IntrospectedTable> = Vec::with_capacity(table_rows.len());
for row in table_rows {
let table: String = row.try_get("name")?;
let columns = introspect_columns(pool, &table).await?;
tables.push(IntrospectedTable {
name: pascal_case_from_table(&table),
table,
columns,
});
}
Ok(IntrospectedSchema { tables })
}
pub async fn introspect_pool_pg(pool: &PgPool) -> Result<IntrospectedSchema, InspectError> {
let table_rows: Vec<(String,)> = sqlx::query_as(
"SELECT table_name FROM information_schema.tables \
WHERE table_schema = 'public' \
AND table_type = 'BASE TABLE' \
AND table_name <> 'umbral_migrations' \
ORDER BY table_name",
)
.fetch_all(pool)
.await?;
let mut tables: Vec<IntrospectedTable> = Vec::with_capacity(table_rows.len());
for (table,) in table_rows {
let columns = introspect_columns_pg(pool, &table).await?;
tables.push(IntrospectedTable {
name: pascal_case_from_table(&table),
table,
columns,
});
}
Ok(IntrospectedSchema { tables })
}
async fn introspect_columns_pg(
pool: &PgPool,
table: &str,
) -> Result<Vec<IntrospectedColumn>, InspectError> {
let pk_rows: Vec<(String,)> = sqlx::query_as(
"SELECT kcu.column_name \
FROM information_schema.table_constraints tc \
JOIN information_schema.key_column_usage kcu \
ON tc.constraint_name = kcu.constraint_name \
AND tc.table_schema = kcu.table_schema \
WHERE tc.constraint_type = 'PRIMARY KEY' \
AND tc.table_schema = 'public' \
AND tc.table_name = $1",
)
.bind(table)
.fetch_all(pool)
.await?;
let pk_columns: std::collections::HashSet<String> = pk_rows.into_iter().map(|(c,)| c).collect();
let column_rows: Vec<(String, String, String, String)> = sqlx::query_as(
"SELECT column_name, data_type, is_nullable, udt_name \
FROM information_schema.columns \
WHERE table_schema = 'public' AND table_name = $1 \
ORDER BY ordinal_position",
)
.bind(table)
.fetch_all(pool)
.await?;
let mut columns: Vec<IntrospectedColumn> = Vec::with_capacity(column_rows.len());
for (name, data_type, is_nullable, udt_name) in column_rows {
let ty = if data_type.eq_ignore_ascii_case("ARRAY") {
let elem_name = udt_name.strip_prefix('_').unwrap_or(udt_name.as_str());
map_postgres_array_element(elem_name).ok_or_else(|| {
InspectError::UnsupportedColumnType {
table: table.to_string(),
column: name.clone(),
sql_type: format!("ARRAY of {elem_name}"),
}
})?
} else {
map_postgres_type(&data_type).ok_or_else(|| InspectError::UnsupportedColumnType {
table: table.to_string(),
column: name.clone(),
sql_type: data_type.clone(),
})?
};
let primary_key = pk_columns.contains(&name);
let nullable = if primary_key {
false
} else {
is_nullable.eq_ignore_ascii_case("YES")
};
columns.push(IntrospectedColumn {
name,
ty,
primary_key,
nullable,
});
}
Ok(columns)
}
fn map_postgres_array_element(elem: &str) -> Option<SqlType> {
use crate::orm::ArrayElement;
let kind = match elem.trim().to_ascii_lowercase().as_str() {
"int2" => ArrayElement::SmallInt,
"int4" => ArrayElement::Integer,
"int8" => ArrayElement::BigInt,
"float4" => ArrayElement::Real,
"float8" => ArrayElement::Double,
"bool" => ArrayElement::Boolean,
"text" | "varchar" | "bpchar" => ArrayElement::Text,
"uuid" => ArrayElement::Uuid,
_ => return None,
};
Some(SqlType::Array(kind))
}
fn map_postgres_type(raw: &str) -> Option<SqlType> {
let normalised = raw.trim().to_ascii_lowercase();
match normalised.as_str() {
"smallint" => Some(SqlType::SmallInt),
"integer" => Some(SqlType::Integer),
"bigint" => Some(SqlType::BigInt),
"real" => Some(SqlType::Real),
"double precision" => Some(SqlType::Double),
"boolean" => Some(SqlType::Boolean),
"text" | "character varying" | "character" => Some(SqlType::Text),
"date" => Some(SqlType::Date),
"time without time zone" | "time with time zone" => Some(SqlType::Time),
"timestamp without time zone" | "timestamp with time zone" => Some(SqlType::Timestamptz),
"uuid" => Some(SqlType::Uuid),
"json" | "jsonb" => Some(SqlType::Json),
"inet" => Some(SqlType::Inet),
"cidr" => Some(SqlType::Cidr),
"macaddr" => Some(SqlType::MacAddr),
"xml" => Some(SqlType::Xml),
"ltree" => Some(SqlType::Ltree),
"bit" | "bit varying" | "varbit" => Some(SqlType::Bit),
"tsvector" => Some(SqlType::FullText),
"bytea" => Some(SqlType::Bytes),
_ => None,
}
}
async fn introspect_columns(
pool: &SqlitePool,
table: &str,
) -> Result<Vec<IntrospectedColumn>, InspectError> {
let sql = format!("PRAGMA table_info(\"{}\")", table.replace('"', "\"\""));
let mut rows = sqlx::query(&sql).fetch_all(pool).await?;
rows.sort_by_key(|r| r.try_get::<i64, _>("cid").unwrap_or(0));
let mut columns: Vec<IntrospectedColumn> = Vec::with_capacity(rows.len());
for row in rows {
let name: String = row.try_get("name")?;
let raw_type: String = row.try_get("type")?;
let notnull: i64 = row.try_get("notnull")?;
let pk: i64 = row.try_get("pk")?;
let ty = map_sqlite_type(&raw_type).ok_or_else(|| InspectError::UnsupportedColumnType {
table: table.to_string(),
column: name.clone(),
sql_type: raw_type.clone(),
})?;
let primary_key = pk != 0;
let nullable = if primary_key { false } else { notnull == 0 };
columns.push(IntrospectedColumn {
name,
ty,
primary_key,
nullable,
});
}
Ok(columns)
}
fn map_sqlite_type(raw: &str) -> Option<SqlType> {
let head = match raw.split_once('(') {
Some((before, _)) => before,
None => raw,
};
let normalised = head.trim().to_ascii_lowercase();
match normalised.as_str() {
"smallint" | "int2" => Some(SqlType::SmallInt),
"int" | "integer" | "int4" => Some(SqlType::Integer),
"bigint" | "int8" => Some(SqlType::BigInt),
"real" | "float" | "float4" => Some(SqlType::Real),
"double" | "double precision" | "float8" => Some(SqlType::Double),
"boolean" | "bool" => Some(SqlType::Boolean),
"text" | "varchar" | "char" | "clob" | "character" | "varying character" | "nchar"
| "nvarchar" => Some(SqlType::Text),
"date" => Some(SqlType::Date),
"time" => Some(SqlType::Time),
"timestamp" | "timestamptz" | "datetime" => Some(SqlType::Timestamptz),
"uuid" => Some(SqlType::Uuid),
"json" | "jsonb" => Some(SqlType::Json),
"blob" | "bytea" => Some(SqlType::Bytes),
_ => None,
}
}
pub fn render_models(schema: &IntrospectedSchema) -> String {
let mut out = String::new();
out.push_str(HEADER);
let mut tables: Vec<&IntrospectedTable> = schema.tables.iter().collect();
tables.sort_by(|a, b| a.name.cmp(&b.name));
for table in tables {
out.push('\n');
out.push_str(&render_one_struct(table));
}
out
}
const HEADER: &str = "\
//! Generated by `umbral inspectdb`. Wire each struct into your App
//! builder with `.model::<StructName>()`. Re-run `inspectdb` to
//! regenerate; edits made by hand will be lost.
use umbral::prelude::*;
";
fn render_one_struct(table: &IntrospectedTable) -> String {
let mut out = String::new();
out.push_str("#[derive(Debug, Clone, sqlx::FromRow, Model)]\n");
if to_snake_case(&table.name) != table.table {
out.push_str(&format!("#[umbral(table = \"{}\")]\n", table.table));
}
out.push_str(&format!("pub struct {} {{\n", table.name));
for column in &table.columns {
out.push_str(&format!(
" pub {}: {},\n",
column.name,
render_field_type(column.ty, column.nullable),
));
}
out.push_str("}\n");
out
}
fn render_field_type(ty: SqlType, nullable: bool) -> String {
let base = match ty {
SqlType::SmallInt => "i16".to_string(),
SqlType::Integer => "i32".to_string(),
SqlType::BigInt => "i64".to_string(),
SqlType::Real => "f32".to_string(),
SqlType::Double => "f64".to_string(),
SqlType::Boolean => "bool".to_string(),
SqlType::Text => "String".to_string(),
SqlType::Date => "chrono::NaiveDate".to_string(),
SqlType::Time => "chrono::NaiveTime".to_string(),
SqlType::Timestamptz => "chrono::DateTime<chrono::Utc>".to_string(),
SqlType::Uuid => "uuid::Uuid".to_string(),
SqlType::Json => "serde_json::Value".to_string(),
SqlType::Array(elem) => format!("Vec<{}>", render_field_type(elem.to_sql_type(), false)),
SqlType::Inet => "ipnetwork::IpNetwork".to_string(),
SqlType::Cidr => "ipnetwork::IpNetwork".to_string(),
SqlType::MacAddr => "mac_address::MacAddress".to_string(),
SqlType::Xml => "String".to_string(),
SqlType::Ltree => "String".to_string(),
SqlType::Bit => "String".to_string(),
SqlType::FullText => "umbral::orm::TsVector".to_string(),
SqlType::ForeignKey => "i64".to_string(),
SqlType::Bytes => "Vec<u8>".to_string(),
SqlType::Decimal => "rust_decimal::Decimal".to_string(),
};
let base = base.as_str();
if nullable {
format!("Option<{base}>")
} else {
base.to_string()
}
}
pub fn render_initial_migration(schema: &IntrospectedSchema) -> MigrationFile {
let mut models: Vec<ModelMeta> = schema
.tables
.iter()
.map(|t| ModelMeta {
name: t.name.clone(),
table: t.table.clone(),
fields: t.columns.iter().map(Column::from).collect(),
display: t.name.clone(),
icon: "database".to_string(),
database: None,
singleton: false,
unique_together: Vec::new(),
indexes: Vec::new(),
ordering: Vec::new(),
m2m_relations: Vec::new(),
soft_delete: false,
app_label: "app".to_string(),
})
.collect();
models.sort_by(|a, b| a.name.cmp(&b.name));
let operations = schema
.tables
.iter()
.map(|t| Operation::CreateTable {
table: t.table.clone(),
columns: t.columns.iter().map(Column::from).collect(),
unique_together: Vec::new(),
indexes: Vec::new(),
})
.collect();
MigrationFile {
id: INITIAL_MIGRATION_ID.to_string(),
plugin: INSPECTED_PLUGIN_NAME.to_string(),
depends_on: Vec::new(),
operations,
snapshot_after: Snapshot { models },
}
}
pub async fn write_outputs(
output: &Path,
models_src: &str,
migration: &MigrationFile,
) -> Result<InspectReport, InspectError> {
std::fs::create_dir_all(output)?;
let models_path = output.join("models.rs");
std::fs::write(&models_path, models_src)?;
let plugin_dir = output.join("migrations").join(INSPECTED_PLUGIN_NAME);
std::fs::create_dir_all(&plugin_dir)?;
let migration_path = plugin_dir.join(format!("{}.json", migration.id));
let json = serde_json::to_string_pretty(migration)?;
std::fs::write(&migration_path, json)?;
let (tables, columns) =
migration
.operations
.iter()
.fold((0usize, 0usize), |(t, c), op| match op {
Operation::CreateTable { columns, .. } => (t + 1, c + columns.len()),
Operation::CreateM2MTable { .. } => (t + 1, c + 2),
Operation::DropTable { .. }
| Operation::DropM2MTable { .. }
| Operation::AddColumn { .. }
| Operation::DropColumn { .. }
| Operation::AlterColumn { .. }
| Operation::RenameTable { .. }
| Operation::RenameColumn { .. }
| Operation::RunSql { .. } => (t, c),
});
Ok(InspectReport {
tables,
columns,
models_path,
migration_path,
})
}
impl From<&IntrospectedColumn> for Column {
fn from(c: &IntrospectedColumn) -> Self {
Self {
name: c.name.clone(),
ty: c.ty,
primary_key: c.primary_key,
nullable: c.nullable,
fk_target: None,
noform: false,
db_constraint: true,
noedit: false,
is_string_repr: false,
max_length: 0,
choices: Vec::new(),
choice_labels: Vec::new(),
default: String::new(),
is_multichoice: false,
unique: false,
on_delete: crate::orm::FkAction::NoAction,
on_update: crate::orm::FkAction::NoAction,
index: false,
auto_now_add: false,
auto_now: false,
help: String::new(),
example: String::new(),
widget: None,
supported_backends: Vec::new(),
min: None,
max: None,
text_format: ::core::option::Option::None,
slug_from: ::core::option::Option::None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn col(name: &str, ty: SqlType, primary_key: bool, nullable: bool) -> IntrospectedColumn {
IntrospectedColumn {
name: name.to_string(),
ty,
primary_key,
nullable,
}
}
#[test]
fn empty_schema_renders_header_only() {
let out = render_models(&IntrospectedSchema { tables: Vec::new() });
assert_eq!(out, HEADER);
}
#[test]
fn snake_case_table_skips_attribute_when_derive_round_trips() {
let schema = IntrospectedSchema {
tables: vec![IntrospectedTable {
table: "blog_post".to_string(),
name: "BlogPost".to_string(),
columns: vec![
col("id", SqlType::BigInt, true, false),
col("title", SqlType::Text, false, false),
],
}],
};
let out = render_models(&schema);
assert!(!out.contains("#[umbral(table"));
assert!(out.contains("pub struct BlogPost {"));
assert!(out.contains("pub id: i64,"));
assert!(out.contains("pub title: String,"));
}
#[test]
fn lowercase_single_word_table_skips_attribute() {
let schema = IntrospectedSchema {
tables: vec![IntrospectedTable {
table: "post".to_string(),
name: "Post".to_string(),
columns: vec![col("id", SqlType::BigInt, true, false)],
}],
};
let out = render_models(&schema);
assert!(!out.contains("#[umbral(table"));
assert!(out.contains("pub struct Post {"));
}
#[test]
fn non_round_tripping_table_name_keeps_attribute() {
let schema = IntrospectedSchema {
tables: vec![IntrospectedTable {
table: "POSTS".to_string(),
name: "Posts".to_string(),
columns: vec![col("id", SqlType::BigInt, true, false)],
}],
};
let out = render_models(&schema);
assert!(out.contains("#[umbral(table = \"POSTS\")]"));
}
#[test]
fn nullable_column_wraps_in_option() {
let schema = IntrospectedSchema {
tables: vec![IntrospectedTable {
table: "post".to_string(),
name: "Post".to_string(),
columns: vec![
col("id", SqlType::BigInt, true, false),
col("published_at", SqlType::Timestamptz, false, true),
],
}],
};
let out = render_models(&schema);
assert!(out.contains("pub published_at: Option<chrono::DateTime<chrono::Utc>>,"));
}
#[test]
fn type_catalogue_renders_each_sql_type() {
let schema = IntrospectedSchema {
tables: vec![IntrospectedTable {
table: "kitchen_sink".to_string(),
name: "KitchenSink".to_string(),
columns: vec![
col("id", SqlType::BigInt, true, false),
col("small", SqlType::SmallInt, false, false),
col("medium", SqlType::Integer, false, false),
col("real_v", SqlType::Real, false, false),
col("double_v", SqlType::Double, false, false),
col("flag", SqlType::Boolean, false, false),
col("note", SqlType::Text, false, false),
col("day", SqlType::Date, false, false),
col("clock", SqlType::Time, false, false),
col("at", SqlType::Timestamptz, false, false),
col("uid", SqlType::Uuid, false, false),
],
}],
};
let out = render_models(&schema);
for expected in [
"pub id: i64,",
"pub small: i16,",
"pub medium: i32,",
"pub real_v: f32,",
"pub double_v: f64,",
"pub flag: bool,",
"pub note: String,",
"pub day: chrono::NaiveDate,",
"pub clock: chrono::NaiveTime,",
"pub at: chrono::DateTime<chrono::Utc>,",
"pub uid: uuid::Uuid,",
] {
assert!(out.contains(expected), "missing field render: {expected}");
}
}
#[test]
fn structs_are_sorted_by_name() {
let schema = IntrospectedSchema {
tables: vec![
IntrospectedTable {
table: "zebra".to_string(),
name: "Zebra".to_string(),
columns: vec![col("id", SqlType::BigInt, true, false)],
},
IntrospectedTable {
table: "antelope".to_string(),
name: "Antelope".to_string(),
columns: vec![col("id", SqlType::BigInt, true, false)],
},
],
};
let out = render_models(&schema);
let antelope_at = out.find("struct Antelope").expect("Antelope rendered");
let zebra_at = out.find("struct Zebra").expect("Zebra rendered");
assert!(antelope_at < zebra_at);
}
#[test]
fn header_carries_the_regen_warning_and_facade_import() {
let out = render_models(&IntrospectedSchema { tables: Vec::new() });
assert!(out.contains("Generated by `umbral inspectdb`"));
assert!(out.contains("edits made by hand will be lost"));
assert!(out.contains("use umbral::prelude::*;"));
}
#[test]
fn map_postgres_type_covers_the_full_catalogue() {
assert_eq!(map_postgres_type("smallint"), Some(SqlType::SmallInt));
assert_eq!(map_postgres_type("integer"), Some(SqlType::Integer));
assert_eq!(map_postgres_type("bigint"), Some(SqlType::BigInt));
assert_eq!(map_postgres_type("real"), Some(SqlType::Real));
assert_eq!(map_postgres_type("double precision"), Some(SqlType::Double));
assert_eq!(map_postgres_type("boolean"), Some(SqlType::Boolean));
assert_eq!(map_postgres_type("text"), Some(SqlType::Text));
assert_eq!(
map_postgres_type("character varying"),
Some(SqlType::Text),
"VARCHAR maps to Text",
);
assert_eq!(
map_postgres_type("character"),
Some(SqlType::Text),
"CHAR maps to Text",
);
assert_eq!(map_postgres_type("date"), Some(SqlType::Date));
assert_eq!(
map_postgres_type("time without time zone"),
Some(SqlType::Time),
);
assert_eq!(
map_postgres_type("time with time zone"),
Some(SqlType::Time)
);
assert_eq!(
map_postgres_type("timestamp without time zone"),
Some(SqlType::Timestamptz),
);
assert_eq!(
map_postgres_type("timestamp with time zone"),
Some(SqlType::Timestamptz),
);
assert_eq!(map_postgres_type("uuid"), Some(SqlType::Uuid));
assert_eq!(map_postgres_type("json"), Some(SqlType::Json));
assert_eq!(map_postgres_type("jsonb"), Some(SqlType::Json));
assert_eq!(map_postgres_type("inet"), Some(SqlType::Inet));
assert_eq!(map_postgres_type("cidr"), Some(SqlType::Cidr));
assert_eq!(map_postgres_type("macaddr"), Some(SqlType::MacAddr));
assert_eq!(map_postgres_type("bytea"), Some(SqlType::Bytes));
}
#[test]
fn map_postgres_type_returns_none_for_postgres_only_types() {
assert_eq!(map_postgres_type("numeric"), None);
assert_eq!(map_postgres_type("ARRAY"), None);
}
#[test]
fn map_postgres_type_is_case_insensitive_on_input() {
assert_eq!(map_postgres_type("INTEGER"), Some(SqlType::Integer));
assert_eq!(map_postgres_type("Bigint"), Some(SqlType::BigInt));
assert_eq!(map_postgres_type("UUID"), Some(SqlType::Uuid));
}
#[test]
fn map_postgres_type_trims_whitespace() {
assert_eq!(map_postgres_type(" bigint "), Some(SqlType::BigInt));
}
}