Skip to main content

spawn_db/sqltest/
mod.rs

1use crate::config;
2use crate::template;
3use console::{style, Style};
4
5static BASE_TEST: &str = "-- Test file
6SELECT 1;
7";
8
9use similar::{ChangeTag, TextDiff};
10use std::fmt;
11use std::str;
12use std::sync::{Arc, Mutex};
13
14use anyhow::{Context, Result};
15
16pub struct Tester {
17    config: config::Config,
18    script_path: String,
19}
20
21#[derive(Debug)]
22pub struct TestOutcome {
23    pub diff: Option<String>,
24}
25
26impl Tester {
27    pub fn new(config: &config::Config, script_path: &str) -> Self {
28        Tester {
29            config: config.clone(),
30            script_path: script_path.to_string(),
31        }
32    }
33
34    pub fn test_folder(&self) -> String {
35        let mut s = self.config.pather().tests_folder();
36        s.push('/');
37        s.push_str(&self.script_path);
38        s
39    }
40
41    pub fn test_file_path(&self) -> String {
42        let mut s = self.test_folder();
43        s.push_str("/test.sql");
44        s
45    }
46
47    pub fn expected_file_path(&self) -> String {
48        format!("{}/expected", self.test_folder())
49    }
50
51    /// Opens the specified script file and generates a test script, compiled
52    /// using minijinja.
53    pub async fn generate(&self, variables: Option<crate::variables::Variables>) -> Result<String> {
54        let lock_file = None;
55
56        let gen = template::generate_streaming(
57            &self.config,
58            lock_file,
59            &self.test_file_path(),
60            variables,
61        )
62        .await?;
63
64        let mut buffer = Vec::new();
65        gen.render_to_writer(&mut buffer)
66            .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
67        let content = String::from_utf8(buffer)?;
68
69        Ok(content)
70    }
71
72    // Runs the test and compares the actual output to expected.
73    pub async fn run(&self, variables: Option<crate::variables::Variables>) -> Result<String> {
74        let content = self.generate(variables.clone()).await?;
75
76        let engine = self.config.new_engine().await?;
77
78        // Create a shared buffer to capture stdout
79        let stdout_buf = Arc::new(Mutex::new(Vec::new()));
80        let stdout_buf_clone = stdout_buf.clone();
81
82        engine
83            .execute_with_writer(
84                Box::new(move |writer| {
85                    writer.write_all(content.as_bytes())?;
86                    Ok(())
87                }),
88                Some(Box::new(SharedBufWriter(stdout_buf_clone))),
89            )
90            .await
91            .context("failed to write content to test db")?;
92
93        let buf = stdout_buf.lock().unwrap();
94        let generated = String::from_utf8_lossy(&buf).to_string();
95
96        Ok(generated)
97    }
98
99    pub async fn run_compare(
100        &self,
101        variables: Option<crate::variables::Variables>,
102    ) -> Result<TestOutcome> {
103        let generated = self.run(variables).await?;
104        let expected_bytes = self
105            .config
106            .operator()
107            .read(&self.expected_file_path())
108            .await
109            .context("unable to read expectations file")?
110            .to_bytes();
111        let expected = String::from_utf8(expected_bytes.to_vec())
112            .context("expected file is not valid UTF-8")?;
113
114        let outcome = match self.compare(&generated, &expected) {
115            Ok(()) => TestOutcome { diff: None },
116            Err(differences) => TestOutcome {
117                diff: Some(differences.to_string()),
118            },
119        };
120
121        Ok(outcome)
122    }
123
124    pub async fn save_expected(
125        &self,
126        variables: Option<crate::variables::Variables>,
127    ) -> Result<()> {
128        let content = self.run(variables).await?;
129        self.config
130            .operator()
131            .write(&self.expected_file_path(), content)
132            .await
133            .context("unable to write expectation file")?;
134
135        Ok(())
136    }
137
138    /// Creates a new test folder with a blank test.sql file.
139    pub async fn create_test(&self) -> Result<String> {
140        let script_path = self.test_file_path();
141        println!("creating test at {}", &script_path);
142        self.config
143            .operator()
144            .write(&script_path, BASE_TEST)
145            .await?;
146
147        Ok(self.script_path.clone())
148    }
149
150    pub fn compare(&self, generated: &str, expected: &str) -> std::result::Result<(), String> {
151        let diff = TextDiff::from_lines(expected, generated);
152
153        let mut diff_display = String::new();
154
155        for (idx, group) in diff.grouped_ops(3).iter().enumerate() {
156            if idx > 0 {
157                diff_display.push_str(&format!("{:-^1$}", "-", 80));
158            }
159            for op in group {
160                for change in diff.iter_inline_changes(op) {
161                    let (sign, s) = match change.tag() {
162                        ChangeTag::Delete => ("-", Style::new().red()),
163                        ChangeTag::Insert => ("+", Style::new().green()),
164                        ChangeTag::Equal => (" ", Style::new().dim()),
165                    };
166                    diff_display.push_str(&format!(
167                        "{}{} |{}",
168                        style(Line(change.old_index())).dim(),
169                        style(Line(change.new_index())).dim(),
170                        s.apply_to(sign).bold(),
171                    ));
172                    for (emphasized, value) in change.iter_strings_lossy() {
173                        if emphasized {
174                            diff_display.push_str(&format!(
175                                "{}",
176                                s.apply_to(value).underlined().on_black()
177                            ));
178                        } else {
179                            diff_display.push_str(&format!("{}", s.apply_to(value)));
180                        }
181                    }
182                    if change.missing_newline() {
183                        diff_display.push('\n');
184                    }
185                }
186            }
187        }
188
189        if !diff_display.is_empty() {
190            return Err(diff_display);
191        }
192
193        Ok(())
194    }
195}
196
197struct Line(Option<usize>);
198
199impl fmt::Display for Line {
200    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
201        match self.0 {
202            None => write!(f, "    "),
203            Some(idx) => write!(f, "{:<4}", idx + 1),
204        }
205    }
206}
207
208/// A simple AsyncWrite implementation that appends to a shared Vec<u8>
209struct SharedBufWriter(Arc<Mutex<Vec<u8>>>);
210
211impl tokio::io::AsyncWrite for SharedBufWriter {
212    fn poll_write(
213        self: std::pin::Pin<&mut Self>,
214        _cx: &mut std::task::Context<'_>,
215        buf: &[u8],
216    ) -> std::task::Poll<std::io::Result<usize>> {
217        self.0.lock().unwrap().extend_from_slice(buf);
218        std::task::Poll::Ready(Ok(buf.len()))
219    }
220
221    fn poll_flush(
222        self: std::pin::Pin<&mut Self>,
223        _cx: &mut std::task::Context<'_>,
224    ) -> std::task::Poll<std::io::Result<()>> {
225        std::task::Poll::Ready(Ok(()))
226    }
227
228    fn poll_shutdown(
229        self: std::pin::Pin<&mut Self>,
230        _cx: &mut std::task::Context<'_>,
231    ) -> std::task::Poll<std::io::Result<()>> {
232        std::task::Poll::Ready(Ok(()))
233    }
234}