1use libtest_mimic::{Arguments, Conclusion, Failed, Trial};
2use llkv_result::Error;
3use sqllogictest::{AsyncDB, DefaultColumnType, Runner};
4use std::path::{Component, Path};
5
6pub async fn run_slt_file_with_factory<F, Fut, D, E>(path: &Path, factory: F) -> Result<(), Error>
10where
11 F: Fn() -> Fut + Send + Sync + 'static,
12 Fut: std::future::Future<Output = Result<D, E>> + Send,
13 D: AsyncDB<Error = Error, ColumnType = DefaultColumnType> + Send + 'static,
14 E: std::fmt::Debug,
15{
16 let text = std::fs::read_to_string(path)
17 .map_err(|e| Error::Internal(format!("failed to read slt file: {}", e)))?;
18 let raw_lines: Vec<String> = text.lines().map(|l| l.to_string()).collect();
19 let (expanded_lines, mapping) = expand_loops_with_mapping(&raw_lines, 0)?;
20 let (expanded_lines, mapping) = {
21 let mut filtered_lines = Vec::with_capacity(expanded_lines.len());
22 let mut filtered_mapping = Vec::with_capacity(mapping.len());
23 for (line, orig_line) in expanded_lines.into_iter().zip(mapping.into_iter()) {
24 if line.trim_start().starts_with("load ") {
25 tracing::warn!(
26 "Ignoring unsupported SLT directive `load`: {}:{} -> {}",
27 path.display(),
28 orig_line,
29 line.trim()
30 );
31 continue;
32 }
33 filtered_lines.push(line);
34 filtered_mapping.push(orig_line);
35 }
36 (filtered_lines, filtered_mapping)
37 };
38 let (normalized_lines, mapping) = normalize_inline_connections(expanded_lines, mapping);
39
40 let is_duckdb_suite = path
43 .components()
44 .any(|component| matches!(component, Component::Normal(name) if name == "duckdb"));
45 let normalized_lines = if is_duckdb_suite {
46 fix_error_message_spacing(normalized_lines)
47 } else {
48 normalized_lines
49 };
50
51 let expanded_text = normalized_lines.join("\n");
52 let mut named = tempfile::NamedTempFile::new()
53 .map_err(|e| Error::Internal(format!("failed to create temp slt file: {}", e)))?;
54 use std::io::Write as _;
55 named
56 .write_all(expanded_text.as_bytes())
57 .map_err(|e| Error::Internal(format!("failed to write temp slt file: {}", e)))?;
58 if std::env::var("LLKV_DUMP_SLT").is_ok() {
59 let dump_path = std::path::Path::new("target/normalized.slt");
60 if let Some(parent) = dump_path.parent() {
61 let _ = std::fs::create_dir_all(parent);
62 }
63 if let Err(e) = std::fs::write(dump_path, &expanded_text) {
64 tracing::warn!("failed to dump normalized slt file: {}", e);
65 }
66 }
67 let tmp = named.path().to_path_buf();
68
69 let mut runner = Runner::new(|| async {
70 factory()
71 .await
72 .map_err(|e| Error::Internal(format!("factory error: {:?}", e)))
73 });
74 if let Err(e) = runner.run_file_async(&tmp).await {
75 let (mapped, opt_orig_line) =
76 map_temp_error_message(&format!("{}", e), &tmp, &normalized_lines, &mapping, path);
77 if let Some(orig_line) = opt_orig_line
78 && let Ok(text) = std::fs::read_to_string(path)
79 && let Some(line) = text.lines().nth(orig_line - 1)
80 {
81 eprintln!(
82 "[llkv-slt] original {}:{}: {}",
83 path.display(),
84 orig_line,
85 line.trim()
86 );
87 }
88 drop(named);
89 return Err(Error::Internal(format!("slt runner failed: {}", mapped)));
90 }
91
92 drop(named);
93 Ok(())
94}
95
96pub fn run_slt_harness<FF, F, Fut, D, E>(slt_dir: &str, factory_factory: FF)
105where
106 FF: Fn() -> F + Send + Sync + 'static + Clone,
107 F: Fn() -> Fut + Send + Sync + 'static,
108 Fut: std::future::Future<Output = Result<D, E>> + Send + 'static,
109 D: AsyncDB<Error = Error, ColumnType = DefaultColumnType> + Send + 'static,
110 E: std::fmt::Debug + Send + 'static,
111{
112 let args = Arguments::from_args();
113 let conclusion = run_slt_harness_with_args(slt_dir, factory_factory, args);
114 if conclusion.has_failed() {
115 panic!(
116 "SLT harness reported {} failed test(s)",
117 conclusion.num_failed
118 );
119 }
120}
121
122pub fn run_slt_harness_with_args<FF, F, Fut, D, E>(
125 slt_dir: &str,
126 factory_factory: FF,
127 args: Arguments,
128) -> Conclusion
129where
130 FF: Fn() -> F + Send + Sync + 'static + Clone,
131 F: Fn() -> Fut + Send + Sync + 'static,
132 Fut: std::future::Future<Output = Result<D, E>> + Send + 'static,
133 D: AsyncDB<Error = Error, ColumnType = DefaultColumnType> + Send + 'static,
134 E: std::fmt::Debug + Send + 'static,
135{
136 let base = std::path::Path::new(slt_dir);
137 let files = {
139 let mut out = Vec::new();
140 if base.exists() {
141 let mut stack = vec![base.to_path_buf()];
142 while let Some(p) = stack.pop() {
143 if p.is_dir() {
144 if let Ok(read) = std::fs::read_dir(&p) {
145 for entry in read.flatten() {
146 stack.push(entry.path());
147 }
148 }
149 } else if let Some(ext) = p.extension()
150 && ext == "slt"
151 {
152 out.push(p);
153 }
154 }
155 }
156 out.sort();
157 out
158 };
159
160 let base_parent = base.parent();
161 let mut trials: Vec<Trial> = Vec::new();
162 for f in files {
163 let name_path = base_parent
164 .and_then(|parent| f.strip_prefix(parent).ok())
165 .or_else(|| f.strip_prefix(base).ok())
166 .unwrap_or(&f);
167 let mut name = name_path.to_string_lossy().to_string();
168 if std::path::MAIN_SEPARATOR != '/' {
169 name = name.replace(std::path::MAIN_SEPARATOR, "/");
170 }
171 let name = name.trim_start_matches(&['/', '\\'][..]).to_string();
172 let path_clone = f.clone();
173 let factory_factory_clone = factory_factory.clone();
174 trials.push(Trial::test(name, move || {
175 let p = path_clone.clone();
176 let fac = factory_factory_clone();
178 let rt = tokio::runtime::Builder::new_current_thread()
179 .enable_all()
180 .build()
181 .map_err(|e| Failed::from(format!("failed to build tokio runtime: {e}")))?;
182 let res: Result<(), Error> =
183 rt.block_on(async move { run_slt_file_with_factory(&p, fac).await });
184 res.map_err(|e| Failed::from(format!("slt runner error: {e}")))
185 }));
186 }
187
188 libtest_mimic::run(&args, trials)
189}
190
191pub fn expand_loops_with_mapping(
194 lines: &[String],
195 base_index: usize,
196) -> Result<(Vec<String>, Vec<usize>), Error> {
197 let mut out_lines: Vec<String> = Vec::new();
198 let mut out_map: Vec<usize> = Vec::new();
199 let mut i = 0usize;
200 while i < lines.len() {
201 let line = lines[i].trim_start().to_string();
202 if line.starts_with("loop ") {
203 let parts: Vec<&str> = line.split_whitespace().collect();
204 if parts.len() < 4 {
205 return Err(Error::Internal(format!(
206 "malformed loop directive: {}",
207 line
208 )));
209 }
210 let var = parts[1];
211 let start: i64 = parts[2]
212 .parse()
213 .map_err(|e| Error::Internal(format!("invalid loop start: {}", e)))?;
214 let count: i64 = parts[3]
215 .parse()
216 .map_err(|e| Error::Internal(format!("invalid loop count: {}", e)))?;
217
218 let mut j = i + 1;
219 while j < lines.len() && lines[j].trim_start() != "endloop" {
220 j += 1;
221 }
222 if j >= lines.len() {
223 return Err(Error::Internal("unterminated loop in slt".to_string()));
224 }
225
226 let inner = &lines[i + 1..j];
227 let (expanded_inner, inner_map) = expand_loops_with_mapping(inner, base_index + i + 1)?;
228
229 for k in 0..count {
230 let val = (start + k).to_string();
231 for (s, &orig_line) in expanded_inner.iter().zip(inner_map.iter()) {
232 let substituted = s.replace(&format!("${}", var), &val);
233 out_lines.push(substituted);
234 out_map.push(orig_line);
235 }
236 }
237
238 i = j + 1;
239 } else {
240 out_lines.push(lines[i].clone());
241 out_map.push(base_index + i + 1);
242 i += 1;
243 }
244 }
245 Ok((out_lines, out_map))
246}
247
248#[allow(clippy::type_complexity)] fn normalize_inline_connections(
254 lines: Vec<String>,
255 mapping: Vec<usize>,
256) -> (Vec<String>, Vec<usize>) {
257 fn collect_statement_error_block(
258 lines: &[String],
259 mapping: &[usize],
260 start: usize,
261 ) -> (
262 Vec<(String, usize)>,
263 Option<String>,
264 Vec<(String, usize)>,
265 bool,
266 usize,
267 ) {
268 let mut sql_lines = Vec::new();
269 let mut message_lines = Vec::new();
270 let mut regex_pattern = None;
271 let mut idx = start;
272 let mut saw_separator = false;
273
274 while idx < lines.len() {
275 let line = &lines[idx];
276 let trimmed = line.trim_start();
277 if trimmed == "----" {
278 saw_separator = true;
279 idx += 1;
280 break;
281 }
282 sql_lines.push((line.clone(), mapping[idx]));
283 idx += 1;
284 }
285
286 if saw_separator {
287 while idx < lines.len() {
288 let line = &lines[idx];
289 let trimmed_full = line.trim();
290 if trimmed_full.is_empty() {
291 idx += 1;
292 break;
293 }
294 if let Some(pattern) = trimmed_full.strip_prefix("<REGEX>:") {
295 regex_pattern = Some(pattern.to_string());
296 idx += 1;
297 while idx < lines.len() && lines[idx].trim().is_empty() {
298 idx += 1;
299 }
300 message_lines.clear();
301 break;
302 }
303 message_lines.push((line.clone(), mapping[idx]));
304 idx += 1;
305 }
306 }
307
308 (sql_lines, regex_pattern, message_lines, saw_separator, idx)
309 }
310
311 fn is_connection_token(token: &str) -> bool {
312 token
313 .strip_prefix("con")
314 .map(|suffix| !suffix.is_empty() && suffix.chars().all(|ch| ch.is_ascii_digit()))
315 .unwrap_or(false)
316 }
317
318 let mut out_lines = Vec::with_capacity(lines.len());
319 let mut out_map = Vec::with_capacity(mapping.len());
320
321 let mut i = 0usize;
322 while i < lines.len() {
323 let line = &lines[i];
324 let orig = mapping[i];
325 let trimmed = line.trim_start();
326
327 if trimmed.starts_with("statement ") || trimmed.starts_with("query ") {
329 let mut tokens: Vec<&str> = trimmed.split_whitespace().collect();
330 if tokens.len() >= 3 && tokens.last().is_some_and(|last| is_connection_token(last)) {
331 let conn = tokens.pop().unwrap();
332 let indent_len = line.len() - trimmed.len();
333 let indent = &line[..indent_len];
334
335 out_lines.push(format!("{indent}connection {conn}"));
336 out_map.push(orig);
337
338 let normalized = format!("{indent}{}", tokens.join(" "));
339 let normalized_trimmed = normalized.trim_start();
340 if normalized_trimmed.starts_with("statement error") {
341 let (sql_lines, regex_pattern, message_lines, saw_separator, new_idx) =
342 collect_statement_error_block(&lines, &mapping, i + 1);
343 i = new_idx;
344
345 if let Some(pattern) = regex_pattern {
346 out_lines.push(format!("{indent}connection {conn}"));
347 out_map.push(orig);
348 out_lines.push(format!("{indent}statement error {}", pattern));
349 out_map.push(orig);
350 } else {
351 out_lines.push(normalized.clone());
352 out_map.push(orig);
353 }
354 for (sql_line, sql_map) in sql_lines {
355 out_lines.push(sql_line);
356 out_map.push(sql_map);
357 }
358 if saw_separator && !message_lines.is_empty() {
359 out_lines.push(format!("{indent}----"));
360 out_map.push(orig);
361 for (msg_line, msg_map) in message_lines {
362 out_lines.push(msg_line);
363 out_map.push(msg_map);
364 }
365 }
366 out_lines.push(String::new());
367 out_map.push(orig);
368 continue;
369 } else {
370 out_lines.push(normalized);
372 out_map.push(orig);
373 i += 1;
374 continue;
375 }
376 }
377 }
378
379 if trimmed.starts_with("statement error") {
381 let indent = &line[..line.len() - trimmed.len()];
382 let (sql_lines, regex_pattern, message_lines, saw_separator, new_idx) =
383 collect_statement_error_block(&lines, &mapping, i + 1);
384 i = new_idx;
385
386 if let Some(pattern) = regex_pattern {
387 out_lines.push(format!("{indent}statement error {}", pattern));
388 out_map.push(orig);
389 } else {
390 out_lines.push(line.clone());
391 out_map.push(orig);
392 }
393 for (sql_line, sql_map) in sql_lines {
394 out_lines.push(sql_line);
395 out_map.push(sql_map);
396 }
397 if saw_separator && !message_lines.is_empty() {
398 out_lines.push(format!("{indent}----"));
399 out_map.push(orig);
400 for (msg_line, msg_map) in message_lines {
401 out_lines.push(msg_line);
402 out_map.push(msg_map);
403 }
404 }
405 out_lines.push(String::new());
406 out_map.push(orig);
407 continue;
408 }
409
410 out_lines.push(line.clone());
411 out_map.push(orig);
412 i += 1;
413 }
414
415 (out_lines, out_map)
416}
417
418pub fn map_temp_error_message(
421 err_msg: &str,
422 tmp_path: &Path,
423 expanded_lines: &[String],
424 mapping: &[usize],
425 orig_path: &Path,
426) -> (String, Option<usize>) {
427 let tmp_str = tmp_path.to_string_lossy().to_string();
428 let mut out = err_msg.to_string();
429 if let Some(pos) = out.find(&tmp_str) {
430 let after = &out[pos + tmp_str.len()..];
431 if let Some(stripped) = after.strip_prefix(':') {
432 let mut digits = String::new();
433 for ch in stripped.chars() {
434 if ch.is_ascii_digit() {
435 digits.push(ch);
436 } else {
437 break;
438 }
439 }
440 if let Ok(expanded_line) = digits.parse::<usize>() {
441 let candidates: [isize; 3] = [1, 0, -1];
442 for &off in &candidates {
443 let idx = (expanded_line as isize - 1) + off;
444 if idx >= 0 && (idx as usize) < mapping.len() {
445 let idx_us = idx as usize;
446 let expanded_text =
447 expanded_lines.get(idx_us).map(|s| s.trim()).unwrap_or("");
448 if expanded_text.is_empty() {
449 continue;
450 }
451 let orig_line = mapping[idx_us];
452 let replacement = format!("{}:{}", orig_path.display(), orig_line);
453 out = out.replacen(
454 &format!("{}:{}", tmp_str, expanded_line),
455 &replacement,
456 1,
457 );
458 return (out, Some(orig_line));
459 }
460 }
461 }
462 }
463 }
464 (out, None)
465}
466
467fn fix_error_message_spacing(lines: Vec<String>) -> Vec<String> {
470 let mut out_lines = Vec::with_capacity(lines.len() + 10);
471
472 let mut i = 0;
473 while i < lines.len() {
474 let line = &lines[i];
475 let trimmed = line.trim();
476
477 if trimmed.starts_with("statement error") && !trimmed.contains("<REGEX>:") {
479 out_lines.push(line.clone());
481 i += 1;
482
483 while i < lines.len() && lines[i].trim() != "----" {
485 out_lines.push(lines[i].clone());
486 i += 1;
487 }
488
489 if i < lines.len() && lines[i].trim() == "----" {
491 out_lines.push(lines[i].clone());
492 i += 1;
493
494 if i < lines.len() && !lines[i].trim().is_empty() {
496 out_lines.push(lines[i].clone());
498 i += 1;
499
500 if i < lines.len() && lines[i].trim().is_empty() {
502 out_lines.push(lines[i].clone());
503 i += 1;
504 }
505
506 out_lines.push(String::new());
508 } else {
509 if i < lines.len() {
511 out_lines.push(lines[i].clone());
512 i += 1;
513 }
514 }
515 }
516 } else {
517 out_lines.push(line.clone());
519 i += 1;
520 }
521 }
522
523 out_lines
524}