use crate::{BannerVersion, BigIntegerType, DateTimeCrate, GenerateSubcommands};
use core::time;
use sea_orm_codegen::{
BannerVersion as CodegenBannerVersion, BigIntegerType as CodegenBigIntegerType,
DateTimeCrate as CodegenDateTimeCrate, EntityFormat, EntityTransformer, EntityWriterContext,
MergeReport, OutputFile, WithPrelude, WithSerde, merge_entity_files,
};
use std::{error::Error, fs, path::Path, process::Command, str::FromStr};
use tracing_subscriber::{EnvFilter, prelude::*};
use url::Url;
pub async fn run_generate_command(
command: GenerateSubcommands,
verbose: bool,
) -> Result<(), Box<dyn Error>> {
match command {
GenerateSubcommands::Entity {
entity_format,
compact_format: _,
expanded_format,
frontend_format,
include_hidden_tables,
tables,
ignore_tables,
max_connections,
acquire_timeout,
output_dir,
database_schema,
database_url,
with_prelude,
with_serde,
serde_skip_deserializing_primary_key,
serde_skip_hidden_column,
with_copy_enums,
date_time_crate,
big_integer_type,
lib,
model_extra_derives,
model_extra_attributes,
enum_extra_derives,
enum_extra_attributes,
column_extra_derives,
seaography,
impl_active_model_behavior,
preserve_user_modifications,
banner_version,
er_diagram,
} => {
if verbose {
let _ = tracing_subscriber::fmt()
.with_max_level(tracing::Level::DEBUG)
.with_test_writer()
.try_init();
} else {
let filter_layer = EnvFilter::try_new("sea_orm_codegen=info").unwrap();
let fmt_layer = tracing_subscriber::fmt::layer()
.with_target(false)
.with_level(false)
.without_time();
let _ = tracing_subscriber::registry()
.with(filter_layer)
.with(fmt_layer)
.try_init();
}
let url = Url::parse(&database_url)?;
let is_sqlite = url.scheme() == "sqlite";
let filter_tables =
|table: &String| -> bool { tables.is_empty() || tables.contains(table) };
let filter_hidden_tables = |table: &str| -> bool {
if include_hidden_tables {
true
} else {
!table.starts_with('_')
}
};
let filter_skip_tables = |table: &String| -> bool { !ignore_tables.contains(table) };
let _database_name = if !is_sqlite {
let database_name = url
.path_segments()
.unwrap_or_else(|| {
panic!(
"There is no database name as part of the url path: {}",
url.as_str()
)
})
.next()
.unwrap();
if database_name.is_empty() {
panic!(
"There is no database name as part of the url path: {}",
url.as_str()
);
}
database_name
} else {
Default::default()
};
let (schema_name, table_stmts) = match url.scheme() {
"mysql" => {
#[cfg(not(feature = "sqlx-mysql"))]
{
panic!("mysql feature is off")
}
#[cfg(feature = "sqlx-mysql")]
{
use sea_schema::mysql::discovery::SchemaDiscovery;
use sqlx::MySql;
println!("Connecting to MySQL ...");
let connection = sqlx_connect::<MySql>(
max_connections,
acquire_timeout,
url.as_str(),
None,
)
.await?;
println!("Discovering schema ...");
let schema_discovery = SchemaDiscovery::new(connection, _database_name);
let schema = schema_discovery.discover().await?;
let table_stmts = schema
.tables
.into_iter()
.filter(|schema| filter_tables(&schema.info.name))
.filter(|schema| filter_hidden_tables(&schema.info.name))
.filter(|schema| filter_skip_tables(&schema.info.name))
.map(|schema| schema.write())
.collect();
(None, table_stmts)
}
}
"sqlite" => {
#[cfg(not(feature = "sqlx-sqlite"))]
{
panic!("sqlite feature is off")
}
#[cfg(feature = "sqlx-sqlite")]
{
use sea_schema::sqlite::discovery::SchemaDiscovery;
use sqlx::Sqlite;
println!("Connecting to SQLite ...");
let connection = sqlx_connect::<Sqlite>(
max_connections,
acquire_timeout,
url.as_str(),
None,
)
.await?;
println!("Discovering schema ...");
let schema_discovery = SchemaDiscovery::new(connection);
let schema = schema_discovery
.discover()
.await?
.merge_indexes_into_table();
let table_stmts = schema
.tables
.into_iter()
.filter(|schema| filter_tables(&schema.name))
.filter(|schema| filter_hidden_tables(&schema.name))
.filter(|schema| filter_skip_tables(&schema.name))
.map(|schema| schema.write())
.collect();
(None, table_stmts)
}
}
"postgres" | "postgresql" => {
#[cfg(not(feature = "sqlx-postgres"))]
{
panic!("postgres feature is off")
}
#[cfg(feature = "sqlx-postgres")]
{
use sea_schema::postgres::discovery::SchemaDiscovery;
use sqlx::Postgres;
println!("Connecting to Postgres ...");
let schema = database_schema.as_deref().unwrap_or("public");
let connection = sqlx_connect::<Postgres>(
max_connections,
acquire_timeout,
url.as_str(),
Some(schema),
)
.await?;
println!("Discovering schema ...");
let schema_discovery = SchemaDiscovery::new(connection, schema);
let schema = schema_discovery.discover().await?;
let table_stmts = schema
.tables
.into_iter()
.filter(|schema| filter_tables(&schema.info.name))
.filter(|schema| filter_hidden_tables(&schema.info.name))
.filter(|schema| filter_skip_tables(&schema.info.name))
.map(|schema| schema.write())
.collect();
(database_schema, table_stmts)
}
}
_ => unimplemented!("{} is not supported", url.scheme()),
};
println!("... discovered.");
let writer_context = EntityWriterContext::new(
if expanded_format {
EntityFormat::Expanded
} else if frontend_format {
EntityFormat::Frontend
} else if let Some(entity_format) = entity_format {
EntityFormat::from_str(&entity_format).expect("Invalid entity-format option")
} else {
EntityFormat::default()
},
WithPrelude::from_str(&with_prelude).expect("Invalid prelude option"),
WithSerde::from_str(&with_serde).expect("Invalid serde derive option"),
with_copy_enums,
date_time_crate.into(),
big_integer_type.into(),
schema_name,
lib,
serde_skip_deserializing_primary_key,
serde_skip_hidden_column,
model_extra_derives,
model_extra_attributes,
enum_extra_derives,
enum_extra_attributes,
column_extra_derives,
seaography,
impl_active_model_behavior,
banner_version.into(),
);
let entity_writer = EntityTransformer::transform(table_stmts)?;
let dir = Path::new(&output_dir);
fs::create_dir_all(dir)?;
if er_diagram {
let diagram = entity_writer.generate_er_diagram();
let diagram_path = dir.join("entities.mermaid");
fs::write(&diagram_path, &diagram)?;
println!("Writing {}", diagram_path.display());
}
let output = entity_writer.generate(&writer_context);
let mut merge_fallback_files: Vec<String> = Vec::new();
for OutputFile { name, content } in output.files.iter() {
let file_path = dir.join(name);
println!("Writing {}", file_path.display());
if !matches!(
name.as_str(),
"mod.rs" | "lib.rs" | "prelude.rs" | "sea_orm_active_enums.rs"
) && file_path.exists()
&& preserve_user_modifications
{
let prev_content = fs::read_to_string(&file_path)?;
match merge_entity_files(&prev_content, content) {
Ok(merged) => {
fs::write(file_path, merged)?;
}
Err(MergeReport {
output,
warnings,
fallback_applied,
}) => {
for message in warnings {
eprintln!("{message}");
}
fs::write(file_path, output)?;
if fallback_applied {
merge_fallback_files.push(name.clone());
}
}
}
} else {
fs::write(file_path, content)?;
};
}
for OutputFile { name, .. } in output.files.iter() {
let exit_status = Command::new("rustfmt").arg(dir.join(name)).status()?; if !exit_status.success() {
return Err(format!("Fail to format file `{name}`").into());
}
}
if merge_fallback_files.is_empty() {
println!("... Done.");
} else {
return Err(format!(
"Merge fallback applied for {} file(s): \n{}",
merge_fallback_files.len(),
merge_fallback_files.join("\n")
)
.into());
}
}
}
Ok(())
}
async fn sqlx_connect<DB>(
max_connections: u32,
acquire_timeout: u64,
url: &str,
schema: Option<&str>,
) -> Result<sqlx::Pool<DB>, Box<dyn Error>>
where
DB: sqlx::Database,
for<'a> &'a mut <DB as sqlx::Database>::Connection: sqlx::Executor<'a>,
{
let mut pool_options = sqlx::pool::PoolOptions::<DB>::new()
.max_connections(max_connections)
.acquire_timeout(time::Duration::from_secs(acquire_timeout));
if let Some(schema) = schema {
let sql = format!("SET search_path = '{schema}'");
pool_options = pool_options.after_connect(move |conn, _| {
let sql = sql.clone();
Box::pin(async move {
sqlx::Executor::execute(conn, sqlx::AssertSqlSafe(sql))
.await
.map(|_| ())
})
});
}
pool_options.connect(url).await.map_err(Into::into)
}
impl From<DateTimeCrate> for CodegenDateTimeCrate {
fn from(date_time_crate: DateTimeCrate) -> CodegenDateTimeCrate {
match date_time_crate {
DateTimeCrate::Chrono => CodegenDateTimeCrate::Chrono,
DateTimeCrate::Time => CodegenDateTimeCrate::Time,
}
}
}
impl From<BigIntegerType> for CodegenBigIntegerType {
fn from(date_time_crate: BigIntegerType) -> CodegenBigIntegerType {
match date_time_crate {
BigIntegerType::I64 => CodegenBigIntegerType::I64,
BigIntegerType::I32 => CodegenBigIntegerType::I32,
}
}
}
impl From<BannerVersion> for CodegenBannerVersion {
fn from(banner_version: BannerVersion) -> CodegenBannerVersion {
match banner_version {
BannerVersion::Off => CodegenBannerVersion::Off,
BannerVersion::Major => CodegenBannerVersion::Major,
BannerVersion::Minor => CodegenBannerVersion::Minor,
BannerVersion::Patch => CodegenBannerVersion::Patch,
}
}
}
#[cfg(test)]
mod tests {
use clap::Parser;
use super::*;
use crate::{Cli, Commands};
#[test]
#[should_panic(
expected = "called `Result::unwrap()` on an `Err` value: RelativeUrlWithoutBase"
)]
fn test_generate_entity_no_protocol() {
let cli = Cli::parse_from([
"sea-orm-cli",
"generate",
"entity",
"--database-url",
"://root:root@localhost:3306/database",
]);
match cli.command {
Commands::Generate { command } => {
smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
}
_ => unreachable!(),
}
}
#[test]
#[should_panic(
expected = "There is no database name as part of the url path: postgresql://root:root@localhost:3306"
)]
fn test_generate_entity_no_database_section() {
let cli = Cli::parse_from([
"sea-orm-cli",
"generate",
"entity",
"--database-url",
"postgresql://root:root@localhost:3306",
]);
match cli.command {
Commands::Generate { command } => {
smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
}
_ => unreachable!(),
}
}
#[test]
#[should_panic(
expected = "There is no database name as part of the url path: mysql://root:root@localhost:3306/"
)]
fn test_generate_entity_no_database_path() {
let cli = Cli::parse_from([
"sea-orm-cli",
"generate",
"entity",
"--database-url",
"mysql://root:root@localhost:3306/",
]);
match cli.command {
Commands::Generate { command } => {
smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
}
_ => unreachable!(),
}
}
#[test]
#[should_panic(expected = "called `Result::unwrap()` on an `Err` value: EmptyHost")]
fn test_generate_entity_no_host() {
let cli = Cli::parse_from([
"sea-orm-cli",
"generate",
"entity",
"--database-url",
"postgres://root:root@/database",
]);
match cli.command {
Commands::Generate { command } => {
smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
}
_ => unreachable!(),
}
}
}