Skip to main content

rbdc_sqlite/connection/
worker.rs

1use std::sync::atomic::{AtomicUsize, Ordering};
2use std::sync::Arc;
3use std::thread;
4
5use crate::connection::collation::create_collation;
6use crate::connection::establish::EstablishParams;
7use crate::connection::ConnectionState;
8use crate::connection::{execute, ConnectionHandleRaw};
9use crate::{SqliteArguments, SqliteQueryResult, SqliteRow, SqliteStatement};
10use crossfire::{spsc, AsyncTx};
11use either::Either;
12use futures_channel::oneshot;
13use futures_intrusive::sync::{Mutex, MutexGuard};
14use rbdc::error::Error;
15
16// Each SQLite connection has a dedicated thread.
17
18// TODO: Tweak this so that we can use a thread pool per pool of SQLite3 connections to reduce
19//       OS resource usage. Low priority because a high concurrent load for SQLite3 is very
20//       unlikely.
21
22pub(crate) struct ConnectionWorker {
23    command_tx: AsyncTx<crossfire::spsc::Array<Command>>,
24    /// The `sqlite3` pointer. NOTE: access is unsynchronized!
25    pub(crate) handle_raw: ConnectionHandleRaw,
26    /// Mutex for locking access to the database.
27    pub(crate) shared: Arc<WorkerSharedState>,
28}
29
30pub(crate) struct WorkerSharedState {
31    pub(crate) cached_statements_size: AtomicUsize,
32    pub(crate) conn: Mutex<ConnectionState>,
33}
34
35pub enum Command {
36    Prepare {
37        query: Box<str>,
38        tx: oneshot::Sender<Result<SqliteStatement, Error>>,
39    },
40    Execute {
41        query: Box<str>,
42        arguments: Option<SqliteArguments>,
43        persistent: bool,
44        tx: crossfire::Tx<
45            crossfire::spsc::Array<Result<Either<SqliteQueryResult, SqliteRow>, Error>>,
46        >,
47    },
48    CreateCollation {
49        create_collation:
50            Box<dyn FnOnce(&mut ConnectionState) -> Result<(), Error> + Send + Sync + 'static>,
51    },
52    UnlockDb,
53    ClearCache {
54        tx: oneshot::Sender<()>,
55    },
56    Ping {
57        tx: oneshot::Sender<()>,
58    },
59    Shutdown {
60        tx: oneshot::Sender<()>,
61    },
62}
63
64impl ConnectionWorker {
65    pub(crate) async fn establish(params: EstablishParams) -> Result<Self, Error> {
66        let (establish_tx, establish_rx) = oneshot::channel();
67
68        thread::Builder::new()
69            .name(params.thread_name.clone())
70            .spawn(move || {
71                let (command_tx, command_rx) =
72                    spsc::bounded_async_blocking(params.command_channel_size);
73
74                let conn = match params.establish() {
75                    Ok(conn) => conn,
76                    Err(e) => {
77                        establish_tx.send(Err(e)).ok();
78                        return;
79                    }
80                };
81
82                let shared = Arc::new(WorkerSharedState {
83                    cached_statements_size: AtomicUsize::new(0),
84                    // note: must be fair because in `Command::UnlockDb` we unlock the mutex
85                    // and then immediately try to relock it; an unfair mutex would immediately
86                    // grant us the lock even if another task is waiting.
87                    conn: Mutex::new(conn, true),
88                });
89                let mut conn = shared.conn.try_lock().unwrap();
90
91                if establish_tx
92                    .send(Ok(Self {
93                        command_tx,
94                        handle_raw: conn.handle.to_raw(),
95                        shared: Arc::clone(&shared),
96                    }))
97                    .is_err()
98                {
99                    return;
100                }
101
102                // Use blocking receiver in sync thread
103                loop {
104                    let cmd = match command_rx.recv() {
105                        Ok(cmd) => cmd,
106                        Err(_) => break, // channel closed
107                    };
108
109                    match cmd {
110                        Command::Prepare { query, tx } => {
111                            tx.send(prepare(&mut conn, &query).map(|prepared| {
112                                update_cached_statements_size(
113                                    &conn,
114                                    &shared.cached_statements_size,
115                                );
116                                prepared
117                            }))
118                            .ok();
119                        }
120                        Command::Execute {
121                            query,
122                            arguments,
123                            persistent,
124                            tx,
125                        } => {
126                            let iter = match execute::iter(&mut conn, &query, arguments, persistent)
127                            {
128                                Ok(iter) => iter,
129                                Err(e) => {
130                                    tx.send(Err(e)).ok();
131                                    continue;
132                                }
133                            };
134
135                            for res in iter {
136                                if tx.send(res).is_err() {
137                                    break;
138                                }
139                            }
140
141                            update_cached_statements_size(&conn, &shared.cached_statements_size);
142                        }
143                        Command::CreateCollation { create_collation } => {
144                            if let Err(e) = (create_collation)(&mut conn) {
145                                log::warn!("error applying collation in background worker: {}", e);
146                            }
147                        }
148                        Command::ClearCache { tx } => {
149                            conn.statements.clear();
150                            update_cached_statements_size(&conn, &shared.cached_statements_size);
151                            tx.send(()).ok();
152                        }
153                        Command::UnlockDb => {
154                            drop(conn);
155                            conn = futures_executor::block_on(shared.conn.lock());
156                        }
157                        Command::Ping { tx } => {
158                            tx.send(()).ok();
159                        }
160                        Command::Shutdown { tx } => {
161                            // drop the connection references before sending confirmation
162                            // and ending the command loop
163                            drop(conn);
164                            drop(shared);
165                            let _ = tx.send(());
166                            return;
167                        }
168                    }
169                }
170            })?;
171
172        establish_rx
173            .await
174            .map_err(|_| Error::from("WorkerCrashed"))?
175    }
176
177    pub(crate) async fn prepare(&mut self, query: &str) -> Result<SqliteStatement, Error> {
178        self.oneshot_cmd(|tx| Command::Prepare {
179            query: query.into(),
180            tx,
181        })
182        .await?
183    }
184
185    pub(crate) async fn execute(
186        &mut self,
187        query: String,
188        args: Option<SqliteArguments>,
189        chan_size: usize,
190        persistent: bool,
191    ) -> Result<
192        crossfire::AsyncRx<
193            crossfire::spsc::Array<Result<Either<SqliteQueryResult, SqliteRow>, Error>>,
194        >,
195        Error,
196    > {
197        let (tx, rx) = spsc::bounded_blocking_async(chan_size);
198
199        self.command_tx
200            .send(Command::Execute {
201                query: query.into(),
202                arguments: args.map(SqliteArguments::into_static),
203                persistent,
204                tx,
205            })
206            .await
207            .map_err(|_| Error::from("WorkerCrashed"))?;
208
209        Ok(rx)
210    }
211
212    pub(crate) async fn ping(&mut self) -> Result<(), Error> {
213        self.oneshot_cmd(|tx| Command::Ping { tx }).await
214    }
215
216    pub(crate) async fn oneshot_cmd<F, T>(&mut self, command: F) -> Result<T, Error>
217    where
218        F: FnOnce(oneshot::Sender<T>) -> Command,
219    {
220        let (tx, rx) = oneshot::channel();
221
222        self.command_tx
223            .send(command(tx))
224            .await
225            .map_err(|_| Error::from("WorkerCrashed"))?;
226
227        rx.await.map_err(|_| Error::from("WorkerCrashed"))
228    }
229
230    pub async fn create_collation(
231        &mut self,
232        name: &str,
233        compare: impl Fn(&str, &str) -> std::cmp::Ordering + Send + Sync + 'static,
234    ) -> Result<(), Error> {
235        let name = name.to_string();
236
237        self.command_tx
238            .send(Command::CreateCollation {
239                create_collation: Box::new(move |conn| {
240                    create_collation(&mut conn.handle, &name, compare)
241                }),
242            })
243            .await
244            .map_err(|_| Error::from("WorkerCrashed"))?;
245        Ok(())
246    }
247
248    pub(crate) async fn clear_cache(&mut self) -> Result<(), Error> {
249        self.oneshot_cmd(|tx| Command::ClearCache { tx }).await
250    }
251
252    pub(crate) async fn unlock_db(&mut self) -> Result<MutexGuard<'_, ConnectionState>, Error> {
253        let (guard, res) = futures_util::future::join(
254            // we need to join the wait queue for the lock before we send the message
255            self.shared.conn.lock(),
256            self.command_tx.send(Command::UnlockDb),
257        )
258        .await;
259
260        res.map_err(|_| Error::from("WorkerCrashed"))?;
261
262        Ok(guard)
263    }
264
265    /// Send a command to the worker to shut down the processing thread.
266    ///
267    /// A `WorkerCrashed` error may be returned if the thread has already stopped.
268    pub(crate) async fn shutdown(&mut self) -> Result<(), Error> {
269        let (tx, rx) = oneshot::channel();
270
271        self.command_tx
272            .send(Command::Shutdown { tx })
273            .await
274            .map_err(|_| Error::from("WorkerCrashed"))?;
275
276        // wait for the response
277        rx.await.map_err(|_| Error::from("WorkerCrashed"))
278    }
279}
280
281fn prepare(conn: &mut ConnectionState, query: &str) -> Result<SqliteStatement, Error> {
282    // prepare statement object (or checkout from cache)
283    let statement = conn.statements.get(query, true)?;
284
285    let mut parameters = 0;
286    let mut columns = None;
287    let mut column_names = None;
288
289    while let Some(statement) = statement.prepare_next(&mut conn.handle)? {
290        parameters += statement.handle.bind_parameter_count();
291
292        // the first non-empty statement is chosen as the statement we pull columns from
293        if !statement.columns.is_empty() && columns.is_none() {
294            columns = Some(Arc::clone(statement.columns));
295            column_names = Some(Arc::clone(statement.column_names));
296        }
297    }
298
299    Ok(SqliteStatement {
300        sql: query.to_string(),
301        columns: columns.unwrap_or_default(),
302        column_names: column_names.unwrap_or_default(),
303        parameters,
304    })
305}
306
307fn update_cached_statements_size(conn: &ConnectionState, size: &AtomicUsize) {
308    size.store(conn.statements.len(), Ordering::Release);
309}