Skip to main content

cai_query/
parser.rs

1//! SQL query parser
2
3use crate::error::{QueryError, QueryResult};
4
5/// Parsed query representation
6#[derive(Debug, Clone, Default)]
7pub struct ParsedQuery {
8    /// Columns to select (empty means all)
9    pub select_wildcard: bool,
10    /// Selected column names
11    pub columns: Vec<String>,
12    /// Table name (must be "entries")
13    pub table: Option<String>,
14    /// WHERE clause as SQL string (for simple cases)
15    pub where_sql: Option<String>,
16    /// GROUP BY columns
17    pub group_by: Vec<String>,
18    /// ORDER BY columns
19    pub order_by: Vec<(String, bool)>, // (column, asc)
20    /// LIMIT value
21    pub limit: Option<usize>,
22    /// Has aggregate functions
23    pub has_aggregates: bool,
24    /// Query type for schema queries
25    pub query_type: QueryType,
26}
27
28/// Type of SQL query
29#[derive(Debug, Clone, Default, PartialEq)]
30pub enum QueryType {
31    /// Standard SELECT query
32    #[default]
33    Select,
34    /// SHOW TABLES query
35    ShowTables,
36    /// DESCRIBE table query
37    DescribeTable(String),
38}
39
40/// Parse a SQL query string
41pub fn parse(sql: &str) -> QueryResult<ParsedQuery> {
42    let sql_upper = sql.trim().to_uppercase();
43
44    // Handle SHOW TABLES
45    if sql_upper == "SHOW TABLES" || sql_upper.starts_with("SHOW TABLES;") {
46        return Ok(ParsedQuery {
47            query_type: QueryType::ShowTables,
48            ..Default::default()
49        });
50    }
51
52    // Handle DESCRIBE table
53    if sql_upper.starts_with("DESCRIBE ") || sql_upper.starts_with("DESC ") {
54        let keyword = if sql_upper.starts_with("DESCRIBE ") {
55            "DESCRIBE "
56        } else {
57            "DESC "
58        };
59        let table_name = sql[keyword.len()..].trim().to_string();
60        let table_name = table_name.trim_end_matches(';').trim().to_string();
61
62        if table_name.to_lowercase() != "entries" {
63            return Err(QueryError::InvalidTable(table_name));
64        }
65
66        return Ok(ParsedQuery {
67            query_type: QueryType::DescribeTable("entries".to_string()),
68            table: Some("entries".to_string()),
69            ..Default::default()
70        });
71    }
72
73    // Handle PRAGMA table_info (SQLite-style)
74    if sql_upper.starts_with("PRAGMA TABLE_INFO(") {
75        // Extract table name from PRAGMA table_info(entries)
76        let start = sql_upper.find('(').unwrap() + 1;
77        let end = sql_upper.find(')').unwrap();
78        let table_name = sql[start..end].trim().to_string();
79
80        if table_name.to_lowercase() != "entries" {
81            return Err(QueryError::InvalidTable(table_name));
82        }
83
84        return Ok(ParsedQuery {
85            query_type: QueryType::DescribeTable("entries".to_string()),
86            table: Some("entries".to_string()),
87            ..Default::default()
88        });
89    }
90
91    // Validate it's a SELECT statement
92    if !sql_upper.starts_with("SELECT") {
93        return Err(QueryError::ParseError(
94            "Only SELECT, SHOW TABLES, and DESCRIBE statements are supported".to_string(),
95        ));
96    }
97
98    // Check for FROM entries
99    if !sql_upper.contains("FROM") {
100        return Err(QueryError::ParseError("Missing FROM clause".to_string()));
101    }
102
103    // Extract table name
104    let table = if sql_upper.contains("FROM ENTRIES") {
105        Some("entries".to_string())
106    } else {
107        // Try to find what comes after FROM
108        let from_idx = sql_upper.find("FROM ").unwrap() + 5;
109        let table_part = &sql[from_idx..];
110        let table_end = table_part
111            .find(|c: char| c.is_whitespace())
112            .or_else(|| table_part.find(';'))
113            .unwrap_or(table_part.len());
114        let table_name = table_part[..table_end].trim().to_string();
115        // Validate table name
116        if table_name.to_lowercase() != "entries" {
117            return Err(QueryError::InvalidTable(table_name));
118        }
119        Some("entries".to_string())
120    };
121
122    // Check for LIMIT
123    let limit = if let Some(limit_idx) = sql_upper.find("LIMIT ") {
124        let limit_str = &sql[limit_idx + 6..];
125        let limit_end = limit_str
126            .find(|c: char| c.is_whitespace())
127            .or_else(|| limit_str.find(';'))
128            .unwrap_or(limit_str.len());
129        limit_str[..limit_end].trim().parse::<usize>().ok()
130    } else {
131        None
132    };
133
134    // Check for WHERE
135    let where_sql = if sql_upper.contains("WHERE ") {
136        let where_idx = sql_upper.find("WHERE ").unwrap() + 6;
137        let where_end = sql_upper[where_idx..]
138            .find(" GROUP BY")
139            .or_else(|| sql_upper[where_idx..].find(" ORDER BY"))
140            .or_else(|| sql_upper[where_idx..].find(" LIMIT"))
141            .or_else(|| sql_upper[where_idx..].find(';'))
142            .unwrap_or(sql_upper[where_idx..].len());
143        Some(sql[where_idx..where_idx + where_end].trim().to_string())
144    } else {
145        None
146    };
147
148    // Check for wildcard
149    let select_wildcard = sql_upper.contains("SELECT *");
150
151    Ok(ParsedQuery {
152        select_wildcard,
153        columns: vec![],
154        table,
155        where_sql,
156        group_by: vec![],
157        order_by: vec![],
158        limit,
159        has_aggregates: false,
160        query_type: QueryType::Select,
161    })
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167
168    #[test]
169    fn test_parse_simple_select() {
170        let result = parse("SELECT * FROM entries");
171        assert!(result.is_ok());
172        let parsed = result.unwrap();
173        assert!(parsed.select_wildcard);
174        assert_eq!(parsed.table, Some("entries".to_string()));
175        assert_eq!(parsed.query_type, QueryType::Select);
176    }
177
178    #[test]
179    fn test_parse_select_with_limit() {
180        let result = parse("SELECT * FROM entries LIMIT 10");
181        assert!(result.is_ok());
182        let parsed = result.unwrap();
183        assert_eq!(parsed.limit, Some(10));
184    }
185
186    #[test]
187    fn test_parse_select_with_where() {
188        let result = parse("SELECT * FROM entries WHERE source = 'Claude'");
189        assert!(result.is_ok());
190        let parsed = result.unwrap();
191        assert!(parsed.where_sql.is_some());
192    }
193
194    #[test]
195    fn test_parse_show_tables() {
196        let result = parse("SHOW TABLES");
197        assert!(result.is_ok());
198        let parsed = result.unwrap();
199        assert_eq!(parsed.query_type, QueryType::ShowTables);
200    }
201
202    #[test]
203    fn test_parse_show_tables_with_semicolon() {
204        let result = parse("SHOW TABLES;");
205        assert!(result.is_ok());
206        let parsed = result.unwrap();
207        assert_eq!(parsed.query_type, QueryType::ShowTables);
208    }
209
210    #[test]
211    fn test_parse_describe_entries() {
212        let result = parse("DESCRIBE entries");
213        assert!(result.is_ok());
214        let parsed = result.unwrap();
215        assert_eq!(
216            parsed.query_type,
217            QueryType::DescribeTable("entries".to_string())
218        );
219    }
220
221    #[test]
222    fn test_parse_desc_entries() {
223        let result = parse("DESC entries");
224        assert!(result.is_ok());
225        let parsed = result.unwrap();
226        assert_eq!(
227            parsed.query_type,
228            QueryType::DescribeTable("entries".to_string())
229        );
230    }
231
232    #[test]
233    fn test_parse_pragma_table_info() {
234        let result = parse("PRAGMA table_info(entries)");
235        assert!(result.is_ok());
236        let parsed = result.unwrap();
237        assert_eq!(
238            parsed.query_type,
239            QueryType::DescribeTable("entries".to_string())
240        );
241    }
242
243    #[test]
244    fn test_parse_describe_invalid_table() {
245        let result = parse("DESCRIBE invalid_table");
246        assert!(result.is_err());
247        assert!(matches!(result.unwrap_err(), QueryError::InvalidTable(_)));
248    }
249}