use crate::logger::{
LogEntry, LOG_ENTRY_MAX_FILENAME_LENGTH, LOG_ENTRY_MAX_HOSTNAME_LENGTH,
LOG_ENTRY_MAX_MESSAGE_LENGTH, LOG_ENTRY_MAX_MODULE_LENGTH,
};
use crate::{truncate_option_str, Connection, Db, Result};
use futures::TryStreamExt;
use sqlx::postgres::{PgConnectOptions, PgPool};
use sqlx::Row;
use std::convert::TryFrom;
use std::env;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use time::OffsetDateTime;
const SCHEMA: &str = include_str!("../schemas/postgres.sql");
fn strip_sql_comments(input: &str) -> String {
let mut output = String::new();
let mut comment = 0;
for ch in input.chars() {
assert!(comment <= 2);
match ch {
'\r' => (),
'\n' => {
if comment == 1 {
output.push('-');
}
comment = 0;
output.push('\n');
}
_ if comment == 2 => (),
'-' if comment < 2 => comment += 1,
ch => {
if comment == 1 {
output.push('-');
} else if comment == 2 {
output.push_str("--");
}
comment = 0;
output.push(ch);
}
}
}
output
}
#[derive(Default)]
#[cfg_attr(test, derive(PartialEq))]
pub struct ConnectionOptions {
pub host: String,
pub port: u16,
pub database: String,
pub username: String,
pub password: String,
}
#[cfg(test)]
impl std::fmt::Debug for ConnectionOptions {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConnectionOptions")
.field("host", &self.host)
.field("port", &self.port)
.field("database", &self.database)
.field("username", &self.username)
.field("password", &"scrubbed".to_owned())
.finish()
}
}
impl ConnectionOptions {
pub fn from_env(prefix: &str) -> Result<ConnectionOptions> {
fn get_required_var(prefix: &str, suffix: &str) -> Result<String> {
let name = format!("{}_{}", prefix, suffix);
match env::var(&name) {
Ok(value) => Ok(value),
Err(env::VarError::NotPresent) => {
Err(format!("Required environment variable {} not present", name))
}
Err(env::VarError::NotUnicode(_)) => {
Err(format!("Invalid value in environment variable {}", name))
}
}
}
Ok(ConnectionOptions {
host: get_required_var(prefix, "HOST")?,
port: get_required_var(prefix, "PORT")?
.parse::<u16>()
.map_err(|e| format!("Invalid port number: {}", e))?,
database: get_required_var(prefix, "DATABASE")?,
username: get_required_var(prefix, "USERNAME")?,
password: get_required_var(prefix, "PASSWORD")?,
})
}
}
pub fn connect_lazy(opts: ConnectionOptions) -> Connection {
Connection(Arc::from(PostgresDb::connect_lazy(opts, None)))
}
pub async fn setup_test(opts: ConnectionOptions) -> Connection {
Connection(Arc::from(PostgresTestDb::setup_test(opts).await))
}
#[derive(Clone)]
struct PostgresDb {
pool: PgPool,
suffix: Option<u32>,
log_sequence: Arc<AtomicU32>,
}
impl PostgresDb {
fn connect_lazy(opts: ConnectionOptions, suffix: Option<u32>) -> Self {
let options = PgConnectOptions::new()
.host(&opts.host)
.port(opts.port)
.database(&opts.database)
.username(&opts.username)
.password(&opts.password);
Self {
pool: PgPool::connect_lazy_with(options),
suffix,
log_sequence: Arc::from(AtomicU32::new(0)),
}
}
fn patch_query(&self, query: &str) -> String {
match self.suffix {
None => query.to_owned(),
Some(suffix) => query.replace(" logs", &format!(" logs_{}", suffix)),
}
}
}
#[async_trait::async_trait]
impl Db for PostgresDb {
async fn create_schema(&self) -> Result<()> {
let schema = self.patch_query(&strip_sql_comments(SCHEMA));
let mut tx = self.pool.begin().await.map_err(|e| e.to_string())?;
for query_str in schema.split(';') {
sqlx::query(query_str).execute(&mut tx).await.map_err(|e| e.to_string())?;
}
tx.commit().await.map_err(|e| e.to_string())
}
async fn get_log_entries(&self) -> Result<Vec<String>> {
let query_str = self.patch_query("SELECT * FROM logs ORDER BY timestamp, sequence");
let mut rows = sqlx::query(&query_str).fetch(&self.pool);
let mut entries = vec![];
while let Some(row) = rows.try_next().await.map_err(|e| e.to_string())? {
let timestamp: OffsetDateTime = row.try_get("timestamp").map_err(|e| e.to_string())?;
let hostname: String = row.try_get("hostname").map_err(|e| e.to_string())?;
let level: i16 = row.try_get("level").map_err(|e| e.to_string())?;
let module: Option<String> = row.try_get("module").map_err(|e| e.to_string())?;
let filename: Option<String> = row.try_get("filename").map_err(|e| e.to_string())?;
let line: Option<i16> = row.try_get("line").map_err(|e| e.to_string())?;
let message: String = row.try_get("message").map_err(|e| e.to_string())?;
entries.push(format!(
"{}.{} {} {} {} {}:{} {}",
timestamp.unix_timestamp(),
timestamp.unix_timestamp_nanos() % 1000000000,
hostname,
level,
module.as_deref().unwrap_or("NO-MODULE"),
filename.as_deref().unwrap_or("NO-FILENAME"),
line.unwrap_or(-1),
message
))
}
Ok(entries)
}
async fn put_log_entries(&self, entries: Vec<LogEntry<'_, '_>>) -> Result<()> {
let nentries = u32::try_from(entries.len())
.map_err(|e| format!("Cannot insert {} log entries at once: {}", entries.len(), e))?;
if nentries == 0 {
return Ok(());
}
let mut sequence = self.log_sequence.fetch_add(nentries, Ordering::SeqCst);
let mut query_str = self.patch_query(
"INSERT INTO logs
(timestamp, sequence, hostname, level, module, filename, line, message)
VALUES ",
);
const NPARAMS: usize = 8;
let mut param: usize = 1;
for _ in 0..nentries {
if param > 1 {
query_str.push(',');
}
query_str.push('(');
for i in 1..NPARAMS + 1 {
if i == 1 {
query_str += &format!("${}", param);
} else {
query_str += &format!(", ${}", param);
}
param += 1;
}
query_str.push(')');
}
let mut query = sqlx::query(&query_str);
for mut entry in entries.into_iter() {
let module = truncate_option_str(entry.module, LOG_ENTRY_MAX_MODULE_LENGTH);
let filename = truncate_option_str(entry.filename, LOG_ENTRY_MAX_FILENAME_LENGTH);
entry.hostname.truncate(LOG_ENTRY_MAX_HOSTNAME_LENGTH);
entry.message.truncate(LOG_ENTRY_MAX_MESSAGE_LENGTH);
let line = match entry.line {
Some(n) => Some(i16::try_from(n).map_err(|_| "line out of range".to_owned())?),
None => None,
};
query = query
.bind(entry.timestamp)
.bind(sequence)
.bind(entry.hostname)
.bind(i16::try_from(entry.level as usize).expect("Levels must fit in u16"))
.bind(module)
.bind(filename)
.bind(line)
.bind(entry.message);
sequence += 1;
}
let done = query.execute(&self.pool).await.map_err(|e| e.to_string())?;
if done.rows_affected() != u64::from(nentries) {
return Err(format!(
"Log entries insertion created {} rows but expected {}",
done.rows_affected(),
nentries
));
}
Ok(())
}
}
#[derive(Clone)]
struct PostgresTestDb(PostgresDb);
impl PostgresTestDb {
async fn setup_test(opts: ConnectionOptions) -> Self {
let db = PostgresDb::connect_lazy(opts, Some(rand::random()));
db.create_schema().await.unwrap();
PostgresTestDb(db)
}
async fn teardown_test(&self) {
let suffix = self.0.suffix.expect("This should only be called from tests");
let mut tx = self.0.pool.begin().await.unwrap();
for query_str in &[
format!("DROP INDEX logs_{}_by_timestamp", suffix),
format!("DROP TABLE logs_{}", suffix),
] {
sqlx::query(query_str).execute(&mut tx).await.unwrap();
}
tx.commit().await.unwrap();
self.0.pool.close().await;
}
}
impl Drop for PostgresTestDb {
fn drop(&mut self) {
#[tokio::main]
async fn cleanup(context: &mut PostgresTestDb) {
context.teardown_test().await;
}
cleanup(self)
}
}
#[async_trait::async_trait]
impl Db for PostgresTestDb {
async fn create_schema(&self) -> Result<()> {
self.0.create_schema().await
}
async fn get_log_entries(&self) -> Result<Vec<String>> {
self.0.get_log_entries().await
}
async fn put_log_entries(&self, entries: Vec<LogEntry<'_, '_>>) -> Result<()> {
self.0.put_log_entries(entries).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::testutils;
#[test]
fn test_strip_sql_comments() {
let input = "first line\nsecond - line\n-third-line-is-here-\nfourth--li--ne\nfifth -- \nx";
let exp_output = "first line\nsecond - line\n-third-line-is-here-\nfourth\nfifth \nx";
assert_eq!(exp_output, &strip_sql_comments(input));
}
#[test]
fn test_connectionoptions_from_env_ok() {
let prefix = format!("TEST_{}", rand::random::<u32>());
env::set_var(format!("{}_HOST", prefix), "the-host");
env::set_var(format!("{}_PORT", prefix), "1234");
env::set_var(format!("{}_DATABASE", prefix), "the-database");
env::set_var(format!("{}_USERNAME", prefix), "the-username");
env::set_var(format!("{}_PASSWORD", prefix), "the-password");
let opts = ConnectionOptions::from_env(&prefix).unwrap();
assert_eq!(
ConnectionOptions {
host: "the-host".to_owned(),
port: 1234,
database: "the-database".to_owned(),
username: "the-username".to_owned(),
password: "the-password".to_owned(),
},
opts
);
}
fn do_connectionoptions_from_env_missing_test(missing: &str) {
let prefix = format!("TEST_{}", rand::random::<u32>());
if missing != "HOST" {
env::set_var(format!("{}_HOST", prefix), "host");
}
if missing != "PORT" {
env::set_var(format!("{}_PORT", prefix), "5432");
}
if missing != "DATABASE" {
env::set_var(format!("{}_DATABASE", prefix), "database");
}
if missing != "USERNAME" {
env::set_var(format!("{}_USERNAME", prefix), "username");
}
if missing != "PASSWORD" {
env::set_var(format!("{}_PASSWORD", prefix), "password");
}
match ConnectionOptions::from_env(&prefix) {
Ok(_) => panic!("Should have failed"),
Err(e) => assert!(e.contains(&format!("{}_{} not present", prefix, missing))),
}
}
#[test]
fn test_connectionoptions_from_env_missing_host() {
do_connectionoptions_from_env_missing_test("HOST");
}
#[test]
fn test_connectionoptions_from_env_missing_port() {
do_connectionoptions_from_env_missing_test("PORT");
}
#[test]
fn test_connectionoptions_from_env_missing_database() {
do_connectionoptions_from_env_missing_test("DATABASE");
}
#[test]
fn test_connectionoptions_from_env_missing_username() {
do_connectionoptions_from_env_missing_test("USERNAME");
}
#[test]
fn test_connectionoptions_from_env_missing_password() {
do_connectionoptions_from_env_missing_test("PASSWORD");
}
#[test]
fn test_connectionoptions_from_env_invalid_port() {
let prefix = format!("TEST_{}", rand::random::<u32>());
env::set_var(format!("{}_HOST", prefix), "host");
env::set_var(format!("{}_PORT", prefix), "abc");
env::set_var(format!("{}_DATABASE", prefix), "database");
env::set_var(format!("{}_USERNAME", prefix), "username");
env::set_var(format!("{}_PASSWORD", prefix), "password");
match ConnectionOptions::from_env(&prefix) {
Ok(_) => panic!("Should have failed"),
Err(e) => assert!(e.contains("Invalid port number")),
}
}
struct PostgresTestContext {
db: PostgresTestDb,
}
#[async_trait::async_trait]
impl testutils::TestContext for PostgresTestContext {
fn db(&self) -> &(dyn Db + Send + Sync) {
&self.db
}
}
fn setup() -> Box<dyn testutils::TestContext> {
let _can_fail = env_logger::builder().is_test(true).try_init();
#[tokio::main]
async fn prepare() -> PostgresTestDb {
PostgresTestDb::setup_test(ConnectionOptions::from_env("POSTGRES_TEST").unwrap()).await
}
Box::from(PostgresTestContext { db: prepare() })
}
#[test]
#[ignore = "Requires environment configuration and is expensive"]
fn test_postgresdb_log_entries_none() {
testutils::test_log_entries_none(setup());
}
#[test]
#[ignore = "Requires environment configuration and is expensive"]
fn test_postgresdb_log_entries_individual() {
testutils::test_log_entries_individual(setup());
}
#[test]
#[ignore = "Requires environment configuration and is expensive"]
fn test_postgresdb_log_entries_combined() {
testutils::test_log_entries_combined(setup());
}
#[test]
#[ignore = "Requires environment configuration and is expensive"]
fn test_postgresdb_log_entries_long_strings() {
testutils::test_log_entries_long_strings(setup());
}
}