rasi_sqlite3/
lib.rs

1use std::{
2    ffi::{CStr, CString},
3    io::{Error, ErrorKind, Result},
4    os::raw::c_void,
5    ptr::{null, null_mut},
6    slice::from_raw_parts,
7    str::{from_utf8_unchecked, FromStr},
8    sync::{
9        atomic::{AtomicBool, Ordering},
10        Arc,
11    },
12    task::Poll,
13};
14
15use rasi::rdbc::*;
16use sqlite3_sys as ffi;
17
18struct Sqlite3Driver;
19
20unsafe fn db_error(db: *mut ffi::sqlite3) -> Error {
21    Error::new(
22        ErrorKind::Other,
23        format!(
24            "sqlite3: code={}, error={}",
25            ffi::sqlite3_errcode(db),
26            from_utf8_unchecked(CStr::from_ptr(ffi::sqlite3_errmsg(db)).to_bytes())
27        ),
28    )
29}
30
31impl syscall::Driver for Sqlite3Driver {
32    fn create_connection(
33        &self,
34        driver_name: &str,
35        source_name: &str,
36    ) -> std::io::Result<Connection> {
37        let mut db = null_mut();
38
39        unsafe {
40            let rc = ffi::sqlite3_open_v2(
41                CString::new(source_name)?.as_ptr(),
42                &mut db,
43                ffi::SQLITE_OPEN_CREATE
44                    | ffi::SQLITE_OPEN_READWRITE
45                    | ffi::SQLITE_OPEN_URI
46                    | ffi::SQLITE_OPEN_FULLMUTEX,
47                null_mut(),
48            );
49
50            if rc != ffi::SQLITE_OK {
51                return Err(db_error(db));
52            }
53        }
54
55        let conn = Sqlite3Conn(Arc::new(RawConn(db)));
56
57        Ok((driver_name.to_owned(), conn).into())
58    }
59}
60
61struct RawConn(*mut ffi::sqlite3);
62
63unsafe impl Send for RawConn {}
64unsafe impl Sync for RawConn {}
65
66impl Drop for RawConn {
67    fn drop(&mut self) {
68        unsafe {
69            ffi::sqlite3_close(self.0);
70        }
71    }
72}
73
74struct RawStmt(*mut ffi::sqlite3_stmt);
75
76unsafe impl Send for RawStmt {}
77unsafe impl Sync for RawStmt {}
78
79impl Drop for RawStmt {
80    fn drop(&mut self) {
81        unsafe {
82            ffi::sqlite3_finalize(self.0);
83        }
84    }
85}
86
87struct Sqlite3Conn(Arc<RawConn>);
88
89fn exec(conn: &RawConn, sql: &CStr) -> Result<()> {
90    unsafe {
91        let rc = ffi::sqlite3_exec(conn.0, sql.as_ptr(), None, null_mut(), null_mut());
92
93        if rc != ffi::SQLITE_OK {
94            return Err(db_error(conn.0));
95        }
96    }
97
98    Ok(())
99}
100
101fn prepare(conn: Arc<RawConn>, sql: &CStr) -> Result<Prepare> {
102    let mut c_stmt = null_mut();
103
104    unsafe {
105        let rc = ffi::sqlite3_prepare_v2(conn.0, sql.as_ptr(), -1, &mut c_stmt, null_mut());
106
107        if rc != ffi::SQLITE_OK {
108            return Err(db_error(conn.0));
109        }
110    }
111
112    Ok(Sqlite3Prepare {
113        conn,
114        stmt: Arc::new(RawStmt(c_stmt)),
115    }
116    .into())
117}
118
119impl syscall::DriverConn for Sqlite3Conn {
120    fn poll_ready(&self, _cx: &mut std::task::Context<'_>) -> std::task::Poll<std::io::Result<()>> {
121        Poll::Ready(Ok(()))
122    }
123
124    fn begin(&self) -> std::io::Result<Transaction> {
125        exec(&self.0, c"BEGIN;")?;
126
127        Ok(Sqlite3Tx(self.0.clone(), AtomicBool::new(false)).into())
128    }
129
130    fn prepare(&self, query: &str) -> std::io::Result<Prepare> {
131        prepare(self.0.clone(), CString::new(query)?.as_ref())
132    }
133
134    fn exec(&self, query: &str, params: &[SqlParameter<'_>]) -> std::io::Result<Update> {
135        self.prepare(query)?.as_driver_query().exec(params)
136    }
137
138    fn query(&self, query: &str, params: &[SqlParameter<'_>]) -> std::io::Result<Query> {
139        self.prepare(query)?.as_driver_query().query(params)
140    }
141}
142
143struct Sqlite3Tx(Arc<RawConn>, AtomicBool);
144
145impl Drop for Sqlite3Tx {
146    fn drop(&mut self) {
147        if self
148            .1
149            .compare_exchange(false, true, Ordering::SeqCst, Ordering::Relaxed)
150            .is_ok()
151        {
152            if let Err(err) = exec(&self.0, c"COMMIT;") {
153                log::error!(target:"Sqlite3Tx","auto commit failed, {}",err);
154            }
155        }
156    }
157}
158
159impl syscall::DriverTx for Sqlite3Tx {
160    fn poll_ready(&self, _cx: &mut std::task::Context<'_>) -> Poll<Result<()>> {
161        Poll::Ready(Ok(()))
162    }
163
164    fn poll_rollback(&self, _cx: &mut std::task::Context<'_>) -> Poll<Result<()>> {
165        if self
166            .1
167            .compare_exchange(false, true, Ordering::SeqCst, Ordering::Relaxed)
168            .is_ok()
169        {
170            exec(&self.0, c"ROLLBACK;")?;
171
172            Poll::Ready(Ok(()))
173        } else {
174            Poll::Ready(Err(Error::new(ErrorKind::Other, "Call rollback twice")))
175        }
176    }
177
178    fn prepare(&self, query: &str) -> Result<Prepare> {
179        prepare(self.0.clone(), CString::new(query)?.as_ref())
180    }
181
182    fn exec(&self, query: &str, params: &[SqlParameter<'_>]) -> Result<Update> {
183        self.prepare(query)?.as_driver_query().exec(params)
184    }
185
186    fn query(&self, query: &str, params: &[SqlParameter<'_>]) -> Result<Query> {
187        self.prepare(query)?.as_driver_query().query(params)
188    }
189}
190
191struct Sqlite3Prepare {
192    conn: Arc<RawConn>,
193    stmt: Arc<RawStmt>,
194}
195
196impl syscall::DriverPrepare for Sqlite3Prepare {
197    fn poll_ready(&self, _cx: &mut std::task::Context<'_>) -> Poll<Result<()>> {
198        Poll::Ready(Ok(()))
199    }
200
201    fn exec(&self, params: &[SqlParameter<'_>]) -> Result<Update> {
202        self.bind_params(params)?;
203
204        let rc = unsafe { ffi::sqlite3_step(self.stmt.0) };
205
206        match rc {
207            ffi::SQLITE_DONE => {
208                let last_insert_id = unsafe { ffi::sqlite3_last_insert_rowid(self.conn.0) } as i64;
209                let raws_affected = unsafe { ffi::sqlite3_changes(self.conn.0) } as i64;
210
211                return Ok(Sqlite3Update(last_insert_id, raws_affected).into());
212            }
213            ffi::SQLITE_ROW => {
214                return Err(Error::new(
215                    ErrorKind::Unsupported,
216                    "Call exec on query statement.",
217                ))
218            }
219            _ => return Err(unsafe { db_error(self.conn.0) }),
220        }
221    }
222
223    fn query(&self, params: &[SqlParameter<'_>]) -> Result<Query> {
224        self.bind_params(params)?;
225
226        Ok(Sqlite3Query {
227            conn: self.conn.clone(),
228            stmt: self.stmt.clone(),
229        }
230        .into())
231    }
232}
233
234impl Sqlite3Prepare {
235    fn bind_params(&self, params: &[SqlParameter]) -> Result<()> {
236        unsafe {
237            if ffi::SQLITE_OK != ffi::sqlite3_reset(self.stmt.0) {
238                return Err(db_error(self.conn.0));
239            }
240        }
241
242        let mut named_params = 0;
243
244        for (index, param) in params.iter().enumerate() {
245            let (index, value) = match param {
246                SqlParameter::Named(name, value) => unsafe {
247                    let index = ffi::sqlite3_bind_parameter_index(
248                        self.stmt.0,
249                        CString::new(name.as_ref())?.as_ptr(),
250                    );
251
252                    if index == 0 {
253                        return Err(Error::new(
254                            ErrorKind::NotFound,
255                            format!("no matching parameter is found: {}", name),
256                        ));
257                    }
258
259                    named_params += 1;
260
261                    (index, value)
262                },
263                SqlParameter::Offset(value) => (index as i32 + 1 - named_params, value),
264            };
265
266            let rc = match value {
267                SqlValue::Bool(value) => {
268                    let value = if *value { 1 } else { 0 };
269
270                    unsafe { ffi::sqlite3_bind_int(self.stmt.0, index, value) }
271                }
272                SqlValue::Int(value) => unsafe {
273                    ffi::sqlite3_bind_int64(self.stmt.0, index, *value)
274                },
275                SqlValue::BigInt(value) => unsafe {
276                    let value = CString::new(format!("{value}"))?.as_ptr();
277
278                    ffi::sqlite3_bind_text(
279                        self.stmt.0,
280                        index,
281                        value,
282                        -1,
283                        Some(std::mem::transmute(-1isize)),
284                    )
285                },
286                SqlValue::Float(value) => unsafe {
287                    ffi::sqlite3_bind_double(self.stmt.0, index, *value)
288                },
289
290                SqlValue::Decimal(value) => unsafe {
291                    let value = CString::new(format!("{value}"))?.as_ptr();
292
293                    ffi::sqlite3_bind_text(
294                        self.stmt.0,
295                        index,
296                        value,
297                        -1,
298                        Some(std::mem::transmute(-1isize)),
299                    )
300                },
301                SqlValue::Binary(value) => unsafe {
302                    ffi::sqlite3_bind_blob(
303                        self.stmt.0,
304                        index,
305                        value.as_ptr() as *const c_void,
306                        value.len() as i32,
307                        Some(std::mem::transmute(-1isize)),
308                    )
309                },
310                SqlValue::String(value) => unsafe {
311                    ffi::sqlite3_bind_text(
312                        self.stmt.0,
313                        index,
314                        CString::new(value.as_ref())?.as_ptr(),
315                        -1,
316                        Some(std::mem::transmute(-1isize)),
317                    )
318                },
319                SqlValue::Null => unsafe { ffi::sqlite3_bind_null(self.stmt.0, index) },
320            };
321
322            if rc != ffi::SQLITE_OK {
323                return Err(unsafe { db_error(self.conn.0) });
324            }
325        }
326
327        Ok(())
328    }
329}
330
331struct Sqlite3Update(i64, i64);
332
333impl syscall::DriverUpdate for Sqlite3Update {
334    fn poll_ready(&self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(i64, i64)>> {
335        Poll::Ready(Ok((self.0, self.1)))
336    }
337}
338
339struct Sqlite3Query {
340    conn: Arc<RawConn>,
341    stmt: Arc<RawStmt>,
342}
343
344impl syscall::DriverQuery for Sqlite3Query {
345    fn poll_next(&self, _cx: &mut std::task::Context<'_>) -> Poll<Result<Option<Row>>> {
346        unsafe {
347            match ffi::sqlite3_step(self.stmt.0) {
348                ffi::SQLITE_DONE => Poll::Ready(Ok(None)),
349
350                ffi::SQLITE_ROW => Poll::Ready(Ok(Some(
351                    Sqlite3Row {
352                        conn: self.conn.clone(),
353                        stmt: self.stmt.clone(),
354                    }
355                    .into(),
356                ))),
357
358                _ => Poll::Ready(Err(db_error(self.conn.0))),
359            }
360        }
361    }
362}
363
364impl syscall::DriverTableMetadata for Sqlite3Query {
365    fn cols(&self) -> Result<usize> {
366        let count = unsafe { ffi::sqlite3_column_count(self.stmt.0) };
367
368        Ok(count as usize)
369    }
370
371    fn col_name(&self, offset: usize) -> Result<&str> {
372        unsafe {
373            let name = ffi::sqlite3_column_name(self.stmt.0, offset as i32);
374
375            Ok(from_utf8_unchecked(CStr::from_ptr(name).to_bytes()))
376        }
377    }
378
379    fn col_type(&self, _offset: usize) -> Result<Option<SqlType>> {
380        Ok(None)
381    }
382
383    fn col_size(&self, _offset: usize) -> Result<Option<usize>> {
384        Ok(None)
385    }
386}
387
388struct Sqlite3Row {
389    #[allow(unused)]
390    conn: Arc<RawConn>,
391    stmt: Arc<RawStmt>,
392}
393
394impl syscall::DriverRow for Sqlite3Row {
395    fn get(&self, index: usize, sql_type: &SqlType) -> Result<SqlValue<'static>> {
396        let col = index as i32;
397
398        match sql_type {
399            SqlType::Bool => unsafe {
400                if 1 == ffi::sqlite3_column_int(self.stmt.0, col) {
401                    Ok(SqlValue::Bool(true))
402                } else {
403                    Ok(SqlValue::Bool(false))
404                }
405            },
406            SqlType::Int => unsafe {
407                Ok(SqlValue::Int(ffi::sqlite3_column_int64(self.stmt.0, col)))
408            },
409            SqlType::BigInt => unsafe {
410                let data = ffi::sqlite3_column_text(self.stmt.0, col) as *const i8;
411
412                if data != null() {
413                    let value = from_utf8_unchecked(CStr::from_ptr(data).to_bytes());
414
415                    Ok(SqlValue::BigInt(value.parse().map_err(|err| {
416                        Error::new(
417                            ErrorKind::InvalidData,
418                            format!(
419                                "Convert column value({}) to BigInt with error: {}",
420                                value, err
421                            ),
422                        )
423                    })?))
424                } else {
425                    Ok(SqlValue::Null)
426                }
427            },
428            SqlType::Float => unsafe {
429                Ok(SqlValue::Float(ffi::sqlite3_column_double(
430                    self.stmt.0,
431                    col,
432                )))
433            },
434
435            SqlType::Decimal => unsafe {
436                let data = ffi::sqlite3_column_text(self.stmt.0, col) as *const i8;
437
438                if data != null() {
439                    let value = from_utf8_unchecked(CStr::from_ptr(data).to_bytes());
440
441                    Ok(SqlValue::Decimal(BigDecimal::from_str(value).map_err(
442                        |err| {
443                            Error::new(
444                                ErrorKind::InvalidData,
445                                format!(
446                                    "Convert column value({}) to Decimal with error: {}",
447                                    value, err
448                                ),
449                            )
450                        },
451                    )?))
452                } else {
453                    Ok(SqlValue::Null)
454                }
455            },
456            SqlType::Binary => unsafe {
457                let len = ffi::sqlite3_column_bytes(self.stmt.0, col);
458                let data = ffi::sqlite3_column_blob(self.stmt.0, col) as *const u8;
459                let data = from_raw_parts(data, len as usize).to_owned();
460
461                Ok(SqlValue::Binary(data.into()))
462            },
463            SqlType::String => unsafe {
464                let data = ffi::sqlite3_column_text(self.stmt.0, col) as *const i8;
465
466                if data != null() {
467                    let value = CStr::from_ptr(data);
468
469                    Ok(SqlValue::String(
470                        from_utf8_unchecked(value.to_bytes()).into(),
471                    ))
472                } else {
473                    Ok(SqlValue::Null)
474                }
475            },
476            SqlType::Null => Err(Error::new(
477                ErrorKind::InvalidInput,
478                "Call result get with SqlType::Null",
479            )),
480        }
481    }
482}
483/// Register sqlite3 database driver with name `sqlite3`.
484pub fn register_sqlite3() {
485    register_rdbc_driver("sqlite3", Sqlite3Driver).unwrap();
486}
487
488#[cfg(test)]
489mod tests {
490    use super::*;
491
492    #[futures_test::test]
493    async fn test_sqlite3_spec() {
494        register_sqlite3();
495        rasi_spec::rdbc::run(|| async { open("sqlite3", "").await.unwrap() }).await;
496    }
497}