1use std::fmt;
53
54use kimberlite_query::QueryResult;
55
56pub mod duckdb;
57pub mod kimberlite;
58
59pub use self::duckdb::DuckDbOracle;
60pub use self::kimberlite::KimberliteOracle;
61
62pub trait OracleRunner {
71 fn execute(&mut self, sql: &str) -> Result<QueryResult, OracleError>;
88
89 fn reset(&mut self) -> Result<(), OracleError>;
94
95 fn name(&self) -> &'static str;
97}
98
99#[derive(Debug, thiserror::Error)]
105pub enum OracleError {
106 #[error("SQL syntax error: {0}")]
108 SyntaxError(String),
109
110 #[error("Semantic error: {0}")]
112 SemanticError(String),
113
114 #[error("Runtime error: {0}")]
116 RuntimeError(String),
117
118 #[error("Timeout after {0}ms")]
120 Timeout(u64),
121
122 #[error("Unsupported feature: {0}")]
124 Unsupported(String),
125
126 #[error("Internal error: {0}")]
128 Internal(String),
129}
130
131pub fn compare_results(
150 left: &QueryResult,
151 right: &QueryResult,
152 left_name: &str,
153 right_name: &str,
154) -> Result<(), ResultMismatch> {
155 if left.columns.len() != right.columns.len() {
157 return Err(ResultMismatch::ColumnCountMismatch {
158 left: left.columns.len(),
159 right: right.columns.len(),
160 left_name: left_name.to_string(),
161 right_name: right_name.to_string(),
162 });
163 }
164
165 for (i, (left_col, right_col)) in left.columns.iter().zip(right.columns.iter()).enumerate() {
167 if left_col.as_str() != right_col.as_str() {
168 return Err(ResultMismatch::ColumnNameMismatch {
169 column_index: i,
170 left: left_col.as_str().to_string(),
171 right: right_col.as_str().to_string(),
172 left_name: left_name.to_string(),
173 right_name: right_name.to_string(),
174 });
175 }
176 }
177
178 if left.rows.len() != right.rows.len() {
180 return Err(ResultMismatch::RowCountMismatch {
181 left: left.rows.len(),
182 right: right.rows.len(),
183 left_name: left_name.to_string(),
184 right_name: right_name.to_string(),
185 });
186 }
187
188 for (row_idx, (left_row, right_row)) in left.rows.iter().zip(right.rows.iter()).enumerate() {
190 if left_row.len() != right_row.len() {
191 return Err(ResultMismatch::RowValueCountMismatch {
192 row_index: row_idx,
193 left: left_row.len(),
194 right: right_row.len(),
195 left_name: left_name.to_string(),
196 right_name: right_name.to_string(),
197 });
198 }
199
200 for (col_idx, (left_val, right_val)) in left_row.iter().zip(right_row.iter()).enumerate() {
201 if left_val != right_val {
202 return Err(ResultMismatch::ValueMismatch {
203 row_index: row_idx,
204 column_index: col_idx,
205 left: format!("{left_val:?}"),
206 right: format!("{right_val:?}"),
207 left_name: left_name.to_string(),
208 right_name: right_name.to_string(),
209 });
210 }
211 }
212 }
213
214 Ok(())
215}
216
217#[derive(Debug, Clone)]
219pub enum ResultMismatch {
220 ColumnCountMismatch {
222 left: usize,
223 right: usize,
224 left_name: String,
225 right_name: String,
226 },
227
228 ColumnNameMismatch {
230 column_index: usize,
231 left: String,
232 right: String,
233 left_name: String,
234 right_name: String,
235 },
236
237 RowCountMismatch {
239 left: usize,
240 right: usize,
241 left_name: String,
242 right_name: String,
243 },
244
245 RowValueCountMismatch {
247 row_index: usize,
248 left: usize,
249 right: usize,
250 left_name: String,
251 right_name: String,
252 },
253
254 ValueMismatch {
256 row_index: usize,
257 column_index: usize,
258 left: String,
259 right: String,
260 left_name: String,
261 right_name: String,
262 },
263}
264
265impl fmt::Display for ResultMismatch {
266 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
267 match self {
268 ResultMismatch::ColumnCountMismatch {
269 left,
270 right,
271 left_name,
272 right_name,
273 } => {
274 write!(
275 f,
276 "Column count mismatch: {left_name}={left}, {right_name}={right}"
277 )
278 }
279 ResultMismatch::ColumnNameMismatch {
280 column_index,
281 left,
282 right,
283 left_name,
284 right_name,
285 } => {
286 write!(
287 f,
288 "Column name mismatch at index {column_index}: {left_name}='{left}', {right_name}='{right}'"
289 )
290 }
291 ResultMismatch::RowCountMismatch {
292 left,
293 right,
294 left_name,
295 right_name,
296 } => {
297 write!(
298 f,
299 "Row count mismatch: {left_name}={left}, {right_name}={right}"
300 )
301 }
302 ResultMismatch::RowValueCountMismatch {
303 row_index,
304 left,
305 right,
306 left_name,
307 right_name,
308 } => {
309 write!(
310 f,
311 "Row value count mismatch at row {row_index}: {left_name}={left}, {right_name}={right}"
312 )
313 }
314 ResultMismatch::ValueMismatch {
315 row_index,
316 column_index,
317 left,
318 right,
319 left_name,
320 right_name,
321 } => {
322 write!(
323 f,
324 "Value mismatch at row {row_index}, column {column_index}: {left_name}={left}, {right_name}={right}"
325 )
326 }
327 }
328 }
329}
330
331impl std::error::Error for ResultMismatch {}
332
333#[cfg(test)]
338mod tests {
339 use super::*;
340 use kimberlite_query::{ColumnName, Value};
341
342 #[test]
343 fn test_compare_results_identical() {
344 let result1 = QueryResult {
345 columns: vec![ColumnName::from("id"), ColumnName::from("name")],
346 rows: vec![
347 vec![Value::BigInt(1), Value::Text("Alice".to_string())],
348 vec![Value::BigInt(2), Value::Text("Bob".to_string())],
349 ],
350 };
351
352 let result2 = result1.clone();
353
354 assert!(compare_results(&result1, &result2, "left", "right").is_ok());
355 }
356
357 #[test]
358 fn test_compare_results_column_count_mismatch() {
359 let result1 = QueryResult {
360 columns: vec![ColumnName::from("id"), ColumnName::from("name")],
361 rows: vec![],
362 };
363
364 let result2 = QueryResult {
365 columns: vec![ColumnName::from("id")],
366 rows: vec![],
367 };
368
369 let err = compare_results(&result1, &result2, "left", "right").unwrap_err();
370 assert!(matches!(err, ResultMismatch::ColumnCountMismatch { .. }));
371 }
372
373 #[test]
374 fn test_compare_results_column_name_mismatch() {
375 let result1 = QueryResult {
376 columns: vec![ColumnName::from("id"), ColumnName::from("name")],
377 rows: vec![],
378 };
379
380 let result2 = QueryResult {
381 columns: vec![ColumnName::from("id"), ColumnName::from("email")],
382 rows: vec![],
383 };
384
385 let err = compare_results(&result1, &result2, "left", "right").unwrap_err();
386 assert!(matches!(err, ResultMismatch::ColumnNameMismatch { .. }));
387 }
388
389 #[test]
390 fn test_compare_results_row_count_mismatch() {
391 let result1 = QueryResult {
392 columns: vec![ColumnName::from("id")],
393 rows: vec![vec![Value::BigInt(1)], vec![Value::BigInt(2)]],
394 };
395
396 let result2 = QueryResult {
397 columns: vec![ColumnName::from("id")],
398 rows: vec![vec![Value::BigInt(1)]],
399 };
400
401 let err = compare_results(&result1, &result2, "left", "right").unwrap_err();
402 assert!(matches!(err, ResultMismatch::RowCountMismatch { .. }));
403 }
404
405 #[test]
406 fn test_compare_results_value_mismatch() {
407 let result1 = QueryResult {
408 columns: vec![ColumnName::from("id")],
409 rows: vec![vec![Value::BigInt(1)]],
410 };
411
412 let result2 = QueryResult {
413 columns: vec![ColumnName::from("id")],
414 rows: vec![vec![Value::BigInt(2)]],
415 };
416
417 let err = compare_results(&result1, &result2, "left", "right").unwrap_err();
418 assert!(matches!(err, ResultMismatch::ValueMismatch { .. }));
419 }
420}