Skip to main content

cast_core/
pool.rs

1//! Database connection pool(s).
2//!
3//! Cast supports Postgres, MySQL, and SQLite via per-driver pool variants.
4//! The URL scheme determines the driver at `connect()` time:
5//!
6//! - `postgres://...` / `postgresql://...` → `Driver::Postgres`
7//! - `mysql://...` / `mariadb://...`       → `Driver::MySql`
8//! - `sqlite://...` / `sqlite:...`         → `Driver::Sqlite`
9
10use std::collections::HashMap;
11use std::sync::Arc;
12
13use parking_lot::RwLock;
14
15use crate::Error;
16
17/// Which database engine a connection is talking to.
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
19pub enum Driver {
20    Postgres,
21    MySql,
22    Sqlite,
23}
24
25impl Driver {
26    /// Infer the driver from a connection URL.
27    pub fn from_url(url: &str) -> Result<Self, Error> {
28        let lower = url.trim().to_ascii_lowercase();
29        if lower.starts_with("postgres://") || lower.starts_with("postgresql://") {
30            Ok(Driver::Postgres)
31        } else if lower.starts_with("mysql://") || lower.starts_with("mariadb://") {
32            Ok(Driver::MySql)
33        } else if lower.starts_with("sqlite:") {
34            Ok(Driver::Sqlite)
35        } else {
36            Err(Error::Internal(format!(
37                "unknown database URL scheme: {url}"
38            )))
39        }
40    }
41
42    pub fn name(&self) -> &'static str {
43        match self {
44            Driver::Postgres => "postgres",
45            Driver::MySql => "mysql",
46            Driver::Sqlite => "sqlite",
47        }
48    }
49}
50
51/// Driver-tagged pool. Variant tells you which backing sqlx type is live.
52#[derive(Clone)]
53pub enum Pool {
54    Postgres(sqlx::PgPool),
55    MySql(sqlx::MySqlPool),
56    Sqlite(sqlx::SqlitePool),
57}
58
59impl Pool {
60    pub fn driver(&self) -> Driver {
61        match self {
62            Pool::Postgres(_) => Driver::Postgres,
63            Pool::MySql(_) => Driver::MySql,
64            Pool::Sqlite(_) => Driver::Sqlite,
65        }
66    }
67
68    pub fn as_postgres(&self) -> Option<&sqlx::PgPool> {
69        match self {
70            Pool::Postgres(p) => Some(p),
71            _ => None,
72        }
73    }
74
75    pub fn as_mysql(&self) -> Option<&sqlx::MySqlPool> {
76        match self {
77            Pool::MySql(p) => Some(p),
78            _ => None,
79        }
80    }
81
82    pub fn as_sqlite(&self) -> Option<&sqlx::SqlitePool> {
83        match self {
84            Pool::Sqlite(p) => Some(p),
85            _ => None,
86        }
87    }
88
89    /// Panic with a clear message if the pool isn't Postgres. The Cast `#[derive(Model)]`
90    /// query builder + relations target Postgres only in v0.1; use this internally to
91    /// extract the typed pool. v0.2 lifts the restriction.
92    pub fn expect_pg(&self) -> &sqlx::PgPool {
93        self.as_postgres().unwrap_or_else(|| {
94            panic!(
95                "Cast::Model query builder requires a Postgres pool in v0.1 (got {:?}). \
96                 Use raw sqlx::query against c.pool().as_mysql()/as_sqlite() for now.",
97                self.driver()
98            )
99        })
100    }
101
102    /// Execute a `&str` against whichever driver is live. Returns rows affected.
103    pub async fn execute(&self, sql: &str) -> Result<u64, Error> {
104        Ok(match self {
105            Pool::Postgres(p) => sqlx::query(sql).execute(p).await?.rows_affected(),
106            Pool::MySql(p) => sqlx::query(sql).execute(p).await?.rows_affected(),
107            Pool::Sqlite(p) => sqlx::query(sql).execute(p).await?.rows_affected(),
108        })
109    }
110}
111
112/// Backward-compat: many call sites pass `&Pool` to sqlx via the Postgres path.
113/// Users who need the bare PgPool can call `pool.as_postgres().expect("...")`.
114///
115/// To keep the v0.1 surface alive, this `Deref`-style extraction is available
116/// via `From<&Pool>`.
117impl<'a> From<&'a Pool> for Option<&'a sqlx::PgPool> {
118    fn from(pool: &'a Pool) -> Self {
119        pool.as_postgres()
120    }
121}
122
123/// Connect to a database, dispatching by URL scheme.
124pub async fn connect(url: &str, max_connections: u32) -> Result<Pool, Error> {
125    let driver = Driver::from_url(url)?;
126    match driver {
127        Driver::Postgres => {
128            let pool = sqlx::postgres::PgPoolOptions::new()
129                .max_connections(max_connections)
130                .connect(url)
131                .await?;
132            Ok(Pool::Postgres(pool))
133        }
134        Driver::MySql => {
135            let pool = sqlx::mysql::MySqlPoolOptions::new()
136                .max_connections(max_connections)
137                .connect(url)
138                .await?;
139            Ok(Pool::MySql(pool))
140        }
141        Driver::Sqlite => {
142            // `sqlite:foo.db` and `sqlite://foo.db` are both fine — sqlx parses both.
143            // Use ConnectOptions so we can `create_if_missing(true)` for dev/test files.
144            use sqlx::ConnectOptions;
145            use std::str::FromStr;
146            let opts = sqlx::sqlite::SqliteConnectOptions::from_str(url)?
147                .create_if_missing(true)
148                .log_statements(tracing::log::LevelFilter::Debug);
149            let pool = sqlx::sqlite::SqlitePoolOptions::new()
150                .max_connections(max_connections.max(1))
151                .connect_with(opts)
152                .await?;
153            Ok(Pool::Sqlite(pool))
154        }
155    }
156}
157
158/// One named connection — a write pool plus zero-or-more read replicas.
159#[derive(Clone)]
160pub struct Connection {
161    pub name: String,
162    pub write: Pool,
163    pub reads: Vec<Pool>,
164}
165
166impl Connection {
167    pub fn driver(&self) -> Driver {
168        self.write.driver()
169    }
170
171    pub fn writer(&self) -> &Pool {
172        &self.write
173    }
174
175    pub fn reader(&self) -> &Pool {
176        if self.reads.is_empty() {
177            &self.write
178        } else {
179            use std::sync::atomic::{AtomicUsize, Ordering};
180            static CURSOR: AtomicUsize = AtomicUsize::new(0);
181            let idx = CURSOR.fetch_add(1, Ordering::Relaxed) % self.reads.len();
182            &self.reads[idx]
183        }
184    }
185}
186
187/// Resolves named connections — the centerpiece of Cast's multi-database support.
188#[derive(Clone)]
189pub struct ConnectionManager {
190    inner: Arc<ManagerInner>,
191}
192
193struct ManagerInner {
194    default: String,
195    connections: RwLock<HashMap<String, Connection>>,
196}
197
198impl ConnectionManager {
199    pub fn from_pool(pool: Pool) -> Self {
200        let mut map = HashMap::new();
201        map.insert(
202            "default".to_string(),
203            Connection {
204                name: "default".to_string(),
205                write: pool,
206                reads: Vec::new(),
207            },
208        );
209        Self {
210            inner: Arc::new(ManagerInner {
211                default: "default".to_string(),
212                connections: RwLock::new(map),
213            }),
214        }
215    }
216
217    pub fn from_connections(
218        default: impl Into<String>,
219        connections: HashMap<String, Connection>,
220    ) -> Self {
221        Self {
222            inner: Arc::new(ManagerInner {
223                default: default.into(),
224                connections: RwLock::new(connections),
225            }),
226        }
227    }
228
229    pub fn get(&self, name: &str) -> Option<Connection> {
230        self.inner.connections.read().get(name).cloned()
231    }
232
233    pub fn default_connection(&self) -> Connection {
234        let map = self.inner.connections.read();
235        map.get(&self.inner.default)
236            .or_else(|| map.values().next())
237            .cloned()
238            .expect("no connections configured")
239    }
240
241    pub fn default_pool(&self) -> Pool {
242        self.default_connection().write
243    }
244
245    pub fn default_driver(&self) -> Driver {
246        self.default_pool().driver()
247    }
248
249    pub fn pool(&self, name: &str) -> Option<Pool> {
250        self.get(name).map(|c| c.write)
251    }
252
253    pub fn insert(&self, conn: Connection) {
254        self.inner
255            .connections
256            .write()
257            .insert(conn.name.clone(), conn);
258    }
259
260    pub fn names(&self) -> Vec<String> {
261        self.inner.connections.read().keys().cloned().collect()
262    }
263
264    pub fn default_name(&self) -> &str {
265        &self.inner.default
266    }
267}