use std::collections::HashMap;
use diesel::result::Error::NotFound;
use super::data_structures::*;
use super::table_data::*;
use crate::config::{Filtering, PrintSchema};
use crate::database::InferConnection;
use crate::print_schema::{ColumnSorting, DocConfig};
static RESERVED_NAMES: &[&str] = &[
"abstract",
"alignof",
"as",
"become",
"box",
"break",
"const",
"continue",
"crate",
"do",
"else",
"enum",
"extern",
"false",
"final",
"fn",
"for",
"if",
"impl",
"in",
"let",
"loop",
"macro",
"match",
"mod",
"move",
"mut",
"offsetof",
"override",
"priv",
"proc",
"pub",
"pure",
"ref",
"return",
"Self",
"self",
"sizeof",
"static",
"struct",
"super",
"trait",
"true",
"type",
"typeof",
"unsafe",
"unsized",
"use",
"virtual",
"where",
"while",
"yield",
"bool",
"table",
"columns",
"is_nullable",
];
fn is_reserved_name(name: &str) -> bool {
RESERVED_NAMES.contains(&name)
|| (
name.starts_with("__") && !name.ends_with('_')
)
}
fn contains_unmappable_chars(name: &str) -> bool {
!name.chars().all(|c| c.is_ascii_alphanumeric() || c == '_')
}
pub fn rust_name_for_sql_name(sql_name: &str) -> String {
if is_reserved_name(sql_name) {
format!("{sql_name}_")
} else if contains_unmappable_chars(sql_name) {
let mut rust_name: String = sql_name
.chars()
.map(|c| if c.is_ascii_alphanumeric() { c } else { '_' })
.collect();
let mut last_len = rust_name.len();
'remove_adjoining: loop {
rust_name = rust_name.replace("__", "_");
if rust_name.len() == last_len {
break 'remove_adjoining;
}
last_len = rust_name.len();
}
rust_name
} else {
sql_name.to_string()
}
}
#[tracing::instrument(skip(connection))]
pub fn load_table_names(
connection: &mut InferConnection,
schema_name: Option<&str>,
) -> Result<Vec<TableName>, crate::errors::Error> {
let tables = match connection {
#[cfg(feature = "sqlite")]
InferConnection::Sqlite(ref mut c) => super::sqlite::load_table_names(c, schema_name),
#[cfg(feature = "postgres")]
InferConnection::Pg(ref mut c) => {
super::information_schema::load_table_names(c, schema_name)
}
#[cfg(feature = "mysql")]
InferConnection::Mysql(ref mut c) => {
super::information_schema::load_table_names(c, schema_name)
}
}?;
tracing::info!(?tables, "Loaded tables");
Ok(tables)
}
pub fn filter_table_names(table_names: Vec<TableName>, table_filter: &Filtering) -> Vec<TableName> {
table_names
.into_iter()
.filter(|t| !table_filter.should_ignore_table(t))
.collect::<_>()
}
#[tracing::instrument(skip(conn))]
fn get_table_comment(
conn: &mut InferConnection,
table: &TableName,
) -> Result<Option<String>, crate::errors::Error> {
let table_comment = match *conn {
#[cfg(feature = "sqlite")]
InferConnection::Sqlite(_) => Ok(None),
#[cfg(feature = "postgres")]
InferConnection::Pg(ref mut c) => super::pg::get_table_comment(c, table),
#[cfg(feature = "mysql")]
InferConnection::Mysql(ref mut c) => super::mysql::get_table_comment(c, table),
};
if let Err(NotFound) = table_comment {
Err(crate::errors::Error::NoTableFound(table.clone()))
} else {
let table_comment = table_comment?;
tracing::info!(?table_comment, "Load table comments for {table}");
Ok(table_comment)
}
}
fn get_column_information(
conn: &mut InferConnection,
table: &TableName,
column_sorting: &ColumnSorting,
pg_domains_as_custom_types: &[®ex::Regex],
) -> Result<Vec<ColumnInformation>, crate::errors::Error> {
#[cfg(not(feature = "postgres"))]
let _ = pg_domains_as_custom_types;
let column_info = match *conn {
#[cfg(feature = "sqlite")]
InferConnection::Sqlite(ref mut c) => {
super::sqlite::get_table_data(c, table, column_sorting)
}
#[cfg(feature = "postgres")]
InferConnection::Pg(ref mut c) => {
super::pg::get_table_data(c, table, column_sorting, pg_domains_as_custom_types)
}
#[cfg(feature = "mysql")]
InferConnection::Mysql(ref mut c) => super::mysql::get_table_data(c, table, column_sorting),
};
if let Err(NotFound) = column_info {
Err(crate::errors::Error::NoTableFound(table.clone()))
} else {
let column_info = column_info?;
tracing::info!(?column_info, "Load column information for table {table}");
Ok(column_info)
}
}
fn determine_column_type(
attr: &ColumnInformation,
conn: &mut InferConnection,
#[allow(unused_variables)] table: &TableName,
#[allow(unused_variables)] primary_keys: &[String],
#[allow(unused_variables)] foreign_keys: &HashMap<String, ForeignKeyConstraint>,
#[allow(unused_variables)] config: &PrintSchema,
) -> Result<ColumnType, crate::errors::Error> {
match *conn {
#[cfg(feature = "sqlite")]
InferConnection::Sqlite(ref mut conn) => super::sqlite::determine_column_type(
conn,
attr,
table,
primary_keys,
foreign_keys,
config,
),
#[cfg(feature = "postgres")]
InferConnection::Pg(ref mut conn) => {
use crate::infer_schema_internals::information_schema::DefaultSchema;
super::pg::determine_column_type(attr, diesel::pg::Pg::default_schema(conn)?)
}
#[cfg(feature = "mysql")]
InferConnection::Mysql(_) => super::mysql::determine_column_type(attr),
}
}
#[tracing::instrument(skip(conn))]
pub(crate) fn get_primary_keys(
conn: &mut InferConnection,
table: &TableName,
) -> Result<Vec<String>, crate::errors::Error> {
let primary_keys: Vec<String> = match *conn {
#[cfg(feature = "sqlite")]
InferConnection::Sqlite(ref mut c) => super::sqlite::get_primary_keys(c, table),
#[cfg(feature = "postgres")]
InferConnection::Pg(ref mut c) => super::information_schema::get_primary_keys(c, table),
#[cfg(feature = "mysql")]
InferConnection::Mysql(ref mut c) => super::information_schema::get_primary_keys(c, table),
}?;
if primary_keys.is_empty() {
Err(crate::errors::Error::NoPrimaryKeyFound(table.clone()))
} else {
tracing::info!(?primary_keys, "Load primary keys for table {table}");
Ok(primary_keys)
}
}
#[tracing::instrument(skip(connection))]
pub fn load_foreign_key_constraints(
connection: &mut InferConnection,
schema_name: Option<&str>,
) -> Result<Vec<ForeignKeyConstraint>, crate::errors::Error> {
let constraints = match connection {
#[cfg(feature = "sqlite")]
InferConnection::Sqlite(ref mut c) => {
super::sqlite::load_foreign_key_constraints(c, schema_name)
}
#[cfg(feature = "postgres")]
InferConnection::Pg(ref mut c) => {
super::pg::load_foreign_key_constraints(c, schema_name).map_err(Into::into)
}
#[cfg(feature = "mysql")]
InferConnection::Mysql(ref mut c) => {
super::mysql::load_foreign_key_constraints(c, schema_name).map_err(Into::into)
}
};
constraints.map(|mut ct| {
ct.sort();
ct.iter_mut().for_each(|foreign_key_constraint| {
for name in &mut foreign_key_constraint.foreign_key_columns_rust {
if is_reserved_name(name) {
*name = format!("{name}_");
}
}
});
tracing::info!(?ct, "Loaded foreign key constraints");
ct
})
}
#[tracing::instrument(skip(connection))]
pub fn load_table_data(
connection: &mut InferConnection,
name: TableName,
config: &PrintSchema,
) -> Result<TableData, crate::errors::Error> {
let table_comment = match config.with_docs {
DocConfig::NoDocComments => None,
DocConfig::OnlyDatabaseComments
| DocConfig::DatabaseCommentsFallbackToAutoGeneratedDocComment => {
get_table_comment(connection, &name)?
}
};
let primary_key = get_primary_keys(connection, &name)?;
let foreign_keys = load_foreign_key_constraints(connection, name.schema.as_deref())?
.into_iter()
.filter_map(|c| {
if c.child_table == name && c.foreign_key_columns.len() == 1 {
Some((c.foreign_key_columns_rust[0].clone(), c))
} else {
None
}
})
.collect();
let pg_domains_as_custom_types = config
.pg_domains_as_custom_types
.iter()
.map(|regex| regex as ®ex::Regex)
.collect::<Vec<_>>();
let column_data = get_column_information(
connection,
&name,
&config.column_sorting,
&pg_domains_as_custom_types,
)?
.into_iter()
.map(|c| {
let ty = determine_column_type(&c, connection, &name, &primary_key, &foreign_keys, config)?;
let ColumnInformation {
column_name,
comment,
..
} = c;
let rust_name = rust_name_for_sql_name(&column_name);
Ok(ColumnDefinition {
sql_name: column_name,
ty,
rust_name,
comment,
})
})
.collect::<Result<_, crate::errors::Error>>()?;
let primary_key = primary_key
.iter()
.map(|k| rust_name_for_sql_name(k))
.collect::<Vec<_>>();
Ok(TableData {
name,
primary_key,
column_data,
comment: table_comment,
})
}