Skip to main content

dbrest_sqlite/
executor.rs

1//! SQLite backend executor — implements [`DatabaseBackend`] for `sqlx::SqlitePool`.
2
3use std::time::Duration;
4
5use async_trait::async_trait;
6use sqlx::sqlite::SqlitePoolOptions;
7use sqlx::{Column, Row};
8
9use dbrest_core::backend::{DatabaseBackend, DbVersion, StatementResult};
10use dbrest_core::error::Error;
11use dbrest_core::query::sql_builder::{SqlBuilder, SqlParam};
12use dbrest_core::schema_cache::db::DbIntrospector;
13
14use crate::introspector::SqliteIntrospector;
15
16/// SQLite backend backed by `sqlx::SqlitePool`.
17pub struct SqliteBackend {
18    pool: sqlx::SqlitePool,
19}
20
21impl SqliteBackend {
22    /// Get a reference to the underlying pool.
23    pub fn pool(&self) -> &sqlx::SqlitePool {
24        &self.pool
25    }
26
27    /// Create from an existing pool (useful for tests).
28    pub fn from_pool(pool: sqlx::SqlitePool) -> Self {
29        Self { pool }
30    }
31
32    /// Ensure the session vars temp table exists on a connection.
33    async fn ensure_vars_table(conn: &mut sqlx::SqliteConnection) -> Result<(), Error> {
34        sqlx::query("CREATE TEMP TABLE IF NOT EXISTS _dbrest_vars(key TEXT PRIMARY KEY, val TEXT)")
35            .execute(&mut *conn)
36            .await
37            .map_err(map_sqlx_error)?;
38        Ok(())
39    }
40}
41
42// --------------------------------------------------------------------------
43// Helper: bind SqlParam values to a sqlx query
44// --------------------------------------------------------------------------
45
46fn bind_params<'q>(
47    mut q: sqlx::query::Query<'q, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'q>>,
48    params: &'q [SqlParam],
49) -> sqlx::query::Query<'q, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'q>> {
50    for p in params {
51        match p {
52            SqlParam::Text(t) => q = q.bind(t.as_str()),
53            SqlParam::Json(j) => q = q.bind(String::from_utf8_lossy(j).into_owned()),
54            SqlParam::Binary(b) => q = q.bind(b.to_vec()),
55            SqlParam::Null => q = q.bind(Option::<String>::None),
56        }
57    }
58    q
59}
60
61/// Map a `sqlx::Error` to our `Error` type for SQLite.
62pub fn map_sqlx_error(e: sqlx::Error) -> Error {
63    let (code, message) = match &e {
64        sqlx::Error::Database(db_err) => {
65            let code = db_err.code().map(|c| c.to_string());
66            let message = db_err.message().to_string();
67            (code, message)
68        }
69        _ => {
70            return Error::Database {
71                code: None,
72                message: e.to_string(),
73                detail: None,
74                hint: None,
75            };
76        }
77    };
78
79    // Map common SQLite error codes
80    match code.as_deref() {
81        // UNIQUE constraint
82        Some("2067") | Some("1555") => Error::UniqueViolation(message),
83        // FOREIGN KEY constraint
84        Some("787") => Error::ForeignKeyViolation(message),
85        // CHECK constraint
86        Some("275") => Error::CheckViolation(message),
87        // NOT NULL constraint
88        Some("1299") => Error::NotNullViolation(message),
89        _ => Error::Database {
90            code,
91            message,
92            detail: None,
93            hint: None,
94        },
95    }
96}
97
98// --------------------------------------------------------------------------
99// Parse the standard 5-column result set
100// --------------------------------------------------------------------------
101
102fn parse_statement_row(row: &sqlx::sqlite::SqliteRow) -> StatementResult {
103    let total: Option<i64> = row
104        .try_get::<String, _>("total_result_set")
105        .ok()
106        .and_then(|s| s.parse::<i64>().ok());
107
108    let page_total: i64 = row.try_get("page_total").unwrap_or(0);
109
110    let body_str: String = row.try_get("body").unwrap_or_else(|_| "[]".to_string());
111
112    let response_headers: Option<serde_json::Value> = row
113        .try_get::<Option<String>, _>("response_headers")
114        .ok()
115        .flatten()
116        .and_then(|s| {
117            if s.is_empty() {
118                None
119            } else {
120                serde_json::from_str(&s).ok()
121            }
122        });
123
124    let response_status: Option<i32> = row
125        .try_get::<Option<String>, _>("response_status")
126        .ok()
127        .flatten()
128        .and_then(|s| {
129            if s.is_empty() {
130                None
131            } else {
132                s.parse::<i32>().ok()
133            }
134        });
135
136    StatementResult {
137        total,
138        page_total,
139        body: body_str,
140        response_headers,
141        response_status,
142    }
143}
144
145// --------------------------------------------------------------------------
146// DatabaseBackend implementation
147// --------------------------------------------------------------------------
148
149#[async_trait]
150impl DatabaseBackend for SqliteBackend {
151    async fn connect(
152        uri: &str,
153        pool_size: u32,
154        acquire_timeout_secs: u64,
155        max_lifetime_secs: u64,
156        idle_timeout_secs: u64,
157    ) -> Result<Self, Error> {
158        let pool = SqlitePoolOptions::new()
159            .max_connections(pool_size)
160            .acquire_timeout(Duration::from_secs(acquire_timeout_secs))
161            .max_lifetime(Duration::from_secs(max_lifetime_secs))
162            .idle_timeout(Duration::from_secs(idle_timeout_secs))
163            .connect(uri)
164            .await
165            .map_err(|e| Error::DbConnection(e.to_string()))?;
166
167        // Enable WAL mode and foreign keys for better concurrency
168        sqlx::query("PRAGMA journal_mode=WAL")
169            .execute(&pool)
170            .await
171            .map_err(map_sqlx_error)?;
172        sqlx::query("PRAGMA foreign_keys=ON")
173            .execute(&pool)
174            .await
175            .map_err(map_sqlx_error)?;
176
177        Ok(Self { pool })
178    }
179
180    async fn version(&self) -> Result<DbVersion, Error> {
181        let row: (String,) = sqlx::query_as("SELECT sqlite_version()")
182            .fetch_one(&self.pool)
183            .await
184            .map_err(|e| Error::DbConnection(format!("Failed to query SQLite version: {}", e)))?;
185
186        let version_str = &row.0;
187        let parts: Vec<&str> = version_str.split('.').collect();
188        Ok(DbVersion {
189            major: parts.first().and_then(|s| s.parse().ok()).unwrap_or(0),
190            minor: parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(0),
191            patch: parts.get(2).and_then(|s| s.parse().ok()).unwrap_or(0),
192            engine: "SQLite".to_string(),
193        })
194    }
195
196    fn min_version(&self) -> (u32, u32) {
197        // Require SQLite 3.35+ for RETURNING support
198        (3, 35)
199    }
200
201    async fn exec_raw(&self, sql: &str, params: &[SqlParam]) -> Result<(), Error> {
202        let q = sqlx::query(sql);
203        let q = bind_params(q, params);
204        q.execute(&self.pool).await.map_err(map_sqlx_error)?;
205        Ok(())
206    }
207
208    async fn exec_statement(
209        &self,
210        sql: &str,
211        params: &[SqlParam],
212    ) -> Result<StatementResult, Error> {
213        let q = sqlx::query(sql);
214        let q = bind_params(q, params);
215        let rows = q.fetch_all(&self.pool).await.map_err(map_sqlx_error)?;
216
217        if rows.is_empty() {
218            return Ok(StatementResult::empty());
219        }
220
221        Ok(parse_statement_row(&rows[0]))
222    }
223
224    async fn exec_in_transaction(
225        &self,
226        tx_vars: Option<&SqlBuilder>,
227        pre_req: Option<&SqlBuilder>,
228        mutation: Option<&SqlBuilder>,
229        main: Option<&SqlBuilder>,
230    ) -> Result<StatementResult, Error> {
231        let mut tx = self.pool.begin().await.map_err(|e| Error::Database {
232            code: None,
233            message: e.to_string(),
234            detail: None,
235            hint: None,
236        })?;
237
238        // Ensure the session vars temp table exists
239        Self::ensure_vars_table(&mut tx).await?;
240
241        // 1. Set session variables
242        if let Some(tv) = tx_vars {
243            let q = sqlx::query(tv.sql());
244            let q = bind_params(q, tv.params());
245            q.execute(&mut *tx).await.map_err(map_sqlx_error)?;
246        }
247
248        // 2. Call pre-request function
249        if let Some(pr) = pre_req {
250            let q = sqlx::query(pr.sql());
251            let q = bind_params(q, pr.params());
252            q.execute(&mut *tx).await.map_err(map_sqlx_error)?;
253        }
254
255        // 3. If there's a split mutation, execute it and bridge results via temp table
256        if let Some(mut_q) = mutation {
257            // Execute mutation with RETURNING and collect rows
258            let q = sqlx::query(mut_q.sql());
259            let q = bind_params(q, mut_q.params());
260            let rows = q.fetch_all(&mut *tx).await.map_err(map_sqlx_error)?;
261
262            // Create temp table and insert RETURNING rows
263            // We need to know column names/count — extract from first row
264            if !rows.is_empty() {
265                let ncols = rows[0].len();
266                // Build CREATE TEMP TABLE with generic column names
267                // Then use the actual column names from the row metadata
268                let columns: Vec<String> = (0..ncols)
269                    .map(|i| rows[0].column(i).name().to_string())
270                    .collect();
271
272                let mut create_sql = String::from("CREATE TEMP TABLE IF NOT EXISTS _dbrst_mut(");
273                for (i, col) in columns.iter().enumerate() {
274                    if i > 0 {
275                        create_sql.push_str(", ");
276                    }
277                    create_sql.push('"');
278                    create_sql.push_str(&col.replace('"', "\"\""));
279                    create_sql.push_str("\" TEXT");
280                }
281                create_sql.push(')');
282                sqlx::query(&create_sql)
283                    .execute(&mut *tx)
284                    .await
285                    .map_err(map_sqlx_error)?;
286
287                // Insert each row
288                for row in &rows {
289                    let mut insert_sql = String::from("INSERT INTO _dbrst_mut VALUES(");
290                    for i in 0..ncols {
291                        if i > 0 {
292                            insert_sql.push_str(", ");
293                        }
294                        insert_sql.push('?');
295                    }
296                    insert_sql.push(')');
297
298                    let mut q = sqlx::query(&insert_sql);
299                    for i in 0..ncols {
300                        // Try to get as string; fall back to NULL
301                        let val: Option<String> = row.try_get(i).ok();
302                        q = q.bind(val);
303                    }
304                    q.execute(&mut *tx).await.map_err(map_sqlx_error)?;
305                }
306            } else {
307                // No rows returned — still create the temp table with a dummy schema
308                sqlx::query("CREATE TEMP TABLE IF NOT EXISTS _dbrst_mut(__dummy TEXT)")
309                    .execute(&mut *tx)
310                    .await
311                    .map_err(map_sqlx_error)?;
312            }
313        }
314
315        // 4. Execute the main query (aggregation SELECT)
316        let result = if let Some(main_q) = main {
317            let q = sqlx::query(main_q.sql());
318            let q = bind_params(q, main_q.params());
319            let rows = q.fetch_all(&mut *tx).await.map_err(map_sqlx_error)?;
320
321            if rows.is_empty() {
322                StatementResult::empty()
323            } else {
324                parse_statement_row(&rows[0])
325            }
326        } else {
327            StatementResult::empty()
328        };
329
330        // 5. Clean up temp table if we created one
331        if mutation.is_some() {
332            let _ = sqlx::query("DROP TABLE IF EXISTS _dbrst_mut")
333                .execute(&mut *tx)
334                .await;
335        }
336
337        tx.commit().await.map_err(|e| Error::Database {
338            code: None,
339            message: e.to_string(),
340            detail: None,
341            hint: None,
342        })?;
343
344        Ok(result)
345    }
346
347    fn introspector(&self) -> Box<dyn DbIntrospector + '_> {
348        Box::new(SqliteIntrospector::new(&self.pool))
349    }
350
351    async fn start_listener(
352        &self,
353        _channel: &str,
354        _cancel: tokio::sync::watch::Receiver<bool>,
355        _on_event: std::sync::Arc<dyn Fn(String) + Send + Sync>,
356    ) -> Result<(), Error> {
357        // SQLite has no LISTEN/NOTIFY mechanism.
358        // We simply return Ok — schema reload must be triggered differently
359        // (e.g., by file watch or timer in the caller).
360        tracing::info!("SQLite does not support LISTEN/NOTIFY — schema change listener disabled");
361        Ok(())
362    }
363
364    fn map_error(&self, err: Box<dyn std::error::Error + Send + Sync>) -> Error {
365        if let Ok(sqlx_err) = err.downcast::<sqlx::Error>() {
366            map_sqlx_error(*sqlx_err)
367        } else {
368            Error::Internal("Unknown database error".to_string())
369        }
370    }
371}