1use std::sync::atomic::{AtomicUsize, Ordering};
2use std::sync::Arc;
3use std::thread;
4
5use crate::connection::collation::create_collation;
6use crate::connection::establish::EstablishParams;
7use crate::connection::ConnectionState;
8use crate::connection::{execute, ConnectionHandleRaw};
9use crate::{SqliteArguments, SqliteQueryResult, SqliteRow, SqliteStatement};
10use crossfire::{spsc, AsyncTx};
11use either::Either;
12use futures_channel::oneshot;
13use futures_intrusive::sync::{Mutex, MutexGuard};
14use rbdc::error::Error;
15
16pub(crate) struct ConnectionWorker {
23 command_tx: AsyncTx<crossfire::spsc::Array<Command>>,
24 pub(crate) handle_raw: ConnectionHandleRaw,
26 pub(crate) shared: Arc<WorkerSharedState>,
28}
29
30pub(crate) struct WorkerSharedState {
31 pub(crate) cached_statements_size: AtomicUsize,
32 pub(crate) conn: Mutex<ConnectionState>,
33}
34
35pub enum Command {
36 Prepare {
37 query: Box<str>,
38 tx: oneshot::Sender<Result<SqliteStatement, Error>>,
39 },
40 Execute {
41 query: Box<str>,
42 arguments: Option<SqliteArguments>,
43 persistent: bool,
44 tx: crossfire::Tx<
45 crossfire::spsc::Array<Result<Either<SqliteQueryResult, SqliteRow>, Error>>,
46 >,
47 },
48 CreateCollation {
49 create_collation:
50 Box<dyn FnOnce(&mut ConnectionState) -> Result<(), Error> + Send + Sync + 'static>,
51 },
52 UnlockDb,
53 ClearCache {
54 tx: oneshot::Sender<()>,
55 },
56 Ping {
57 tx: oneshot::Sender<()>,
58 },
59 Shutdown {
60 tx: oneshot::Sender<()>,
61 },
62}
63
64impl ConnectionWorker {
65 pub(crate) async fn establish(params: EstablishParams) -> Result<Self, Error> {
66 let (establish_tx, establish_rx) = oneshot::channel();
67
68 thread::Builder::new()
69 .name(params.thread_name.clone())
70 .spawn(move || {
71 let (command_tx, command_rx) =
72 spsc::bounded_async_blocking(params.command_channel_size);
73
74 let conn = match params.establish() {
75 Ok(conn) => conn,
76 Err(e) => {
77 establish_tx.send(Err(e)).ok();
78 return;
79 }
80 };
81
82 let shared = Arc::new(WorkerSharedState {
83 cached_statements_size: AtomicUsize::new(0),
84 conn: Mutex::new(conn, true),
88 });
89 let mut conn = shared.conn.try_lock().unwrap();
90
91 if establish_tx
92 .send(Ok(Self {
93 command_tx,
94 handle_raw: conn.handle.to_raw(),
95 shared: Arc::clone(&shared),
96 }))
97 .is_err()
98 {
99 return;
100 }
101
102 loop {
104 let cmd = match command_rx.recv() {
105 Ok(cmd) => cmd,
106 Err(_) => break, };
108
109 match cmd {
110 Command::Prepare { query, tx } => {
111 tx.send(prepare(&mut conn, &query).map(|prepared| {
112 update_cached_statements_size(
113 &conn,
114 &shared.cached_statements_size,
115 );
116 prepared
117 }))
118 .ok();
119 }
120 Command::Execute {
121 query,
122 arguments,
123 persistent,
124 tx,
125 } => {
126 let iter = match execute::iter(&mut conn, &query, arguments, persistent)
127 {
128 Ok(iter) => iter,
129 Err(e) => {
130 tx.send(Err(e)).ok();
131 continue;
132 }
133 };
134
135 for res in iter {
136 if tx.send(res).is_err() {
137 break;
138 }
139 }
140
141 update_cached_statements_size(&conn, &shared.cached_statements_size);
142 }
143 Command::CreateCollation { create_collation } => {
144 if let Err(e) = (create_collation)(&mut conn) {
145 log::warn!("error applying collation in background worker: {}", e);
146 }
147 }
148 Command::ClearCache { tx } => {
149 conn.statements.clear();
150 update_cached_statements_size(&conn, &shared.cached_statements_size);
151 tx.send(()).ok();
152 }
153 Command::UnlockDb => {
154 drop(conn);
155 conn = futures_executor::block_on(shared.conn.lock());
156 }
157 Command::Ping { tx } => {
158 tx.send(()).ok();
159 }
160 Command::Shutdown { tx } => {
161 drop(conn);
164 drop(shared);
165 let _ = tx.send(());
166 return;
167 }
168 }
169 }
170 })?;
171
172 establish_rx
173 .await
174 .map_err(|_| Error::from("WorkerCrashed"))?
175 }
176
177 pub(crate) async fn prepare(&mut self, query: &str) -> Result<SqliteStatement, Error> {
178 self.oneshot_cmd(|tx| Command::Prepare {
179 query: query.into(),
180 tx,
181 })
182 .await?
183 }
184
185 pub(crate) async fn execute(
186 &mut self,
187 query: String,
188 args: Option<SqliteArguments>,
189 chan_size: usize,
190 persistent: bool,
191 ) -> Result<
192 crossfire::AsyncRx<
193 crossfire::spsc::Array<Result<Either<SqliteQueryResult, SqliteRow>, Error>>,
194 >,
195 Error,
196 > {
197 let (tx, rx) = spsc::bounded_blocking_async(chan_size);
198
199 self.command_tx
200 .send(Command::Execute {
201 query: query.into(),
202 arguments: args.map(SqliteArguments::into_static),
203 persistent,
204 tx,
205 })
206 .await
207 .map_err(|_| Error::from("WorkerCrashed"))?;
208
209 Ok(rx)
210 }
211
212 pub(crate) async fn ping(&mut self) -> Result<(), Error> {
213 self.oneshot_cmd(|tx| Command::Ping { tx }).await
214 }
215
216 pub(crate) async fn oneshot_cmd<F, T>(&mut self, command: F) -> Result<T, Error>
217 where
218 F: FnOnce(oneshot::Sender<T>) -> Command,
219 {
220 let (tx, rx) = oneshot::channel();
221
222 self.command_tx
223 .send(command(tx))
224 .await
225 .map_err(|_| Error::from("WorkerCrashed"))?;
226
227 rx.await.map_err(|_| Error::from("WorkerCrashed"))
228 }
229
230 pub async fn create_collation(
231 &mut self,
232 name: &str,
233 compare: impl Fn(&str, &str) -> std::cmp::Ordering + Send + Sync + 'static,
234 ) -> Result<(), Error> {
235 let name = name.to_string();
236
237 self.command_tx
238 .send(Command::CreateCollation {
239 create_collation: Box::new(move |conn| {
240 create_collation(&mut conn.handle, &name, compare)
241 }),
242 })
243 .await
244 .map_err(|_| Error::from("WorkerCrashed"))?;
245 Ok(())
246 }
247
248 pub(crate) async fn clear_cache(&mut self) -> Result<(), Error> {
249 self.oneshot_cmd(|tx| Command::ClearCache { tx }).await
250 }
251
252 pub(crate) async fn unlock_db(&mut self) -> Result<MutexGuard<'_, ConnectionState>, Error> {
253 let (guard, res) = futures_util::future::join(
254 self.shared.conn.lock(),
256 self.command_tx.send(Command::UnlockDb),
257 )
258 .await;
259
260 res.map_err(|_| Error::from("WorkerCrashed"))?;
261
262 Ok(guard)
263 }
264
265 pub(crate) async fn shutdown(&mut self) -> Result<(), Error> {
269 let (tx, rx) = oneshot::channel();
270
271 self.command_tx
272 .send(Command::Shutdown { tx })
273 .await
274 .map_err(|_| Error::from("WorkerCrashed"))?;
275
276 rx.await.map_err(|_| Error::from("WorkerCrashed"))
278 }
279}
280
281fn prepare(conn: &mut ConnectionState, query: &str) -> Result<SqliteStatement, Error> {
282 let statement = conn.statements.get(query, true)?;
284
285 let mut parameters = 0;
286 let mut columns = None;
287 let mut column_names = None;
288
289 while let Some(statement) = statement.prepare_next(&mut conn.handle)? {
290 parameters += statement.handle.bind_parameter_count();
291
292 if !statement.columns.is_empty() && columns.is_none() {
294 columns = Some(Arc::clone(statement.columns));
295 column_names = Some(Arc::clone(statement.column_names));
296 }
297 }
298
299 Ok(SqliteStatement {
300 sql: query.to_string(),
301 columns: columns.unwrap_or_default(),
302 column_names: column_names.unwrap_or_default(),
303 parameters,
304 })
305}
306
307fn update_cached_statements_size(conn: &ConnectionState, size: &AtomicUsize) {
308 size.store(conn.statements.len(), Ordering::Release);
309}