1use async_trait::async_trait;
2use std::sync::Arc;
3use std::sync::OnceLock;
4
5use nestrs::core::DatabasePing;
6use nestrs::prelude::*;
7
8pub mod client;
9pub mod deployment;
10pub mod error;
11pub mod index_ddl;
12mod macros;
13mod macros_enum;
14mod macros_index;
15mod macros_relation;
16mod macros_where;
17pub mod mapping;
18pub mod query_optimization;
19pub mod relation_queries;
20pub mod relations;
21pub mod schema_bridge;
22pub mod transaction;
23
24#[doc(hidden)]
25pub use paste;
26
27#[cfg(feature = "sqlx")]
28#[doc(hidden)]
29pub use sqlx;
30
31#[cfg(all(feature = "sqlx", feature = "sqlx-postgres"))]
35pub type SqlxDb = sqlx::Postgres;
36#[cfg(all(
37 feature = "sqlx",
38 not(feature = "sqlx-postgres"),
39 feature = "sqlx-mysql"
40))]
41pub type SqlxDb = sqlx::MySql;
42#[cfg(all(
44 feature = "sqlx",
45 not(feature = "sqlx-postgres"),
46 not(feature = "sqlx-mysql"),
47))]
48pub type SqlxDb = sqlx::Sqlite;
49
50#[cfg(feature = "sqlx")]
51pub type SqlxPool = sqlx::Pool<SqlxDb>;
52
53#[cfg(feature = "sqlx")]
54pub use client::ModelRepository;
55pub use client::SortOrder;
56pub use error::PrismaError;
57
58#[cfg(feature = "sqlx")]
59use tokio::sync::OnceCell;
60
61pub const DEFAULT_SCHEMA_PATH: &str = "prisma/schema.prisma";
63pub const DEFAULT_MIGRATIONS_PATH: &str = "prisma/migrations";
65
66pub fn prisma_generate_command(schema_path: &str) -> String {
71 format!("cargo prisma generate --schema {schema_path}")
72}
73
74pub fn prisma_migrate_deploy_command() -> &'static str {
76 "npx prisma migrate deploy"
77}
78
79pub fn prisma_db_push_command() -> &'static str {
81 "npx prisma db push"
82}
83
84#[derive(Debug, Clone)]
85pub struct PrismaOptions {
86 pub database_url: String,
87 pub pool_min: u32,
88 pub pool_max: u32,
89 pub schema_path: String,
90}
91
92impl PrismaOptions {
93 pub fn from_url(database_url: impl Into<String>) -> Self {
94 Self {
95 database_url: database_url.into(),
96 pool_min: 2,
97 pool_max: 20,
98 schema_path: DEFAULT_SCHEMA_PATH.to_string(),
99 }
100 }
101
102 pub fn pool_min(mut self, value: u32) -> Self {
103 self.pool_min = value;
104 self
105 }
106
107 pub fn pool_max(mut self, value: u32) -> Self {
108 self.pool_max = value;
109 self
110 }
111
112 pub fn schema_path(mut self, value: impl Into<String>) -> Self {
113 self.schema_path = value.into();
114 self
115 }
116}
117
118static PRISMA_OPTIONS: OnceLock<PrismaOptions> = OnceLock::new();
119
120#[cfg(feature = "sqlx")]
121static SQLX_POOL: OnceCell<SqlxPool> = OnceCell::const_new();
122
123#[cfg(feature = "sqlx")]
125pub async fn sqlx_pool() -> Result<&'static SqlxPool, PrismaError> {
126 ensure_sqlx_pool().await.map_err(PrismaError::PoolInit)
127}
128
129#[cfg(feature = "sqlx")]
130async fn ensure_sqlx_pool() -> Result<&'static SqlxPool, String> {
131 SQLX_POOL
132 .get_or_try_init(|| async {
133 let opts = PRISMA_OPTIONS.get().cloned().ok_or_else(|| {
134 "PrismaModule::for_root / for_root_with_options must be called before SQL connectivity"
135 .to_string()
136 })?;
137 sqlx::pool::PoolOptions::<SqlxDb>::new()
138 .max_connections(opts.pool_max)
139 .min_connections(opts.pool_min)
140 .connect(&opts.database_url)
141 .await
142 .map_err(|e| format!("sqlx connect: {e}"))
143 })
144 .await
145}
146
147#[derive(Debug, Clone)]
148pub struct PrismaClientHandle {
149 pub database_url: String,
150 pub schema_path: String,
151}
152
153pub struct PrismaService {
158 options: PrismaOptions,
159 client: PrismaClientHandle,
160}
161
162impl PrismaService {
163 pub fn client(&self) -> &PrismaClientHandle {
164 &self.client
165 }
166
167 pub fn options(&self) -> &PrismaOptions {
168 &self.options
169 }
170
171 pub fn health(&self) -> &'static str {
173 "ok"
174 }
175
176 #[cfg(feature = "sqlx")]
178 pub async fn query_scalar(&self, sql: &str) -> Result<String, String> {
179 let pool = ensure_sqlx_pool().await?;
180 let v: i64 = sqlx::query_scalar(sql)
181 .fetch_one(pool)
182 .await
183 .map_err(|e| format!("sqlx query: {e}"))?;
184 Ok(v.to_string())
185 }
186
187 #[cfg(feature = "sqlx")]
189 pub async fn query_all_as<T>(&self, sql: &str) -> Result<Vec<T>, String>
190 where
191 for<'r> T: sqlx::FromRow<'r, <SqlxDb as sqlx::Database>::Row> + Send + Unpin,
192 {
193 let pool = ensure_sqlx_pool().await?;
194 sqlx::query_as::<_, T>(sql)
195 .fetch_all(pool)
196 .await
197 .map_err(|e| format!("sqlx query: {e}"))
198 }
199
200 #[cfg(feature = "sqlx")]
202 pub async fn execute(&self, sql: &str) -> Result<u64, String> {
203 let pool = ensure_sqlx_pool().await?;
204 sqlx::query(sql)
205 .execute(pool)
206 .await
207 .map_err(|e| format!("sqlx execute: {e}"))
208 .map(|r| r.rows_affected())
209 }
210
211 #[cfg(feature = "sqlx")]
213 pub async fn ping(&self) -> Result<(), String> {
214 let pool = ensure_sqlx_pool().await?;
215 sqlx::query("SELECT 1")
216 .execute(pool)
217 .await
218 .map_err(|e| format!("sqlx ping: {e}"))?;
219 Ok(())
220 }
221
222 #[cfg(not(feature = "sqlx"))]
224 pub fn query_raw(&self, sql: &str) -> String {
225 format!("query accepted by prisma stub (enable nestrs-prisma/sqlx): {sql}")
226 }
227
228 pub fn mapping_guidance(&self) -> &'static str {
229 "Prefer `From<ModelData>` / `TryFrom<ModelData>` impls for response DTOs; avoid returning generated Prisma model types directly from controllers."
230 }
231}
232
233#[async_trait]
234impl DatabasePing for PrismaService {
235 async fn ping_database(&self) -> Result<(), String> {
236 #[cfg(feature = "sqlx")]
237 {
238 self.ping().await
239 }
240 #[cfg(not(feature = "sqlx"))]
241 {
242 Ok(())
243 }
244 }
245}
246
247impl Default for PrismaService {
248 fn default() -> Self {
249 let options = PRISMA_OPTIONS
250 .get()
251 .cloned()
252 .or_else(|| {
253 std::env::var("DATABASE_URL")
254 .ok()
255 .map(PrismaOptions::from_url)
256 })
257 .unwrap_or_else(|| PrismaOptions::from_url("file:./dev.db"));
258
259 let client = PrismaClientHandle {
260 database_url: options.database_url.clone(),
261 schema_path: options.schema_path.clone(),
262 };
263
264 Self { options, client }
265 }
266}
267
268impl Injectable for PrismaService {
269 fn construct(_registry: &ProviderRegistry) -> Arc<Self> {
270 Arc::new(Self::default())
271 }
272}
273
274#[module(
275 providers = [PrismaService],
276 exports = [PrismaService],
277)]
278pub struct PrismaModule;
279
280impl PrismaModule {
281 pub fn for_root(database_url: impl Into<String>) -> Self {
282 let _ = PRISMA_OPTIONS.set(PrismaOptions::from_url(database_url));
283 Self
284 }
285
286 pub fn for_root_with_options(options: PrismaOptions) -> Self {
287 let _ = PRISMA_OPTIONS.set(options);
288 Self
289 }
290
291 pub fn generate_command_hint() -> String {
292 let schema_path = PRISMA_OPTIONS
293 .get()
294 .map(|o| o.schema_path.as_str())
295 .unwrap_or(DEFAULT_SCHEMA_PATH);
296 prisma_generate_command(schema_path)
297 }
298
299 pub fn deploy_command_hint() -> &'static str {
300 prisma_migrate_deploy_command()
301 }
302}