Skip to main content

prax_sqlite/
connection.rs

1//! SQLite connection wrapper.
2
3use std::collections::VecDeque;
4use std::sync::Arc;
5use std::time::Instant;
6
7use parking_lot::Mutex;
8use tokio::sync::OwnedSemaphorePermit;
9use tokio_rusqlite::Connection;
10use tracing::{debug, trace};
11
12use crate::error::{SqliteError, SqliteResult};
13
14/// A pooled connection for returning to the pool.
15pub(crate) struct PooledConnection {
16    /// The underlying connection.
17    pub conn: Connection,
18    /// When this connection was created.
19    pub created_at: Instant,
20    /// When this connection was last used.
21    pub last_used: Instant,
22}
23
24impl PooledConnection {
25    pub fn new(conn: Connection) -> Self {
26        let now = Instant::now();
27        Self {
28            conn,
29            created_at: now,
30            last_used: now,
31        }
32    }
33}
34
35/// A wrapper around a SQLite connection.
36pub struct SqliteConnection {
37    conn: Option<Connection>,
38    #[allow(dead_code)]
39    permit: OwnedSemaphorePermit,
40    /// Channel to return the connection to the pool.
41    return_to_pool: Option<Arc<Mutex<VecDeque<PooledConnection>>>>,
42    /// When this connection was created (for pool tracking).
43    created_at: Instant,
44}
45
46impl SqliteConnection {
47    /// Create a new connection wrapper (non-pooled).
48    pub fn new(conn: Connection, permit: OwnedSemaphorePermit) -> Self {
49        Self {
50            conn: Some(conn),
51            permit,
52            return_to_pool: None,
53            created_at: Instant::now(),
54        }
55    }
56
57    /// Create a new pooled connection wrapper.
58    pub(crate) fn new_pooled(
59        conn: Connection,
60        permit: OwnedSemaphorePermit,
61        return_to_pool: Option<Arc<Mutex<VecDeque<PooledConnection>>>>,
62    ) -> Self {
63        Self {
64            conn: Some(conn),
65            permit,
66            return_to_pool,
67            created_at: Instant::now(),
68        }
69    }
70
71    /// Get the inner connection reference.
72    fn conn(&self) -> &Connection {
73        self.conn.as_ref().expect("Connection already taken")
74    }
75
76    /// Execute a query and return all rows as JSON values.
77    pub async fn query(&self, sql: &str) -> SqliteResult<Vec<serde_json::Value>> {
78        let sql = sql.to_string();
79        debug!(sql = %sql, "Executing query");
80
81        self.conn()
82            .call(move |conn| {
83                let mut stmt = conn.prepare(&sql)?;
84                let columns: Vec<String> =
85                    stmt.column_names().iter().map(|s| s.to_string()).collect();
86
87                let rows = stmt.query_map([], |row| {
88                    let mut map = serde_json::Map::new();
89                    for (i, col) in columns.iter().enumerate() {
90                        let value = crate::types::get_value_at_index(row, i);
91                        map.insert(col.clone(), value);
92                    }
93                    Ok(serde_json::Value::Object(map))
94                })?;
95
96                let results: Result<Vec<_>, _> = rows.collect();
97                Ok(results?)
98            })
99            .await
100            .map_err(SqliteError::from)
101    }
102
103    /// Execute a query with parameters and return all rows.
104    pub async fn query_params(
105        &self,
106        sql: &str,
107        params: Vec<rusqlite::types::Value>,
108    ) -> SqliteResult<Vec<serde_json::Value>> {
109        let sql = sql.to_string();
110        debug!(sql = %sql, "Executing parameterized query");
111
112        self.conn()
113            .call(move |conn| {
114                let mut stmt = conn.prepare(&sql)?;
115                let columns: Vec<String> =
116                    stmt.column_names().iter().map(|s| s.to_string()).collect();
117
118                let params_ref: Vec<&dyn rusqlite::ToSql> =
119                    params.iter().map(|v| v as &dyn rusqlite::ToSql).collect();
120
121                let rows = stmt.query_map(params_ref.as_slice(), |row| {
122                    let mut map = serde_json::Map::new();
123                    for (i, col) in columns.iter().enumerate() {
124                        let value = crate::types::get_value_at_index(row, i);
125                        map.insert(col.clone(), value);
126                    }
127                    Ok(serde_json::Value::Object(map))
128                })?;
129
130                let results: Result<Vec<_>, _> = rows.collect();
131                Ok(results?)
132            })
133            .await
134            .map_err(SqliteError::from)
135    }
136
137    /// Execute a query and return a single row.
138    pub async fn query_one(&self, sql: &str) -> SqliteResult<serde_json::Value> {
139        let sql = sql.to_string();
140        debug!(sql = %sql, "Executing query_one");
141
142        self.conn()
143            .call(move |conn| {
144                let mut stmt = conn.prepare(&sql)?;
145                let columns: Vec<String> =
146                    stmt.column_names().iter().map(|s| s.to_string()).collect();
147
148                Ok(stmt.query_row([], |row| {
149                    let mut map = serde_json::Map::new();
150                    for (i, col) in columns.iter().enumerate() {
151                        let value = crate::types::get_value_at_index(row, i);
152                        map.insert(col.clone(), value);
153                    }
154                    Ok(serde_json::Value::Object(map))
155                })?)
156            })
157            .await
158            .map_err(SqliteError::from)
159    }
160
161    /// Execute a query and return an optional row.
162    pub async fn query_optional(&self, sql: &str) -> SqliteResult<Option<serde_json::Value>> {
163        let sql = sql.to_string();
164        debug!(sql = %sql, "Executing query_optional");
165
166        self.conn()
167            .call(move |conn| {
168                let mut stmt = conn.prepare(&sql)?;
169                let columns: Vec<String> =
170                    stmt.column_names().iter().map(|s| s.to_string()).collect();
171
172                let result = stmt.query_row([], |row| {
173                    let mut map = serde_json::Map::new();
174                    for (i, col) in columns.iter().enumerate() {
175                        let value = crate::types::get_value_at_index(row, i);
176                        map.insert(col.clone(), value);
177                    }
178                    Ok(serde_json::Value::Object(map))
179                });
180
181                match result {
182                    Ok(row) => Ok(Some(row)),
183                    Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
184                    Err(e) => Err(tokio_rusqlite::Error::Rusqlite(e)),
185                }
186            })
187            .await
188            .map_err(SqliteError::from)
189    }
190
191    /// Execute a statement and return the number of affected rows.
192    pub async fn execute(&self, sql: &str) -> SqliteResult<usize> {
193        let sql = sql.to_string();
194        debug!(sql = %sql, "Executing statement");
195
196        self.conn()
197            .call(move |conn| Ok(conn.execute(&sql, [])?))
198            .await
199            .map_err(SqliteError::from)
200    }
201
202    /// Execute a statement with parameters and return the number of affected rows.
203    pub async fn execute_params(
204        &self,
205        sql: &str,
206        params: Vec<rusqlite::types::Value>,
207    ) -> SqliteResult<usize> {
208        let sql = sql.to_string();
209        debug!(sql = %sql, "Executing parameterized statement");
210
211        self.conn()
212            .call(move |conn| {
213                let params_ref: Vec<&dyn rusqlite::ToSql> =
214                    params.iter().map(|v| v as &dyn rusqlite::ToSql).collect();
215                Ok(conn.execute(&sql, params_ref.as_slice())?)
216            })
217            .await
218            .map_err(SqliteError::from)
219    }
220
221    /// Execute a statement and return the last insert rowid.
222    pub async fn execute_insert(&self, sql: &str) -> SqliteResult<i64> {
223        let sql = sql.to_string();
224        debug!(sql = %sql, "Executing insert");
225
226        self.conn()
227            .call(move |conn| {
228                conn.execute(&sql, [])?;
229                Ok(conn.last_insert_rowid())
230            })
231            .await
232            .map_err(SqliteError::from)
233    }
234
235    /// Execute a statement with parameters and return the last insert rowid.
236    pub async fn execute_insert_params(
237        &self,
238        sql: &str,
239        params: Vec<rusqlite::types::Value>,
240    ) -> SqliteResult<i64> {
241        let sql = sql.to_string();
242        debug!(sql = %sql, "Executing parameterized insert");
243
244        self.conn()
245            .call(move |conn| {
246                let params_ref: Vec<&dyn rusqlite::ToSql> =
247                    params.iter().map(|v| v as &dyn rusqlite::ToSql).collect();
248                conn.execute(&sql, params_ref.as_slice())?;
249                Ok(conn.last_insert_rowid())
250            })
251            .await
252            .map_err(SqliteError::from)
253    }
254
255    /// Execute multiple statements in a batch.
256    pub async fn execute_batch(&self, sql: &str) -> SqliteResult<()> {
257        let sql = sql.to_string();
258        debug!(sql = %sql, "Executing batch");
259
260        self.conn()
261            .call(move |conn| Ok(conn.execute_batch(&sql)?))
262            .await
263            .map_err(SqliteError::from)
264    }
265
266    /// Get the inner connection.
267    pub fn inner(&self) -> &Connection {
268        self.conn()
269    }
270}
271
272impl Drop for SqliteConnection {
273    fn drop(&mut self) {
274        // Return the connection to the pool if possible
275        if let Some(pool) = self.return_to_pool.take() {
276            if let Some(conn) = self.conn.take() {
277                trace!("Returning connection to pool");
278                let mut idle: parking_lot::MutexGuard<'_, VecDeque<PooledConnection>> = pool.lock();
279                idle.push_back(PooledConnection {
280                    conn,
281                    created_at: self.created_at,
282                    last_used: Instant::now(),
283                });
284            }
285        }
286        // Otherwise, the connection is just dropped
287    }
288}