1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
use diesel::r2d2::{ConnectionManager, PoolError};
use diesel::Connection;
use r2d2::{Builder, PooledConnection};
use roa::{async_trait, Context, State, Status};
use std::time::Duration;

/// An alias for r2d2::Pool<diesel::r2d2::ConnectionManager<Conn>>.
pub type Pool<Conn> = r2d2::Pool<ConnectionManager<Conn>>;

/// An alias for r2d2::PooledConnection<diesel::r2d2::ConnectionManager<Conn>>.
pub type WrapConnection<Conn> = PooledConnection<ConnectionManager<Conn>>;

/// Create a connection pool.
///
/// ### Example
///
/// ```
/// use roa_diesel::{make_pool, Pool};
/// use diesel::sqlite::SqliteConnection;
/// use std::error::Error;
///
/// # fn main() -> Result<(), Box<dyn Error>> {
/// let pool: Pool<SqliteConnection> = make_pool(":memory:")?;
/// Ok(())
/// # }
/// ```
pub fn make_pool<Conn>(url: impl Into<String>) -> Result<Pool<Conn>, PoolError>
where
    Conn: Connection + 'static,
{
    r2d2::Pool::new(ConnectionManager::<Conn>::new(url))
}

/// Create a pool builder.
pub fn builder<Conn>() -> Builder<ConnectionManager<Conn>>
where
    Conn: Connection + 'static,
{
    r2d2::Pool::builder()
}

/// A context extension to access r2d2 pool asynchronously.
#[async_trait]
pub trait AsyncPool<Conn>
where
    Conn: Connection + 'static,
{
    /// Retrieves a connection from the pool.
    ///
    /// Waits for at most the configured connection timeout before returning an
    /// error.
    ///
    /// ```
    /// use roa::{Context, Result};
    /// use diesel::sqlite::SqliteConnection;
    /// use roa_diesel::preload::AsyncPool;
    /// use roa_diesel::Pool;
    /// use diesel::r2d2::ConnectionManager;
    ///
    /// #[derive(Clone)]
    /// struct State(Pool<SqliteConnection>);
    ///
    /// impl AsRef<Pool<SqliteConnection>> for State {
    ///     fn as_ref(&self) -> &Pool<SqliteConnection> {
    ///         &self.0
    ///     }
    /// }
    ///
    /// async fn get(ctx: Context<State>) -> Result {
    ///     let conn = ctx.get_conn().await?;
    ///     // handle conn
    ///     Ok(())
    /// }
    /// ```
    async fn get_conn(&self) -> Result<WrapConnection<Conn>, Status>;

    /// Retrieves a connection from the pool, waiting for at most `timeout`
    ///
    /// The given timeout will be used instead of the configured connection
    /// timeout.
    async fn get_timeout(
        &self,
        timeout: Duration,
    ) -> Result<WrapConnection<Conn>, Status>;

    /// Returns information about the current state of the pool.
    async fn pool_state(&self) -> r2d2::State;
}

#[async_trait]
impl<S, Conn> AsyncPool<Conn> for Context<S>
where
    S: State + AsRef<Pool<Conn>>,
    Conn: Connection + 'static,
{
    #[inline]
    async fn get_conn(&self) -> Result<WrapConnection<Conn>, Status> {
        let pool = self.as_ref().clone();
        Ok(self.exec.spawn_blocking(move || pool.get()).await?)
    }

    #[inline]
    async fn get_timeout(
        &self,
        timeout: Duration,
    ) -> Result<WrapConnection<Conn>, Status> {
        let pool = self.as_ref().clone();
        Ok(self
            .exec
            .spawn_blocking(move || pool.get_timeout(timeout))
            .await?)
    }

    #[inline]
    async fn pool_state(&self) -> r2d2::State {
        let pool = self.as_ref().clone();
        self.exec.spawn_blocking(move || pool.state()).await
    }
}