Skip to main content

spawn_db/sqltest/
mod.rs

1use crate::config;
2use crate::template;
3use console::{style, Style};
4
5static BASE_TEST: &str = "-- Test file
6SELECT 1;
7";
8
9use similar::{ChangeTag, TextDiff};
10use std::fmt;
11use std::str;
12use std::sync::{Arc, Mutex};
13
14use anyhow::{Context, Result};
15
16pub struct Tester {
17    config: config::Config,
18    script_path: String,
19}
20
21#[derive(Debug)]
22pub struct TestOutcome {
23    pub diff: Option<String>,
24}
25
26impl Tester {
27    pub fn new(config: &config::Config, script_path: &str) -> Self {
28        Tester {
29            config: config.clone(),
30            script_path: script_path.to_string(),
31        }
32    }
33
34    pub fn test_folder(&self) -> String {
35        let mut s = self.config.pather().tests_folder();
36        s.push('/');
37        s.push_str(&self.script_path);
38        s
39    }
40
41    pub fn test_file_path(&self) -> String {
42        let mut s = self.test_folder();
43        s.push_str("/test.sql");
44        s
45    }
46
47    pub fn expected_file_path(&self) -> String {
48        format!("{}/expected", self.test_folder())
49    }
50
51    /// Opens the specified script file and generates a test script, compiled
52    /// using minijinja.
53    pub async fn generate(&self, variables: Option<crate::variables::Variables>) -> Result<String> {
54        let lock_file = None;
55
56        let gen = template::generate_streaming(
57            &self.config,
58            lock_file,
59            &self.test_file_path(),
60            variables,
61        )
62        .await?;
63
64        let mut buffer = Vec::new();
65        gen.render_to_writer(&mut buffer)
66            .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
67        let content = String::from_utf8(buffer)?;
68
69        Ok(content)
70    }
71
72    // Runs the test and compares the actual output to expected.
73    pub async fn run(&self, variables: Option<crate::variables::Variables>) -> Result<String> {
74        let content = self.generate(variables.clone()).await?;
75
76        let engine = self.config.new_engine().await?;
77
78        // Create a shared buffer to capture stdout
79        let stdout_buf = Arc::new(Mutex::new(Vec::new()));
80        let stdout_buf_clone = stdout_buf.clone();
81
82        engine
83            .execute_with_writer(
84                Box::new(move |writer| {
85                    writer.write_all(content.as_bytes())?;
86                    Ok(())
87                }),
88                Some(Box::new(SharedBufWriter(stdout_buf_clone))),
89                true, // Merge stderr into stdout for tests
90            )
91            .await
92            .context("failed to write content to test db")?;
93
94        let buf = stdout_buf.lock().unwrap();
95        let generated = String::from_utf8_lossy(&buf).to_string();
96
97        Ok(generated)
98    }
99
100    pub async fn run_compare(
101        &self,
102        variables: Option<crate::variables::Variables>,
103    ) -> Result<TestOutcome> {
104        let generated = self.run(variables).await?;
105        let expected_bytes = self
106            .config
107            .operator()
108            .read(&self.expected_file_path())
109            .await
110            .context("unable to read expectations file")?
111            .to_bytes();
112        let expected = String::from_utf8(expected_bytes.to_vec())
113            .context("expected file is not valid UTF-8")?;
114
115        let outcome = match self.compare(&generated, &expected) {
116            Ok(()) => TestOutcome { diff: None },
117            Err(differences) => TestOutcome {
118                diff: Some(differences.to_string()),
119            },
120        };
121
122        Ok(outcome)
123    }
124
125    pub async fn save_expected(
126        &self,
127        variables: Option<crate::variables::Variables>,
128    ) -> Result<()> {
129        let content = self.run(variables).await?;
130        self.config
131            .operator()
132            .write(&self.expected_file_path(), content)
133            .await
134            .context("unable to write expectation file")?;
135
136        Ok(())
137    }
138
139    /// Creates a new test folder with a blank test.sql file.
140    pub async fn create_test(&self) -> Result<String> {
141        let script_path = self.test_file_path();
142        println!("creating test at {}", &script_path);
143        self.config
144            .operator()
145            .write(&script_path, BASE_TEST)
146            .await?;
147
148        Ok(self.script_path.clone())
149    }
150
151    pub fn compare(&self, generated: &str, expected: &str) -> std::result::Result<(), String> {
152        let diff = TextDiff::from_lines(expected, generated);
153
154        let mut diff_display = String::new();
155
156        for (idx, group) in diff.grouped_ops(3).iter().enumerate() {
157            if idx > 0 {
158                diff_display.push_str(&format!("{:-^1$}", "-", 80));
159            }
160            for op in group {
161                for change in diff.iter_inline_changes(op) {
162                    let (sign, s) = match change.tag() {
163                        ChangeTag::Delete => ("-", Style::new().red()),
164                        ChangeTag::Insert => ("+", Style::new().green()),
165                        ChangeTag::Equal => (" ", Style::new().dim()),
166                    };
167                    diff_display.push_str(&format!(
168                        "{}{} |{}",
169                        style(Line(change.old_index())).dim(),
170                        style(Line(change.new_index())).dim(),
171                        s.apply_to(sign).bold(),
172                    ));
173                    for (emphasized, value) in change.iter_strings_lossy() {
174                        if emphasized {
175                            diff_display.push_str(&format!(
176                                "{}",
177                                s.apply_to(value).underlined().on_black()
178                            ));
179                        } else {
180                            diff_display.push_str(&format!("{}", s.apply_to(value)));
181                        }
182                    }
183                    if change.missing_newline() {
184                        diff_display.push('\n');
185                    }
186                }
187            }
188        }
189
190        if !diff_display.is_empty() {
191            return Err(diff_display);
192        }
193
194        Ok(())
195    }
196}
197
198struct Line(Option<usize>);
199
200impl fmt::Display for Line {
201    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
202        match self.0 {
203            None => write!(f, "    "),
204            Some(idx) => write!(f, "{:<4}", idx + 1),
205        }
206    }
207}
208
209/// A simple AsyncWrite implementation that appends to a shared Vec<u8>
210struct SharedBufWriter(Arc<Mutex<Vec<u8>>>);
211
212impl tokio::io::AsyncWrite for SharedBufWriter {
213    fn poll_write(
214        self: std::pin::Pin<&mut Self>,
215        _cx: &mut std::task::Context<'_>,
216        buf: &[u8],
217    ) -> std::task::Poll<std::io::Result<usize>> {
218        self.0.lock().unwrap().extend_from_slice(buf);
219        std::task::Poll::Ready(Ok(buf.len()))
220    }
221
222    fn poll_flush(
223        self: std::pin::Pin<&mut Self>,
224        _cx: &mut std::task::Context<'_>,
225    ) -> std::task::Poll<std::io::Result<()>> {
226        std::task::Poll::Ready(Ok(()))
227    }
228
229    fn poll_shutdown(
230        self: std::pin::Pin<&mut Self>,
231        _cx: &mut std::task::Context<'_>,
232    ) -> std::task::Poll<std::io::Result<()>> {
233        std::task::Poll::Ready(Ok(()))
234    }
235}