llkv_sql/
slt.rs

1use libtest_mimic::{Arguments, Conclusion, Failed, Trial};
2use llkv_result::Error;
3use regex::escape;
4use sqllogictest::{AsyncDB, DefaultColumnType, Runner};
5use std::path::Path;
6
7/// Run a single slt file using the provided AsyncDB factory. The factory is
8/// a closure that returns a future resolving to a new DB instance for the
9/// runner. This mirrors sqllogictest's Runner::new signature and behavior.
10pub async fn run_slt_file_with_factory<F, Fut, D, E>(path: &Path, factory: F) -> Result<(), Error>
11where
12    F: Fn() -> Fut + Send + Sync + 'static,
13    Fut: std::future::Future<Output = Result<D, E>> + Send,
14    D: AsyncDB<Error = Error, ColumnType = DefaultColumnType> + Send + 'static,
15    E: std::fmt::Debug,
16{
17    let text = std::fs::read_to_string(path)
18        .map_err(|e| Error::Internal(format!("failed to read slt file: {}", e)))?;
19    let raw_lines: Vec<String> = text.lines().map(|l| l.to_string()).collect();
20    let (expanded_lines, mapping) = expand_loops_with_mapping(&raw_lines, 0)?;
21    let (expanded_lines, mapping) = {
22        let mut filtered_lines = Vec::with_capacity(expanded_lines.len());
23        let mut filtered_mapping = Vec::with_capacity(mapping.len());
24        for (line, orig_line) in expanded_lines.into_iter().zip(mapping.into_iter()) {
25            if line.trim_start().starts_with("load ") {
26                tracing::warn!(
27                    "Ignoring unsupported SLT directive `load`: {}:{} -> {}",
28                    path.display(),
29                    orig_line,
30                    line.trim()
31                );
32                continue;
33            }
34            filtered_lines.push(line);
35            filtered_mapping.push(orig_line);
36        }
37        (filtered_lines, filtered_mapping)
38    };
39    let (normalized_lines, mapping) = normalize_inline_connections(expanded_lines, mapping);
40
41    let expanded_text = normalized_lines.join("\n");
42    let mut named = tempfile::NamedTempFile::new()
43        .map_err(|e| Error::Internal(format!("failed to create temp slt file: {}", e)))?;
44    use std::io::Write as _;
45    named
46        .write_all(expanded_text.as_bytes())
47        .map_err(|e| Error::Internal(format!("failed to write temp slt file: {}", e)))?;
48    if std::env::var("LLKV_DUMP_SLT").is_ok() {
49        let dump_path = std::path::Path::new("target/normalized.slt");
50        if let Some(parent) = dump_path.parent() {
51            let _ = std::fs::create_dir_all(parent);
52        }
53        if let Err(e) = std::fs::write(dump_path, &expanded_text) {
54            tracing::warn!("failed to dump normalized slt file: {}", e);
55        }
56    }
57    let tmp = named.path().to_path_buf();
58
59    let mut runner = Runner::new(|| async {
60        factory()
61            .await
62            .map_err(|e| Error::Internal(format!("factory error: {:?}", e)))
63    });
64
65    // Align with the canonical sqllogictest harness default (256 values) so large
66    // result sets compare by MD5 digest instead of dumping every row.
67    runner.with_hash_threshold(256);
68
69    if let Err(e) = runner.run_file_async(&tmp).await {
70        let (mapped, opt_orig_line) =
71            map_temp_error_message(&format!("{}", e), &tmp, &normalized_lines, &mapping, path);
72        if let Some(orig_line) = opt_orig_line
73            && let Ok(text) = std::fs::read_to_string(path)
74            && let Some(line) = text.lines().nth(orig_line - 1)
75        {
76            eprintln!(
77                "[llkv-slt] original {}:{}: {}",
78                path.display(),
79                orig_line,
80                line.trim()
81            );
82        }
83        drop(named);
84        return Err(Error::Internal(format!("slt runner failed: {}", mapped)));
85    }
86
87    drop(named);
88    Ok(())
89}
90
91/// Discover `.slt` files under the given directory and run them as
92/// libtest_mimic trials using the provided AsyncDB factory constructor.
93///
94/// The `factory_factory` closure is called once per test file and should return
95/// a factory closure that creates DB instances. This allows each test file to
96/// have isolated state while enabling multiple connections within a test to
97/// share state. This keeps the harness engine-agnostic so different crates
98/// can provide their own engine adapters.
99pub fn run_slt_harness<FF, F, Fut, D, E>(slt_dir: &str, factory_factory: FF)
100where
101    FF: Fn() -> F + Send + Sync + 'static + Clone,
102    F: Fn() -> Fut + Send + Sync + 'static,
103    Fut: std::future::Future<Output = Result<D, E>> + Send + 'static,
104    D: AsyncDB<Error = Error, ColumnType = DefaultColumnType> + Send + 'static,
105    E: std::fmt::Debug + Send + 'static,
106{
107    let args = Arguments::from_args();
108    let conclusion = run_slt_harness_with_args(slt_dir, factory_factory, args);
109    if conclusion.has_failed() {
110        panic!(
111            "SLT harness reported {} failed test(s)",
112            conclusion.num_failed
113        );
114    }
115}
116
117/// Same as [`run_slt_harness`], but accepts pre-parsed [`Arguments`] so callers
118/// can control CLI parsing (e.g. custom binaries).
119pub fn run_slt_harness_with_args<FF, F, Fut, D, E>(
120    slt_dir: &str,
121    factory_factory: FF,
122    args: Arguments,
123) -> Conclusion
124where
125    FF: Fn() -> F + Send + Sync + 'static + Clone,
126    F: Fn() -> Fut + Send + Sync + 'static,
127    Fut: std::future::Future<Output = Result<D, E>> + Send + 'static,
128    D: AsyncDB<Error = Error, ColumnType = DefaultColumnType> + Send + 'static,
129    E: std::fmt::Debug + Send + 'static,
130{
131    let base = std::path::Path::new(slt_dir);
132    // Discover files
133    let files = {
134        let mut out = Vec::new();
135        if base.exists() {
136            let mut stack = vec![base.to_path_buf()];
137            while let Some(p) = stack.pop() {
138                if p.is_dir() {
139                    if let Ok(read) = std::fs::read_dir(&p) {
140                        for entry in read.flatten() {
141                            stack.push(entry.path());
142                        }
143                    }
144                } else if let Some(ext) = p.extension()
145                    && ext == "slt"
146                {
147                    out.push(p);
148                }
149            }
150        }
151        out.sort();
152        out
153    };
154
155    let base_parent = base.parent();
156    let mut trials: Vec<Trial> = Vec::new();
157    for f in files {
158        let name_path = base_parent
159            .and_then(|parent| f.strip_prefix(parent).ok())
160            .or_else(|| f.strip_prefix(base).ok())
161            .unwrap_or(&f);
162        let mut name = name_path.to_string_lossy().to_string();
163        if std::path::MAIN_SEPARATOR != '/' {
164            name = name.replace(std::path::MAIN_SEPARATOR, "/");
165        }
166        let name = name.trim_start_matches(&['/', '\\'][..]).to_string();
167        let path_clone = f.clone();
168        let factory_factory_clone = factory_factory.clone();
169        trials.push(Trial::test(name, move || {
170            let p = path_clone.clone();
171            // Call the factory_factory to get a fresh factory for this test file
172            let fac = factory_factory_clone();
173            let rt = tokio::runtime::Builder::new_current_thread()
174                .enable_all()
175                .build()
176                .map_err(|e| Failed::from(format!("failed to build tokio runtime: {e}")))?;
177            let res: Result<(), Error> =
178                rt.block_on(async move { run_slt_file_with_factory(&p, fac).await });
179            res.map_err(|e| Failed::from(format!("slt runner error: {e}")))
180        }));
181    }
182
183    libtest_mimic::run(&args, trials)
184}
185
186/// Expand `loop var start count` directives, returning the expanded lines and
187/// a mapping from expanded line index to the original 1-based source line.
188pub fn expand_loops_with_mapping(
189    lines: &[String],
190    base_index: usize,
191) -> Result<(Vec<String>, Vec<usize>), Error> {
192    let mut out_lines: Vec<String> = Vec::new();
193    let mut out_map: Vec<usize> = Vec::new();
194    let mut i = 0usize;
195    while i < lines.len() {
196        let line = lines[i].trim_start().to_string();
197        if line.starts_with("loop ") {
198            let parts: Vec<&str> = line.split_whitespace().collect();
199            if parts.len() < 4 {
200                return Err(Error::Internal(format!(
201                    "malformed loop directive: {}",
202                    line
203                )));
204            }
205            let var = parts[1];
206            let start: i64 = parts[2]
207                .parse()
208                .map_err(|e| Error::Internal(format!("invalid loop start: {}", e)))?;
209            let count: i64 = parts[3]
210                .parse()
211                .map_err(|e| Error::Internal(format!("invalid loop count: {}", e)))?;
212
213            let mut j = i + 1;
214            while j < lines.len() && lines[j].trim_start() != "endloop" {
215                j += 1;
216            }
217            if j >= lines.len() {
218                return Err(Error::Internal("unterminated loop in slt".to_string()));
219            }
220
221            let inner = &lines[i + 1..j];
222            let (expanded_inner, inner_map) = expand_loops_with_mapping(inner, base_index + i + 1)?;
223
224            for k in 0..count {
225                let val = (start + k).to_string();
226                let token_plain = format!("${}", var);
227                let token_braced = format!("${{{}}}", var);
228                for (s, &orig_line) in expanded_inner.iter().zip(inner_map.iter()) {
229                    let substituted = s.replace(&token_braced, &val).replace(&token_plain, &val);
230                    out_lines.push(substituted);
231                    out_map.push(orig_line);
232                }
233            }
234
235            i = j + 1;
236        } else {
237            out_lines.push(lines[i].clone());
238            out_map.push(base_index + i + 1);
239            i += 1;
240        }
241    }
242    Ok((out_lines, out_map))
243}
244
245/// Convert sqllogictest inline connection syntax (e.g. `statement ok con1`)
246/// into explicit `connection` records so the upstream parser can understand them.
247/// Also ensures proper termination of statement error blocks by adding a blank line
248/// after ---- when there's no expected error pattern.
249#[allow(clippy::type_complexity)] // NOTE: Helper returns structured parsing results; refactor once sqllogictest parser is extracted.
250fn normalize_inline_connections(
251    lines: Vec<String>,
252    mapping: Vec<usize>,
253) -> (Vec<String>, Vec<usize>) {
254    fn collect_statement_error_block(
255        lines: &[String],
256        mapping: &[usize],
257        start: usize,
258    ) -> (
259        Vec<(String, usize)>,
260        Option<String>,
261        Vec<(String, usize)>,
262        bool,
263        usize,
264    ) {
265        let mut sql_lines = Vec::new();
266        let mut message_lines = Vec::new();
267        let mut regex_pattern = None;
268        let mut idx = start;
269        let mut saw_separator = false;
270
271        while idx < lines.len() {
272            let line = &lines[idx];
273            let trimmed = line.trim_start();
274            if trimmed == "----" {
275                saw_separator = true;
276                idx += 1;
277                break;
278            }
279            sql_lines.push((line.clone(), mapping[idx]));
280            idx += 1;
281        }
282
283        if saw_separator {
284            while idx < lines.len() {
285                let line = &lines[idx];
286                let trimmed_full = line.trim();
287                if trimmed_full.is_empty() {
288                    idx += 1;
289                    break;
290                }
291                if let Some(pattern) = trimmed_full.strip_prefix("<REGEX>:") {
292                    regex_pattern = Some(pattern.to_string());
293                    idx += 1;
294                    while idx < lines.len() && lines[idx].trim().is_empty() {
295                        idx += 1;
296                    }
297                    message_lines.clear();
298                    break;
299                }
300                message_lines.push((line.clone(), mapping[idx]));
301                idx += 1;
302            }
303
304            if regex_pattern.is_none()
305                && !message_lines.is_empty()
306                && let Some((first_line, _)) = message_lines.first()
307            {
308                let trimmed_first = first_line.trim();
309                if !trimmed_first.is_empty() {
310                    let escaped = escape(trimmed_first);
311                    regex_pattern = Some(format!(".*{}.*", escaped));
312                    message_lines.clear();
313                }
314            }
315        }
316
317        (sql_lines, regex_pattern, message_lines, saw_separator, idx)
318    }
319
320    fn is_connection_token(token: &str) -> bool {
321        token
322            .strip_prefix("con")
323            .map(|suffix| !suffix.is_empty() && suffix.chars().all(|ch| ch.is_ascii_digit()))
324            .unwrap_or(false)
325    }
326
327    let mut out_lines = Vec::with_capacity(lines.len());
328    let mut out_map = Vec::with_capacity(mapping.len());
329
330    let mut i = 0usize;
331    while i < lines.len() {
332        let line = &lines[i];
333        let orig = mapping[i];
334        let trimmed = line.trim_start();
335
336        // Handle connection syntax normalization
337        if trimmed.starts_with("statement ") || trimmed.starts_with("query ") {
338            let mut tokens: Vec<&str> = trimmed.split_whitespace().collect();
339            if tokens.len() >= 3 && tokens.last().is_some_and(|last| is_connection_token(last)) {
340                let conn = tokens.pop().unwrap();
341                let indent_len = line.len() - trimmed.len();
342                let indent = &line[..indent_len];
343
344                out_lines.push(format!("{indent}connection {conn}"));
345                out_map.push(orig);
346
347                let normalized = format!("{indent}{}", tokens.join(" "));
348                let normalized_trimmed = normalized.trim_start();
349                if normalized_trimmed.starts_with("statement error") {
350                    let (sql_lines, regex_pattern, message_lines, saw_separator, new_idx) =
351                        collect_statement_error_block(&lines, &mapping, i + 1);
352                    i = new_idx;
353
354                    let has_regex = regex_pattern.is_some();
355                    if let Some(pattern) = regex_pattern {
356                        out_lines.push(format!("{indent}connection {conn}"));
357                        out_map.push(orig);
358                        out_lines.push(format!("{indent}statement error {}", pattern));
359                        out_map.push(orig);
360                    } else {
361                        out_lines.push(normalized.clone());
362                        out_map.push(orig);
363                    }
364                    for (sql_line, sql_map) in sql_lines {
365                        out_lines.push(sql_line);
366                        out_map.push(sql_map);
367                    }
368                    // Only output ---- when there are actual error message lines
369                    if saw_separator && !has_regex && !message_lines.is_empty() {
370                        out_lines.push(format!("{indent}----"));
371                        out_map.push(orig);
372                        for (msg_line, msg_map) in message_lines {
373                            out_lines.push(msg_line);
374                            out_map.push(msg_map);
375                        }
376                        // Add extra blank line after plain text error messages to prevent multiline interpretation
377                        out_lines.push(String::new());
378                        out_map.push(orig);
379                    }
380                    out_lines.push(String::new());
381                    out_map.push(orig);
382                    continue;
383                } else {
384                    // Not a statement error, just output the normalized line
385                    out_lines.push(normalized);
386                    out_map.push(orig);
387                    i += 1;
388                    continue;
389                }
390            }
391        }
392
393        // Check if this is a statement error (without inline connection) followed by ----
394        if trimmed.starts_with("statement error") {
395            let indent = &line[..line.len() - trimmed.len()];
396            let (sql_lines, regex_pattern, message_lines, saw_separator, new_idx) =
397                collect_statement_error_block(&lines, &mapping, i + 1);
398            i = new_idx;
399
400            let has_regex = regex_pattern.is_some();
401            if let Some(pattern) = regex_pattern {
402                out_lines.push(format!("{indent}statement error {}", pattern));
403                out_map.push(orig);
404            } else {
405                out_lines.push(line.clone());
406                out_map.push(orig);
407            }
408            for (sql_line, sql_map) in sql_lines {
409                out_lines.push(sql_line);
410                out_map.push(sql_map);
411            }
412            // Only output ---- when there are actual error message lines
413            if saw_separator && !has_regex && !message_lines.is_empty() {
414                out_lines.push(format!("{indent}----"));
415                out_map.push(orig);
416                for (msg_line, msg_map) in message_lines {
417                    out_lines.push(msg_line);
418                    out_map.push(msg_map);
419                }
420                // Add extra blank line after plain text error messages to prevent multiline interpretation
421                out_lines.push(String::new());
422                out_map.push(orig);
423            }
424            out_lines.push(String::new());
425            out_map.push(orig);
426            continue;
427        }
428
429        out_lines.push(line.clone());
430        out_map.push(orig);
431        i += 1;
432    }
433
434    (out_lines, out_map)
435}
436
437/// Map a temporary expanded-file error message back to the original file path
438/// and line; returns (mapped_message, optional original line number).
439pub fn map_temp_error_message(
440    err_msg: &str,
441    tmp_path: &Path,
442    expanded_lines: &[String],
443    mapping: &[usize],
444    orig_path: &Path,
445) -> (String, Option<usize>) {
446    let tmp_str = tmp_path.to_string_lossy().to_string();
447    let mut out = err_msg.to_string();
448    if let Some(pos) = out.find(&tmp_str) {
449        let after = &out[pos + tmp_str.len()..];
450        if let Some(stripped) = after.strip_prefix(':') {
451            let mut digits = String::new();
452            for ch in stripped.chars() {
453                if ch.is_ascii_digit() {
454                    digits.push(ch);
455                } else {
456                    break;
457                }
458            }
459            if let Ok(expanded_line) = digits.parse::<usize>() {
460                let candidates: [isize; 3] = [1, 0, -1];
461                for &off in &candidates {
462                    let idx = (expanded_line as isize - 1) + off;
463                    if idx >= 0 && (idx as usize) < mapping.len() {
464                        let idx_us = idx as usize;
465                        let expanded_text =
466                            expanded_lines.get(idx_us).map(|s| s.trim()).unwrap_or("");
467                        if expanded_text.is_empty() {
468                            continue;
469                        }
470                        let orig_line = mapping[idx_us];
471                        let replacement = format!("{}:{}", orig_path.display(), orig_line);
472                        out = out.replacen(
473                            &format!("{}:{}", tmp_str, expanded_line),
474                            &replacement,
475                            1,
476                        );
477                        return (out, Some(orig_line));
478                    }
479                }
480            }
481        }
482    }
483    (out, None)
484}