Skip to main content

edgehog_device_runtime_store/
db.rs

1// This file is part of Edgehog.
2//
3// Copyright 2024 - 2025 SECO Mind Srl
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9//    http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16//
17// SPDX-License-Identifier: Apache-2.0
18
19//! Structure to handle the SQLite store.
20//!
21//! ## Concurrency
22//!
23//! It handles concurrency by having a shared Mutex for the writer part and a per instance reader.
24//! To have a new reader you need to open a new connection to the database.
25//!
26//! We pass a mutable reference to the connection to a [`FnOnce`]. If the closure panics the
27//! connection will be lost and needs to be recreated.
28
29use std::{
30    error::Error,
31    fmt::Debug,
32    num::NonZeroUsize,
33    path::{Path, PathBuf},
34    sync::Arc,
35    time::Duration,
36};
37
38use deadpool::managed::{BuildError, Pool, PoolError};
39use diesel::{connection::SimpleConnection, Connection, ConnectionError, SqliteConnection};
40use tokio::{sync::Mutex, task::JoinError};
41
42type DynError = Box<dyn Error + Send + Sync>;
43/// Result for the [`HandleError`] returned by the [`Handle`].
44pub type Result<T> = std::result::Result<T, HandleError>;
45
46/// Handler error
47#[derive(Debug, thiserror::Error, displaydoc::Display)]
48pub enum HandleError {
49    /// couldn't open database with non UTF-8 path: {0}
50    NonUtf8Path(PathBuf),
51    /// couldn't join database task
52    Join(#[from] JoinError),
53    /// error returned while building the reader pool
54    PoolBuilder(#[from] BuildError),
55    /// error returned while creating the writer connection
56    Writer(#[from] ManagerError),
57    /// error returned while getting a reader connection
58    Reader(#[from] PoolError<ManagerError>),
59    /// couldn't execute the query
60    Query(#[from] diesel::result::Error),
61    /// couldn't run pending migrations
62    Migrations(#[source] DynError),
63    /// wrong number of rows updated, expected {exp} but modified {modified}
64    UpdateRows {
65        /// Number of rows modified.
66        modified: usize,
67        /// Expected number or rows.
68        exp: usize,
69    },
70    /// error returned by the application
71    #[error(transparent)]
72    Application(DynError),
73}
74
75impl HandleError {
76    /// Creates an [`HandleError::Application`] error.
77    pub fn from_app(error: impl Into<DynError>) -> Self {
78        Self::Application(error.into())
79    }
80}
81
82impl HandleError {
83    /// Check the result of the number of rows a query modified
84    pub fn check_modified(modified: usize, exp: usize) -> Result<()> {
85        if modified != exp {
86            Err(HandleError::UpdateRows { exp, modified })
87        } else {
88            Ok(())
89        }
90    }
91}
92
93/// Read and write connection to the database
94#[derive(Clone)]
95pub struct Handle {
96    /// Write handle to the database
97    writer: Arc<Mutex<SqliteConnection>>,
98    /// Per task/thread reader
99    readers: Pool<Manager>,
100}
101
102impl Handle {
103    /// Create a new instance by connecting to the file with default options
104    pub async fn open(db_file: impl AsRef<Path>) -> Result<Self> {
105        Self::with_options(db_file, SqliteOpts::default()).await
106    }
107
108    /// Create a new instance by connecting to the file
109    pub async fn with_options(db_file: impl AsRef<Path>, options: SqliteOpts) -> Result<Self> {
110        let db_path = db_file.as_ref();
111        let db_str: String = db_path
112            .to_str()
113            .ok_or_else(|| HandleError::NonUtf8Path(db_path.to_path_buf()))
114            .map(str::to_string)?;
115
116        let manager = Manager {
117            db_file: db_str,
118            options,
119        };
120
121        let writer = manager.establish(false).await?;
122        // We don't have migrations other than the containers for now
123        #[cfg(feature = "containers")]
124        let mut writer = writer;
125
126        let writer = tokio::task::spawn_blocking(move || -> Result<SqliteConnection> {
127            #[cfg(feature = "containers")]
128            {
129                use diesel_migrations::MigrationHarness;
130                writer
131                    .run_pending_migrations(crate::schema::CONTAINER_MIGRATIONS)
132                    .map_err(HandleError::Migrations)?;
133            }
134
135            Ok(writer)
136        })
137        .await??;
138
139        let readers = Pool::builder(manager)
140            .max_size(options.max_pool_size.get())
141            .build()?;
142
143        Ok(Self {
144            writer: Arc::new(Mutex::new(writer)),
145            readers,
146        })
147    }
148
149    /// Passes the reader to a callback to execute a query.
150    pub async fn for_read<F, O>(&self, f: F) -> Result<O>
151    where
152        F: FnOnce(&mut SqliteConnection) -> Result<O> + Send + 'static,
153        O: Send + 'static,
154    {
155        let mut reader = self.readers.get().await?;
156
157        // If this task panics (the error is returned) the connection would still be null
158        let res = tokio::task::spawn_blocking(move || (f)(&mut reader)).await?;
159
160        res
161    }
162
163    /// Passes the writer to a callback with a transaction already started.
164    pub async fn for_write<F, O>(&self, f: F) -> Result<O>
165    where
166        F: FnOnce(&mut SqliteConnection) -> Result<O> + Send + 'static,
167        O: Send + 'static,
168    {
169        let mut writer = Arc::clone(&self.writer).lock_owned().await;
170
171        tokio::task::spawn_blocking(move || writer.transaction(|writer| (f)(writer))).await?
172    }
173}
174
175impl Debug for Handle {
176    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
177        f.debug_struct("Handle")
178            .field("db_path", &self.readers.manager().db_file)
179            .finish_non_exhaustive()
180    }
181}
182
183/// Options for the SQLite connection
184#[derive(Debug, Clone, Copy)]
185pub struct SqliteOpts {
186    max_pool_size: NonZeroUsize,
187    busy_timeout: Duration,
188    cache_size: i16,
189    max_page_count: u32,
190    journal_size_limit: u64,
191    wal_autocheckpoint: u32,
192}
193
194impl SqliteOpts {
195    /// Setter for the max pool size
196    pub fn set_max_pool_size(&mut self, max_pool_size: NonZeroUsize) {
197        self.max_pool_size = max_pool_size;
198    }
199
200    /// Setter for the busy timeout
201    pub fn set_busy_timeout(&mut self, busy_timeout: Duration) {
202        self.busy_timeout = busy_timeout;
203    }
204
205    /// Setter for the max page count
206    pub fn set_max_page_count(&mut self, max_page_count: u32) {
207        self.max_page_count = max_page_count;
208    }
209
210    /// Setter for the journal size limit
211    pub fn set_journal_size_limit(&mut self, journal_size_limit: u64) {
212        self.journal_size_limit = journal_size_limit;
213    }
214
215    /// Setter for the WAL auto-checkpoint
216    pub fn set_wal_autocheckpoint(&mut self, wal_autocheckpoint: u32) {
217        self.wal_autocheckpoint = wal_autocheckpoint;
218    }
219}
220
221impl Default for SqliteOpts {
222    fn default() -> Self {
223        const DEFAULT_POOL_SIZE: NonZeroUsize = match NonZeroUsize::new(4) {
224            Some(size) => size,
225            None => unreachable!(),
226        };
227        // 2 gib (assumes 4096 page size)
228        const DEFAULT_MAX_PAGE_COUNT: u32 = 2 * (1024 * 1024 * 1024) / 4096;
229
230        Self {
231            max_pool_size: std::thread::available_parallelism().unwrap_or(DEFAULT_POOL_SIZE),
232            busy_timeout: Duration::from_secs(5),
233            // 2 kib
234            cache_size: -2 * 1024,
235            // 2 gib (assumes 4096 page size)
236            max_page_count: DEFAULT_MAX_PAGE_COUNT,
237            // 64 mib
238            journal_size_limit: 64 * 1024 * 1024,
239            // 1000 pages
240            wal_autocheckpoint: 1000,
241        }
242    }
243}
244
245struct Manager {
246    db_file: String,
247    options: SqliteOpts,
248}
249
250impl Manager {
251    async fn establish(&self, reader: bool) -> std::result::Result<SqliteConnection, ManagerError> {
252        let options = self.options;
253        let db_file = self.db_file.clone();
254        tokio::task::spawn_blocking(move || {
255            let mut conn =
256                SqliteConnection::establish(&db_file).map_err(|err| ManagerError::Connection {
257                    db_file: db_file.to_string(),
258                    backtrace: err,
259                })?;
260
261            conn.batch_execute("PRAGMA journal_mode = wal;")?;
262            conn.batch_execute("PRAGMA foreign_keys = true;")?;
263            conn.batch_execute("PRAGMA synchronous = NORMAL;")?;
264            conn.batch_execute("PRAGMA auto_vacuum = INCREMENTAL;")?;
265            conn.batch_execute("PRAGMA temp_store = MEMORY;")?;
266            // NOTE: Safe to format since we handle the options, do not pass strings.
267            conn.batch_execute(&format!(
268                "PRAGMA busy_timeout = {};",
269                options.busy_timeout.as_millis()
270            ))?;
271            conn.batch_execute(&format!("PRAGMA cache_size = {};", options.cache_size))?;
272            conn.batch_execute(&format!(
273                "PRAGMA max_page_count = {};",
274                options.max_page_count
275            ))?;
276            conn.batch_execute(&format!(
277                "PRAGMA journal_size_limit = {};",
278                options.journal_size_limit
279            ))?;
280            conn.batch_execute(&format!(
281                "PRAGMA wal_autocheckpoint = {};",
282                options.wal_autocheckpoint
283            ))?;
284
285            if reader {
286                conn.batch_execute("PRAGMA query_only = ON;")?;
287            }
288
289            Ok(conn)
290        })
291        .await?
292    }
293}
294
295impl deadpool::managed::Manager for Manager {
296    type Type = diesel::sqlite::SqliteConnection;
297
298    type Error = ManagerError;
299
300    async fn create(&self) -> std::result::Result<Self::Type, Self::Error> {
301        self.establish(true).await
302    }
303
304    async fn recycle(
305        &self,
306        _obj: &mut Self::Type,
307        _metrics: &deadpool::managed::Metrics,
308    ) -> deadpool::managed::RecycleResult<Self::Error> {
309        Ok(())
310    }
311}
312
313/// Error returned while creating a connection
314#[derive(Debug, thiserror::Error, displaydoc::Display)]
315#[non_exhaustive]
316pub enum ManagerError {
317    /// couldn't connect to the database {db_file}
318    Connection {
319        /// Connection to the database file
320        db_file: String,
321        /// Underling connection error
322        #[source]
323        backtrace: ConnectionError,
324    },
325    /// couldn't join database task
326    Join(#[from] JoinError),
327    /// couldn't execute the query
328    Query(#[from] diesel::result::Error),
329}
330
331#[cfg(test)]
332mod tests {
333    use tempfile::TempDir;
334
335    use super::*;
336
337    #[tokio::test]
338    async fn should_open_db() {
339        let tmp = TempDir::with_prefix("should_open").unwrap();
340
341        Handle::open(&tmp.path().join("database.db")).await.unwrap();
342    }
343}