pub extern crate r2d2;
#[cfg(any(feature = "diesel_sqlite_pool",
feature = "diesel_postgres_pool",
feature = "diesel_mysql_pool"))]
pub extern crate diesel;
use std::collections::BTreeMap;
use std::fmt::{self, Display, Formatter};
use std::marker::{Send, Sized};
use rocket::config::{self, Value};
use self::r2d2::ManageConnection;
#[doc(hidden)] pub use rocket_contrib_codegen::*;
#[cfg(feature = "postgres_pool")] pub extern crate postgres;
#[cfg(feature = "postgres_pool")] pub extern crate r2d2_postgres;
#[cfg(feature = "mysql_pool")] pub extern crate mysql;
#[cfg(feature = "mysql_pool")] pub extern crate r2d2_mysql;
#[cfg(feature = "sqlite_pool")] pub extern crate rusqlite;
#[cfg(feature = "sqlite_pool")] pub extern crate r2d2_sqlite;
#[cfg(feature = "cypher_pool")] pub extern crate rusted_cypher;
#[cfg(feature = "cypher_pool")] pub extern crate r2d2_cypher;
#[cfg(feature = "redis_pool")] pub extern crate redis;
#[cfg(feature = "redis_pool")] pub extern crate r2d2_redis;
#[cfg(feature = "mongodb_pool")] pub extern crate mongodb;
#[cfg(feature = "mongodb_pool")] pub extern crate r2d2_mongodb;
#[cfg(feature = "memcache_pool")] pub extern crate memcache;
#[cfg(feature = "memcache_pool")] pub extern crate r2d2_memcache;
#[derive(Debug, Clone, PartialEq)]
pub struct DatabaseConfig<'a> {
pub url: &'a str,
pub pool_size: u32,
pub extras: BTreeMap<String, Value>,
}
#[derive(Debug)]
pub enum DbError<T> {
Custom(T),
PoolError(r2d2::Error),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConfigError {
MissingTable,
MissingKey,
MalformedConfiguration,
MissingUrl,
MalformedUrl,
InvalidPoolSize(i64),
}
pub fn database_config<'a>(
name: &str,
from: &'a config::Config
) -> Result<DatabaseConfig<'a>, ConfigError> {
let connection_config = from.get_table("databases")
.map_err(|_| ConfigError::MissingTable)?
.get(name)
.ok_or(ConfigError::MissingKey)?
.as_table()
.ok_or(ConfigError::MalformedConfiguration)?;
let maybe_url = connection_config.get("url")
.ok_or(ConfigError::MissingUrl)?;
let url = maybe_url.as_str().ok_or(ConfigError::MalformedUrl)?;
let pool_size = connection_config.get("pool_size")
.and_then(Value::as_integer)
.unwrap_or(from.workers as i64);
if pool_size < 1 || pool_size > u32::max_value() as i64 {
return Err(ConfigError::InvalidPoolSize(pool_size));
}
let mut extras = connection_config.clone();
extras.remove("url");
extras.remove("pool_size");
Ok(DatabaseConfig { url, pool_size: pool_size as u32, extras: extras })
}
impl<'a> Display for ConfigError {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
match self {
ConfigError::MissingTable => {
write!(f, "A table named `databases` was not found for this configuration")
},
ConfigError::MissingKey => {
write!(f, "An entry in the `databases` table was not found for this key")
},
ConfigError::MalformedConfiguration => {
write!(f, "The configuration for this database is malformed")
}
ConfigError::MissingUrl => {
write!(f, "The connection URL is missing for this database")
},
ConfigError::MalformedUrl => {
write!(f, "The specified connection URL is malformed")
},
ConfigError::InvalidPoolSize(invalid_size) => {
write!(f, "'{}' is not a valid value for `pool_size`", invalid_size)
},
}
}
}
pub trait Poolable: Send + Sized + 'static {
type Manager: ManageConnection<Connection=Self>;
type Error;
fn pool(config: DatabaseConfig) -> Result<r2d2::Pool<Self::Manager>, Self::Error>;
}
#[cfg(feature = "diesel_sqlite_pool")]
impl Poolable for diesel::SqliteConnection {
type Manager = diesel::r2d2::ConnectionManager<diesel::SqliteConnection>;
type Error = r2d2::Error;
fn pool(config: DatabaseConfig) -> Result<r2d2::Pool<Self::Manager>, Self::Error> {
let manager = diesel::r2d2::ConnectionManager::new(config.url);
r2d2::Pool::builder().max_size(config.pool_size).build(manager)
}
}
#[cfg(feature = "diesel_postgres_pool")]
impl Poolable for diesel::PgConnection {
type Manager = diesel::r2d2::ConnectionManager<diesel::PgConnection>;
type Error = r2d2::Error;
fn pool(config: DatabaseConfig) -> Result<r2d2::Pool<Self::Manager>, Self::Error> {
let manager = diesel::r2d2::ConnectionManager::new(config.url);
r2d2::Pool::builder().max_size(config.pool_size).build(manager)
}
}
#[cfg(feature = "diesel_mysql_pool")]
impl Poolable for diesel::MysqlConnection {
type Manager = diesel::r2d2::ConnectionManager<diesel::MysqlConnection>;
type Error = r2d2::Error;
fn pool(config: DatabaseConfig) -> Result<r2d2::Pool<Self::Manager>, Self::Error> {
let manager = diesel::r2d2::ConnectionManager::new(config.url);
r2d2::Pool::builder().max_size(config.pool_size).build(manager)
}
}
#[cfg(feature = "postgres_pool")]
impl Poolable for postgres::Connection {
type Manager = r2d2_postgres::PostgresConnectionManager;
type Error = DbError<postgres::Error>;
fn pool(config: DatabaseConfig) -> Result<r2d2::Pool<Self::Manager>, Self::Error> {
let manager = r2d2_postgres::PostgresConnectionManager::new(config.url, r2d2_postgres::TlsMode::None)
.map_err(DbError::Custom)?;
r2d2::Pool::builder().max_size(config.pool_size).build(manager)
.map_err(DbError::PoolError)
}
}
#[cfg(feature = "mysql_pool")]
impl Poolable for mysql::Conn {
type Manager = r2d2_mysql::MysqlConnectionManager;
type Error = r2d2::Error;
fn pool(config: DatabaseConfig) -> Result<r2d2::Pool<Self::Manager>, Self::Error> {
let opts = mysql::OptsBuilder::from_opts(config.url);
let manager = r2d2_mysql::MysqlConnectionManager::new(opts);
r2d2::Pool::builder().max_size(config.pool_size).build(manager)
}
}
#[cfg(feature = "sqlite_pool")]
impl Poolable for rusqlite::Connection {
type Manager = r2d2_sqlite::SqliteConnectionManager;
type Error = r2d2::Error;
fn pool(config: DatabaseConfig) -> Result<r2d2::Pool<Self::Manager>, Self::Error> {
let manager = r2d2_sqlite::SqliteConnectionManager::file(config.url);
r2d2::Pool::builder().max_size(config.pool_size).build(manager)
}
}
#[cfg(feature = "cypher_pool")]
impl Poolable for rusted_cypher::GraphClient {
type Manager = r2d2_cypher::CypherConnectionManager;
type Error = r2d2::Error;
fn pool(config: DatabaseConfig) -> Result<r2d2::Pool<Self::Manager>, Self::Error> {
let manager = r2d2_cypher::CypherConnectionManager { url: config.url.to_string() };
r2d2::Pool::builder().max_size(config.pool_size).build(manager)
}
}
#[cfg(feature = "redis_pool")]
impl Poolable for redis::Connection {
type Manager = r2d2_redis::RedisConnectionManager;
type Error = DbError<redis::RedisError>;
fn pool(config: DatabaseConfig) -> Result<r2d2::Pool<Self::Manager>, Self::Error> {
let manager = r2d2_redis::RedisConnectionManager::new(config.url).map_err(DbError::Custom)?;
r2d2::Pool::builder().max_size(config.pool_size).build(manager)
.map_err(DbError::PoolError)
}
}
#[cfg(feature = "mongodb_pool")]
impl Poolable for mongodb::db::Database {
type Manager = r2d2_mongodb::MongodbConnectionManager;
type Error = DbError<mongodb::Error>;
fn pool(config: DatabaseConfig) -> Result<r2d2::Pool<Self::Manager>, Self::Error> {
let manager = r2d2_mongodb::MongodbConnectionManager::new_with_uri(config.url).map_err(DbError::Custom)?;
r2d2::Pool::builder().max_size(config.pool_size).build(manager).map_err(DbError::PoolError)
}
}
#[cfg(feature = "memcache_pool")]
impl Poolable for memcache::Client {
type Manager = r2d2_memcache::MemcacheConnectionManager;
type Error = DbError<memcache::MemcacheError>;
fn pool(config: DatabaseConfig) -> Result<r2d2::Pool<Self::Manager>, Self::Error> {
let manager = r2d2_memcache::MemcacheConnectionManager::new(config.url);
r2d2::Pool::builder().max_size(config.pool_size).build(manager).map_err(DbError::PoolError)
}
}
#[cfg(test)]
mod tests {
use std::collections::BTreeMap;
use rocket::{Config, config::{Environment, Value}};
use super::{ConfigError::*, database_config};
#[test]
fn no_database_entry_in_config_returns_error() {
let config = Config::build(Environment::Development)
.finalize()
.unwrap();
let database_config_result = database_config("dummy_db", &config);
assert_eq!(Err(MissingTable), database_config_result);
}
#[test]
fn no_matching_connection_returns_error() {
let mut database_extra = BTreeMap::new();
let mut connection_config = BTreeMap::new();
connection_config.insert("url".to_string(), Value::from("dummy_db.sqlite"));
connection_config.insert("pool_size".to_string(), Value::from(10));
database_extra.insert("dummy_db".to_string(), Value::from(connection_config));
let config = Config::build(Environment::Development)
.extra("databases", database_extra)
.finalize()
.unwrap();
let database_config_result = database_config("real_db", &config);
assert_eq!(Err(MissingKey), database_config_result);
}
#[test]
fn incorrectly_structured_config_returns_error() {
let mut database_extra = BTreeMap::new();
let connection_config = vec!["url", "dummy_db.slqite"];
database_extra.insert("dummy_db".to_string(), Value::from(connection_config));
let config = Config::build(Environment::Development)
.extra("databases", database_extra)
.finalize()
.unwrap();
let database_config_result = database_config("dummy_db", &config);
assert_eq!(Err(MalformedConfiguration), database_config_result);
}
#[test]
fn missing_connection_string_returns_error() {
let mut database_extra = BTreeMap::new();
let connection_config: BTreeMap<String, Value> = BTreeMap::new();
database_extra.insert("dummy_db", connection_config);
let config = Config::build(Environment::Development)
.extra("databases", database_extra)
.finalize()
.unwrap();
let database_config_result = database_config("dummy_db", &config);
assert_eq!(Err(MissingUrl), database_config_result);
}
#[test]
fn invalid_connection_string_returns_error() {
let mut database_extra = BTreeMap::new();
let mut connection_config = BTreeMap::new();
connection_config.insert("url".to_string(), Value::from(42));
database_extra.insert("dummy_db", connection_config);
let config = Config::build(Environment::Development)
.extra("databases", database_extra)
.finalize()
.unwrap();
let database_config_result = database_config("dummy_db", &config);
assert_eq!(Err(MalformedUrl), database_config_result);
}
#[test]
fn negative_pool_size_returns_error() {
let mut database_extra = BTreeMap::new();
let mut connection_config = BTreeMap::new();
connection_config.insert("url".to_string(), Value::from("dummy_db.sqlite"));
connection_config.insert("pool_size".to_string(), Value::from(-1));
database_extra.insert("dummy_db", connection_config);
let config = Config::build(Environment::Development)
.extra("databases", database_extra)
.finalize()
.unwrap();
let database_config_result = database_config("dummy_db", &config);
assert_eq!(Err(InvalidPoolSize(-1)), database_config_result);
}
#[test]
fn pool_size_beyond_u32_max_returns_error() {
let mut database_extra = BTreeMap::new();
let mut connection_config = BTreeMap::new();
let over_max = (u32::max_value()) as i64 + 1;
connection_config.insert("url".to_string(), Value::from("dummy_db.sqlite"));
connection_config.insert("pool_size".to_string(), Value::from(over_max));
database_extra.insert("dummy_db", connection_config);
let config = Config::build(Environment::Development)
.extra("databases", database_extra)
.finalize()
.unwrap();
let database_config_result = database_config("dummy_db", &config);
assert_eq!(Err(InvalidPoolSize(over_max)), database_config_result);
}
#[test]
fn happy_path_database_config() {
let url = "dummy_db.sqlite";
let pool_size = 10;
let mut database_extra = BTreeMap::new();
let mut connection_config = BTreeMap::new();
connection_config.insert("url".to_string(), Value::from(url));
connection_config.insert("pool_size".to_string(), Value::from(pool_size));
database_extra.insert("dummy_db", connection_config);
let config = Config::build(Environment::Development)
.extra("databases", database_extra)
.finalize()
.unwrap();
let database_config = database_config("dummy_db", &config).unwrap();
assert_eq!(url, database_config.url);
assert_eq!(pool_size, database_config.pool_size);
assert_eq!(0, database_config.extras.len());
}
#[test]
fn extras_do_not_contain_required_keys() {
let url = "dummy_db.sqlite";
let pool_size = 10;
let mut database_extra = BTreeMap::new();
let mut connection_config = BTreeMap::new();
connection_config.insert("url".to_string(), Value::from(url));
connection_config.insert("pool_size".to_string(), Value::from(pool_size));
database_extra.insert("dummy_db", connection_config);
let config = Config::build(Environment::Development)
.extra("databases", database_extra)
.finalize()
.unwrap();
let database_config = database_config("dummy_db", &config).unwrap();
assert_eq!(url, database_config.url);
assert_eq!(pool_size, database_config.pool_size);
assert_eq!(false, database_config.extras.contains_key("url"));
assert_eq!(false, database_config.extras.contains_key("pool_size"));
}
#[test]
fn extra_values_are_placed_in_extras_map() {
let url = "dummy_db.sqlite";
let pool_size = 10;
let tls_cert = "certs.pem";
let tls_key = "key.pem";
let mut database_extra = BTreeMap::new();
let mut connection_config = BTreeMap::new();
connection_config.insert("url".to_string(), Value::from(url));
connection_config.insert("pool_size".to_string(), Value::from(pool_size));
connection_config.insert("certs".to_string(), Value::from(tls_cert));
connection_config.insert("key".to_string(), Value::from(tls_key));
database_extra.insert("dummy_db", connection_config);
let config = Config::build(Environment::Development)
.extra("databases", database_extra)
.finalize()
.unwrap();
let database_config = database_config("dummy_db", &config).unwrap();
assert_eq!(url, database_config.url);
assert_eq!(pool_size, database_config.pool_size);
assert_eq!(true, database_config.extras.contains_key("certs"));
assert_eq!(true, database_config.extras.contains_key("key"));
println!("{:#?}", database_config);
}
}