1use crate::error::{QueryError, QueryResult};
4
5#[derive(Debug, Clone, Default)]
7pub struct ParsedQuery {
8 pub select_wildcard: bool,
10 pub columns: Vec<String>,
12 pub table: Option<String>,
14 pub where_sql: Option<String>,
16 pub group_by: Vec<String>,
18 pub order_by: Vec<(String, bool)>, pub limit: Option<usize>,
22 pub has_aggregates: bool,
24 pub query_type: QueryType,
26}
27
28#[derive(Debug, Clone, Default, PartialEq)]
30pub enum QueryType {
31 #[default]
33 Select,
34 ShowTables,
36 DescribeTable(String),
38}
39
40pub fn parse(sql: &str) -> QueryResult<ParsedQuery> {
42 let sql_upper = sql.trim().to_uppercase();
43
44 if sql_upper == "SHOW TABLES" || sql_upper.starts_with("SHOW TABLES;") {
46 return Ok(ParsedQuery {
47 query_type: QueryType::ShowTables,
48 ..Default::default()
49 });
50 }
51
52 if sql_upper.starts_with("DESCRIBE ") || sql_upper.starts_with("DESC ") {
54 let keyword = if sql_upper.starts_with("DESCRIBE ") {
55 "DESCRIBE "
56 } else {
57 "DESC "
58 };
59 let table_name = sql[keyword.len()..].trim().to_string();
60 let table_name = table_name.trim_end_matches(';').trim().to_string();
61
62 if table_name.to_lowercase() != "entries" {
63 return Err(QueryError::InvalidTable(table_name));
64 }
65
66 return Ok(ParsedQuery {
67 query_type: QueryType::DescribeTable("entries".to_string()),
68 table: Some("entries".to_string()),
69 ..Default::default()
70 });
71 }
72
73 if sql_upper.starts_with("PRAGMA TABLE_INFO(") {
75 let start = sql_upper.find('(').unwrap() + 1;
77 let end = sql_upper.find(')').unwrap();
78 let table_name = sql[start..end].trim().to_string();
79
80 if table_name.to_lowercase() != "entries" {
81 return Err(QueryError::InvalidTable(table_name));
82 }
83
84 return Ok(ParsedQuery {
85 query_type: QueryType::DescribeTable("entries".to_string()),
86 table: Some("entries".to_string()),
87 ..Default::default()
88 });
89 }
90
91 if !sql_upper.starts_with("SELECT") {
93 return Err(QueryError::ParseError(
94 "Only SELECT, SHOW TABLES, and DESCRIBE statements are supported".to_string(),
95 ));
96 }
97
98 if !sql_upper.contains("FROM") {
100 return Err(QueryError::ParseError("Missing FROM clause".to_string()));
101 }
102
103 let table = if sql_upper.contains("FROM ENTRIES") {
105 Some("entries".to_string())
106 } else {
107 let from_idx = sql_upper.find("FROM ").unwrap() + 5;
109 let table_part = &sql[from_idx..];
110 let table_end = table_part
111 .find(|c: char| c.is_whitespace())
112 .or_else(|| table_part.find(';'))
113 .unwrap_or(table_part.len());
114 let table_name = table_part[..table_end].trim().to_string();
115 if table_name.to_lowercase() != "entries" {
117 return Err(QueryError::InvalidTable(table_name));
118 }
119 Some("entries".to_string())
120 };
121
122 let limit = if let Some(limit_idx) = sql_upper.find("LIMIT ") {
124 let limit_str = &sql[limit_idx + 6..];
125 let limit_end = limit_str
126 .find(|c: char| c.is_whitespace())
127 .or_else(|| limit_str.find(';'))
128 .unwrap_or(limit_str.len());
129 limit_str[..limit_end].trim().parse::<usize>().ok()
130 } else {
131 None
132 };
133
134 let where_sql = if sql_upper.contains("WHERE ") {
136 let where_idx = sql_upper.find("WHERE ").unwrap() + 6;
137 let where_end = sql_upper[where_idx..]
138 .find(" GROUP BY")
139 .or_else(|| sql_upper[where_idx..].find(" ORDER BY"))
140 .or_else(|| sql_upper[where_idx..].find(" LIMIT"))
141 .or_else(|| sql_upper[where_idx..].find(';'))
142 .unwrap_or(sql_upper[where_idx..].len());
143 Some(sql[where_idx..where_idx + where_end].trim().to_string())
144 } else {
145 None
146 };
147
148 let select_wildcard = sql_upper.contains("SELECT *");
150
151 Ok(ParsedQuery {
152 select_wildcard,
153 columns: vec![],
154 table,
155 where_sql,
156 group_by: vec![],
157 order_by: vec![],
158 limit,
159 has_aggregates: false,
160 query_type: QueryType::Select,
161 })
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167
168 #[test]
169 fn test_parse_simple_select() {
170 let result = parse("SELECT * FROM entries");
171 assert!(result.is_ok());
172 let parsed = result.unwrap();
173 assert!(parsed.select_wildcard);
174 assert_eq!(parsed.table, Some("entries".to_string()));
175 assert_eq!(parsed.query_type, QueryType::Select);
176 }
177
178 #[test]
179 fn test_parse_select_with_limit() {
180 let result = parse("SELECT * FROM entries LIMIT 10");
181 assert!(result.is_ok());
182 let parsed = result.unwrap();
183 assert_eq!(parsed.limit, Some(10));
184 }
185
186 #[test]
187 fn test_parse_select_with_where() {
188 let result = parse("SELECT * FROM entries WHERE source = 'Claude'");
189 assert!(result.is_ok());
190 let parsed = result.unwrap();
191 assert!(parsed.where_sql.is_some());
192 }
193
194 #[test]
195 fn test_parse_show_tables() {
196 let result = parse("SHOW TABLES");
197 assert!(result.is_ok());
198 let parsed = result.unwrap();
199 assert_eq!(parsed.query_type, QueryType::ShowTables);
200 }
201
202 #[test]
203 fn test_parse_show_tables_with_semicolon() {
204 let result = parse("SHOW TABLES;");
205 assert!(result.is_ok());
206 let parsed = result.unwrap();
207 assert_eq!(parsed.query_type, QueryType::ShowTables);
208 }
209
210 #[test]
211 fn test_parse_describe_entries() {
212 let result = parse("DESCRIBE entries");
213 assert!(result.is_ok());
214 let parsed = result.unwrap();
215 assert_eq!(
216 parsed.query_type,
217 QueryType::DescribeTable("entries".to_string())
218 );
219 }
220
221 #[test]
222 fn test_parse_desc_entries() {
223 let result = parse("DESC entries");
224 assert!(result.is_ok());
225 let parsed = result.unwrap();
226 assert_eq!(
227 parsed.query_type,
228 QueryType::DescribeTable("entries".to_string())
229 );
230 }
231
232 #[test]
233 fn test_parse_pragma_table_info() {
234 let result = parse("PRAGMA table_info(entries)");
235 assert!(result.is_ok());
236 let parsed = result.unwrap();
237 assert_eq!(
238 parsed.query_type,
239 QueryType::DescribeTable("entries".to_string())
240 );
241 }
242
243 #[test]
244 fn test_parse_describe_invalid_table() {
245 let result = parse("DESCRIBE invalid_table");
246 assert!(result.is_err());
247 assert!(matches!(result.unwrap_err(), QueryError::InvalidTable(_)));
248 }
249}