Skip to main content

mdql_core/
executor.rs

1//! Unified SQL execution — single entry point for CLI, REPL, and web server.
2
3use std::path::Path;
4
5use crate::api::Table;
6use crate::cascade;
7use crate::database::{ViewDef, is_database_dir, load_database_config, save_database_config};
8use crate::errors::{MdqlError, ValidationError};
9use crate::model::Row;
10use crate::query_ast::*;
11use crate::query_engine::{execute_join_query, execute_query};
12use crate::query_parser::{Statement, parse_query};
13use crate::schema::Schema;
14
15#[derive(Debug)]
16pub enum QueryResult {
17    Rows { rows: Vec<Row>, columns: Vec<String> },
18    Message(String),
19}
20
21pub fn execute(path: &Path, sql: &str) -> crate::errors::Result<(QueryResult, Vec<ValidationError>)> {
22    let stmt = parse_query(sql)?;
23    let is_db = is_database_dir(path);
24
25    match stmt {
26        Statement::Select(ref q) => {
27            let has_ctes = !q.ctes.is_empty();
28            let has_subqueries = query_has_subqueries(q);
29            let needs_db = has_ctes || has_subqueries || q.subquery.is_some() || !q.joins.is_empty() || is_db;
30
31            if has_ctes && !is_db {
32                return Err(MdqlError::QueryExecution(
33                    "CTEs (WITH) require a database directory".into(),
34                ));
35            }
36
37            if needs_db {
38                let (_config, mut tables, errors) = crate::loader::load_database(path)?;
39
40                for cte in &q.ctes {
41                    let (rows, cols) = materialize_cte(&cte.query, &tables)?;
42                    let schema = crate::loader::build_view_schema(&cte.name, &cols, &rows);
43                    tables.insert(cte.name.clone(), (schema, rows));
44                }
45
46                let mut q = q.clone();
47                if has_subqueries {
48                    materialize_subqueries(&mut q, &tables)?;
49                }
50
51                let (rows, cols) = if let Some(ref sub) = q.subquery {
52                    let source_table = &sub.table;
53                    let (schema, table_rows) = tables.get(source_table).ok_or_else(|| {
54                        MdqlError::QueryExecution(format!(
55                            "table '{}' not found in database",
56                            source_table
57                        ))
58                    })?;
59                    execute_query(&q, table_rows, schema)?
60                } else if !q.joins.is_empty() {
61                    execute_join_query(&q, &tables)?
62                } else {
63                    let (schema, rows) = tables.get(&q.table).ok_or_else(|| {
64                        MdqlError::QueryExecution(format!(
65                            "table '{}' not found in database",
66                            q.table
67                        ))
68                    })?;
69                    execute_query(&q, rows, schema)?
70                };
71                Ok((QueryResult::Rows { rows, columns: cols }, errors))
72            } else {
73                let (schema, rows, errors) = crate::loader::load_table(path)?;
74                let (rows, cols) = execute_query(q, &rows, &schema)?;
75                Ok((QueryResult::Rows { rows, columns: cols }, errors))
76            }
77        }
78        Statement::CreateView(ref cv) => {
79            if !is_db {
80                return Err(MdqlError::QueryExecution(
81                    "CREATE VIEW requires a database directory".into(),
82                ));
83            }
84            let mut config = load_database_config(path)?;
85
86            let (_config_check, tables, _errors) = crate::loader::load_database(path)?;
87            if tables.contains_key(&cv.view_name) {
88                return Err(MdqlError::QueryExecution(format!(
89                    "Name '{}' already exists as a table or view",
90                    cv.view_name
91                )));
92            }
93
94            if config.views.iter().any(|v| v.name == cv.view_name) {
95                return Err(MdqlError::QueryExecution(format!(
96                    "View '{}' already exists",
97                    cv.view_name
98                )));
99            }
100
101            let query_str = extract_view_query(sql)?;
102
103            let view_def = ViewDef {
104                name: cv.view_name.clone(),
105                query: query_str,
106            };
107
108            let test_result = crate::loader::load_database(path);
109            if let Ok((_cfg, test_tables, _errs)) = test_result {
110                let test_view = ViewDef {
111                    name: view_def.name.clone(),
112                    query: view_def.query.clone(),
113                };
114                if let Err(e) = super::loader::materialize_view(&test_view, &test_tables) {
115                    return Err(MdqlError::QueryExecution(format!(
116                        "View query failed validation: {}",
117                        e
118                    )));
119                }
120            }
121
122            config.views.push(view_def);
123            save_database_config(path, &config)?;
124            Ok((
125                QueryResult::Message(format!("View '{}' created", cv.view_name)),
126                vec![],
127            ))
128        }
129        Statement::DropView(ref dv) => {
130            if !is_db {
131                return Err(MdqlError::QueryExecution(
132                    "DROP VIEW requires a database directory".into(),
133                ));
134            }
135            let mut config = load_database_config(path)?;
136            let len_before = config.views.len();
137            config.views.retain(|v| v.name != dv.view_name);
138            if config.views.len() == len_before {
139                return Err(MdqlError::QueryExecution(format!(
140                    "View '{}' does not exist",
141                    dv.view_name
142                )));
143            }
144            save_database_config(path, &config)?;
145            Ok((
146                QueryResult::Message(format!("View '{}' dropped", dv.view_name)),
147                vec![],
148            ))
149        }
150        Statement::Delete(ref dq) if dq.mode != DeleteMode::Default => {
151            if !is_db {
152                return Err(MdqlError::QueryExecution(
153                    "CASCADE/RESTRICT requires a database directory".into(),
154                ));
155            }
156            let config = load_database_config(path)?;
157            if config.views.iter().any(|v| v.name == dq.table) {
158                return Err(MdqlError::QueryExecution(format!(
159                    "Cannot write to view '{}' — views are read-only",
160                    dq.table
161                )));
162            }
163            let (_cfg, tables, errors) = crate::loader::load_database(path)?;
164            let (_, rows) = tables.get(&dq.table).ok_or_else(|| {
165                MdqlError::QueryExecution(format!("table '{}' not found in database", dq.table))
166            })?;
167            let matched_filenames: Vec<String> = if let Some(ref wc) = dq.where_clause {
168                rows.iter()
169                    .filter(|r| crate::query_engine::evaluate(wc, r))
170                    .filter_map(|r| r.get("path").and_then(|v| v.as_str()).map(|s| s.to_string()))
171                    .collect()
172            } else {
173                rows.iter()
174                    .filter_map(|r| r.get("path").and_then(|v| v.as_str()).map(|s| s.to_string()))
175                    .collect()
176            };
177
178            match dq.mode {
179                DeleteMode::Cascade => {
180                    let plan = cascade::build_cascade_plan(
181                        &dq.table, &matched_filenames, &config, &tables,
182                    );
183                    let msg = cascade::execute_cascade_plan(&plan, path)?;
184                    Ok((QueryResult::Message(msg), errors))
185                }
186                DeleteMode::Restrict => {
187                    let plan = cascade::build_restrict_plan(
188                        &dq.table, &matched_filenames, &config, &tables,
189                    );
190                    if !plan.restrict_violations.is_empty() {
191                        let violations = plan.restrict_violations.join("\n  ");
192                        return Err(MdqlError::QueryExecution(format!(
193                            "RESTRICT: cannot delete — {} dependent references:\n  {}",
194                            plan.restrict_violations.len(),
195                            violations,
196                        )));
197                    }
198                    let table_path = path.join(&dq.table);
199                    let table = Table::new(&table_path)?;
200                    let msg = table.exec_delete_matched(&matched_filenames)?;
201                    Ok((QueryResult::Message(msg), errors))
202                }
203                DeleteMode::Default => unreachable!(),
204            }
205        }
206        ref stmt @ (Statement::Insert(_)
207        | Statement::Update(_)
208        | Statement::Delete(_)
209        | Statement::AlterRename(_)
210        | Statement::AlterDrop(_)
211        | Statement::AlterMerge(_)) => {
212            if is_db {
213                let config = load_database_config(path)?;
214                let target = stmt.table_name();
215                if config.views.iter().any(|v| v.name == target) {
216                    return Err(MdqlError::QueryExecution(format!(
217                        "Cannot write to view '{}' — views are read-only",
218                        target
219                    )));
220                }
221            }
222            let table_path = if is_db {
223                path.join(stmt.table_name())
224            } else {
225                path.to_path_buf()
226            };
227            let mut table = Table::new(&table_path)?;
228            let msg = table.execute_sql(sql)?;
229            Ok((QueryResult::Message(msg), vec![]))
230        }
231    }
232}
233
234pub fn materialize_cte(
235    query: &crate::query_ast::SelectQuery,
236    tables: &std::collections::HashMap<String, (crate::schema::Schema, Vec<Row>)>,
237) -> crate::errors::Result<(Vec<Row>, Vec<String>)> {
238    if let Some(ref sub) = query.subquery {
239        let (_, sub_rows) = tables.get(&sub.table).ok_or_else(|| {
240            MdqlError::QueryExecution(format!("table '{}' not found in database", sub.table))
241        })?;
242        let (sub_rows, _) = execute_query(sub, sub_rows, &tables.get(&sub.table).unwrap().0)?;
243        execute_query(query, &sub_rows, &tables.get(&sub.table).unwrap().0)
244    } else if !query.joins.is_empty() {
245        execute_join_query(query, tables)
246    } else {
247        let (schema, rows) = tables.get(&query.table).ok_or_else(|| {
248            MdqlError::QueryExecution(format!("table '{}' not found in database", query.table))
249        })?;
250        execute_query(query, rows, schema)
251    }
252}
253
254type Tables = std::collections::HashMap<String, (Schema, Vec<Row>)>;
255
256fn query_has_subqueries(q: &SelectQuery) -> bool {
257    if let Some(ref wc) = q.where_clause {
258        if where_has_subquery(wc) { return true; }
259    }
260    if let ColumnList::Named(ref exprs) = q.columns {
261        for se in exprs {
262            match se {
263                SelectExpr::Expr { expr, .. } => {
264                    if expr_has_subquery(expr) { return true; }
265                }
266                SelectExpr::Aggregate { arg_expr: Some(e), .. } => {
267                    if expr_has_subquery(e) { return true; }
268                }
269                _ => {}
270            }
271        }
272    }
273    false
274}
275
276fn where_has_subquery(wc: &WhereClause) -> bool {
277    match wc {
278        WhereClause::BoolOp(bop) => where_has_subquery(&bop.left) || where_has_subquery(&bop.right),
279        WhereClause::Comparison(cmp) => {
280            cmp.left_expr.as_ref().map_or(false, |e| expr_has_subquery(e))
281                || cmp.right_expr.as_ref().map_or(false, |e| expr_has_subquery(e))
282        }
283    }
284}
285
286fn expr_has_subquery(expr: &Expr) -> bool {
287    match expr {
288        Expr::Subquery(_) => true,
289        Expr::BinaryOp { left, right, .. } => expr_has_subquery(left) || expr_has_subquery(right),
290        Expr::UnaryMinus(inner) => expr_has_subquery(inner),
291        Expr::Case { whens, else_expr } => {
292            whens.iter().any(|(c, e)| where_has_subquery(c) || expr_has_subquery(e))
293                || else_expr.as_ref().map_or(false, |e| expr_has_subquery(e))
294        }
295        _ => false,
296    }
297}
298
299pub fn materialize_subqueries(
300    query: &mut SelectQuery,
301    tables: &Tables,
302) -> crate::errors::Result<()> {
303    if let Some(ref mut wc) = query.where_clause {
304        materialize_in_where(wc, tables)?;
305    }
306    if let ColumnList::Named(ref mut exprs) = query.columns {
307        for se in exprs.iter_mut() {
308            match se {
309                SelectExpr::Expr { ref mut expr, .. } => {
310                    materialize_in_expr(expr, tables)?;
311                }
312                SelectExpr::Aggregate { ref mut arg_expr, .. } => {
313                    if let Some(ref mut e) = arg_expr {
314                        materialize_in_expr(e, tables)?;
315                    }
316                }
317                _ => {}
318            }
319        }
320    }
321    Ok(())
322}
323
324fn materialize_in_where(wc: &mut WhereClause, tables: &Tables) -> crate::errors::Result<()> {
325    match wc {
326        WhereClause::BoolOp(ref mut bop) => {
327            materialize_in_where(&mut bop.left, tables)?;
328            materialize_in_where(&mut bop.right, tables)?;
329        }
330        WhereClause::Comparison(ref mut cmp) => {
331            if let Some(ref mut expr) = cmp.left_expr {
332                materialize_in_expr(expr, tables)?;
333            }
334            if let Some(ref mut expr) = cmp.right_expr {
335                if let Expr::Subquery(ref sq) = expr {
336                    let (rows, _cols) = materialize_cte(sq, tables)?;
337                    if cmp.op == CmpOp::In {
338                        let values: Vec<SqlValue> = rows.iter()
339                            .filter_map(|r| r.values().next())
340                            .map(|v| value_to_sql_value(v))
341                            .collect();
342                        cmp.value = Some(SqlValue::List(values.clone()));
343                        cmp.right_expr = None;
344                    } else {
345                        let val = rows.first()
346                            .and_then(|r| r.values().next())
347                            .map(|v| value_to_sql_value(v))
348                            .unwrap_or(SqlValue::Null);
349                        *expr = Expr::Literal(val);
350                    }
351                } else {
352                    materialize_in_expr(expr, tables)?;
353                }
354            }
355        }
356    }
357    Ok(())
358}
359
360fn materialize_in_expr(expr: &mut Expr, tables: &Tables) -> crate::errors::Result<()> {
361    match expr {
362        Expr::Subquery(ref sq) => {
363            let (rows, _cols) = materialize_cte(sq, tables)?;
364            let val = rows.first()
365                .and_then(|r| r.values().next())
366                .map(|v| value_to_sql_value(v))
367                .unwrap_or(SqlValue::Null);
368            *expr = Expr::Literal(val);
369        }
370        Expr::BinaryOp { ref mut left, ref mut right, .. } => {
371            materialize_in_expr(left, tables)?;
372            materialize_in_expr(right, tables)?;
373        }
374        Expr::UnaryMinus(ref mut inner) => {
375            materialize_in_expr(inner, tables)?;
376        }
377        Expr::Case { ref mut whens, ref mut else_expr } => {
378            for (ref mut cond, ref mut result) in whens.iter_mut() {
379                materialize_in_where(cond, tables)?;
380                materialize_in_expr(result, tables)?;
381            }
382            if let Some(ref mut e) = else_expr {
383                materialize_in_expr(e, tables)?;
384            }
385        }
386        _ => {}
387    }
388    Ok(())
389}
390
391fn value_to_sql_value(v: &crate::model::Value) -> SqlValue {
392    match v {
393        crate::model::Value::String(s) => SqlValue::String(s.clone()),
394        crate::model::Value::Int(n) => SqlValue::Int(*n),
395        crate::model::Value::Float(f) => SqlValue::Float(*f),
396        crate::model::Value::Bool(b) => SqlValue::Int(if *b { 1 } else { 0 }),
397        _ => SqlValue::Null,
398    }
399}
400
401fn extract_view_query(sql: &str) -> crate::errors::Result<String> {
402    let upper = sql.to_uppercase();
403    let as_keyword = upper.find(" AS ");
404    if let Some(pos) = as_keyword {
405        let after = &sql[pos + 4..];
406        let trimmed = after.trim_start();
407        let trimmed_upper = trimmed.to_uppercase();
408        if trimmed_upper.starts_with("SELECT") {
409            return Ok(trimmed.to_string());
410        }
411    }
412    // Fallback: scan for any whitespace-surrounded AS that precedes SELECT
413    let bytes = upper.as_bytes();
414    let mut i = 0;
415    while i + 4 < bytes.len() {
416        if bytes[i].is_ascii_whitespace()
417            && bytes[i + 1] == b'A'
418            && bytes[i + 2] == b'S'
419            && bytes[i + 3].is_ascii_whitespace()
420        {
421            let after = &sql[i + 3..];
422            let trimmed = after.trim_start();
423            let trimmed_upper = trimmed.to_uppercase();
424            if trimmed_upper.starts_with("SELECT") {
425                return Ok(trimmed.to_string());
426            }
427        }
428        i += 1;
429    }
430    Err(crate::errors::MdqlError::QueryExecution(
431        "CREATE VIEW must contain AS clause followed by SELECT".into(),
432    ))
433}
434
435#[cfg(test)]
436mod tests {
437    use super::*;
438    use crate::model::Value;
439    use std::fs;
440
441    fn make_test_db() -> tempfile::TempDir {
442        let dir = tempfile::tempdir().unwrap();
443
444        // Database-level _mdql.md
445        fs::write(
446            dir.path().join("_mdql.md"),
447            "---\ntype: database\nname: testdb\n---\n",
448        )
449        .unwrap();
450
451        // Table: strategies
452        let strats = dir.path().join("strategies");
453        fs::create_dir(&strats).unwrap();
454        fs::write(
455            strats.join("_mdql.md"),
456            "---\ntype: schema\ntable: strategies\nprimary_key: path\nfrontmatter:\n  title:\n    type: string\n  status:\n    type: string\n---\n",
457        )
458        .unwrap();
459        fs::write(
460            strats.join("alpha.md"),
461            "---\ntitle: Alpha\nstatus: LIVE\n---\n# Alpha\n",
462        )
463        .unwrap();
464        fs::write(
465            strats.join("beta.md"),
466            "---\ntitle: Beta\nstatus: DRAFT\n---\n# Beta\n",
467        )
468        .unwrap();
469
470        dir
471    }
472
473    #[test]
474    fn test_create_and_query_view() {
475        let dir = make_test_db();
476        let (result, _) = execute(
477            dir.path(),
478            "CREATE VIEW live AS SELECT * FROM strategies WHERE status = 'LIVE'",
479        )
480        .unwrap();
481        assert!(matches!(result, QueryResult::Message(ref m) if m.contains("created")));
482
483        let (result, _) = execute(dir.path(), "SELECT * FROM live").unwrap();
484        if let QueryResult::Rows { rows, columns } = result {
485            assert_eq!(rows.len(), 1);
486            assert!(columns.contains(&"title".to_string()));
487        } else {
488            panic!("Expected Rows");
489        }
490    }
491
492    #[test]
493    fn test_drop_view() {
494        let dir = make_test_db();
495        execute(
496            dir.path(),
497            "CREATE VIEW live AS SELECT * FROM strategies WHERE status = 'LIVE'",
498        )
499        .unwrap();
500
501        let (result, _) = execute(dir.path(), "DROP VIEW live").unwrap();
502        assert!(matches!(result, QueryResult::Message(ref m) if m.contains("dropped")));
503
504        let err = execute(dir.path(), "SELECT * FROM live");
505        assert!(err.is_err());
506    }
507
508    #[test]
509    fn test_drop_nonexistent_view() {
510        let dir = make_test_db();
511        let err = execute(dir.path(), "DROP VIEW nonexistent");
512        assert!(err.is_err());
513        assert!(err.unwrap_err().to_string().contains("does not exist"));
514    }
515
516    #[test]
517    fn test_create_view_duplicate_name() {
518        let dir = make_test_db();
519        execute(
520            dir.path(),
521            "CREATE VIEW live AS SELECT * FROM strategies WHERE status = 'LIVE'",
522        )
523        .unwrap();
524
525        let err = execute(
526            dir.path(),
527            "CREATE VIEW live AS SELECT * FROM strategies",
528        );
529        assert!(err.is_err());
530        assert!(err.unwrap_err().to_string().contains("already exists"));
531    }
532
533    #[test]
534    fn test_create_view_conflicts_with_table() {
535        let dir = make_test_db();
536        let err = execute(
537            dir.path(),
538            "CREATE VIEW strategies AS SELECT * FROM strategies",
539        );
540        assert!(err.is_err());
541        assert!(err.unwrap_err().to_string().contains("already exists"));
542    }
543
544    #[test]
545    fn test_write_to_view_rejected() {
546        let dir = make_test_db();
547        execute(
548            dir.path(),
549            "CREATE VIEW live AS SELECT * FROM strategies WHERE status = 'LIVE'",
550        )
551        .unwrap();
552
553        let err = execute(
554            dir.path(),
555            "INSERT INTO live (title, status) VALUES ('Gamma', 'LIVE')",
556        );
557        assert!(err.is_err());
558        assert!(err.unwrap_err().to_string().contains("read-only"));
559    }
560
561    #[test]
562    fn test_create_view_not_database() {
563        let dir = tempfile::tempdir().unwrap();
564        fs::write(
565            dir.path().join("_mdql.md"),
566            "---\ntype: schema\ntable: t\nprimary_key: path\nfrontmatter:\n  x:\n    type: string\n---\n",
567        )
568        .unwrap();
569
570        let err = execute(
571            dir.path(),
572            "CREATE VIEW v AS SELECT * FROM t",
573        );
574        assert!(err.is_err());
575        assert!(err.unwrap_err().to_string().contains("database directory"));
576    }
577
578    #[test]
579    fn test_extract_view_query_basic() {
580        let q = extract_view_query("CREATE VIEW v AS SELECT * FROM t").unwrap();
581        assert!(q.starts_with("SELECT"));
582    }
583
584    #[test]
585    fn test_extract_view_query_with_column_alias() {
586        let q = extract_view_query(
587            "CREATE VIEW v AS SELECT token, SUM(size) as sell_size FROM orders GROUP BY token HAVING sell_size > 0"
588        ).unwrap();
589        assert!(q.starts_with("SELECT"));
590        assert!(q.contains("HAVING"));
591    }
592
593    #[test]
594    fn test_extract_view_query_newline_after_as() {
595        let q = extract_view_query("CREATE VIEW v AS\nSELECT * FROM t").unwrap();
596        assert!(q.starts_with("SELECT"));
597    }
598
599    #[test]
600    fn test_create_view_with_aggregate_arithmetic() {
601        let dir = make_test_db();
602        let result = execute(
603            dir.path(),
604            "CREATE VIEW v AS SELECT status, COUNT(*) - COUNT(*) as zero FROM strategies GROUP BY status",
605        );
606        assert!(result.is_ok());
607    }
608
609    // ── Issue #44: HAVING in CREATE VIEW ──
610
611    #[test]
612    fn test_create_view_with_having() {
613        let dir = make_test_db();
614        // Create a view with HAVING — both statuses have cnt=1, so HAVING cnt > 0 keeps both
615        let (result, _) = execute(
616            dir.path(),
617            "CREATE VIEW popular AS SELECT status, COUNT(*) as cnt FROM strategies GROUP BY status HAVING cnt > 0",
618        )
619        .unwrap();
620        assert!(matches!(result, QueryResult::Message(ref m) if m.contains("created")));
621
622        // Query the view to confirm it works
623        let (result, _) = execute(dir.path(), "SELECT * FROM popular").unwrap();
624        if let QueryResult::Rows { rows, columns } = result {
625            assert!(columns.contains(&"status".to_string()));
626            assert!(columns.contains(&"cnt".to_string()));
627            // Both LIVE and DRAFT have count 1, both > 0
628            assert_eq!(rows.len(), 2);
629        } else {
630            panic!("Expected Rows, got {:?}", result);
631        }
632    }
633
634    #[test]
635    fn test_extract_view_query_tab_after_as() {
636        let q = extract_view_query("CREATE VIEW v AS\tSELECT * FROM t").unwrap();
637        assert!(q.starts_with("SELECT"));
638        assert!(q.contains("FROM t"));
639    }
640
641    fn make_join_db() -> tempfile::TempDir {
642        let dir = tempfile::tempdir().unwrap();
643        fs::write(
644            dir.path().join("_mdql.md"),
645            "---\ntype: database\nname: testdb\n---\n",
646        )
647        .unwrap();
648
649        let strats = dir.path().join("strategies");
650        fs::create_dir(&strats).unwrap();
651        fs::write(
652            strats.join("_mdql.md"),
653            "---\ntype: schema\ntable: strategies\nprimary_key: path\nfrontmatter:\n  title:\n    type: string\n---\n",
654        )
655        .unwrap();
656        fs::write(strats.join("alpha.md"), "---\ntitle: Alpha\n---\n# Alpha\n").unwrap();
657        fs::write(strats.join("beta.md"), "---\ntitle: Beta\n---\n# Beta\n").unwrap();
658        fs::write(strats.join("gamma.md"), "---\ntitle: Gamma\n---\n# Gamma\n").unwrap();
659
660        let bt = dir.path().join("backtests");
661        fs::create_dir(&bt).unwrap();
662        fs::write(
663            bt.join("_mdql.md"),
664            "---\ntype: schema\ntable: backtests\nprimary_key: path\nfrontmatter:\n  strategy:\n    type: string\n  sharpe:\n    type: float\n---\n",
665        )
666        .unwrap();
667        fs::write(bt.join("bt-alpha.md"), "---\nstrategy: alpha.md\nsharpe: 1.5\n---\n# BT Alpha\n").unwrap();
668
669        dir
670    }
671
672    #[test]
673    fn test_inner_join() {
674        let dir = make_join_db();
675        let (result, _) = execute(
676            dir.path(),
677            "SELECT s.title, b.sharpe FROM strategies s JOIN backtests b ON b.strategy = s.path",
678        )
679        .unwrap();
680        if let QueryResult::Rows { rows, .. } = result {
681            assert_eq!(rows.len(), 1);
682            assert_eq!(rows[0].get("s.title").unwrap(), &Value::String("Alpha".into()));
683        } else {
684            panic!("Expected Rows");
685        }
686    }
687
688    #[test]
689    fn test_left_join() {
690        let dir = make_join_db();
691        let (result, _) = execute(
692            dir.path(),
693            "SELECT s.title, b.sharpe FROM strategies s LEFT JOIN backtests b ON b.strategy = s.path",
694        )
695        .unwrap();
696        if let QueryResult::Rows { rows, .. } = result {
697            assert_eq!(rows.len(), 3);
698            let alpha = rows.iter().find(|r| r.get("s.title") == Some(&Value::String("Alpha".into()))).unwrap();
699            assert_eq!(alpha.get("b.sharpe"), Some(&Value::Float(1.5)));
700            let beta = rows.iter().find(|r| r.get("s.title") == Some(&Value::String("Beta".into()))).unwrap();
701            assert_eq!(beta.get("b.sharpe"), Some(&Value::Null));
702        } else {
703            panic!("Expected Rows");
704        }
705    }
706
707    #[test]
708    fn test_left_join_in_view() {
709        let dir = make_join_db();
710        execute(
711            dir.path(),
712            "CREATE VIEW overview AS SELECT s.title, b.sharpe FROM strategies s LEFT JOIN backtests b ON b.strategy = s.path",
713        )
714        .unwrap();
715        let (result, _) = execute(dir.path(), "SELECT * FROM overview").unwrap();
716        if let QueryResult::Rows { rows, .. } = result {
717            assert_eq!(rows.len(), 3);
718        } else {
719            panic!("Expected Rows");
720        }
721    }
722
723    fn make_compound_join_db() -> tempfile::TempDir {
724        let dir = tempfile::tempdir().unwrap();
725        fs::write(
726            dir.path().join("_mdql.md"),
727            "---\ntype: database\nname: testdb\n---\n",
728        )
729        .unwrap();
730
731        let strats = dir.path().join("strategies");
732        fs::create_dir(&strats).unwrap();
733        fs::write(
734            strats.join("_mdql.md"),
735            "---\ntype: schema\ntable: strategies\nprimary_key: path\nfrontmatter:\n  title:\n    type: string\n---\n",
736        )
737        .unwrap();
738        fs::write(strats.join("alpha.md"), "---\ntitle: Alpha\n---\n# Alpha\n").unwrap();
739        fs::write(strats.join("beta.md"), "---\ntitle: Beta\n---\n# Beta\n").unwrap();
740
741        let bt = dir.path().join("backtests");
742        fs::create_dir(&bt).unwrap();
743        fs::write(
744            bt.join("_mdql.md"),
745            "---\ntype: schema\ntable: backtests\nprimary_key: path\nfrontmatter:\n  strategy:\n    type: string\n  mode:\n    type: string\n  sharpe:\n    type: float\n---\n",
746        )
747        .unwrap();
748        fs::write(bt.join("bt-alpha-paper.md"), "---\nstrategy: alpha.md\nmode: PAPER\nsharpe: 1.5\n---\n# BT\n").unwrap();
749        fs::write(bt.join("bt-alpha-live.md"), "---\nstrategy: alpha.md\nmode: LIVE\nsharpe: 1.2\n---\n# BT\n").unwrap();
750        fs::write(bt.join("bt-beta-paper.md"), "---\nstrategy: beta.md\nmode: PAPER\nsharpe: 0.8\n---\n# BT\n").unwrap();
751
752        dir
753    }
754
755    #[test]
756    fn test_join_compound_and() {
757        let dir = make_compound_join_db();
758        let (result, _) = execute(
759            dir.path(),
760            "SELECT s.title, b.sharpe FROM strategies s JOIN backtests b ON b.strategy = s.path AND b.mode = 'PAPER'",
761        )
762        .unwrap();
763        if let QueryResult::Rows { rows, .. } = result {
764            assert_eq!(rows.len(), 2);
765            let alpha = rows.iter().find(|r| r.get("s.title") == Some(&Value::String("Alpha".into()))).unwrap();
766            assert_eq!(alpha.get("b.sharpe"), Some(&Value::Float(1.5)));
767            let beta = rows.iter().find(|r| r.get("s.title") == Some(&Value::String("Beta".into()))).unwrap();
768            assert_eq!(beta.get("b.sharpe"), Some(&Value::Float(0.8)));
769        } else {
770            panic!("Expected Rows");
771        }
772    }
773
774    #[test]
775    fn test_left_join_compound() {
776        let dir = make_compound_join_db();
777        let (result, _) = execute(
778            dir.path(),
779            "SELECT s.title, b.sharpe FROM strategies s LEFT JOIN backtests b ON b.strategy = s.path AND b.mode = 'LIVE'",
780        )
781        .unwrap();
782        if let QueryResult::Rows { rows, .. } = result {
783            assert_eq!(rows.len(), 2);
784            let alpha = rows.iter().find(|r| r.get("s.title") == Some(&Value::String("Alpha".into()))).unwrap();
785            assert_eq!(alpha.get("b.sharpe"), Some(&Value::Float(1.2)));
786            let beta = rows.iter().find(|r| r.get("s.title") == Some(&Value::String("Beta".into()))).unwrap();
787            assert_eq!(beta.get("b.sharpe"), Some(&Value::Null));
788        } else {
789            panic!("Expected Rows");
790        }
791    }
792
793    #[test]
794    fn test_join_compound_with_comparison() {
795        let dir = make_compound_join_db();
796        let (result, _) = execute(
797            dir.path(),
798            "SELECT s.title, b.sharpe FROM strategies s JOIN backtests b ON b.strategy = s.path AND b.sharpe > 1.0",
799        )
800        .unwrap();
801        if let QueryResult::Rows { rows, .. } = result {
802            assert_eq!(rows.len(), 2);
803            assert!(rows.iter().all(|r| {
804                if let Some(Value::Float(s)) = r.get("b.sharpe") { *s > 1.0 } else { false }
805            }));
806        } else {
807            panic!("Expected Rows");
808        }
809    }
810
811    fn make_cascade_db() -> tempfile::TempDir {
812        let dir = tempfile::tempdir().unwrap();
813
814        fs::write(
815            dir.path().join("_mdql.md"),
816            "---\ntype: database\nname: testdb\nforeign_keys:\n  - from: backtests.strategy\n    to: strategies.path\n---\n",
817        )
818        .unwrap();
819
820        let strats = dir.path().join("strategies");
821        fs::create_dir(&strats).unwrap();
822        fs::write(
823            strats.join("_mdql.md"),
824            "---\ntype: schema\ntable: strategies\nprimary_key: path\nfrontmatter:\n  title:\n    type: string\n  status:\n    type: string\n---\n",
825        )
826        .unwrap();
827        fs::write(strats.join("alpha.md"), "---\ntitle: Alpha\nstatus: KILLED\n---\n# Alpha\n").unwrap();
828        fs::write(strats.join("beta.md"), "---\ntitle: Beta\nstatus: LIVE\n---\n# Beta\n").unwrap();
829
830        let bt = dir.path().join("backtests");
831        fs::create_dir(&bt).unwrap();
832        fs::write(
833            bt.join("_mdql.md"),
834            "---\ntype: schema\ntable: backtests\nprimary_key: path\nfrontmatter:\n  strategy:\n    type: string\n  sharpe:\n    type: float\n---\n",
835        )
836        .unwrap();
837        fs::write(bt.join("bt-alpha.md"), "---\nstrategy: alpha.md\nsharpe: 1.5\n---\n# BT Alpha\n").unwrap();
838        fs::write(bt.join("bt-beta.md"), "---\nstrategy: beta.md\nsharpe: 0.8\n---\n# BT Beta\n").unwrap();
839
840        dir
841    }
842
843    #[test]
844    fn test_cascade_delete() {
845        let dir = make_cascade_db();
846        let (result, _) = execute(
847            dir.path(),
848            "DELETE FROM strategies WHERE status = 'KILLED' CASCADE",
849        )
850        .unwrap();
851        if let QueryResult::Message(msg) = result {
852            assert!(msg.contains("DELETE 1"));
853            assert!(msg.contains("cascade"));
854        } else {
855            panic!("Expected Message");
856        }
857        assert!(!dir.path().join("strategies/alpha.md").exists());
858        assert!(!dir.path().join("backtests/bt-alpha.md").exists());
859        assert!(dir.path().join("strategies/beta.md").exists());
860        assert!(dir.path().join("backtests/bt-beta.md").exists());
861    }
862
863    #[test]
864    fn test_restrict_delete_blocks() {
865        let dir = make_cascade_db();
866        let err = execute(
867            dir.path(),
868            "DELETE FROM strategies WHERE status = 'KILLED' RESTRICT",
869        );
870        assert!(err.is_err());
871        let msg = err.unwrap_err().to_string();
872        assert!(msg.contains("RESTRICT"));
873        assert!(dir.path().join("strategies/alpha.md").exists());
874    }
875
876    #[test]
877    fn test_restrict_delete_allows_no_dependents() {
878        let dir = make_cascade_db();
879        fs::remove_file(dir.path().join("backtests/bt-alpha.md")).unwrap();
880
881        let (result, _) = execute(
882            dir.path(),
883            "DELETE FROM strategies WHERE status = 'KILLED' RESTRICT",
884        )
885        .unwrap();
886        if let QueryResult::Message(msg) = result {
887            assert!(msg.contains("DELETE 1"));
888        } else {
889            panic!("Expected Message");
890        }
891        assert!(!dir.path().join("strategies/alpha.md").exists());
892    }
893
894    #[test]
895    fn test_cascade_default_unchanged() {
896        let dir = make_cascade_db();
897        let (result, _) = execute(
898            dir.path(),
899            "DELETE FROM strategies WHERE status = 'KILLED'",
900        )
901        .unwrap();
902        if let QueryResult::Message(msg) = result {
903            assert!(msg.contains("DELETE 1"));
904        } else {
905            panic!("Expected Message");
906        }
907        assert!(!dir.path().join("strategies/alpha.md").exists());
908        assert!(dir.path().join("backtests/bt-alpha.md").exists());
909    }
910
911    // ── CTE (WITH) integration tests ──
912
913    #[test]
914    fn test_cte_basic() {
915        let dir = make_test_db();
916        let (result, _) = execute(
917            dir.path(),
918            "WITH live AS (SELECT * FROM strategies WHERE status = 'LIVE') SELECT * FROM live",
919        )
920        .unwrap();
921        if let QueryResult::Rows { rows, columns } = result {
922            assert_eq!(rows.len(), 1);
923            assert!(columns.contains(&"title".to_string()));
924            assert_eq!(rows[0].get("title"), Some(&Value::String("Alpha".into())));
925        } else {
926            panic!("Expected Rows");
927        }
928    }
929
930    #[test]
931    fn test_cte_with_where_on_cte() {
932        let dir = make_join_db();
933        let (result, _) = execute(
934            dir.path(),
935            "WITH bt AS (SELECT * FROM backtests WHERE sharpe > 1.0) SELECT * FROM bt",
936        )
937        .unwrap();
938        if let QueryResult::Rows { rows, .. } = result {
939            assert_eq!(rows.len(), 1);
940            assert_eq!(rows[0].get("sharpe"), Some(&Value::Float(1.5)));
941        } else {
942            panic!("Expected Rows");
943        }
944    }
945
946    #[test]
947    fn test_cte_multi_with_join() {
948        let dir = make_join_db();
949        let (result, _) = execute(
950            dir.path(),
951            "WITH s AS (SELECT * FROM strategies), b AS (SELECT * FROM backtests) SELECT s.title, b.sharpe FROM s JOIN b ON b.strategy = s.path",
952        )
953        .unwrap();
954        if let QueryResult::Rows { rows, .. } = result {
955            assert_eq!(rows.len(), 1);
956            assert_eq!(rows[0].get("s.title"), Some(&Value::String("Alpha".into())));
957        } else {
958            panic!("Expected Rows");
959        }
960    }
961
962    #[test]
963    fn test_cte_with_aggregation() {
964        let dir = make_test_db();
965        let (result, _) = execute(
966            dir.path(),
967            "WITH counts AS (SELECT status, COUNT(*) AS cnt FROM strategies GROUP BY status) SELECT * FROM counts WHERE cnt > 0",
968        )
969        .unwrap();
970        if let QueryResult::Rows { rows, columns } = result {
971            assert!(columns.contains(&"status".to_string()));
972            assert!(columns.contains(&"cnt".to_string()));
973            assert!(rows.len() >= 1);
974        } else {
975            panic!("Expected Rows");
976        }
977    }
978
979    #[test]
980    fn test_cte_chained() {
981        let dir = make_join_db();
982        let (result, _) = execute(
983            dir.path(),
984            "WITH good AS (SELECT * FROM backtests WHERE sharpe > 1.0), matched AS (SELECT s.title, g.sharpe FROM strategies s JOIN good g ON g.strategy = s.path) SELECT * FROM matched",
985        )
986        .unwrap();
987        if let QueryResult::Rows { rows, .. } = result {
988            assert_eq!(rows.len(), 1);
989        } else {
990            panic!("Expected Rows");
991        }
992    }
993
994    // ── Subquery integration tests ──
995
996    #[test]
997    fn test_where_in_subquery() {
998        let dir = make_join_db();
999        let (result, _) = execute(
1000            dir.path(),
1001            "SELECT * FROM strategies WHERE path IN (SELECT strategy FROM backtests)",
1002        )
1003        .unwrap();
1004        if let QueryResult::Rows { rows, .. } = result {
1005            assert_eq!(rows.len(), 1);
1006            assert_eq!(rows[0].get("title"), Some(&Value::String("Alpha".into())));
1007        } else {
1008            panic!("Expected Rows");
1009        }
1010    }
1011
1012    fn make_multi_bt_db() -> tempfile::TempDir {
1013        let dir = tempfile::tempdir().unwrap();
1014        fs::write(
1015            dir.path().join("_mdql.md"),
1016            "---\ntype: database\nname: testdb\n---\n",
1017        )
1018        .unwrap();
1019
1020        let strats = dir.path().join("strategies");
1021        fs::create_dir(&strats).unwrap();
1022        fs::write(
1023            strats.join("_mdql.md"),
1024            "---\ntype: schema\ntable: strategies\nprimary_key: path\nfrontmatter:\n  title:\n    type: string\n---\n",
1025        )
1026        .unwrap();
1027        fs::write(strats.join("alpha.md"), "---\ntitle: Alpha\n---\n# Alpha\n").unwrap();
1028        fs::write(strats.join("beta.md"), "---\ntitle: Beta\n---\n# Beta\n").unwrap();
1029
1030        let bt = dir.path().join("backtests");
1031        fs::create_dir(&bt).unwrap();
1032        fs::write(
1033            bt.join("_mdql.md"),
1034            "---\ntype: schema\ntable: backtests\nprimary_key: path\nfrontmatter:\n  strategy:\n    type: string\n  sharpe:\n    type: float\n---\n",
1035        )
1036        .unwrap();
1037        fs::write(bt.join("bt-alpha.md"), "---\nstrategy: alpha.md\nsharpe: 2.0\n---\n# BT\n").unwrap();
1038        fs::write(bt.join("bt-beta.md"), "---\nstrategy: beta.md\nsharpe: 0.5\n---\n# BT\n").unwrap();
1039
1040        dir
1041    }
1042
1043    #[test]
1044    fn test_where_scalar_subquery() {
1045        let dir = make_multi_bt_db();
1046        // AVG(sharpe) = (2.0 + 0.5) / 2 = 1.25 → only bt-alpha (2.0) passes
1047        let (result, _) = execute(
1048            dir.path(),
1049            "SELECT * FROM backtests WHERE sharpe > (SELECT AVG(sharpe) FROM backtests)",
1050        )
1051        .unwrap();
1052        if let QueryResult::Rows { rows, .. } = result {
1053            assert_eq!(rows.len(), 1);
1054            assert_eq!(rows[0].get("sharpe"), Some(&Value::Float(2.0)));
1055        } else {
1056            panic!("Expected Rows");
1057        }
1058    }
1059
1060    #[test]
1061    fn test_select_scalar_subquery() {
1062        let dir = make_join_db();
1063        let (result, _) = execute(
1064            dir.path(),
1065            "SELECT title, (SELECT COUNT(*) FROM backtests) AS bt_count FROM strategies",
1066        )
1067        .unwrap();
1068        if let QueryResult::Rows { rows, columns } = result {
1069            assert!(columns.contains(&"bt_count".to_string()));
1070            for row in &rows {
1071                assert_eq!(row.get("bt_count"), Some(&Value::Int(1)));
1072            }
1073        } else {
1074            panic!("Expected Rows");
1075        }
1076    }
1077}