Skip to main content

miden_node_db/
lib.rs

1mod conv;
2mod errors;
3mod manager;
4pub mod migration;
5
6use std::num::NonZeroUsize;
7use std::path::Path;
8
9pub use conv::{DatabaseTypeConversionError, SqlTypeConvert};
10use diesel::{RunQueryDsl, SqliteConnection};
11pub use errors::{DatabaseError, SchemaVerificationError};
12pub use manager::{ConnectionManager, ConnectionManagerError, configure_connection_on_creation};
13use tracing::Instrument;
14
15pub type Result<T, E = DatabaseError> = std::result::Result<T, E>;
16
17/// Returns the default SQLite connection pool size.
18///
19/// Defaults to twice the available CPU parallelism. If the OS cannot report the available
20/// parallelism, fall back to two connections.
21pub fn default_connection_pool_size() -> NonZeroUsize {
22    let available_cores = std::thread::available_parallelism().map_or(1, NonZeroUsize::get);
23    let connection_count = available_cores.saturating_mul(2);
24    NonZeroUsize::new(connection_count).expect("connection count must be non-zero")
25}
26
27/// Database handle that provides fundamental operations that various components of Miden Node can
28/// utililze for their storage needs.
29#[derive(Clone)]
30pub struct Db {
31    pool: deadpool_diesel::Pool<ConnectionManager, deadpool::managed::Object<ConnectionManager>>,
32}
33
34impl Db {
35    /// Creates a new database instance with the provided connection pool.
36    pub fn new(database_filepath: &Path) -> Result<Self, DatabaseError> {
37        Self::new_with_pool_size(database_filepath, default_connection_pool_size())
38    }
39
40    /// Creates a new database instance with the provided connection pool size.
41    pub fn new_with_pool_size(
42        database_filepath: &Path,
43        connection_pool_size: NonZeroUsize,
44    ) -> Result<Self, DatabaseError> {
45        let manager = ConnectionManager::new(database_filepath.to_str().unwrap());
46        let pool = deadpool_diesel::Pool::builder(manager)
47            .max_size(connection_pool_size.get())
48            .build()?;
49        Ok(Self { pool })
50    }
51
52    /// Checks out a connection from the pool and pins it for the caller's exclusive, long-lived
53    /// use. See [`PinnedConnection`].
54    ///
55    /// This removes one connection from the shared pool for the lifetime of the returned handle,
56    /// so the pool must be sized to leave at least one connection for other users.
57    pub async fn pinned_connection(&self) -> Result<PinnedConnection, DatabaseError> {
58        let conn = self
59            .pool
60            .get()
61            .in_current_span()
62            .await
63            .map_err(|e| DatabaseError::ConnectionPoolObtainError(Box::new(e)))?;
64        Ok(PinnedConnection { conn })
65    }
66
67    /// Create and commit a transaction with the queries added in the provided closure
68    pub async fn transact<R, E, Q, M>(&self, msg: M, query: Q) -> std::result::Result<R, E>
69    where
70        Q: Send
71            + for<'a, 't> FnOnce(&'a mut SqliteConnection) -> std::result::Result<R, E>
72            + 'static,
73        R: Send + 'static,
74        M: Send + ToString,
75        E: From<diesel::result::Error>,
76        E: From<DatabaseError>,
77        E: std::error::Error + Send + Sync + 'static,
78    {
79        self.pinned_connection().await.map_err(E::from)?.transact(msg, query).await
80    }
81
82    /// Run the query _without_ a transaction
83    pub async fn query<R, E, Q, M>(&self, msg: M, query: Q) -> std::result::Result<R, E>
84    where
85        Q: Send + FnOnce(&mut SqliteConnection) -> std::result::Result<R, E> + 'static,
86        R: Send + 'static,
87        M: Send + ToString,
88        E: From<DatabaseError>,
89        E: std::error::Error + Send + Sync + 'static,
90    {
91        self.pinned_connection().await.map_err(E::from)?.query(msg, query).await
92    }
93}
94
95/// A connection checked out of [`Db`]'s pool and held for the caller's exclusive, long-lived use.
96///
97/// A hot event loop can pin a connection so its queries never wait on the shared pool even when
98/// many concurrent tasks are saturating it. `transact`/`query` mirror [`Db`]'s, but run on the
99/// pinned connection rather than acquiring one per call. The connection is returned to the pool
100/// when the `PinnedConnection` is dropped.
101pub struct PinnedConnection {
102    conn: deadpool::managed::Object<ConnectionManager>,
103}
104
105impl PinnedConnection {
106    /// Create and commit a transaction with the queries added in the provided closure, running on
107    /// the pinned connection.
108    pub async fn transact<R, E, Q, M>(&self, msg: M, query: Q) -> std::result::Result<R, E>
109    where
110        Q: Send
111            + for<'a, 't> FnOnce(&'a mut SqliteConnection) -> std::result::Result<R, E>
112            + 'static,
113        R: Send + 'static,
114        M: Send + ToString,
115        E: From<diesel::result::Error>,
116        E: From<DatabaseError>,
117        E: std::error::Error + Send + Sync + 'static,
118    {
119        let span = tracing::Span::current();
120        self.conn
121            .interact(move |conn| {
122                let _guard = span.enter();
123                <_ as diesel::Connection>::transaction::<R, E, Q>(conn, query)
124            })
125            .await
126            .map_err(|err| E::from(DatabaseError::interact(&msg.to_string(), &err)))?
127    }
128
129    /// Run the query _without_ a transaction on the pinned connection.
130    pub async fn query<R, E, Q, M>(&self, msg: M, query: Q) -> std::result::Result<R, E>
131    where
132        Q: Send + FnOnce(&mut SqliteConnection) -> std::result::Result<R, E> + 'static,
133        R: Send + 'static,
134        M: Send + ToString,
135        E: From<DatabaseError>,
136        E: std::error::Error + Send + Sync + 'static,
137    {
138        let span = tracing::Span::current();
139        self.conn
140            .interact(move |conn| {
141                let _guard = span.enter();
142                query(conn)
143            })
144            .await
145            .map_err(|err| E::from(DatabaseError::interact(&msg.to_string(), &err)))?
146    }
147}