Skip to main content

overdrive/
query_engine.rs

1//! Embedded SQL Query Engine for OverDrive InCode SDK
2//! 
3//! Parses and executes SQL queries directly against the embedded database.
4//! No server needed — runs entirely in-process.
5//! 
6//! ## Supported SQL
7//! 
8//! - `SELECT [columns] FROM <table> [WHERE ...] [ORDER BY ...] [LIMIT n] [OFFSET n]`
9//! - `INSERT INTO <table> {json}`
10//! - `UPDATE <table> SET {json} [WHERE ...]`
11//! - `DELETE FROM <table> [WHERE ...]`
12//! - `CREATE TABLE <name>`
13//! - `DROP TABLE <name>`
14//! - `SHOW TABLES`
15//! - `SELECT COUNT(*) FROM <table>` (and other aggregates)
16
17use crate::{OverDriveDB, QueryResult};
18use crate::result::{SdkResult, SdkError};
19use serde_json::Value;
20
21/// Execute an SQL query against the embedded database
22pub fn execute(db: &mut OverDriveDB, sql: &str) -> SdkResult<QueryResult> {
23    let sql = sql.trim().trim_end_matches(';').trim();
24    
25    if sql.is_empty() {
26        return Ok(QueryResult::empty());
27    }
28    
29    // Tokenize: split into words respecting quoted strings and braces
30    let tokens = tokenize(sql);
31    
32    if tokens.is_empty() {
33        return Ok(QueryResult::empty());
34    }
35    
36    let first = tokens[0].to_uppercase();
37    
38    match first.as_str() {
39        "SELECT" => execute_select(db, &tokens),
40        "INSERT" => execute_insert(db, &tokens, sql),
41        "UPDATE" => execute_update(db, &tokens, sql),
42        "DELETE" => execute_delete(db, &tokens),
43        "CREATE" => execute_create(db, &tokens),
44        "DROP"   => execute_drop(db, &tokens),
45        "SHOW"   => execute_show(db, &tokens),
46        _ => Err(SdkError::InvalidQuery(format!("Unsupported SQL command: {}", first))),
47    }
48}
49
50/// Simple tokenizer that respects quoted strings and JSON braces
51fn tokenize(input: &str) -> Vec<String> {
52    let mut tokens = Vec::new();
53    let mut chars = input.chars().peekable();
54    
55    while let Some(&c) = chars.peek() {
56        if c.is_whitespace() {
57            chars.next();
58            continue;
59        }
60        
61        // Quoted string
62        if c == '\'' || c == '"' {
63            let quote = c;
64            chars.next();
65            let mut s = String::new();
66            while let Some(&ch) = chars.peek() {
67                if ch == quote {
68                    chars.next();
69                    break;
70                }
71                if ch == '\\' {
72                    chars.next();
73                    if let Some(&escaped) = chars.peek() {
74                        s.push(escaped);
75                        chars.next();
76                        continue;
77                    }
78                }
79                s.push(ch);
80                chars.next();
81            }
82            tokens.push(format!("'{}'", s));
83            continue;
84        }
85        
86        // JSON object or array
87        if c == '{' || c == '[' {
88            let (close, open) = if c == '{' { ('}', '{') } else { (']', '[') };
89            let mut depth = 0;
90            let mut s = String::new();
91            while let Some(&ch) = chars.peek() {
92                s.push(ch);
93                if ch == open { depth += 1; }
94                if ch == close { depth -= 1; }
95                chars.next();
96                if depth == 0 { break; }
97            }
98            tokens.push(s);
99            continue;
100        }
101        
102        // Operators
103        if c == '>' || c == '<' || c == '!' || c == '=' {
104            let mut op = String::new();
105            op.push(c);
106            chars.next();
107            if let Some(&next) = chars.peek() {
108                if next == '=' {
109                    op.push(next);
110                    chars.next();
111                }
112            }
113            tokens.push(op);
114            continue;
115        }
116        
117        // Comma, parentheses, star
118        if c == ',' || c == '(' || c == ')' || c == '*' {
119            tokens.push(c.to_string());
120            chars.next();
121            continue;
122        }
123        
124        // Word or number
125        let mut word = String::new();
126        while let Some(&ch) = chars.peek() {
127            if ch.is_whitespace() || ch == ',' || ch == '(' || ch == ')' 
128               || ch == '{' || ch == '[' || ch == '>' || ch == '<' 
129               || ch == '=' || ch == '!' {
130                break;
131            }
132            word.push(ch);
133            chars.next();
134        }
135        if !word.is_empty() {
136            tokens.push(word);
137        }
138    }
139    
140    tokens
141}
142
143/// Execute SELECT query
144fn execute_select(db: &mut OverDriveDB, tokens: &[String]) -> SdkResult<QueryResult> {
145    let mut pos = 1; // skip "SELECT"
146    
147    // Parse columns
148    let mut columns: Vec<String> = Vec::new();
149    let mut aggregates: Vec<(String, String)> = Vec::new(); // (func_name, arg)
150    
151    while pos < tokens.len() {
152        let upper = tokens[pos].to_uppercase();
153        if upper == "FROM" {
154            break;
155        }
156        
157        let col = tokens[pos].trim_end_matches(',').to_string();
158        
159        // Check for aggregate functions
160        if let Some(agg) = try_parse_aggregate(tokens, &mut pos) {
161            aggregates.push(agg);
162        } else {
163            columns.push(col);
164        }
165        
166        if pos < tokens.len() && tokens[pos] == "," {
167            pos += 1;
168        } else {
169            pos += 1;
170        }
171    }
172    
173    // Expect FROM
174    if pos >= tokens.len() || tokens[pos].to_uppercase() != "FROM" {
175        return Err(SdkError::InvalidQuery("Expected FROM keyword".to_string()));
176    }
177    pos += 1;
178    
179    // Table name
180    if pos >= tokens.len() {
181        return Err(SdkError::InvalidQuery("Expected table name after FROM".to_string()));
182    }
183    let table = &tokens[pos];
184    pos += 1;
185    
186    // Get all data from table via dynamic FFI
187    let mut data = db.scan(table)?;
188    
189    // Parse WHERE clause
190    if pos < tokens.len() && tokens[pos].to_uppercase() == "WHERE" {
191        pos += 1;
192        data = apply_where_filter(data, tokens, &mut pos);
193    }
194    
195    // Handle aggregates
196    if !aggregates.is_empty() {
197        let mut result_row = serde_json::Map::new();
198        for (func, arg) in &aggregates {
199            let value = execute_aggregate(func, arg, &data);
200            let key = format!("{}({})", func, arg);
201            result_row.insert(key, value);
202        }
203        return Ok(QueryResult {
204            rows: vec![Value::Object(result_row)],
205            columns: vec![],
206            rows_affected: 0,
207            execution_time_ms: 0.0,
208        });
209    }
210    
211    // Parse ORDER BY
212    if pos < tokens.len() && tokens[pos].to_uppercase() == "ORDER" {
213        pos += 1;
214        if pos < tokens.len() && tokens[pos].to_uppercase() == "BY" {
215            pos += 1;
216        }
217        if pos < tokens.len() {
218            let sort_col = tokens[pos].trim_end_matches(',').to_string();
219            pos += 1;
220            let descending = if pos < tokens.len() && tokens[pos].to_uppercase() == "DESC" {
221                pos += 1;
222                true
223            } else {
224                if pos < tokens.len() && tokens[pos].to_uppercase() == "ASC" {
225                    pos += 1;
226                }
227                false
228            };
229            sort_data(&mut data, &sort_col, descending);
230        }
231    }
232    
233    // Parse LIMIT
234    let mut limit: Option<usize> = None;
235    if pos < tokens.len() && tokens[pos].to_uppercase() == "LIMIT" {
236        pos += 1;
237        if pos < tokens.len() {
238            limit = tokens[pos].parse().ok();
239            pos += 1;
240        }
241    }
242    
243    // Parse OFFSET
244    let mut offset: usize = 0;
245    if pos < tokens.len() && tokens[pos].to_uppercase() == "OFFSET" {
246        pos += 1;
247        if pos < tokens.len() {
248            offset = tokens[pos].parse().unwrap_or(0);
249            let _ = pos;
250        }
251    }
252    
253    // Apply offset and limit
254    if offset > 0 {
255        if offset >= data.len() {
256            data.clear();
257        } else {
258            data = data.into_iter().skip(offset).collect();
259        }
260    }
261    
262    if let Some(lim) = limit {
263        data.truncate(lim);
264    }
265    
266    // Column projection (if not SELECT *)
267    let is_select_all = columns.len() == 1 && columns[0] == "*";
268    if !is_select_all && !columns.is_empty() {
269        data = data.into_iter().map(|row| {
270            if let Value::Object(map) = &row {
271                let mut projected = serde_json::Map::new();
272                for col in &columns {
273                    let col_clean = col.trim_end_matches(',');
274                    if let Some(val) = map.get(col_clean) {
275                        projected.insert(col_clean.to_string(), val.clone());
276                    }
277                }
278                Value::Object(projected)
279            } else {
280                row
281            }
282        }).collect();
283    }
284    
285    Ok(QueryResult {
286        rows: data,
287        columns,
288        rows_affected: 0,
289        execution_time_ms: 0.0,
290    })
291}
292
293/// Try to parse an aggregate function like COUNT(*), SUM(col), AVG(col)
294fn try_parse_aggregate(tokens: &[String], pos: &mut usize) -> Option<(String, String)> {
295    let upper = tokens[*pos].to_uppercase();
296    let func_names = ["COUNT", "SUM", "AVG", "MIN", "MAX"];
297    
298    if !func_names.contains(&upper.as_str()) {
299        return None;
300    }
301    
302    if *pos + 1 >= tokens.len() || tokens[*pos + 1] != "(" {
303        let combined = upper.clone();
304        if combined.contains('(') {
305            // Parse inline like COUNT(*)
306            let paren_start = combined.find('(')?;
307            let paren_end = combined.find(')')?;
308            let func = &combined[..paren_start];
309            let arg = &combined[paren_start+1..paren_end];
310            return Some((func.to_string(), arg.to_string()));
311        }
312        return None;
313    }
314    
315    let func_name = tokens[*pos].to_uppercase();
316    *pos += 1; // skip func name
317    *pos += 1; // skip (
318    
319    let arg = if *pos < tokens.len() {
320        let a = tokens[*pos].clone();
321        *pos += 1;
322        a
323    } else {
324        return None;
325    };
326    
327    if *pos < tokens.len() && tokens[*pos] == ")" {
328        *pos += 1;
329    }
330    
331    Some((func_name, arg))
332}
333
334/// Execute an aggregate function
335fn execute_aggregate(func: &str, arg: &str, data: &[Value]) -> Value {
336    match func {
337        "COUNT" => Value::from(data.len()),
338        "SUM" => {
339            let sum: f64 = data.iter()
340                .filter_map(|row| row.get(arg).and_then(|v| v.as_f64()))
341                .sum();
342            Value::from(sum)
343        }
344        "AVG" => {
345            let vals: Vec<f64> = data.iter()
346                .filter_map(|row| row.get(arg).and_then(|v| v.as_f64()))
347                .collect();
348            if vals.is_empty() {
349                Value::Null
350            } else {
351                Value::from(vals.iter().sum::<f64>() / vals.len() as f64)
352            }
353        }
354        "MIN" => {
355            data.iter()
356                .filter_map(|row| row.get(arg).and_then(|v| v.as_f64()))
357                .fold(None, |min: Option<f64>, v| Some(min.map_or(v, |m: f64| m.min(v))))
358                .map(Value::from)
359                .unwrap_or(Value::Null)
360        }
361        "MAX" => {
362            data.iter()
363                .filter_map(|row| row.get(arg).and_then(|v| v.as_f64()))
364                .fold(None, |max: Option<f64>, v| Some(max.map_or(v, |m: f64| m.max(v))))
365                .map(Value::from)
366                .unwrap_or(Value::Null)
367        }
368        _ => Value::Null,
369    }
370}
371
372/// Sort data by a column
373fn sort_data(data: &mut Vec<Value>, column: &str, descending: bool) {
374    data.sort_by(|a, b| {
375        let va = a.get(column);
376        let vb = b.get(column);
377        
378        let cmp = match (va, vb) {
379            (Some(Value::Number(a)), Some(Value::Number(b))) => {
380                a.as_f64().unwrap_or(0.0).partial_cmp(&b.as_f64().unwrap_or(0.0))
381                    .unwrap_or(std::cmp::Ordering::Equal)
382            }
383            (Some(Value::String(a)), Some(Value::String(b))) => a.cmp(b),
384            (Some(_), None) => std::cmp::Ordering::Less,
385            (None, Some(_)) => std::cmp::Ordering::Greater,
386            _ => std::cmp::Ordering::Equal,
387        };
388        
389        if descending { cmp.reverse() } else { cmp }
390    });
391}
392
393/// Apply WHERE filtering to data
394fn apply_where_filter(data: Vec<Value>, tokens: &[String], pos: &mut usize) -> Vec<Value> {
395    let mut conditions: Vec<(String, String, String)> = Vec::new();
396    let mut logical_ops: Vec<String> = Vec::new();
397    
398    while *pos < tokens.len() {
399        let upper = tokens[*pos].to_uppercase();
400        if ["ORDER", "LIMIT", "OFFSET", "GROUP"].contains(&upper.as_str()) {
401            break;
402        }
403        
404        if *pos + 2 < tokens.len() {
405            let col = tokens[*pos].clone();
406            let op = tokens[*pos + 1].clone();
407            let val = tokens[*pos + 2].clone();
408            conditions.push((col, op, val));
409            *pos += 3;
410            
411            if *pos < tokens.len() {
412                let next_upper = tokens[*pos].to_uppercase();
413                if next_upper == "AND" || next_upper == "OR" {
414                    logical_ops.push(next_upper);
415                    *pos += 1;
416                }
417            }
418        } else {
419            break;
420        }
421    }
422    
423    if conditions.is_empty() {
424        return data;
425    }
426    
427    data.into_iter().filter(|row| {
428        let mut result = evaluate_condition(row, &conditions[0]);
429        
430        for i in 0..logical_ops.len() {
431            if i + 1 < conditions.len() {
432                let next_result = evaluate_condition(row, &conditions[i + 1]);
433                result = match logical_ops[i].as_str() {
434                    "AND" => result && next_result,
435                    "OR" => result || next_result,
436                    _ => result,
437                };
438            }
439        }
440        
441        result
442    }).collect()
443}
444
445/// Evaluate a single WHERE condition against a row
446fn evaluate_condition(row: &Value, condition: &(String, String, String)) -> bool {
447    let (col, op, val) = condition;
448    
449    let field_val = match row.get(col) {
450        Some(v) => v,
451        None => return false,
452    };
453    
454    let clean_val = val.trim_matches('\'').trim_matches('"');
455    
456    match op.as_str() {
457        "=" | "==" => {
458            if let Ok(n) = clean_val.parse::<f64>() {
459                field_val.as_f64().map_or(false, |fv| (fv - n).abs() < f64::EPSILON)
460            } else {
461                field_val.as_str().map_or(false, |s| s == clean_val)
462                    || field_val.to_string().trim_matches('"') == clean_val
463            }
464        }
465        "!=" | "<>" => {
466            if let Ok(n) = clean_val.parse::<f64>() {
467                field_val.as_f64().map_or(true, |fv| (fv - n).abs() >= f64::EPSILON)
468            } else {
469                field_val.as_str().map_or(true, |s| s != clean_val)
470            }
471        }
472        ">" => compare_values(field_val, clean_val) > 0,
473        "<" => compare_values(field_val, clean_val) < 0,
474        ">=" => compare_values(field_val, clean_val) >= 0,
475        "<=" => compare_values(field_val, clean_val) <= 0,
476        _ => false,
477    }
478}
479
480/// Compare a JSON value with a string value
481fn compare_values(field: &Value, val: &str) -> i32 {
482    if let Ok(n) = val.parse::<f64>() {
483        if let Some(fv) = field.as_f64() {
484            return if fv > n { 1 } else if fv < n { -1 } else { 0 };
485        }
486    }
487    if let Some(s) = field.as_str() {
488        return s.cmp(val) as i32;
489    }
490    0
491}
492
493/// Execute INSERT query
494fn execute_insert(db: &mut OverDriveDB, tokens: &[String], raw_sql: &str) -> SdkResult<QueryResult> {
495    if tokens.len() < 3 || tokens[1].to_uppercase() != "INTO" {
496        return Err(SdkError::InvalidQuery("Expected INSERT INTO <table> {json}".to_string()));
497    }
498    
499    let table = &tokens[2];
500    
501    let json_str = if let Some(brace_pos) = raw_sql.find('{') {
502        &raw_sql[brace_pos..]
503    } else {
504        return Err(SdkError::InvalidQuery("Expected JSON object after table name".to_string()));
505    };
506    
507    let value: Value = serde_json::from_str(json_str)
508        .map_err(|e| SdkError::InvalidQuery(format!("Invalid JSON: {}", e)))?;
509    
510    let id = db.insert(table, &value)?;
511    
512    Ok(QueryResult {
513        rows: vec![serde_json::json!({"_id": id})],
514        columns: vec!["_id".to_string()],
515        rows_affected: 1,
516        execution_time_ms: 0.0,
517    })
518}
519
520/// Execute UPDATE query
521fn execute_update(db: &mut OverDriveDB, tokens: &[String], raw_sql: &str) -> SdkResult<QueryResult> {
522    if tokens.len() < 3 {
523        return Err(SdkError::InvalidQuery("Expected UPDATE <table> SET {json}".to_string()));
524    }
525    
526    let table = tokens[1].clone();
527    
528    let set_pos = tokens.iter().position(|t| t.to_uppercase() == "SET")
529        .ok_or_else(|| SdkError::InvalidQuery("Expected SET keyword".to_string()))?;
530    
531    let json_str = if let Some(brace_pos) = raw_sql.find('{') {
532        let sub = &raw_sql[brace_pos..];
533        let mut depth = 0;
534        let mut end = 0;
535        for (i, c) in sub.chars().enumerate() {
536            if c == '{' { depth += 1; }
537            if c == '}' { depth -= 1; }
538            if depth == 0 { end = i + 1; break; }
539        }
540        &raw_sql[brace_pos..brace_pos + end]
541    } else {
542        return Err(SdkError::InvalidQuery("Expected {updates} after SET".to_string()));
543    };
544    
545    let updates: Value = serde_json::from_str(json_str)
546        .map_err(|e| SdkError::InvalidQuery(format!("Invalid JSON: {}", e)))?;
547    
548    // Get all data via scan to find matching rows
549    let all_data = db.scan(&table)?;
550    
551    // Parse WHERE if present
552    let mut where_pos = set_pos + 1;
553    while where_pos < tokens.len() && tokens[where_pos].to_uppercase() != "WHERE" {
554        where_pos += 1;
555    }
556    
557    let matched_ids: Vec<String>;
558    if where_pos < tokens.len() && tokens[where_pos].to_uppercase() == "WHERE" {
559        where_pos += 1;
560        let filtered = apply_where_filter(all_data, tokens, &mut where_pos);
561        matched_ids = filtered.iter()
562            .filter_map(|r| r.get("_id").and_then(|v| v.as_str()).map(|s| s.to_string()))
563            .collect();
564    } else {
565        matched_ids = all_data.iter()
566            .filter_map(|r| r.get("_id").and_then(|v| v.as_str()).map(|s| s.to_string()))
567            .collect();
568    }
569    
570    let mut affected = 0;
571    for id in &matched_ids {
572        if db.update(&table, id, &updates)? {
573            affected += 1;
574        }
575    }
576    
577    Ok(QueryResult {
578        rows: Vec::new(),
579        columns: Vec::new(),
580        rows_affected: affected,
581        execution_time_ms: 0.0,
582    })
583}
584
585/// Execute DELETE query
586fn execute_delete(db: &mut OverDriveDB, tokens: &[String]) -> SdkResult<QueryResult> {
587    if tokens.len() < 3 || tokens[1].to_uppercase() != "FROM" {
588        return Err(SdkError::InvalidQuery("Expected DELETE FROM <table>".to_string()));
589    }
590    
591    let table = tokens[2].clone();
592    
593    let all_data = db.scan(&table)?;
594    
595    let mut pos = 3;
596    let matched: Vec<Value>;
597    if pos < tokens.len() && tokens[pos].to_uppercase() == "WHERE" {
598        pos += 1;
599        matched = apply_where_filter(all_data, tokens, &mut pos);
600    } else {
601        matched = all_data;
602    }
603    
604    let ids: Vec<String> = matched.iter()
605        .filter_map(|r| r.get("_id").and_then(|v| v.as_str()).map(|s| s.to_string()))
606        .collect();
607    
608    let mut affected = 0;
609    for id in &ids {
610        if db.delete(&table, id)? {
611            affected += 1;
612        }
613    }
614    
615    Ok(QueryResult {
616        rows: Vec::new(),
617        columns: Vec::new(),
618        rows_affected: affected,
619        execution_time_ms: 0.0,
620    })
621}
622
623/// Execute CREATE TABLE
624fn execute_create(db: &mut OverDriveDB, tokens: &[String]) -> SdkResult<QueryResult> {
625    if tokens.len() < 3 {
626        return Err(SdkError::InvalidQuery("Expected CREATE TABLE <name>".to_string()));
627    }
628    let kw = tokens[1].to_uppercase();
629    if kw != "TABLE" && kw != "TB" {
630        return Err(SdkError::InvalidQuery("Expected CREATE TABLE".to_string()));
631    }
632    let name = &tokens[2];
633    db.create_table(name)?;
634    
635    Ok(QueryResult::empty())
636}
637
638/// Execute DROP TABLE
639fn execute_drop(db: &mut OverDriveDB, tokens: &[String]) -> SdkResult<QueryResult> {
640    if tokens.len() < 3 {
641        return Err(SdkError::InvalidQuery("Expected DROP TABLE <name>".to_string()));
642    }
643    let kw = tokens[1].to_uppercase();
644    if kw != "TABLE" && kw != "TB" {
645        return Err(SdkError::InvalidQuery("Expected DROP TABLE".to_string()));
646    }
647    let name = &tokens[2];
648    db.drop_table(name)?;
649    
650    Ok(QueryResult::empty())
651}
652
653/// Execute SHOW TABLES
654fn execute_show(db: &mut OverDriveDB, tokens: &[String]) -> SdkResult<QueryResult> {
655    if tokens.len() < 2 {
656        return Err(SdkError::InvalidQuery("Expected SHOW TABLES".to_string()));
657    }
658    let kw = tokens[1].to_uppercase();
659    if kw != "TABLES" && kw != "TABLE" && kw != "TB" {
660        return Err(SdkError::InvalidQuery("Expected SHOW TABLES".to_string()));
661    }
662    
663    let tables = db.list_tables()?;
664    let rows: Vec<Value> = tables.iter()
665        .map(|t| serde_json::json!({"table_name": t}))
666        .collect();
667    
668    Ok(QueryResult {
669        rows,
670        columns: vec!["table_name".to_string()],
671        rows_affected: 0,
672        execution_time_ms: 0.0,
673    })
674}