1use std::sync::atomic::{AtomicUsize, Ordering};
2use std::sync::Arc;
3use std::thread;
4
5use crate::connection::collation::create_collation;
6use crate::connection::establish::EstablishParams;
7use crate::connection::ConnectionState;
8use crate::connection::{execute, ConnectionHandleRaw};
9use crate::{SqliteArguments, SqliteQueryResult, SqliteRow, SqliteStatement};
10use either::Either;
11use futures_channel::oneshot;
12use futures_intrusive::sync::{Mutex, MutexGuard};
13use rbdc::error::Error;
14use crossfire::{spsc, AsyncTx};
15
16pub(crate) struct ConnectionWorker {
23 command_tx: AsyncTx<crossfire::spsc::Array<Command>>,
24 pub(crate) handle_raw: ConnectionHandleRaw,
26 pub(crate) shared: Arc<WorkerSharedState>,
28}
29
30pub(crate) struct WorkerSharedState {
31 pub(crate) cached_statements_size: AtomicUsize,
32 pub(crate) conn: Mutex<ConnectionState>,
33}
34
35pub enum Command {
36 Prepare {
37 query: Box<str>,
38 tx: oneshot::Sender<Result<SqliteStatement, Error>>,
39 },
40 Execute {
41 query: Box<str>,
42 arguments: Option<SqliteArguments>,
43 persistent: bool,
44 tx: crossfire::Tx<crossfire::spsc::Array<Result<Either<SqliteQueryResult, SqliteRow>, Error>>>,
45 },
46 CreateCollation {
47 create_collation:
48 Box<dyn FnOnce(&mut ConnectionState) -> Result<(), Error> + Send + Sync + 'static>,
49 },
50 UnlockDb,
51 ClearCache {
52 tx: oneshot::Sender<()>,
53 },
54 Ping {
55 tx: oneshot::Sender<()>,
56 },
57 Shutdown {
58 tx: oneshot::Sender<()>,
59 },
60}
61
62impl ConnectionWorker {
63 pub(crate) async fn establish(params: EstablishParams) -> Result<Self, Error> {
64 let (establish_tx, establish_rx) = oneshot::channel();
65
66 thread::Builder::new()
67 .name(params.thread_name.clone())
68 .spawn(move || {
69 let (command_tx, command_rx) = spsc::bounded_async_blocking(params.command_channel_size);
70
71 let conn = match params.establish() {
72 Ok(conn) => conn,
73 Err(e) => {
74 establish_tx.send(Err(e)).ok();
75 return;
76 }
77 };
78
79 let shared = Arc::new(WorkerSharedState {
80 cached_statements_size: AtomicUsize::new(0),
81 conn: Mutex::new(conn, true),
85 });
86 let mut conn = shared.conn.try_lock().unwrap();
87
88 if establish_tx
89 .send(Ok(Self {
90 command_tx,
91 handle_raw: conn.handle.to_raw(),
92 shared: Arc::clone(&shared),
93 }))
94 .is_err()
95 {
96 return;
97 }
98
99 loop {
101 let cmd = match command_rx.recv() {
102 Ok(cmd) => cmd,
103 Err(_) => break, };
105
106 match cmd {
107 Command::Prepare { query, tx } => {
108 tx.send(prepare(&mut conn, &query).map(|prepared| {
109 update_cached_statements_size(
110 &conn,
111 &shared.cached_statements_size,
112 );
113 prepared
114 }))
115 .ok();
116 }
117 Command::Execute {
118 query,
119 arguments,
120 persistent,
121 tx,
122 } => {
123 let iter = match execute::iter(&mut conn, &query, arguments, persistent)
124 {
125 Ok(iter) => iter,
126 Err(e) => {
127 tx.send(Err(e)).ok();
128 continue;
129 }
130 };
131
132 for res in iter {
133 if tx.send(res).is_err() {
134 break;
135 }
136 }
137
138 update_cached_statements_size(&conn, &shared.cached_statements_size);
139 }
140 Command::CreateCollation { create_collation } => {
141 if let Err(e) = (create_collation)(&mut conn) {
142 log::warn!("error applying collation in background worker: {}", e);
143 }
144 }
145 Command::ClearCache { tx } => {
146 conn.statements.clear();
147 update_cached_statements_size(&conn, &shared.cached_statements_size);
148 tx.send(()).ok();
149 }
150 Command::UnlockDb => {
151 drop(conn);
152 conn = futures_executor::block_on(shared.conn.lock());
153 }
154 Command::Ping { tx } => {
155 tx.send(()).ok();
156 }
157 Command::Shutdown { tx } => {
158 drop(conn);
161 drop(shared);
162 let _ = tx.send(());
163 return;
164 }
165 }
166 }
167 })?;
168
169 establish_rx
170 .await
171 .map_err(|_| Error::from("WorkerCrashed"))?
172 }
173
174 pub(crate) async fn prepare(&mut self, query: &str) -> Result<SqliteStatement, Error> {
175 self.oneshot_cmd(|tx| Command::Prepare {
176 query: query.into(),
177 tx,
178 })
179 .await?
180 }
181
182 pub(crate) async fn execute(
183 &mut self,
184 query: String,
185 args: Option<SqliteArguments>,
186 chan_size: usize,
187 persistent: bool,
188 ) -> Result<crossfire::AsyncRx<crossfire::spsc::Array<Result<Either<SqliteQueryResult, SqliteRow>, Error>>>, Error> {
189 let (tx, rx) = spsc::bounded_blocking_async(chan_size);
190
191 self.command_tx
192 .send(Command::Execute {
193 query: query.into(),
194 arguments: args.map(SqliteArguments::into_static),
195 persistent,
196 tx,
197 })
198 .await
199 .map_err(|_| Error::from("WorkerCrashed"))?;
200
201 Ok(rx)
202 }
203
204 pub(crate) async fn ping(&mut self) -> Result<(), Error> {
205 self.oneshot_cmd(|tx| Command::Ping { tx }).await
206 }
207
208 pub(crate) async fn oneshot_cmd<F, T>(&mut self, command: F) -> Result<T, Error>
209 where
210 F: FnOnce(oneshot::Sender<T>) -> Command,
211 {
212 let (tx, rx) = oneshot::channel();
213
214 self.command_tx
215 .send(command(tx))
216 .await
217 .map_err(|_| Error::from("WorkerCrashed"))?;
218
219 rx.await.map_err(|_| Error::from("WorkerCrashed"))
220 }
221
222 pub async fn create_collation(
223 &mut self,
224 name: &str,
225 compare: impl Fn(&str, &str) -> std::cmp::Ordering + Send + Sync + 'static,
226 ) -> Result<(), Error> {
227 let name = name.to_string();
228
229 self.command_tx
230 .send(Command::CreateCollation {
231 create_collation: Box::new(move |conn| {
232 create_collation(&mut conn.handle, &name, compare)
233 }),
234 })
235 .await
236 .map_err(|_| Error::from("WorkerCrashed"))?;
237 Ok(())
238 }
239
240 pub(crate) async fn clear_cache(&mut self) -> Result<(), Error> {
241 self.oneshot_cmd(|tx| Command::ClearCache { tx }).await
242 }
243
244 pub(crate) async fn unlock_db(&mut self) -> Result<MutexGuard<'_, ConnectionState>, Error> {
245 let (guard, res) = futures_util::future::join(
246 self.shared.conn.lock(),
248 self.command_tx.send(Command::UnlockDb),
249 )
250 .await;
251
252 res.map_err(|_| Error::from("WorkerCrashed"))?;
253
254 Ok(guard)
255 }
256
257 pub(crate) async fn shutdown(&mut self) -> Result<(), Error> {
261 let (tx, rx) = oneshot::channel();
262
263 self.command_tx
264 .send(Command::Shutdown { tx })
265 .await
266 .map_err(|_| Error::from("WorkerCrashed"))?;
267
268 rx.await.map_err(|_| Error::from("WorkerCrashed"))
270 }
271}
272
273fn prepare(conn: &mut ConnectionState, query: &str) -> Result<SqliteStatement, Error> {
274 let statement = conn.statements.get(query, true)?;
276
277 let mut parameters = 0;
278 let mut columns = None;
279 let mut column_names = None;
280
281 while let Some(statement) = statement.prepare_next(&mut conn.handle)? {
282 parameters += statement.handle.bind_parameter_count();
283
284 if !statement.columns.is_empty() && columns.is_none() {
286 columns = Some(Arc::clone(statement.columns));
287 column_names = Some(Arc::clone(statement.column_names));
288 }
289 }
290
291 Ok(SqliteStatement {
292 sql: query.to_string(),
293 columns: columns.unwrap_or_default(),
294 column_names: column_names.unwrap_or_default(),
295 parameters,
296 })
297}
298
299fn update_cached_statements_size(conn: &ConnectionState, size: &AtomicUsize) {
300 size.store(conn.statements.len(), Ordering::Release);
301}