use crate::error::{ContextualError, InnerError};
use crate::script_discovery::ReadFrom;
use crate::sqltext::sql_statements_with_line_no;
use crate::tracing::TxLockTracer;
use postgres::{Client, NoTls, Transaction};
use regex::Regex;
use std::collections::HashMap;
use tracing::trace_transaction;
pub mod hint_data;
pub mod hints;
pub mod lints;
pub mod output;
pub mod pg_types;
pub mod pgpass;
pub mod sqltext;
pub mod tracing;
pub mod script_discovery;
pub(crate) mod comments;
#[cfg(test)]
mod render_doc_snapshots;
pub mod parse_scripts;
pub mod tempserver;
pub mod error;
pub mod git;
pub mod utils {
use std::path::Path;
pub trait FsyncDir {
fn fsync(&self) -> Result<(), std::io::Error>;
}
impl<P: AsRef<Path>> FsyncDir for P {
fn fsync(&self) -> Result<(), std::io::Error> {
#[cfg(not(windows))]
{
let dir = std::fs::File::open(self)?;
dir.sync_all()
}
#[cfg(windows)]
{
Ok(())
}
}
}
}
pub struct SqlScript {
pub name: String,
pub sql: String,
}
pub type Result<T> = std::result::Result<T, error::Error>;
pub fn read_script(read_from: &ReadFrom, placeholders: &HashMap<&str, &str>) -> Result<SqlScript> {
let sql = read_from.read()?;
let sql = sqltext::resolve_placeholders(&sql, placeholders)?;
Ok(SqlScript {
name: read_from.name().to_string(),
sql,
})
}
pub struct ClientSource {
user: String,
database: String,
host: String,
port: u16,
password: String,
client: Option<Client>,
}
impl ClientSource {
pub fn connection_string(&self) -> String {
let out = format!(
"host={} user={} dbname={} port={} password={}",
self.host, self.user, self.database, self.port, self.password
);
out
}
pub fn new(user: String, database: String, host: String, port: u16, password: String) -> Self {
ClientSource {
user,
database,
host,
port,
password,
client: None,
}
}
}
pub trait WithClient {
fn with_client<T>(&mut self, f: impl FnOnce(&mut Client) -> Result<T>) -> Result<T>;
fn in_transaction<T>(
&mut self,
commit: bool,
f: impl FnOnce(&mut Transaction) -> Result<T>,
) -> Result<T> {
self.with_client(|client| {
let mut tx = client.transaction()?;
let result = f(&mut tx)?;
if commit {
tx.commit()?;
} else {
tx.rollback()?;
}
client.execute("RESET ALL", &[])?;
Ok(result)
})
}
}
impl WithClient for ClientSource {
fn with_client<T>(&mut self, f: impl FnOnce(&mut Client) -> Result<T>) -> Result<T> {
if let Some(ref mut client) = self.client {
f(client)
} else {
let client = Client::connect(self.connection_string().as_str(), NoTls)?;
self.client = Some(client);
f(self.client.as_mut().unwrap())
}
}
}
pub fn parse_placeholders(placeholders: &[String]) -> Result<HashMap<&str, &str>> {
let mut map = HashMap::new();
for placeholder in placeholders {
let parts: Vec<&str> = placeholder.splitn(2, '=').collect();
if parts.len() != 2 {
return Err(InnerError::PlaceholderSyntaxError.with_context(format!(
"Placeholder '{}' must be in the form name=value",
placeholder
)));
}
map.insert(parts[0], parts[1]);
}
Ok(map)
}
pub fn perform_trace<'a, T: WithClient>(
script: &SqlScript,
connection_settings: &mut T,
ignored_hints: &'a [&'a str],
commit: bool,
skip: &[Regex],
is_final: bool,
) -> Result<TxLockTracer<'a>> {
let sql_statements = sql_statements_with_line_no(script.sql.as_str())?;
let all_concurrently = sql_statements
.iter()
.map(|(_, s)| s)
.all(sqltext::is_concurrently);
if all_concurrently && commit {
connection_settings.with_client(|client| {
for (_, s) in sql_statements.iter() {
let skip_this = skip.iter().any(|r| r.is_match(s));
if !skip_this {
client.execute(*s, &[])?;
}
}
Ok(())
})?;
Ok(TxLockTracer::tracer_for_concurrently(
Some(script.name.clone()),
sql_statements.iter().copied(),
ignored_hints,
))
} else {
connection_settings.in_transaction(commit, |conn| {
trace_transaction(
Some(script.name.clone()),
conn,
sql_statements.iter().copied(),
ignored_hints,
skip,
is_final,
)
})
}
}
#[cfg(test)]
pub fn generate_new_test_db() -> String {
let mut pg_client = Client::connect(
"host=localhost dbname=postgres password=postgres user=postgres",
NoTls,
)
.unwrap();
pg_client
.execute(
"CREATE TABLE IF NOT EXISTS test_dbs(\
name text PRIMARY KEY, time timestamptz default now());",
&[],
)
.ok();
let db_name = format!(
"eugene_testdb_{}",
uuid::Uuid::new_v4().to_string().replace('-', "_")
);
pg_client
.execute(
"INSERT INTO test_dbs(name) VALUES($1);",
&[&db_name.as_str()],
)
.unwrap();
let old_dbs = pg_client
.query(
"SELECT name FROM test_dbs WHERE time < now() - interval '15 minutes';",
&[],
)
.unwrap();
for row in old_dbs {
let db_name: String = row.get(0);
pg_client
.execute(&format!("DROP DATABASE IF EXISTS {}", db_name), &[])
.unwrap();
pg_client
.execute(
"DELETE FROM test_dbs WHERE name = $1;",
&[&db_name.as_str()],
)
.unwrap();
}
pg_client
.execute(
&format!("CREATE DATABASE {} TEMPLATE test_db", db_name),
&[],
)
.unwrap();
db_name
}
#[cfg(test)]
mod tests {
#[test]
fn lint_line_numbers_should_make_sense_ex2() {
let script = "ALTER TABLE foo ADD a text;
-- A comment
CREATE UNIQUE INDEX my_index ON foo (a);";
let report = super::lints::anon_lint(script).unwrap();
assert_eq!(report.statements[0].line_number, 1);
assert_eq!(report.statements[1].line_number, 5);
}
#[test]
fn lint_line_numbers_should_make_sense_ex1() {
let script = "ALTER TABLE
foo
ADD
a text;
CREATE UNIQUE INDEX
my_index ON foo (a);";
let report = super::lints::anon_lint(script).unwrap();
assert_eq!(report.statements[0].line_number, 1);
assert_eq!(report.statements[1].line_number, 6);
}
}