1use std::collections::HashMap;
11use std::sync::Arc;
12
13use parking_lot::RwLock;
14
15use crate::Error;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
19pub enum Driver {
20 Postgres,
21 MySql,
22 Sqlite,
23}
24
25impl Driver {
26 pub fn from_url(url: &str) -> Result<Self, Error> {
28 let lower = url.trim().to_ascii_lowercase();
29 if lower.starts_with("postgres://") || lower.starts_with("postgresql://") {
30 Ok(Driver::Postgres)
31 } else if lower.starts_with("mysql://") || lower.starts_with("mariadb://") {
32 Ok(Driver::MySql)
33 } else if lower.starts_with("sqlite:") {
34 Ok(Driver::Sqlite)
35 } else {
36 Err(Error::Internal(format!(
37 "unknown database URL scheme: {url}"
38 )))
39 }
40 }
41
42 pub fn name(&self) -> &'static str {
43 match self {
44 Driver::Postgres => "postgres",
45 Driver::MySql => "mysql",
46 Driver::Sqlite => "sqlite",
47 }
48 }
49}
50
51#[derive(Clone)]
53pub enum Pool {
54 Postgres(sqlx::PgPool),
55 MySql(sqlx::MySqlPool),
56 Sqlite(sqlx::SqlitePool),
57}
58
59impl Pool {
60 pub fn driver(&self) -> Driver {
61 match self {
62 Pool::Postgres(_) => Driver::Postgres,
63 Pool::MySql(_) => Driver::MySql,
64 Pool::Sqlite(_) => Driver::Sqlite,
65 }
66 }
67
68 pub fn as_postgres(&self) -> Option<&sqlx::PgPool> {
69 match self {
70 Pool::Postgres(p) => Some(p),
71 _ => None,
72 }
73 }
74
75 pub fn as_mysql(&self) -> Option<&sqlx::MySqlPool> {
76 match self {
77 Pool::MySql(p) => Some(p),
78 _ => None,
79 }
80 }
81
82 pub fn as_sqlite(&self) -> Option<&sqlx::SqlitePool> {
83 match self {
84 Pool::Sqlite(p) => Some(p),
85 _ => None,
86 }
87 }
88
89 pub fn expect_pg(&self) -> &sqlx::PgPool {
93 self.as_postgres().unwrap_or_else(|| {
94 panic!(
95 "Cast::Model query builder requires a Postgres pool in v0.1 (got {:?}). \
96 Use raw sqlx::query against c.pool().as_mysql()/as_sqlite() for now.",
97 self.driver()
98 )
99 })
100 }
101
102 pub async fn execute(&self, sql: &str) -> Result<u64, Error> {
104 Ok(match self {
105 Pool::Postgres(p) => sqlx::query(sql).execute(p).await?.rows_affected(),
106 Pool::MySql(p) => sqlx::query(sql).execute(p).await?.rows_affected(),
107 Pool::Sqlite(p) => sqlx::query(sql).execute(p).await?.rows_affected(),
108 })
109 }
110}
111
112impl<'a> From<&'a Pool> for Option<&'a sqlx::PgPool> {
118 fn from(pool: &'a Pool) -> Self {
119 pool.as_postgres()
120 }
121}
122
123pub async fn connect(url: &str, max_connections: u32) -> Result<Pool, Error> {
125 let driver = Driver::from_url(url)?;
126 match driver {
127 Driver::Postgres => {
128 let pool = sqlx::postgres::PgPoolOptions::new()
129 .max_connections(max_connections)
130 .connect(url)
131 .await?;
132 Ok(Pool::Postgres(pool))
133 }
134 Driver::MySql => {
135 let pool = sqlx::mysql::MySqlPoolOptions::new()
136 .max_connections(max_connections)
137 .connect(url)
138 .await?;
139 Ok(Pool::MySql(pool))
140 }
141 Driver::Sqlite => {
142 use sqlx::ConnectOptions;
145 use std::str::FromStr;
146 let opts = sqlx::sqlite::SqliteConnectOptions::from_str(url)?
147 .create_if_missing(true)
148 .log_statements(tracing::log::LevelFilter::Debug);
149 let pool = sqlx::sqlite::SqlitePoolOptions::new()
150 .max_connections(max_connections.max(1))
151 .connect_with(opts)
152 .await?;
153 Ok(Pool::Sqlite(pool))
154 }
155 }
156}
157
158#[derive(Clone)]
160pub struct Connection {
161 pub name: String,
162 pub write: Pool,
163 pub reads: Vec<Pool>,
164}
165
166impl Connection {
167 pub fn driver(&self) -> Driver {
168 self.write.driver()
169 }
170
171 pub fn writer(&self) -> &Pool {
172 &self.write
173 }
174
175 pub fn reader(&self) -> &Pool {
176 if self.reads.is_empty() {
177 &self.write
178 } else {
179 use std::sync::atomic::{AtomicUsize, Ordering};
180 static CURSOR: AtomicUsize = AtomicUsize::new(0);
181 let idx = CURSOR.fetch_add(1, Ordering::Relaxed) % self.reads.len();
182 &self.reads[idx]
183 }
184 }
185}
186
187#[derive(Clone)]
189pub struct ConnectionManager {
190 inner: Arc<ManagerInner>,
191}
192
193struct ManagerInner {
194 default: String,
195 connections: RwLock<HashMap<String, Connection>>,
196}
197
198impl ConnectionManager {
199 pub fn from_pool(pool: Pool) -> Self {
200 let mut map = HashMap::new();
201 map.insert(
202 "default".to_string(),
203 Connection {
204 name: "default".to_string(),
205 write: pool,
206 reads: Vec::new(),
207 },
208 );
209 Self {
210 inner: Arc::new(ManagerInner {
211 default: "default".to_string(),
212 connections: RwLock::new(map),
213 }),
214 }
215 }
216
217 pub fn from_connections(
218 default: impl Into<String>,
219 connections: HashMap<String, Connection>,
220 ) -> Self {
221 Self {
222 inner: Arc::new(ManagerInner {
223 default: default.into(),
224 connections: RwLock::new(connections),
225 }),
226 }
227 }
228
229 pub fn get(&self, name: &str) -> Option<Connection> {
230 self.inner.connections.read().get(name).cloned()
231 }
232
233 pub fn default_connection(&self) -> Connection {
234 let map = self.inner.connections.read();
235 map.get(&self.inner.default)
236 .or_else(|| map.values().next())
237 .cloned()
238 .expect("no connections configured")
239 }
240
241 pub fn default_pool(&self) -> Pool {
242 self.default_connection().write
243 }
244
245 pub fn default_driver(&self) -> Driver {
246 self.default_pool().driver()
247 }
248
249 pub fn pool(&self, name: &str) -> Option<Pool> {
250 self.get(name).map(|c| c.write)
251 }
252
253 pub fn insert(&self, conn: Connection) {
254 self.inner
255 .connections
256 .write()
257 .insert(conn.name.clone(), conn);
258 }
259
260 pub fn names(&self) -> Vec<String> {
261 self.inner.connections.read().keys().cloned().collect()
262 }
263
264 pub fn default_name(&self) -> &str {
265 &self.inner.default
266 }
267}