db-derive 0.1.8

PostgreSQL/SQLite pooling derive system
Documentation
use {
    crate::{table::Schema, Error, Transaction},
    std::path::Path,
};

#[derive(Clone, Debug)]
pub enum Pool {
    #[cfg(feature = "postgresql")]
    PostgreSQL(r2d2::Pool<self::postgres::PostgresConnectionManager<::postgres::NoTls>>),
    #[cfg(feature = "sqlite")]
    SQLite(r2d2::Pool<self::sqlite::SqliteConnectionManager>),
}

impl Pool {
    #[cfg(feature = "postgresql")]
    pub fn postgres(config: ::postgres::Config) -> Result<Self, Error> {
        let conn = Pool::PostgreSQL(r2d2::Pool::new(
            crate::pool::postgres::PostgresConnectionManager::new(config, ::postgres::NoTls),
        )?);

        Ok(conn)
    }

    #[cfg(feature = "sqlite")]
    pub fn sqlite(path: impl AsRef<Path>) -> Result<Self, Error> {
        let conn = Pool::SQLite(r2d2::Pool::new(
            crate::pool::sqlite::SqliteConnectionManager::file(path),
        )?);

        Ok(conn)
    }

    pub fn as_kind(&self) -> PoolKind {
        match self {
            #[cfg(feature = "postgresql")]
            Pool::PostgreSQL(_) => PoolKind::PostgreSQL,
            #[cfg(feature = "sqlite")]
            Pool::SQLite(_) => PoolKind::SQLite,
        }
    }

    pub fn batch_execute(&self, exec: impl AsRef<str>) -> Result<(), Error> {
        match self {
            #[cfg(feature = "postgresql")]
            Pool::PostgreSQL(pool) => {
                let mut conn = pool.get()?;

                conn.batch_execute(exec.as_ref())?;
            }
            #[cfg(feature = "sqlite")]
            Pool::SQLite(pool) => {
                let conn = pool.get()?;

                conn.execute_batch(exec.as_ref())?;
            }
        }

        Ok(())
    }

    pub fn transaction(
        &self,
        run: impl FnOnce(Transaction<'_>) -> Result<(), Error>,
    ) -> Result<(), Error> {
        match self {
            #[cfg(feature = "postgresql")]
            Pool::PostgreSQL(pool) => {
                let mut conn = pool.get()?;

                let trans = conn.transaction()?;

                let inner = Transaction::PostgreSQL(trans);

                run(inner)?;
            }
            #[cfg(feature = "sqlite")]
            Pool::SQLite(pool) => {
                let mut conn = pool.get()?;

                let trans = conn.transaction()?;

                let inner = Transaction::SQLite(trans);

                run(inner)?;
            }
        }

        Ok(())
    }

    pub fn schema<T: Schema>(&self) -> Result<(), Error> {
        match self {
            #[cfg(feature = "postgresql")]
            Pool::PostgreSQL(pool) => {
                let mut conn = pool.get()?;

                conn.batch_execute(T::schema_postgres())?;
            }
            #[cfg(feature = "sqlite")]
            Pool::SQLite(pool) => {
                let conn = pool.get()?;

                conn.execute_batch(T::schema_sqlite())?;
            }
        }

        Ok(())
    }
}

#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum PoolKind {
    #[cfg(feature = "postgresql")]
    PostgreSQL,
    #[cfg(feature = "sqlite")]
    SQLite,
}

impl From<Pool> for PoolKind {
    fn from(pool: Pool) -> PoolKind {
        pool.as_kind()
    }
}

impl<'a> From<&'a Pool> for PoolKind {
    fn from(pool: &'a Pool) -> PoolKind {
        pool.as_kind()
    }
}

#[cfg(feature = "postgresql")]
pub mod postgres {
    use {
        postgres::{
            tls::{MakeTlsConnect, TlsConnect},
            Client, Config, Error, Socket,
        },
        r2d2::ManageConnection,
    };

    #[derive(Debug)]
    pub struct PostgresConnectionManager<T> {
        config: Config,
        tls_connector: T,
    }

    impl<T> PostgresConnectionManager<T>
    where
        T: MakeTlsConnect<Socket> + Clone + 'static + Sync + Send,
        T::TlsConnect: Send,
        T::Stream: Send,
        <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
    {
        /// Creates a new `PostgresConnectionManager`.
        pub fn new(config: Config, tls_connector: T) -> PostgresConnectionManager<T> {
            PostgresConnectionManager {
                config,
                tls_connector,
            }
        }
    }

    impl<T> ManageConnection for PostgresConnectionManager<T>
    where
        T: MakeTlsConnect<Socket> + Clone + 'static + Sync + Send,
        T::TlsConnect: Send,
        T::Stream: Send,
        <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
    {
        type Connection = Client;
        type Error = Error;

        fn connect(&self) -> Result<Client, Error> {
            self.config.connect(self.tls_connector.clone())
        }

        fn is_valid(&self, client: &mut Client) -> Result<(), Error> {
            client.simple_query("").map(|_| ())
        }

        fn has_broken(&self, client: &mut Client) -> bool {
            client.is_closed()
        }
    }
}

#[cfg(feature = "sqlite")]
pub mod sqlite {
    use {
        rusqlite::{Connection, Error, OpenFlags},
        std::{
            fmt,
            path::{Path, PathBuf},
        },
    };

    pub struct SqliteConnectionManager {
        path: PathBuf,
    }

    impl fmt::Debug for SqliteConnectionManager {
        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
            let mut builder = f.debug_struct("SqliteConnectionManager");
            let _ = builder.field("path", &self.path);
            builder.finish()
        }
    }

    impl SqliteConnectionManager {
        pub fn file<P: AsRef<Path>>(path: P) -> Self {
            Self {
                path: path.as_ref().to_path_buf(),
            }
        }
    }

    impl r2d2::ManageConnection for SqliteConnectionManager {
        type Connection = Connection;
        type Error = rusqlite::Error;

        fn connect(&self) -> Result<Connection, Error> {
            Connection::open_with_flags(&self.path, OpenFlags::default()).map_err(Into::into)
        }

        fn is_valid(&self, conn: &mut Connection) -> Result<(), Error> {
            conn.execute_batch("").map_err(Into::into)
        }

        fn has_broken(&self, _: &mut Connection) -> bool {
            false
        }
    }
}