rbdc_sqlite/connection/
worker.rs

1use std::sync::atomic::{AtomicUsize, Ordering};
2use std::sync::Arc;
3use std::thread;
4
5use crate::connection::collation::create_collation;
6use crate::connection::establish::EstablishParams;
7use crate::connection::ConnectionState;
8use crate::connection::{execute, ConnectionHandleRaw};
9use crate::{SqliteArguments, SqliteQueryResult, SqliteRow, SqliteStatement};
10use either::Either;
11use futures_channel::oneshot;
12use futures_intrusive::sync::{Mutex, MutexGuard};
13use rbdc::error::Error;
14use crossfire::{spsc, AsyncTx};
15
16// Each SQLite connection has a dedicated thread.
17
18// TODO: Tweak this so that we can use a thread pool per pool of SQLite3 connections to reduce
19//       OS resource usage. Low priority because a high concurrent load for SQLite3 is very
20//       unlikely.
21
22pub(crate) struct ConnectionWorker {
23    command_tx: AsyncTx<crossfire::spsc::Array<Command>>,
24    /// The `sqlite3` pointer. NOTE: access is unsynchronized!
25    pub(crate) handle_raw: ConnectionHandleRaw,
26    /// Mutex for locking access to the database.
27    pub(crate) shared: Arc<WorkerSharedState>,
28}
29
30pub(crate) struct WorkerSharedState {
31    pub(crate) cached_statements_size: AtomicUsize,
32    pub(crate) conn: Mutex<ConnectionState>,
33}
34
35pub enum Command {
36    Prepare {
37        query: Box<str>,
38        tx: oneshot::Sender<Result<SqliteStatement, Error>>,
39    },
40    Execute {
41        query: Box<str>,
42        arguments: Option<SqliteArguments>,
43        persistent: bool,
44        tx: crossfire::Tx<crossfire::spsc::Array<Result<Either<SqliteQueryResult, SqliteRow>, Error>>>,
45    },
46    CreateCollation {
47        create_collation:
48            Box<dyn FnOnce(&mut ConnectionState) -> Result<(), Error> + Send + Sync + 'static>,
49    },
50    UnlockDb,
51    ClearCache {
52        tx: oneshot::Sender<()>,
53    },
54    Ping {
55        tx: oneshot::Sender<()>,
56    },
57    Shutdown {
58        tx: oneshot::Sender<()>,
59    },
60}
61
62impl ConnectionWorker {
63    pub(crate) async fn establish(params: EstablishParams) -> Result<Self, Error> {
64        let (establish_tx, establish_rx) = oneshot::channel();
65
66        thread::Builder::new()
67            .name(params.thread_name.clone())
68            .spawn(move || {
69                let (command_tx, command_rx) = spsc::bounded_async_blocking(params.command_channel_size);
70
71                let conn = match params.establish() {
72                    Ok(conn) => conn,
73                    Err(e) => {
74                        establish_tx.send(Err(e)).ok();
75                        return;
76                    }
77                };
78
79                let shared = Arc::new(WorkerSharedState {
80                    cached_statements_size: AtomicUsize::new(0),
81                    // note: must be fair because in `Command::UnlockDb` we unlock the mutex
82                    // and then immediately try to relock it; an unfair mutex would immediately
83                    // grant us the lock even if another task is waiting.
84                    conn: Mutex::new(conn, true),
85                });
86                let mut conn = shared.conn.try_lock().unwrap();
87
88                if establish_tx
89                    .send(Ok(Self {
90                        command_tx,
91                        handle_raw: conn.handle.to_raw(),
92                        shared: Arc::clone(&shared),
93                    }))
94                    .is_err()
95                {
96                    return;
97                }
98
99                // Use blocking receiver in sync thread
100                loop {
101                    let cmd = match command_rx.recv() {
102                        Ok(cmd) => cmd,
103                        Err(_) => break, // channel closed
104                    };
105
106                    match cmd {
107                        Command::Prepare { query, tx } => {
108                            tx.send(prepare(&mut conn, &query).map(|prepared| {
109                                update_cached_statements_size(
110                                    &conn,
111                                    &shared.cached_statements_size,
112                                );
113                                prepared
114                            }))
115                            .ok();
116                        }
117                        Command::Execute {
118                            query,
119                            arguments,
120                            persistent,
121                            tx,
122                        } => {
123                            let iter = match execute::iter(&mut conn, &query, arguments, persistent)
124                            {
125                                Ok(iter) => iter,
126                                Err(e) => {
127                                    tx.send(Err(e)).ok();
128                                    continue;
129                                }
130                            };
131
132                            for res in iter {
133                                if tx.send(res).is_err() {
134                                    break;
135                                }
136                            }
137
138                            update_cached_statements_size(&conn, &shared.cached_statements_size);
139                        }
140                        Command::CreateCollation { create_collation } => {
141                            if let Err(e) = (create_collation)(&mut conn) {
142                                log::warn!("error applying collation in background worker: {}", e);
143                            }
144                        }
145                        Command::ClearCache { tx } => {
146                            conn.statements.clear();
147                            update_cached_statements_size(&conn, &shared.cached_statements_size);
148                            tx.send(()).ok();
149                        }
150                        Command::UnlockDb => {
151                            drop(conn);
152                            conn = futures_executor::block_on(shared.conn.lock());
153                        }
154                        Command::Ping { tx } => {
155                            tx.send(()).ok();
156                        }
157                        Command::Shutdown { tx } => {
158                            // drop the connection references before sending confirmation
159                            // and ending the command loop
160                            drop(conn);
161                            drop(shared);
162                            let _ = tx.send(());
163                            return;
164                        }
165                    }
166                }
167            })?;
168
169        establish_rx
170            .await
171            .map_err(|_| Error::from("WorkerCrashed"))?
172    }
173
174    pub(crate) async fn prepare(&mut self, query: &str) -> Result<SqliteStatement, Error> {
175        self.oneshot_cmd(|tx| Command::Prepare {
176            query: query.into(),
177            tx,
178        })
179        .await?
180    }
181
182    pub(crate) async fn execute(
183        &mut self,
184        query: String,
185        args: Option<SqliteArguments>,
186        chan_size: usize,
187        persistent: bool,
188    ) -> Result<crossfire::AsyncRx<crossfire::spsc::Array<Result<Either<SqliteQueryResult, SqliteRow>, Error>>>, Error> {
189        let (tx, rx) = spsc::bounded_blocking_async(chan_size);
190
191        self.command_tx
192            .send(Command::Execute {
193                query: query.into(),
194                arguments: args.map(SqliteArguments::into_static),
195                persistent,
196                tx,
197            })
198            .await
199            .map_err(|_| Error::from("WorkerCrashed"))?;
200
201        Ok(rx)
202    }
203
204    pub(crate) async fn ping(&mut self) -> Result<(), Error> {
205        self.oneshot_cmd(|tx| Command::Ping { tx }).await
206    }
207
208    pub(crate) async fn oneshot_cmd<F, T>(&mut self, command: F) -> Result<T, Error>
209    where
210        F: FnOnce(oneshot::Sender<T>) -> Command,
211    {
212        let (tx, rx) = oneshot::channel();
213
214        self.command_tx
215            .send(command(tx))
216            .await
217            .map_err(|_| Error::from("WorkerCrashed"))?;
218
219        rx.await.map_err(|_| Error::from("WorkerCrashed"))
220    }
221
222    pub async fn create_collation(
223        &mut self,
224        name: &str,
225        compare: impl Fn(&str, &str) -> std::cmp::Ordering + Send + Sync + 'static,
226    ) -> Result<(), Error> {
227        let name = name.to_string();
228
229        self.command_tx
230            .send(Command::CreateCollation {
231                create_collation: Box::new(move |conn| {
232                    create_collation(&mut conn.handle, &name, compare)
233                }),
234            })
235            .await
236            .map_err(|_| Error::from("WorkerCrashed"))?;
237        Ok(())
238    }
239
240    pub(crate) async fn clear_cache(&mut self) -> Result<(), Error> {
241        self.oneshot_cmd(|tx| Command::ClearCache { tx }).await
242    }
243
244    pub(crate) async fn unlock_db(&mut self) -> Result<MutexGuard<'_, ConnectionState>, Error> {
245        let (guard, res) = futures_util::future::join(
246            // we need to join the wait queue for the lock before we send the message
247            self.shared.conn.lock(),
248            self.command_tx.send(Command::UnlockDb),
249        )
250        .await;
251
252        res.map_err(|_| Error::from("WorkerCrashed"))?;
253
254        Ok(guard)
255    }
256
257    /// Send a command to the worker to shut down the processing thread.
258    ///
259    /// A `WorkerCrashed` error may be returned if the thread has already stopped.
260    pub(crate) async fn shutdown(&mut self) -> Result<(), Error> {
261        let (tx, rx) = oneshot::channel();
262
263        self.command_tx
264            .send(Command::Shutdown { tx })
265            .await
266            .map_err(|_| Error::from("WorkerCrashed"))?;
267
268        // wait for the response
269        rx.await.map_err(|_| Error::from("WorkerCrashed"))
270    }
271}
272
273fn prepare(conn: &mut ConnectionState, query: &str) -> Result<SqliteStatement, Error> {
274    // prepare statement object (or checkout from cache)
275    let statement = conn.statements.get(query, true)?;
276
277    let mut parameters = 0;
278    let mut columns = None;
279    let mut column_names = None;
280
281    while let Some(statement) = statement.prepare_next(&mut conn.handle)? {
282        parameters += statement.handle.bind_parameter_count();
283
284        // the first non-empty statement is chosen as the statement we pull columns from
285        if !statement.columns.is_empty() && columns.is_none() {
286            columns = Some(Arc::clone(statement.columns));
287            column_names = Some(Arc::clone(statement.column_names));
288        }
289    }
290
291    Ok(SqliteStatement {
292        sql: query.to_string(),
293        columns: columns.unwrap_or_default(),
294        column_names: column_names.unwrap_or_default(),
295        parameters,
296    })
297}
298
299fn update_cached_statements_size(conn: &ConnectionState, size: &AtomicUsize) {
300    size.store(conn.statements.len(), Ordering::Release);
301}