Skip to main content

spawn_db/sqltest/
mod.rs

1use crate::config;
2use crate::engine::EngineError;
3use crate::template;
4use console::{style, Style};
5
6static BASE_TEST: &str = "-- Test file
7SELECT 1;
8";
9
10use similar::{ChangeTag, TextDiff};
11use std::fmt;
12use std::str;
13use std::sync::{Arc, Mutex};
14
15use anyhow::{Context, Result};
16
17pub struct Tester {
18    config: config::Config,
19    script_path: String,
20}
21
22#[derive(Debug)]
23pub struct TestOutcome {
24    pub diff: Option<String>,
25}
26
27impl Tester {
28    pub fn new(config: &config::Config, script_path: &str) -> Self {
29        Tester {
30            config: config.clone(),
31            script_path: script_path.to_string(),
32        }
33    }
34
35    pub fn test_folder(&self) -> String {
36        let mut s = self.config.pather().tests_folder();
37        s.push('/');
38        s.push_str(&self.script_path);
39        s
40    }
41
42    pub fn test_file_path(&self) -> String {
43        let mut s = self.test_folder();
44        s.push_str("/test.sql");
45        s
46    }
47
48    pub fn expected_file_path(&self) -> String {
49        format!("{}/expected", self.test_folder())
50    }
51
52    /// Opens the specified script file and generates a test script, compiled
53    /// using minijinja.
54    pub async fn generate(&self, variables: Option<crate::variables::Variables>) -> Result<String> {
55        let lock_file = None;
56
57        let gen = template::generate_streaming(
58            &self.config,
59            lock_file,
60            &self.test_file_path(),
61            variables,
62        )
63        .await?;
64
65        let mut buffer = Vec::new();
66        gen.render_to_writer(&mut buffer)
67            .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
68        let content = String::from_utf8(buffer)?;
69
70        Ok(content)
71    }
72
73    // Runs the test and compares the actual output to expected.
74    pub async fn run(&self, variables: Option<crate::variables::Variables>) -> Result<String> {
75        let content = self.generate(variables.clone()).await?;
76
77        let engine = self.config.new_engine().await?;
78
79        // Create a shared buffer to capture stdout
80        let stdout_buf = Arc::new(Mutex::new(Vec::new()));
81        let stdout_buf_clone = stdout_buf.clone();
82
83        match engine
84            .execute_with_writer(
85                Box::new(move |writer| {
86                    writer.write_all(content.as_bytes())?;
87                    Ok(())
88                }),
89                Some(Box::new(SharedBufWriter(stdout_buf_clone))),
90                true, // Merge stderr into stdout for tests
91            )
92            .await
93        {
94            Ok(()) => {}
95            Err(EngineError::ExecutionFailed { .. }) => {
96                // psql exited non-zero (e.g. ON_ERROR_STOP triggered).
97                // The combined output buffer already has the error output,
98                // so we just continue and return it.
99            }
100            Err(e) => return Err(e).context("failed to write content to test db"),
101        }
102
103        let buf = stdout_buf.lock().unwrap();
104        let generated = String::from_utf8_lossy(&buf).to_string();
105
106        Ok(generated)
107    }
108
109    pub async fn run_compare(
110        &self,
111        variables: Option<crate::variables::Variables>,
112    ) -> Result<TestOutcome> {
113        let generated = self.run(variables).await?;
114        let expected_bytes = self
115            .config
116            .operator()
117            .read(&self.expected_file_path())
118            .await
119            .context("unable to read expectations file")?
120            .to_bytes();
121        let expected = String::from_utf8(expected_bytes.to_vec())
122            .context("expected file is not valid UTF-8")?;
123
124        let outcome = match self.compare(&generated, &expected) {
125            Ok(()) => TestOutcome { diff: None },
126            Err(differences) => TestOutcome {
127                diff: Some(differences.to_string()),
128            },
129        };
130
131        Ok(outcome)
132    }
133
134    pub async fn save_expected(
135        &self,
136        variables: Option<crate::variables::Variables>,
137    ) -> Result<()> {
138        let content = self.run(variables).await?;
139        self.config
140            .operator()
141            .write(&self.expected_file_path(), content)
142            .await
143            .context("unable to write expectation file")?;
144
145        Ok(())
146    }
147
148    /// Creates a new test folder with a blank test.sql file.
149    pub async fn create_test(&self) -> Result<String> {
150        let script_path = self.test_file_path();
151        println!("creating test at {}", &script_path);
152        self.config
153            .operator()
154            .write(&script_path, BASE_TEST)
155            .await?;
156
157        Ok(self.script_path.clone())
158    }
159
160    pub fn compare(&self, generated: &str, expected: &str) -> std::result::Result<(), String> {
161        let diff = TextDiff::from_lines(expected, generated);
162
163        let mut diff_display = String::new();
164
165        for (idx, group) in diff.grouped_ops(3).iter().enumerate() {
166            if idx > 0 {
167                diff_display.push_str(&format!("{:-^1$}", "-", 80));
168            }
169            for op in group {
170                for change in diff.iter_inline_changes(op) {
171                    let (sign, s) = match change.tag() {
172                        ChangeTag::Delete => ("-", Style::new().red()),
173                        ChangeTag::Insert => ("+", Style::new().green()),
174                        ChangeTag::Equal => (" ", Style::new().dim()),
175                    };
176                    diff_display.push_str(&format!(
177                        "{}{} |{}",
178                        style(Line(change.old_index())).dim(),
179                        style(Line(change.new_index())).dim(),
180                        s.apply_to(sign).bold(),
181                    ));
182                    for (emphasized, value) in change.iter_strings_lossy() {
183                        if emphasized {
184                            diff_display.push_str(&format!(
185                                "{}",
186                                s.apply_to(value).underlined().on_black()
187                            ));
188                        } else {
189                            diff_display.push_str(&format!("{}", s.apply_to(value)));
190                        }
191                    }
192                    if change.missing_newline() {
193                        diff_display.push('\n');
194                    }
195                }
196            }
197        }
198
199        if !diff_display.is_empty() {
200            return Err(diff_display);
201        }
202
203        Ok(())
204    }
205}
206
207struct Line(Option<usize>);
208
209impl fmt::Display for Line {
210    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
211        match self.0 {
212            None => write!(f, "    "),
213            Some(idx) => write!(f, "{:<4}", idx + 1),
214        }
215    }
216}
217
218/// A simple AsyncWrite implementation that appends to a shared Vec<u8>
219struct SharedBufWriter(Arc<Mutex<Vec<u8>>>);
220
221impl tokio::io::AsyncWrite for SharedBufWriter {
222    fn poll_write(
223        self: std::pin::Pin<&mut Self>,
224        _cx: &mut std::task::Context<'_>,
225        buf: &[u8],
226    ) -> std::task::Poll<std::io::Result<usize>> {
227        self.0.lock().unwrap().extend_from_slice(buf);
228        std::task::Poll::Ready(Ok(buf.len()))
229    }
230
231    fn poll_flush(
232        self: std::pin::Pin<&mut Self>,
233        _cx: &mut std::task::Context<'_>,
234    ) -> std::task::Poll<std::io::Result<()>> {
235        std::task::Poll::Ready(Ok(()))
236    }
237
238    fn poll_shutdown(
239        self: std::pin::Pin<&mut Self>,
240        _cx: &mut std::task::Context<'_>,
241    ) -> std::task::Poll<std::io::Result<()>> {
242        std::task::Poll::Ready(Ok(()))
243    }
244}