1use std::{
2 ffi::{CStr, CString},
3 io::{Error, ErrorKind, Result},
4 os::raw::c_void,
5 ptr::{null, null_mut},
6 slice::from_raw_parts,
7 str::{from_utf8_unchecked, FromStr},
8 sync::{
9 atomic::{AtomicBool, Ordering},
10 Arc,
11 },
12 task::Poll,
13};
14
15use rasi::rdbc::*;
16use sqlite3_sys as ffi;
17
18struct Sqlite3Driver;
19
20unsafe fn db_error(db: *mut ffi::sqlite3) -> Error {
21 Error::new(
22 ErrorKind::Other,
23 format!(
24 "sqlite3: code={}, error={}",
25 ffi::sqlite3_errcode(db),
26 from_utf8_unchecked(CStr::from_ptr(ffi::sqlite3_errmsg(db)).to_bytes())
27 ),
28 )
29}
30
31impl syscall::Driver for Sqlite3Driver {
32 fn create_connection(
33 &self,
34 driver_name: &str,
35 source_name: &str,
36 ) -> std::io::Result<Connection> {
37 let mut db = null_mut();
38
39 unsafe {
40 let rc = ffi::sqlite3_open_v2(
41 CString::new(source_name)?.as_ptr(),
42 &mut db,
43 ffi::SQLITE_OPEN_CREATE
44 | ffi::SQLITE_OPEN_READWRITE
45 | ffi::SQLITE_OPEN_URI
46 | ffi::SQLITE_OPEN_FULLMUTEX,
47 null_mut(),
48 );
49
50 if rc != ffi::SQLITE_OK {
51 return Err(db_error(db));
52 }
53 }
54
55 let conn = Sqlite3Conn(Arc::new(RawConn(db)));
56
57 Ok((driver_name.to_owned(), conn).into())
58 }
59}
60
61struct RawConn(*mut ffi::sqlite3);
62
63unsafe impl Send for RawConn {}
64unsafe impl Sync for RawConn {}
65
66impl Drop for RawConn {
67 fn drop(&mut self) {
68 unsafe {
69 ffi::sqlite3_close(self.0);
70 }
71 }
72}
73
74struct RawStmt(*mut ffi::sqlite3_stmt);
75
76unsafe impl Send for RawStmt {}
77unsafe impl Sync for RawStmt {}
78
79impl Drop for RawStmt {
80 fn drop(&mut self) {
81 unsafe {
82 ffi::sqlite3_finalize(self.0);
83 }
84 }
85}
86
87struct Sqlite3Conn(Arc<RawConn>);
88
89fn exec(conn: &RawConn, sql: &CStr) -> Result<()> {
90 unsafe {
91 let rc = ffi::sqlite3_exec(conn.0, sql.as_ptr(), None, null_mut(), null_mut());
92
93 if rc != ffi::SQLITE_OK {
94 return Err(db_error(conn.0));
95 }
96 }
97
98 Ok(())
99}
100
101fn prepare(conn: Arc<RawConn>, sql: &CStr) -> Result<Prepare> {
102 let mut c_stmt = null_mut();
103
104 unsafe {
105 let rc = ffi::sqlite3_prepare_v2(conn.0, sql.as_ptr(), -1, &mut c_stmt, null_mut());
106
107 if rc != ffi::SQLITE_OK {
108 return Err(db_error(conn.0));
109 }
110 }
111
112 Ok(Sqlite3Prepare {
113 conn,
114 stmt: Arc::new(RawStmt(c_stmt)),
115 }
116 .into())
117}
118
119impl syscall::DriverConn for Sqlite3Conn {
120 fn poll_ready(&self, _cx: &mut std::task::Context<'_>) -> std::task::Poll<std::io::Result<()>> {
121 Poll::Ready(Ok(()))
122 }
123
124 fn begin(&self) -> std::io::Result<Transaction> {
125 exec(&self.0, c"BEGIN;")?;
126
127 Ok(Sqlite3Tx(self.0.clone(), AtomicBool::new(false)).into())
128 }
129
130 fn prepare(&self, query: &str) -> std::io::Result<Prepare> {
131 prepare(self.0.clone(), CString::new(query)?.as_ref())
132 }
133
134 fn exec(&self, query: &str, params: &[SqlParameter<'_>]) -> std::io::Result<Update> {
135 self.prepare(query)?.as_driver_query().exec(params)
136 }
137
138 fn query(&self, query: &str, params: &[SqlParameter<'_>]) -> std::io::Result<Query> {
139 self.prepare(query)?.as_driver_query().query(params)
140 }
141}
142
143struct Sqlite3Tx(Arc<RawConn>, AtomicBool);
144
145impl Drop for Sqlite3Tx {
146 fn drop(&mut self) {
147 if self
148 .1
149 .compare_exchange(false, true, Ordering::SeqCst, Ordering::Relaxed)
150 .is_ok()
151 {
152 if let Err(err) = exec(&self.0, c"COMMIT;") {
153 log::error!(target:"Sqlite3Tx","auto commit failed, {}",err);
154 }
155 }
156 }
157}
158
159impl syscall::DriverTx for Sqlite3Tx {
160 fn poll_ready(&self, _cx: &mut std::task::Context<'_>) -> Poll<Result<()>> {
161 Poll::Ready(Ok(()))
162 }
163
164 fn poll_rollback(&self, _cx: &mut std::task::Context<'_>) -> Poll<Result<()>> {
165 if self
166 .1
167 .compare_exchange(false, true, Ordering::SeqCst, Ordering::Relaxed)
168 .is_ok()
169 {
170 exec(&self.0, c"ROLLBACK;")?;
171
172 Poll::Ready(Ok(()))
173 } else {
174 Poll::Ready(Err(Error::new(ErrorKind::Other, "Call rollback twice")))
175 }
176 }
177
178 fn prepare(&self, query: &str) -> Result<Prepare> {
179 prepare(self.0.clone(), CString::new(query)?.as_ref())
180 }
181
182 fn exec(&self, query: &str, params: &[SqlParameter<'_>]) -> Result<Update> {
183 self.prepare(query)?.as_driver_query().exec(params)
184 }
185
186 fn query(&self, query: &str, params: &[SqlParameter<'_>]) -> Result<Query> {
187 self.prepare(query)?.as_driver_query().query(params)
188 }
189}
190
191struct Sqlite3Prepare {
192 conn: Arc<RawConn>,
193 stmt: Arc<RawStmt>,
194}
195
196impl syscall::DriverPrepare for Sqlite3Prepare {
197 fn poll_ready(&self, _cx: &mut std::task::Context<'_>) -> Poll<Result<()>> {
198 Poll::Ready(Ok(()))
199 }
200
201 fn exec(&self, params: &[SqlParameter<'_>]) -> Result<Update> {
202 self.bind_params(params)?;
203
204 let rc = unsafe { ffi::sqlite3_step(self.stmt.0) };
205
206 match rc {
207 ffi::SQLITE_DONE => {
208 let last_insert_id = unsafe { ffi::sqlite3_last_insert_rowid(self.conn.0) } as i64;
209 let raws_affected = unsafe { ffi::sqlite3_changes(self.conn.0) } as i64;
210
211 return Ok(Sqlite3Update(last_insert_id, raws_affected).into());
212 }
213 ffi::SQLITE_ROW => {
214 return Err(Error::new(
215 ErrorKind::Unsupported,
216 "Call exec on query statement.",
217 ))
218 }
219 _ => return Err(unsafe { db_error(self.conn.0) }),
220 }
221 }
222
223 fn query(&self, params: &[SqlParameter<'_>]) -> Result<Query> {
224 self.bind_params(params)?;
225
226 Ok(Sqlite3Query {
227 conn: self.conn.clone(),
228 stmt: self.stmt.clone(),
229 }
230 .into())
231 }
232}
233
234impl Sqlite3Prepare {
235 fn bind_params(&self, params: &[SqlParameter]) -> Result<()> {
236 unsafe {
237 if ffi::SQLITE_OK != ffi::sqlite3_reset(self.stmt.0) {
238 return Err(db_error(self.conn.0));
239 }
240 }
241
242 let mut named_params = 0;
243
244 for (index, param) in params.iter().enumerate() {
245 let (index, value) = match param {
246 SqlParameter::Named(name, value) => unsafe {
247 let index = ffi::sqlite3_bind_parameter_index(
248 self.stmt.0,
249 CString::new(name.as_ref())?.as_ptr(),
250 );
251
252 if index == 0 {
253 return Err(Error::new(
254 ErrorKind::NotFound,
255 format!("no matching parameter is found: {}", name),
256 ));
257 }
258
259 named_params += 1;
260
261 (index, value)
262 },
263 SqlParameter::Offset(value) => (index as i32 + 1 - named_params, value),
264 };
265
266 let rc = match value {
267 SqlValue::Bool(value) => {
268 let value = if *value { 1 } else { 0 };
269
270 unsafe { ffi::sqlite3_bind_int(self.stmt.0, index, value) }
271 }
272 SqlValue::Int(value) => unsafe {
273 ffi::sqlite3_bind_int64(self.stmt.0, index, *value)
274 },
275 SqlValue::BigInt(value) => unsafe {
276 let value = CString::new(format!("{value}"))?.as_ptr();
277
278 ffi::sqlite3_bind_text(
279 self.stmt.0,
280 index,
281 value,
282 -1,
283 Some(std::mem::transmute(-1isize)),
284 )
285 },
286 SqlValue::Float(value) => unsafe {
287 ffi::sqlite3_bind_double(self.stmt.0, index, *value)
288 },
289
290 SqlValue::Decimal(value) => unsafe {
291 let value = CString::new(format!("{value}"))?.as_ptr();
292
293 ffi::sqlite3_bind_text(
294 self.stmt.0,
295 index,
296 value,
297 -1,
298 Some(std::mem::transmute(-1isize)),
299 )
300 },
301 SqlValue::Binary(value) => unsafe {
302 ffi::sqlite3_bind_blob(
303 self.stmt.0,
304 index,
305 value.as_ptr() as *const c_void,
306 value.len() as i32,
307 Some(std::mem::transmute(-1isize)),
308 )
309 },
310 SqlValue::String(value) => unsafe {
311 ffi::sqlite3_bind_text(
312 self.stmt.0,
313 index,
314 CString::new(value.as_ref())?.as_ptr(),
315 -1,
316 Some(std::mem::transmute(-1isize)),
317 )
318 },
319 SqlValue::Null => unsafe { ffi::sqlite3_bind_null(self.stmt.0, index) },
320 };
321
322 if rc != ffi::SQLITE_OK {
323 return Err(unsafe { db_error(self.conn.0) });
324 }
325 }
326
327 Ok(())
328 }
329}
330
331struct Sqlite3Update(i64, i64);
332
333impl syscall::DriverUpdate for Sqlite3Update {
334 fn poll_ready(&self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(i64, i64)>> {
335 Poll::Ready(Ok((self.0, self.1)))
336 }
337}
338
339struct Sqlite3Query {
340 conn: Arc<RawConn>,
341 stmt: Arc<RawStmt>,
342}
343
344impl syscall::DriverQuery for Sqlite3Query {
345 fn poll_next(&self, _cx: &mut std::task::Context<'_>) -> Poll<Result<Option<Row>>> {
346 unsafe {
347 match ffi::sqlite3_step(self.stmt.0) {
348 ffi::SQLITE_DONE => Poll::Ready(Ok(None)),
349
350 ffi::SQLITE_ROW => Poll::Ready(Ok(Some(
351 Sqlite3Row {
352 conn: self.conn.clone(),
353 stmt: self.stmt.clone(),
354 }
355 .into(),
356 ))),
357
358 _ => Poll::Ready(Err(db_error(self.conn.0))),
359 }
360 }
361 }
362}
363
364impl syscall::DriverTableMetadata for Sqlite3Query {
365 fn cols(&self) -> Result<usize> {
366 let count = unsafe { ffi::sqlite3_column_count(self.stmt.0) };
367
368 Ok(count as usize)
369 }
370
371 fn col_name(&self, offset: usize) -> Result<&str> {
372 unsafe {
373 let name = ffi::sqlite3_column_name(self.stmt.0, offset as i32);
374
375 Ok(from_utf8_unchecked(CStr::from_ptr(name).to_bytes()))
376 }
377 }
378
379 fn col_type(&self, _offset: usize) -> Result<Option<SqlType>> {
380 Ok(None)
381 }
382
383 fn col_size(&self, _offset: usize) -> Result<Option<usize>> {
384 Ok(None)
385 }
386}
387
388struct Sqlite3Row {
389 #[allow(unused)]
390 conn: Arc<RawConn>,
391 stmt: Arc<RawStmt>,
392}
393
394impl syscall::DriverRow for Sqlite3Row {
395 fn get(&self, index: usize, sql_type: &SqlType) -> Result<SqlValue<'static>> {
396 let col = index as i32;
397
398 match sql_type {
399 SqlType::Bool => unsafe {
400 if 1 == ffi::sqlite3_column_int(self.stmt.0, col) {
401 Ok(SqlValue::Bool(true))
402 } else {
403 Ok(SqlValue::Bool(false))
404 }
405 },
406 SqlType::Int => unsafe {
407 Ok(SqlValue::Int(ffi::sqlite3_column_int64(self.stmt.0, col)))
408 },
409 SqlType::BigInt => unsafe {
410 let data = ffi::sqlite3_column_text(self.stmt.0, col) as *const i8;
411
412 if data != null() {
413 let value = from_utf8_unchecked(CStr::from_ptr(data).to_bytes());
414
415 Ok(SqlValue::BigInt(value.parse().map_err(|err| {
416 Error::new(
417 ErrorKind::InvalidData,
418 format!(
419 "Convert column value({}) to BigInt with error: {}",
420 value, err
421 ),
422 )
423 })?))
424 } else {
425 Ok(SqlValue::Null)
426 }
427 },
428 SqlType::Float => unsafe {
429 Ok(SqlValue::Float(ffi::sqlite3_column_double(
430 self.stmt.0,
431 col,
432 )))
433 },
434
435 SqlType::Decimal => unsafe {
436 let data = ffi::sqlite3_column_text(self.stmt.0, col) as *const i8;
437
438 if data != null() {
439 let value = from_utf8_unchecked(CStr::from_ptr(data).to_bytes());
440
441 Ok(SqlValue::Decimal(BigDecimal::from_str(value).map_err(
442 |err| {
443 Error::new(
444 ErrorKind::InvalidData,
445 format!(
446 "Convert column value({}) to Decimal with error: {}",
447 value, err
448 ),
449 )
450 },
451 )?))
452 } else {
453 Ok(SqlValue::Null)
454 }
455 },
456 SqlType::Binary => unsafe {
457 let len = ffi::sqlite3_column_bytes(self.stmt.0, col);
458 let data = ffi::sqlite3_column_blob(self.stmt.0, col) as *const u8;
459 let data = from_raw_parts(data, len as usize).to_owned();
460
461 Ok(SqlValue::Binary(data.into()))
462 },
463 SqlType::String => unsafe {
464 let data = ffi::sqlite3_column_text(self.stmt.0, col) as *const i8;
465
466 if data != null() {
467 let value = CStr::from_ptr(data);
468
469 Ok(SqlValue::String(
470 from_utf8_unchecked(value.to_bytes()).into(),
471 ))
472 } else {
473 Ok(SqlValue::Null)
474 }
475 },
476 SqlType::Null => Err(Error::new(
477 ErrorKind::InvalidInput,
478 "Call result get with SqlType::Null",
479 )),
480 }
481 }
482}
483pub fn register_sqlite3() {
485 register_rdbc_driver("sqlite3", Sqlite3Driver).unwrap();
486}
487
488#[cfg(test)]
489mod tests {
490 use super::*;
491
492 #[futures_test::test]
493 async fn test_sqlite3_spec() {
494 register_sqlite3();
495 rasi_spec::rdbc::run(|| async { open("sqlite3", "").await.unwrap() }).await;
496 }
497}