Skip to main content

mdql_core/
executor.rs

1//! Unified SQL execution — single entry point for CLI, REPL, and web server.
2
3use std::path::Path;
4
5use crate::api::Table;
6use crate::database::{ViewDef, is_database_dir, load_database_config, save_database_config};
7use crate::errors::{MdqlError, ValidationError};
8use crate::model::Row;
9use crate::query_engine::{execute_join_query, execute_query};
10use crate::query_parser::{Statement, parse_query};
11
12#[derive(Debug)]
13pub enum QueryResult {
14    Rows { rows: Vec<Row>, columns: Vec<String> },
15    Message(String),
16}
17
18pub fn execute(path: &Path, sql: &str) -> crate::errors::Result<(QueryResult, Vec<ValidationError>)> {
19    let stmt = parse_query(sql)?;
20    let is_db = is_database_dir(path);
21
22    match stmt {
23        Statement::Select(ref q) => {
24            if q.subquery.is_some() || !q.joins.is_empty() || is_db {
25                let (_config, tables, errors) = crate::loader::load_database(path)?;
26                let (rows, cols) = if let Some(ref sub) = q.subquery {
27                    let source_table = &sub.table;
28                    let (schema, table_rows) = tables.get(source_table).ok_or_else(|| {
29                        MdqlError::QueryExecution(format!(
30                            "table '{}' not found in database",
31                            source_table
32                        ))
33                    })?;
34                    execute_query(q, table_rows, schema)?
35                } else if !q.joins.is_empty() {
36                    execute_join_query(q, &tables)?
37                } else {
38                    let (schema, rows) = tables.get(&q.table).ok_or_else(|| {
39                        MdqlError::QueryExecution(format!(
40                            "table '{}' not found in database",
41                            q.table
42                        ))
43                    })?;
44                    execute_query(q, rows, schema)?
45                };
46                Ok((QueryResult::Rows { rows, columns: cols }, errors))
47            } else {
48                let (schema, rows, errors) = crate::loader::load_table(path)?;
49                let (rows, cols) = execute_query(q, &rows, &schema)?;
50                Ok((QueryResult::Rows { rows, columns: cols }, errors))
51            }
52        }
53        Statement::CreateView(ref cv) => {
54            if !is_db {
55                return Err(MdqlError::QueryExecution(
56                    "CREATE VIEW requires a database directory".into(),
57                ));
58            }
59            let mut config = load_database_config(path)?;
60
61            let (_config_check, tables, _errors) = crate::loader::load_database(path)?;
62            if tables.contains_key(&cv.view_name) {
63                return Err(MdqlError::QueryExecution(format!(
64                    "Name '{}' already exists as a table or view",
65                    cv.view_name
66                )));
67            }
68
69            if config.views.iter().any(|v| v.name == cv.view_name) {
70                return Err(MdqlError::QueryExecution(format!(
71                    "View '{}' already exists",
72                    cv.view_name
73                )));
74            }
75
76            let query_str = extract_view_query(sql)?;
77
78            let view_def = ViewDef {
79                name: cv.view_name.clone(),
80                query: query_str,
81            };
82
83            let test_result = crate::loader::load_database(path);
84            if let Ok((_cfg, test_tables, _errs)) = test_result {
85                let test_view = ViewDef {
86                    name: view_def.name.clone(),
87                    query: view_def.query.clone(),
88                };
89                if let Err(e) = super::loader::materialize_view(&test_view, &test_tables) {
90                    return Err(MdqlError::QueryExecution(format!(
91                        "View query failed validation: {}",
92                        e
93                    )));
94                }
95            }
96
97            config.views.push(view_def);
98            save_database_config(path, &config)?;
99            Ok((
100                QueryResult::Message(format!("View '{}' created", cv.view_name)),
101                vec![],
102            ))
103        }
104        Statement::DropView(ref dv) => {
105            if !is_db {
106                return Err(MdqlError::QueryExecution(
107                    "DROP VIEW requires a database directory".into(),
108                ));
109            }
110            let mut config = load_database_config(path)?;
111            let len_before = config.views.len();
112            config.views.retain(|v| v.name != dv.view_name);
113            if config.views.len() == len_before {
114                return Err(MdqlError::QueryExecution(format!(
115                    "View '{}' does not exist",
116                    dv.view_name
117                )));
118            }
119            save_database_config(path, &config)?;
120            Ok((
121                QueryResult::Message(format!("View '{}' dropped", dv.view_name)),
122                vec![],
123            ))
124        }
125        ref stmt @ (Statement::Insert(_)
126        | Statement::Update(_)
127        | Statement::Delete(_)
128        | Statement::AlterRename(_)
129        | Statement::AlterDrop(_)
130        | Statement::AlterMerge(_)) => {
131            if is_db {
132                let config = load_database_config(path)?;
133                let target = stmt.table_name();
134                if config.views.iter().any(|v| v.name == target) {
135                    return Err(MdqlError::QueryExecution(format!(
136                        "Cannot write to view '{}' — views are read-only",
137                        target
138                    )));
139                }
140            }
141            let table_path = if is_db {
142                path.join(stmt.table_name())
143            } else {
144                path.to_path_buf()
145            };
146            let mut table = Table::new(&table_path)?;
147            let msg = table.execute_sql(sql)?;
148            Ok((QueryResult::Message(msg), vec![]))
149        }
150    }
151}
152
153fn extract_view_query(sql: &str) -> crate::errors::Result<String> {
154    let upper = sql.to_uppercase();
155    let as_keyword = upper.find(" AS ");
156    if let Some(pos) = as_keyword {
157        let after = &sql[pos + 4..];
158        let trimmed = after.trim_start();
159        let trimmed_upper = trimmed.to_uppercase();
160        if trimmed_upper.starts_with("SELECT") {
161            return Ok(trimmed.to_string());
162        }
163    }
164    // Fallback: scan for any whitespace-surrounded AS that precedes SELECT
165    let bytes = upper.as_bytes();
166    let mut i = 0;
167    while i + 4 < bytes.len() {
168        if bytes[i].is_ascii_whitespace()
169            && bytes[i + 1] == b'A'
170            && bytes[i + 2] == b'S'
171            && bytes[i + 3].is_ascii_whitespace()
172        {
173            let after = &sql[i + 3..];
174            let trimmed = after.trim_start();
175            let trimmed_upper = trimmed.to_uppercase();
176            if trimmed_upper.starts_with("SELECT") {
177                return Ok(trimmed.to_string());
178            }
179        }
180        i += 1;
181    }
182    Err(crate::errors::MdqlError::QueryExecution(
183        "CREATE VIEW must contain AS clause followed by SELECT".into(),
184    ))
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190    use crate::model::Value;
191    use std::fs;
192
193    fn make_test_db() -> tempfile::TempDir {
194        let dir = tempfile::tempdir().unwrap();
195
196        // Database-level _mdql.md
197        fs::write(
198            dir.path().join("_mdql.md"),
199            "---\ntype: database\nname: testdb\n---\n",
200        )
201        .unwrap();
202
203        // Table: strategies
204        let strats = dir.path().join("strategies");
205        fs::create_dir(&strats).unwrap();
206        fs::write(
207            strats.join("_mdql.md"),
208            "---\ntype: schema\ntable: strategies\nprimary_key: path\nfrontmatter:\n  title:\n    type: string\n  status:\n    type: string\n---\n",
209        )
210        .unwrap();
211        fs::write(
212            strats.join("alpha.md"),
213            "---\ntitle: Alpha\nstatus: LIVE\n---\n# Alpha\n",
214        )
215        .unwrap();
216        fs::write(
217            strats.join("beta.md"),
218            "---\ntitle: Beta\nstatus: DRAFT\n---\n# Beta\n",
219        )
220        .unwrap();
221
222        dir
223    }
224
225    #[test]
226    fn test_create_and_query_view() {
227        let dir = make_test_db();
228        let (result, _) = execute(
229            dir.path(),
230            "CREATE VIEW live AS SELECT * FROM strategies WHERE status = 'LIVE'",
231        )
232        .unwrap();
233        assert!(matches!(result, QueryResult::Message(ref m) if m.contains("created")));
234
235        let (result, _) = execute(dir.path(), "SELECT * FROM live").unwrap();
236        if let QueryResult::Rows { rows, columns } = result {
237            assert_eq!(rows.len(), 1);
238            assert!(columns.contains(&"title".to_string()));
239        } else {
240            panic!("Expected Rows");
241        }
242    }
243
244    #[test]
245    fn test_drop_view() {
246        let dir = make_test_db();
247        execute(
248            dir.path(),
249            "CREATE VIEW live AS SELECT * FROM strategies WHERE status = 'LIVE'",
250        )
251        .unwrap();
252
253        let (result, _) = execute(dir.path(), "DROP VIEW live").unwrap();
254        assert!(matches!(result, QueryResult::Message(ref m) if m.contains("dropped")));
255
256        let err = execute(dir.path(), "SELECT * FROM live");
257        assert!(err.is_err());
258    }
259
260    #[test]
261    fn test_drop_nonexistent_view() {
262        let dir = make_test_db();
263        let err = execute(dir.path(), "DROP VIEW nonexistent");
264        assert!(err.is_err());
265        assert!(err.unwrap_err().to_string().contains("does not exist"));
266    }
267
268    #[test]
269    fn test_create_view_duplicate_name() {
270        let dir = make_test_db();
271        execute(
272            dir.path(),
273            "CREATE VIEW live AS SELECT * FROM strategies WHERE status = 'LIVE'",
274        )
275        .unwrap();
276
277        let err = execute(
278            dir.path(),
279            "CREATE VIEW live AS SELECT * FROM strategies",
280        );
281        assert!(err.is_err());
282        assert!(err.unwrap_err().to_string().contains("already exists"));
283    }
284
285    #[test]
286    fn test_create_view_conflicts_with_table() {
287        let dir = make_test_db();
288        let err = execute(
289            dir.path(),
290            "CREATE VIEW strategies AS SELECT * FROM strategies",
291        );
292        assert!(err.is_err());
293        assert!(err.unwrap_err().to_string().contains("already exists"));
294    }
295
296    #[test]
297    fn test_write_to_view_rejected() {
298        let dir = make_test_db();
299        execute(
300            dir.path(),
301            "CREATE VIEW live AS SELECT * FROM strategies WHERE status = 'LIVE'",
302        )
303        .unwrap();
304
305        let err = execute(
306            dir.path(),
307            "INSERT INTO live (title, status) VALUES ('Gamma', 'LIVE')",
308        );
309        assert!(err.is_err());
310        assert!(err.unwrap_err().to_string().contains("read-only"));
311    }
312
313    #[test]
314    fn test_create_view_not_database() {
315        let dir = tempfile::tempdir().unwrap();
316        fs::write(
317            dir.path().join("_mdql.md"),
318            "---\ntype: schema\ntable: t\nprimary_key: path\nfrontmatter:\n  x:\n    type: string\n---\n",
319        )
320        .unwrap();
321
322        let err = execute(
323            dir.path(),
324            "CREATE VIEW v AS SELECT * FROM t",
325        );
326        assert!(err.is_err());
327        assert!(err.unwrap_err().to_string().contains("database directory"));
328    }
329
330    #[test]
331    fn test_extract_view_query_basic() {
332        let q = extract_view_query("CREATE VIEW v AS SELECT * FROM t").unwrap();
333        assert!(q.starts_with("SELECT"));
334    }
335
336    #[test]
337    fn test_extract_view_query_with_column_alias() {
338        let q = extract_view_query(
339            "CREATE VIEW v AS SELECT token, SUM(size) as sell_size FROM orders GROUP BY token HAVING sell_size > 0"
340        ).unwrap();
341        assert!(q.starts_with("SELECT"));
342        assert!(q.contains("HAVING"));
343    }
344
345    #[test]
346    fn test_extract_view_query_newline_after_as() {
347        let q = extract_view_query("CREATE VIEW v AS\nSELECT * FROM t").unwrap();
348        assert!(q.starts_with("SELECT"));
349    }
350
351    #[test]
352    fn test_create_view_with_aggregate_arithmetic() {
353        let dir = make_test_db();
354        let result = execute(
355            dir.path(),
356            "CREATE VIEW v AS SELECT status, COUNT(*) - COUNT(*) as zero FROM strategies GROUP BY status",
357        );
358        assert!(result.is_ok());
359    }
360
361    // ── Issue #44: HAVING in CREATE VIEW ──
362
363    #[test]
364    fn test_create_view_with_having() {
365        let dir = make_test_db();
366        // Create a view with HAVING — both statuses have cnt=1, so HAVING cnt > 0 keeps both
367        let (result, _) = execute(
368            dir.path(),
369            "CREATE VIEW popular AS SELECT status, COUNT(*) as cnt FROM strategies GROUP BY status HAVING cnt > 0",
370        )
371        .unwrap();
372        assert!(matches!(result, QueryResult::Message(ref m) if m.contains("created")));
373
374        // Query the view to confirm it works
375        let (result, _) = execute(dir.path(), "SELECT * FROM popular").unwrap();
376        if let QueryResult::Rows { rows, columns } = result {
377            assert!(columns.contains(&"status".to_string()));
378            assert!(columns.contains(&"cnt".to_string()));
379            // Both LIVE and DRAFT have count 1, both > 0
380            assert_eq!(rows.len(), 2);
381        } else {
382            panic!("Expected Rows, got {:?}", result);
383        }
384    }
385
386    #[test]
387    fn test_extract_view_query_tab_after_as() {
388        let q = extract_view_query("CREATE VIEW v AS\tSELECT * FROM t").unwrap();
389        assert!(q.starts_with("SELECT"));
390        assert!(q.contains("FROM t"));
391    }
392
393    fn make_join_db() -> tempfile::TempDir {
394        let dir = tempfile::tempdir().unwrap();
395        fs::write(
396            dir.path().join("_mdql.md"),
397            "---\ntype: database\nname: testdb\n---\n",
398        )
399        .unwrap();
400
401        let strats = dir.path().join("strategies");
402        fs::create_dir(&strats).unwrap();
403        fs::write(
404            strats.join("_mdql.md"),
405            "---\ntype: schema\ntable: strategies\nprimary_key: path\nfrontmatter:\n  title:\n    type: string\n---\n",
406        )
407        .unwrap();
408        fs::write(strats.join("alpha.md"), "---\ntitle: Alpha\n---\n# Alpha\n").unwrap();
409        fs::write(strats.join("beta.md"), "---\ntitle: Beta\n---\n# Beta\n").unwrap();
410        fs::write(strats.join("gamma.md"), "---\ntitle: Gamma\n---\n# Gamma\n").unwrap();
411
412        let bt = dir.path().join("backtests");
413        fs::create_dir(&bt).unwrap();
414        fs::write(
415            bt.join("_mdql.md"),
416            "---\ntype: schema\ntable: backtests\nprimary_key: path\nfrontmatter:\n  strategy:\n    type: string\n  sharpe:\n    type: float\n---\n",
417        )
418        .unwrap();
419        fs::write(bt.join("bt-alpha.md"), "---\nstrategy: alpha.md\nsharpe: 1.5\n---\n# BT Alpha\n").unwrap();
420
421        dir
422    }
423
424    #[test]
425    fn test_inner_join() {
426        let dir = make_join_db();
427        let (result, _) = execute(
428            dir.path(),
429            "SELECT s.title, b.sharpe FROM strategies s JOIN backtests b ON b.strategy = s.path",
430        )
431        .unwrap();
432        if let QueryResult::Rows { rows, .. } = result {
433            assert_eq!(rows.len(), 1);
434            assert_eq!(rows[0].get("s.title").unwrap(), &Value::String("Alpha".into()));
435        } else {
436            panic!("Expected Rows");
437        }
438    }
439
440    #[test]
441    fn test_left_join() {
442        let dir = make_join_db();
443        let (result, _) = execute(
444            dir.path(),
445            "SELECT s.title, b.sharpe FROM strategies s LEFT JOIN backtests b ON b.strategy = s.path",
446        )
447        .unwrap();
448        if let QueryResult::Rows { rows, .. } = result {
449            assert_eq!(rows.len(), 3);
450            let alpha = rows.iter().find(|r| r.get("s.title") == Some(&Value::String("Alpha".into()))).unwrap();
451            assert_eq!(alpha.get("b.sharpe"), Some(&Value::Float(1.5)));
452            let beta = rows.iter().find(|r| r.get("s.title") == Some(&Value::String("Beta".into()))).unwrap();
453            assert_eq!(beta.get("b.sharpe"), Some(&Value::Null));
454        } else {
455            panic!("Expected Rows");
456        }
457    }
458
459    #[test]
460    fn test_left_join_in_view() {
461        let dir = make_join_db();
462        execute(
463            dir.path(),
464            "CREATE VIEW overview AS SELECT s.title, b.sharpe FROM strategies s LEFT JOIN backtests b ON b.strategy = s.path",
465        )
466        .unwrap();
467        let (result, _) = execute(dir.path(), "SELECT * FROM overview").unwrap();
468        if let QueryResult::Rows { rows, .. } = result {
469            assert_eq!(rows.len(), 3);
470        } else {
471            panic!("Expected Rows");
472        }
473    }
474
475    fn make_compound_join_db() -> tempfile::TempDir {
476        let dir = tempfile::tempdir().unwrap();
477        fs::write(
478            dir.path().join("_mdql.md"),
479            "---\ntype: database\nname: testdb\n---\n",
480        )
481        .unwrap();
482
483        let strats = dir.path().join("strategies");
484        fs::create_dir(&strats).unwrap();
485        fs::write(
486            strats.join("_mdql.md"),
487            "---\ntype: schema\ntable: strategies\nprimary_key: path\nfrontmatter:\n  title:\n    type: string\n---\n",
488        )
489        .unwrap();
490        fs::write(strats.join("alpha.md"), "---\ntitle: Alpha\n---\n# Alpha\n").unwrap();
491        fs::write(strats.join("beta.md"), "---\ntitle: Beta\n---\n# Beta\n").unwrap();
492
493        let bt = dir.path().join("backtests");
494        fs::create_dir(&bt).unwrap();
495        fs::write(
496            bt.join("_mdql.md"),
497            "---\ntype: schema\ntable: backtests\nprimary_key: path\nfrontmatter:\n  strategy:\n    type: string\n  mode:\n    type: string\n  sharpe:\n    type: float\n---\n",
498        )
499        .unwrap();
500        fs::write(bt.join("bt-alpha-paper.md"), "---\nstrategy: alpha.md\nmode: PAPER\nsharpe: 1.5\n---\n# BT\n").unwrap();
501        fs::write(bt.join("bt-alpha-live.md"), "---\nstrategy: alpha.md\nmode: LIVE\nsharpe: 1.2\n---\n# BT\n").unwrap();
502        fs::write(bt.join("bt-beta-paper.md"), "---\nstrategy: beta.md\nmode: PAPER\nsharpe: 0.8\n---\n# BT\n").unwrap();
503
504        dir
505    }
506
507    #[test]
508    fn test_join_compound_and() {
509        let dir = make_compound_join_db();
510        let (result, _) = execute(
511            dir.path(),
512            "SELECT s.title, b.sharpe FROM strategies s JOIN backtests b ON b.strategy = s.path AND b.mode = 'PAPER'",
513        )
514        .unwrap();
515        if let QueryResult::Rows { rows, .. } = result {
516            assert_eq!(rows.len(), 2);
517            let alpha = rows.iter().find(|r| r.get("s.title") == Some(&Value::String("Alpha".into()))).unwrap();
518            assert_eq!(alpha.get("b.sharpe"), Some(&Value::Float(1.5)));
519            let beta = rows.iter().find(|r| r.get("s.title") == Some(&Value::String("Beta".into()))).unwrap();
520            assert_eq!(beta.get("b.sharpe"), Some(&Value::Float(0.8)));
521        } else {
522            panic!("Expected Rows");
523        }
524    }
525
526    #[test]
527    fn test_left_join_compound() {
528        let dir = make_compound_join_db();
529        let (result, _) = execute(
530            dir.path(),
531            "SELECT s.title, b.sharpe FROM strategies s LEFT JOIN backtests b ON b.strategy = s.path AND b.mode = 'LIVE'",
532        )
533        .unwrap();
534        if let QueryResult::Rows { rows, .. } = result {
535            assert_eq!(rows.len(), 2);
536            let alpha = rows.iter().find(|r| r.get("s.title") == Some(&Value::String("Alpha".into()))).unwrap();
537            assert_eq!(alpha.get("b.sharpe"), Some(&Value::Float(1.2)));
538            let beta = rows.iter().find(|r| r.get("s.title") == Some(&Value::String("Beta".into()))).unwrap();
539            assert_eq!(beta.get("b.sharpe"), Some(&Value::Null));
540        } else {
541            panic!("Expected Rows");
542        }
543    }
544
545    #[test]
546    fn test_join_compound_with_comparison() {
547        let dir = make_compound_join_db();
548        let (result, _) = execute(
549            dir.path(),
550            "SELECT s.title, b.sharpe FROM strategies s JOIN backtests b ON b.strategy = s.path AND b.sharpe > 1.0",
551        )
552        .unwrap();
553        if let QueryResult::Rows { rows, .. } = result {
554            assert_eq!(rows.len(), 2);
555            assert!(rows.iter().all(|r| {
556                if let Some(Value::Float(s)) = r.get("b.sharpe") { *s > 1.0 } else { false }
557            }));
558        } else {
559            panic!("Expected Rows");
560        }
561    }
562}