edgehog_device_runtime_store/
db.rs1use std::{
30 error::Error,
31 fmt::Debug,
32 num::NonZeroUsize,
33 path::{Path, PathBuf},
34 sync::Arc,
35 time::Duration,
36};
37
38use deadpool::managed::{BuildError, Pool, PoolError};
39use diesel::{connection::SimpleConnection, Connection, ConnectionError, SqliteConnection};
40use tokio::{sync::Mutex, task::JoinError};
41
42type DynError = Box<dyn Error + Send + Sync>;
43pub type Result<T> = std::result::Result<T, HandleError>;
45
46#[derive(Debug, thiserror::Error, displaydoc::Display)]
48pub enum HandleError {
49 NonUtf8Path(PathBuf),
51 Join(#[from] JoinError),
53 PoolBuilder(#[from] BuildError),
55 Writer(#[from] ManagerError),
57 Reader(#[from] PoolError<ManagerError>),
59 Query(#[from] diesel::result::Error),
61 Migrations(#[source] DynError),
63 UpdateRows {
65 modified: usize,
67 exp: usize,
69 },
70 #[error(transparent)]
72 Application(DynError),
73}
74
75impl HandleError {
76 pub fn from_app(error: impl Into<DynError>) -> Self {
78 Self::Application(error.into())
79 }
80}
81
82impl HandleError {
83 pub fn check_modified(modified: usize, exp: usize) -> Result<()> {
85 if modified != exp {
86 Err(HandleError::UpdateRows { exp, modified })
87 } else {
88 Ok(())
89 }
90 }
91}
92
93#[derive(Clone)]
95pub struct Handle {
96 writer: Arc<Mutex<SqliteConnection>>,
98 readers: Pool<Manager>,
100}
101
102impl Handle {
103 pub async fn open(db_file: impl AsRef<Path>) -> Result<Self> {
105 Self::with_options(db_file, SqliteOpts::default()).await
106 }
107
108 pub async fn with_options(db_file: impl AsRef<Path>, options: SqliteOpts) -> Result<Self> {
110 let db_path = db_file.as_ref();
111 let db_str: String = db_path
112 .to_str()
113 .ok_or_else(|| HandleError::NonUtf8Path(db_path.to_path_buf()))
114 .map(str::to_string)?;
115
116 let manager = Manager {
117 db_file: db_str,
118 options,
119 };
120
121 let writer = manager.establish(false).await?;
122 #[cfg(feature = "containers")]
124 let mut writer = writer;
125
126 let writer = tokio::task::spawn_blocking(move || -> Result<SqliteConnection> {
127 #[cfg(feature = "containers")]
128 {
129 use diesel_migrations::MigrationHarness;
130 writer
131 .run_pending_migrations(crate::schema::CONTAINER_MIGRATIONS)
132 .map_err(HandleError::Migrations)?;
133 }
134
135 Ok(writer)
136 })
137 .await??;
138
139 let readers = Pool::builder(manager)
140 .max_size(options.max_pool_size.get())
141 .build()?;
142
143 Ok(Self {
144 writer: Arc::new(Mutex::new(writer)),
145 readers,
146 })
147 }
148
149 pub async fn for_read<F, O>(&self, f: F) -> Result<O>
151 where
152 F: FnOnce(&mut SqliteConnection) -> Result<O> + Send + 'static,
153 O: Send + 'static,
154 {
155 let mut reader = self.readers.get().await?;
156
157 let res = tokio::task::spawn_blocking(move || (f)(&mut reader)).await?;
159
160 res
161 }
162
163 pub async fn for_write<F, O>(&self, f: F) -> Result<O>
165 where
166 F: FnOnce(&mut SqliteConnection) -> Result<O> + Send + 'static,
167 O: Send + 'static,
168 {
169 let mut writer = Arc::clone(&self.writer).lock_owned().await;
170
171 tokio::task::spawn_blocking(move || writer.transaction(|writer| (f)(writer))).await?
172 }
173}
174
175impl Debug for Handle {
176 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
177 f.debug_struct("Handle")
178 .field("db_path", &self.readers.manager().db_file)
179 .finish_non_exhaustive()
180 }
181}
182
183#[derive(Debug, Clone, Copy)]
185pub struct SqliteOpts {
186 max_pool_size: NonZeroUsize,
187 busy_timeout: Duration,
188 cache_size: i16,
189 max_page_count: u32,
190 journal_size_limit: u64,
191 wal_autocheckpoint: u32,
192}
193
194impl SqliteOpts {
195 pub fn set_max_pool_size(&mut self, max_pool_size: NonZeroUsize) {
197 self.max_pool_size = max_pool_size;
198 }
199
200 pub fn set_busy_timeout(&mut self, busy_timeout: Duration) {
202 self.busy_timeout = busy_timeout;
203 }
204
205 pub fn set_max_page_count(&mut self, max_page_count: u32) {
207 self.max_page_count = max_page_count;
208 }
209
210 pub fn set_journal_size_limit(&mut self, journal_size_limit: u64) {
212 self.journal_size_limit = journal_size_limit;
213 }
214
215 pub fn set_wal_autocheckpoint(&mut self, wal_autocheckpoint: u32) {
217 self.wal_autocheckpoint = wal_autocheckpoint;
218 }
219}
220
221impl Default for SqliteOpts {
222 fn default() -> Self {
223 const DEFAULT_POOL_SIZE: NonZeroUsize = match NonZeroUsize::new(4) {
224 Some(size) => size,
225 None => unreachable!(),
226 };
227 const DEFAULT_MAX_PAGE_COUNT: u32 = 2 * (1024 * 1024 * 1024) / 4096;
229
230 Self {
231 max_pool_size: std::thread::available_parallelism().unwrap_or(DEFAULT_POOL_SIZE),
232 busy_timeout: Duration::from_secs(5),
233 cache_size: -2 * 1024,
235 max_page_count: DEFAULT_MAX_PAGE_COUNT,
237 journal_size_limit: 64 * 1024 * 1024,
239 wal_autocheckpoint: 1000,
241 }
242 }
243}
244
245struct Manager {
246 db_file: String,
247 options: SqliteOpts,
248}
249
250impl Manager {
251 async fn establish(&self, reader: bool) -> std::result::Result<SqliteConnection, ManagerError> {
252 let options = self.options;
253 let db_file = self.db_file.clone();
254 tokio::task::spawn_blocking(move || {
255 let mut conn =
256 SqliteConnection::establish(&db_file).map_err(|err| ManagerError::Connection {
257 db_file: db_file.to_string(),
258 backtrace: err,
259 })?;
260
261 conn.batch_execute("PRAGMA journal_mode = wal;")?;
262 conn.batch_execute("PRAGMA foreign_keys = true;")?;
263 conn.batch_execute("PRAGMA synchronous = NORMAL;")?;
264 conn.batch_execute("PRAGMA auto_vacuum = INCREMENTAL;")?;
265 conn.batch_execute("PRAGMA temp_store = MEMORY;")?;
266 conn.batch_execute(&format!(
268 "PRAGMA busy_timeout = {};",
269 options.busy_timeout.as_millis()
270 ))?;
271 conn.batch_execute(&format!("PRAGMA cache_size = {};", options.cache_size))?;
272 conn.batch_execute(&format!(
273 "PRAGMA max_page_count = {};",
274 options.max_page_count
275 ))?;
276 conn.batch_execute(&format!(
277 "PRAGMA journal_size_limit = {};",
278 options.journal_size_limit
279 ))?;
280 conn.batch_execute(&format!(
281 "PRAGMA wal_autocheckpoint = {};",
282 options.wal_autocheckpoint
283 ))?;
284
285 if reader {
286 conn.batch_execute("PRAGMA query_only = ON;")?;
287 }
288
289 Ok(conn)
290 })
291 .await?
292 }
293}
294
295impl deadpool::managed::Manager for Manager {
296 type Type = diesel::sqlite::SqliteConnection;
297
298 type Error = ManagerError;
299
300 async fn create(&self) -> std::result::Result<Self::Type, Self::Error> {
301 self.establish(true).await
302 }
303
304 async fn recycle(
305 &self,
306 _obj: &mut Self::Type,
307 _metrics: &deadpool::managed::Metrics,
308 ) -> deadpool::managed::RecycleResult<Self::Error> {
309 Ok(())
310 }
311}
312
313#[derive(Debug, thiserror::Error, displaydoc::Display)]
315#[non_exhaustive]
316pub enum ManagerError {
317 Connection {
319 db_file: String,
321 #[source]
323 backtrace: ConnectionError,
324 },
325 Join(#[from] JoinError),
327 Query(#[from] diesel::result::Error),
329}
330
331#[cfg(test)]
332mod tests {
333 use tempfile::TempDir;
334
335 use super::*;
336
337 #[tokio::test]
338 async fn should_open_db() {
339 let tmp = TempDir::with_prefix("should_open").unwrap();
340
341 Handle::open(&tmp.path().join("database.db")).await.unwrap();
342 }
343}