1use std::future::Future;
2use std::sync::atomic::{AtomicUsize, Ordering};
3use std::sync::Arc;
4use std::thread;
5
6use crate::connection::collation::create_collation;
7use crate::connection::establish::EstablishParams;
8use crate::connection::ConnectionState;
9use crate::connection::{execute, ConnectionHandleRaw};
10use crate::{SqliteArguments, SqliteQueryResult, SqliteRow, SqliteStatement};
11use either::Either;
12use futures_channel::oneshot;
13use futures_intrusive::sync::{Mutex, MutexGuard};
14use rbdc::error::Error;
15
16pub(crate) struct ConnectionWorker {
23 command_tx: flume::Sender<Command>,
24 pub(crate) handle_raw: ConnectionHandleRaw,
26 pub(crate) shared: Arc<WorkerSharedState>,
28}
29
30pub(crate) struct WorkerSharedState {
31 pub(crate) cached_statements_size: AtomicUsize,
32 pub(crate) conn: Mutex<ConnectionState>,
33}
34
35pub enum Command {
36 Prepare {
37 query: Box<str>,
38 tx: oneshot::Sender<Result<SqliteStatement, Error>>,
39 },
40 Execute {
41 query: Box<str>,
42 arguments: Option<SqliteArguments>,
43 persistent: bool,
44 tx: flume::Sender<Result<Either<SqliteQueryResult, SqliteRow>, Error>>,
45 },
46 CreateCollation {
47 create_collation:
48 Box<dyn FnOnce(&mut ConnectionState) -> Result<(), Error> + Send + Sync + 'static>,
49 },
50 UnlockDb,
51 ClearCache {
52 tx: oneshot::Sender<()>,
53 },
54 Ping {
55 tx: oneshot::Sender<()>,
56 },
57 Shutdown {
58 tx: oneshot::Sender<()>,
59 },
60}
61
62impl ConnectionWorker {
63 pub(crate) async fn establish(params: EstablishParams) -> Result<Self, Error> {
64 let (establish_tx, establish_rx) = oneshot::channel();
65
66 thread::Builder::new()
67 .name(params.thread_name.clone())
68 .spawn(move || {
69 let (command_tx, command_rx) = flume::bounded(params.command_channel_size);
70
71 let conn = match params.establish() {
72 Ok(conn) => conn,
73 Err(e) => {
74 establish_tx.send(Err(e)).ok();
75 return;
76 }
77 };
78
79 let shared = Arc::new(WorkerSharedState {
80 cached_statements_size: AtomicUsize::new(0),
81 conn: Mutex::new(conn, true),
85 });
86 let mut conn = shared.conn.try_lock().unwrap();
87
88 if establish_tx
89 .send(Ok(Self {
90 command_tx,
91 handle_raw: conn.handle.to_raw(),
92 shared: Arc::clone(&shared),
93 }))
94 .is_err()
95 {
96 return;
97 }
98
99 for cmd in command_rx {
100 match cmd {
101 Command::Prepare { query, tx } => {
102 tx.send(prepare(&mut conn, &query).map(|prepared| {
103 update_cached_statements_size(
104 &conn,
105 &shared.cached_statements_size,
106 );
107 prepared
108 }))
109 .ok();
110 }
111 Command::Execute {
112 query,
113 arguments,
114 persistent,
115 tx,
116 } => {
117 let iter = match execute::iter(&mut conn, &query, arguments, persistent)
118 {
119 Ok(iter) => iter,
120 Err(e) => {
121 tx.send(Err(e)).ok();
122 continue;
123 }
124 };
125
126 for res in iter {
127 if tx.send(res).is_err() {
128 break;
129 }
130 }
131
132 update_cached_statements_size(&conn, &shared.cached_statements_size);
133 }
134 Command::CreateCollation { create_collation } => {
135 if let Err(e) = (create_collation)(&mut conn) {
136 log::warn!("error applying collation in background worker: {}", e);
137 }
138 }
139 Command::ClearCache { tx } => {
140 conn.statements.clear();
141 update_cached_statements_size(&conn, &shared.cached_statements_size);
142 tx.send(()).ok();
143 }
144 Command::UnlockDb => {
145 drop(conn);
146 conn = futures_executor::block_on(shared.conn.lock());
147 }
148 Command::Ping { tx } => {
149 tx.send(()).ok();
150 }
151 Command::Shutdown { tx } => {
152 drop(conn);
155 drop(shared);
156 let _ = tx.send(());
157 return;
158 }
159 }
160 }
161 })?;
162
163 establish_rx
164 .await
165 .map_err(|_| Error::from("WorkerCrashed"))?
166 }
167
168 pub(crate) async fn prepare(&mut self, query: &str) -> Result<SqliteStatement, Error> {
169 self.oneshot_cmd(|tx| Command::Prepare {
170 query: query.into(),
171 tx,
172 })
173 .await?
174 }
175
176 pub(crate) async fn execute(
177 &mut self,
178 query: String,
179 args: Option<SqliteArguments>,
180 chan_size: usize,
181 persistent: bool,
182 ) -> Result<flume::Receiver<Result<Either<SqliteQueryResult, SqliteRow>, Error>>, Error> {
183 let (tx, rx) = flume::bounded(chan_size);
184
185 self.command_tx
186 .send_async(Command::Execute {
187 query: query.into(),
188 arguments: args.map(SqliteArguments::into_static),
189 persistent,
190 tx,
191 })
192 .await
193 .map_err(|_| Error::from("WorkerCrashed"))?;
194
195 Ok(rx)
196 }
197
198 pub(crate) async fn ping(&mut self) -> Result<(), Error> {
199 self.oneshot_cmd(|tx| Command::Ping { tx }).await
200 }
201
202 pub(crate) async fn oneshot_cmd<F, T>(&mut self, command: F) -> Result<T, Error>
203 where
204 F: FnOnce(oneshot::Sender<T>) -> Command,
205 {
206 let (tx, rx) = oneshot::channel();
207
208 self.command_tx
209 .send_async(command(tx))
210 .await
211 .map_err(|_| Error::from("WorkerCrashed"))?;
212
213 rx.await.map_err(|_| Error::from("WorkerCrashed"))
214 }
215
216 pub fn create_collation(
217 &mut self,
218 name: &str,
219 compare: impl Fn(&str, &str) -> std::cmp::Ordering + Send + Sync + 'static,
220 ) -> Result<(), Error> {
221 let name = name.to_string();
222
223 self.command_tx
224 .send(Command::CreateCollation {
225 create_collation: Box::new(move |conn| {
226 create_collation(&mut conn.handle, &name, compare)
227 }),
228 })
229 .map_err(|_| Error::from("WorkerCrashed"))?;
230 Ok(())
231 }
232
233 pub(crate) async fn clear_cache(&mut self) -> Result<(), Error> {
234 self.oneshot_cmd(|tx| Command::ClearCache { tx }).await
235 }
236
237 pub(crate) async fn unlock_db(&mut self) -> Result<MutexGuard<'_, ConnectionState>, Error> {
238 let (guard, res) = futures_util::future::join(
239 self.shared.conn.lock(),
241 self.command_tx.send_async(Command::UnlockDb),
242 )
243 .await;
244
245 res.map_err(|_| Error::from("WorkerCrashed"))?;
246
247 Ok(guard)
248 }
249
250 pub(crate) fn shutdown(&mut self) -> impl Future<Output = Result<(), Error>> {
254 let (tx, rx) = oneshot::channel();
255
256 let send_res = self
257 .command_tx
258 .send(Command::Shutdown { tx })
259 .map_err(|_| Error::from("WorkerCrashed"));
260
261 async move {
262 send_res?;
263
264 rx.await.map_err(|_| Error::from("WorkerCrashed"))
266 }
267 }
268}
269
270fn prepare(conn: &mut ConnectionState, query: &str) -> Result<SqliteStatement, Error> {
271 let statement = conn.statements.get(query, true)?;
273
274 let mut parameters = 0;
275 let mut columns = None;
276 let mut column_names = None;
277
278 while let Some(statement) = statement.prepare_next(&mut conn.handle)? {
279 parameters += statement.handle.bind_parameter_count();
280
281 if !statement.columns.is_empty() && columns.is_none() {
283 columns = Some(Arc::clone(statement.columns));
284 column_names = Some(Arc::clone(statement.column_names));
285 }
286 }
287
288 Ok(SqliteStatement {
289 sql: query.to_string(),
290 columns: columns.unwrap_or_default(),
291 column_names: column_names.unwrap_or_default(),
292 parameters,
293 })
294}
295
296fn update_cached_statements_size(conn: &ConnectionState, size: &AtomicUsize) {
297 size.store(conn.statements.len(), Ordering::Release);
298}