1use std::marker::PhantomData;
2
3pub use mobc;
4pub use sqlx;
5
6use mobc::{Manager, async_trait};
7use sqlx::{Connection as _, Database};
8
9pub struct SqlxConnectionManager<DB>
10where
11 DB: Database + Sync,
12{
13 url: String,
14 _phantom: PhantomData<DB>,
15}
16
17impl<DB> SqlxConnectionManager<DB>
18where
19 DB: Database + Sync,
20{
21 #[must_use]
22 pub fn new<S>(url: S) -> Self
23 where
24 S: ToString,
25 {
26 Self {
27 url: url.to_string(),
28 _phantom: PhantomData,
29 }
30 }
31}
32
33#[async_trait]
34impl<DB> Manager for SqlxConnectionManager<DB>
35where
36 DB: Database + Sync,
37{
38 type Connection = DB::Connection;
39 type Error = sqlx::Error;
40
41 async fn connect(&self) -> Result<Self::Connection, Self::Error> {
42 Self::Connection::connect(&self.url).await
43 }
44
45 async fn check(
46 &self,
47 mut conn: Self::Connection,
48 ) -> Result<Self::Connection, Self::Error> {
49 conn.ping().await.map(|()| conn)
50 }
51}