tank_sqlite/
connection.rs

1use crate::{
2    CBox, SqliteDriver, SqlitePrepared, SqliteTransaction, error_message_from_ptr,
3    extract::{extract_name, extract_value},
4};
5use async_stream::{stream, try_stream};
6use libsqlite3_sys::*;
7use std::{
8    borrow::Cow,
9    ffi::{CStr, CString, c_char, c_int},
10    pin::pin,
11    ptr,
12    sync::{
13        Arc,
14        atomic::{AtomicPtr, Ordering},
15    },
16};
17use tank_core::{
18    Connection, Driver, Error, ErrorContext, Executor, Query, QueryResult, Result, RowLabeled,
19    RowsAffected,
20    future::Either,
21    truncate_long,
22    stream::{Stream, StreamExt, TryStreamExt},
23};
24use tokio::task::spawn_blocking;
25
26pub struct SqliteConnection {
27    pub(crate) connection: CBox<*mut sqlite3>,
28    pub(crate) _transaction: bool,
29}
30
31impl SqliteConnection {
32    pub(crate) fn run_prepared(
33        &mut self,
34        statement: CBox<*mut sqlite3_stmt>,
35    ) -> impl Stream<Item = Result<QueryResult>> {
36        unsafe {
37            stream! {
38                let count = sqlite3_column_count(*statement);
39                let labels = (0..count)
40                    .map(|i| extract_name(*statement, i))
41                    .collect::<Result<Arc<[_]>>>()?;
42                loop {
43                    match sqlite3_step(*statement) {
44                        SQLITE_BUSY => {
45                            continue;
46                        }
47                        SQLITE_DONE => {
48                            if sqlite3_stmt_readonly(*statement) == 0 {
49                                yield Ok(QueryResult::Affected(RowsAffected {
50                                    rows_affected: sqlite3_changes64(*self.connection) as u64,
51                                    last_affected_id: Some(sqlite3_last_insert_rowid(*self.connection)),
52                                }))
53                            }
54                            break;
55                        }
56                        SQLITE_ROW => {
57                            yield Ok(QueryResult::Row(RowLabeled {
58                                labels: labels.clone(),
59                                values: (0..count).map(|i| extract_value(*statement, i)).collect()?,
60                            }))
61                        }
62                        _ => {
63                            let error = Error::msg(
64                                error_message_from_ptr(&sqlite3_errmsg(sqlite3_db_handle(*statement)))
65                                    .to_string(),
66                            );
67                            yield Err(error);
68                        }
69                    }
70                }
71            }
72        }
73    }
74
75    pub(crate) fn run_unprepared(
76        &mut self,
77        sql: String,
78    ) -> impl Stream<Item = Result<QueryResult>> {
79        try_stream! {
80            let mut len = sql.trim_end().len();
81            let buff = sql.into_bytes();
82            let mut it = CBox::new(buff.as_ptr() as *const c_char, |_| {});
83            loop {
84                let connection = CBox::new(*self.connection, |_| {});
85                let sql = CBox::new(*it, |_| {});
86                let (statement, tail) = spawn_blocking(move || unsafe {
87                    let mut statement = CBox::new(ptr::null_mut(), |p| {
88                        sqlite3_finalize(p);
89                    });
90                    let mut sql_tail = CBox::new(ptr::null(), |_| {});
91                    let rc = sqlite3_prepare_v2(
92                        *connection,
93                        *sql,
94                        len as c_int,
95                        &mut *statement,
96                        &mut *sql_tail,
97                    );
98                    if rc != SQLITE_OK {
99                        return Err(Error::msg(
100                            error_message_from_ptr(&sqlite3_errmsg(*connection)).to_string(),
101                        ));
102                    }
103                    Ok((statement, sql_tail))
104                })
105                .await??;
106                let mut stream = pin!(self.run_prepared(statement));
107                while let Some(value) = stream.next().await {
108                    yield value?
109                }
110                unsafe {
111                    len = if *tail != ptr::null() {
112                        len - tail.offset_from_unsigned(*it)
113                    } else {
114                        0
115                    };
116                    if len == 0 {
117                        break;
118                    }
119                }
120                *it = *tail;
121            }
122        }
123    }
124}
125
126impl Executor for SqliteConnection {
127    type Driver = SqliteDriver;
128
129    fn driver(&self) -> &Self::Driver {
130        &SqliteDriver {}
131    }
132
133    async fn prepare(&mut self, sql: String) -> Result<Query<Self::Driver>> {
134        let connection = AtomicPtr::new(*self.connection);
135        let context = format!(
136            "While preparing the query:\n{}",
137            truncate_long!(sql.as_str())
138        );
139        let prepared = spawn_blocking(move || unsafe {
140            let connection = connection.load(Ordering::Relaxed);
141            let len = sql.len();
142            let sql = CString::new(sql.as_bytes())?;
143            let mut statement = CBox::new(ptr::null_mut(), |p| {
144                sqlite3_finalize(p);
145            });
146            let mut tail = ptr::null();
147            let rc = sqlite3_prepare_v2(
148                connection,
149                sql.as_ptr(),
150                len as c_int,
151                &mut *statement,
152                &mut tail,
153            );
154            if rc != SQLITE_OK {
155                let error =
156                    Error::msg(error_message_from_ptr(&sqlite3_errmsg(connection)).to_string())
157                        .context(context);
158                log::error!("{:#}", error);
159                return Err(error);
160            }
161            if tail != ptr::null() && *tail != '\0' as i8 {
162                let error = Error::msg(format!(
163                    "Cannot prepare more than one statement at a time (remaining: {})",
164                    CStr::from_ptr(tail).to_str().unwrap_or("unprintable")
165                ))
166                .context(context);
167                log::error!("{:#}", error);
168                return Err(error);
169            }
170            Ok(statement)
171        })
172        .await?;
173        Ok(SqlitePrepared::new(prepared?).into())
174    }
175
176    fn run(
177        &mut self,
178        query: Query<Self::Driver>,
179    ) -> impl Stream<Item = Result<QueryResult>> + Send {
180        let context = Arc::new(format!("While executing the query:\n{}", query));
181        match query {
182            Query::Raw(sql) => Either::Left(self.run_unprepared(sql)),
183            Query::Prepared(prepared) => Either::Right(self.run_prepared(prepared.statement)),
184        }
185        .map_err(move |e| {
186            let e = e.context(context.clone());
187            log::error!("{:#}", e);
188            e
189        })
190    }
191}
192
193impl Connection for SqliteConnection {
194    #[allow(refining_impl_trait)]
195    async fn connect(url: Cow<'static, str>) -> Result<SqliteConnection> {
196        let prefix = format!("{}://", <Self::Driver as Driver>::NAME);
197        if !url.starts_with(&prefix) {
198            let error = Error::msg(format!(
199                "Sqlite connection url must start with `{}`",
200                &prefix
201            ));
202            log::error!("{:#}", error);
203            return Err(error);
204        }
205        let url = CString::new(format!("file:{}", url.trim_start_matches(&prefix)))
206            .with_context(|| format!("Invalid database url `{}`", url))?;
207        let mut connection;
208        unsafe {
209            connection = CBox::new(ptr::null_mut(), |p| {
210                if sqlite3_close(p) != SQLITE_OK {
211                    log::error!("Could not close sqlite connection")
212                }
213            });
214            let rc = sqlite3_open_v2(
215                url.as_ptr(),
216                &mut *connection,
217                SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | SQLITE_OPEN_URI,
218                ptr::null(),
219            );
220            if rc != SQLITE_OK {
221                let error =
222                    Error::msg(error_message_from_ptr(&sqlite3_errmsg(*connection)).to_string())
223                        .context(format!(
224                            "Failed to connect to database url `{}`",
225                            url.to_str().unwrap_or("unprintable value")
226                        ));
227                log::error!("{:#}", error);
228                return Err(error);
229            }
230        }
231        Ok(Self {
232            connection,
233            _transaction: false,
234        })
235    }
236
237    #[allow(refining_impl_trait)]
238    fn begin(&mut self) -> impl Future<Output = Result<SqliteTransaction<'_>>> {
239        SqliteTransaction::new(self)
240    }
241}