1use std::time::Duration;
4
5use async_trait::async_trait;
6use sqlx::sqlite::SqlitePoolOptions;
7use sqlx::{Column, Row};
8
9use dbrest_core::backend::{DatabaseBackend, DbVersion, StatementResult};
10use dbrest_core::error::Error;
11use dbrest_core::query::sql_builder::{SqlBuilder, SqlParam};
12use dbrest_core::schema_cache::db::DbIntrospector;
13
14use crate::introspector::SqliteIntrospector;
15
16pub struct SqliteBackend {
18 pool: sqlx::SqlitePool,
19}
20
21impl SqliteBackend {
22 pub fn pool(&self) -> &sqlx::SqlitePool {
24 &self.pool
25 }
26
27 pub fn from_pool(pool: sqlx::SqlitePool) -> Self {
29 Self { pool }
30 }
31
32 async fn ensure_vars_table(conn: &mut sqlx::SqliteConnection) -> Result<(), Error> {
34 sqlx::query("CREATE TEMP TABLE IF NOT EXISTS _dbrest_vars(key TEXT PRIMARY KEY, val TEXT)")
35 .execute(&mut *conn)
36 .await
37 .map_err(map_sqlx_error)?;
38 Ok(())
39 }
40}
41
42fn bind_params<'q>(
47 mut q: sqlx::query::Query<'q, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'q>>,
48 params: &'q [SqlParam],
49) -> sqlx::query::Query<'q, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'q>> {
50 for p in params {
51 match p {
52 SqlParam::Text(t) => q = q.bind(t.as_str()),
53 SqlParam::Json(j) => q = q.bind(String::from_utf8_lossy(j).into_owned()),
54 SqlParam::Binary(b) => q = q.bind(b.to_vec()),
55 SqlParam::Null => q = q.bind(Option::<String>::None),
56 }
57 }
58 q
59}
60
61pub fn map_sqlx_error(e: sqlx::Error) -> Error {
63 let (code, message) = match &e {
64 sqlx::Error::Database(db_err) => {
65 let code = db_err.code().map(|c| c.to_string());
66 let message = db_err.message().to_string();
67 (code, message)
68 }
69 _ => {
70 return Error::Database {
71 code: None,
72 message: e.to_string(),
73 detail: None,
74 hint: None,
75 };
76 }
77 };
78
79 match code.as_deref() {
81 Some("2067") | Some("1555") => Error::UniqueViolation(message),
83 Some("787") => Error::ForeignKeyViolation(message),
85 Some("275") => Error::CheckViolation(message),
87 Some("1299") => Error::NotNullViolation(message),
89 _ => Error::Database {
90 code,
91 message,
92 detail: None,
93 hint: None,
94 },
95 }
96}
97
98fn parse_statement_row(row: &sqlx::sqlite::SqliteRow) -> StatementResult {
103 let total: Option<i64> = row
104 .try_get::<String, _>("total_result_set")
105 .ok()
106 .and_then(|s| s.parse::<i64>().ok());
107
108 let page_total: i64 = row.try_get("page_total").unwrap_or(0);
109
110 let body_str: String = row.try_get("body").unwrap_or_else(|_| "[]".to_string());
111
112 let response_headers: Option<serde_json::Value> = row
113 .try_get::<Option<String>, _>("response_headers")
114 .ok()
115 .flatten()
116 .and_then(|s| {
117 if s.is_empty() {
118 None
119 } else {
120 serde_json::from_str(&s).ok()
121 }
122 });
123
124 let response_status: Option<i32> = row
125 .try_get::<Option<String>, _>("response_status")
126 .ok()
127 .flatten()
128 .and_then(|s| {
129 if s.is_empty() {
130 None
131 } else {
132 s.parse::<i32>().ok()
133 }
134 });
135
136 StatementResult {
137 total,
138 page_total,
139 body: body_str,
140 response_headers,
141 response_status,
142 }
143}
144
145#[async_trait]
150impl DatabaseBackend for SqliteBackend {
151 async fn connect(
152 uri: &str,
153 pool_size: u32,
154 acquire_timeout_secs: u64,
155 max_lifetime_secs: u64,
156 idle_timeout_secs: u64,
157 ) -> Result<Self, Error> {
158 let pool = SqlitePoolOptions::new()
159 .max_connections(pool_size)
160 .acquire_timeout(Duration::from_secs(acquire_timeout_secs))
161 .max_lifetime(Duration::from_secs(max_lifetime_secs))
162 .idle_timeout(Duration::from_secs(idle_timeout_secs))
163 .connect(uri)
164 .await
165 .map_err(|e| Error::DbConnection(e.to_string()))?;
166
167 sqlx::query("PRAGMA journal_mode=WAL")
169 .execute(&pool)
170 .await
171 .map_err(map_sqlx_error)?;
172 sqlx::query("PRAGMA foreign_keys=ON")
173 .execute(&pool)
174 .await
175 .map_err(map_sqlx_error)?;
176
177 Ok(Self { pool })
178 }
179
180 async fn version(&self) -> Result<DbVersion, Error> {
181 let row: (String,) = sqlx::query_as("SELECT sqlite_version()")
182 .fetch_one(&self.pool)
183 .await
184 .map_err(|e| Error::DbConnection(format!("Failed to query SQLite version: {}", e)))?;
185
186 let version_str = &row.0;
187 let parts: Vec<&str> = version_str.split('.').collect();
188 Ok(DbVersion {
189 major: parts.first().and_then(|s| s.parse().ok()).unwrap_or(0),
190 minor: parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(0),
191 patch: parts.get(2).and_then(|s| s.parse().ok()).unwrap_or(0),
192 engine: "SQLite".to_string(),
193 })
194 }
195
196 fn min_version(&self) -> (u32, u32) {
197 (3, 35)
199 }
200
201 async fn exec_raw(&self, sql: &str, params: &[SqlParam]) -> Result<(), Error> {
202 let q = sqlx::query(sql);
203 let q = bind_params(q, params);
204 q.execute(&self.pool).await.map_err(map_sqlx_error)?;
205 Ok(())
206 }
207
208 async fn exec_statement(
209 &self,
210 sql: &str,
211 params: &[SqlParam],
212 ) -> Result<StatementResult, Error> {
213 let q = sqlx::query(sql);
214 let q = bind_params(q, params);
215 let rows = q.fetch_all(&self.pool).await.map_err(map_sqlx_error)?;
216
217 if rows.is_empty() {
218 return Ok(StatementResult::empty());
219 }
220
221 Ok(parse_statement_row(&rows[0]))
222 }
223
224 async fn exec_in_transaction(
225 &self,
226 tx_vars: Option<&SqlBuilder>,
227 pre_req: Option<&SqlBuilder>,
228 mutation: Option<&SqlBuilder>,
229 main: Option<&SqlBuilder>,
230 ) -> Result<StatementResult, Error> {
231 let mut tx = self.pool.begin().await.map_err(|e| Error::Database {
232 code: None,
233 message: e.to_string(),
234 detail: None,
235 hint: None,
236 })?;
237
238 Self::ensure_vars_table(&mut tx).await?;
240
241 if let Some(tv) = tx_vars {
243 let q = sqlx::query(tv.sql());
244 let q = bind_params(q, tv.params());
245 q.execute(&mut *tx).await.map_err(map_sqlx_error)?;
246 }
247
248 if let Some(pr) = pre_req {
250 let q = sqlx::query(pr.sql());
251 let q = bind_params(q, pr.params());
252 q.execute(&mut *tx).await.map_err(map_sqlx_error)?;
253 }
254
255 if let Some(mut_q) = mutation {
257 let q = sqlx::query(mut_q.sql());
259 let q = bind_params(q, mut_q.params());
260 let rows = q.fetch_all(&mut *tx).await.map_err(map_sqlx_error)?;
261
262 if !rows.is_empty() {
265 let ncols = rows[0].len();
266 let columns: Vec<String> = (0..ncols)
269 .map(|i| rows[0].column(i).name().to_string())
270 .collect();
271
272 let mut create_sql = String::from("CREATE TEMP TABLE IF NOT EXISTS _dbrst_mut(");
273 for (i, col) in columns.iter().enumerate() {
274 if i > 0 {
275 create_sql.push_str(", ");
276 }
277 create_sql.push('"');
278 create_sql.push_str(&col.replace('"', "\"\""));
279 create_sql.push_str("\" TEXT");
280 }
281 create_sql.push(')');
282 sqlx::query(&create_sql)
283 .execute(&mut *tx)
284 .await
285 .map_err(map_sqlx_error)?;
286
287 for row in &rows {
289 let mut insert_sql = String::from("INSERT INTO _dbrst_mut VALUES(");
290 for i in 0..ncols {
291 if i > 0 {
292 insert_sql.push_str(", ");
293 }
294 insert_sql.push('?');
295 }
296 insert_sql.push(')');
297
298 let mut q = sqlx::query(&insert_sql);
299 for i in 0..ncols {
300 let val: Option<String> = row.try_get(i).ok();
302 q = q.bind(val);
303 }
304 q.execute(&mut *tx).await.map_err(map_sqlx_error)?;
305 }
306 } else {
307 sqlx::query("CREATE TEMP TABLE IF NOT EXISTS _dbrst_mut(__dummy TEXT)")
309 .execute(&mut *tx)
310 .await
311 .map_err(map_sqlx_error)?;
312 }
313 }
314
315 let result = if let Some(main_q) = main {
317 let q = sqlx::query(main_q.sql());
318 let q = bind_params(q, main_q.params());
319 let rows = q.fetch_all(&mut *tx).await.map_err(map_sqlx_error)?;
320
321 if rows.is_empty() {
322 StatementResult::empty()
323 } else {
324 parse_statement_row(&rows[0])
325 }
326 } else {
327 StatementResult::empty()
328 };
329
330 if mutation.is_some() {
332 let _ = sqlx::query("DROP TABLE IF EXISTS _dbrst_mut")
333 .execute(&mut *tx)
334 .await;
335 }
336
337 tx.commit().await.map_err(|e| Error::Database {
338 code: None,
339 message: e.to_string(),
340 detail: None,
341 hint: None,
342 })?;
343
344 Ok(result)
345 }
346
347 fn introspector(&self) -> Box<dyn DbIntrospector + '_> {
348 Box::new(SqliteIntrospector::new(&self.pool))
349 }
350
351 async fn start_listener(
352 &self,
353 _channel: &str,
354 _cancel: tokio::sync::watch::Receiver<bool>,
355 _on_event: std::sync::Arc<dyn Fn(String) + Send + Sync>,
356 ) -> Result<(), Error> {
357 tracing::info!("SQLite does not support LISTEN/NOTIFY — schema change listener disabled");
361 Ok(())
362 }
363
364 fn map_error(&self, err: Box<dyn std::error::Error + Send + Sync>) -> Error {
365 if let Ok(sqlx_err) = err.downcast::<sqlx::Error>() {
366 map_sqlx_error(*sqlx_err)
367 } else {
368 Error::Internal("Unknown database error".to_string())
369 }
370 }
371}