use std::collections::HashSet;
use std::future::Future;
use std::pin::Pin;
use axum::Router;
use axum::response::Response;
use sea_orm::sea_query::{Alias, ColumnDef, Query, Table};
use sea_orm::{ConnectionTrait, DatabaseConnection, DbErr};
use sea_orm_migration::{MigrationTrait, SchemaManager};
use utoipa::openapi::OpenApi;
use utoipa_axum::router::OpenApiRouter;
use esylla_error::EsyllaError;
pub trait Module<S: Clone + Send + Sync + 'static>: Send + Sync {
fn name(&self) -> &'static str;
fn routes(&self) -> OpenApiRouter<S> {
OpenApiRouter::new()
}
fn migrations(&self) -> Vec<Box<dyn MigrationTrait>> {
Vec::new()
}
fn on_init(&self, _state: &S) -> impl Future<Output = anyhow::Result<()>> + Send {
async { Ok(()) }
}
}
const MIGRATION_TABLE: &str = "esylla_migrations";
#[tracing::instrument(skip_all, fields(pending = migrations.len()))]
pub async fn run_migrations(
db: &DatabaseConnection,
migrations: &[Box<dyn MigrationTrait>],
) -> Result<(), DbErr> {
let create_tracking = Table::create()
.table(Alias::new(MIGRATION_TABLE))
.if_not_exists()
.col(
ColumnDef::new(Alias::new("name"))
.string()
.not_null()
.primary_key(),
)
.to_owned();
db.execute(&create_tracking).await?;
let select_applied = Query::select()
.column(Alias::new("name"))
.from(Alias::new(MIGRATION_TABLE))
.to_owned();
let applied: HashSet<String> = db
.query_all(&select_applied)
.await?
.iter()
.filter_map(|row| row.try_get::<String>("", "name").ok())
.collect();
let manager = SchemaManager::new(db);
let mut applied_now = 0u32;
for migration in migrations {
let name = migration.name().to_string();
if applied.contains(&name) {
continue;
}
tracing::info!(migration = %name, "applying migration");
migration.up(&manager).await?;
let record = Query::insert()
.into_table(Alias::new(MIGRATION_TABLE))
.columns([Alias::new("name")])
.values_panic([name.into()])
.to_owned();
db.execute(&record).await?;
applied_now += 1;
}
tracing::info!(applied = applied_now, "migrations up to date");
Ok(())
}
type InitFn<S> =
Box<dyn FnOnce(S) -> Pin<Box<dyn Future<Output = anyhow::Result<()>> + Send>> + Send>;
struct DocsPaths {
ui: String,
spec: String,
}
impl Default for DocsPaths {
fn default() -> Self {
DocsPaths {
ui: "/docs".to_string(),
spec: "/openapi.json".to_string(),
}
}
}
pub struct Esylla<S> {
state: S,
router: OpenApiRouter<S>,
migrations: Vec<Box<dyn MigrationTrait>>,
inits: Vec<(&'static str, InitFn<S>)>,
docs: bool,
docs_paths: DocsPaths,
}
impl<S> Esylla<S>
where
S: Clone + Send + Sync + 'static,
{
pub fn new(state: S) -> Self {
Esylla {
state,
router: OpenApiRouter::new(),
migrations: Vec::new(),
inits: Vec::new(),
docs: true,
docs_paths: DocsPaths::default(),
}
}
pub fn module<M: Module<S> + 'static>(mut self, module: M) -> Self {
self.router = self.router.merge(module.routes());
self.migrations.extend(module.migrations());
let name = module.name();
self.inits.push((
name,
Box::new(move |state: S| Box::pin(async move { module.on_init(&state).await })),
));
self
}
pub fn docs(mut self, enabled: bool) -> Self {
self.docs = enabled;
self
}
pub fn docs_paths(mut self, ui: impl Into<String>, spec: impl Into<String>) -> Self {
self.docs_paths = DocsPaths {
ui: ui.into(),
spec: spec.into(),
};
self
}
pub fn trace_requests(mut self) -> Self {
self.router = self
.router
.layer(tower_http::trace::TraceLayer::new_for_http());
self
}
pub fn on_error<F>(self, f: F) -> Self
where
F: Fn(&dyn EsyllaError) -> Response + Send + Sync + 'static,
{
esylla_error::set_error_renderer(f);
self
}
pub fn collected_migrations(&self) -> &[Box<dyn MigrationTrait>] {
&self.migrations
}
pub fn openapi(&self) -> OpenApi {
self.router.get_openapi().clone()
}
pub async fn migrate(&self, db: &DatabaseConnection) -> Result<(), DbErr> {
run_migrations(db, &self.migrations).await
}
pub async fn init(&mut self) -> anyhow::Result<()> {
for (name, init) in std::mem::take(&mut self.inits) {
tracing::debug!(module = name, "module init");
init(self.state.clone()).await?;
}
Ok(())
}
pub fn into_router(self) -> Router {
let (router, api) = self.router.split_for_parts();
let router = if self.docs {
mount_docs(router, api, &self.docs_paths)
} else {
router
};
router.with_state(self.state)
}
}
#[cfg(feature = "swagger-ui")]
fn mount_docs<S>(router: Router<S>, api: OpenApi, paths: &DocsPaths) -> Router<S>
where
S: Clone + Send + Sync + 'static,
{
router.merge(utoipa_swagger_ui::SwaggerUi::new(paths.ui.clone()).url(paths.spec.clone(), api))
}
#[cfg(not(feature = "swagger-ui"))]
fn mount_docs<S>(router: Router<S>, api: OpenApi, paths: &DocsPaths) -> Router<S>
where
S: Clone + Send + Sync + 'static,
{
router.route(
&paths.spec,
axum::routing::get(move || {
let api = api.clone();
async move { axum::Json(api) }
}),
)
}