tank_sqlite/
connection.rs

1use crate::{
2    CBox, SQLiteDriver, SQLitePrepared, SQLiteTransaction, error_message_from_ptr,
3    extract::{extract_name, extract_value},
4};
5use async_stream::try_stream;
6use flume::Sender;
7use libsqlite3_sys::*;
8use std::{
9    borrow::Cow,
10    ffi::{CStr, CString, c_char, c_int},
11    mem, ptr,
12    sync::{
13        Arc,
14        atomic::{AtomicPtr, Ordering},
15    },
16};
17use tank_core::{
18    AsQuery, Connection, Driver, Error, ErrorContext, Executor, Query, QueryResult, Result,
19    RowLabeled, RowsAffected, send_value, stream::Stream, truncate_long,
20};
21use tokio::task::spawn_blocking;
22
23pub struct SQLiteConnection {
24    pub(crate) connection: CBox<*mut sqlite3>,
25    pub(crate) _transaction: bool,
26}
27
28impl SQLiteConnection {
29    pub(crate) fn do_run_prepared(
30        connection: *mut sqlite3,
31        statement: *mut sqlite3_stmt,
32        tx: Sender<Result<QueryResult>>,
33    ) {
34        unsafe {
35            let count = sqlite3_column_count(statement);
36            let labels = match (0..count)
37                .map(|i| extract_name(statement, i))
38                .collect::<Result<Arc<[_]>>>()
39            {
40                Ok(labels) => labels,
41                Err(error) => {
42                    send_value!(tx, Err(error.into()));
43                    return;
44                }
45            };
46            loop {
47                match sqlite3_step(statement) {
48                    SQLITE_BUSY => {
49                        continue;
50                    }
51                    SQLITE_DONE => {
52                        if sqlite3_stmt_readonly(statement) == 0 {
53                            send_value!(
54                                tx,
55                                Ok(QueryResult::Affected(RowsAffected {
56                                    rows_affected: sqlite3_changes64(connection) as u64,
57                                    last_affected_id: Some(sqlite3_last_insert_rowid(connection)),
58                                }))
59                            );
60                        }
61                        break;
62                    }
63                    SQLITE_ROW => {
64                        let values = match (0..count)
65                            .map(|i| extract_value(statement, i))
66                            .collect::<Result<_>>()
67                        {
68                            Ok(value) => value,
69                            Err(error) => {
70                                send_value!(tx, Err(error));
71                                return;
72                            }
73                        };
74                        send_value!(
75                            tx,
76                            Ok(QueryResult::Row(RowLabeled {
77                                labels: labels.clone(),
78                                values: values,
79                            }))
80                        )
81                    }
82                    _ => {
83                        send_value!(
84                            tx,
85                            Err(Error::msg(
86                                error_message_from_ptr(&sqlite3_errmsg(sqlite3_db_handle(
87                                    statement,
88                                )))
89                                .to_string(),
90                            ))
91                        );
92                        return;
93                    }
94                }
95            }
96        }
97    }
98
99    pub(crate) fn do_run_unprepared(
100        connection: *mut sqlite3,
101        sql: &str,
102        tx: Sender<Result<QueryResult>>,
103    ) {
104        unsafe {
105            let sql = sql.trim();
106            let mut it = sql.as_ptr() as *const c_char;
107            let mut len = sql.len();
108            loop {
109                let (statement, tail) = {
110                    let mut statement = SQLitePrepared::new(CBox::new(ptr::null_mut(), |p| {
111                        sqlite3_finalize(p);
112                    }));
113                    let mut sql_tail = ptr::null();
114                    let rc = sqlite3_prepare_v2(
115                        connection,
116                        it,
117                        len as c_int,
118                        &mut *statement.statement,
119                        &mut sql_tail,
120                    );
121                    if rc != SQLITE_OK {
122                        send_value!(
123                            tx,
124                            Err(Error::msg(
125                                error_message_from_ptr(&sqlite3_errmsg(connection)).to_string(),
126                            ))
127                        );
128                        return;
129                    }
130                    (statement, sql_tail)
131                };
132                Self::do_run_prepared(connection, statement.statement(), tx.clone());
133                len = if tail != ptr::null() {
134                    len - tail.offset_from_unsigned(it)
135                } else {
136                    0
137                };
138                if len == 0 {
139                    break;
140                }
141                it = tail;
142            }
143        };
144    }
145}
146
147impl Executor for SQLiteConnection {
148    type Driver = SQLiteDriver;
149
150    fn driver(&self) -> &Self::Driver {
151        &SQLiteDriver {}
152    }
153
154    async fn prepare(&mut self, sql: String) -> Result<Query<Self::Driver>> {
155        let connection = AtomicPtr::new(*self.connection);
156        let context = format!(
157            "While preparing the query:\n{}",
158            truncate_long!(sql.as_str())
159        );
160        let prepared = spawn_blocking(move || unsafe {
161            let connection = connection.load(Ordering::Relaxed);
162            let len = sql.len();
163            let sql = CString::new(sql.as_bytes())?;
164            let mut statement = CBox::new(ptr::null_mut(), |p| {
165                sqlite3_finalize(p);
166            });
167            let mut tail = ptr::null();
168            let rc = sqlite3_prepare_v2(
169                connection,
170                sql.as_ptr(),
171                len as c_int,
172                &mut *statement,
173                &mut tail,
174            );
175            if rc != SQLITE_OK {
176                let error =
177                    Error::msg(error_message_from_ptr(&sqlite3_errmsg(connection)).to_string())
178                        .context(context);
179                log::error!("{:#}", error);
180                return Err(error);
181            }
182            if tail != ptr::null() && *tail != '\0' as i8 {
183                let error = Error::msg(format!(
184                    "Cannot prepare more than one statement at a time (remaining: {})",
185                    CStr::from_ptr(tail).to_str().unwrap_or("unprintable")
186                ))
187                .context(context);
188                log::error!("{:#}", error);
189                return Err(error);
190            }
191            Ok(statement)
192        })
193        .await?;
194        Ok(SQLitePrepared::new(prepared?).into())
195    }
196
197    fn run<'s>(
198        &'s mut self,
199        query: impl AsQuery<Self::Driver> + 's,
200    ) -> impl Stream<Item = Result<QueryResult>> + Send {
201        let mut query = query.as_query();
202        let context = Arc::new(format!("While executing the query:\n{}", query.as_mut()));
203        let (tx, rx) = flume::unbounded::<Result<QueryResult>>();
204        let connection = AtomicPtr::new(*self.connection);
205        let mut owned = mem::take(query.as_mut());
206        let join = spawn_blocking(move || {
207            match &mut owned {
208                Query::Raw(query) => {
209                    Self::do_run_unprepared(connection.load(Ordering::Relaxed), query, tx);
210                }
211                Query::Prepared(prepared) => Self::do_run_prepared(
212                    connection.load(Ordering::Relaxed),
213                    prepared.statement(),
214                    tx,
215                ),
216            }
217            owned
218        });
219        try_stream! {
220            while let Ok(result) = rx.recv_async().await {
221                yield result.map_err(|e| {
222                    let error = e.context(context.clone());
223                    log::error!("{:#}", error);
224                    error
225                })?;
226            }
227            *query.as_mut() = mem::take(&mut join.await?);
228        }
229    }
230}
231
232impl Connection for SQLiteConnection {
233    #[allow(refining_impl_trait)]
234    async fn connect(url: Cow<'static, str>) -> Result<SQLiteConnection> {
235        let context = || format!("While trying to connect to `{}`", truncate_long!(url));
236        let prefix = format!("{}://", <Self::Driver as Driver>::NAME);
237        if !url.starts_with(&prefix) {
238            let error = Error::msg(format!(
239                "SQLite connection url must start with `{}`",
240                &prefix
241            ))
242            .context(context());
243            log::error!("{:#}", error);
244            return Err(error);
245        }
246        let url = CString::new(format!("file:{}", url.trim_start_matches(&prefix)))
247            .with_context(context)?;
248        let mut connection;
249        unsafe {
250            connection = CBox::new(ptr::null_mut(), |p| {
251                if sqlite3_close(p) != SQLITE_OK {
252                    log::error!("Could not close sqlite connection")
253                }
254            });
255            let rc = sqlite3_open_v2(
256                url.as_ptr(),
257                &mut *connection,
258                SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | SQLITE_OPEN_URI,
259                ptr::null(),
260            );
261            if rc != SQLITE_OK {
262                let error =
263                    Error::msg(error_message_from_ptr(&sqlite3_errmsg(*connection)).to_string())
264                        .context(context());
265                log::error!("{:#}", error);
266                return Err(error);
267            }
268        }
269        Ok(Self {
270            connection,
271            _transaction: false,
272        })
273    }
274
275    #[allow(refining_impl_trait)]
276    fn begin(&mut self) -> impl Future<Output = Result<SQLiteTransaction<'_>>> {
277        SQLiteTransaction::new(self)
278    }
279}