1use std::{
4 ffi::{c_void, CStr, CString},
5 os::raw::c_char,
6 ptr::{null, null_mut},
7 slice::from_raw_parts,
8};
9
10use super::error;
11
12use sqlite3_sys::*;
13
14use anyhow::{Ok, Result};
15
16use rdbc_rs::driver::{self, callback::BoxedCallback, RDBCError};
17
18pub fn colunm_decltype(
19 stmt: *mut sqlite3_stmt,
20 i: i32,
21) -> (driver::ColumnType, String, Option<u64>) {
22 let decltype = unsafe { CStr::from_ptr(sqlite3_column_decltype(stmt, i)) }.to_string_lossy();
23
24 match decltype.as_ref() {
25 "INT" | "INTEGER" | "TINYINT" | "SMALLINT" | "MEDIUMINT" | "BIGINT"
26 | "UNSIGNED BIG INT" | "INT2" | "INT8" => {
27 (driver::ColumnType::I64, decltype.to_string(), Some(8))
28 }
29 "CHARACTER(20)"
30 | "VARCHAR(255)"
31 | "VARYING CHARACTER(255)"
32 | "NCHAR(55)"
33 | "NATIVE CHARACTER(70)"
34 | "NVARCHAR(100)"
35 | "TEXT"
36 | "CLOB" => (driver::ColumnType::String, decltype.to_string(), None),
37 "BLOB" => (driver::ColumnType::Bytes, decltype.to_string(), None),
38 "REAL" | "DOUBLE" | "DOUBLE PRECISION" | "FLOAT" => {
39 (driver::ColumnType::F64, decltype.to_string(), Some(8))
40 }
41 _ => (driver::ColumnType::String, decltype.to_string(), None),
42 }
43}
44
45pub fn stmt_sql(stmt: *mut sqlite3_stmt) -> String {
46 unsafe {
47 CStr::from_ptr(sqlite3_expanded_sql(stmt))
48 .to_string_lossy()
49 .to_owned()
50 .to_string()
51 }
52}
53
54pub fn stmt_original_sql(stmt: *mut sqlite3_stmt) -> String {
55 unsafe {
56 CStr::from_ptr(sqlite3_sql(stmt))
57 .to_string_lossy()
58 .to_owned()
59 .to_string()
60 }
61}
62
63pub struct Sqlite3Driver {}
64
65impl driver::Driver for Sqlite3Driver {
66 fn open(&mut self, name: &str) -> Result<Box<dyn driver::Connection>> {
67 let conn = Connection::new(name)?;
68
69 Ok(Box::new(conn))
70 }
71}
72
73pub struct Connection {
75 db: *mut sqlite3,
76 _id: String,
77}
78
79unsafe impl Send for Connection {}
80
81impl Connection {
82 fn new(name: &str) -> Result<Self> {
83 unsafe {
84 assert!(
85 sqlite3_threadsafe() != 0,
86 "Sqlite3 must be compiled in thread safe mode."
87 );
88 }
89
90 let mut db = std::ptr::null_mut();
91
92 let flags =
93 SQLITE_OPEN_URI | SQLITE_OPEN_CREATE | SQLITE_OPEN_READWRITE | SQLITE_OPEN_NOMUTEX;
94
95 log::trace!("open sqlite3 database: {} {:X}", name, flags);
96
97 let c_name = CString::new(name)?;
98
99 unsafe {
100 let r = sqlite3_open_v2(c_name.as_ptr(), &mut db, flags, std::ptr::null());
101
102 if r != SQLITE_OK {
103 let e = if db.is_null() {
104 error::native_error(r, format!("open sqlite {} failure", name))
105 } else {
106 let e = error::db_native_error(db, r);
107
108 let r = sqlite3_close(db); if r != SQLITE_OK {
112 log::error!("close sqlite3 conn failed: code({})", r);
113 }
114
115 e
116 };
117
118 return Err(e);
119 } else {
120 log::trace!("create connection {:?}", db);
121
122 let c_str = CString::new("PRAGMA foreign_keys = ON;").unwrap();
123
124 let rc = sqlite3_exec(
125 db,
126 c_str.as_ptr(),
127 None,
128 null_mut::<c_void>(),
129 null_mut::<*mut i8>(),
130 );
131
132 if rc != SQLITE_OK {
133 return Err(error::db_native_error(db, rc));
134 }
135
136 return Ok(Self {
137 db,
138 _id: format!("{:?}", db),
139 });
140 }
141 }
142 }
143
144 fn _begin(&mut self) -> anyhow::Result<Box<dyn driver::Transaction>> {
145 let rc = unsafe {
146 let c_str = CString::new("BEGIN").unwrap();
147
148 sqlite3_exec(
149 self.db,
150 c_str.as_ptr(),
151 None,
152 null_mut::<c_void>(),
153 null_mut::<*mut i8>(),
154 )
155 };
156
157 if rc != SQLITE_OK {
158 return Err(error::db_native_error(self.db, rc));
159 }
160
161 Ok(Box::new(Transaction {
162 conn: Connection {
163 db: self.db,
164 _id: self._id.clone(),
165 },
166 finished: false,
167 id: uuid::Uuid::new_v4().to_string(), }))
169 }
170
171 fn _prepare(&mut self, query: String) -> Result<Box<dyn driver::Statement>> {
172 let sqlite3_query = CString::new(query.clone())?;
173
174 let mut stmt = null_mut();
175
176 let rc = unsafe {
177 sqlite3_prepare_v2(
178 self.db,
179 sqlite3_query.as_ptr(),
180 sqlite3_query.as_bytes().len() as i32,
181 &mut stmt,
182 null_mut::<*const c_char>(),
183 )
184 };
185
186 if rc != SQLITE_OK {
187 return Err(error::error_with_sql(self.db, rc, &query));
188 }
189
190 if stmt.is_null() {
192 return Err(anyhow::anyhow!("invalid input sql {}", query));
193 }
194
195 Ok(Box::new(Statement {
196 db: self.db,
197 stmt,
198 id: format!("{:?}", stmt),
199 }))
200 }
201}
202
203impl driver::Connection for Connection {
204 fn conn_status(&self) -> driver::ConnStatus {
205 driver::ConnStatus::Connected
206 }
207
208 fn id(&self) -> &str {
209 &self._id
210 }
211
212 fn begin(&mut self, callback: BoxedCallback<Box<dyn driver::Transaction>>) {
213 callback.invoke(self._begin())
214 }
215
216 fn prepare(&mut self, query: String, callback: BoxedCallback<Box<dyn driver::Statement>>) {
217 callback.invoke(self._prepare(query))
218 }
219}
220
221impl Drop for Connection {
222 fn drop(&mut self) {
223 if !self.db.is_null() {
224 log::trace!("drop connection {:?}", self.db);
225
226 let r = unsafe { sqlite3_close(self.db) }; self.db = std::ptr::null_mut(); if r != SQLITE_OK {
232 log::error!("close sqlite3 conn failed: code({})", r);
233 }
234 }
235 }
236}
237
238pub struct Statement {
239 db: *mut sqlite3,
240 stmt: *mut sqlite3_stmt,
241 pub id: String,
242}
243
244unsafe impl Send for Statement {}
245
246fn get_bind_index(stmt: *mut sqlite3_stmt, pos: driver::ArgName) -> anyhow::Result<i32> {
247 let index = match &pos {
248 driver::ArgName::Offset(index) => *index as i32,
249 driver::ArgName::String(name) => {
250 let c_named = CString::new(name.as_str())?;
251 unsafe { sqlite3_bind_parameter_index(stmt, c_named.as_ptr()) }
252 }
253 };
254
255 if index == 0 {
256 return Err(anyhow::format_err!(
257 "arg name({:?}) not found, {}",
258 pos,
259 stmt_original_sql(stmt),
260 ));
261 }
262
263 return Ok(index);
264}
265
266impl Statement {
267 unsafe fn bind_args(&mut self, args: Vec<rdbc_rs::driver::Argument>) -> anyhow::Result<()> {
268 sqlite3_reset(self.stmt);
270
271 log::trace!("execute sql {} with args {:?}", stmt_sql(self.stmt), args);
272
273 for arg in args {
274 let index = get_bind_index(self.stmt, arg.name)?;
275
276 let rc = match arg.value {
277 driver::ArgValue::Bytes(bytes) => {
278 let ptr = bytes.as_ptr();
279 let len = bytes.len();
280 sqlite3_bind_blob(
281 self.stmt,
282 index,
283 ptr as *const c_void,
284 len as i32,
285 Some(std::mem::transmute(SQLITE_TRANSIENT as usize)),
286 )
287 }
288 driver::ArgValue::F64(f64) => sqlite3_bind_double(self.stmt, index, f64),
289
290 driver::ArgValue::I64(i64) => sqlite3_bind_int64(self.stmt, index, i64),
291
292 driver::ArgValue::String(str) => {
293 let str = CString::new(str)?;
294
295 let ptr = str.as_ptr();
296 let len = str.as_bytes().len() as i32;
297
298 sqlite3_bind_text(
299 self.stmt,
300 index,
301 ptr,
302 len,
303 Some(std::mem::transmute(SQLITE_TRANSIENT as usize)),
304 )
305 }
306
307 driver::ArgValue::Null => SQLITE_OK,
308 };
309
310 if rc != SQLITE_OK {
311 return Err(error::db_native_error(self.db, rc));
312 }
313 }
314
315 Ok(())
316 }
317
318 fn _query(&mut self, args: Vec<rdbc_rs::Argument>) -> Result<Box<dyn driver::Rows>> {
319 unsafe { self.bind_args(args) }?;
320
321 return Ok(Box::new(Rows {
322 db: self.db,
323 stmt: self.stmt,
324 columns: None,
325 has_next: false,
326 id: uuid::Uuid::new_v4().to_string(),
327 }));
328 }
329}
330
331impl driver::Statement for Statement {
332 fn execute(
333 &mut self,
334 args: Vec<rdbc_rs::Argument>,
335 callback: BoxedCallback<driver::ExecResult>,
336 ) {
337 let exec = || {
338 unsafe { self.bind_args(args) }?;
339
340 let rc = unsafe { sqlite3_step(self.stmt) };
341
342 match rc {
345 SQLITE_DONE => {
346 let last_insert_id = unsafe { sqlite3_last_insert_rowid(self.db) } as u64;
347 let raws_affected = unsafe { sqlite3_changes(self.db) } as u64;
348
349 return Ok(driver::ExecResult {
350 last_insert_id,
351 raws_affected,
352 });
353 }
354 SQLITE_ROW => {
355 return Err(anyhow::Error::new(driver::RDBCError::UnexpectRows));
356 }
357 _ => {
358 return Err(error::db_native_error(self.db, rc));
359 }
360 };
361 };
362
363 callback.invoke(exec())
364 }
365
366 fn num_input(&self, callback: BoxedCallback<Option<usize>>) {
367 callback.invoke(Ok(Some(unsafe {
368 sqlite3_bind_parameter_count(self.stmt) as usize
369 })))
370 }
371
372 fn query(
373 &mut self,
374 args: Vec<rdbc_rs::Argument>,
375 callback: BoxedCallback<Box<dyn driver::Rows>>,
376 ) {
377 callback.invoke(self._query(args))
378 }
379}
380
381impl Drop for Statement {
382 fn drop(&mut self) {
383 if !self.stmt.is_null() {
384 log::trace!("drop stmt: {}", stmt_sql(self.stmt));
385 unsafe { sqlite3_finalize(self.stmt) };
386 self.stmt = null_mut();
387 }
388 }
389}
390
391pub struct Transaction {
392 conn: Connection,
393 finished: bool,
394 pub id: String,
395}
396
397impl Transaction {
398 fn _rollback(&self) -> anyhow::Result<()> {
399 let rc = unsafe {
400 let c_str = CString::new("ROLLBACK").unwrap();
401
402 sqlite3_exec(
403 self.conn.db,
404 c_str.as_ptr(),
405 None,
406 null_mut::<c_void>(),
407 null_mut::<*mut i8>(),
408 )
409 };
410
411 if rc != SQLITE_OK {
412 return Err(error::error_with_sql(self.conn.db, rc, "ROLLBACK"));
413 }
414
415 Ok(())
416 }
417}
418
419impl driver::Transaction for Transaction {
420 fn commit(&mut self, callback: BoxedCallback<()>) {
421 let mut invoke = || {
422 let rc = unsafe {
423 let c_str = CString::new("COMMIT").unwrap();
424
425 sqlite3_exec(
426 self.conn.db,
427 c_str.as_ptr(),
428 None,
429 null_mut::<c_void>(),
430 null_mut::<*mut i8>(),
431 )
432 };
433
434 self.finished = true;
435
436 if rc != SQLITE_OK {
437 Err(error::error_with_sql(self.conn.db, rc, "COMMIT"))
438 } else {
439 Ok(())
440 }
441 };
442
443 callback.invoke(invoke());
444 }
445
446 fn prepare(&mut self, query: String, callback: BoxedCallback<Box<dyn driver::Statement>>) {
447 use driver::Connection;
448
449 self.conn.prepare(query, callback)
450 }
451
452 fn rollback(&mut self, callback: BoxedCallback<()>) {
453 let mut invoke = || {
454 self.finished = true;
455
456 self._rollback()
457 };
458
459 callback.invoke(invoke())
460 }
461}
462
463impl Drop for Transaction {
464 fn drop(&mut self) {
465 if !self.finished {
467 _ = self._rollback();
468 self.finished = true;
469 }
470
471 self.conn.db = null_mut();
472 }
473}
474
475pub struct Rows {
476 db: *mut sqlite3,
477 stmt: *mut sqlite3_stmt,
478 columns: Option<Vec<driver::Column>>,
479 has_next: bool,
480 pub id: String,
481}
482
483impl Rows {
484 fn _columns(&mut self) -> Result<&Vec<driver::Column>> {
485 if self.columns.is_none() {
486 let mut columns = vec![];
487
488 unsafe {
489 let count = sqlite3_column_count(self.stmt);
490
491 for i in 0..count {
492 let name = sqlite3_column_name(self.stmt, i);
493
494 let (_, decltype, len) = colunm_decltype(self.stmt, i);
495
496 columns.push(driver::Column {
497 column_index: i as u64,
498 column_name: CStr::from_ptr(name).to_string_lossy().to_string(),
499 column_decltype: decltype,
500 column_decltype_len: len,
501 })
502 }
503 };
504
505 self.columns = Some(columns);
506 }
507
508 Ok(self.columns.as_ref().unwrap())
509 }
510
511 fn _get(
512 &mut self,
513 name: driver::ArgName,
514 column_type: driver::ColumnType,
515 ) -> Result<Option<driver::ArgValue>> {
516 log::trace!(
517 "{} :get column({:?},{:?})",
518 stmt_sql(self.stmt),
519 name,
520 column_type
521 );
522
523 let index = match name {
524 driver::ArgName::Offset(index) => index as i32,
525 driver::ArgName::String(name) => {
526 let columns = self._columns()?;
527
528 let col = columns
529 .iter()
530 .find(|column| column.column_name.to_uppercase() == name.to_uppercase())
531 .map(|c| c.column_index as i32);
532
533 if let Some(index) = col {
534 index
535 } else {
536 return Ok(None);
537 }
538 }
539 };
540
541 let max_index = unsafe { sqlite3_column_count(self.stmt) };
542
543 if index >= max_index {
544 return Err(anyhow::Error::new(RDBCError::OutOfRange(index as u64)));
545 }
546
547 if !self.has_next {
548 return Err(anyhow::Error::new(RDBCError::NextDataError));
549 }
550
551 let value = unsafe {
552 match column_type {
553 driver::ColumnType::Bytes => {
554 let len = sqlite3_column_bytes(self.stmt, index);
555 let data = sqlite3_column_blob(self.stmt, index) as *const u8;
556 let data = from_raw_parts(data, len as usize).to_owned();
557
558 driver::ArgValue::Bytes(data)
559 }
560 driver::ColumnType::I64 => {
561 driver::ArgValue::I64(sqlite3_column_int64(self.stmt, index))
562 }
563 driver::ColumnType::F64 => {
564 driver::ArgValue::F64(sqlite3_column_double(self.stmt, index))
565 }
566 driver::ColumnType::String => {
567 let data = sqlite3_column_text(self.stmt, index) as *const i8;
568
569 if data != null() {
570 driver::ArgValue::String(CStr::from_ptr(data).to_string_lossy().to_string())
571 } else {
572 driver::ArgValue::String("".to_owned())
573 }
574 }
575 driver::ColumnType::Null => driver::ArgValue::Null,
576 }
577 };
578
579 Ok(Some(value))
580 }
581
582 fn _next(&mut self) -> Result<bool> {
583 match unsafe { sqlite3_step(self.stmt) } {
584 SQLITE_DONE => {
585 self.has_next = false;
586 Ok(false)
587 }
588
589 SQLITE_ROW => {
590 self.has_next = true;
591 Ok(true)
592 }
593
594 rc => {
595 self.has_next = false;
596 Err(error::db_native_error(self.db, rc))
597 }
598 }
599 }
600}
601
602unsafe impl Send for Rows {}
603
604impl driver::Rows for Rows {
605 fn colunms(&mut self, callback: BoxedCallback<Vec<driver::Column>>) {
606 callback.invoke(self._columns().map(|c| c.clone()))
607 }
608
609 fn next(&mut self, callback: BoxedCallback<bool>) {
610 callback.invoke(self._next())
611 }
612
613 fn get(
614 &mut self,
615 name: driver::ArgName,
616 column_type: driver::ColumnType,
617 callback: BoxedCallback<Option<driver::ArgValue>>,
618 ) {
619 callback.invoke(self._get(name, column_type))
620 }
621}
622
623impl Drop for Rows {
624 fn drop(&mut self) {
625 log::trace!("reset stmt {}", stmt_sql(self.stmt));
626 unsafe { sqlite3_reset(self.stmt) };
627 }
628}