use serde_json::Value;
use sql_cli::data::csv_datasource::CsvApiClient;
use std::collections::HashMap;
use std::fs;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct CapturedQuery {
pub description: String,
pub data_file: String,
pub query: String,
pub expected_row_count: usize,
pub expected_columns: Vec<String>,
pub expected_first_row: Option<HashMap<String, Value>>,
pub case_insensitive: bool,
}
pub struct QueryReplayHarness {
queries: Vec<CapturedQuery>,
}
impl Default for QueryReplayHarness {
fn default() -> Self {
Self::new()
}
}
impl QueryReplayHarness {
#[must_use]
pub fn new() -> Self {
Self {
queries: Vec::new(),
}
}
pub fn add_query(&mut self, query: CapturedQuery) {
self.queries.push(query);
}
pub fn load_from_file(file_path: &str) -> anyhow::Result<Self> {
let json_content = fs::read_to_string(file_path)?;
let queries: Vec<CapturedQuery> = serde_json::from_str(&json_content)?;
Ok(Self { queries })
}
pub fn run_all_tests(&self) -> anyhow::Result<()> {
for (i, query) in self.queries.iter().enumerate() {
println!("Running captured query {}: {}", i + 1, query.description);
self.run_single_test(query)?;
}
println!("All {} captured queries passed!", self.queries.len());
Ok(())
}
pub fn run_single_test(&self, query: &CapturedQuery) -> anyhow::Result<()> {
let mut csv_client = CsvApiClient::new();
csv_client.set_case_insensitive(query.case_insensitive);
if query.data_file.ends_with(".json") {
csv_client.load_json(&query.data_file, "data")?;
} else if query.data_file.ends_with(".csv") {
csv_client.load_csv(&query.data_file, "data")?;
} else {
return Err(anyhow::anyhow!(
"Unsupported file type: {}",
query.data_file
));
}
let response = csv_client.query_csv(&query.query)?;
if response.data.len() != query.expected_row_count {
return Err(anyhow::anyhow!(
"Row count mismatch for query '{}': expected {}, got {}",
query.description,
query.expected_row_count,
response.data.len()
));
}
if !query.expected_columns.is_empty() && !response.data.is_empty() {
if let Some(first_row) = response.data.first() {
if let Some(obj) = first_row.as_object() {
let actual_columns: std::collections::HashSet<_> = obj.keys().collect();
let expected_columns: std::collections::HashSet<_> =
query.expected_columns.iter().collect();
if actual_columns != expected_columns {
return Err(anyhow::anyhow!(
"Column mismatch for query '{}': expected {:?}, got {:?}",
query.description,
query.expected_columns,
obj.keys().collect::<Vec<_>>()
));
}
}
}
}
if let Some(ref expected_first) = query.expected_first_row {
if let Some(actual_first) = response.data.first() {
for (key, expected_value) in expected_first {
if let Some(actual_value) = actual_first.get(key) {
if actual_value != expected_value {
return Err(anyhow::anyhow!(
"First row data mismatch for query '{}' column '{}': expected {:?}, got {:?}",
query.description,
key,
expected_value,
actual_value
));
}
} else {
return Err(anyhow::anyhow!(
"Missing column '{}' in first row for query '{}'",
key,
query.description
));
}
}
}
}
println!("✓ Query '{}' passed", query.description);
Ok(())
}
}
#[must_use]
pub fn suggest_test_from_debug_dump(_debug_content: &str, data_file: &str) -> String {
format!(
"// Suggested test from TUI debug dump:\n\
harness.add_query(CapturedQuery {{\n\
\x20 description: \"Query captured from TUI debug session\".to_string(),\n\
\x20 data_file: \"{data_file}\".to_string(),\n\
\x20 query: \"SELECT * FROM data\".to_string(), // TODO: Extract from debug\n\
\x20 expected_row_count: 0, // TODO: Count from debug output\n\
\x20 expected_columns: vec![], // TODO: Extract from debug output\n\
\x20 expected_first_row: None, // TODO: Extract first row if needed\n\
\x20 case_insensitive: false,\n\
}});"
)
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
#[ignore = "Aggregate functions (COUNT, SUM, AVG) not yet implemented in parser"]
fn test_captured_queries_from_sample_data() -> anyhow::Result<()> {
let mut harness = QueryReplayHarness::new();
harness.add_query(CapturedQuery {
description: "Get all completed trades".to_string(),
data_file: "sample_trades.json".to_string(),
query: "SELECT * FROM data WHERE status = 'Completed'".to_string(),
expected_row_count: 3, expected_columns: vec![
"id".to_string(),
"platformOrderId".to_string(),
"tradeDate".to_string(),
"executionSide".to_string(),
"quantity".to_string(),
"price".to_string(),
"counterparty".to_string(),
"counterpartyCountry".to_string(),
"commission".to_string(),
"status".to_string(),
"trader".to_string(),
],
expected_first_row: None, case_insensitive: false,
});
harness.add_query(CapturedQuery {
description: "Complex aggregation by country".to_string(),
data_file: "sample_trades.json".to_string(),
query: "SELECT counterpartyCountry, COUNT(*) as trade_count, SUM(quantity) as total_quantity, AVG(price) as avg_price FROM data GROUP BY counterpartyCountry ORDER BY trade_count DESC".to_string(),
expected_row_count: 4, expected_columns: vec![
"counterpartyCountry".to_string(),
"trade_count".to_string(),
"total_quantity".to_string(),
"avg_price".to_string(),
],
expected_first_row: None,
case_insensitive: false,
});
harness.add_query(CapturedQuery {
description: "High value trades with complex filtering".to_string(),
data_file: "sample_trades.json".to_string(),
query: "SELECT trader, platformOrderId, quantity * price as trade_value FROM data WHERE quantity * price > 100000 AND status = 'Completed' ORDER BY trade_value DESC".to_string(),
expected_row_count: 3, expected_columns: vec![
"trader".to_string(),
"platformOrderId".to_string(),
"trade_value".to_string(),
],
expected_first_row: None,
case_insensitive: false,
});
if std::path::Path::new("sample_trades.json").exists() {
harness.run_all_tests()?;
} else {
println!("Skipping captured query tests - sample_trades.json not found");
}
Ok(())
}
#[test]
fn test_complex_string_and_date_operations() -> anyhow::Result<()> {
let temp_dir = tempdir()?;
let test_file = temp_dir.path().join("complex_test.json");
let complex_data = serde_json::json!([
{
"id": 1,
"trader": "John Smith Jr.",
"email": "john.smith@bank.com",
"tradeDate": "2024-01-15",
"executionTime": "14:30:25",
"description": "Large block trade for pension fund",
"tags": ["institutional", "pension", "large-cap"],
"metadata": {"risk_level": "low", "client_tier": "premium"}
},
{
"id": 2,
"trader": "Sarah O'Connor",
"email": "sarah.oconnor@hedge.com",
"tradeDate": "2024-01-16",
"executionTime": "09:15:00",
"description": "High-frequency algorithmic trade",
"tags": ["algorithmic", "hft", "momentum"],
"metadata": {"risk_level": "high", "client_tier": "standard"}
},
{
"id": 3,
"trader": "李小明", "email": "xiaoming.li@asia.com",
"tradeDate": "2024-01-17",
"executionTime": "23:45:12",
"description": "Cross-border arbitrage opportunity",
"tags": ["arbitrage", "cross-border", "forex"],
"metadata": {"risk_level": "medium", "client_tier": "vip"}
}
]);
fs::write(&test_file, serde_json::to_string_pretty(&complex_data)?)?;
let mut harness = QueryReplayHarness::new();
harness.add_query(CapturedQuery {
description: "Case-insensitive string matching with apostrophes".to_string(),
data_file: test_file.to_str().unwrap().to_string(),
query: "SELECT * FROM data WHERE trader LIKE '%Sarah%'".to_string(),
expected_row_count: 1,
expected_columns: vec![],
expected_first_row: None,
case_insensitive: false,
});
harness.add_query(CapturedQuery {
description: "Filter by nested JSON metadata".to_string(),
data_file: test_file.to_str().unwrap().to_string(),
query: "SELECT trader, description FROM data WHERE description LIKE '%algorithmic%'"
.to_string(),
expected_row_count: 1,
expected_columns: vec!["trader".to_string(), "description".to_string()],
expected_first_row: None,
case_insensitive: false,
});
harness.add_query(CapturedQuery {
description: "Unicode trader names".to_string(),
data_file: test_file.to_str().unwrap().to_string(),
query: "SELECT * FROM data WHERE trader = '李小明'".to_string(),
expected_row_count: 1,
expected_columns: vec![],
expected_first_row: None,
case_insensitive: false,
});
harness.run_all_tests()?;
Ok(())
}
#[test]
#[ignore = "Aggregate functions (COUNT, SUM, AVG, MAX) and GROUP BY not yet implemented"]
fn test_from_yanked_tui_session() -> anyhow::Result<()> {
let yanked_query_session = r"
SELECT
trader,
COUNT(*) as trade_count,
SUM(quantity * price) as total_value,
AVG(commission) as avg_commission,
MAX(tradeDate) as last_trade_date
FROM data
WHERE status = 'Completed'
AND counterpartyCountry IN ('US', 'JP')
AND quantity > 500
GROUP BY trader
HAVING trade_count >= 1
ORDER BY total_value DESC
"
.trim();
let mut harness = QueryReplayHarness::new();
harness.add_query(CapturedQuery {
description: "Complex yanked query from TUI session".to_string(),
data_file: "sample_trades.json".to_string(),
query: yanked_query_session.to_string(),
expected_row_count: 1, expected_columns: vec![
"trader".to_string(),
"trade_count".to_string(),
"total_value".to_string(),
"avg_commission".to_string(),
"last_trade_date".to_string(),
],
expected_first_row: None,
case_insensitive: false,
});
if std::path::Path::new("sample_trades.json").exists() {
harness.run_all_tests()?;
} else {
println!("Skipping yanked query test - sample_trades.json not found");
}
Ok(())
}
}