tank_sqlite/
connection.rs

1use crate::{
2    CBox, SQLiteDriver, SQLitePrepared, SQLiteTransaction,
3    extract::{extract_name, extract_value},
4};
5use async_stream::try_stream;
6use flume::Sender;
7use libsqlite3_sys::*;
8use std::{
9    borrow::Cow,
10    ffi::{CStr, CString, c_char, c_int},
11    mem, ptr,
12    str::FromStr,
13    sync::{
14        Arc,
15        atomic::{AtomicPtr, Ordering},
16    },
17};
18use tank_core::{
19    AsQuery, Connection, Error, ErrorContext, Executor, Prepared, Query, QueryResult, Result,
20    RowLabeled, RowsAffected, error_message_from_ptr, send_value, stream::Stream, truncate_long,
21};
22use tokio::task::spawn_blocking;
23
24/// Wrapper for a SQLite `sqlite3` connection pointer used by the SQLite driver.
25///
26/// Provides helpers to prepare/execute statements and stream results into `tank_core` result types.
27pub struct SQLiteConnection {
28    pub(crate) connection: CBox<*mut sqlite3>,
29}
30
31impl SQLiteConnection {
32    pub fn last_error(&self) -> String {
33        unsafe {
34            let errcode = sqlite3_errcode(*self.connection);
35            format!(
36                "Error ({errcode}): {}",
37                error_message_from_ptr(&sqlite3_errmsg(*self.connection)),
38            )
39        }
40    }
41
42    pub(crate) fn do_run_prepared(
43        connection: *mut sqlite3,
44        statement: *mut sqlite3_stmt,
45        tx: Sender<Result<QueryResult>>,
46    ) {
47        unsafe {
48            let count = sqlite3_column_count(statement);
49            let labels = match (0..count)
50                .map(|i| extract_name(statement, i))
51                .collect::<Result<Arc<[_]>>>()
52            {
53                Ok(labels) => labels,
54                Err(error) => {
55                    send_value!(tx, Err(error.into()));
56                    return;
57                }
58            };
59            loop {
60                match sqlite3_step(statement) {
61                    SQLITE_BUSY => {
62                        continue;
63                    }
64                    SQLITE_DONE => {
65                        if sqlite3_stmt_readonly(statement) == 0 {
66                            send_value!(
67                                tx,
68                                Ok(QueryResult::Affected(RowsAffected {
69                                    rows_affected: Some(sqlite3_changes64(connection) as _),
70                                    last_affected_id: Some(sqlite3_last_insert_rowid(connection)),
71                                }))
72                            );
73                        }
74                        break;
75                    }
76                    SQLITE_ROW => {
77                        let values = match (0..count)
78                            .map(|i| extract_value(statement, i))
79                            .collect::<Result<_>>()
80                        {
81                            Ok(value) => value,
82                            Err(error) => {
83                                send_value!(tx, Err(error));
84                                return;
85                            }
86                        };
87                        send_value!(
88                            tx,
89                            Ok(QueryResult::Row(RowLabeled {
90                                labels: labels.clone(),
91                                values: values,
92                            }))
93                        )
94                    }
95                    _ => {
96                        send_value!(
97                            tx,
98                            Err(Error::msg(
99                                error_message_from_ptr(&sqlite3_errmsg(sqlite3_db_handle(
100                                    statement,
101                                )))
102                                .to_string(),
103                            ))
104                        );
105                        break;
106                    }
107                }
108            }
109        }
110    }
111
112    pub(crate) fn do_run_unprepared(
113        connection: *mut sqlite3,
114        sql: &str,
115        tx: Sender<Result<QueryResult>>,
116    ) {
117        unsafe {
118            let sql = sql.trim();
119            let mut it = sql.as_ptr() as *const c_char;
120            let mut len = sql.len();
121            loop {
122                let (statement, tail) = {
123                    let mut statement = SQLitePrepared::new(CBox::new(ptr::null_mut(), |p| {
124                        sqlite3_finalize(p);
125                    }));
126                    let mut sql_tail = ptr::null();
127                    let rc = sqlite3_prepare_v2(
128                        connection,
129                        it,
130                        len as c_int,
131                        &mut *statement.statement,
132                        &mut sql_tail,
133                    );
134                    if rc != SQLITE_OK {
135                        send_value!(
136                            tx,
137                            Err(Error::msg(
138                                error_message_from_ptr(&sqlite3_errmsg(connection)).to_string(),
139                            ))
140                        );
141                        return;
142                    }
143                    (statement, sql_tail)
144                };
145                Self::do_run_prepared(connection, statement.statement(), tx.clone());
146                len = if tail != ptr::null() {
147                    len - tail.offset_from_unsigned(it)
148                } else {
149                    0
150                };
151                if len == 0 {
152                    break;
153                }
154                it = tail;
155            }
156        };
157    }
158}
159
160impl Executor for SQLiteConnection {
161    type Driver = SQLiteDriver;
162
163    async fn prepare(&mut self, sql: String) -> Result<Query<Self::Driver>> {
164        let connection = AtomicPtr::new(*self.connection);
165        let context = format!("While preparing the query:\n{}", truncate_long!(sql));
166        let prepared = spawn_blocking(move || unsafe {
167            let connection = connection.load(Ordering::Relaxed);
168            let len = sql.len();
169            let sql = CString::new(sql.into_bytes())?;
170            let mut statement = CBox::new(ptr::null_mut(), |p| {
171                sqlite3_finalize(p);
172            });
173            let mut tail = ptr::null();
174            let rc = sqlite3_prepare_v2(
175                connection,
176                sql.as_ptr(),
177                len as c_int,
178                &mut *statement,
179                &mut tail,
180            );
181            if rc != SQLITE_OK {
182                let error =
183                    Error::msg(error_message_from_ptr(&sqlite3_errmsg(connection)).to_string())
184                        .context(context);
185                log::error!("{:#}", error);
186                return Err(error);
187            }
188            if tail != ptr::null() && *tail != '\0' as i8 {
189                let error = Error::msg(format!(
190                    "Cannot prepare more than one statement at a time (remaining: {})",
191                    CStr::from_ptr(tail).to_str().unwrap_or("unprintable")
192                ))
193                .context(context);
194                log::error!("{:#}", error);
195                return Err(error);
196            }
197            Ok(statement)
198        })
199        .await?;
200        Ok(SQLitePrepared::new(prepared?).into())
201    }
202
203    fn run<'s>(
204        &'s mut self,
205        query: impl AsQuery<Self::Driver> + 's,
206    ) -> impl Stream<Item = Result<QueryResult>> + Send {
207        let mut query = query.as_query();
208        let context = Arc::new(format!("While running the query:\n{}", query.as_mut()));
209        let (tx, rx) = flume::unbounded::<Result<QueryResult>>();
210        let connection = AtomicPtr::new(*self.connection);
211        let mut owned = mem::take(query.as_mut());
212        let join = spawn_blocking(move || {
213            match &mut owned {
214                Query::Raw(query) => {
215                    Self::do_run_unprepared(connection.load(Ordering::Relaxed), &query.sql, tx);
216                }
217                Query::Prepared(prepared) => {
218                    Self::do_run_prepared(
219                        connection.load(Ordering::Relaxed),
220                        prepared.statement(),
221                        tx,
222                    );
223                    let _ = prepared.clear_bindings();
224                }
225            }
226            owned
227        });
228        try_stream! {
229            while let Ok(result) = rx.recv_async().await {
230                yield result.map_err(|e| {
231                    let error = e.context(context.clone());
232                    log::error!("{:#}", error);
233                    error
234                })?;
235            }
236            *query.as_mut() = mem::take(&mut join.await?);
237        }
238    }
239}
240
241impl Connection for SQLiteConnection {
242    async fn connect(url: Cow<'static, str>) -> Result<SQLiteConnection> {
243        let context = format!("While trying to connect to `{}`", truncate_long!(url));
244        let url = Self::sanitize_url(url)?;
245        let url = CString::from_str(&url.as_str().replacen("sqlite://", "file:", 1))
246            .with_context(|| context.clone())?;
247        let mut connection;
248        unsafe {
249            connection = CBox::new(ptr::null_mut(), |p| {
250                if sqlite3_close(p) != SQLITE_OK {
251                    log::error!("Could not close sqlite connection")
252                }
253            });
254            let rc = sqlite3_open_v2(
255                url.as_ptr(),
256                &mut *connection,
257                SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | SQLITE_OPEN_URI,
258                ptr::null(),
259            );
260            if rc != SQLITE_OK {
261                let error =
262                    Error::msg(error_message_from_ptr(&sqlite3_errmsg(*connection)).to_string())
263                        .context(context);
264                log::error!("{:#}", error);
265                return Err(error);
266            }
267        }
268        Ok(Self { connection })
269    }
270
271    fn begin(&mut self) -> impl Future<Output = Result<SQLiteTransaction<'_>>> {
272        SQLiteTransaction::new(self)
273    }
274}