use std::collections::BTreeMap;
use crate::{
DatabaseError, DatabaseValue,
schema::{ColumnInfo, DataType, ForeignKeyInfo, IndexInfo, TableInfo},
};
use sqlx::{MySqlConnection, Row};
pub async fn mysql_sqlx_table_exists(
conn: &mut sqlx::MySqlConnection,
table_name: &str,
) -> Result<bool, DatabaseError> {
let query = "SELECT EXISTS (
SELECT 1 FROM information_schema.tables
WHERE table_schema = DATABASE() AND table_name = ?
)";
let row = sqlx::query(query)
.bind(table_name)
.fetch_one(&mut *conn)
.await
.map_err(|e| DatabaseError::MysqlSqlx(super::mysql::SqlxDatabaseError::from(e)))?;
let exists: i64 = row
.try_get(0)
.map_err(|e| DatabaseError::MysqlSqlx(super::mysql::SqlxDatabaseError::from(e)))?;
Ok(exists != 0)
}
pub async fn mysql_sqlx_list_tables(
conn: &mut MySqlConnection,
) -> Result<Vec<String>, DatabaseError> {
let query = "SELECT CAST(TABLE_NAME AS CHAR) AS TABLE_NAME FROM information_schema.tables WHERE table_schema = DATABASE() ORDER BY TABLE_NAME";
let rows = sqlx::query(query)
.fetch_all(&mut *conn)
.await
.map_err(|e| DatabaseError::MysqlSqlx(super::mysql::SqlxDatabaseError::from(e)))?;
let mut tables = Vec::new();
for row in rows {
let table_name: String = row
.try_get("TABLE_NAME")
.map_err(|e| DatabaseError::MysqlSqlx(super::mysql::SqlxDatabaseError::from(e)))?;
tables.push(table_name);
}
Ok(tables)
}
pub async fn mysql_sqlx_get_table_columns(
conn: &mut sqlx::MySqlConnection,
table_name: &str,
) -> Result<Vec<ColumnInfo>, DatabaseError> {
let query = "SELECT
COLUMN_NAME,
CAST(DATA_TYPE AS CHAR) AS DATA_TYPE,
CAST(COLUMN_TYPE AS CHAR) AS COLUMN_TYPE,
CHARACTER_MAXIMUM_LENGTH,
CAST(IS_NULLABLE AS CHAR) AS IS_NULLABLE,
CAST(COLUMN_DEFAULT AS CHAR) AS COLUMN_DEFAULT,
CAST(COLUMN_KEY AS CHAR) AS COLUMN_KEY,
CAST(EXTRA AS CHAR) AS EXTRA,
ORDINAL_POSITION
FROM information_schema.columns
WHERE table_schema = DATABASE() AND table_name = ?
ORDER BY ORDINAL_POSITION";
let rows = sqlx::query(query)
.bind(table_name)
.fetch_all(&mut *conn)
.await
.map_err(|e| DatabaseError::MysqlSqlx(super::mysql::SqlxDatabaseError::from(e)))?;
let pk_query = "SELECT COLUMN_NAME
FROM information_schema.key_column_usage
WHERE table_schema = DATABASE()
AND table_name = ?
AND constraint_name = 'PRIMARY'";
let pk_rows = sqlx::query(pk_query)
.bind(table_name)
.fetch_all(&mut *conn)
.await
.map_err(|e| DatabaseError::MysqlSqlx(super::mysql::SqlxDatabaseError::from(e)))?;
let primary_key_columns: Vec<String> = pk_rows
.iter()
.map(|row| row.try_get::<String, _>("COLUMN_NAME").unwrap_or_default())
.collect();
let mut columns = Vec::new();
for row in rows {
let column_name: String = row
.try_get("COLUMN_NAME")
.map_err(|e| DatabaseError::MysqlSqlx(super::mysql::SqlxDatabaseError::from(e)))?;
let data_type_str: String = row
.try_get("DATA_TYPE")
.map_err(|e| DatabaseError::MysqlSqlx(super::mysql::SqlxDatabaseError::from(e)))?;
let column_type_str: String = row
.try_get("COLUMN_TYPE")
.map_err(|e| DatabaseError::MysqlSqlx(super::mysql::SqlxDatabaseError::from(e)))?;
let char_max_length: Option<i64> = row.try_get("CHARACTER_MAXIMUM_LENGTH").ok();
let data_type =
mysql_column_type_to_data_type(&column_type_str, &data_type_str, char_max_length);
let is_nullable_str: String = row
.try_get("IS_NULLABLE")
.map_err(|e| DatabaseError::MysqlSqlx(super::mysql::SqlxDatabaseError::from(e)))?;
let nullable = is_nullable_str.to_uppercase() == "YES";
let ordinal_position: u32 = row
.try_get::<u32, _>("ORDINAL_POSITION")
.map_err(|e| DatabaseError::MysqlSqlx(super::mysql::SqlxDatabaseError::from(e)))?;
let default_value: Option<String> = row.try_get("COLUMN_DEFAULT").ok();
let parsed_default = default_value.as_deref().and_then(parse_mysql_default_value);
let extra: String = row.try_get("EXTRA").unwrap_or_default();
let auto_increment = extra.to_uppercase().contains("AUTO_INCREMENT");
let is_primary_key = primary_key_columns.contains(&column_name);
columns.push(ColumnInfo {
name: column_name,
data_type,
nullable,
is_primary_key,
auto_increment,
default_value: parsed_default,
ordinal_position,
});
}
Ok(columns)
}
pub async fn mysql_sqlx_column_exists(
conn: &mut sqlx::MySqlConnection,
table_name: &str,
column_name: &str,
) -> Result<bool, DatabaseError> {
let query = "SELECT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_schema = DATABASE()
AND table_name = ?
AND column_name = ?
)";
let row = sqlx::query(query)
.bind(table_name)
.bind(column_name)
.fetch_one(&mut *conn)
.await
.map_err(|e| DatabaseError::MysqlSqlx(super::mysql::SqlxDatabaseError::from(e)))?;
let exists: i64 = row
.try_get(0)
.map_err(|e| DatabaseError::MysqlSqlx(super::mysql::SqlxDatabaseError::from(e)))?;
Ok(exists != 0)
}
pub async fn mysql_sqlx_get_table_info(
conn: &mut sqlx::MySqlConnection,
table_name: &str,
) -> Result<Option<TableInfo>, DatabaseError> {
if !mysql_sqlx_table_exists(conn, table_name).await? {
return Ok(None);
}
let columns = mysql_sqlx_get_table_columns(conn, table_name).await?;
let mut columns_map = BTreeMap::new();
for column in columns {
columns_map.insert(column.name.clone(), column);
}
let index_query = "SELECT INDEX_NAME, NON_UNIQUE, COLUMN_NAME
FROM information_schema.STATISTICS
WHERE table_schema = DATABASE() AND table_name = ?
ORDER BY INDEX_NAME, SEQ_IN_INDEX";
let index_rows = sqlx::query(index_query)
.bind(table_name)
.fetch_all(&mut *conn)
.await
.map_err(|e| DatabaseError::MysqlSqlx(super::mysql::SqlxDatabaseError::from(e)))?;
let mut indexes_map: BTreeMap<String, IndexInfo> = BTreeMap::new();
for row in index_rows {
let index_name: String = row
.try_get("INDEX_NAME")
.map_err(|e| DatabaseError::MysqlSqlx(super::mysql::SqlxDatabaseError::from(e)))?;
let non_unique: i64 = row
.try_get("NON_UNIQUE")
.map_err(|e| DatabaseError::MysqlSqlx(super::mysql::SqlxDatabaseError::from(e)))?;
let column_name: String = row
.try_get("COLUMN_NAME")
.map_err(|e| DatabaseError::MysqlSqlx(super::mysql::SqlxDatabaseError::from(e)))?;
let is_primary = index_name == "PRIMARY";
let unique = non_unique == 0;
if let Some(existing_index) = indexes_map.get_mut(&index_name) {
existing_index.columns.push(column_name);
} else {
indexes_map.insert(
index_name.clone(),
IndexInfo {
name: index_name,
unique,
columns: vec![column_name],
is_primary,
},
);
}
}
let fk_query = "SELECT
CAST(kcu.CONSTRAINT_NAME AS CHAR) AS CONSTRAINT_NAME,
kcu.COLUMN_NAME,
CAST(kcu.REFERENCED_TABLE_NAME AS CHAR) AS REFERENCED_TABLE_NAME,
CAST(kcu.REFERENCED_COLUMN_NAME AS CHAR) AS REFERENCED_COLUMN_NAME,
CAST(rc.UPDATE_RULE AS CHAR) AS UPDATE_RULE,
CAST(rc.DELETE_RULE AS CHAR) AS DELETE_RULE
FROM information_schema.KEY_COLUMN_USAGE kcu
JOIN information_schema.REFERENTIAL_CONSTRAINTS rc
ON kcu.CONSTRAINT_NAME = rc.CONSTRAINT_NAME
AND kcu.CONSTRAINT_SCHEMA = rc.CONSTRAINT_SCHEMA
WHERE kcu.table_schema = DATABASE()
AND kcu.table_name = ?
AND kcu.REFERENCED_TABLE_NAME IS NOT NULL";
let fk_rows = sqlx::query(fk_query)
.bind(table_name)
.fetch_all(&mut *conn)
.await
.map_err(|e| DatabaseError::MysqlSqlx(super::mysql::SqlxDatabaseError::from(e)))?;
let mut foreign_keys_map = BTreeMap::new();
for row in fk_rows {
let constraint_name: String = row
.try_get("CONSTRAINT_NAME")
.map_err(|e| DatabaseError::MysqlSqlx(super::mysql::SqlxDatabaseError::from(e)))?;
let column_name: String = row
.try_get("COLUMN_NAME")
.map_err(|e| DatabaseError::MysqlSqlx(super::mysql::SqlxDatabaseError::from(e)))?;
let referenced_table: String = row
.try_get("REFERENCED_TABLE_NAME")
.map_err(|e| DatabaseError::MysqlSqlx(super::mysql::SqlxDatabaseError::from(e)))?;
let referenced_column: String = row
.try_get("REFERENCED_COLUMN_NAME")
.map_err(|e| DatabaseError::MysqlSqlx(super::mysql::SqlxDatabaseError::from(e)))?;
let update_rule: Option<String> = row.try_get("UPDATE_RULE").ok();
let delete_rule: Option<String> = row.try_get("DELETE_RULE").ok();
foreign_keys_map.insert(
constraint_name.clone(),
ForeignKeyInfo {
name: constraint_name,
column: column_name,
referenced_table,
referenced_column,
on_update: update_rule,
on_delete: delete_rule,
},
);
}
Ok(Some(TableInfo {
name: table_name.to_string(),
columns: columns_map,
indexes: indexes_map,
foreign_keys: foreign_keys_map,
}))
}
fn mysql_type_to_data_type(mysql_type: &str, char_max_length: Option<i64>) -> DataType {
match mysql_type.to_uppercase().as_str() {
"TINYINT" | "SMALLINT" => DataType::SmallInt,
"MEDIUMINT" | "INT" | "INTEGER" => DataType::Int,
"BIGINT" => DataType::BigInt,
"FLOAT" => DataType::Real,
"DOUBLE" | "REAL" => DataType::Double,
"DECIMAL" | "NUMERIC" => DataType::Decimal(38, 10),
"CHAR" => match char_max_length {
Some(length) if length > 0 && length <= i64::from(u16::MAX) => {
DataType::Char(u16::try_from(length).unwrap_or(1))
}
_ => DataType::Char(1),
},
"VARCHAR" => match char_max_length {
Some(length) if length > 0 && length <= i64::from(u16::MAX) => {
DataType::VarChar(u16::try_from(length).unwrap_or(255))
}
_ => DataType::VarChar(255),
},
"TEXT" | "TINYTEXT" | "MEDIUMTEXT" | "LONGTEXT" => DataType::Text,
"BOOLEAN" | "BOOL" => DataType::Bool,
"DATE" => DataType::Date,
"TIME" => DataType::Time,
"DATETIME" => DataType::DateTime,
"TIMESTAMP" => DataType::Timestamp,
"BLOB" | "TINYBLOB" | "MEDIUMBLOB" | "LONGBLOB" => DataType::Blob,
"BINARY" | "VARBINARY" => DataType::Binary(None),
"JSON" => DataType::Json,
_ => DataType::Custom(mysql_type.to_string()),
}
}
fn mysql_column_type_to_data_type(
column_type: &str,
data_type: &str,
char_max_length: Option<i64>,
) -> DataType {
let column_type_upper = column_type.to_uppercase();
if column_type_upper == "TINYINT(1)" {
return DataType::Bool;
}
if column_type_upper.starts_with("VARCHAR(")
&& let Some(end) = column_type.find(')')
&& let Ok(len) = column_type[8..end].parse::<u16>()
{
return DataType::VarChar(len);
}
if column_type_upper.starts_with("CHAR(")
&& let Some(end) = column_type.find(')')
&& let Ok(len) = column_type[5..end].parse::<u16>()
{
return DataType::Char(len);
}
mysql_type_to_data_type(data_type, char_max_length)
}
fn parse_mysql_default_value(default_str: &str) -> Option<DatabaseValue> {
if default_str.is_empty() || default_str.to_uppercase() == "NULL" {
return None;
}
match default_str.to_uppercase().as_str() {
"CURRENT_TIMESTAMP" | "NOW()" => Some(DatabaseValue::Now),
_ => {
if default_str.starts_with('\'') && default_str.ends_with('\'') {
let unquoted = &default_str[1..default_str.len() - 1];
return Some(DatabaseValue::String(unquoted.to_string()));
}
if let Ok(int_val) = default_str.parse::<i64>() {
return Some(DatabaseValue::Int64(int_val));
}
if let Ok(float_val) = default_str.parse::<f64>() {
return Some(DatabaseValue::Real64(float_val));
}
None
}
}
}