tank_sqlite/
connection.rs

1use crate::{
2    CBox, SQLiteDriver, SQLitePrepared, SQLiteTransaction, error_message_from_ptr,
3    extract::{extract_name, extract_value},
4};
5use async_stream::try_stream;
6use flume::Sender;
7use libsqlite3_sys::*;
8use std::{
9    borrow::Cow,
10    ffi::{CStr, CString, c_char, c_int},
11    mem, ptr,
12    sync::{
13        Arc,
14        atomic::{AtomicPtr, Ordering},
15    },
16};
17use tank_core::{
18    AsQuery, Connection, Driver, Error, ErrorContext, Executor, Query, QueryResult, Result,
19    RowLabeled, RowsAffected, send_value, stream::Stream, truncate_long,
20};
21use tokio::task::spawn_blocking;
22
23pub struct SQLiteConnection {
24    pub(crate) connection: CBox<*mut sqlite3>,
25    pub(crate) _transaction: bool,
26}
27
28impl SQLiteConnection {
29    pub fn last_error(&self) -> String {
30        unsafe {
31            let errcode = sqlite3_errcode(*self.connection);
32            format!(
33                "Error ({errcode}): {}",
34                error_message_from_ptr(&sqlite3_errmsg(*self.connection)).to_string(),
35            )
36        }
37    }
38
39    pub(crate) fn do_run_prepared(
40        connection: *mut sqlite3,
41        statement: *mut sqlite3_stmt,
42        tx: Sender<Result<QueryResult>>,
43    ) {
44        unsafe {
45            let count = sqlite3_column_count(statement);
46            let labels = match (0..count)
47                .map(|i| extract_name(statement, i))
48                .collect::<Result<Arc<[_]>>>()
49            {
50                Ok(labels) => labels,
51                Err(error) => {
52                    send_value!(tx, Err(error.into()));
53                    return;
54                }
55            };
56            loop {
57                match sqlite3_step(statement) {
58                    SQLITE_BUSY => {
59                        continue;
60                    }
61                    SQLITE_DONE => {
62                        if sqlite3_stmt_readonly(statement) == 0 {
63                            send_value!(
64                                tx,
65                                Ok(QueryResult::Affected(RowsAffected {
66                                    rows_affected: sqlite3_changes64(connection) as u64,
67                                    last_affected_id: Some(sqlite3_last_insert_rowid(connection)),
68                                }))
69                            );
70                        }
71                        break;
72                    }
73                    SQLITE_ROW => {
74                        let values = match (0..count)
75                            .map(|i| extract_value(statement, i))
76                            .collect::<Result<_>>()
77                        {
78                            Ok(value) => value,
79                            Err(error) => {
80                                send_value!(tx, Err(error));
81                                return;
82                            }
83                        };
84                        send_value!(
85                            tx,
86                            Ok(QueryResult::Row(RowLabeled {
87                                labels: labels.clone(),
88                                values: values,
89                            }))
90                        )
91                    }
92                    _ => {
93                        send_value!(
94                            tx,
95                            Err(Error::msg(
96                                error_message_from_ptr(&sqlite3_errmsg(sqlite3_db_handle(
97                                    statement,
98                                )))
99                                .to_string(),
100                            ))
101                        );
102                        return;
103                    }
104                }
105            }
106        }
107    }
108
109    pub(crate) fn do_run_unprepared(
110        connection: *mut sqlite3,
111        sql: &str,
112        tx: Sender<Result<QueryResult>>,
113    ) {
114        unsafe {
115            let sql = sql.trim();
116            let mut it = sql.as_ptr() as *const c_char;
117            let mut len = sql.len();
118            loop {
119                let (statement, tail) = {
120                    let mut statement = SQLitePrepared::new(CBox::new(ptr::null_mut(), |p| {
121                        sqlite3_finalize(p);
122                    }));
123                    let mut sql_tail = ptr::null();
124                    let rc = sqlite3_prepare_v2(
125                        connection,
126                        it,
127                        len as c_int,
128                        &mut *statement.statement,
129                        &mut sql_tail,
130                    );
131                    if rc != SQLITE_OK {
132                        send_value!(
133                            tx,
134                            Err(Error::msg(
135                                error_message_from_ptr(&sqlite3_errmsg(connection)).to_string(),
136                            ))
137                        );
138                        return;
139                    }
140                    (statement, sql_tail)
141                };
142                Self::do_run_prepared(connection, statement.statement(), tx.clone());
143                len = if tail != ptr::null() {
144                    len - tail.offset_from_unsigned(it)
145                } else {
146                    0
147                };
148                if len == 0 {
149                    break;
150                }
151                it = tail;
152            }
153        };
154    }
155}
156
157impl Executor for SQLiteConnection {
158    type Driver = SQLiteDriver;
159
160    fn driver(&self) -> &Self::Driver {
161        &SQLiteDriver {}
162    }
163
164    async fn prepare(&mut self, sql: String) -> Result<Query<Self::Driver>> {
165        let connection = AtomicPtr::new(*self.connection);
166        let context = format!(
167            "While preparing the query:\n{}",
168            truncate_long!(sql.as_str())
169        );
170        let prepared = spawn_blocking(move || unsafe {
171            let connection = connection.load(Ordering::Relaxed);
172            let len = sql.len();
173            let sql = CString::new(sql.as_bytes())?;
174            let mut statement = CBox::new(ptr::null_mut(), |p| {
175                sqlite3_finalize(p);
176            });
177            let mut tail = ptr::null();
178            let rc = sqlite3_prepare_v2(
179                connection,
180                sql.as_ptr(),
181                len as c_int,
182                &mut *statement,
183                &mut tail,
184            );
185            if rc != SQLITE_OK {
186                let error =
187                    Error::msg(error_message_from_ptr(&sqlite3_errmsg(connection)).to_string())
188                        .context(context);
189                log::error!("{:#}", error);
190                return Err(error);
191            }
192            if tail != ptr::null() && *tail != '\0' as i8 {
193                let error = Error::msg(format!(
194                    "Cannot prepare more than one statement at a time (remaining: {})",
195                    CStr::from_ptr(tail).to_str().unwrap_or("unprintable")
196                ))
197                .context(context);
198                log::error!("{:#}", error);
199                return Err(error);
200            }
201            Ok(statement)
202        })
203        .await?;
204        Ok(SQLitePrepared::new(prepared?).into())
205    }
206
207    fn run<'s>(
208        &'s mut self,
209        query: impl AsQuery<Self::Driver> + 's,
210    ) -> impl Stream<Item = Result<QueryResult>> + Send {
211        let mut query = query.as_query();
212        let context = Arc::new(format!("While executing the query:\n{}", query.as_mut()));
213        let (tx, rx) = flume::unbounded::<Result<QueryResult>>();
214        let connection = AtomicPtr::new(*self.connection);
215        let mut owned = mem::take(query.as_mut());
216        let join = spawn_blocking(move || {
217            match &mut owned {
218                Query::Raw(query) => {
219                    Self::do_run_unprepared(connection.load(Ordering::Relaxed), query, tx);
220                }
221                Query::Prepared(prepared) => Self::do_run_prepared(
222                    connection.load(Ordering::Relaxed),
223                    prepared.statement(),
224                    tx,
225                ),
226            }
227            owned
228        });
229        try_stream! {
230            while let Ok(result) = rx.recv_async().await {
231                yield result.map_err(|e| {
232                    let error = e.context(context.clone());
233                    log::error!("{:#}", error);
234                    error
235                })?;
236            }
237            *query.as_mut() = mem::take(&mut join.await?);
238        }
239    }
240}
241
242impl Connection for SQLiteConnection {
243    #[allow(refining_impl_trait)]
244    async fn connect(url: Cow<'static, str>) -> Result<SQLiteConnection> {
245        let context = || format!("While trying to connect to `{}`", truncate_long!(url));
246        let prefix = format!("{}://", <Self::Driver as Driver>::NAME);
247        if !url.starts_with(&prefix) {
248            let error = Error::msg(format!(
249                "SQLite connection url must start with `{}`",
250                &prefix
251            ))
252            .context(context());
253            log::error!("{:#}", error);
254            return Err(error);
255        }
256        let url = CString::new(format!("file:{}", url.trim_start_matches(&prefix)))
257            .with_context(context)?;
258        let mut connection;
259        unsafe {
260            connection = CBox::new(ptr::null_mut(), |p| {
261                if sqlite3_close(p) != SQLITE_OK {
262                    log::error!("Could not close sqlite connection")
263                }
264            });
265            let rc = sqlite3_open_v2(
266                url.as_ptr(),
267                &mut *connection,
268                SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | SQLITE_OPEN_URI,
269                ptr::null(),
270            );
271            if rc != SQLITE_OK {
272                let error =
273                    Error::msg(error_message_from_ptr(&sqlite3_errmsg(*connection)).to_string())
274                        .context(context());
275                log::error!("{:#}", error);
276                return Err(error);
277            }
278        }
279        Ok(Self {
280            connection,
281            _transaction: false,
282        })
283    }
284
285    #[allow(refining_impl_trait)]
286    fn begin(&mut self) -> impl Future<Output = Result<SQLiteTransaction<'_>>> {
287        SQLiteTransaction::new(self)
288    }
289}