1#[cfg(not(any(
2 feature = "strict-postgres",
3 feature = "strict-mysql",
4 feature = "strict-sqlite"
5)))]
6pub use sqlx::AnyPool as RullstPool;
7
8#[cfg(not(any(
9 feature = "strict-postgres",
10 feature = "strict-mysql",
11 feature = "strict-sqlite"
12)))]
13pub use sqlx::any::AnyPoolOptions as RullstPoolOptions;
14
15#[cfg(feature = "strict-postgres")]
16pub use sqlx::PgPool as RullstPool;
17
18#[cfg(feature = "strict-postgres")]
19pub use sqlx::postgres::PgPoolOptions as RullstPoolOptions;
20
21#[cfg(all(feature = "strict-mysql", not(feature = "strict-postgres")))]
22pub use sqlx::MySqlPool as RullstPool;
23
24#[cfg(all(feature = "strict-mysql", not(feature = "strict-postgres")))]
25pub use sqlx::mysql::MySqlPoolOptions as RullstPoolOptions;
26
27#[cfg(all(
28 feature = "strict-sqlite",
29 not(feature = "strict-postgres"),
30 not(feature = "strict-mysql")
31))]
32pub use sqlx::SqlitePool as RullstPool;
33
34#[cfg(all(
35 feature = "strict-sqlite",
36 not(feature = "strict-postgres"),
37 not(feature = "strict-mysql")
38))]
39pub use sqlx::sqlite::SqlitePoolOptions as RullstPoolOptions;
40
41#[cfg(not(any(
42 feature = "strict-postgres",
43 feature = "strict-mysql",
44 feature = "strict-sqlite"
45)))]
46use sqlx::any::install_default_drivers;
47
48use std::sync::OnceLock;
49use std::sync::atomic::{AtomicUsize, Ordering};
50
51#[doc(hidden)]
53pub use futures as _futures;
54#[doc(hidden)]
55pub use serde as _serde;
56#[doc(hidden)]
57pub use serde_json as _serde_json;
58#[doc(hidden)]
59pub use sqlx as _sqlx;
60
61#[cfg(feature = "redis")]
62#[doc(hidden)]
63pub use redis as _redis;
64pub mod admin;
65pub mod audit;
66pub mod collection;
67pub mod database;
68pub mod db;
69pub mod error;
70pub mod resource;
71pub mod schema;
72pub mod scout;
73pub mod tenant;
74pub mod types;
75
76pub use error::RullstError as Error;
78
79pub use _sqlx::FromRow;
81pub use admin::dashboard_html;
82pub use collection::RullstCollection;
83pub use database::RullstDatabase;
84pub use resource::{ApiResource, JsonResource, ResourceCollection};
85pub use rullst_orm_macros::Orm;
86pub use scout::{SearchEngine, get_search_engine, set_search_engine};
87pub use tenant::{get_tenant_id, with_tenant};
88pub use types::Json;
89
90pub use async_trait::async_trait;
92
93pub use schema::{JoinClause, SubqueryBuilder};
95
96static DB_POOL: OnceLock<RullstPool> = OnceLock::new();
98
99static DB_DRIVER: OnceLock<String> = OnceLock::new();
101
102static REPLICA_POOLS: OnceLock<Vec<RullstPool>> = OnceLock::new();
104
105static REPLICA_INDEX: AtomicUsize = AtomicUsize::new(0);
107
108#[cfg(feature = "redis")]
109static REDIS_CLIENT: OnceLock<_redis::Client> = OnceLock::new();
110
111#[cfg(feature = "redis")]
112static REDIS_MANAGER: OnceLock<_redis::aio::ConnectionManager> = OnceLock::new();
113
114#[derive(Clone, Debug)]
116pub enum RullstValue {
117 String(String),
118 Int(i32),
119 Float(f64),
120 Bool(bool),
121}
122
123impl From<&str> for RullstValue {
124 fn from(s: &str) -> Self {
125 RullstValue::String(s.to_string())
126 }
127}
128impl From<String> for RullstValue {
129 fn from(s: String) -> Self {
130 RullstValue::String(s)
131 }
132}
133impl From<i32> for RullstValue {
134 fn from(i: i32) -> Self {
135 RullstValue::Int(i)
136 }
137}
138impl From<f64> for RullstValue {
139 fn from(f: f64) -> Self {
140 RullstValue::Float(f)
141 }
142}
143impl From<bool> for RullstValue {
144 fn from(b: bool) -> Self {
145 RullstValue::Bool(b)
146 }
147}
148
149impl TryFrom<RullstValue> for String {
150 type Error = &'static str;
151 fn try_from(val: RullstValue) -> Result<Self, Self::Error> {
152 match val {
153 RullstValue::String(s) => Ok(s),
154 _ => Err("Not a string"),
155 }
156 }
157}
158impl TryFrom<RullstValue> for i32 {
159 type Error = &'static str;
160 fn try_from(val: RullstValue) -> Result<Self, Self::Error> {
161 match val {
162 RullstValue::Int(i) => Ok(i),
163 _ => Err("Not an i32"),
164 }
165 }
166}
167impl TryFrom<RullstValue> for f64 {
168 type Error = &'static str;
169 fn try_from(val: RullstValue) -> Result<Self, Self::Error> {
170 match val {
171 RullstValue::Float(f) => Ok(f),
172 _ => Err("Not an f64"),
173 }
174 }
175}
176impl TryFrom<RullstValue> for bool {
177 type Error = &'static str;
178 fn try_from(val: RullstValue) -> Result<Self, Self::Error> {
179 match val {
180 RullstValue::Bool(b) => Ok(b),
181 _ => Err("Not a bool"),
182 }
183 }
184}
185
186pub struct Orm;
188
189impl Orm {
190 pub async fn init(database_url: &str) -> Result<(), crate::Error> {
192 Self::validate_dsn(database_url);
193
194 #[cfg(not(any(
195 feature = "strict-postgres",
196 feature = "strict-mysql",
197 feature = "strict-sqlite"
198 )))]
199 install_default_drivers();
200
201 let pool = RullstPool::connect(database_url).await?;
202
203 if DB_POOL.set(pool).is_err() {
204 return Err(crate::Error::Internal(
205 "Orm has already been initialized".to_string(),
206 ));
207 }
208
209 let driver = if database_url.starts_with("postgres") {
210 "postgres"
211 } else if database_url.starts_with("mysql") {
212 "mysql"
213 } else {
214 "sqlite"
215 };
216
217 let _ = DB_DRIVER.set(driver.to_string());
218 let _ = REPLICA_POOLS.set(vec![]);
219
220 Ok(())
221 }
222
223 pub async fn init_with_options(
225 database_url: &str,
226 max_connections: u32,
227 acquire_timeout_secs: u64,
228 ) -> Result<(), crate::Error> {
229 Self::validate_dsn(database_url);
230
231 #[cfg(not(any(
232 feature = "strict-postgres",
233 feature = "strict-mysql",
234 feature = "strict-sqlite"
235 )))]
236 install_default_drivers();
237
238 let pool = RullstPoolOptions::new()
239 .max_connections(max_connections)
240 .acquire_timeout(std::time::Duration::from_secs(acquire_timeout_secs))
241 .connect(database_url)
242 .await?;
243
244 if DB_POOL.set(pool).is_err() {
245 return Err(crate::Error::Internal(
246 "Orm has already been initialized".to_string(),
247 ));
248 }
249
250 let driver = if database_url.starts_with("postgres") {
251 "postgres"
252 } else if database_url.starts_with("mysql") {
253 "mysql"
254 } else {
255 "sqlite"
256 };
257
258 let _ = DB_DRIVER.set(driver.to_string());
259 let _ = REPLICA_POOLS.set(vec![]);
260
261 Ok(())
262 }
263
264 #[cfg_attr(test, mutants::skip)]
265 fn validate_dsn(database_url: &str) {
266 if database_url.contains("sslmode=disable")
267 && !database_url.contains("localhost")
268 && !database_url.contains("127.0.0.1")
269 {
270 eprintln!(
271 "⚠️ [SECURITY WARNING] Rullst ORM: TLS/SSL disabled on external database connection! This is highly discouraged in production environments."
272 );
273 }
274 }
275
276 pub async fn init_with_replicas(
278 primary_url: &str,
279 replica_urls: Vec<&str>,
280 ) -> Result<(), crate::Error> {
281 #[cfg(not(any(
282 feature = "strict-postgres",
283 feature = "strict-mysql",
284 feature = "strict-sqlite"
285 )))]
286 install_default_drivers();
287
288 let pool = RullstPool::connect(primary_url).await?;
289
290 if DB_POOL.set(pool).is_err() {
291 return Err(crate::Error::Internal(
292 "Orm has already been initialized".to_string(),
293 ));
294 }
295
296 let driver = if primary_url.starts_with("postgres") {
297 "postgres"
298 } else if primary_url.starts_with("mysql") {
299 "mysql"
300 } else {
301 "sqlite"
302 };
303
304 let _ = DB_DRIVER.set(driver.to_string());
305
306 let replica_futures: Vec<_> = replica_urls.into_iter().map(RullstPool::connect).collect();
308 let replicas = futures::future::try_join_all(replica_futures).await?;
309 let _ = REPLICA_POOLS.set(replicas);
310
311 Ok(())
312 }
313
314 pub fn pool() -> &'static RullstPool {
316 DB_POOL
317 .get()
318 .expect("Orm must be initialized before querying")
319 }
320
321 #[cfg_attr(test, mutants::skip)]
324 pub fn read_pool() -> &'static RullstPool {
325 if let Some(replicas) = REPLICA_POOLS.get()
326 && !replicas.is_empty()
327 {
328 let idx = REPLICA_INDEX.fetch_add(1, Ordering::Relaxed) % replicas.len();
329 return &replicas[idx];
330 }
331 Self::pool()
332 }
333
334 pub fn driver() -> &'static str {
336 DB_DRIVER
337 .get()
338 .expect("Orm must be initialized before querying")
339 .as_str()
340 }
341
342 pub async fn begin_transaction() -> Result<crate::db::Transaction<'static>, crate::Error> {
343 let pool = Self::pool();
344 pool.begin().await.map_err(Into::into)
345 }
346
347 #[cfg_attr(test, mutants::skip)]
349 pub async fn seed(seeders: Vec<Box<dyn Seeder>>) -> Result<(), crate::Error> {
350 for seeder in seeders {
351 seeder.run().await?;
352 }
353 Ok(())
354 }
355
356 pub fn enable_query_log() {
358 crate::schema::enable_query_log();
359 }
360
361 pub fn disable_query_log() {
363 crate::schema::disable_query_log();
364 }
365
366 pub fn set_max_query_limit(limit: usize) {
368 crate::schema::set_max_query_limit(limit);
369 }
370
371 pub fn set_query_timeout(secs: u64) {
373 crate::schema::set_query_timeout(secs);
374 }
375
376 #[cfg(feature = "redis")]
378 #[cfg_attr(test, mutants::skip)]
379 pub async fn init_redis(redis_url: &str) -> Result<(), crate::Error> {
380 let client = _redis::Client::open(redis_url)?;
381 let manager = _redis::aio::ConnectionManager::new(client.clone()).await?;
382 let _ = REDIS_CLIENT.set(client);
383 let _ = REDIS_MANAGER.set(manager);
384 Ok(())
385 }
386
387 #[cfg(feature = "redis")]
389 #[cfg_attr(test, mutants::skip)]
390 pub fn redis_client() -> Result<&'static _redis::Client, crate::Error> {
391 REDIS_CLIENT.get().ok_or_else(|| {
392 crate::Error::Internal(
393 "Orm::init_redis() must be called before using cache features".to_string(),
394 )
395 })
396 }
397
398 #[cfg(feature = "redis")]
400 #[cfg_attr(test, mutants::skip)]
401 pub fn redis_manager() -> Result<_redis::aio::ConnectionManager, crate::Error> {
402 REDIS_MANAGER.get().cloned().ok_or_else(|| {
403 crate::Error::Internal(
404 "Orm::init_redis() must be called before using cache features".to_string(),
405 )
406 })
407 }
408}
409
410#[async_trait]
412pub trait Seeder: Send + Sync {
413 async fn run(&self) -> Result<(), crate::Error>;
414}
415
416#[async_trait]
418pub trait RullstModel {
419 fn table_name() -> &'static str;
420}
421
422#[derive(Debug, Clone)]
424pub struct PaginationResult<T> {
425 pub data: Vec<T>,
426 pub total: i64,
427 pub per_page: usize,
428 pub current_page: usize,
429 pub last_page: usize,
430}
431
432#[cfg(test)]
433mod tests {
434 use super::*;
435
436 #[test]
437 fn test_pagination_result() {
438 let mut pr = PaginationResult {
439 data: vec![1, 2, 3],
440 total: 3,
441 per_page: 10,
442 current_page: 1,
443 last_page: 1,
444 };
445 assert_eq!(pr.data.len(), 3);
446 assert_eq!(pr.total, 3);
447 pr.data.push(4);
448 assert_eq!(pr.data.len(), 4);
449 }
450
451 #[test]
452 fn test_rullst_value_conversions() {
453 let v: RullstValue = "test".into();
455 assert!(matches!(v, RullstValue::String(_)));
456 let v_string: RullstValue = "test".to_string().into();
457 assert!(matches!(v_string, RullstValue::String(_)));
458 let v_int: RullstValue = 100.into();
459 assert!(matches!(v_int, RullstValue::Int(100)));
460 let v_bool: RullstValue = false.into();
461 assert!(matches!(v_bool, RullstValue::Bool(false)));
462 let v_float: RullstValue = std::f64::consts::PI.into();
463 assert!(matches!(v_float, RullstValue::Float(_)));
464
465 let v_str_conv = RullstValue::String("hello".to_string());
467 assert_eq!(String::try_from(v_str_conv).unwrap(), "hello");
468 assert!(String::try_from(RullstValue::Int(10)).is_err());
469
470 let v_int_conv = RullstValue::Int(42);
472 assert_eq!(i32::try_from(v_int_conv).unwrap(), 42);
473 assert!(i32::try_from(RullstValue::Bool(true)).is_err());
474
475 let v_float_conv = RullstValue::Float(2.71);
477 assert_eq!(f64::try_from(v_float_conv).unwrap(), 2.71);
478 assert!(f64::try_from(RullstValue::Int(10)).is_err());
479
480 let v_bool_conv = RullstValue::Bool(true);
482 assert!(bool::try_from(v_bool_conv).unwrap());
483 assert!(bool::try_from(RullstValue::Int(10)).is_err());
484 }
485
486 #[test]
487 fn test_enable_query_log_wrapper() {
488 Orm::disable_query_log();
490 assert!(!crate::schema::is_query_log_enabled());
491 Orm::enable_query_log();
492 assert!(crate::schema::is_query_log_enabled());
493 Orm::disable_query_log();
494 assert!(!crate::schema::is_query_log_enabled());
495 }
496
497 #[test]
498 fn test_disable_query_log_wrapper() {
499 Orm::enable_query_log();
500 Orm::disable_query_log();
501 assert!(!crate::schema::is_query_log_enabled());
502 }
503
504 #[cfg(feature = "redis")]
505 #[test]
506 fn test_redis_client_uninitialized() {
507 let err = Orm::redis_client().unwrap_err();
508 assert!(matches!(err, crate::Error::Internal(_)));
509 }
510
511 #[cfg(feature = "redis")]
512 #[test]
513 fn test_redis_manager_uninitialized() {
514 let err = Orm::redis_manager().unwrap_err();
515 assert!(matches!(err, crate::Error::Internal(_)));
516 }
517
518 #[test]
519 #[should_panic(expected = "Orm must be initialized before querying")]
520 fn test_pool_uninitialized() {
521 let _ = Orm::pool();
522 }
523
524 #[test]
525 #[should_panic(expected = "Orm must be initialized before querying")]
526 fn test_driver_uninitialized() {
527 let _ = Orm::driver();
528 }
529
530 #[test]
531 #[should_panic(expected = "Orm must be initialized before querying")]
532 fn test_read_pool_uninitialized() {
533 let _ = Orm::read_pool();
534 }
535
536 #[test]
537 fn test_validate_dsn() {
538 Orm::validate_dsn("sqlite::memory:");
540 Orm::validate_dsn("postgres://external-db.com/mydb?sslmode=disable");
542 }
543
544 #[cfg(feature = "redis")]
545 #[tokio::test]
546 async fn test_init_redis_failure() {
547 let err = Orm::init_redis("redis://127.0.0.1:0").await.unwrap_err();
548 assert!(matches!(err, crate::Error::CacheError(_)));
549 }
550
551 #[test]
552 fn test_orm_max_query_limit_and_timeout() {
553 Orm::set_max_query_limit(15);
554 assert_eq!(crate::schema::get_max_query_limit(), Some(15));
555 Orm::set_max_query_limit(0);
556 assert_eq!(crate::schema::get_max_query_limit(), None);
557
558 Orm::set_query_timeout(5);
559 assert_eq!(
560 crate::schema::get_query_timeout(),
561 Some(std::time::Duration::from_secs(5))
562 );
563 Orm::set_query_timeout(0);
564 assert_eq!(crate::schema::get_query_timeout(), None);
565 }
566}