1use std::{
2 num::NonZeroUsize,
3 path::{Path, PathBuf},
4 sync::{
5 atomic::{AtomicU64, Ordering::Relaxed},
6 Arc,
7 },
8 thread::available_parallelism,
9};
10
11use crate::{Client, ClientBuilder, Error, JournalMode};
12
13use futures_util::future::join_all;
14use rusqlite::{Connection, OpenFlags};
15
16#[derive(Clone, Debug, Default)]
35pub struct PoolBuilder {
36 path: Option<PathBuf>,
37 flags: OpenFlags,
38 journal_mode: Option<JournalMode>,
39 vfs: Option<String>,
40 num_conns: Option<usize>,
41}
42
43impl PoolBuilder {
44 pub fn new() -> Self {
46 Self::default()
47 }
48
49 pub fn path<P: AsRef<Path>>(mut self, path: P) -> Self {
53 self.path = Some(path.as_ref().into());
54 self
55 }
56
57 pub fn flags(mut self, flags: OpenFlags) -> Self {
61 self.flags = flags;
62 self
63 }
64
65 pub fn journal_mode(mut self, journal_mode: JournalMode) -> Self {
69 self.journal_mode = Some(journal_mode);
70 self
71 }
72
73 pub fn vfs(mut self, vfs: &str) -> Self {
75 self.vfs = Some(vfs.to_owned());
76 self
77 }
78
79 pub fn num_conns(mut self, num_conns: usize) -> Self {
90 self.num_conns = Some(num_conns.max(1));
91 self
92 }
93
94 pub async fn open(self) -> Result<Pool, Error> {
106 let num_conns = self.get_num_conns();
107 let opens = (0..num_conns).map(|_| {
108 ClientBuilder {
109 path: self.path.clone(),
110 flags: self.flags,
111 journal_mode: self.journal_mode,
112 vfs: self.vfs.clone(),
113 }
114 .open()
115 });
116 let clients = join_all(opens)
117 .await
118 .into_iter()
119 .collect::<Result<Vec<Client>, Error>>()?;
120 Ok(Pool {
121 state: Arc::new(State {
122 clients,
123 counter: AtomicU64::new(0),
124 }),
125 })
126 }
127
128 pub fn open_blocking(self) -> Result<Pool, Error> {
141 let num_conns = self.get_num_conns();
142 let clients = (0..num_conns)
143 .map(|_| {
144 ClientBuilder {
145 path: self.path.clone(),
146 flags: self.flags,
147 journal_mode: self.journal_mode,
148 vfs: self.vfs.clone(),
149 }
150 .open_blocking()
151 })
152 .collect::<Result<Vec<Client>, Error>>()?;
153 Ok(Pool {
154 state: Arc::new(State {
155 clients,
156 counter: AtomicU64::new(0),
157 }),
158 })
159 }
160
161 fn get_num_conns(&self) -> usize {
162 self.num_conns.unwrap_or_else(|| {
163 available_parallelism()
164 .unwrap_or_else(|_| NonZeroUsize::new(1).unwrap())
165 .into()
166 })
167 }
168}
169
170#[derive(Clone)]
174pub struct Pool {
175 state: Arc<State>,
176}
177
178struct State {
179 clients: Vec<Client>,
180 counter: AtomicU64,
181}
182
183impl Pool {
184 pub async fn conn<F, T>(&self, func: F) -> Result<T, Error>
186 where
187 F: FnOnce(&Connection) -> Result<T, rusqlite::Error> + Send + 'static,
188 T: Send + 'static,
189 {
190 self.get().conn(func).await
191 }
192
193 pub async fn conn_mut<F, T>(&self, func: F) -> Result<T, Error>
195 where
196 F: FnOnce(&mut Connection) -> Result<T, rusqlite::Error> + Send + 'static,
197 T: Send + 'static,
198 {
199 self.get().conn_mut(func).await
200 }
201
202 pub async fn close(&self) -> Result<(), Error> {
207 let closes = self.state.clients.iter().map(|client| client.close());
208 let res = join_all(closes).await;
209 res.into_iter().collect::<Result<Vec<_>, Error>>()?;
210 Ok(())
211 }
212
213 pub fn conn_blocking<F, T>(&self, func: F) -> Result<T, Error>
216 where
217 F: FnOnce(&Connection) -> Result<T, rusqlite::Error> + Send + 'static,
218 T: Send + 'static,
219 {
220 self.get().conn_blocking(func)
221 }
222
223 pub fn conn_mut_blocking<F, T>(&self, func: F) -> Result<T, Error>
226 where
227 F: FnOnce(&mut Connection) -> Result<T, rusqlite::Error> + Send + 'static,
228 T: Send + 'static,
229 {
230 self.get().conn_mut_blocking(func)
231 }
232
233 pub fn close_blocking(&self) -> Result<(), Error> {
238 self.state
239 .clients
240 .iter()
241 .try_for_each(|client| client.close_blocking())
242 }
243
244 fn get(&self) -> &Client {
245 let n = self.state.counter.fetch_add(1, Relaxed);
246 &self.state.clients[n as usize % self.state.clients.len()]
247 }
248
249 pub async fn conn_for_each<F, T>(&self, func: F) -> Vec<Result<T, Error>>
253 where
254 F: Fn(&Connection) -> Result<T, rusqlite::Error> + Send + Sync + 'static,
255 T: Send + 'static,
256 {
257 let func = Arc::new(func);
258 let futures = self.state.clients.iter().map(|client| {
259 let func = func.clone();
260 async move { client.conn(move |conn| func(conn)).await }
261 });
262 join_all(futures).await
263 }
264
265 pub fn conn_for_each_blocking<F, T>(&self, func: F) -> Vec<Result<T, Error>>
267 where
268 F: Fn(&Connection) -> Result<T, rusqlite::Error> + Send + Sync + 'static,
269 T: Send + 'static,
270 {
271 let func = Arc::new(func);
272 self.state
273 .clients
274 .iter()
275 .map(|client| {
276 let func = func.clone();
277 client.conn_blocking(move |conn| func(conn))
278 })
279 .collect()
280 }
281}