use std::{
collections::HashMap,
fmt::Display,
fs::File,
io::{BufRead, BufReader, Write},
path::Path,
};
use crate::{
config::Config,
error::Result,
interceptor::{InterceptorRef, Registry},
Database, SqlnessError,
};
const COMMENT_PREFIX: &str = "--";
const QUERY_DELIMITER: char = ';';
pub(crate) struct TestCase {
name: String,
queries: Vec<Query>,
}
impl TestCase {
pub(crate) fn from_file<P: AsRef<Path>>(path: P, cfg: &Config) -> Result<Self> {
let file = File::open(path.as_ref()).map_err(|e| SqlnessError::ReadPath {
source: e,
path: path.as_ref().to_path_buf(),
})?;
let mut queries = vec![];
let mut query = Query::with_interceptor_factories(cfg.interceptor_registry.clone());
let reader = BufReader::new(file);
for line in reader.lines() {
let line = line?;
if line.starts_with(COMMENT_PREFIX) {
query.push_comment(line.clone());
if line.starts_with(&cfg.interceptor_prefix) {
query.push_interceptor(&cfg.interceptor_prefix, line)?;
}
continue;
}
if line.is_empty() {
continue;
}
query.append_query_line(&line);
if line.ends_with(QUERY_DELIMITER) {
queries.push(query);
query = Query::with_interceptor_factories(cfg.interceptor_registry.clone());
} else {
query.append_query_line("\n");
}
}
Ok(Self {
name: path.as_ref().to_str().unwrap().to_string(),
queries,
})
}
pub(crate) async fn execute<W>(&mut self, db: &dyn Database, writer: &mut W) -> Result<()>
where
W: Write,
{
for query in &mut self.queries {
query.execute(db, writer).await?;
}
Ok(())
}
}
impl Display for TestCase {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.name)
}
}
#[derive(Default, Debug, Clone)]
pub struct QueryContext {
pub context: HashMap<String, String>,
}
#[derive(Default)]
struct Query {
comment_lines: Vec<String>,
display_query: Vec<String>,
execute_query: Vec<String>,
interceptor_registry: Registry,
interceptors: Vec<InterceptorRef>,
}
impl Query {
pub fn with_interceptor_factories(interceptor_registry: Registry) -> Self {
Self {
interceptor_registry,
..Default::default()
}
}
fn push_interceptor(
&mut self,
interceptor_prefix: &str,
interceptor_line: String,
) -> Result<()> {
if let Some((_, remaining)) = interceptor_line.split_once(interceptor_prefix) {
let interceptor = self.interceptor_registry.create(remaining)?;
self.interceptors.push(interceptor);
Ok(())
} else {
Err(SqlnessError::MissingPrefix {
line: interceptor_line,
})
}
}
fn push_comment(&mut self, comment_line: String) {
self.comment_lines.push(comment_line);
}
fn append_query_line(&mut self, line: &str) {
self.display_query.push(line.to_string());
self.execute_query.push(line.to_string());
}
async fn execute<W>(&mut self, db: &dyn Database, writer: &mut W) -> Result<()>
where
W: Write,
{
let context = self.before_execute_intercept().await;
for comment in &self.comment_lines {
writer.write_all(comment.as_bytes())?;
writer.write_all("\n".as_bytes())?;
}
for comment in &self.display_query {
writer.write_all(comment.as_bytes())?;
}
writer.write_all("\n\n".as_bytes())?;
let sql = self.concat_query_lines();
for sql in sql.split(crate::interceptor::template::DELIMITER) {
if !sql.trim().is_empty() {
let sql = if sql.ends_with(QUERY_DELIMITER) {
sql.to_string()
} else {
format!("{sql};")
};
let mut result = db.query(context.clone(), sql).await.to_string();
self.after_execute_intercept(&mut result).await;
self.write_result(writer, result)?;
}
}
Ok(())
}
async fn before_execute_intercept(&mut self) -> QueryContext {
let mut context = QueryContext::default();
for interceptor in &self.interceptors {
interceptor
.before_execute_async(&mut self.execute_query, &mut context)
.await;
}
context
}
async fn after_execute_intercept(&mut self, result: &mut String) {
for interceptor in &self.interceptors {
interceptor.after_execute_async(result).await;
}
}
fn concat_query_lines(&self) -> String {
self.execute_query
.iter()
.fold(String::new(), |query, str| query + str)
.trim_start()
.to_string()
}
#[allow(clippy::unused_io_amount)]
fn write_result<W>(&self, writer: &mut W, result: String) -> Result<()>
where
W: Write,
{
writer.write_all(result.as_bytes())?;
writer.write("\n\n".as_bytes())?;
Ok(())
}
}