1#[cfg(not(any(
2 feature = "strict-postgres",
3 feature = "strict-mysql",
4 feature = "strict-sqlite"
5)))]
6pub use sqlx::AnyPool as RullstPool;
7
8#[cfg(feature = "strict-postgres")]
9pub use sqlx::PgPool as RullstPool;
10
11#[cfg(all(feature = "strict-mysql", not(feature = "strict-postgres")))]
12pub use sqlx::MySqlPool as RullstPool;
13
14#[cfg(all(
15 feature = "strict-sqlite",
16 not(feature = "strict-postgres"),
17 not(feature = "strict-mysql")
18))]
19pub use sqlx::SqlitePool as RullstPool;
20
21#[cfg(not(any(
22 feature = "strict-postgres",
23 feature = "strict-mysql",
24 feature = "strict-sqlite"
25)))]
26use sqlx::any::install_default_drivers;
27
28use std::sync::OnceLock;
29use std::sync::atomic::{AtomicUsize, Ordering};
30
31pub use futures;
33pub use serde;
34pub use serde_json;
35pub use sqlx;
36
37#[cfg(feature = "redis")]
38pub use redis;
39pub mod admin;
40pub mod audit;
41pub mod collection;
42pub mod database;
43pub mod resource;
44pub mod schema;
45pub mod scout;
46pub mod tenant;
47pub mod types;
48
49pub use admin::dashboard_html;
51pub use collection::RullstCollection;
52pub use database::RullstDatabase;
53pub use resource::{ApiResource, JsonResource, ResourceCollection};
54pub use rullst_orm_macros::Orm;
55pub use scout::{SearchEngine, get_search_engine, set_search_engine};
56pub use tenant::{get_tenant_id, with_tenant};
57pub use types::Json;
58
59pub use async_trait::async_trait;
61
62pub use schema::{JoinClause, SubqueryBuilder};
64pub use sqlx::FromRow;
65
66static DB_POOL: OnceLock<RullstPool> = OnceLock::new();
68
69static DB_DRIVER: OnceLock<String> = OnceLock::new();
71
72static REPLICA_POOLS: OnceLock<Vec<RullstPool>> = OnceLock::new();
74
75static REPLICA_INDEX: AtomicUsize = AtomicUsize::new(0);
77
78#[cfg(feature = "redis")]
79static REDIS_CLIENT: OnceLock<redis::Client> = OnceLock::new();
80
81#[cfg(feature = "redis")]
82static REDIS_MANAGER: OnceLock<redis::aio::ConnectionManager> = OnceLock::new();
83
84#[derive(Clone, Debug)]
86pub enum RullstValue {
87 String(String),
88 Int(i32),
89 Float(f64),
90 Bool(bool),
91}
92
93impl From<&str> for RullstValue {
94 fn from(s: &str) -> Self {
95 RullstValue::String(s.to_string())
96 }
97}
98impl From<String> for RullstValue {
99 fn from(s: String) -> Self {
100 RullstValue::String(s)
101 }
102}
103impl From<i32> for RullstValue {
104 fn from(i: i32) -> Self {
105 RullstValue::Int(i)
106 }
107}
108impl From<f64> for RullstValue {
109 fn from(f: f64) -> Self {
110 RullstValue::Float(f)
111 }
112}
113impl From<bool> for RullstValue {
114 fn from(b: bool) -> Self {
115 RullstValue::Bool(b)
116 }
117}
118
119impl TryFrom<RullstValue> for String {
120 type Error = &'static str;
121 fn try_from(val: RullstValue) -> Result<Self, Self::Error> {
122 match val {
123 RullstValue::String(s) => Ok(s),
124 _ => Err("Not a string"),
125 }
126 }
127}
128impl TryFrom<RullstValue> for i32 {
129 type Error = &'static str;
130 fn try_from(val: RullstValue) -> Result<Self, Self::Error> {
131 match val {
132 RullstValue::Int(i) => Ok(i),
133 _ => Err("Not an i32"),
134 }
135 }
136}
137impl TryFrom<RullstValue> for f64 {
138 type Error = &'static str;
139 fn try_from(val: RullstValue) -> Result<Self, Self::Error> {
140 match val {
141 RullstValue::Float(f) => Ok(f),
142 _ => Err("Not an f64"),
143 }
144 }
145}
146impl TryFrom<RullstValue> for bool {
147 type Error = &'static str;
148 fn try_from(val: RullstValue) -> Result<Self, Self::Error> {
149 match val {
150 RullstValue::Bool(b) => Ok(b),
151 _ => Err("Not a bool"),
152 }
153 }
154}
155
156pub struct Orm;
158
159impl Orm {
160 pub async fn init(database_url: &str) -> Result<(), sqlx::Error> {
162 #[cfg(not(any(
163 feature = "strict-postgres",
164 feature = "strict-mysql",
165 feature = "strict-sqlite"
166 )))]
167 install_default_drivers();
168
169 let pool = RullstPool::connect(database_url).await?;
170
171 if DB_POOL.set(pool).is_err() {
172 panic!("Orm has already been initialized");
173 }
174
175 let driver = if database_url.starts_with("postgres") {
176 "postgres"
177 } else if database_url.starts_with("mysql") {
178 "mysql"
179 } else {
180 "sqlite"
181 };
182
183 let _ = DB_DRIVER.set(driver.to_string());
184 let _ = REPLICA_POOLS.set(vec![]);
185
186 Ok(())
187 }
188
189 pub async fn init_with_replicas(
191 primary_url: &str,
192 replica_urls: Vec<&str>,
193 ) -> Result<(), sqlx::Error> {
194 #[cfg(not(any(
195 feature = "strict-postgres",
196 feature = "strict-mysql",
197 feature = "strict-sqlite"
198 )))]
199 install_default_drivers();
200
201 let pool = RullstPool::connect(primary_url).await?;
202
203 if DB_POOL.set(pool).is_err() {
204 panic!("Orm has already been initialized");
205 }
206
207 let driver = if primary_url.starts_with("postgres") {
208 "postgres"
209 } else if primary_url.starts_with("mysql") {
210 "mysql"
211 } else {
212 "sqlite"
213 };
214
215 let _ = DB_DRIVER.set(driver.to_string());
216
217 let replica_futures: Vec<_> = replica_urls.into_iter().map(RullstPool::connect).collect();
219 let replicas = futures::future::try_join_all(replica_futures).await?;
220 let _ = REPLICA_POOLS.set(replicas);
221
222 Ok(())
223 }
224
225 pub fn pool() -> &'static RullstPool {
227 DB_POOL
228 .get()
229 .expect("Orm must be initialized before querying")
230 }
231
232 pub fn read_pool() -> &'static RullstPool {
235 if let Some(replicas) = REPLICA_POOLS.get()
236 && !replicas.is_empty()
237 {
238 let idx = REPLICA_INDEX.fetch_add(1, Ordering::Relaxed) % replicas.len();
239 return &replicas[idx];
240 }
241 Self::pool()
242 }
243
244 pub fn driver() -> &'static str {
246 DB_DRIVER
247 .get()
248 .expect("Orm must be initialized before querying")
249 .as_str()
250 }
251
252 #[cfg(not(any(
254 feature = "strict-postgres",
255 feature = "strict-mysql",
256 feature = "strict-sqlite"
257 )))]
258 pub async fn begin_transaction() -> Result<sqlx::Transaction<'static, sqlx::Any>, sqlx::Error> {
259 let pool = Self::pool();
260 pool.begin().await
261 }
262
263 #[cfg(feature = "strict-postgres")]
264 pub async fn begin_transaction()
265 -> Result<sqlx::Transaction<'static, sqlx::Postgres>, sqlx::Error> {
266 let pool = Self::pool();
267 pool.begin().await
268 }
269
270 #[cfg(all(feature = "strict-mysql", not(feature = "strict-postgres")))]
271 pub async fn begin_transaction() -> Result<sqlx::Transaction<'static, sqlx::MySql>, sqlx::Error>
272 {
273 let pool = Self::pool();
274 pool.begin().await
275 }
276
277 #[cfg(all(
278 feature = "strict-sqlite",
279 not(feature = "strict-postgres"),
280 not(feature = "strict-mysql")
281 ))]
282 pub async fn begin_transaction() -> Result<sqlx::Transaction<'static, sqlx::Sqlite>, sqlx::Error>
283 {
284 let pool = Self::pool();
285 pool.begin().await
286 }
287
288 pub async fn seed(seeders: Vec<Box<dyn Seeder>>) -> Result<(), sqlx::Error> {
290 for seeder in seeders {
291 seeder.run().await?;
292 }
293 Ok(())
294 }
295
296 pub fn enable_query_log() {
298 crate::schema::enable_query_log();
299 }
300
301 pub fn disable_query_log() {
303 crate::schema::disable_query_log();
304 }
305
306 #[cfg(feature = "redis")]
308 pub async fn init_redis(redis_url: &str) -> Result<(), redis::RedisError> {
309 let client = redis::Client::open(redis_url)?;
310 let manager = redis::aio::ConnectionManager::new(client.clone()).await?;
311 let _ = REDIS_CLIENT.set(client);
312 let _ = REDIS_MANAGER.set(manager);
313 Ok(())
314 }
315
316 #[cfg(feature = "redis")]
318 pub fn redis_client() -> &'static redis::Client {
319 REDIS_CLIENT
320 .get()
321 .expect("Redis must be initialized before using cache features")
322 }
323
324 #[cfg(feature = "redis")]
326 pub fn redis_manager() -> redis::aio::ConnectionManager {
327 REDIS_MANAGER
328 .get()
329 .expect("Redis must be initialized before using cache features")
330 .clone()
331 }
332}
333
334#[async_trait]
336pub trait Seeder: Send + Sync {
337 async fn run(&self) -> Result<(), sqlx::Error>;
338}
339
340#[async_trait]
342pub trait RullstModel {
343 fn table_name() -> &'static str;
344}
345
346#[derive(Debug, Clone)]
348pub struct PaginationResult<T> {
349 pub data: Vec<T>,
350 pub total: i64,
351 pub per_page: usize,
352 pub current_page: usize,
353 pub last_page: usize,
354}
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359
360 #[test]
361 fn test_rullst_value_conversions() {
362 let v: RullstValue = "test".into();
363 assert!(matches!(v, RullstValue::String(_)));
364 let v_int: RullstValue = 100.into();
365 assert!(matches!(v_int, RullstValue::Int(100)));
366 let v_bool: RullstValue = false.into();
367 assert!(matches!(v_bool, RullstValue::Bool(false)));
368 }
369
370 #[test]
371 fn test_enable_query_log_wrapper() {
372 Orm::disable_query_log();
374 assert!(!crate::schema::is_query_log_enabled());
375 Orm::enable_query_log();
376 assert!(crate::schema::is_query_log_enabled());
377 Orm::disable_query_log();
378 assert!(!crate::schema::is_query_log_enabled());
379 }
380
381 #[test]
382 fn test_disable_query_log_wrapper() {
383 Orm::enable_query_log();
384 Orm::disable_query_log();
385 assert!(!crate::schema::is_query_log_enabled());
386 }
387}