1use libtest_mimic::{Arguments, Conclusion, Failed, Trial};
2use llkv_result::Error;
3use regex::escape;
4use sqllogictest::{AsyncDB, DefaultColumnType, Runner};
5use std::path::Path;
6
7pub async fn run_slt_file_with_factory<F, Fut, D, E>(path: &Path, factory: F) -> Result<(), Error>
11where
12 F: Fn() -> Fut + Send + Sync + 'static,
13 Fut: std::future::Future<Output = Result<D, E>> + Send,
14 D: AsyncDB<Error = Error, ColumnType = DefaultColumnType> + Send + 'static,
15 E: std::fmt::Debug,
16{
17 let text = std::fs::read_to_string(path)
18 .map_err(|e| Error::Internal(format!("failed to read slt file: {}", e)))?;
19 let raw_lines: Vec<String> = text.lines().map(|l| l.to_string()).collect();
20 let (expanded_lines, mapping) = expand_loops_with_mapping(&raw_lines, 0)?;
21 let (expanded_lines, mapping) = {
22 let mut filtered_lines = Vec::with_capacity(expanded_lines.len());
23 let mut filtered_mapping = Vec::with_capacity(mapping.len());
24 for (line, orig_line) in expanded_lines.into_iter().zip(mapping.into_iter()) {
25 if line.trim_start().starts_with("load ") {
26 tracing::warn!(
27 "Ignoring unsupported SLT directive `load`: {}:{} -> {}",
28 path.display(),
29 orig_line,
30 line.trim()
31 );
32 continue;
33 }
34 filtered_lines.push(line);
35 filtered_mapping.push(orig_line);
36 }
37 (filtered_lines, filtered_mapping)
38 };
39 let (normalized_lines, mapping) = normalize_inline_connections(expanded_lines, mapping);
40
41 let expanded_text = normalized_lines.join("\n");
42 let mut named = tempfile::NamedTempFile::new()
43 .map_err(|e| Error::Internal(format!("failed to create temp slt file: {}", e)))?;
44 use std::io::Write as _;
45 named
46 .write_all(expanded_text.as_bytes())
47 .map_err(|e| Error::Internal(format!("failed to write temp slt file: {}", e)))?;
48 if std::env::var("LLKV_DUMP_SLT").is_ok() {
49 let dump_path = std::path::Path::new("target/normalized.slt");
50 if let Some(parent) = dump_path.parent() {
51 let _ = std::fs::create_dir_all(parent);
52 }
53 if let Err(e) = std::fs::write(dump_path, &expanded_text) {
54 tracing::warn!("failed to dump normalized slt file: {}", e);
55 }
56 }
57 let tmp = named.path().to_path_buf();
58
59 let mut runner = Runner::new(|| async {
60 factory()
61 .await
62 .map_err(|e| Error::Internal(format!("factory error: {:?}", e)))
63 });
64
65 runner.with_hash_threshold(256);
68
69 if let Err(e) = runner.run_file_async(&tmp).await {
70 let (mapped, opt_orig_line) =
71 map_temp_error_message(&format!("{}", e), &tmp, &normalized_lines, &mapping, path);
72 if let Some(orig_line) = opt_orig_line
73 && let Ok(text) = std::fs::read_to_string(path)
74 && let Some(line) = text.lines().nth(orig_line - 1)
75 {
76 eprintln!(
77 "[llkv-slt] original {}:{}: {}",
78 path.display(),
79 orig_line,
80 line.trim()
81 );
82 }
83 drop(named);
84 return Err(Error::Internal(format!("slt runner failed: {}", mapped)));
85 }
86
87 drop(named);
88 Ok(())
89}
90
91pub fn run_slt_harness<FF, F, Fut, D, E>(slt_dir: &str, factory_factory: FF)
100where
101 FF: Fn() -> F + Send + Sync + 'static + Clone,
102 F: Fn() -> Fut + Send + Sync + 'static,
103 Fut: std::future::Future<Output = Result<D, E>> + Send + 'static,
104 D: AsyncDB<Error = Error, ColumnType = DefaultColumnType> + Send + 'static,
105 E: std::fmt::Debug + Send + 'static,
106{
107 let args = Arguments::from_args();
108 let conclusion = run_slt_harness_with_args(slt_dir, factory_factory, args);
109 if conclusion.has_failed() {
110 panic!(
111 "SLT harness reported {} failed test(s)",
112 conclusion.num_failed
113 );
114 }
115}
116
117pub fn run_slt_harness_with_args<FF, F, Fut, D, E>(
120 slt_dir: &str,
121 factory_factory: FF,
122 args: Arguments,
123) -> Conclusion
124where
125 FF: Fn() -> F + Send + Sync + 'static + Clone,
126 F: Fn() -> Fut + Send + Sync + 'static,
127 Fut: std::future::Future<Output = Result<D, E>> + Send + 'static,
128 D: AsyncDB<Error = Error, ColumnType = DefaultColumnType> + Send + 'static,
129 E: std::fmt::Debug + Send + 'static,
130{
131 let base = std::path::Path::new(slt_dir);
132 let files = {
134 let mut out = Vec::new();
135 if base.exists() {
136 let mut stack = vec![base.to_path_buf()];
137 while let Some(p) = stack.pop() {
138 if p.is_dir() {
139 if let Ok(read) = std::fs::read_dir(&p) {
140 for entry in read.flatten() {
141 stack.push(entry.path());
142 }
143 }
144 } else if let Some(ext) = p.extension()
145 && ext == "slt"
146 {
147 out.push(p);
148 }
149 }
150 }
151 out.sort();
152 out
153 };
154
155 let base_parent = base.parent();
156 let mut trials: Vec<Trial> = Vec::new();
157 for f in files {
158 let name_path = base_parent
159 .and_then(|parent| f.strip_prefix(parent).ok())
160 .or_else(|| f.strip_prefix(base).ok())
161 .unwrap_or(&f);
162 let mut name = name_path.to_string_lossy().to_string();
163 if std::path::MAIN_SEPARATOR != '/' {
164 name = name.replace(std::path::MAIN_SEPARATOR, "/");
165 }
166 let name = name.trim_start_matches(&['/', '\\'][..]).to_string();
167 let path_clone = f.clone();
168 let factory_factory_clone = factory_factory.clone();
169 trials.push(Trial::test(name, move || {
170 let p = path_clone.clone();
171 let fac = factory_factory_clone();
173 let rt = tokio::runtime::Builder::new_current_thread()
174 .enable_all()
175 .build()
176 .map_err(|e| Failed::from(format!("failed to build tokio runtime: {e}")))?;
177 let res: Result<(), Error> =
178 rt.block_on(async move { run_slt_file_with_factory(&p, fac).await });
179 res.map_err(|e| Failed::from(format!("slt runner error: {e}")))
180 }));
181 }
182
183 libtest_mimic::run(&args, trials)
184}
185
186pub fn expand_loops_with_mapping(
189 lines: &[String],
190 base_index: usize,
191) -> Result<(Vec<String>, Vec<usize>), Error> {
192 let mut out_lines: Vec<String> = Vec::new();
193 let mut out_map: Vec<usize> = Vec::new();
194 let mut i = 0usize;
195 while i < lines.len() {
196 let line = lines[i].trim_start().to_string();
197 if line.starts_with("loop ") {
198 let parts: Vec<&str> = line.split_whitespace().collect();
199 if parts.len() < 4 {
200 return Err(Error::Internal(format!(
201 "malformed loop directive: {}",
202 line
203 )));
204 }
205 let var = parts[1];
206 let start: i64 = parts[2]
207 .parse()
208 .map_err(|e| Error::Internal(format!("invalid loop start: {}", e)))?;
209 let count: i64 = parts[3]
210 .parse()
211 .map_err(|e| Error::Internal(format!("invalid loop count: {}", e)))?;
212
213 let mut j = i + 1;
214 while j < lines.len() && lines[j].trim_start() != "endloop" {
215 j += 1;
216 }
217 if j >= lines.len() {
218 return Err(Error::Internal("unterminated loop in slt".to_string()));
219 }
220
221 let inner = &lines[i + 1..j];
222 let (expanded_inner, inner_map) = expand_loops_with_mapping(inner, base_index + i + 1)?;
223
224 for k in 0..count {
225 let val = (start + k).to_string();
226 let token_plain = format!("${}", var);
227 let token_braced = format!("${{{}}}", var);
228 for (s, &orig_line) in expanded_inner.iter().zip(inner_map.iter()) {
229 let substituted = s.replace(&token_braced, &val).replace(&token_plain, &val);
230 out_lines.push(substituted);
231 out_map.push(orig_line);
232 }
233 }
234
235 i = j + 1;
236 } else {
237 out_lines.push(lines[i].clone());
238 out_map.push(base_index + i + 1);
239 i += 1;
240 }
241 }
242 Ok((out_lines, out_map))
243}
244
245#[allow(clippy::type_complexity)] fn normalize_inline_connections(
251 lines: Vec<String>,
252 mapping: Vec<usize>,
253) -> (Vec<String>, Vec<usize>) {
254 fn collect_statement_error_block(
255 lines: &[String],
256 mapping: &[usize],
257 start: usize,
258 ) -> (
259 Vec<(String, usize)>,
260 Option<String>,
261 Vec<(String, usize)>,
262 bool,
263 usize,
264 ) {
265 let mut sql_lines = Vec::new();
266 let mut message_lines = Vec::new();
267 let mut regex_pattern = None;
268 let mut idx = start;
269 let mut saw_separator = false;
270
271 while idx < lines.len() {
272 let line = &lines[idx];
273 let trimmed = line.trim_start();
274 if trimmed == "----" {
275 saw_separator = true;
276 idx += 1;
277 break;
278 }
279 sql_lines.push((line.clone(), mapping[idx]));
280 idx += 1;
281 }
282
283 if saw_separator {
284 while idx < lines.len() {
285 let line = &lines[idx];
286 let trimmed_full = line.trim();
287 if trimmed_full.is_empty() {
288 idx += 1;
289 break;
290 }
291 if let Some(pattern) = trimmed_full.strip_prefix("<REGEX>:") {
292 regex_pattern = Some(pattern.to_string());
293 idx += 1;
294 while idx < lines.len() && lines[idx].trim().is_empty() {
295 idx += 1;
296 }
297 message_lines.clear();
298 break;
299 }
300 message_lines.push((line.clone(), mapping[idx]));
301 idx += 1;
302 }
303
304 if regex_pattern.is_none()
305 && !message_lines.is_empty()
306 && let Some((first_line, _)) = message_lines.first()
307 {
308 let trimmed_first = first_line.trim();
309 if !trimmed_first.is_empty() {
310 let escaped = escape(trimmed_first);
311 regex_pattern = Some(format!(".*{}.*", escaped));
312 message_lines.clear();
313 }
314 }
315 }
316
317 (sql_lines, regex_pattern, message_lines, saw_separator, idx)
318 }
319
320 fn is_connection_token(token: &str) -> bool {
321 token
322 .strip_prefix("con")
323 .map(|suffix| !suffix.is_empty() && suffix.chars().all(|ch| ch.is_ascii_digit()))
324 .unwrap_or(false)
325 }
326
327 let mut out_lines = Vec::with_capacity(lines.len());
328 let mut out_map = Vec::with_capacity(mapping.len());
329
330 let mut i = 0usize;
331 while i < lines.len() {
332 let line = &lines[i];
333 let orig = mapping[i];
334 let trimmed = line.trim_start();
335
336 if trimmed.starts_with("statement ") || trimmed.starts_with("query ") {
338 let mut tokens: Vec<&str> = trimmed.split_whitespace().collect();
339 if tokens.len() >= 3 && tokens.last().is_some_and(|last| is_connection_token(last)) {
340 let conn = tokens.pop().unwrap();
341 let indent_len = line.len() - trimmed.len();
342 let indent = &line[..indent_len];
343
344 out_lines.push(format!("{indent}connection {conn}"));
345 out_map.push(orig);
346
347 let normalized = format!("{indent}{}", tokens.join(" "));
348 let normalized_trimmed = normalized.trim_start();
349 if normalized_trimmed.starts_with("statement error") {
350 let (sql_lines, regex_pattern, message_lines, saw_separator, new_idx) =
351 collect_statement_error_block(&lines, &mapping, i + 1);
352 i = new_idx;
353
354 let has_regex = regex_pattern.is_some();
355 if let Some(pattern) = regex_pattern {
356 out_lines.push(format!("{indent}connection {conn}"));
357 out_map.push(orig);
358 out_lines.push(format!("{indent}statement error {}", pattern));
359 out_map.push(orig);
360 } else {
361 out_lines.push(normalized.clone());
362 out_map.push(orig);
363 }
364 for (sql_line, sql_map) in sql_lines {
365 out_lines.push(sql_line);
366 out_map.push(sql_map);
367 }
368 if saw_separator && !has_regex && !message_lines.is_empty() {
370 out_lines.push(format!("{indent}----"));
371 out_map.push(orig);
372 for (msg_line, msg_map) in message_lines {
373 out_lines.push(msg_line);
374 out_map.push(msg_map);
375 }
376 out_lines.push(String::new());
378 out_map.push(orig);
379 }
380 out_lines.push(String::new());
381 out_map.push(orig);
382 continue;
383 } else {
384 out_lines.push(normalized);
386 out_map.push(orig);
387 i += 1;
388 continue;
389 }
390 }
391 }
392
393 if trimmed.starts_with("statement error") {
395 let indent = &line[..line.len() - trimmed.len()];
396 let (sql_lines, regex_pattern, message_lines, saw_separator, new_idx) =
397 collect_statement_error_block(&lines, &mapping, i + 1);
398 i = new_idx;
399
400 let has_regex = regex_pattern.is_some();
401 if let Some(pattern) = regex_pattern {
402 out_lines.push(format!("{indent}statement error {}", pattern));
403 out_map.push(orig);
404 } else {
405 out_lines.push(line.clone());
406 out_map.push(orig);
407 }
408 for (sql_line, sql_map) in sql_lines {
409 out_lines.push(sql_line);
410 out_map.push(sql_map);
411 }
412 if saw_separator && !has_regex && !message_lines.is_empty() {
414 out_lines.push(format!("{indent}----"));
415 out_map.push(orig);
416 for (msg_line, msg_map) in message_lines {
417 out_lines.push(msg_line);
418 out_map.push(msg_map);
419 }
420 out_lines.push(String::new());
422 out_map.push(orig);
423 }
424 out_lines.push(String::new());
425 out_map.push(orig);
426 continue;
427 }
428
429 out_lines.push(line.clone());
430 out_map.push(orig);
431 i += 1;
432 }
433
434 (out_lines, out_map)
435}
436
437pub fn map_temp_error_message(
440 err_msg: &str,
441 tmp_path: &Path,
442 expanded_lines: &[String],
443 mapping: &[usize],
444 orig_path: &Path,
445) -> (String, Option<usize>) {
446 let tmp_str = tmp_path.to_string_lossy().to_string();
447 let mut out = err_msg.to_string();
448 if let Some(pos) = out.find(&tmp_str) {
449 let after = &out[pos + tmp_str.len()..];
450 if let Some(stripped) = after.strip_prefix(':') {
451 let mut digits = String::new();
452 for ch in stripped.chars() {
453 if ch.is_ascii_digit() {
454 digits.push(ch);
455 } else {
456 break;
457 }
458 }
459 if let Ok(expanded_line) = digits.parse::<usize>() {
460 let candidates: [isize; 3] = [1, 0, -1];
461 for &off in &candidates {
462 let idx = (expanded_line as isize - 1) + off;
463 if idx >= 0 && (idx as usize) < mapping.len() {
464 let idx_us = idx as usize;
465 let expanded_text =
466 expanded_lines.get(idx_us).map(|s| s.trim()).unwrap_or("");
467 if expanded_text.is_empty() {
468 continue;
469 }
470 let orig_line = mapping[idx_us];
471 let replacement = format!("{}:{}", orig_path.display(), orig_line);
472 out = out.replacen(
473 &format!("{}:{}", tmp_str, expanded_line),
474 &replacement,
475 1,
476 );
477 return (out, Some(orig_line));
478 }
479 }
480 }
481 }
482 }
483 (out, None)
484}