Skip to main content

overdrive/
query_engine.rs

1//! Embedded SQL Query Engine for OverDrive InCode SDK
2//! 
3//! Parses and executes SQL queries directly against the embedded database.
4//! No server needed — runs entirely in-process.
5//! 
6//! ## Supported SQL
7//! 
8//! - `SELECT [columns] FROM <table> [WHERE ...] [ORDER BY ...] [LIMIT n] [OFFSET n]`
9//! - `INSERT INTO <table> {json}`
10//! - `UPDATE <table> SET {json} [WHERE ...]`
11//! - `DELETE FROM <table> [WHERE ...]`
12//! - `CREATE TABLE <name>`
13//! - `DROP TABLE <name>`
14//! - `SHOW TABLES`
15//! - `SELECT COUNT(*) FROM <table>` (and other aggregates)
16
17use crate::{OverDriveDB, QueryResult};
18use crate::result::{SdkResult, SdkError};
19use serde_json::Value;
20
21/// Execute an SQL query against the embedded database
22pub fn execute(db: &mut OverDriveDB, sql: &str) -> SdkResult<QueryResult> {
23    let sql = sql.trim().trim_end_matches(';').trim();
24    
25    if sql.is_empty() {
26        return Ok(QueryResult::empty());
27    }
28    
29    // Tokenize: split into words respecting quoted strings and braces
30    let tokens = tokenize(sql);
31    
32    if tokens.is_empty() {
33        return Ok(QueryResult::empty());
34    }
35    
36    let first = tokens[0].to_uppercase();
37    
38    match first.as_str() {
39        "SELECT" => execute_select(db, &tokens),
40        "INSERT" => execute_insert(db, &tokens, sql),
41        "UPDATE" => execute_update(db, &tokens, sql),
42        "DELETE" => execute_delete(db, &tokens),
43        "CREATE" => execute_create(db, &tokens),
44        "DROP"   => execute_drop(db, &tokens),
45        "SHOW"   => execute_show(db, &tokens),
46        _ => Err(SdkError::InvalidQuery(format!("Unsupported SQL command: {}", first))),
47    }
48}
49
50/// Simple tokenizer that respects quoted strings and JSON braces
51fn tokenize(input: &str) -> Vec<String> {
52    let mut tokens = Vec::new();
53    let mut chars = input.chars().peekable();
54    
55    while let Some(&c) = chars.peek() {
56        if c.is_whitespace() {
57            chars.next();
58            continue;
59        }
60        
61        // Quoted string
62        if c == '\'' || c == '"' {
63            let quote = c;
64            chars.next();
65            let mut s = String::new();
66            while let Some(&ch) = chars.peek() {
67                if ch == quote {
68                    chars.next();
69                    break;
70                }
71                if ch == '\\' {
72                    chars.next();
73                    if let Some(&escaped) = chars.peek() {
74                        s.push(escaped);
75                        chars.next();
76                        continue;
77                    }
78                }
79                s.push(ch);
80                chars.next();
81            }
82            tokens.push(format!("'{}'", s));
83            continue;
84        }
85        
86        // JSON object or array
87        if c == '{' || c == '[' {
88            let (close, open) = if c == '{' { ('}', '{') } else { (']', '[') };
89            let mut depth = 0;
90            let mut s = String::new();
91            while let Some(&ch) = chars.peek() {
92                s.push(ch);
93                if ch == open { depth += 1; }
94                if ch == close { depth -= 1; }
95                chars.next();
96                if depth == 0 { break; }
97            }
98            tokens.push(s);
99            continue;
100        }
101        
102        // Operators
103        if c == '>' || c == '<' || c == '!' || c == '=' {
104            let mut op = String::new();
105            op.push(c);
106            chars.next();
107            if let Some(&next) = chars.peek() {
108                if next == '=' {
109                    op.push(next);
110                    chars.next();
111                }
112            }
113            tokens.push(op);
114            continue;
115        }
116        
117        // Comma, parentheses, star
118        if c == ',' || c == '(' || c == ')' || c == '*' {
119            tokens.push(c.to_string());
120            chars.next();
121            continue;
122        }
123        
124        // Word or number
125        let mut word = String::new();
126        while let Some(&ch) = chars.peek() {
127            if ch.is_whitespace() || ch == ',' || ch == '(' || ch == ')' 
128               || ch == '{' || ch == '[' || ch == '>' || ch == '<' 
129               || ch == '=' || ch == '!' {
130                break;
131            }
132            word.push(ch);
133            chars.next();
134        }
135        if !word.is_empty() {
136            tokens.push(word);
137        }
138    }
139    
140    tokens
141}
142
143/// Execute SELECT query
144fn execute_select(db: &mut OverDriveDB, tokens: &[String]) -> SdkResult<QueryResult> {
145    let mut pos = 1; // skip "SELECT"
146    
147    // Parse columns
148    let mut columns: Vec<String> = Vec::new();
149    let mut aggregates: Vec<(String, String)> = Vec::new(); // (func_name, arg)
150    
151    while pos < tokens.len() {
152        let upper = tokens[pos].to_uppercase();
153        if upper == "FROM" {
154            break;
155        }
156        
157        let col = tokens[pos].trim_end_matches(',').to_string();
158        
159        // Check for aggregate functions
160        if let Some(agg) = try_parse_aggregate(tokens, &mut pos) {
161            aggregates.push(agg);
162        } else {
163            columns.push(col);
164        }
165        
166        // Advance past comma separator if present, otherwise advance past current token
167        pos += 1;
168    }
169    
170    // Expect FROM
171    if pos >= tokens.len() || tokens[pos].to_uppercase() != "FROM" {
172        return Err(SdkError::InvalidQuery("Expected FROM keyword".to_string()));
173    }
174    pos += 1;
175    
176    // Table name
177    if pos >= tokens.len() {
178        return Err(SdkError::InvalidQuery("Expected table name after FROM".to_string()));
179    }
180    let table = &tokens[pos];
181    pos += 1;
182    
183    // Get all data from table via dynamic FFI
184    let mut data = db.scan(table)?;
185    
186    // Parse WHERE clause
187    if pos < tokens.len() && tokens[pos].to_uppercase() == "WHERE" {
188        pos += 1;
189        data = apply_where_filter(data, tokens, &mut pos);
190    }
191    
192    // Handle aggregates
193    if !aggregates.is_empty() {
194        let mut result_row = serde_json::Map::new();
195        for (func, arg) in &aggregates {
196            let value = execute_aggregate(func, arg, &data);
197            let key = format!("{}({})", func, arg);
198            result_row.insert(key, value);
199        }
200        return Ok(QueryResult {
201            rows: vec![Value::Object(result_row)],
202            columns: vec![],
203            rows_affected: 0,
204            execution_time_ms: 0.0,
205        });
206    }
207    
208    // Parse ORDER BY
209    if pos < tokens.len() && tokens[pos].to_uppercase() == "ORDER" {
210        pos += 1;
211        if pos < tokens.len() && tokens[pos].to_uppercase() == "BY" {
212            pos += 1;
213        }
214        if pos < tokens.len() {
215            let sort_col = tokens[pos].trim_end_matches(',').to_string();
216            pos += 1;
217            let descending = if pos < tokens.len() && tokens[pos].to_uppercase() == "DESC" {
218                pos += 1;
219                true
220            } else {
221                if pos < tokens.len() && tokens[pos].to_uppercase() == "ASC" {
222                    pos += 1;
223                }
224                false
225            };
226            sort_data(&mut data, &sort_col, descending);
227        }
228    }
229    
230    // Parse LIMIT
231    let mut limit: Option<usize> = None;
232    if pos < tokens.len() && tokens[pos].to_uppercase() == "LIMIT" {
233        pos += 1;
234        if pos < tokens.len() {
235            limit = tokens[pos].parse().ok();
236            pos += 1;
237        }
238    }
239    
240    // Parse OFFSET
241    let mut offset: usize = 0;
242    if pos < tokens.len() && tokens[pos].to_uppercase() == "OFFSET" {
243        pos += 1;
244        if pos < tokens.len() {
245            offset = tokens[pos].parse().unwrap_or(0);
246            let _ = pos;
247        }
248    }
249    
250    // Apply offset and limit
251    if offset > 0 {
252        if offset >= data.len() {
253            data.clear();
254        } else {
255            data = data.into_iter().skip(offset).collect();
256        }
257    }
258    
259    if let Some(lim) = limit {
260        data.truncate(lim);
261    }
262    
263    // Column projection (if not SELECT *)
264    let is_select_all = columns.len() == 1 && columns[0] == "*";
265    if !is_select_all && !columns.is_empty() {
266        data = data.into_iter().map(|row| {
267            if let Value::Object(map) = &row {
268                let mut projected = serde_json::Map::new();
269                for col in &columns {
270                    let col_clean = col.trim_end_matches(',');
271                    if let Some(val) = map.get(col_clean) {
272                        projected.insert(col_clean.to_string(), val.clone());
273                    }
274                }
275                Value::Object(projected)
276            } else {
277                row
278            }
279        }).collect();
280    }
281    
282    Ok(QueryResult {
283        rows: data,
284        columns,
285        rows_affected: 0,
286        execution_time_ms: 0.0,
287    })
288}
289
290/// Try to parse an aggregate function like COUNT(*), SUM(col), AVG(col)
291fn try_parse_aggregate(tokens: &[String], pos: &mut usize) -> Option<(String, String)> {
292    let upper = tokens[*pos].to_uppercase();
293    let func_names = ["COUNT", "SUM", "AVG", "MIN", "MAX"];
294    
295    if !func_names.contains(&upper.as_str()) {
296        return None;
297    }
298    
299    if *pos + 1 >= tokens.len() || tokens[*pos + 1] != "(" {
300        let combined = upper.clone();
301        if combined.contains('(') {
302            // Parse inline like COUNT(*)
303            let paren_start = combined.find('(')?;
304            let paren_end = combined.find(')')?;
305            let func = &combined[..paren_start];
306            let arg = &combined[paren_start+1..paren_end];
307            return Some((func.to_string(), arg.to_string()));
308        }
309        return None;
310    }
311    
312    let func_name = tokens[*pos].to_uppercase();
313    *pos += 1; // skip func name
314    *pos += 1; // skip (
315    
316    let arg = if *pos < tokens.len() {
317        let a = tokens[*pos].clone();
318        *pos += 1;
319        a
320    } else {
321        return None;
322    };
323    
324    if *pos < tokens.len() && tokens[*pos] == ")" {
325        *pos += 1;
326    }
327    
328    Some((func_name, arg))
329}
330
331/// Execute an aggregate function
332fn execute_aggregate(func: &str, arg: &str, data: &[Value]) -> Value {
333    match func {
334        "COUNT" => Value::from(data.len()),
335        "SUM" => {
336            let sum: f64 = data.iter()
337                .filter_map(|row| row.get(arg).and_then(|v| v.as_f64()))
338                .sum();
339            Value::from(sum)
340        }
341        "AVG" => {
342            let vals: Vec<f64> = data.iter()
343                .filter_map(|row| row.get(arg).and_then(|v| v.as_f64()))
344                .collect();
345            if vals.is_empty() {
346                Value::Null
347            } else {
348                Value::from(vals.iter().sum::<f64>() / vals.len() as f64)
349            }
350        }
351        "MIN" => {
352            data.iter()
353                .filter_map(|row| row.get(arg).and_then(|v| v.as_f64()))
354                .fold(None, |min: Option<f64>, v| Some(min.map_or(v, |m: f64| m.min(v))))
355                .map(Value::from)
356                .unwrap_or(Value::Null)
357        }
358        "MAX" => {
359            data.iter()
360                .filter_map(|row| row.get(arg).and_then(|v| v.as_f64()))
361                .fold(None, |max: Option<f64>, v| Some(max.map_or(v, |m: f64| m.max(v))))
362                .map(Value::from)
363                .unwrap_or(Value::Null)
364        }
365        _ => Value::Null,
366    }
367}
368
369/// Sort data by a column
370fn sort_data(data: &mut [Value], column: &str, descending: bool) {
371    data.sort_by(|a, b| {
372        let va = a.get(column);
373        let vb = b.get(column);
374        
375        let cmp = match (va, vb) {
376            (Some(Value::Number(a)), Some(Value::Number(b))) => {
377                a.as_f64().unwrap_or(0.0).partial_cmp(&b.as_f64().unwrap_or(0.0))
378                    .unwrap_or(std::cmp::Ordering::Equal)
379            }
380            (Some(Value::String(a)), Some(Value::String(b))) => a.cmp(b),
381            (Some(_), None) => std::cmp::Ordering::Less,
382            (None, Some(_)) => std::cmp::Ordering::Greater,
383            _ => std::cmp::Ordering::Equal,
384        };
385        
386        if descending { cmp.reverse() } else { cmp }
387    });
388}
389
390/// Apply WHERE filtering to data
391fn apply_where_filter(data: Vec<Value>, tokens: &[String], pos: &mut usize) -> Vec<Value> {
392    let mut conditions: Vec<(String, String, String)> = Vec::new();
393    let mut logical_ops: Vec<String> = Vec::new();
394    
395    while *pos < tokens.len() {
396        let upper = tokens[*pos].to_uppercase();
397        if ["ORDER", "LIMIT", "OFFSET", "GROUP"].contains(&upper.as_str()) {
398            break;
399        }
400        
401        if *pos + 2 < tokens.len() {
402            let col = tokens[*pos].clone();
403            let op = tokens[*pos + 1].clone();
404            let val = tokens[*pos + 2].clone();
405            conditions.push((col, op, val));
406            *pos += 3;
407            
408            if *pos < tokens.len() {
409                let next_upper = tokens[*pos].to_uppercase();
410                if next_upper == "AND" || next_upper == "OR" {
411                    logical_ops.push(next_upper);
412                    *pos += 1;
413                }
414            }
415        } else {
416            break;
417        }
418    }
419    
420    if conditions.is_empty() {
421        return data;
422    }
423    
424    data.into_iter().filter(|row| {
425        let mut result = evaluate_condition(row, &conditions[0]);
426        
427        for i in 0..logical_ops.len() {
428            if i + 1 < conditions.len() {
429                let next_result = evaluate_condition(row, &conditions[i + 1]);
430                result = match logical_ops[i].as_str() {
431                    "AND" => result && next_result,
432                    "OR" => result || next_result,
433                    _ => result,
434                };
435            }
436        }
437        
438        result
439    }).collect()
440}
441
442/// Evaluate a single WHERE condition against a row
443fn evaluate_condition(row: &Value, condition: &(String, String, String)) -> bool {
444    let (col, op, val) = condition;
445    
446    let field_val = match row.get(col) {
447        Some(v) => v,
448        None => return false,
449    };
450    
451    let clean_val = val.trim_matches('\'').trim_matches('"');
452    
453    match op.as_str() {
454        "=" | "==" => {
455            if let Ok(n) = clean_val.parse::<f64>() {
456                field_val.as_f64().is_some_and(|fv| (fv - n).abs() < f64::EPSILON)
457            } else {
458                field_val.as_str().is_some_and(|s| s == clean_val)
459                    || field_val.to_string().trim_matches('"') == clean_val
460            }
461        }
462        "!=" | "<>" => {
463            if let Ok(n) = clean_val.parse::<f64>() {
464                !field_val.as_f64().is_some_and(|fv| (fv - n).abs() < f64::EPSILON)
465            } else {
466                field_val.as_str().is_none_or(|s| s != clean_val)
467            }
468        }
469        ">" => compare_values(field_val, clean_val) > 0,
470        "<" => compare_values(field_val, clean_val) < 0,
471        ">=" => compare_values(field_val, clean_val) >= 0,
472        "<=" => compare_values(field_val, clean_val) <= 0,
473        _ => false,
474    }
475}
476
477/// Compare a JSON value with a string value
478fn compare_values(field: &Value, val: &str) -> i32 {
479    if let Ok(n) = val.parse::<f64>() {
480        if let Some(fv) = field.as_f64() {
481            return if fv > n { 1 } else if fv < n { -1 } else { 0 };
482        }
483    }
484    if let Some(s) = field.as_str() {
485        return s.cmp(val) as i32;
486    }
487    0
488}
489
490/// Execute INSERT query
491fn execute_insert(db: &mut OverDriveDB, tokens: &[String], raw_sql: &str) -> SdkResult<QueryResult> {
492    if tokens.len() < 3 || tokens[1].to_uppercase() != "INTO" {
493        return Err(SdkError::InvalidQuery("Expected INSERT INTO <table> {json}".to_string()));
494    }
495    
496    let table = &tokens[2];
497    
498    let json_str = if let Some(brace_pos) = raw_sql.find('{') {
499        &raw_sql[brace_pos..]
500    } else {
501        return Err(SdkError::InvalidQuery("Expected JSON object after table name".to_string()));
502    };
503    
504    let value: Value = serde_json::from_str(json_str)
505        .map_err(|e| SdkError::InvalidQuery(format!("Invalid JSON: {}", e)))?;
506    
507    let id = db.insert(table, &value)?;
508    
509    Ok(QueryResult {
510        rows: vec![serde_json::json!({"_id": id})],
511        columns: vec!["_id".to_string()],
512        rows_affected: 1,
513        execution_time_ms: 0.0,
514    })
515}
516
517/// Execute UPDATE query
518fn execute_update(db: &mut OverDriveDB, tokens: &[String], raw_sql: &str) -> SdkResult<QueryResult> {
519    if tokens.len() < 3 {
520        return Err(SdkError::InvalidQuery("Expected UPDATE <table> SET {json}".to_string()));
521    }
522    
523    let table = tokens[1].clone();
524    
525    let set_pos = tokens.iter().position(|t| t.to_uppercase() == "SET")
526        .ok_or_else(|| SdkError::InvalidQuery("Expected SET keyword".to_string()))?;
527    
528    let json_str = if let Some(brace_pos) = raw_sql.find('{') {
529        let sub = &raw_sql[brace_pos..];
530        let mut depth = 0;
531        let mut end = 0;
532        for (i, c) in sub.chars().enumerate() {
533            if c == '{' { depth += 1; }
534            if c == '}' { depth -= 1; }
535            if depth == 0 { end = i + 1; break; }
536        }
537        &raw_sql[brace_pos..brace_pos + end]
538    } else {
539        return Err(SdkError::InvalidQuery("Expected {updates} after SET".to_string()));
540    };
541    
542    let updates: Value = serde_json::from_str(json_str)
543        .map_err(|e| SdkError::InvalidQuery(format!("Invalid JSON: {}", e)))?;
544    
545    // Get all data via scan to find matching rows
546    let all_data = db.scan(&table)?;
547    
548    // Parse WHERE if present
549    let mut where_pos = set_pos + 1;
550    while where_pos < tokens.len() && tokens[where_pos].to_uppercase() != "WHERE" {
551        where_pos += 1;
552    }
553    
554    let matched_ids: Vec<String> = if where_pos < tokens.len() && tokens[where_pos].to_uppercase() == "WHERE" {
555        where_pos += 1;
556        let filtered = apply_where_filter(all_data, tokens, &mut where_pos);
557        filtered.iter()
558            .filter_map(|r| r.get("_id").and_then(|v| v.as_str()).map(|s| s.to_string()))
559            .collect()
560    } else {
561        all_data.iter()
562            .filter_map(|r| r.get("_id").and_then(|v| v.as_str()).map(|s| s.to_string()))
563            .collect()
564    };
565    
566    let mut affected = 0;
567    for id in &matched_ids {
568        if db.update(&table, id, &updates)? {
569            affected += 1;
570        }
571    }
572    
573    Ok(QueryResult {
574        rows: Vec::new(),
575        columns: Vec::new(),
576        rows_affected: affected,
577        execution_time_ms: 0.0,
578    })
579}
580
581/// Execute DELETE query
582fn execute_delete(db: &mut OverDriveDB, tokens: &[String]) -> SdkResult<QueryResult> {
583    if tokens.len() < 3 || tokens[1].to_uppercase() != "FROM" {
584        return Err(SdkError::InvalidQuery("Expected DELETE FROM <table>".to_string()));
585    }
586    
587    let table = tokens[2].clone();
588    
589    let all_data = db.scan(&table)?;
590    
591    let mut pos = 3;
592    let matched: Vec<Value> = if pos < tokens.len() && tokens[pos].to_uppercase() == "WHERE" {
593        pos += 1;
594        apply_where_filter(all_data, tokens, &mut pos)
595    } else {
596        all_data
597    };
598    
599    let ids: Vec<String> = matched.iter()
600        .filter_map(|r| r.get("_id").and_then(|v| v.as_str()).map(|s| s.to_string()))
601        .collect();
602    
603    let mut affected = 0;
604    for id in &ids {
605        if db.delete(&table, id)? {
606            affected += 1;
607        }
608    }
609    
610    Ok(QueryResult {
611        rows: Vec::new(),
612        columns: Vec::new(),
613        rows_affected: affected,
614        execution_time_ms: 0.0,
615    })
616}
617
618/// Execute CREATE TABLE
619fn execute_create(db: &mut OverDriveDB, tokens: &[String]) -> SdkResult<QueryResult> {
620    if tokens.len() < 3 {
621        return Err(SdkError::InvalidQuery("Expected CREATE TABLE <name>".to_string()));
622    }
623    let kw = tokens[1].to_uppercase();
624    if kw != "TABLE" && kw != "TB" {
625        return Err(SdkError::InvalidQuery("Expected CREATE TABLE".to_string()));
626    }
627    let name = &tokens[2];
628    db.create_table(name)?;
629    
630    Ok(QueryResult::empty())
631}
632
633/// Execute DROP TABLE
634fn execute_drop(db: &mut OverDriveDB, tokens: &[String]) -> SdkResult<QueryResult> {
635    if tokens.len() < 3 {
636        return Err(SdkError::InvalidQuery("Expected DROP TABLE <name>".to_string()));
637    }
638    let kw = tokens[1].to_uppercase();
639    if kw != "TABLE" && kw != "TB" {
640        return Err(SdkError::InvalidQuery("Expected DROP TABLE".to_string()));
641    }
642    let name = &tokens[2];
643    db.drop_table(name)?;
644    
645    Ok(QueryResult::empty())
646}
647
648/// Execute SHOW TABLES
649fn execute_show(db: &mut OverDriveDB, tokens: &[String]) -> SdkResult<QueryResult> {
650    if tokens.len() < 2 {
651        return Err(SdkError::InvalidQuery("Expected SHOW TABLES".to_string()));
652    }
653    let kw = tokens[1].to_uppercase();
654    if kw != "TABLES" && kw != "TABLE" && kw != "TB" {
655        return Err(SdkError::InvalidQuery("Expected SHOW TABLES".to_string()));
656    }
657    
658    let tables = db.list_tables()?;
659    let rows: Vec<Value> = tables.iter()
660        .map(|t| serde_json::json!({"table_name": t}))
661        .collect();
662    
663    Ok(QueryResult {
664        rows,
665        columns: vec!["table_name".to_string()],
666        rows_affected: 0,
667        execution_time_ms: 0.0,
668    })
669}