use std::sync::Arc;
use std::sync::OnceLock;
use nestrs::prelude::*;
#[cfg(feature = "sqlx")]
use tokio::sync::OnceCell;
pub const DEFAULT_SCHEMA_PATH: &str = "prisma/schema.prisma";
pub fn prisma_generate_command(schema_path: &str) -> String {
format!("cargo prisma generate --schema {schema_path}")
}
#[derive(Debug, Clone)]
pub struct PrismaOptions {
pub database_url: String,
pub pool_min: u32,
pub pool_max: u32,
pub schema_path: String,
}
impl PrismaOptions {
pub fn from_url(database_url: impl Into<String>) -> Self {
Self {
database_url: database_url.into(),
pool_min: 2,
pool_max: 20,
schema_path: DEFAULT_SCHEMA_PATH.to_string(),
}
}
pub fn pool_min(mut self, value: u32) -> Self {
self.pool_min = value;
self
}
pub fn pool_max(mut self, value: u32) -> Self {
self.pool_max = value;
self
}
pub fn schema_path(mut self, value: impl Into<String>) -> Self {
self.schema_path = value.into();
self
}
}
static PRISMA_OPTIONS: OnceLock<PrismaOptions> = OnceLock::new();
#[cfg(feature = "sqlx")]
static SQLX_POOL: OnceCell<sqlx::AnyPool> = OnceCell::const_new();
#[cfg(feature = "sqlx")]
async fn ensure_sqlx_pool() -> Result<&'static sqlx::AnyPool, String> {
SQLX_POOL
.get_or_try_init(|| async {
let opts = PRISMA_OPTIONS.get().cloned().ok_or_else(|| {
"PrismaModule::for_root / for_root_with_options must be called before SQL connectivity"
.to_string()
})?;
sqlx::any::AnyPoolOptions::new()
.max_connections(opts.pool_max)
.min_connections(opts.pool_min)
.connect(&opts.database_url)
.await
.map_err(|e| format!("sqlx connect: {e}"))
})
.await
}
#[derive(Debug, Clone)]
pub struct PrismaClientHandle {
pub database_url: String,
pub schema_path: String,
}
pub struct PrismaService {
options: PrismaOptions,
client: PrismaClientHandle,
}
impl PrismaService {
pub fn client(&self) -> &PrismaClientHandle {
&self.client
}
pub fn options(&self) -> &PrismaOptions {
&self.options
}
pub fn health(&self) -> &'static str {
"ok"
}
#[cfg(feature = "sqlx")]
pub async fn query_scalar(&self, sql: &str) -> Result<String, String> {
let pool = ensure_sqlx_pool().await?;
let v: i64 = sqlx::query_scalar(sql)
.fetch_one(pool)
.await
.map_err(|e| format!("sqlx query: {e}"))?;
Ok(v.to_string())
}
#[cfg(feature = "sqlx")]
pub async fn query_all_as<T>(&self, sql: &str) -> Result<Vec<T>, String>
where
for<'r> T: sqlx::FromRow<'r, sqlx::any::AnyRow> + Send + Unpin,
{
let pool = ensure_sqlx_pool().await?;
sqlx::query_as::<_, T>(sql)
.fetch_all(pool)
.await
.map_err(|e| format!("sqlx query: {e}"))
}
#[cfg(feature = "sqlx")]
pub async fn execute(&self, sql: &str) -> Result<u64, String> {
let pool = ensure_sqlx_pool().await?;
sqlx::query(sql)
.execute(pool)
.await
.map_err(|e| format!("sqlx execute: {e}"))
.map(|r| r.rows_affected())
}
#[cfg(feature = "sqlx")]
pub async fn ping(&self) -> Result<(), String> {
let pool = ensure_sqlx_pool().await?;
sqlx::query("SELECT 1")
.execute(pool)
.await
.map_err(|e| format!("sqlx ping: {e}"))?;
Ok(())
}
#[cfg(not(feature = "sqlx"))]
pub fn query_raw(&self, sql: &str) -> String {
format!("query accepted by prisma stub (enable nestrs-prisma/sqlx): {sql}")
}
pub fn mapping_guidance(&self) -> &'static str {
"Prefer `From<ModelData>` / `TryFrom<ModelData>` impls for response DTOs; avoid returning generated Prisma model types directly from controllers."
}
}
impl Default for PrismaService {
fn default() -> Self {
let options = PRISMA_OPTIONS
.get()
.cloned()
.or_else(|| {
std::env::var("DATABASE_URL")
.ok()
.map(PrismaOptions::from_url)
})
.unwrap_or_else(|| PrismaOptions::from_url("file:./dev.db"));
let client = PrismaClientHandle {
database_url: options.database_url.clone(),
schema_path: options.schema_path.clone(),
};
Self { options, client }
}
}
impl Injectable for PrismaService {
fn construct(_registry: &ProviderRegistry) -> Arc<Self> {
Arc::new(Self::default())
}
}
#[module(
providers = [PrismaService],
exports = [PrismaService],
)]
pub struct PrismaModule;
impl PrismaModule {
pub fn for_root(database_url: impl Into<String>) -> Self {
let _ = PRISMA_OPTIONS.set(PrismaOptions::from_url(database_url));
Self
}
pub fn for_root_with_options(options: PrismaOptions) -> Self {
let _ = PRISMA_OPTIONS.set(options);
Self
}
pub fn generate_command_hint() -> String {
let schema_path = PRISMA_OPTIONS
.get()
.map(|o| o.schema_path.as_str())
.unwrap_or(DEFAULT_SCHEMA_PATH);
prisma_generate_command(schema_path)
}
}