1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
use diesel::{
r2d2,
r2d2::{ConnectionManager, PooledConnection},
result::Error as DieselError,
PgConnection,
};
use diesel_migrations::{EmbeddedMigrations, MigrationHarness};
pub type PgPool = r2d2::Pool<ConnectionManager<PgConnection>>;
pub type PgPooledConnection = PooledConnection<ConnectionManager<PgConnection>>;
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("create pool: {0}")]
CreatePool(String),
#[error("get connection from pool: {0}")]
GetConnection(String),
#[error("database migration: {0}")]
GetMigration(String),
#[error("there are {0} pending migrations")]
PendingMigration(usize),
#[error("task failed: {0}")]
Task(String),
}
#[derive(Clone)]
pub struct Pool {
pool: PgPool,
}
impl Pool {
pub fn new(connection_url: String, migrations: EmbeddedMigrations) -> Result<Self, Error> {
let pool = r2d2::Pool::builder()
.test_on_check_out(true)
.build(ConnectionManager::<PgConnection>::new(connection_url))
.map_err(|e| Error::CreatePool(e.to_string()))?;
check_pending_migrations(&pool, migrations)?;
Ok(Self { pool })
}
pub async fn execute<T, Q>(&self, query: Q) -> Result<Result<T, DieselError>, Error>
where
T: Send + 'static,
Q: FnOnce(PgPooledConnection) -> Result<T, DieselError> + Send + 'static,
{
let conn = self.connection()?;
tokio::task::spawn_blocking(|| query(conn))
.await
.map_err(|e| Error::Task(e.to_string()))
}
fn connection(&self) -> Result<PgPooledConnection, Error> {
self.pool
.get()
.map_err(|e| Error::GetConnection(e.to_string()))
}
}
fn check_pending_migrations(pool: &PgPool, migrations: EmbeddedMigrations) -> Result<(), Error> {
match count_pending_migrations(pool, migrations)? {
0 => Ok(()),
n => Err(Error::PendingMigration(n)),
}
}
fn count_pending_migrations(pool: &PgPool, migrations: EmbeddedMigrations) -> Result<usize, Error> {
let count_pending_migrations = MigrationHarness::pending_migrations(
&mut pool
.get()
.map_err(|e| Error::GetConnection(e.to_string()))?,
migrations,
)
.map_err(|e| Error::GetMigration(e.to_string()))?
.len();
Ok(count_pending_migrations)
}