1use crate::error::{ColumnInfo, QueryError, QueryResult, SchemaInfo, SchemaQueryType};
4use crate::parser::{ParsedQuery, QueryType};
5use cai_core::Entry;
6use cai_storage::Storage;
7use std::sync::Arc;
8
9#[derive(Debug, Clone)]
11pub enum QueryResultData {
12 Entries(Vec<Entry>),
14 Schema(SchemaInfo),
16}
17
18#[derive(Clone)]
20pub struct QueryEngine {
21 storage: Arc<dyn Storage>,
22}
23
24impl QueryEngine {
25 pub fn new<S>(storage: S) -> Self
27 where
28 S: Storage + 'static,
29 {
30 Self {
31 storage: Arc::new(storage),
32 }
33 }
34
35 pub fn from_arc(storage: Arc<dyn Storage>) -> Self {
37 Self { storage }
38 }
39
40 pub async fn execute(&self, sql: &str) -> QueryResult<Vec<Entry>> {
42 let parsed = crate::parse(sql)?;
43
44 match &parsed.query_type {
46 QueryType::ShowTables => {
47 Ok(vec![])
50 }
51 QueryType::DescribeTable(_) => {
52 Ok(vec![])
55 }
56 QueryType::Select => {
57 if parsed
59 .table
60 .as_ref()
61 .is_some_and(|t| t.to_lowercase() != "entries")
62 {
63 if let Some(table) = parsed.table {
64 return Err(QueryError::InvalidTable(table));
65 }
66 }
67
68 self.execute_simple_query(&parsed).await
70 }
71 }
72 }
73
74 pub async fn execute_full(&self, sql: &str) -> QueryResult<QueryResultData> {
76 let parsed = crate::parse(sql)?;
77
78 match &parsed.query_type {
79 QueryType::ShowTables => Ok(QueryResultData::Schema(SchemaInfo {
80 query_type: SchemaQueryType::ShowTables,
81 table_name: None,
82 tables: vec!["entries".to_string()],
83 columns: vec![],
84 })),
85 QueryType::DescribeTable(table_name) => Ok(QueryResultData::Schema(SchemaInfo {
86 query_type: SchemaQueryType::DescribeTable,
87 table_name: Some(table_name.clone()),
88 tables: vec![],
89 columns: Self::get_entry_columns(),
90 })),
91 QueryType::Select => {
92 if parsed
94 .table
95 .as_ref()
96 .is_some_and(|t| t.to_lowercase() != "entries")
97 {
98 if let Some(table) = parsed.table.clone() {
99 return Err(QueryError::InvalidTable(table));
100 }
101 }
102
103 let entries = self.execute_simple_query(&parsed).await?;
104 Ok(QueryResultData::Entries(entries))
105 }
106 }
107 }
108
109 fn get_entry_columns() -> Vec<ColumnInfo> {
111 vec![
112 ColumnInfo {
113 name: "id".to_string(),
114 data_type: "TEXT".to_string(),
115 description: "Unique identifier".to_string(),
116 },
117 ColumnInfo {
118 name: "source".to_string(),
119 data_type: "TEXT".to_string(),
120 description: "Source system (Claude, Codex, Git, Other)".to_string(),
121 },
122 ColumnInfo {
123 name: "timestamp".to_string(),
124 data_type: "TIMESTAMP".to_string(),
125 description: "Interaction timestamp (UTC)".to_string(),
126 },
127 ColumnInfo {
128 name: "prompt".to_string(),
129 data_type: "TEXT".to_string(),
130 description: "User prompt/input".to_string(),
131 },
132 ColumnInfo {
133 name: "response".to_string(),
134 data_type: "TEXT".to_string(),
135 description: "AI response/output".to_string(),
136 },
137 ColumnInfo {
138 name: "metadata".to_string(),
139 data_type: "JSON".to_string(),
140 description: "Additional metadata (file_path, language, etc.)".to_string(),
141 },
142 ]
143 }
144
145 async fn execute_simple_query(&self, parsed: &ParsedQuery) -> QueryResult<Vec<Entry>> {
146 let mut entries = self.storage.query(None).await?;
147
148 if let Some(ref where_sql) = parsed.where_sql {
150 entries = self.apply_where_filter(entries, where_sql)?;
151 }
152
153 if !parsed.order_by.is_empty() {
155 entries = self.apply_order_by(entries, &parsed.order_by)?;
156 }
157
158 if let Some(limit) = parsed.limit {
160 entries.truncate(limit);
161 }
162
163 Ok(entries)
164 }
165
166 fn apply_where_filter(&self, entries: Vec<Entry>, where_sql: &str) -> QueryResult<Vec<Entry>> {
167 let where_upper = where_sql.to_uppercase();
169
170 let expected_source = if where_upper.contains("SOURCE =") || where_upper.contains("SOURCE=")
172 {
173 extract_quoted_string(where_sql)
174 } else {
175 None
176 };
177
178 let expected_ts_gt =
179 if where_upper.contains("TIMESTAMP >") || where_upper.contains("TIMESTAMP>") {
180 extract_timestamp(where_sql)
181 .and_then(|s| s.parse::<chrono::DateTime<chrono::Utc>>().ok())
182 } else {
183 None
184 };
185
186 let expected_ts_lt =
187 if where_upper.contains("TIMESTAMP <") || where_upper.contains("TIMESTAMP<") {
188 extract_timestamp(where_sql)
189 .and_then(|s| s.parse::<chrono::DateTime<chrono::Utc>>().ok())
190 } else {
191 None
192 };
193
194 Ok(entries
195 .into_iter()
196 .filter(|entry| {
197 if let Some(ref source) = expected_source {
198 if format!("{:?}", entry.source) != *source {
199 return false;
200 }
201 }
202 if let Some(ts) = expected_ts_gt {
203 if entry.timestamp <= ts {
204 return false;
205 }
206 }
207 if let Some(ts) = expected_ts_lt {
208 if entry.timestamp >= ts {
209 return false;
210 }
211 }
212 true
213 })
214 .collect::<Vec<_>>())
215 }
216
217 fn apply_order_by(
218 &self,
219 mut entries: Vec<Entry>,
220 order_by: &[(String, bool)],
221 ) -> QueryResult<Vec<Entry>> {
222 entries.sort_by(|a, b| {
223 for (col, asc) in order_by {
224 let cmp = match col.to_lowercase().as_str() {
225 "timestamp" => a.timestamp.cmp(&b.timestamp),
226 "source" => format!("{:?}", a.source).cmp(&format!("{:?}", b.source)),
227 "id" => a.id.cmp(&b.id),
228 "prompt" => a.prompt.cmp(&b.prompt),
229 "response" => a.response.cmp(&b.response),
230 _ => std::cmp::Ordering::Equal,
231 };
232
233 let cmp = if *asc { cmp } else { cmp.reverse() };
234
235 if cmp != std::cmp::Ordering::Equal {
236 return cmp;
237 }
238 }
239 std::cmp::Ordering::Equal
240 });
241 Ok(entries)
242 }
243}
244
245fn extract_timestamp(sql: &str) -> Option<&str> {
246 let start = sql.find('\'')? + 1;
247 let end = sql[start..].find('\'')?;
248 Some(&sql[start..start + end])
249}
250
251fn extract_quoted_string(sql: &str) -> Option<String> {
252 let start = sql.find('\'')? + 1;
253 let end = sql[start..].find('\'')?;
254 Some(sql[start..start + end].to_string())
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260 use cai_core::Source;
261 use cai_storage::MemoryStorage;
262 use chrono::Utc;
263
264 fn make_test_entries() -> (MemoryStorage, Vec<Entry>) {
265 let storage = MemoryStorage::new();
266
267 let entries = vec![
268 Entry {
269 id: "1".to_string(),
270 source: Source::Claude,
271 timestamp: chrono::DateTime::parse_from_rfc3339("2024-01-15T10:00:00Z")
272 .unwrap()
273 .with_timezone(&Utc),
274 prompt: "hello".to_string(),
275 response: "world".to_string(),
276 metadata: cai_core::Metadata {
277 file_path: Some("/path/to/file.rs".to_string()),
278 repo_url: None,
279 commit_hash: None,
280 language: Some("rust".to_string()),
281 ..Default::default()
282 },
283 },
284 Entry {
285 id: "2".to_string(),
286 source: Source::Git,
287 timestamp: chrono::DateTime::parse_from_rfc3339("2024-01-16T11:00:00Z")
288 .unwrap()
289 .with_timezone(&Utc),
290 prompt: "commit".to_string(),
291 response: "message".to_string(),
292 metadata: cai_core::Metadata {
293 file_path: None,
294 repo_url: None,
295 commit_hash: Some("abc123".to_string()),
296 language: None,
297 ..Default::default()
298 },
299 },
300 ];
301
302 (storage, entries)
303 }
304
305 #[tokio::test]
306 async fn test_simple_select() {
307 let (storage, entries) = make_test_entries();
308 for entry in &entries {
309 storage.store(entry).await.unwrap();
310 }
311
312 let engine = QueryEngine::new(storage);
313 let results = engine.execute("SELECT * FROM entries").await.unwrap();
314
315 assert_eq!(results.len(), 2);
316 }
317
318 #[tokio::test]
319 async fn test_select_with_limit() {
320 let (storage, entries) = make_test_entries();
321 for entry in &entries {
322 storage.store(entry).await.unwrap();
323 }
324
325 let engine = QueryEngine::new(storage);
326 let results = engine
327 .execute("SELECT * FROM entries LIMIT 1")
328 .await
329 .unwrap();
330
331 assert_eq!(results.len(), 1);
332 }
333
334 #[tokio::test]
335 async fn test_select_with_where() {
336 let (storage, entries) = make_test_entries();
337 for entry in &entries {
338 storage.store(entry).await.unwrap();
339 }
340
341 let engine = QueryEngine::new(storage);
342 let results = engine
343 .execute("SELECT * FROM entries WHERE source = 'Claude'")
344 .await
345 .unwrap();
346
347 assert_eq!(results.len(), 1);
348 assert_eq!(results[0].source, Source::Claude);
349 }
350
351 #[tokio::test]
352 async fn test_order_by() {
353 let (storage, entries) = make_test_entries();
354 for entry in &entries {
355 storage.store(entry).await.unwrap();
356 }
357
358 let engine = QueryEngine::new(storage);
359 let results = engine
361 .execute("SELECT * FROM entries ORDER BY timestamp DESC")
362 .await
363 .unwrap();
364
365 assert_eq!(results.len(), 2);
366 }
368
369 #[tokio::test]
370 async fn test_invalid_table() {
371 let storage = MemoryStorage::new();
372 let engine = QueryEngine::new(storage);
373
374 let result = engine.execute("SELECT * FROM invalid_table").await;
375
376 assert!(matches!(result, Err(QueryError::InvalidTable(_))));
377 }
378}