use self::{
interval::PgInterval,
numeric::PgNumeric,
};
#[cfg(feature = "db-auth")]
use crate::db_auth::{
Role,
User,
};
use crate::{
error::{
DataOpError,
PlatformError,
},
table::SchemaContent,
DbError,
TableDef,
TableName,
Value,
*,
};
use bigdecimal::BigDecimal;
use bytes::BytesMut;
use geo_types::Point;
use log::*;
use postgres::{
self,
types::{
to_sql_checked,
FromSql,
IsNull,
Kind,
ToSql,
Type,
},
NoTls,
};
use r2d2::{
self,
ManageConnection,
};
use clia_rustorm_dao::{
value::Array,
Interval,
Rows,
};
use std::{
error::Error,
fmt,
string::FromUtf8Error,
};
use thiserror::Error;
use url::Url;
mod column_info;
#[allow(unused)]
mod interval;
mod numeric;
mod table_info;
pub fn init_pool(
db_url: &str,
) -> Result<r2d2::Pool<r2d2_postgres::PostgresConnectionManager<NoTls>>, PostgresError> {
let url = Url::parse(db_url).expect("Invalid DB url");
let mut config = postgres::Config::new();
config
.host(url.host_str().expect("invalid DB URL"))
.user(url.username());
if let Some(password) = url.password() {
config.password(password);
}
if let Some(database) = url.path_segments().into_iter().flatten().next() {
config.dbname(database);
}
test_connection(config.clone())?;
let manager = r2d2_postgres::PostgresConnectionManager::new(config, NoTls);
let pool = r2d2::Pool::new(manager)?;
Ok(pool)
}
pub fn test_connection(config: postgres::Config) -> Result<(), PostgresError> {
let manager = r2d2_postgres::PostgresConnectionManager::new(config, NoTls);
let mut conn = manager
.connect()
.map_err(|e| PostgresError::Sql(e, "Connect Error".into()))?;
manager
.is_valid(&mut conn)
.map_err(|e| PostgresError::Sql(e, "Invalid Connection".into()))?;
Ok(())
}
pub struct PostgresDB(pub r2d2::PooledConnection<r2d2_postgres::PostgresConnectionManager<NoTls>>);
impl PostgresDB {
fn pg_execute_sql_with_return(
&mut self,
sql: &str,
param: &[&Value],
) -> Result<Rows, postgres::Error> {
let stmt = self.0.prepare(sql)?;
let pg_values = to_pg_values(param);
let sql_types = to_sql_types(&pg_values);
let rows = self.0.query(&stmt, &*sql_types)?;
let columns = rows.first().into_iter().flat_map(postgres::Row::columns);
let column_names: Vec<String> = columns.map(|c| c.name().to_string()).collect();
let column_count = column_names.len();
let mut records = Rows::new(column_names);
for r in rows.iter() {
let mut record: Vec<Value> = vec![];
for column_index in 0..column_count {
let value: Option<OwnedPgValue> = r.get(column_index);
match value {
Some(value) => record.push(value.0),
None => {
record.push(Value::Nil); }
}
}
records.push(record);
}
Ok(records)
}
}
impl Database for PostgresDB {
fn begin_transaction(&mut self) -> Result<(), DbError> {
self.execute_sql_with_return("BEGIN TRANSACTION", &[])?;
Ok(())
}
fn commit_transaction(&mut self) -> Result<(), DbError> {
self.execute_sql_with_return("COMMIT TRANSACTION", &[])?;
Ok(())
}
fn rollback_transaction(&mut self) -> Result<(), DbError> {
self.execute_sql_with_return("ROLLBACK TRANSACTION", &[])?;
Ok(())
}
fn execute_sql_with_return(&mut self, sql: &str, param: &[&Value]) -> Result<Rows, DbError> {
self.pg_execute_sql_with_return(sql, param).map_err(|e| {
Into::<DataOpError>::into(PlatformError::PostgresError(PostgresError::Sql(
e,
sql.to_string(),
)))
.into()
})
}
fn get_table(&mut self, table_name: &TableName) -> Result<Option<TableDef>, DbError> {
table_info::get_table(&mut *self, table_name)
}
fn set_autoincrement_value(
&mut self,
table_name: &TableName,
sequence_value: i64,
) -> Result<Option<i64>, DbError> {
if let Some(table) = self.get_table(table_name)? {
let pk = table.get_primary_columns();
assert_eq!(
pk.len(),
1,
"auto increment only supports 1 primary column table"
);
let pk_column = pk.get(0).expect("must have a primary column");
if let Some(pk_sequnce_name) = pk_column.autoincrement_sequence_name() {
let sql = format!("SELECT setval('{}',$1) AS value", pk_sequnce_name);
let rows = self.execute_sql_with_return(&sql, &[&sequence_value.to_value()])?;
let row = rows.iter().next().expect("must have 1 row");
let value = row.get("value").expect("value");
Ok(Some(value))
} else {
Ok(None)
}
} else {
Err(DbError::DataError(DataError::TableNameNotFound(
table_name.complete_name(),
)))
}
}
fn get_autoincrement_last_value(
&mut self,
table_name: &TableName,
) -> Result<Option<i64>, DbError> {
if let Some(table) = self.get_table(table_name)? {
let pk = table.get_primary_columns();
assert_eq!(
pk.len(),
1,
"auto increment only supports 1 primary column table"
);
let pk_column = pk.get(0).expect("must have a primary column");
if let Some(pk_sequnce_name) = pk_column.autoincrement_sequence_name() {
let sql = format!("SELECT last_value FROM {}", pk_sequnce_name);
let rows = self.execute_sql_with_return(&sql, &[])?;
let row = rows.iter().next().expect("must have 1 row");
let last_value = row.get("last_value").expect("must have a last_value");
Ok(Some(last_value))
} else {
Ok(None)
}
} else {
Err(DbError::DataError(DataError::TableNameNotFound(
table_name.complete_name(),
)))
}
}
fn get_all_tables(&mut self) -> Result<Vec<TableDef>, DbError> {
table_info::get_all_tables(&mut *self)
}
fn get_tablenames(&mut self) -> Result<Vec<TableName>, DbError> {
table_info::get_tablenames(&mut *self)
}
fn get_grouped_tables(&mut self) -> Result<Vec<SchemaContent>, DbError> {
table_info::get_organized_tables(&mut *self)
}
#[cfg(feature = "db-auth")]
fn get_users(&mut self) -> Result<Vec<User>, DbError> {
let sql = "SELECT oid::int AS sysid,
rolname AS username,
rolsuper AS is_superuser,
rolinherit AS is_inherit,
rolcreaterole AS can_create_role,
rolcreatedb AS can_create_db,
rolcanlogin AS can_login,
rolreplication AS can_do_replication,
rolbypassrls AS can_bypass_rls,
CASE WHEN rolconnlimit < 0 THEN NULL
ELSE rolconnlimit END AS conn_limit,
CASE WHEN rolvaliduntil = 'infinity'::timestamp THEN NULL
ELSE rolvaliduntil
END AS valid_until
FROM pg_authid";
let rows: Result<Rows, DbError> = self.execute_sql_with_return(sql, &[]);
rows.map(|rows| {
rows.iter()
.map(|row| {
User {
sysid: row.get("sysid").expect("sysid"),
username: row.get("username").expect("username"),
is_superuser: row.get("is_superuser").expect("is_superuser"),
is_inherit: row.get("is_inherit").expect("is_inherit"),
can_create_db: row.get("can_create_db").expect("can_create_db"),
can_create_role: row.get("can_create_role").expect("can_create_role"),
can_login: row.get("can_login").expect("can_login"),
can_do_replication: row
.get("can_do_replication")
.expect("can_do_replication"),
can_bypass_rls: row.get("can_bypass_rls").expect("can_bypass_rls"),
valid_until: row.get("valid_until").expect("valid_until"),
conn_limit: row.get("conn_limit").expect("conn_limit"),
}
})
.collect()
})
}
#[cfg(feature = "db-auth")]
fn get_user_detail(&mut self, username: &str) -> Result<Vec<User>, DbError> {
let sql = "SELECT oid::int AS sysid,
rolname AS username,
rolsuper AS is_superuser,
rolinherit AS is_inherit,
rolcreaterole AS can_create_role,
rolcreatedb AS can_create_db,
rolcanlogin AS can_login,
rolreplication AS can_do_replication,
rolbypassrls AS can_bypass_rls,
CASE WHEN rolconnlimit < 0 THEN NULL
ELSE rolconnlimit END AS conn_limit,
CASE WHEN rolvaliduntil = 'infinity'::timestamp THEN NULL
ELSE rolvaliduntil
END AS valid_until
FROM pg_authid
WHERE rolname = $1
";
let rows: Result<Rows, DbError> =
self.execute_sql_with_return(sql, &[&username.to_value()]);
rows.map(|rows| {
rows.iter()
.map(|row| {
User {
sysid: row.get("sysid").expect("sysid"),
username: row.get("username").expect("username"),
is_superuser: row.get("is_superuser").expect("is_superuser"),
is_inherit: row.get("is_inherit").expect("is_inherit"),
can_create_db: row.get("can_create_db").expect("can_create_db"),
can_create_role: row.get("can_create_role").expect("can_create_role"),
can_login: row.get("can_login").expect("can_login"),
can_do_replication: row
.get("can_do_replication")
.expect("can_do_replication"),
can_bypass_rls: row.get("can_bypass_rls").expect("can_bypass_rls"),
valid_until: row.get("valid_until").expect("valid_until"),
conn_limit: row.get("conn_limit").expect("conn_limit"),
}
})
.collect()
})
}
#[cfg(feature = "db-auth")]
fn get_roles(&mut self, username: &str) -> Result<Vec<Role>, DbError> {
let sql = "SELECT
(SELECT rolname FROM pg_roles WHERE oid = m.roleid) AS role_name
FROM pg_auth_members m
LEFT JOIN pg_roles
ON m.member = pg_roles.oid
WHERE pg_roles.rolname = $1
";
self.execute_sql_with_return(sql, &[&username.to_value()])
.map(|rows| {
rows.iter()
.map(|row| {
Role {
role_name: row.get("role_name").expect("role_name"),
}
})
.collect()
})
}
fn get_database_name(&mut self) -> Result<Option<DatabaseName>, DbError> {
let sql = "SELECT current_database() AS name,
description FROM pg_database
LEFT JOIN pg_shdescription ON objoid = pg_database.oid
WHERE datname = current_database()";
let mut database_names: Vec<Option<DatabaseName>> =
self.execute_sql_with_return(sql, &[]).map(|rows| {
rows.iter()
.map(|row| {
row.get_opt("name").expect("must not error").map(|name| {
DatabaseName {
name,
description: None,
}
})
})
.collect()
})?;
if !database_names.is_empty() {
Ok(database_names.remove(0))
} else {
Ok(None)
}
}
}
fn to_pg_values<'a>(values: &[&'a Value]) -> Vec<PgValue<'a>> {
values.iter().map(|v| PgValue(v)).collect()
}
fn to_sql_types<'a>(values: &'a [PgValue]) -> Vec<&'a (dyn ToSql + Sync)> {
let mut sql_types = vec![];
for v in values.iter() {
sql_types.push(&*v as &(dyn ToSql + Sync));
}
sql_types
}
#[derive(Debug)]
pub struct PgValue<'a>(&'a Value);
#[derive(Debug)]
pub struct OwnedPgValue(Value);
impl<'a> ToSql for PgValue<'a> {
to_sql_checked!();
fn to_sql(
&self,
ty: &Type,
out: &mut BytesMut,
) -> Result<IsNull, Box<dyn Error + 'static + Sync + Send>> {
match *self.0 {
Value::Bool(ref v) => v.to_sql(ty, out),
Value::Tinyint(ref v) => v.to_sql(ty, out),
Value::Smallint(ref v) => v.to_sql(ty, out),
Value::Int(ref v) => v.to_sql(ty, out),
Value::Bigint(ref v) => v.to_sql(ty, out),
Value::Float(ref v) => v.to_sql(ty, out),
Value::Double(ref v) => v.to_sql(ty, out),
Value::Blob(ref v) => v.to_sql(ty, out),
Value::Char(ref v) => v.to_string().to_sql(ty, out),
Value::Text(ref v) => v.to_sql(ty, out),
Value::Uuid(ref v) => v.to_sql(ty, out),
Value::Date(ref v) => v.to_sql(ty, out),
Value::Timestamp(ref v) => v.to_sql(ty, out),
Value::DateTime(ref v) => v.to_sql(ty, out),
Value::Time(ref v) => v.to_sql(ty, out),
Value::Interval(ref _v) => panic!("storing interval in DB is not supported"),
Value::BigDecimal(ref v) => {
let numeric: PgNumeric = v.into();
numeric.to_sql(ty, out)
}
Value::Json(ref v) => v.to_sql(ty, out),
Value::Point(ref v) => v.to_sql(ty, out),
Value::Array(ref v) => {
match *v {
Array::Text(ref av) => av.to_sql(ty, out),
Array::Int(ref av) => av.to_sql(ty, out),
Array::Float(ref av) => av.to_sql(ty, out),
}
}
Value::Nil => Ok(IsNull::Yes),
}
}
fn accepts(_ty: &Type) -> bool { true }
}
impl<'b> FromSql<'b> for OwnedPgValue {
fn from_sql(ty: &Type, raw: &'b [u8]) -> Result<Self, Box<dyn Error + Sync + Send>> {
macro_rules! match_type {
($variant:ident) => {
FromSql::from_sql(ty, raw).map(|v| OwnedPgValue(Value::$variant(v)))
};
}
let kind = ty.kind();
match *kind {
Kind::Enum(_) => match_type!(Text),
Kind::Array(ref array_type) => {
let array_type_kind = array_type.kind();
match *array_type_kind {
Kind::Enum(_) => {
FromSql::from_sql(ty, raw)
.map(|v| OwnedPgValue(Value::Array(Array::Text(v))))
}
_ => {
match *ty {
Type::TEXT_ARRAY | Type::NAME_ARRAY | Type::VARCHAR_ARRAY => {
FromSql::from_sql(ty, raw)
.map(|v| OwnedPgValue(Value::Array(Array::Text(v))))
}
Type::INT4_ARRAY => {
FromSql::from_sql(ty, raw)
.map(|v| OwnedPgValue(Value::Array(Array::Int(v))))
}
Type::FLOAT4_ARRAY => {
FromSql::from_sql(ty, raw)
.map(|v| OwnedPgValue(Value::Array(Array::Float(v))))
}
_ => panic!("Array type {:?} is not yet covered", array_type),
}
}
}
}
Kind::Simple => {
match *ty {
Type::BOOL => match_type!(Bool),
Type::INT2 => match_type!(Smallint),
Type::INT4 => match_type!(Int),
Type::INT8 => match_type!(Bigint),
Type::FLOAT4 => match_type!(Float),
Type::FLOAT8 => match_type!(Double),
Type::TEXT | Type::VARCHAR | Type::NAME | Type::UNKNOWN => {
match_type!(Text)
}
Type::TS_VECTOR => {
let text = String::from_utf8(raw.to_owned());
match text {
Ok(text) => Ok(OwnedPgValue(Value::Text(text))),
Err(e) => Err(Box::new(PostgresError::Utf8(e))),
}
}
Type::BPCHAR => {
let v: Result<String, _> = FromSql::from_sql(&Type::TEXT, raw);
match v {
Ok(v) => {
if v.chars().count() == 1 {
Ok(OwnedPgValue(Value::Char(v.chars().next().unwrap())))
} else {
FromSql::from_sql(ty, raw).map(|v: String| {
let value_string: String = v.trim_end().to_string();
OwnedPgValue(Value::Text(value_string))
})
}
}
Err(e) => Err(e),
}
}
Type::UUID => match_type!(Uuid),
Type::DATE => match_type!(Date),
Type::TIMESTAMPTZ | Type::TIMESTAMP => match_type!(Timestamp),
Type::TIME | Type::TIMETZ => match_type!(Time),
Type::BYTEA => match_type!(Blob),
Type::NUMERIC => {
let numeric: PgNumeric = FromSql::from_sql(ty, raw)?;
let bigdecimal = BigDecimal::from(numeric);
Ok(OwnedPgValue(Value::BigDecimal(bigdecimal)))
}
Type::JSON | Type::JSONB => {
let value: serde_json::Value = FromSql::from_sql(ty, raw)?;
let text = serde_json::to_string(&value).unwrap();
Ok(OwnedPgValue(Value::Json(text)))
}
Type::INTERVAL => {
let pg_interval: PgInterval = FromSql::from_sql(ty, raw)?;
let interval = Interval::new(
pg_interval.microseconds,
pg_interval.days,
pg_interval.months,
);
Ok(OwnedPgValue(Value::Interval(interval)))
}
Type::POINT => {
let p: Point<f64> = FromSql::from_sql(ty, raw)?;
Ok(OwnedPgValue(Value::Point(p)))
}
Type::INET => {
info!("inet raw:{:?}", raw);
match_type!(Text)
}
_ => panic!("unable to convert from {:?}", ty),
}
}
_ => panic!("not yet handling this kind: {:?}", kind),
}
}
fn accepts(_ty: &Type) -> bool { true }
fn from_sql_null(_ty: &Type) -> Result<Self, Box<dyn Error + Sync + Send>> {
Ok(OwnedPgValue(Value::Nil))
}
fn from_sql_nullable(
ty: &Type,
raw: Option<&[u8]>,
) -> Result<Self, Box<dyn Error + Sync + Send>> {
match raw {
Some(raw) => Self::from_sql(ty, raw),
None => Self::from_sql_null(ty),
}
}
}
#[derive(Debug, Error)]
pub enum PostgresError {
Sql(postgres::Error, String),
Utf8(#[from] FromUtf8Error),
PoolInitialization(#[from] r2d2::Error),
}
impl fmt::Display for PostgresError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{:#?}", self) }
}
#[cfg(test)]
mod test {
use crate::{
pool::*,
Pool,
*,
};
use log::*;
use std::ops::DerefMut;
#[test]
fn test_character_array_data_type() {
let db_url = "postgres://postgres:p0stgr3s@localhost:5432/sakila";
let mut pool = Pool::new();
let mut dm = pool.dm(db_url).unwrap();
let sql = "SELECT language_id, name FROM language";
let languages: Result<Rows, DbError> = dm.execute_sql_with_return(sql, &[]);
println!("languages: {:#?}", languages);
assert!(languages.is_ok());
}
#[test]
fn test_advancing_autoincrement_primary_column() {
let db_url = "postgres://postgres:p0stgr3s@localhost:5432/sakila";
let mut pool = Pool::new();
let mut em = pool.em(db_url).unwrap();
let actor_table = TableName::from("public.actor");
let last_value = em
.get_autoincrement_last_value(&actor_table)
.unwrap()
.unwrap();
let result = em
.set_autoincrement_value(&actor_table, last_value + 1)
.unwrap_or_else(|e| panic!("{}", e));
println!("result: {:?}", result);
assert_eq!(result, Some(last_value + 1));
}
#[test]
fn test_ts_vector() {
let db_url = "postgres://postgres:p0stgr3s@localhost:5432/sakila";
let mut pool = Pool::new();
let mut dm = pool.dm(db_url).unwrap();
let sql = "SELECT film_id, title, fulltext::text FROM film LIMIT 40";
let films: Result<Rows, DbError> = dm.execute_sql_with_return(sql, &[]);
println!("film: {:#?}", films);
assert!(films.is_ok());
}
#[test]
fn connect_test_query() {
let db_url = "postgres://postgres:p0stgr3s@localhost:5432/sakila";
let mut pool = Pool::new();
let conn = pool.connect(db_url);
assert!(conn.is_ok());
let mut conn: PooledConn = conn.unwrap();
match conn {
PooledConn::PooledPg(ref mut pooled_pg) => {
let rows = pooled_pg.query("select 42, 'life'", &[]).unwrap();
for row in rows.iter() {
let n: i32 = row.get(0);
let l: String = row.get(1);
assert_eq!(n, 42);
assert_eq!(l, "life");
}
}
#[cfg(any(feature = "with-sqlite", feature = "with-mysql"))]
_ => unreachable!(),
}
}
#[test]
fn connect_test_query_explicit_deref() {
let db_url = "postgres://postgres:p0stgr3s@localhost:5432/sakila";
let mut pool = Pool::new();
let conn = pool.connect(db_url);
assert!(conn.is_ok());
let mut conn: PooledConn = conn.unwrap();
match conn {
PooledConn::PooledPg(ref mut pooled_pg) => {
let c = pooled_pg.deref_mut(); let rows = c.query("select 42, 'life'", &[]).unwrap();
for row in rows.iter() {
let n: i32 = row.get(0);
let l: String = row.get(1);
assert_eq!(n, 42);
assert_eq!(l, "life");
}
}
#[cfg(any(feature = "with-sqlite", feature = "with-mysql"))]
_ => unreachable!(),
}
}
#[test]
fn test_unknown_type() {
let mut pool = Pool::new();
let db_url = "postgres://postgres:p0stgr3s@localhost/sakila";
let mut db = pool.db(db_url).unwrap();
let values: Vec<Value> = vec!["hi".into(), true.into(), 42.into(), 1.0.into()];
let bvalues: Vec<&Value> = values.iter().collect();
let rows: Result<Rows, DbError> = db.execute_sql_with_return(
"select 'Hello', $1::TEXT, $2::BOOL, $3::INT, $4::FLOAT",
&bvalues,
);
info!("rows: {:#?}", rows);
assert!(rows.is_ok());
}
#[test]
fn test_unknown_type_i32_f32() {
let mut pool = Pool::new();
let db_url = "postgres://postgres:p0stgr3s@localhost/sakila";
let mut db = pool.db(db_url).unwrap();
let values: Vec<Value> = vec![42.into(), 1.0.into()];
let bvalues: Vec<&Value> = values.iter().collect();
let rows: Result<Rows, DbError> = db.execute_sql_with_return("select $1, $2", &bvalues);
info!("rows: {:#?}", rows);
assert!(!rows.is_ok());
}
#[test]
#[allow(clippy::bool_assert_comparison)]
fn using_values() {
let mut pool = Pool::new();
let db_url = "postgres://postgres:p0stgr3s@localhost/sakila";
let mut db = pool.db(db_url).unwrap();
let values: Vec<Value> = vec!["hi".into(), true.into(), 42.into(), 1.0.into()];
let bvalues: Vec<&Value> = values.iter().collect();
let rows: Result<Rows, DbError> = db.execute_sql_with_return(
"select 'Hello'::TEXT, $1::TEXT, $2::BOOL, $3::INT, $4::FLOAT",
&bvalues,
);
info!("columns: {:#?}", rows);
assert!(rows.is_ok());
if let Ok(rows) = rows {
for row in rows.iter() {
info!("row {:?}", row);
let v4: Result<f64, _> = row.get("float8");
assert_eq!(v4.unwrap(), 1.0f64);
let v3: Result<i32, _> = row.get("int4");
assert_eq!(v3.unwrap(), 42i32);
let hi: Result<String, _> = row.get("text");
assert_eq!(hi.unwrap(), "hi");
let b: Result<bool, _> = row.get("bool");
assert_eq!(b.unwrap(), true);
}
}
}
#[test]
fn with_nulls() {
let mut pool = Pool::new();
let db_url = "postgres://postgres:p0stgr3s@localhost/sakila";
let mut db = pool.db(db_url).unwrap();
let rows:Result<Rows, DbError> = db.execute_sql_with_return("select 'rust'::TEXT AS name, NULL::TEXT AS schedule, NULL::TEXT AS specialty from actor", &[]);
info!("columns: {:#?}", rows);
assert!(rows.is_ok());
if let Ok(rows) = rows {
for row in rows.iter() {
info!("row {:?}", row);
let name: Result<Option<String>, _> = row.get("name");
info!("name: {:?}", name);
assert_eq!(name.unwrap().unwrap(), "rust");
let schedule: Result<Option<String>, _> = row.get("schedule");
info!("schedule: {:?}", schedule);
assert_eq!(schedule.unwrap(), None);
let specialty: Result<Option<String>, _> = row.get("specialty");
info!("specialty: {:?}", specialty);
assert_eq!(specialty.unwrap(), None);
}
}
}
#[test]
#[cfg(feature = "db-auth")]
fn test_get_users() {
let mut pool = Pool::new();
let db_url = "postgres://postgres:p0stgr3s@localhost/sakila";
let mut em = pool.em(db_url).unwrap();
let users = em.get_users();
info!("users: {:#?}", users);
assert!(users.is_ok());
}
}