rdbc_sqlite3/
sqlite3.rs

1/// ! sqlite3 c api wrapper mod
2///
3use std::{
4    ffi::{c_void, CStr, CString},
5    os::raw::c_char,
6    ptr::{null, null_mut},
7    slice::from_raw_parts,
8};
9
10use super::error;
11
12use sqlite3_sys::*;
13
14use anyhow::{Ok, Result};
15
16use rdbc_rs::driver::{self, callback::BoxedCallback, RDBCError};
17
18pub fn colunm_decltype(
19    stmt: *mut sqlite3_stmt,
20    i: i32,
21) -> (driver::ColumnType, String, Option<u64>) {
22    let decltype = unsafe { CStr::from_ptr(sqlite3_column_decltype(stmt, i)) }.to_string_lossy();
23
24    match decltype.as_ref() {
25        "INT" | "INTEGER" | "TINYINT" | "SMALLINT" | "MEDIUMINT" | "BIGINT"
26        | "UNSIGNED BIG INT" | "INT2" | "INT8" => {
27            (driver::ColumnType::I64, decltype.to_string(), Some(8))
28        }
29        "CHARACTER(20)"
30        | "VARCHAR(255)"
31        | "VARYING CHARACTER(255)"
32        | "NCHAR(55)"
33        | "NATIVE CHARACTER(70)"
34        | "NVARCHAR(100)"
35        | "TEXT"
36        | "CLOB" => (driver::ColumnType::String, decltype.to_string(), None),
37        "BLOB" => (driver::ColumnType::Bytes, decltype.to_string(), None),
38        "REAL" | "DOUBLE" | "DOUBLE PRECISION" | "FLOAT" => {
39            (driver::ColumnType::F64, decltype.to_string(), Some(8))
40        }
41        _ => (driver::ColumnType::String, decltype.to_string(), None),
42    }
43}
44
45pub fn stmt_sql(stmt: *mut sqlite3_stmt) -> String {
46    unsafe {
47        CStr::from_ptr(sqlite3_expanded_sql(stmt))
48            .to_string_lossy()
49            .to_owned()
50            .to_string()
51    }
52}
53
54pub fn stmt_original_sql(stmt: *mut sqlite3_stmt) -> String {
55    unsafe {
56        CStr::from_ptr(sqlite3_sql(stmt))
57            .to_string_lossy()
58            .to_owned()
59            .to_string()
60    }
61}
62
63pub struct Sqlite3Driver {}
64
65impl driver::Driver for Sqlite3Driver {
66    fn open(&mut self, name: &str) -> Result<Box<dyn driver::Connection>> {
67        let conn = Connection::new(name)?;
68
69        Ok(Box::new(conn))
70    }
71}
72
73/// sqlite connection object
74pub struct Connection {
75    db: *mut sqlite3,
76    _id: String,
77}
78
79unsafe impl Send for Connection {}
80
81impl Connection {
82    fn new(name: &str) -> Result<Self> {
83        unsafe {
84            assert!(
85                sqlite3_threadsafe() != 0,
86                "Sqlite3 must be compiled in thread safe mode."
87            );
88        }
89
90        let mut db = std::ptr::null_mut();
91
92        let flags =
93            SQLITE_OPEN_URI | SQLITE_OPEN_CREATE | SQLITE_OPEN_READWRITE | SQLITE_OPEN_NOMUTEX;
94
95        log::trace!("open sqlite3 database: {} {:X}", name, flags);
96
97        let c_name = CString::new(name)?;
98
99        unsafe {
100            let r = sqlite3_open_v2(c_name.as_ptr(), &mut db, flags, std::ptr::null());
101
102            if r != SQLITE_OK {
103                let e = if db.is_null() {
104                    error::native_error(r, format!("open sqlite {} failure", name))
105                } else {
106                    let e = error::db_native_error(db, r);
107
108                    let r = sqlite3_close(db); // ignore result .
109
110                    // debug output
111                    if r != SQLITE_OK {
112                        log::error!("close sqlite3 conn failed: code({})", r);
113                    }
114
115                    e
116                };
117
118                return Err(e);
119            } else {
120                log::trace!("create connection {:?}", db);
121
122                let c_str = CString::new("PRAGMA foreign_keys = ON;").unwrap();
123
124                let rc = sqlite3_exec(
125                    db,
126                    c_str.as_ptr(),
127                    None,
128                    null_mut::<c_void>(),
129                    null_mut::<*mut i8>(),
130                );
131
132                if rc != SQLITE_OK {
133                    return Err(error::db_native_error(db, rc));
134                }
135
136                return Ok(Self {
137                    db,
138                    _id: format!("{:?}", db),
139                });
140            }
141        }
142    }
143
144    fn _begin(&mut self) -> anyhow::Result<Box<dyn driver::Transaction>> {
145        let rc = unsafe {
146            let c_str = CString::new("BEGIN").unwrap();
147
148            sqlite3_exec(
149                self.db,
150                c_str.as_ptr(),
151                None,
152                null_mut::<c_void>(),
153                null_mut::<*mut i8>(),
154            )
155        };
156
157        if rc != SQLITE_OK {
158            return Err(error::db_native_error(self.db, rc));
159        }
160
161        Ok(Box::new(Transaction {
162            conn: Connection {
163                db: self.db,
164                _id: self._id.clone(),
165            },
166            finished: false,
167            id: uuid::Uuid::new_v4().to_string(), // Use the randomly generated uuid as tx id
168        }))
169    }
170
171    fn _prepare(&mut self, query: String) -> Result<Box<dyn driver::Statement>> {
172        let sqlite3_query = CString::new(query.clone())?;
173
174        let mut stmt = null_mut();
175
176        let rc = unsafe {
177            sqlite3_prepare_v2(
178                self.db,
179                sqlite3_query.as_ptr(),
180                sqlite3_query.as_bytes().len() as i32,
181                &mut stmt,
182                null_mut::<*const c_char>(),
183            )
184        };
185
186        if rc != SQLITE_OK {
187            return Err(error::error_with_sql(self.db, rc, &query));
188        }
189
190        // If the input text contains no SQL (if the input is an empty string or a comment) then *ppStmt is set to NULL.
191        if stmt.is_null() {
192            return Err(anyhow::anyhow!("invalid input sql {}", query));
193        }
194
195        Ok(Box::new(Statement {
196            db: self.db,
197            stmt,
198            id: format!("{:?}", stmt),
199        }))
200    }
201}
202
203impl driver::Connection for Connection {
204    fn conn_status(&self) -> driver::ConnStatus {
205        driver::ConnStatus::Connected
206    }
207
208    fn id(&self) -> &str {
209        &self._id
210    }
211
212    fn begin(&mut self, callback: BoxedCallback<Box<dyn driver::Transaction>>) {
213        callback.invoke(self._begin())
214    }
215
216    fn prepare(&mut self, query: String, callback: BoxedCallback<Box<dyn driver::Statement>>) {
217        callback.invoke(self._prepare(query))
218    }
219}
220
221impl Drop for Connection {
222    fn drop(&mut self) {
223        if !self.db.is_null() {
224            log::trace!("drop connection {:?}", self.db);
225
226            let r = unsafe { sqlite3_close(self.db) }; // ignore result .
227
228            self.db = std::ptr::null_mut(); // set db ptr to null to preventing twice drop
229
230            // debug output
231            if r != SQLITE_OK {
232                log::error!("close sqlite3 conn failed: code({})", r);
233            }
234        }
235    }
236}
237
238pub struct Statement {
239    db: *mut sqlite3,
240    stmt: *mut sqlite3_stmt,
241    pub id: String,
242}
243
244unsafe impl Send for Statement {}
245
246fn get_bind_index(stmt: *mut sqlite3_stmt, pos: driver::ArgName) -> anyhow::Result<i32> {
247    let index = match &pos {
248        driver::ArgName::Offset(index) => *index as i32,
249        driver::ArgName::String(name) => {
250            let c_named = CString::new(name.as_str())?;
251            unsafe { sqlite3_bind_parameter_index(stmt, c_named.as_ptr()) }
252        }
253    };
254
255    if index == 0 {
256        return Err(anyhow::format_err!(
257            "arg name({:?}) not found, {}",
258            pos,
259            stmt_original_sql(stmt),
260        ));
261    }
262
263    return Ok(index);
264}
265
266impl Statement {
267    unsafe fn bind_args(&mut self, args: Vec<rdbc_rs::driver::Argument>) -> anyhow::Result<()> {
268        // sqlite3_clear_bindings(self.stmt);
269        sqlite3_reset(self.stmt);
270
271        log::trace!("execute sql {} with args {:?}", stmt_sql(self.stmt), args);
272
273        for arg in args {
274            let index = get_bind_index(self.stmt, arg.name)?;
275
276            let rc = match arg.value {
277                driver::ArgValue::Bytes(bytes) => {
278                    let ptr = bytes.as_ptr();
279                    let len = bytes.len();
280                    sqlite3_bind_blob(
281                        self.stmt,
282                        index,
283                        ptr as *const c_void,
284                        len as i32,
285                        Some(std::mem::transmute(SQLITE_TRANSIENT as usize)),
286                    )
287                }
288                driver::ArgValue::F64(f64) => sqlite3_bind_double(self.stmt, index, f64),
289
290                driver::ArgValue::I64(i64) => sqlite3_bind_int64(self.stmt, index, i64),
291
292                driver::ArgValue::String(str) => {
293                    let str = CString::new(str)?;
294
295                    let ptr = str.as_ptr();
296                    let len = str.as_bytes().len() as i32;
297
298                    sqlite3_bind_text(
299                        self.stmt,
300                        index,
301                        ptr,
302                        len,
303                        Some(std::mem::transmute(SQLITE_TRANSIENT as usize)),
304                    )
305                }
306
307                driver::ArgValue::Null => SQLITE_OK,
308            };
309
310            if rc != SQLITE_OK {
311                return Err(error::db_native_error(self.db, rc));
312            }
313        }
314
315        Ok(())
316    }
317
318    fn _query(&mut self, args: Vec<rdbc_rs::Argument>) -> Result<Box<dyn driver::Rows>> {
319        unsafe { self.bind_args(args) }?;
320
321        return Ok(Box::new(Rows {
322            db: self.db,
323            stmt: self.stmt,
324            columns: None,
325            has_next: false,
326            id: uuid::Uuid::new_v4().to_string(),
327        }));
328    }
329}
330
331impl driver::Statement for Statement {
332    fn execute(
333        &mut self,
334        args: Vec<rdbc_rs::Argument>,
335        callback: BoxedCallback<driver::ExecResult>,
336    ) {
337        let exec = || {
338            unsafe { self.bind_args(args) }?;
339
340            let rc = unsafe { sqlite3_step(self.stmt) };
341
342            // unsafe { sqlite3_reset(self.stmt) };
343
344            match rc {
345                SQLITE_DONE => {
346                    let last_insert_id = unsafe { sqlite3_last_insert_rowid(self.db) } as u64;
347                    let raws_affected = unsafe { sqlite3_changes(self.db) } as u64;
348
349                    return Ok(driver::ExecResult {
350                        last_insert_id,
351                        raws_affected,
352                    });
353                }
354                SQLITE_ROW => {
355                    return Err(anyhow::Error::new(driver::RDBCError::UnexpectRows));
356                }
357                _ => {
358                    return Err(error::db_native_error(self.db, rc));
359                }
360            };
361        };
362
363        callback.invoke(exec())
364    }
365
366    fn num_input(&self, callback: BoxedCallback<Option<usize>>) {
367        callback.invoke(Ok(Some(unsafe {
368            sqlite3_bind_parameter_count(self.stmt) as usize
369        })))
370    }
371
372    fn query(
373        &mut self,
374        args: Vec<rdbc_rs::Argument>,
375        callback: BoxedCallback<Box<dyn driver::Rows>>,
376    ) {
377        callback.invoke(self._query(args))
378    }
379}
380
381impl Drop for Statement {
382    fn drop(&mut self) {
383        if !self.stmt.is_null() {
384            log::trace!("drop stmt: {}", stmt_sql(self.stmt));
385            unsafe { sqlite3_finalize(self.stmt) };
386            self.stmt = null_mut();
387        }
388    }
389}
390
391pub struct Transaction {
392    conn: Connection,
393    finished: bool,
394    pub id: String,
395}
396
397impl Transaction {
398    fn _rollback(&self) -> anyhow::Result<()> {
399        let rc = unsafe {
400            let c_str = CString::new("ROLLBACK").unwrap();
401
402            sqlite3_exec(
403                self.conn.db,
404                c_str.as_ptr(),
405                None,
406                null_mut::<c_void>(),
407                null_mut::<*mut i8>(),
408            )
409        };
410
411        if rc != SQLITE_OK {
412            return Err(error::error_with_sql(self.conn.db, rc, "ROLLBACK"));
413        }
414
415        Ok(())
416    }
417}
418
419impl driver::Transaction for Transaction {
420    fn commit(&mut self, callback: BoxedCallback<()>) {
421        let mut invoke = || {
422            let rc = unsafe {
423                let c_str = CString::new("COMMIT").unwrap();
424
425                sqlite3_exec(
426                    self.conn.db,
427                    c_str.as_ptr(),
428                    None,
429                    null_mut::<c_void>(),
430                    null_mut::<*mut i8>(),
431                )
432            };
433
434            self.finished = true;
435
436            if rc != SQLITE_OK {
437                Err(error::error_with_sql(self.conn.db, rc, "COMMIT"))
438            } else {
439                Ok(())
440            }
441        };
442
443        callback.invoke(invoke());
444    }
445
446    fn prepare(&mut self, query: String, callback: BoxedCallback<Box<dyn driver::Statement>>) {
447        use driver::Connection;
448
449        self.conn.prepare(query, callback)
450    }
451
452    fn rollback(&mut self, callback: BoxedCallback<()>) {
453        let mut invoke = || {
454            self.finished = true;
455
456            self._rollback()
457        };
458
459        callback.invoke(invoke())
460    }
461}
462
463impl Drop for Transaction {
464    fn drop(&mut self) {
465        // default to rollback all stmt .
466        if !self.finished {
467            _ = self._rollback();
468            self.finished = true;
469        }
470
471        self.conn.db = null_mut();
472    }
473}
474
475pub struct Rows {
476    db: *mut sqlite3,
477    stmt: *mut sqlite3_stmt,
478    columns: Option<Vec<driver::Column>>,
479    has_next: bool,
480    pub id: String,
481}
482
483impl Rows {
484    fn _columns(&mut self) -> Result<&Vec<driver::Column>> {
485        if self.columns.is_none() {
486            let mut columns = vec![];
487
488            unsafe {
489                let count = sqlite3_column_count(self.stmt);
490
491                for i in 0..count {
492                    let name = sqlite3_column_name(self.stmt, i);
493
494                    let (_, decltype, len) = colunm_decltype(self.stmt, i);
495
496                    columns.push(driver::Column {
497                        column_index: i as u64,
498                        column_name: CStr::from_ptr(name).to_string_lossy().to_string(),
499                        column_decltype: decltype,
500                        column_decltype_len: len,
501                    })
502                }
503            };
504
505            self.columns = Some(columns);
506        }
507
508        Ok(self.columns.as_ref().unwrap())
509    }
510
511    fn _get(
512        &mut self,
513        name: driver::ArgName,
514        column_type: driver::ColumnType,
515    ) -> Result<Option<driver::ArgValue>> {
516        log::trace!(
517            "{} :get column({:?},{:?})",
518            stmt_sql(self.stmt),
519            name,
520            column_type
521        );
522
523        let index = match name {
524            driver::ArgName::Offset(index) => index as i32,
525            driver::ArgName::String(name) => {
526                let columns = self._columns()?;
527
528                let col = columns
529                    .iter()
530                    .find(|column| column.column_name.to_uppercase() == name.to_uppercase())
531                    .map(|c| c.column_index as i32);
532
533                if let Some(index) = col {
534                    index
535                } else {
536                    return Ok(None);
537                }
538            }
539        };
540
541        let max_index = unsafe { sqlite3_column_count(self.stmt) };
542
543        if index >= max_index {
544            return Err(anyhow::Error::new(RDBCError::OutOfRange(index as u64)));
545        }
546
547        if !self.has_next {
548            return Err(anyhow::Error::new(RDBCError::NextDataError));
549        }
550
551        let value = unsafe {
552            match column_type {
553                driver::ColumnType::Bytes => {
554                    let len = sqlite3_column_bytes(self.stmt, index);
555                    let data = sqlite3_column_blob(self.stmt, index) as *const u8;
556                    let data = from_raw_parts(data, len as usize).to_owned();
557
558                    driver::ArgValue::Bytes(data)
559                }
560                driver::ColumnType::I64 => {
561                    driver::ArgValue::I64(sqlite3_column_int64(self.stmt, index))
562                }
563                driver::ColumnType::F64 => {
564                    driver::ArgValue::F64(sqlite3_column_double(self.stmt, index))
565                }
566                driver::ColumnType::String => {
567                    let data = sqlite3_column_text(self.stmt, index) as *const i8;
568
569                    if data != null() {
570                        driver::ArgValue::String(CStr::from_ptr(data).to_string_lossy().to_string())
571                    } else {
572                        driver::ArgValue::String("".to_owned())
573                    }
574                }
575                driver::ColumnType::Null => driver::ArgValue::Null,
576            }
577        };
578
579        Ok(Some(value))
580    }
581
582    fn _next(&mut self) -> Result<bool> {
583        match unsafe { sqlite3_step(self.stmt) } {
584            SQLITE_DONE => {
585                self.has_next = false;
586                Ok(false)
587            }
588
589            SQLITE_ROW => {
590                self.has_next = true;
591                Ok(true)
592            }
593
594            rc => {
595                self.has_next = false;
596                Err(error::db_native_error(self.db, rc))
597            }
598        }
599    }
600}
601
602unsafe impl Send for Rows {}
603
604impl driver::Rows for Rows {
605    fn colunms(&mut self, callback: BoxedCallback<Vec<driver::Column>>) {
606        callback.invoke(self._columns().map(|c| c.clone()))
607    }
608
609    fn next(&mut self, callback: BoxedCallback<bool>) {
610        callback.invoke(self._next())
611    }
612
613    fn get(
614        &mut self,
615        name: driver::ArgName,
616        column_type: driver::ColumnType,
617        callback: BoxedCallback<Option<driver::ArgValue>>,
618    ) {
619        callback.invoke(self._get(name, column_type))
620    }
621}
622
623impl Drop for Rows {
624    fn drop(&mut self) {
625        log::trace!("reset stmt {}", stmt_sql(self.stmt));
626        unsafe { sqlite3_reset(self.stmt) };
627    }
628}