1use crate::config;
2use crate::engine::EngineError;
3use crate::template;
4use console::{style, Style};
5
6static BASE_TEST: &str = "-- Test file
7SELECT 1;
8";
9
10use similar::{ChangeTag, TextDiff};
11use std::fmt;
12use std::str;
13use std::sync::{Arc, Mutex};
14
15use anyhow::{Context, Result};
16
17pub struct Tester {
18 config: config::Config,
19 script_path: String,
20}
21
22#[derive(Debug)]
23pub struct TestOutcome {
24 pub diff: Option<String>,
25}
26
27impl Tester {
28 pub fn new(config: &config::Config, script_path: &str) -> Self {
29 Tester {
30 config: config.clone(),
31 script_path: script_path.to_string(),
32 }
33 }
34
35 pub fn test_folder(&self) -> String {
36 let mut s = self.config.pather().tests_folder();
37 s.push('/');
38 s.push_str(&self.script_path);
39 s
40 }
41
42 pub fn test_file_path(&self) -> String {
43 let mut s = self.test_folder();
44 s.push_str("/test.sql");
45 s
46 }
47
48 pub fn expected_file_path(&self) -> String {
49 format!("{}/expected", self.test_folder())
50 }
51
52 pub async fn generate(&self, variables: Option<crate::variables::Variables>) -> Result<String> {
55 let lock_file = None;
56
57 let gen = template::generate_streaming(
58 &self.config,
59 lock_file,
60 &self.test_file_path(),
61 variables,
62 )
63 .await?;
64
65 let mut buffer = Vec::new();
66 gen.render_to_writer(&mut buffer)
67 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
68 let content = String::from_utf8(buffer)?;
69
70 Ok(content)
71 }
72
73 pub async fn run(&self, variables: Option<crate::variables::Variables>) -> Result<String> {
75 let content = self.generate(variables.clone()).await?;
76
77 let engine = self.config.new_engine().await?;
78
79 let stdout_buf = Arc::new(Mutex::new(Vec::new()));
81 let stdout_buf_clone = stdout_buf.clone();
82
83 match engine
84 .execute_with_writer(
85 Box::new(move |writer| {
86 writer.write_all(content.as_bytes())?;
87 Ok(())
88 }),
89 Some(Box::new(SharedBufWriter(stdout_buf_clone))),
90 true, )
92 .await
93 {
94 Ok(()) => {}
95 Err(EngineError::ExecutionFailed { .. }) => {
96 }
100 Err(e) => return Err(e).context("failed to write content to test db"),
101 }
102
103 let buf = stdout_buf.lock().unwrap();
104 let generated = String::from_utf8_lossy(&buf).to_string();
105
106 Ok(generated)
107 }
108
109 pub async fn run_compare(
110 &self,
111 variables: Option<crate::variables::Variables>,
112 ) -> Result<TestOutcome> {
113 let generated = self.run(variables).await?;
114 let expected_bytes = self
115 .config
116 .operator()
117 .read(&self.expected_file_path())
118 .await
119 .context("unable to read expectations file")?
120 .to_bytes();
121 let expected = String::from_utf8(expected_bytes.to_vec())
122 .context("expected file is not valid UTF-8")?;
123
124 let outcome = match self.compare(&generated, &expected) {
125 Ok(()) => TestOutcome { diff: None },
126 Err(differences) => TestOutcome {
127 diff: Some(differences.to_string()),
128 },
129 };
130
131 Ok(outcome)
132 }
133
134 pub async fn save_expected(
135 &self,
136 variables: Option<crate::variables::Variables>,
137 ) -> Result<()> {
138 let content = self.run(variables).await?;
139 self.config
140 .operator()
141 .write(&self.expected_file_path(), content)
142 .await
143 .context("unable to write expectation file")?;
144
145 Ok(())
146 }
147
148 pub async fn create_test(&self) -> Result<String> {
150 let script_path = self.test_file_path();
151 println!("creating test at {}", &script_path);
152 self.config
153 .operator()
154 .write(&script_path, BASE_TEST)
155 .await?;
156
157 Ok(self.script_path.clone())
158 }
159
160 pub fn compare(&self, generated: &str, expected: &str) -> std::result::Result<(), String> {
161 let diff = TextDiff::from_lines(expected, generated);
162
163 let mut diff_display = String::new();
164
165 for (idx, group) in diff.grouped_ops(3).iter().enumerate() {
166 if idx > 0 {
167 diff_display.push_str(&format!("{:-^1$}", "-", 80));
168 }
169 for op in group {
170 for change in diff.iter_inline_changes(op) {
171 let (sign, s) = match change.tag() {
172 ChangeTag::Delete => ("-", Style::new().red()),
173 ChangeTag::Insert => ("+", Style::new().green()),
174 ChangeTag::Equal => (" ", Style::new().dim()),
175 };
176 diff_display.push_str(&format!(
177 "{}{} |{}",
178 style(Line(change.old_index())).dim(),
179 style(Line(change.new_index())).dim(),
180 s.apply_to(sign).bold(),
181 ));
182 for (emphasized, value) in change.iter_strings_lossy() {
183 if emphasized {
184 diff_display.push_str(&format!(
185 "{}",
186 s.apply_to(value).underlined().on_black()
187 ));
188 } else {
189 diff_display.push_str(&format!("{}", s.apply_to(value)));
190 }
191 }
192 if change.missing_newline() {
193 diff_display.push('\n');
194 }
195 }
196 }
197 }
198
199 if !diff_display.is_empty() {
200 return Err(diff_display);
201 }
202
203 Ok(())
204 }
205}
206
207struct Line(Option<usize>);
208
209impl fmt::Display for Line {
210 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
211 match self.0 {
212 None => write!(f, " "),
213 Some(idx) => write!(f, "{:<4}", idx + 1),
214 }
215 }
216}
217
218struct SharedBufWriter(Arc<Mutex<Vec<u8>>>);
220
221impl tokio::io::AsyncWrite for SharedBufWriter {
222 fn poll_write(
223 self: std::pin::Pin<&mut Self>,
224 _cx: &mut std::task::Context<'_>,
225 buf: &[u8],
226 ) -> std::task::Poll<std::io::Result<usize>> {
227 self.0.lock().unwrap().extend_from_slice(buf);
228 std::task::Poll::Ready(Ok(buf.len()))
229 }
230
231 fn poll_flush(
232 self: std::pin::Pin<&mut Self>,
233 _cx: &mut std::task::Context<'_>,
234 ) -> std::task::Poll<std::io::Result<()>> {
235 std::task::Poll::Ready(Ok(()))
236 }
237
238 fn poll_shutdown(
239 self: std::pin::Pin<&mut Self>,
240 _cx: &mut std::task::Context<'_>,
241 ) -> std::task::Poll<std::io::Result<()>> {
242 std::task::Poll::Ready(Ok(()))
243 }
244}