1use std::{
2 path::{Path, PathBuf},
3 thread,
4};
5
6use crate::Error;
7
8use crossbeam_channel::{bounded, unbounded, Receiver, Sender, TrySendError};
9use futures_channel::oneshot;
10use rusqlite::{Connection, OpenFlags};
11
12#[derive(Clone, Debug, Default)]
32pub struct ClientBuilder {
33 pub(crate) path: Option<PathBuf>,
34 pub(crate) flags: OpenFlags,
35 pub(crate) journal_mode: Option<JournalMode>,
36 pub(crate) vfs: Option<String>,
37 pub(crate) queue_capacity: Option<usize>,
38}
39
40impl ClientBuilder {
41 pub fn new() -> Self {
43 Self::default()
44 }
45
46 pub fn path<P: AsRef<Path>>(mut self, path: P) -> Self {
50 self.path = Some(path.as_ref().into());
51 self
52 }
53
54 pub fn flags(mut self, flags: OpenFlags) -> Self {
58 self.flags = flags;
59 self
60 }
61
62 pub fn journal_mode(mut self, journal_mode: JournalMode) -> Self {
66 self.journal_mode = Some(journal_mode);
67 self
68 }
69
70 pub fn vfs(mut self, vfs: &str) -> Self {
72 self.vfs = Some(vfs.to_owned());
73 self
74 }
75
76 pub fn queue_capacity(mut self, queue_capacity: usize) -> Self {
83 self.queue_capacity = Some(queue_capacity);
84 self
85 }
86
87 pub async fn open(self) -> Result<Client, Error> {
99 Client::open_async(self).await
100 }
101
102 pub fn open_blocking(self) -> Result<Client, Error> {
115 Client::open_blocking(self)
116 }
117}
118
119enum Command {
120 Func(Box<dyn QueuedFunc>),
121 Shutdown(Box<dyn QueuedShutdown>),
122}
123
124trait QueuedFunc: Send {
125 fn is_canceled(&self) -> bool;
126 fn execute(self: Box<Self>, conn: &mut Connection);
127}
128
129struct AsyncFunc<F, T, E> {
130 tx: oneshot::Sender<Result<T, E>>,
131 func: F,
132}
133
134impl<F, T, E> QueuedFunc for AsyncFunc<F, T, E>
135where
136 F: FnOnce(&mut Connection) -> Result<T, E> + Send + 'static,
137 T: Send + 'static,
138 E: Send + 'static,
139{
140 fn is_canceled(&self) -> bool {
141 self.tx.is_canceled()
142 }
143
144 fn execute(self: Box<Self>, conn: &mut Connection) {
145 let Self { tx, func } = *self;
146 _ = tx.send(func(conn));
147 }
148}
149
150struct BlockingFunc<F, T, E> {
151 tx: Sender<Result<T, E>>,
152 func: F,
153}
154
155impl<F, T, E> QueuedFunc for BlockingFunc<F, T, E>
156where
157 F: FnOnce(&mut Connection) -> Result<T, E> + Send + 'static,
158 T: Send + 'static,
159 E: Send + 'static,
160{
161 fn is_canceled(&self) -> bool {
162 false
163 }
164
165 fn execute(self: Box<Self>, conn: &mut Connection) {
166 let Self { tx, func } = *self;
167 _ = tx.send(func(conn));
168 }
169}
170
171trait QueuedShutdown: Send {
172 fn is_canceled(&self) -> bool;
173 fn respond(self: Box<Self>, res: Result<(), Error>);
174}
175
176struct AsyncShutdown {
177 tx: oneshot::Sender<Result<(), Error>>,
178}
179
180impl QueuedShutdown for AsyncShutdown {
181 fn is_canceled(&self) -> bool {
182 self.tx.is_canceled()
183 }
184
185 fn respond(self: Box<Self>, res: Result<(), Error>) {
186 _ = self.tx.send(res);
187 }
188}
189
190struct BlockingShutdown {
191 tx: Sender<Result<(), Error>>,
192}
193
194impl QueuedShutdown for BlockingShutdown {
195 fn is_canceled(&self) -> bool {
196 false
197 }
198
199 fn respond(self: Box<Self>, res: Result<(), Error>) {
200 _ = self.tx.send(res);
201 }
202}
203
204fn run_catching<F, T>(conn: &mut Connection, func: F) -> Result<T, Error>
205where
206 F: FnOnce(&mut Connection) -> Result<T, rusqlite::Error>,
207{
208 match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| func(conn))) {
209 Ok(res) => res.map_err(Error::from),
210 Err(p) => {
211 rollback_if_needed(conn);
212 Err(Error::Panic {
213 message: panic_message(&*p),
214 })
215 }
216 }
217}
218
219fn run_catching_and_then<F, T, E>(conn: &mut Connection, func: F) -> Result<T, E>
220where
221 F: FnOnce(&mut Connection) -> Result<T, E>,
222 E: From<Error>,
223{
224 match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| func(conn))) {
225 Ok(res) => res,
226 Err(p) => {
227 rollback_if_needed(conn);
228 Err(E::from(Error::Panic {
229 message: panic_message(&*p),
230 }))
231 }
232 }
233}
234
235fn rollback_if_needed(conn: &mut Connection) {
236 if !conn.is_autocommit() {
237 let _ = conn.execute_batch("ROLLBACK");
238 }
239}
240
241fn panic_message(p: &(dyn std::any::Any + Send)) -> String {
242 if let Some(s) = p.downcast_ref::<&'static str>() {
243 (*s).to_owned()
244 } else if let Some(s) = p.downcast_ref::<String>() {
245 s.clone()
246 } else {
247 "panic".to_owned()
248 }
249}
250
251#[derive(Clone)]
254pub struct Client {
255 conn_tx: Sender<Command>,
256}
257
258impl Client {
259 async fn open_async(builder: ClientBuilder) -> Result<Self, Error> {
260 let (open_tx, open_rx) = oneshot::channel();
261 Self::open(builder, |res| {
262 _ = open_tx.send(res);
263 });
264 open_rx.await?
265 }
266
267 fn open_blocking(builder: ClientBuilder) -> Result<Self, Error> {
268 let (conn_tx, conn_rx) = bounded(1);
269 Self::open(builder, move |res| {
270 _ = conn_tx.send(res);
271 });
272 conn_rx.recv()?
273 }
274
275 fn open<F>(builder: ClientBuilder, func: F)
276 where
277 F: FnOnce(Result<Self, Error>) + Send + 'static,
278 {
279 thread::spawn(move || {
280 let (conn_tx, conn_rx) = match builder.queue_capacity {
281 Some(queue_capacity) => bounded(queue_capacity),
282 None => unbounded(),
283 };
284
285 let mut conn = match Client::create_conn(builder) {
286 Ok(conn) => conn,
287 Err(err) => {
288 func(Err(err));
289 return;
290 }
291 };
292
293 let client = Self { conn_tx };
294 func(Ok(client));
295
296 while let Ok(cmd) = conn_rx.recv() {
297 match cmd {
298 Command::Func(func) => {
299 if !func.is_canceled() {
300 func.execute(&mut conn);
301 }
302 }
303 Command::Shutdown(func) => {
304 if !func.is_canceled() {
305 match conn.close() {
306 Ok(()) => {
307 func.respond(Ok(()));
308 return;
309 }
310 Err((c, e)) => {
311 conn = c;
312 func.respond(Err(e.into()));
313 }
314 }
315 }
316 }
317 }
318 }
319 });
320 }
321
322 fn create_conn(mut builder: ClientBuilder) -> Result<Connection, Error> {
323 let path = builder.path.take().unwrap_or_else(|| ":memory:".into());
324 let conn = if let Some(vfs) = builder.vfs.take() {
325 Connection::open_with_flags_and_vfs(path, builder.flags, vfs.as_str())?
326 } else {
327 Connection::open_with_flags(path, builder.flags)?
328 };
329
330 if let Some(journal_mode) = builder.journal_mode.take() {
331 let val = journal_mode.as_str();
332 let out: String =
333 conn.pragma_update_and_check(None, "journal_mode", val, |row| row.get(0))?;
334 if !out.eq_ignore_ascii_case(val) {
335 return Err(Error::PragmaUpdate {
336 name: "journal_mode",
337 exp: val,
338 got: out,
339 });
340 }
341 }
342
343 Ok(conn)
344 }
345
346 fn enqueue_async<F, T, E>(
347 &self,
348 func: F,
349 ) -> Result<oneshot::Receiver<Result<T, E>>, TrySendError<Command>>
350 where
351 F: FnOnce(&mut Connection) -> Result<T, E> + Send + 'static,
352 T: Send + 'static,
353 E: Send + 'static,
354 {
355 let (tx, rx) = oneshot::channel();
356 self.conn_tx
357 .try_send(Command::Func(Box::new(AsyncFunc { tx, func })))?;
358 Ok(rx)
359 }
360
361 fn enqueue_blocking<F, T, E>(
362 &self,
363 func: F,
364 ) -> Result<Receiver<Result<T, E>>, TrySendError<Command>>
365 where
366 F: FnOnce(&mut Connection) -> Result<T, E> + Send + 'static,
367 T: Send + 'static,
368 E: Send + 'static,
369 {
370 let (tx, rx) = bounded(1);
371 self.conn_tx
372 .try_send(Command::Func(Box::new(BlockingFunc { tx, func })))?;
373 Ok(rx)
374 }
375
376 pub async fn conn<F, T>(&self, func: F) -> Result<T, Error>
378 where
379 F: FnOnce(&Connection) -> Result<T, rusqlite::Error> + Send + 'static,
380 T: Send + 'static,
381 {
382 let rx = self
383 .enqueue_async(move |conn| run_catching(conn, |conn| func(conn)))
384 .map_err(Error::from)?;
385 rx.await?
386 }
387
388 pub async fn conn_mut<F, T>(&self, func: F) -> Result<T, Error>
390 where
391 F: FnOnce(&mut Connection) -> Result<T, rusqlite::Error> + Send + 'static,
392 T: Send + 'static,
393 {
394 let rx = self
395 .enqueue_async(move |conn| run_catching(conn, func))
396 .map_err(Error::from)?;
397 rx.await?
398 }
399
400 pub async fn conn_and_then<F, T, E>(&self, func: F) -> Result<T, E>
405 where
406 F: FnOnce(&Connection) -> Result<T, E> + Send + 'static,
407 T: Send + 'static,
408 E: From<rusqlite::Error> + From<Error> + Send + 'static,
409 {
410 let rx = self
411 .enqueue_async(move |conn| run_catching_and_then(conn, |conn| func(conn)))
412 .map_err(Error::from)?;
413 rx.await.map_err(Error::from)?
414 }
415
416 pub async fn conn_mut_and_then<F, T, E>(&self, func: F) -> Result<T, E>
421 where
422 F: FnOnce(&mut Connection) -> Result<T, E> + Send + 'static,
423 T: Send + 'static,
424 E: From<rusqlite::Error> + From<Error> + Send + 'static,
425 {
426 let rx = self
427 .enqueue_async(move |conn| run_catching_and_then(conn, func))
428 .map_err(Error::from)?;
429 rx.await.map_err(Error::from)?
430 }
431
432 pub async fn close(&self) -> Result<(), Error> {
437 let (tx, rx) = oneshot::channel();
438 match self
439 .conn_tx
440 .try_send(Command::Shutdown(Box::new(AsyncShutdown { tx })))
441 {
442 Ok(()) => {}
443 Err(TrySendError::Disconnected(_)) => {
444 return Ok(());
446 }
447 Err(err) => return Err(err.into()),
448 }
449 rx.await.unwrap_or(Ok(()))
451 }
452
453 pub fn conn_blocking<F, T>(&self, func: F) -> Result<T, Error>
456 where
457 F: FnOnce(&Connection) -> Result<T, rusqlite::Error> + Send + 'static,
458 T: Send + 'static,
459 {
460 let rx = self
461 .enqueue_blocking(move |conn| run_catching(conn, |conn| func(conn)))
462 .map_err(Error::from)?;
463 rx.recv()?
464 }
465
466 pub fn conn_mut_blocking<F, T>(&self, func: F) -> Result<T, Error>
469 where
470 F: FnOnce(&mut Connection) -> Result<T, rusqlite::Error> + Send + 'static,
471 T: Send + 'static,
472 {
473 let rx = self
474 .enqueue_blocking(move |conn| run_catching(conn, func))
475 .map_err(Error::from)?;
476 rx.recv()?
477 }
478
479 pub fn conn_and_then_blocking<F, T, E>(&self, func: F) -> Result<T, E>
485 where
486 F: FnOnce(&Connection) -> Result<T, E> + Send + 'static,
487 T: Send + 'static,
488 E: From<rusqlite::Error> + From<Error> + Send + 'static,
489 {
490 let rx = self
491 .enqueue_blocking(move |conn| run_catching_and_then(conn, |conn| func(conn)))
492 .map_err(Error::from)?;
493 rx.recv().map_err(Error::from)?
494 }
495
496 pub fn conn_mut_and_then_blocking<F, T, E>(&self, func: F) -> Result<T, E>
502 where
503 F: FnOnce(&mut Connection) -> Result<T, E> + Send + 'static,
504 T: Send + 'static,
505 E: From<rusqlite::Error> + From<Error> + Send + 'static,
506 {
507 let rx = self
508 .enqueue_blocking(move |conn| run_catching_and_then(conn, func))
509 .map_err(Error::from)?;
510 rx.recv().map_err(Error::from)?
511 }
512
513 pub fn close_blocking(&self) -> Result<(), Error> {
519 let (tx, rx) = bounded(1);
520 match self
521 .conn_tx
522 .try_send(Command::Shutdown(Box::new(BlockingShutdown { tx })))
523 {
524 Ok(()) => {}
525 Err(TrySendError::Disconnected(_)) => return Ok(()),
526 Err(err) => return Err(err.into()),
527 }
528 rx.recv().unwrap_or(Ok(()))
530 }
531}
532
533#[derive(Clone, Copy, Debug)]
537pub enum JournalMode {
538 Delete,
539 Truncate,
540 Persist,
541 Memory,
542 Wal,
543 Off,
544}
545
546impl JournalMode {
547 pub fn as_str(&self) -> &'static str {
549 match self {
550 Self::Delete => "DELETE",
551 Self::Truncate => "TRUNCATE",
552 Self::Persist => "PERSIST",
553 Self::Memory => "MEMORY",
554 Self::Wal => "WAL",
555 Self::Off => "OFF",
556 }
557 }
558}