Skip to main content

tank_sqlite/
connection.rs

1use crate::{
2    CBox, SQLiteDriver, SQLitePrepared, SQLiteTransaction,
3    extract::{extract_name, extract_value},
4};
5use async_stream::try_stream;
6use flume::Sender;
7use libsqlite3_sys::*;
8use std::{
9    borrow::Cow,
10    ffi::{CStr, CString, c_char, c_int},
11    mem, ptr,
12    str::FromStr,
13    sync::{
14        Arc,
15        atomic::{AtomicPtr, Ordering},
16    },
17};
18use tank_core::{
19    AsQuery, Connection, Error, ErrorContext, Executor, Prepared, Query, QueryResult, RawQuery,
20    Result, Row, RowsAffected, error_message_from_ptr, send_value, stream::Stream,
21    truncate_long,
22};
23use tokio::task::spawn_blocking;
24
25/// Wrapper for a SQLite `sqlite3` connection pointer used by the SQLite driver.
26///
27/// Provides helpers to prepare/execute statements and stream results into `tank_core` result types.
28pub struct SQLiteConnection {
29    pub(crate) connection: CBox<*mut sqlite3>,
30}
31
32impl SQLiteConnection {
33    pub fn last_error(&self) -> String {
34        unsafe {
35            let errcode = sqlite3_errcode(*self.connection);
36            format!(
37                "Error ({errcode}): {}",
38                error_message_from_ptr(&sqlite3_errmsg(*self.connection)),
39            )
40        }
41    }
42
43    pub(crate) fn do_run_prepared(
44        connection: *mut sqlite3,
45        statement: *mut sqlite3_stmt,
46        tx: Sender<Result<QueryResult>>,
47    ) {
48        unsafe {
49            let count = sqlite3_column_count(statement);
50            let labels = match (0..count)
51                .map(|i| extract_name(statement, i))
52                .collect::<Result<Arc<[_]>>>()
53            {
54                Ok(labels) => labels,
55                Err(error) => {
56                    send_value!(tx, Err(error.into()));
57                    return;
58                }
59            };
60            loop {
61                match sqlite3_step(statement) {
62                    SQLITE_BUSY => {
63                        continue;
64                    }
65                    SQLITE_DONE => {
66                        if sqlite3_stmt_readonly(statement) == 0 {
67                            send_value!(
68                                tx,
69                                Ok(QueryResult::Affected(RowsAffected {
70                                    rows_affected: Some(sqlite3_changes64(connection) as _),
71                                    last_affected_id: Some(sqlite3_last_insert_rowid(connection)),
72                                }))
73                            );
74                        }
75                        break;
76                    }
77                    SQLITE_ROW => {
78                        let values = match (0..count)
79                            .map(|i| extract_value(statement, i))
80                            .collect::<Result<_>>()
81                        {
82                            Ok(value) => value,
83                            Err(error) => {
84                                send_value!(tx, Err(error));
85                                return;
86                            }
87                        };
88                        send_value!(
89                            tx,
90                            Ok(QueryResult::Row(Row {
91                                labels: labels.clone(),
92                                values: values,
93                            }))
94                        )
95                    }
96                    _ => {
97                        send_value!(
98                            tx,
99                            Err(Error::msg(
100                                error_message_from_ptr(&sqlite3_errmsg(sqlite3_db_handle(
101                                    statement,
102                                )))
103                                .to_string(),
104                            ))
105                        );
106                        break;
107                    }
108                }
109            }
110        }
111    }
112
113    pub(crate) fn do_run_unprepared(
114        connection: *mut sqlite3,
115        sql: &str,
116        tx: Sender<Result<QueryResult>>,
117    ) {
118        unsafe {
119            let sql = sql.trim();
120            let mut it = sql.as_ptr() as *const c_char;
121            let mut len = sql.len();
122            loop {
123                let (statement, tail) = {
124                    let mut statement = SQLitePrepared::new(CBox::new(ptr::null_mut(), |p| {
125                        let rc = sqlite3_finalize(p);
126                        if rc != SQLITE_OK {
127                            return;
128                        }
129                    }));
130                    let mut sql_tail = ptr::null();
131                    let rc = sqlite3_prepare_v2(
132                        connection,
133                        it,
134                        len as c_int,
135                        &mut *statement.statement,
136                        &mut sql_tail,
137                    );
138                    if rc != SQLITE_OK {
139                        send_value!(
140                            tx,
141                            Err(Error::msg(
142                                error_message_from_ptr(&sqlite3_errmsg(connection)).to_string(),
143                            ))
144                        );
145                        return;
146                    }
147                    (statement, sql_tail)
148                };
149                Self::do_run_prepared(connection, statement.statement(), tx.clone());
150                len = if tail != ptr::null() {
151                    len - tail.offset_from_unsigned(it)
152                } else {
153                    0
154                };
155                if len == 0 {
156                    break;
157                }
158                it = tail;
159            }
160        };
161    }
162}
163
164impl Executor for SQLiteConnection {
165    type Driver = SQLiteDriver;
166
167    async fn do_prepare(&mut self, sql: String) -> Result<Query<SQLiteDriver>> {
168        let connection = AtomicPtr::new(*self.connection);
169        let context = format!("While preparing the query:\n{}", truncate_long!(sql));
170        let prepared = spawn_blocking(move || unsafe {
171            let connection = connection.load(Ordering::Relaxed);
172            let len = sql.len();
173            let sql = CString::new(sql.into_bytes())?;
174            let mut statement = CBox::new(ptr::null_mut(), |p| {
175                let db = sqlite3_db_handle(p);
176                let rc = sqlite3_finalize(p);
177                if rc != SQLITE_OK {
178                    let error = Error::msg(error_message_from_ptr(&sqlite3_errmsg(db)).to_string())
179                        .context("While finalizing a prepared statement");
180                    log::error!("{error:#}");
181                }
182            });
183            let mut tail = ptr::null();
184            let rc = sqlite3_prepare_v2(
185                connection,
186                sql.as_ptr(),
187                len as c_int,
188                &mut *statement,
189                &mut tail,
190            );
191            if rc != SQLITE_OK {
192                let error =
193                    Error::msg(error_message_from_ptr(&sqlite3_errmsg(connection)).to_string())
194                        .context(context);
195                log::error!("{:#}", error);
196                return Err(error);
197            }
198            if tail != ptr::null() && *tail != '\0' as i8 {
199                let error = Error::msg(format!(
200                    "Cannot prepare more than one statement at a time (remaining: {})",
201                    CStr::from_ptr(tail).to_str().unwrap_or("unprintable")
202                ))
203                .context(context);
204                log::error!("{:#}", error);
205                return Err(error);
206            }
207            Ok(statement)
208        })
209        .await?;
210        Ok(SQLitePrepared::new(prepared?).into())
211    }
212
213    fn run<'s>(
214        &'s mut self,
215        query: impl AsQuery<SQLiteDriver> + 's,
216    ) -> impl Stream<Item = Result<QueryResult>> + Send {
217        let mut query = query.as_query();
218        let context = Arc::new(format!("While running the query:\n{}", query.as_mut()));
219        let (tx, rx) = flume::unbounded::<Result<QueryResult>>();
220        let connection = AtomicPtr::new(*self.connection);
221        let mut owned = mem::take(query.as_mut());
222        let join = spawn_blocking(move || {
223            match &mut owned {
224                Query::Raw(RawQuery(sql)) => {
225                    Self::do_run_unprepared(connection.load(Ordering::Relaxed), sql, tx);
226                }
227                Query::Prepared(prepared) => {
228                    Self::do_run_prepared(
229                        connection.load(Ordering::Relaxed),
230                        prepared.statement(),
231                        tx,
232                    );
233                    let _ = prepared.clear_bindings();
234                }
235            }
236            owned
237        });
238        try_stream! {
239            while let Ok(result) = rx.recv_async().await {
240                yield result.map_err(|e| {
241                    let error = e.context(context.clone());
242                    log::error!("{:#}", error);
243                    error
244                })?;
245            }
246            *query.as_mut() = mem::take(&mut join.await?);
247        }
248    }
249}
250
251impl Connection for SQLiteConnection {
252    async fn connect(driver: &SQLiteDriver, url: Cow<'static, str>) -> Result<Self> {
253        let context = format!("While trying to connect to `{}`", truncate_long!(url));
254        let url = Self::sanitize_url(driver, url)?;
255        let url = CString::from_str(&url.as_str().replacen("sqlite://", "file:", 1))
256            .with_context(|| context.clone())?;
257        let mut connection;
258        unsafe {
259            connection = CBox::new(ptr::null_mut(), |p| {
260                if sqlite3_close(p) != SQLITE_OK {
261                    let error = Error::msg(error_message_from_ptr(&sqlite3_errmsg(p)).to_string())
262                        .context("While closing the sqlite connection");
263                    log::error!("{error:#}");
264                }
265            });
266            let rc = sqlite3_open_v2(
267                url.as_ptr(),
268                &mut *connection,
269                SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | SQLITE_OPEN_URI,
270                ptr::null(),
271            );
272            if rc != SQLITE_OK {
273                let error =
274                    Error::msg(error_message_from_ptr(&sqlite3_errmsg(*connection)).to_string())
275                        .context(context);
276                log::error!("{:#}", error);
277                return Err(error);
278            }
279        }
280        Ok(Self { connection })
281    }
282
283    fn begin(&mut self) -> impl Future<Output = Result<SQLiteTransaction<'_>>> {
284        SQLiteTransaction::new(self)
285    }
286}