1use std::{
2 num::NonZeroUsize,
3 path::{Path, PathBuf},
4 sync::{
5 atomic::{AtomicU64, Ordering::Relaxed},
6 Arc,
7 },
8 thread::available_parallelism,
9};
10
11use crate::{Client, ClientBuilder, Error, JournalMode};
12
13use futures_util::future::join_all;
14use rusqlite::{Connection, OpenFlags};
15
16#[derive(Clone, Debug, Default)]
35pub struct PoolBuilder {
36 path: Option<PathBuf>,
37 flags: OpenFlags,
38 journal_mode: Option<JournalMode>,
39 vfs: Option<String>,
40 num_conns: Option<usize>,
41}
42
43impl PoolBuilder {
44 pub fn new() -> Self {
46 Self::default()
47 }
48
49 pub fn path<P: AsRef<Path>>(mut self, path: P) -> Self {
53 self.path = Some(path.as_ref().into());
54 self
55 }
56
57 pub fn flags(mut self, flags: OpenFlags) -> Self {
61 self.flags = flags;
62 self
63 }
64
65 pub fn journal_mode(mut self, journal_mode: JournalMode) -> Self {
69 self.journal_mode = Some(journal_mode);
70 self
71 }
72
73 pub fn vfs(mut self, vfs: &str) -> Self {
75 self.vfs = Some(vfs.to_owned());
76 self
77 }
78
79 pub fn num_conns(mut self, num_conns: usize) -> Self {
90 self.num_conns = Some(num_conns.max(1));
91 self
92 }
93
94 pub async fn open(self) -> Result<Pool, Error> {
106 let num_conns = self.get_num_conns();
107
108 let first = ClientBuilder {
112 path: self.path.clone(),
113 flags: self.flags,
114 journal_mode: self.journal_mode,
115 vfs: self.vfs.clone(),
116 }
117 .open()
118 .await?;
119
120 let opens = (1..num_conns).map(|_| {
123 ClientBuilder {
124 path: self.path.clone(),
125 flags: self.flags,
126 journal_mode: None,
127 vfs: self.vfs.clone(),
128 }
129 .open()
130 });
131 let mut clients = vec![first];
132 clients.extend(
133 join_all(opens)
134 .await
135 .into_iter()
136 .collect::<Result<Vec<Client>, Error>>()?,
137 );
138
139 Ok(Pool {
140 state: Arc::new(State {
141 clients,
142 counter: AtomicU64::new(0),
143 }),
144 })
145 }
146
147 pub fn open_blocking(self) -> Result<Pool, Error> {
160 let num_conns = self.get_num_conns();
161
162 let first = ClientBuilder {
164 path: self.path.clone(),
165 flags: self.flags,
166 journal_mode: self.journal_mode,
167 vfs: self.vfs.clone(),
168 }
169 .open_blocking()?;
170
171 let mut clients = vec![first];
174 clients.extend(
175 (1..num_conns)
176 .map(|_| {
177 ClientBuilder {
178 path: self.path.clone(),
179 flags: self.flags,
180 journal_mode: None,
181 vfs: self.vfs.clone(),
182 }
183 .open_blocking()
184 })
185 .collect::<Result<Vec<Client>, Error>>()?,
186 );
187
188 Ok(Pool {
189 state: Arc::new(State {
190 clients,
191 counter: AtomicU64::new(0),
192 }),
193 })
194 }
195
196 fn get_num_conns(&self) -> usize {
197 self.num_conns.unwrap_or_else(|| {
198 available_parallelism()
199 .unwrap_or_else(|_| NonZeroUsize::new(1).unwrap())
200 .into()
201 })
202 }
203}
204
205#[derive(Clone)]
209pub struct Pool {
210 state: Arc<State>,
211}
212
213struct State {
214 clients: Vec<Client>,
215 counter: AtomicU64,
216}
217
218impl Pool {
219 pub async fn conn<F, T>(&self, func: F) -> Result<T, Error>
221 where
222 F: FnOnce(&Connection) -> Result<T, rusqlite::Error> + Send + 'static,
223 T: Send + 'static,
224 {
225 self.get().conn(func).await
226 }
227
228 pub async fn conn_mut<F, T>(&self, func: F) -> Result<T, Error>
230 where
231 F: FnOnce(&mut Connection) -> Result<T, rusqlite::Error> + Send + 'static,
232 T: Send + 'static,
233 {
234 self.get().conn_mut(func).await
235 }
236
237 pub async fn close(&self) -> Result<(), Error> {
242 let closes = self.state.clients.iter().map(|client| client.close());
243 let res = join_all(closes).await;
244 res.into_iter().collect::<Result<Vec<_>, Error>>()?;
245 Ok(())
246 }
247
248 pub fn conn_blocking<F, T>(&self, func: F) -> Result<T, Error>
251 where
252 F: FnOnce(&Connection) -> Result<T, rusqlite::Error> + Send + 'static,
253 T: Send + 'static,
254 {
255 self.get().conn_blocking(func)
256 }
257
258 pub fn conn_mut_blocking<F, T>(&self, func: F) -> Result<T, Error>
261 where
262 F: FnOnce(&mut Connection) -> Result<T, rusqlite::Error> + Send + 'static,
263 T: Send + 'static,
264 {
265 self.get().conn_mut_blocking(func)
266 }
267
268 pub fn close_blocking(&self) -> Result<(), Error> {
273 let mut first_err = None;
274 for client in self.state.clients.iter() {
275 if let Err(e) = client.close_blocking() {
276 if first_err.is_none() {
277 first_err = Some(e);
278 }
279 }
280 }
281 match first_err {
282 Some(e) => Err(e),
283 None => Ok(()),
284 }
285 }
286
287 fn get(&self) -> &Client {
288 let n = self.state.counter.fetch_add(1, Relaxed);
289 &self.state.clients[n as usize % self.state.clients.len()]
290 }
291
292 pub async fn conn_for_each<F, T>(&self, func: F) -> Vec<Result<T, Error>>
296 where
297 F: Fn(&Connection) -> Result<T, rusqlite::Error> + Send + Sync + 'static,
298 T: Send + 'static,
299 {
300 let func = Arc::new(func);
301 let futures = self.state.clients.iter().map(|client| {
302 let func = func.clone();
303 async move { client.conn(move |conn| func(conn)).await }
304 });
305 join_all(futures).await
306 }
307
308 pub fn conn_for_each_blocking<F, T>(&self, func: F) -> Vec<Result<T, Error>>
310 where
311 F: Fn(&Connection) -> Result<T, rusqlite::Error> + Send + Sync + 'static,
312 T: Send + 'static,
313 {
314 let func = Arc::new(func);
315 self.state
316 .clients
317 .iter()
318 .map(|client| {
319 let func = func.clone();
320 client.conn_blocking(move |conn| func(conn))
321 })
322 .collect()
323 }
324}