1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
use diesel::{
    r2d2,
    r2d2::{ConnectionManager, PooledConnection},
    result::Error as DieselError,
    PgConnection,
};
use diesel_migrations::{EmbeddedMigrations, MigrationHarness};

pub type PgPool = r2d2::Pool<ConnectionManager<PgConnection>>;
pub type PgPooledConnection = PooledConnection<ConnectionManager<PgConnection>>;

#[derive(thiserror::Error, Debug)]
pub enum Error {
    #[error("create pool: {0}")]
    CreatePool(String),
    #[error("get connection from pool: {0}")]
    GetConnection(String),
    #[error("database migration: {0}")]
    GetMigration(String),
    #[error("there are {0} pending migrations")]
    PendingMigration(usize),
    #[error("task failed: {0}")]
    Task(String),
}

#[derive(Clone)]
pub struct Pool {
    pool: PgPool,
}

impl Pool {
    pub fn new(connection_url: String, migrations: EmbeddedMigrations) -> Result<Self, Error> {
        let pool = r2d2::Pool::builder()
            .test_on_check_out(true)
            .build(ConnectionManager::<PgConnection>::new(connection_url))
            .map_err(|e| Error::CreatePool(e.to_string()))?;
        check_pending_migrations(&pool, migrations)?;
        Ok(Self { pool })
    }

    pub async fn execute<T, Q>(&self, query: Q) -> Result<Result<T, DieselError>, Error>
    where
        T: Send + 'static,
        Q: FnOnce(PgPooledConnection) -> Result<T, DieselError> + Send + 'static,
    {
        let conn = self.connection()?;
        tokio::task::spawn_blocking(|| query(conn))
            .await
            .map_err(|e| Error::Task(e.to_string()))
    }

    fn connection(&self) -> Result<PgPooledConnection, Error> {
        self.pool
            .get()
            .map_err(|e| Error::GetConnection(e.to_string()))
    }
}

fn check_pending_migrations(pool: &PgPool, migrations: EmbeddedMigrations) -> Result<(), Error> {
    match count_pending_migrations(pool, migrations)? {
        0 => Ok(()),
        n => Err(Error::PendingMigration(n)),
    }
}

fn count_pending_migrations(pool: &PgPool, migrations: EmbeddedMigrations) -> Result<usize, Error> {
    let count_pending_migrations = MigrationHarness::pending_migrations(
        &mut pool
            .get()
            .map_err(|e| Error::GetConnection(e.to_string()))?,
        migrations,
    )
    .map_err(|e| Error::GetMigration(e.to_string()))?
    .len();
    Ok(count_pending_migrations)
}