1use crate::{
2 CBox, SQLiteDriver, SQLitePrepared, SQLiteTransaction, error_message_from_ptr,
3 extract::{extract_name, extract_value},
4};
5use async_stream::try_stream;
6use flume::Sender;
7use libsqlite3_sys::*;
8use std::{
9 borrow::Cow,
10 ffi::{CStr, CString, c_char, c_int},
11 mem, ptr,
12 sync::{
13 Arc,
14 atomic::{AtomicPtr, Ordering},
15 },
16};
17use tank_core::{
18 AsQuery, Connection, Driver, Error, ErrorContext, Executor, Query, QueryResult, Result,
19 RowLabeled, RowsAffected, send_value, stream::Stream, truncate_long,
20};
21use tokio::task::spawn_blocking;
22
23pub struct SQLiteConnection {
24 pub(crate) connection: CBox<*mut sqlite3>,
25 pub(crate) _transaction: bool,
26}
27
28impl SQLiteConnection {
29 pub(crate) fn do_run_prepared(
30 connection: *mut sqlite3,
31 statement: *mut sqlite3_stmt,
32 tx: Sender<Result<QueryResult>>,
33 ) {
34 unsafe {
35 let count = sqlite3_column_count(statement);
36 let labels = match (0..count)
37 .map(|i| extract_name(statement, i))
38 .collect::<Result<Arc<[_]>>>()
39 {
40 Ok(labels) => labels,
41 Err(error) => {
42 send_value!(tx, Err(error.into()));
43 return;
44 }
45 };
46 loop {
47 match sqlite3_step(statement) {
48 SQLITE_BUSY => {
49 continue;
50 }
51 SQLITE_DONE => {
52 if sqlite3_stmt_readonly(statement) == 0 {
53 send_value!(
54 tx,
55 Ok(QueryResult::Affected(RowsAffected {
56 rows_affected: sqlite3_changes64(connection) as u64,
57 last_affected_id: Some(sqlite3_last_insert_rowid(connection)),
58 }))
59 );
60 }
61 break;
62 }
63 SQLITE_ROW => {
64 let values = match (0..count)
65 .map(|i| extract_value(statement, i))
66 .collect::<Result<_>>()
67 {
68 Ok(value) => value,
69 Err(error) => {
70 send_value!(tx, Err(error));
71 return;
72 }
73 };
74 send_value!(
75 tx,
76 Ok(QueryResult::Row(RowLabeled {
77 labels: labels.clone(),
78 values: values,
79 }))
80 )
81 }
82 _ => {
83 send_value!(
84 tx,
85 Err(Error::msg(
86 error_message_from_ptr(&sqlite3_errmsg(sqlite3_db_handle(
87 statement,
88 )))
89 .to_string(),
90 ))
91 );
92 return;
93 }
94 }
95 }
96 }
97 }
98
99 pub(crate) fn do_run_unprepared(
100 connection: *mut sqlite3,
101 sql: &str,
102 tx: Sender<Result<QueryResult>>,
103 ) {
104 unsafe {
105 let sql = sql.trim();
106 let mut it = sql.as_ptr() as *const c_char;
107 let mut len = sql.len();
108 loop {
109 let (statement, tail) = {
110 let mut statement = SQLitePrepared::new(CBox::new(ptr::null_mut(), |p| {
111 sqlite3_finalize(p);
112 }));
113 let mut sql_tail = ptr::null();
114 let rc = sqlite3_prepare_v2(
115 connection,
116 it,
117 len as c_int,
118 &mut *statement.statement,
119 &mut sql_tail,
120 );
121 if rc != SQLITE_OK {
122 send_value!(
123 tx,
124 Err(Error::msg(
125 error_message_from_ptr(&sqlite3_errmsg(connection)).to_string(),
126 ))
127 );
128 return;
129 }
130 (statement, sql_tail)
131 };
132 Self::do_run_prepared(connection, statement.statement(), tx.clone());
133 len = if tail != ptr::null() {
134 len - tail.offset_from_unsigned(it)
135 } else {
136 0
137 };
138 if len == 0 {
139 break;
140 }
141 it = tail;
142 }
143 };
144 }
145}
146
147impl Executor for SQLiteConnection {
148 type Driver = SQLiteDriver;
149
150 fn driver(&self) -> &Self::Driver {
151 &SQLiteDriver {}
152 }
153
154 async fn prepare(&mut self, sql: String) -> Result<Query<Self::Driver>> {
155 let connection = AtomicPtr::new(*self.connection);
156 let context = format!(
157 "While preparing the query:\n{}",
158 truncate_long!(sql.as_str())
159 );
160 let prepared = spawn_blocking(move || unsafe {
161 let connection = connection.load(Ordering::Relaxed);
162 let len = sql.len();
163 let sql = CString::new(sql.as_bytes())?;
164 let mut statement = CBox::new(ptr::null_mut(), |p| {
165 sqlite3_finalize(p);
166 });
167 let mut tail = ptr::null();
168 let rc = sqlite3_prepare_v2(
169 connection,
170 sql.as_ptr(),
171 len as c_int,
172 &mut *statement,
173 &mut tail,
174 );
175 if rc != SQLITE_OK {
176 let error =
177 Error::msg(error_message_from_ptr(&sqlite3_errmsg(connection)).to_string())
178 .context(context);
179 log::error!("{:#}", error);
180 return Err(error);
181 }
182 if tail != ptr::null() && *tail != '\0' as i8 {
183 let error = Error::msg(format!(
184 "Cannot prepare more than one statement at a time (remaining: {})",
185 CStr::from_ptr(tail).to_str().unwrap_or("unprintable")
186 ))
187 .context(context);
188 log::error!("{:#}", error);
189 return Err(error);
190 }
191 Ok(statement)
192 })
193 .await?;
194 Ok(SQLitePrepared::new(prepared?).into())
195 }
196
197 fn run<'s>(
198 &'s mut self,
199 query: impl AsQuery<Self::Driver> + 's,
200 ) -> impl Stream<Item = Result<QueryResult>> + Send {
201 let mut query = query.as_query();
202 let context = Arc::new(format!("While executing the query:\n{}", query.as_mut()));
203 let (tx, rx) = flume::unbounded::<Result<QueryResult>>();
204 let connection = AtomicPtr::new(*self.connection);
205 let mut owned = mem::take(query.as_mut());
206 let join = spawn_blocking(move || {
207 match &mut owned {
208 Query::Raw(query) => {
209 Self::do_run_unprepared(connection.load(Ordering::Relaxed), query, tx);
210 }
211 Query::Prepared(prepared) => Self::do_run_prepared(
212 connection.load(Ordering::Relaxed),
213 prepared.statement(),
214 tx,
215 ),
216 }
217 owned
218 });
219 try_stream! {
220 while let Ok(result) = rx.recv_async().await {
221 yield result.map_err(|e| {
222 let error = e.context(context.clone());
223 log::error!("{:#}", error);
224 error
225 })?;
226 }
227 *query.as_mut() = mem::take(&mut join.await?);
228 }
229 }
230}
231
232impl Connection for SQLiteConnection {
233 #[allow(refining_impl_trait)]
234 async fn connect(url: Cow<'static, str>) -> Result<SQLiteConnection> {
235 let context = || format!("While trying to connect to `{}`", truncate_long!(url));
236 let prefix = format!("{}://", <Self::Driver as Driver>::NAME);
237 if !url.starts_with(&prefix) {
238 let error = Error::msg(format!(
239 "SQLite connection url must start with `{}`",
240 &prefix
241 ))
242 .context(context());
243 log::error!("{:#}", error);
244 return Err(error);
245 }
246 let url = CString::new(format!("file:{}", url.trim_start_matches(&prefix)))
247 .with_context(context)?;
248 let mut connection;
249 unsafe {
250 connection = CBox::new(ptr::null_mut(), |p| {
251 if sqlite3_close(p) != SQLITE_OK {
252 log::error!("Could not close sqlite connection")
253 }
254 });
255 let rc = sqlite3_open_v2(
256 url.as_ptr(),
257 &mut *connection,
258 SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | SQLITE_OPEN_URI,
259 ptr::null(),
260 );
261 if rc != SQLITE_OK {
262 let error =
263 Error::msg(error_message_from_ptr(&sqlite3_errmsg(*connection)).to_string())
264 .context(context());
265 log::error!("{:#}", error);
266 return Err(error);
267 }
268 }
269 Ok(Self {
270 connection,
271 _transaction: false,
272 })
273 }
274
275 #[allow(refining_impl_trait)]
276 fn begin(&mut self) -> impl Future<Output = Result<SQLiteTransaction<'_>>> {
277 SQLiteTransaction::new(self)
278 }
279}