1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
use diesel::{connection::LoadConnection, migration::MigrationConnection, pg::Pg, Connection};
use diesel_migrations::{EmbeddedMigrations, MigrationHarness};

#[derive(thiserror::Error, Debug)]
pub enum ConnectionError {
    #[error("diesel error: {0}")]
    Diesel(diesel::ConnectionError),
}

#[derive(thiserror::Error, Debug)]
pub enum MigrationError {
    #[error("database migration: {0}")]
    GetMigration(String),
    #[error("there are {0} pending migrations")]
    PendingMigration(usize),
}

pub fn new_connection(
    connection_url: &str,
) -> Result<
    impl Connection<Backend = Pg> + LoadConnection<Backend = Pg> + MigrationConnection<Backend = Pg>,
    ConnectionError,
> {
    diesel::PgConnection::establish(connection_url).map_err(ConnectionError::Diesel)
}

#[cfg(feature = "tracing")]
pub fn new_connection_with_tracing(
    connection_url: &str,
) -> Result<impl Connection<Backend = Pg> + LoadConnection<Backend = Pg>, ConnectionError> {
    diesel_tracing::pg::InstrumentedPgConnection::establish(connection_url)
        .map_err(ConnectionError::Diesel)
}

pub fn check_pending_migrations(
    conn: &mut (impl Connection<Backend = Pg>
              + LoadConnection<Backend = Pg>
              + MigrationConnection<Backend = Pg>
              + 'static),
    migrations: EmbeddedMigrations,
) -> Result<(), MigrationError> {
    match count_pending_migrations(conn, migrations)? {
        0 => Ok(()),
        n => Err(MigrationError::PendingMigration(n)),
    }
}

fn count_pending_migrations(
    conn: &mut (impl Connection<Backend = Pg>
              + LoadConnection<Backend = Pg>
              + MigrationConnection<Backend = Pg>
              + 'static),
    migrations: EmbeddedMigrations,
) -> Result<usize, MigrationError> {
    let count_pending_migrations = MigrationHarness::pending_migrations(conn, migrations)
        .map_err(|e| MigrationError::GetMigration(e.to_string()))?
        .len();
    Ok(count_pending_migrations)
}