llkv_sql/
slt.rs

1use libtest_mimic::{Arguments, Conclusion, Failed, Trial};
2use llkv_result::Error;
3use sqllogictest::{AsyncDB, DefaultColumnType, Runner};
4use std::path::{Component, Path};
5
6/// Run a single slt file using the provided AsyncDB factory. The factory is
7/// a closure that returns a future resolving to a new DB instance for the
8/// runner. This mirrors sqllogictest's Runner::new signature and behavior.
9pub async fn run_slt_file_with_factory<F, Fut, D, E>(path: &Path, factory: F) -> Result<(), Error>
10where
11    F: Fn() -> Fut + Send + Sync + 'static,
12    Fut: std::future::Future<Output = Result<D, E>> + Send,
13    D: AsyncDB<Error = Error, ColumnType = DefaultColumnType> + Send + 'static,
14    E: std::fmt::Debug,
15{
16    let text = std::fs::read_to_string(path)
17        .map_err(|e| Error::Internal(format!("failed to read slt file: {}", e)))?;
18    let raw_lines: Vec<String> = text.lines().map(|l| l.to_string()).collect();
19    let (expanded_lines, mapping) = expand_loops_with_mapping(&raw_lines, 0)?;
20    let (expanded_lines, mapping) = {
21        let mut filtered_lines = Vec::with_capacity(expanded_lines.len());
22        let mut filtered_mapping = Vec::with_capacity(mapping.len());
23        for (line, orig_line) in expanded_lines.into_iter().zip(mapping.into_iter()) {
24            if line.trim_start().starts_with("load ") {
25                tracing::warn!(
26                    "Ignoring unsupported SLT directive `load`: {}:{} -> {}",
27                    path.display(),
28                    orig_line,
29                    line.trim()
30                );
31                continue;
32            }
33            filtered_lines.push(line);
34            filtered_mapping.push(orig_line);
35        }
36        (filtered_lines, filtered_mapping)
37    };
38    let (normalized_lines, mapping) = normalize_inline_connections(expanded_lines, mapping);
39
40    // TODO: Remove this check once the harness implements dialects (https://github.com/jzombie/rust-llkv/issues/111)
41    // DuckDB-specific fix: add extra blank line after plain text error messages
42    let is_duckdb_suite = path
43        .components()
44        .any(|component| matches!(component, Component::Normal(name) if name == "duckdb"));
45    let normalized_lines = if is_duckdb_suite {
46        fix_error_message_spacing(normalized_lines)
47    } else {
48        normalized_lines
49    };
50
51    let expanded_text = normalized_lines.join("\n");
52    let mut named = tempfile::NamedTempFile::new()
53        .map_err(|e| Error::Internal(format!("failed to create temp slt file: {}", e)))?;
54    use std::io::Write as _;
55    named
56        .write_all(expanded_text.as_bytes())
57        .map_err(|e| Error::Internal(format!("failed to write temp slt file: {}", e)))?;
58    if std::env::var("LLKV_DUMP_SLT").is_ok() {
59        let dump_path = std::path::Path::new("target/normalized.slt");
60        if let Some(parent) = dump_path.parent() {
61            let _ = std::fs::create_dir_all(parent);
62        }
63        if let Err(e) = std::fs::write(dump_path, &expanded_text) {
64            tracing::warn!("failed to dump normalized slt file: {}", e);
65        }
66    }
67    let tmp = named.path().to_path_buf();
68
69    let mut runner = Runner::new(|| async {
70        factory()
71            .await
72            .map_err(|e| Error::Internal(format!("factory error: {:?}", e)))
73    });
74    if let Err(e) = runner.run_file_async(&tmp).await {
75        let (mapped, opt_orig_line) =
76            map_temp_error_message(&format!("{}", e), &tmp, &normalized_lines, &mapping, path);
77        if let Some(orig_line) = opt_orig_line
78            && let Ok(text) = std::fs::read_to_string(path)
79            && let Some(line) = text.lines().nth(orig_line - 1)
80        {
81            eprintln!(
82                "[llkv-slt] original {}:{}: {}",
83                path.display(),
84                orig_line,
85                line.trim()
86            );
87        }
88        drop(named);
89        return Err(Error::Internal(format!("slt runner failed: {}", mapped)));
90    }
91
92    drop(named);
93    Ok(())
94}
95
96/// Discover `.slt` files under the given directory and run them as
97/// libtest_mimic trials using the provided AsyncDB factory constructor.
98///
99/// The `factory_factory` closure is called once per test file and should return
100/// a factory closure that creates DB instances. This allows each test file to
101/// have isolated state while enabling multiple connections within a test to
102/// share state. This keeps the harness engine-agnostic so different crates
103/// can provide their own engine adapters.
104pub fn run_slt_harness<FF, F, Fut, D, E>(slt_dir: &str, factory_factory: FF)
105where
106    FF: Fn() -> F + Send + Sync + 'static + Clone,
107    F: Fn() -> Fut + Send + Sync + 'static,
108    Fut: std::future::Future<Output = Result<D, E>> + Send + 'static,
109    D: AsyncDB<Error = Error, ColumnType = DefaultColumnType> + Send + 'static,
110    E: std::fmt::Debug + Send + 'static,
111{
112    let args = Arguments::from_args();
113    let conclusion = run_slt_harness_with_args(slt_dir, factory_factory, args);
114    if conclusion.has_failed() {
115        panic!(
116            "SLT harness reported {} failed test(s)",
117            conclusion.num_failed
118        );
119    }
120}
121
122/// Same as [`run_slt_harness`], but accepts pre-parsed [`Arguments`] so callers
123/// can control CLI parsing (e.g. custom binaries).
124pub fn run_slt_harness_with_args<FF, F, Fut, D, E>(
125    slt_dir: &str,
126    factory_factory: FF,
127    args: Arguments,
128) -> Conclusion
129where
130    FF: Fn() -> F + Send + Sync + 'static + Clone,
131    F: Fn() -> Fut + Send + Sync + 'static,
132    Fut: std::future::Future<Output = Result<D, E>> + Send + 'static,
133    D: AsyncDB<Error = Error, ColumnType = DefaultColumnType> + Send + 'static,
134    E: std::fmt::Debug + Send + 'static,
135{
136    let base = std::path::Path::new(slt_dir);
137    // Discover files
138    let files = {
139        let mut out = Vec::new();
140        if base.exists() {
141            let mut stack = vec![base.to_path_buf()];
142            while let Some(p) = stack.pop() {
143                if p.is_dir() {
144                    if let Ok(read) = std::fs::read_dir(&p) {
145                        for entry in read.flatten() {
146                            stack.push(entry.path());
147                        }
148                    }
149                } else if let Some(ext) = p.extension()
150                    && ext == "slt"
151                {
152                    out.push(p);
153                }
154            }
155        }
156        out.sort();
157        out
158    };
159
160    let base_parent = base.parent();
161    let mut trials: Vec<Trial> = Vec::new();
162    for f in files {
163        let name_path = base_parent
164            .and_then(|parent| f.strip_prefix(parent).ok())
165            .or_else(|| f.strip_prefix(base).ok())
166            .unwrap_or(&f);
167        let mut name = name_path.to_string_lossy().to_string();
168        if std::path::MAIN_SEPARATOR != '/' {
169            name = name.replace(std::path::MAIN_SEPARATOR, "/");
170        }
171        let name = name.trim_start_matches(&['/', '\\'][..]).to_string();
172        let path_clone = f.clone();
173        let factory_factory_clone = factory_factory.clone();
174        trials.push(Trial::test(name, move || {
175            let p = path_clone.clone();
176            // Call the factory_factory to get a fresh factory for this test file
177            let fac = factory_factory_clone();
178            let rt = tokio::runtime::Builder::new_current_thread()
179                .enable_all()
180                .build()
181                .map_err(|e| Failed::from(format!("failed to build tokio runtime: {e}")))?;
182            let res: Result<(), Error> =
183                rt.block_on(async move { run_slt_file_with_factory(&p, fac).await });
184            res.map_err(|e| Failed::from(format!("slt runner error: {e}")))
185        }));
186    }
187
188    libtest_mimic::run(&args, trials)
189}
190
191/// Expand `loop var start count` directives, returning the expanded lines and
192/// a mapping from expanded line index to the original 1-based source line.
193pub fn expand_loops_with_mapping(
194    lines: &[String],
195    base_index: usize,
196) -> Result<(Vec<String>, Vec<usize>), Error> {
197    let mut out_lines: Vec<String> = Vec::new();
198    let mut out_map: Vec<usize> = Vec::new();
199    let mut i = 0usize;
200    while i < lines.len() {
201        let line = lines[i].trim_start().to_string();
202        if line.starts_with("loop ") {
203            let parts: Vec<&str> = line.split_whitespace().collect();
204            if parts.len() < 4 {
205                return Err(Error::Internal(format!(
206                    "malformed loop directive: {}",
207                    line
208                )));
209            }
210            let var = parts[1];
211            let start: i64 = parts[2]
212                .parse()
213                .map_err(|e| Error::Internal(format!("invalid loop start: {}", e)))?;
214            let count: i64 = parts[3]
215                .parse()
216                .map_err(|e| Error::Internal(format!("invalid loop count: {}", e)))?;
217
218            let mut j = i + 1;
219            while j < lines.len() && lines[j].trim_start() != "endloop" {
220                j += 1;
221            }
222            if j >= lines.len() {
223                return Err(Error::Internal("unterminated loop in slt".to_string()));
224            }
225
226            let inner = &lines[i + 1..j];
227            let (expanded_inner, inner_map) = expand_loops_with_mapping(inner, base_index + i + 1)?;
228
229            for k in 0..count {
230                let val = (start + k).to_string();
231                for (s, &orig_line) in expanded_inner.iter().zip(inner_map.iter()) {
232                    let substituted = s.replace(&format!("${}", var), &val);
233                    out_lines.push(substituted);
234                    out_map.push(orig_line);
235                }
236            }
237
238            i = j + 1;
239        } else {
240            out_lines.push(lines[i].clone());
241            out_map.push(base_index + i + 1);
242            i += 1;
243        }
244    }
245    Ok((out_lines, out_map))
246}
247
248/// Convert legacy sqllogictest inline connection syntax (e.g. `statement ok con1`)
249/// into explicit `connection` records so the upstream parser can understand them.
250/// Also ensures proper termination of statement error blocks by adding a blank line
251/// after ---- when there's no expected error pattern.
252#[allow(clippy::type_complexity)] // TODO: Refactor type complexity
253fn normalize_inline_connections(
254    lines: Vec<String>,
255    mapping: Vec<usize>,
256) -> (Vec<String>, Vec<usize>) {
257    fn collect_statement_error_block(
258        lines: &[String],
259        mapping: &[usize],
260        start: usize,
261    ) -> (
262        Vec<(String, usize)>,
263        Option<String>,
264        Vec<(String, usize)>,
265        bool,
266        usize,
267    ) {
268        let mut sql_lines = Vec::new();
269        let mut message_lines = Vec::new();
270        let mut regex_pattern = None;
271        let mut idx = start;
272        let mut saw_separator = false;
273
274        while idx < lines.len() {
275            let line = &lines[idx];
276            let trimmed = line.trim_start();
277            if trimmed == "----" {
278                saw_separator = true;
279                idx += 1;
280                break;
281            }
282            sql_lines.push((line.clone(), mapping[idx]));
283            idx += 1;
284        }
285
286        if saw_separator {
287            while idx < lines.len() {
288                let line = &lines[idx];
289                let trimmed_full = line.trim();
290                if trimmed_full.is_empty() {
291                    idx += 1;
292                    break;
293                }
294                if let Some(pattern) = trimmed_full.strip_prefix("<REGEX>:") {
295                    regex_pattern = Some(pattern.to_string());
296                    idx += 1;
297                    while idx < lines.len() && lines[idx].trim().is_empty() {
298                        idx += 1;
299                    }
300                    message_lines.clear();
301                    break;
302                }
303                message_lines.push((line.clone(), mapping[idx]));
304                idx += 1;
305            }
306        }
307
308        (sql_lines, regex_pattern, message_lines, saw_separator, idx)
309    }
310
311    fn is_connection_token(token: &str) -> bool {
312        token
313            .strip_prefix("con")
314            .map(|suffix| !suffix.is_empty() && suffix.chars().all(|ch| ch.is_ascii_digit()))
315            .unwrap_or(false)
316    }
317
318    let mut out_lines = Vec::with_capacity(lines.len());
319    let mut out_map = Vec::with_capacity(mapping.len());
320
321    let mut i = 0usize;
322    while i < lines.len() {
323        let line = &lines[i];
324        let orig = mapping[i];
325        let trimmed = line.trim_start();
326
327        // Handle connection syntax normalization
328        if trimmed.starts_with("statement ") || trimmed.starts_with("query ") {
329            let mut tokens: Vec<&str> = trimmed.split_whitespace().collect();
330            if tokens.len() >= 3 && tokens.last().is_some_and(|last| is_connection_token(last)) {
331                let conn = tokens.pop().unwrap();
332                let indent_len = line.len() - trimmed.len();
333                let indent = &line[..indent_len];
334
335                out_lines.push(format!("{indent}connection {conn}"));
336                out_map.push(orig);
337
338                let normalized = format!("{indent}{}", tokens.join(" "));
339                let normalized_trimmed = normalized.trim_start();
340                if normalized_trimmed.starts_with("statement error") {
341                    let (sql_lines, regex_pattern, message_lines, saw_separator, new_idx) =
342                        collect_statement_error_block(&lines, &mapping, i + 1);
343                    i = new_idx;
344
345                    if let Some(pattern) = regex_pattern {
346                        out_lines.push(format!("{indent}connection {conn}"));
347                        out_map.push(orig);
348                        out_lines.push(format!("{indent}statement error {}", pattern));
349                        out_map.push(orig);
350                    } else {
351                        out_lines.push(normalized.clone());
352                        out_map.push(orig);
353                    }
354                    for (sql_line, sql_map) in sql_lines {
355                        out_lines.push(sql_line);
356                        out_map.push(sql_map);
357                    }
358                    if saw_separator && !message_lines.is_empty() {
359                        out_lines.push(format!("{indent}----"));
360                        out_map.push(orig);
361                        for (msg_line, msg_map) in message_lines {
362                            out_lines.push(msg_line);
363                            out_map.push(msg_map);
364                        }
365                    }
366                    out_lines.push(String::new());
367                    out_map.push(orig);
368                    continue;
369                } else {
370                    // Not a statement error, just output the normalized line
371                    out_lines.push(normalized);
372                    out_map.push(orig);
373                    i += 1;
374                    continue;
375                }
376            }
377        }
378
379        // Check if this is a statement error (without inline connection) followed by ----
380        if trimmed.starts_with("statement error") {
381            let indent = &line[..line.len() - trimmed.len()];
382            let (sql_lines, regex_pattern, message_lines, saw_separator, new_idx) =
383                collect_statement_error_block(&lines, &mapping, i + 1);
384            i = new_idx;
385
386            if let Some(pattern) = regex_pattern {
387                out_lines.push(format!("{indent}statement error {}", pattern));
388                out_map.push(orig);
389            } else {
390                out_lines.push(line.clone());
391                out_map.push(orig);
392            }
393            for (sql_line, sql_map) in sql_lines {
394                out_lines.push(sql_line);
395                out_map.push(sql_map);
396            }
397            if saw_separator && !message_lines.is_empty() {
398                out_lines.push(format!("{indent}----"));
399                out_map.push(orig);
400                for (msg_line, msg_map) in message_lines {
401                    out_lines.push(msg_line);
402                    out_map.push(msg_map);
403                }
404            }
405            out_lines.push(String::new());
406            out_map.push(orig);
407            continue;
408        }
409
410        out_lines.push(line.clone());
411        out_map.push(orig);
412        i += 1;
413    }
414
415    (out_lines, out_map)
416}
417
418/// Map a temporary expanded-file error message back to the original file path
419/// and line; returns (mapped_message, optional original line number).
420pub fn map_temp_error_message(
421    err_msg: &str,
422    tmp_path: &Path,
423    expanded_lines: &[String],
424    mapping: &[usize],
425    orig_path: &Path,
426) -> (String, Option<usize>) {
427    let tmp_str = tmp_path.to_string_lossy().to_string();
428    let mut out = err_msg.to_string();
429    if let Some(pos) = out.find(&tmp_str) {
430        let after = &out[pos + tmp_str.len()..];
431        if let Some(stripped) = after.strip_prefix(':') {
432            let mut digits = String::new();
433            for ch in stripped.chars() {
434                if ch.is_ascii_digit() {
435                    digits.push(ch);
436                } else {
437                    break;
438                }
439            }
440            if let Ok(expanded_line) = digits.parse::<usize>() {
441                let candidates: [isize; 3] = [1, 0, -1];
442                for &off in &candidates {
443                    let idx = (expanded_line as isize - 1) + off;
444                    if idx >= 0 && (idx as usize) < mapping.len() {
445                        let idx_us = idx as usize;
446                        let expanded_text =
447                            expanded_lines.get(idx_us).map(|s| s.trim()).unwrap_or("");
448                        if expanded_text.is_empty() {
449                            continue;
450                        }
451                        let orig_line = mapping[idx_us];
452                        let replacement = format!("{}:{}", orig_path.display(), orig_line);
453                        out = out.replacen(
454                            &format!("{}:{}", tmp_str, expanded_line),
455                            &replacement,
456                            1,
457                        );
458                        return (out, Some(orig_line));
459                    }
460                }
461            }
462        }
463    }
464    (out, None)
465}
466
467/// Fix error message spacing to prevent sqllogictest multiline interpretation.
468/// Adds an extra blank line after plain text error messages (not regex patterns).
469fn fix_error_message_spacing(lines: Vec<String>) -> Vec<String> {
470    let mut out_lines = Vec::with_capacity(lines.len() + 10);
471
472    let mut i = 0;
473    while i < lines.len() {
474        let line = &lines[i];
475        let trimmed = line.trim();
476
477        // Detect error block: statement error followed by SQL, ----, optional message, blank line
478        if trimmed.starts_with("statement error") && !trimmed.contains("<REGEX>:") {
479            // Output the statement error line
480            out_lines.push(line.clone());
481            i += 1;
482
483            // Collect SQL lines until ----
484            while i < lines.len() && lines[i].trim() != "----" {
485                out_lines.push(lines[i].clone());
486                i += 1;
487            }
488
489            // Output ----
490            if i < lines.len() && lines[i].trim() == "----" {
491                out_lines.push(lines[i].clone());
492                i += 1;
493
494                // Check if there's an error message (non-empty line after ----)
495                if i < lines.len() && !lines[i].trim().is_empty() {
496                    // Output the error message
497                    out_lines.push(lines[i].clone());
498                    i += 1;
499
500                    // Output the existing blank line
501                    if i < lines.len() && lines[i].trim().is_empty() {
502                        out_lines.push(lines[i].clone());
503                        i += 1;
504                    }
505
506                    // Add EXTRA blank line to prevent multiline interpretation
507                    out_lines.push(String::new());
508                } else {
509                    // No message - just pass through the blank line
510                    if i < lines.len() {
511                        out_lines.push(lines[i].clone());
512                        i += 1;
513                    }
514                }
515            }
516        } else {
517            // Not an error block - pass through
518            out_lines.push(line.clone());
519            i += 1;
520        }
521    }
522
523    out_lines
524}