use std::fmt::Display;
use std::path::Path;
use std::sync::Arc;
use std::time::Duration;
use std::vec;
use async_trait::async_trait;
use difference::Difference;
use futures::executor::block_on;
use futures::{stream, Future, StreamExt};
use itertools::Itertools;
use owo_colors::OwoColorize;
use regex::Regex;
use tempfile::{tempdir, TempDir};
use crate::parser::*;
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
#[non_exhaustive]
pub enum ColumnType {
Text,
Integer,
FloatingPoint,
Any,
Unknown(char),
}
impl Display for ColumnType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ColumnType::Text => write!(f, "T"),
ColumnType::Integer => write!(f, "I"),
ColumnType::FloatingPoint => write!(f, "R"),
ColumnType::Any => write!(f, "?"),
ColumnType::Unknown(c) => write!(f, "{}", c),
}
}
}
impl TryFrom<char> for ColumnType {
type Error = ParseErrorKind;
fn try_from(c: char) -> Result<Self, Self::Error> {
match c {
'T' => Ok(Self::Text),
'I' => Ok(Self::Integer),
'R' => Ok(Self::FloatingPoint),
'?' => Ok(Self::Any),
_ => Ok(Self::Unknown(c)),
}
}
}
#[derive(Debug, Clone)]
pub enum RecordOutput {
Nothing,
Query {
types: Vec<ColumnType>,
rows: Vec<Vec<String>>,
error: Option<Arc<dyn std::error::Error + Send + Sync>>,
},
Statement {
count: u64,
error: Option<Arc<dyn std::error::Error + Send + Sync>>,
},
}
#[non_exhaustive]
pub enum DBOutput {
Rows {
types: Vec<ColumnType>,
rows: Vec<Vec<String>>,
},
StatementComplete(u64),
}
#[async_trait]
pub trait AsyncDB: Send {
type Error: std::error::Error + Send + Sync + 'static;
async fn run(&mut self, sql: &str) -> Result<DBOutput, Self::Error>;
fn engine_name(&self) -> &str {
""
}
async fn sleep(dur: Duration) {
std::thread::sleep(dur);
}
}
pub trait DB: Send {
type Error: std::error::Error + Send + Sync + 'static;
fn run(&mut self, sql: &str) -> Result<DBOutput, Self::Error>;
fn engine_name(&self) -> &str {
""
}
}
#[async_trait]
impl<D> AsyncDB for D
where
D: DB,
{
type Error = <D as DB>::Error;
async fn run(&mut self, sql: &str) -> Result<DBOutput, Self::Error> {
<D as DB>::run(self, sql)
}
fn engine_name(&self) -> &str {
<D as DB>::engine_name(self)
}
}
#[derive(thiserror::Error, Clone)]
#[error("{kind}\nat {loc}\n")]
pub struct TestError {
kind: TestErrorKind,
loc: Location,
}
impl TestError {
pub fn display(&self, colorize: bool) -> TestErrorDisplay<'_> {
TestErrorDisplay {
err: self,
colorize,
}
}
}
pub struct TestErrorDisplay<'a> {
err: &'a TestError,
colorize: bool,
}
impl<'a> Display for TestErrorDisplay<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}\nat {}\n",
self.err.kind.display(self.colorize),
self.err.loc
)
}
}
#[derive(Clone, Debug, thiserror::Error)]
pub struct ParallelTestError {
errors: Vec<TestError>,
}
impl Display for ParallelTestError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "parallel test failed")?;
write!(f, "Caused by:")?;
for i in &self.errors {
writeln!(f, "{}", i)?;
}
Ok(())
}
}
pub struct ParallelTestErrorDisplay<'a> {
err: &'a ParallelTestError,
colorize: bool,
}
impl<'a> Display for ParallelTestErrorDisplay<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "parallel test failed")?;
write!(f, "Caused by:")?;
for i in &self.err.errors {
writeln!(f, "{}", i.display(self.colorize))?;
}
Ok(())
}
}
impl ParallelTestError {
pub fn display(&self, colorize: bool) -> ParallelTestErrorDisplay<'_> {
ParallelTestErrorDisplay {
err: self,
colorize,
}
}
}
impl std::fmt::Debug for TestError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self)
}
}
impl TestError {
pub fn kind(&self) -> TestErrorKind {
self.kind.clone()
}
pub fn location(&self) -> Location {
self.loc.clone()
}
}
#[derive(Debug, Clone)]
pub enum RecordKind {
Statement,
Query,
}
impl std::fmt::Display for RecordKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RecordKind::Statement => write!(f, "statement"),
RecordKind::Query => write!(f, "query"),
}
}
}
#[derive(thiserror::Error, Debug, Clone)]
pub enum TestErrorKind {
#[error("parse error: {0}")]
ParseError(ParseErrorKind),
#[error("{kind} is expected to fail, but actually succeed:\n[SQL] {sql}")]
Ok { sql: String, kind: RecordKind },
#[error("{kind} failed: {err}\n[SQL] {sql}")]
Fail {
sql: String,
err: Arc<dyn std::error::Error + Send + Sync>,
kind: RecordKind,
},
#[error("{kind} is expected to fail with error:\n\t{expected_err}\nbut got error:\n\t{err}\n[SQL] {sql}")]
ErrorMismatch {
sql: String,
err: Arc<dyn std::error::Error + Send + Sync>,
expected_err: String,
kind: RecordKind,
},
#[error("statement is expected to affect {expected} rows, but actually {actual}\n[SQL] {sql}")]
StatementResultMismatch {
sql: String,
expected: u64,
actual: String,
},
#[error(
"query result mismatch:\n[SQL] {sql}\n[Diff] (-expected|+actual)\n{}",
difference::Changeset::new(.expected, .actual, "\n").diffs.iter().format_with("\n", |diff, f| format_diff(diff, f, false))
)]
QueryResultMismatch {
sql: String,
expected: String,
actual: String,
},
#[error("expected results are invalid: expected {expected} columns, got {actual} columns\n[SQL] {sql}")]
QueryResultColumnCountMismatch {
sql: String,
expected: usize,
actual: usize,
},
}
impl From<ParseError> for TestError {
fn from(e: ParseError) -> Self {
TestError {
kind: TestErrorKind::ParseError(e.kind()),
loc: e.location(),
}
}
}
impl TestErrorKind {
fn at(self, loc: Location) -> TestError {
TestError { kind: self, loc }
}
pub fn display(&self, colorize: bool) -> TestErrorKindDisplay<'_> {
TestErrorKindDisplay {
error: self,
colorize,
}
}
}
pub struct TestErrorKindDisplay<'a> {
error: &'a TestErrorKind,
colorize: bool,
}
impl<'a> Display for TestErrorKindDisplay<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if !self.colorize {
return write!(f, "{}", self.error);
}
match self.error {
TestErrorKind::ErrorMismatch {
sql,
err,
expected_err,
kind,
} => write!(
f,
"{kind} is expected to fail with error:\n\t{}\nbut got error:\n\t{}\n[SQL] {sql}",
expected_err.bright_green(),
err.bright_red(),
),
TestErrorKind::QueryResultMismatch {
sql,
expected,
actual,
} => write!(
f,
"query result mismatch:\n[SQL] {sql}\n[Diff] ({}|{})\n{}",
"-expected".bright_red(),
"+actual".bright_green(),
difference::Changeset::new(expected, actual, "\n")
.diffs
.iter()
.format_with("\n", |diff, f| format_diff(diff, f, true))
),
_ => write!(f, "{}", self.error),
}
}
}
fn format_diff(
diff: &Difference,
f: &mut dyn FnMut(&dyn std::fmt::Display) -> std::fmt::Result,
colorize: bool,
) -> std::fmt::Result {
match *diff {
Difference::Same(ref x) => f(&x
.lines()
.format_with("\n", |line, f| f(&format_args!(" {line}")))),
Difference::Add(ref x) => f(&x.lines().format_with("\n", |line, f| {
if colorize {
f(&format_args!("+ {line}").bright_green())
} else {
f(&format_args!("+ {line}"))
}
})),
Difference::Rem(ref x) => f(&x.lines().format_with("\n", |line, f| {
if colorize {
f(&format_args!("- {line}").bright_red())
} else {
f(&format_args!("- {line}"))
}
})),
}
}
pub type Validator = fn(actual: &[Vec<String>], expected: &[String]) -> bool;
pub fn default_validator(actual: &[Vec<String>], expected: &[String]) -> bool {
let expected_results = expected.iter().map(normalize_string).collect_vec();
let normalized_rows = actual
.iter()
.map(|strs| strs.iter().map(normalize_string).join(" "))
.collect_vec();
normalized_rows == expected_results
}
pub struct Runner<D: AsyncDB> {
db: D,
validator: Validator,
testdir: Option<TempDir>,
sort_mode: Option<SortMode>,
hash_threshold: usize,
}
impl<D: AsyncDB> Runner<D> {
pub fn new(db: D) -> Self {
Runner {
db,
validator: default_validator,
testdir: None,
sort_mode: None,
hash_threshold: 0,
}
}
pub fn enable_testdir(&mut self) {
self.testdir = Some(tempdir().expect("failed to create testdir"));
}
pub fn with_validator(&mut self, validator: Validator) {
self.validator = validator;
}
pub fn with_hash_threshold(&mut self, hash_threshold: usize) {
self.hash_threshold = hash_threshold;
}
pub async fn apply_record(&mut self, record: Record) -> RecordOutput {
match record {
Record::Statement { conditions, .. } if self.should_skip(&conditions) => {
RecordOutput::Nothing
}
Record::Statement {
conditions: _,
sql,
expected_error: _,
expected_count: _,
loc: _,
} => {
let sql = self.replace_keywords(sql);
let ret = self.db.run(&sql).await;
match ret {
Ok(out) => match out {
DBOutput::Rows { types, rows } => RecordOutput::Query {
types,
rows,
error: None,
},
DBOutput::StatementComplete(count) => {
RecordOutput::Statement { count, error: None }
}
},
Err(e) => RecordOutput::Statement {
count: 0,
error: Some(Arc::new(e)),
},
}
}
Record::Query { conditions, .. } if self.should_skip(&conditions) => {
RecordOutput::Nothing
}
Record::Query {
conditions: _,
sql,
sort_mode,
type_string: _,
expected_error: _,
expected_results: _,
loc: _,
label: _,
} => {
let sql = self.replace_keywords(sql);
let (types, mut rows) = match self.db.run(&sql).await {
Ok(out) => match out {
DBOutput::Rows { types, rows } => (types, rows),
DBOutput::StatementComplete(count) => {
return RecordOutput::Statement { count, error: None }
}
},
Err(e) => {
return RecordOutput::Query {
error: Some(Arc::new(e)),
types: vec![],
rows: vec![],
}
}
};
match sort_mode.as_ref().or(self.sort_mode.as_ref()) {
None | Some(SortMode::NoSort) => {}
Some(SortMode::RowSort) => {
rows.sort_unstable();
}
Some(SortMode::ValueSort) => todo!("value sort"),
};
if self.hash_threshold > 0 && rows.len() * types.len() > self.hash_threshold {
let mut md5 = md5::Context::new();
for line in &rows {
for value in line {
md5.consume(value.as_bytes());
md5.consume(b"\n");
}
}
let hash = md5.compute();
rows = vec![vec![format!(
"{} values hashing to {:?}",
rows.len() * rows[0].len(),
hash
)]];
}
RecordOutput::Query {
error: None,
types,
rows,
}
}
Record::Sleep { duration, .. } => {
D::sleep(duration).await;
RecordOutput::Nothing
}
Record::Control(control) => match control {
Control::SortMode(sort_mode) => {
self.sort_mode = Some(sort_mode);
RecordOutput::Nothing
}
},
Record::HashThreshold { loc: _, threshold } => {
self.hash_threshold = threshold as usize;
RecordOutput::Nothing
}
Record::Include { .. }
| Record::Comment(_)
| Record::Newline
| Record::Subtest { .. }
| Record::Halt { .. }
| Record::Injected(_)
| Record::Condition(_) => RecordOutput::Nothing,
}
}
pub async fn run_async(&mut self, record: Record) -> Result<(), TestError> {
tracing::info!(?record, "testing");
match (record.clone(), self.apply_record(record).await) {
(_, RecordOutput::Nothing) => {}
(Record::Statement { .. }, RecordOutput::Query { error: None, .. }) => {}
(
Record::Query {
expected_results,
loc,
sql,
..
},
RecordOutput::Statement { error: None, .. },
) => {
if !expected_results.is_empty() {
return Err(TestErrorKind::QueryResultMismatch {
sql,
expected: expected_results.join("\n"),
actual: "".to_string(),
}
.at(loc));
}
}
(
Record::Statement {
loc,
conditions: _,
expected_error,
sql,
expected_count,
},
RecordOutput::Statement { count, error },
) => match (error, expected_error) {
(None, Some(_)) => {
return Err(TestErrorKind::Ok {
sql,
kind: RecordKind::Statement,
}
.at(loc))
}
(None, None) => {
if let Some(expected_count) = expected_count {
if expected_count != count {
return Err(TestErrorKind::StatementResultMismatch {
sql,
expected: expected_count,
actual: format!("affected {count} rows"),
}
.at(loc));
}
}
}
(Some(e), Some(expected_error)) => {
if !expected_error.is_match(&e.to_string()) {
return Err(TestErrorKind::ErrorMismatch {
sql,
err: Arc::new(e),
expected_err: expected_error.to_string(),
kind: RecordKind::Statement,
}
.at(loc));
}
}
(Some(e), None) => {
return Err(TestErrorKind::Fail {
sql,
err: Arc::new(e),
kind: RecordKind::Statement,
}
.at(loc));
}
},
(
Record::Query {
loc,
conditions: _,
type_string,
sort_mode: _,
label: _,
expected_error,
sql,
expected_results,
},
RecordOutput::Query { types, rows, error },
) => {
match (error, expected_error) {
(None, Some(_)) => {
return Err(TestErrorKind::Ok {
sql,
kind: RecordKind::Query,
}
.at(loc))
}
(None, None) => {}
(Some(e), Some(expected_error)) => {
if !expected_error.is_match(&e.to_string()) {
return Err(TestErrorKind::ErrorMismatch {
sql,
err: Arc::new(e),
expected_err: expected_error.to_string(),
kind: RecordKind::Query,
}
.at(loc));
}
return Ok(());
}
(Some(e), None) => {
return Err(TestErrorKind::Fail {
sql,
err: Arc::new(e),
kind: RecordKind::Query,
}
.at(loc));
}
};
if types.len() != type_string.len() {
}
for (t_actual, t_expected) in types.iter().zip(type_string.iter()) {
if t_actual != &ColumnType::Any
&& t_expected != &ColumnType::Any
&& t_actual != t_expected
{
}
}
if !(self.validator)(&rows, &expected_results) {
let output_rows = rows
.into_iter()
.map(|strs| strs.iter().join(" "))
.collect_vec();
return Err(TestErrorKind::QueryResultMismatch {
sql,
expected: expected_results.join("\n"),
actual: output_rows.join("\n"),
}
.at(loc));
}
}
_ => unreachable!(),
}
Ok(())
}
pub fn run(&mut self, record: Record) -> Result<(), TestError> {
futures::executor::block_on(self.run_async(record))
}
pub async fn run_multi_async(
&mut self,
records: impl IntoIterator<Item = Record>,
) -> Result<(), TestError> {
for record in records.into_iter() {
if let Record::Halt { .. } = record {
break;
}
self.run_async(record).await?;
}
Ok(())
}
pub fn run_multi(
&mut self,
records: impl IntoIterator<Item = Record>,
) -> Result<(), TestError> {
block_on(self.run_multi_async(records))
}
pub async fn run_script_async(&mut self, script: &str) -> Result<(), TestError> {
let records = parse(script).expect("failed to parse sqllogictest");
self.run_multi_async(records).await
}
pub async fn run_file_async(&mut self, filename: impl AsRef<Path>) -> Result<(), TestError> {
let records = parse_file(filename)?;
self.run_multi_async(records).await
}
pub fn run_script(&mut self, script: &str) -> Result<(), TestError> {
block_on(self.run_script_async(script))
}
pub fn run_file(&mut self, filename: impl AsRef<Path>) -> Result<(), TestError> {
block_on(self.run_file_async(filename))
}
pub async fn run_parallel_async<Fut>(
&mut self,
glob: &str,
hosts: Vec<String>,
conn_builder: fn(String, String) -> Fut,
jobs: usize,
) -> Result<(), ParallelTestError>
where
Fut: Future<Output = D>,
{
let files = glob::glob(glob).expect("failed to read glob pattern");
let mut tasks = vec![];
for (idx, file) in files.enumerate() {
let file = file.unwrap();
let db_name = file
.file_name()
.expect("not a valid filename")
.to_str()
.expect("not a UTF-8 filename");
let db_name = db_name.replace([' ', '.', '-'], "_");
self.db
.run(&format!("CREATE DATABASE {};", db_name))
.await
.expect("create db failed");
let target = hosts[idx % hosts.len()].clone();
tasks.push(async move {
let db = conn_builder(target, db_name).await;
let mut tester = Runner::new(db);
let filename = file.to_string_lossy().to_string();
tester.run_file_async(filename).await
})
}
let tasks = stream::iter(tasks).buffer_unordered(jobs);
let errors: Vec<_> = tasks
.filter_map(|result| async { result.err() })
.collect()
.await;
if errors.is_empty() {
Ok(())
} else {
Err(ParallelTestError { errors })
}
}
pub fn run_parallel<Fut>(
&mut self,
glob: &str,
hosts: Vec<String>,
conn_builder: fn(String, String) -> Fut,
jobs: usize,
) -> Result<(), ParallelTestError>
where
Fut: Future<Output = D>,
{
block_on(self.run_parallel_async(glob, hosts, conn_builder, jobs))
}
fn replace_keywords(&self, sql: String) -> String {
if let Some(testdir) = &self.testdir {
sql.replace("__TEST_DIR__", testdir.path().to_str().unwrap())
} else {
sql
}
}
fn should_skip(&self, conditions: &[Condition]) -> bool {
conditions
.iter()
.any(|c| c.should_skip(self.db.engine_name()))
}
}
#[allow(clippy::ptr_arg)]
fn normalize_string(s: &String) -> String {
s.trim().split_ascii_whitespace().join(" ")
}
pub fn update_record_with_output(
record: &Record,
record_output: &RecordOutput,
col_separator: &str,
validator: Validator,
) -> Option<Record> {
match (record.clone(), record_output) {
(_, RecordOutput::Nothing) => None,
(
Record::Statement {
sql,
loc,
conditions,
expected_error: None,
expected_count,
},
RecordOutput::Query { error: None, .. },
) => {
Some(Record::Statement {
sql,
expected_error: None,
loc,
conditions,
expected_count,
})
}
(
Record::Query {
sql,
loc,
conditions,
..
},
RecordOutput::Statement { error: None, .. },
) => Some(Record::Statement {
sql,
expected_error: None,
loc,
conditions,
expected_count: None,
}),
(
Record::Statement {
loc,
conditions,
expected_error,
sql,
expected_count,
},
RecordOutput::Statement { count, error },
) => match (error, expected_error) {
(None, _) => Some(Record::Statement {
sql,
expected_error: None,
loc,
conditions,
expected_count: expected_count.map(|_| *count),
}),
(Some(e), Some(expected_error)) if expected_error.is_match(&e.to_string()) => {
Some(Record::Statement {
sql,
expected_error: Some(expected_error),
loc,
conditions,
expected_count: None,
})
}
(Some(e), _) => Some(Record::Statement {
sql,
expected_error: Some(Regex::new(&e.to_string()).unwrap()),
loc,
conditions,
expected_count: None,
}),
},
(
Record::Query {
loc,
conditions,
type_string,
sort_mode,
label,
expected_error,
sql,
expected_results,
},
RecordOutput::Query {
types: _,
rows,
error,
},
) => {
match (error, expected_error) {
(None, _) => {}
(Some(e), Some(expected_error)) if expected_error.is_match(&e.to_string()) => {
return Some(Record::Query {
sql,
expected_error: Some(expected_error),
loc,
conditions,
type_string: vec![],
sort_mode,
label,
expected_results: vec![],
})
}
(Some(e), _) => {
return Some(Record::Query {
sql,
expected_error: Some(Regex::new(&e.to_string()).unwrap()),
loc,
conditions,
type_string: vec![],
sort_mode,
label,
expected_results: vec![],
})
}
};
let results = if validator(rows, &expected_results) {
expected_results
} else {
rows.iter().map(|cols| cols.join(col_separator)).collect()
};
Some(Record::Query {
sql,
expected_error: None,
loc,
conditions,
type_string,
sort_mode,
label,
expected_results: results,
})
}
_ => None,
}
}
pub async fn update_test_file<D: AsyncDB>(
filename: impl AsRef<Path>,
runner: &mut Runner<D>,
col_separator: &str,
validator: Validator,
) -> Result<(), Box<dyn std::error::Error>> {
use std::io::{Read, Seek, SeekFrom, Write};
use std::path::PathBuf;
use fs_err::{File, OpenOptions};
fn create_outfile(filename: impl AsRef<Path>) -> std::io::Result<(PathBuf, File)> {
let filename = filename.as_ref();
let outfilename = filename.file_name().unwrap().to_str().unwrap().to_owned() + ".temp";
let outfilename = filename.parent().unwrap().join(outfilename);
let outfile = OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.read(true)
.open(&outfilename)?;
Ok((outfilename, outfile))
}
fn override_with_outfile(
filename: &String,
outfilename: &PathBuf,
outfile: &mut File,
) -> std::io::Result<()> {
const N: usize = 8;
let mut buf = [0u8; N];
loop {
outfile.seek(SeekFrom::End(-(N as i64))).unwrap();
outfile.read_exact(&mut buf).unwrap();
let num_newlines = buf.iter().rev().take_while(|&&b| b == b'\n').count();
assert!(num_newlines > 0);
if num_newlines > 1 {
outfile
.set_len(outfile.metadata().unwrap().len() - num_newlines as u64 + 1)
.unwrap();
}
if num_newlines == 1 || num_newlines < N {
break;
}
}
outfile.flush()?;
fs_err::rename(outfilename, filename)?;
Ok(())
}
struct Item {
filename: String,
outfilename: PathBuf,
outfile: File,
halt: bool,
}
let filename = filename.as_ref();
let records = parse_file(filename)?;
let (outfilename, outfile) = create_outfile(filename)?;
let mut stack = vec![Item {
filename: filename.to_string_lossy().to_string(),
outfilename,
outfile,
halt: false,
}];
for record in records {
let Item {
filename,
outfilename,
outfile,
halt,
} = stack.last_mut().unwrap();
match &record {
Record::Injected(Injected::BeginInclude(filename)) => {
let (outfilename, outfile) = create_outfile(filename)?;
stack.push(Item {
filename: filename.clone(),
outfilename,
outfile,
halt: false,
});
}
Record::Injected(Injected::EndInclude(_)) => {
override_with_outfile(filename, outfilename, outfile)?;
stack.pop();
}
_ => {
if *halt {
writeln!(outfile, "{}", record)?;
continue;
}
if matches!(record, Record::Halt { .. }) {
*halt = true;
writeln!(outfile, "{}", record)?;
continue;
}
let record_output = runner.apply_record(record.clone()).await;
let record =
update_record_with_output(&record, &record_output, col_separator, validator)
.unwrap_or(record);
writeln!(outfile, "{}", record)?;
}
}
}
let Item {
filename,
outfilename,
outfile,
halt: _,
} = stack.last_mut().unwrap();
override_with_outfile(filename, outfilename, outfile)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_query_replacement() {
TestCase {
input: "query III\n\
select * from foo;\n\
----\n\
1 2",
record_output: query_output(&[&["3", "4"]]),
expected: Some(
"query III\n\
select * from foo;\n\
----\n\
3 4",
),
}
.run()
}
#[test]
fn test_query_replacement_no_input() {
TestCase {
input: "query III\n\
select * from foo;\n\
----",
record_output: query_output(&[&["3", "4"]]),
expected: Some(
"query III\n\
select * from foo;\n\
----\n\
3 4",
),
}
.run()
}
#[test]
fn test_query_replacement_no_output() {
TestCase {
input: "query III\n\
select * from foo;\n\
----",
record_output: RecordOutput::Nothing,
expected: None,
}
.run()
}
#[test]
fn test_query_replacement_error() {
TestCase {
input: "query III\n\
select * from foo;\n\
----",
record_output: query_output_error("MyAwesomeDB Error"),
expected: Some(
"query error TestError: MyAwesomeDB Error\n\
select * from foo;\n",
),
}
.run()
}
#[test]
fn test_statement_query_output() {
TestCase {
input: "statement ok\n\
create table foo;",
record_output: query_output(&[&["3", "4"]]),
expected: Some(
"statement ok\n\
create table foo;",
),
}
.run()
}
#[test]
fn test_query_statement_output() {
TestCase {
input: "query III\n\
select * from foo;\n\
----",
record_output: statement_output(3),
expected: Some(
"statement ok\n\
select * from foo;",
),
}
.run()
}
#[test]
fn test_statement_output() {
TestCase {
input: "statement ok\n\
insert into foo values(2);",
record_output: statement_output(3),
expected: Some(
"statement ok\n\
insert into foo values(2);",
),
}
.run()
}
#[test]
fn test_statement_error_to_ok() {
TestCase {
input: "statement error\n\
insert into foo values(2);",
record_output: statement_output(3),
expected: Some(
"statement ok\n\
insert into foo values(2);",
),
}
.run()
}
#[test]
fn test_statement_error_no_error() {
TestCase {
input: "statement error\n\
insert into foo values(2);",
record_output: statement_output_error("foo"),
expected: Some(
"statement error\n\
insert into foo values(2);",
),
}
.run()
}
#[test]
fn test_statement_error_new_error() {
TestCase {
input: "statement error bar\n\
insert into foo values(2);",
record_output: statement_output_error("foo"),
expected: Some(
"statement error TestError: foo\n\
insert into foo values(2);",
),
}
.run()
}
#[test]
fn test_statement_error_ok_to_error() {
TestCase {
input: "statement ok\n\
insert into foo values(2);",
record_output: statement_output_error("foo"),
expected: Some(
"statement error TestError: foo\n\
insert into foo values(2);",
),
}
.run()
}
#[derive(Debug)]
struct TestCase {
input: &'static str,
record_output: RecordOutput,
expected: Option<&'static str>,
}
impl TestCase {
fn run(self) {
let Self {
input,
record_output,
expected,
} = self;
println!("TestCase");
println!("**input:\n{input}\n");
println!("**record_output:\n{record_output:#?}\n");
println!("**expected:\n{}\n", expected.unwrap_or(""));
let input = parse_to_record(input);
let expected = expected.map(parse_to_record);
let output = update_record_with_output(&input, &record_output, " ", default_validator);
assert_eq!(
&output,
&expected,
"\n\noutput:\n\n{}\n\nexpected:\n\n{}",
output
.as_ref()
.map(|r| r.to_string())
.unwrap_or_else(|| "None".into()),
expected
.as_ref()
.map(|r| r.to_string())
.unwrap_or_else(|| "None".into()),
);
}
}
fn parse_to_record(s: &str) -> Record {
let mut records = parse(s).unwrap();
assert_eq!(records.len(), 1);
records.pop().unwrap()
}
fn query_output(rows: &[&[&str]]) -> RecordOutput {
let rows = rows
.iter()
.map(|cols| cols.iter().map(|c| c.to_string()).collect::<Vec<_>>())
.collect::<Vec<_>>();
let types = rows.iter().map(|_| ColumnType::Any).collect();
RecordOutput::Query {
types,
rows,
error: None,
}
}
fn query_output_error(error_message: &str) -> RecordOutput {
RecordOutput::Query {
types: vec![],
rows: vec![],
error: Some(Arc::new(TestError(error_message.to_string()))),
}
}
fn statement_output(count: u64) -> RecordOutput {
RecordOutput::Statement { count, error: None }
}
fn statement_output_error(error_message: &str) -> RecordOutput {
RecordOutput::Statement {
count: 0,
error: Some(Arc::new(TestError(error_message.to_string()))),
}
}
#[derive(Debug)]
struct TestError(String);
impl std::error::Error for TestError {}
impl std::fmt::Display for TestError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "TestError: {}", self.0)
}
}
}