Skip to main content

cai_query/
executor.rs

1//! Query execution engine
2
3use crate::error::{ColumnInfo, QueryError, QueryResult, SchemaInfo, SchemaQueryType};
4use crate::parser::{ParsedQuery, QueryType};
5use cai_core::Entry;
6use cai_storage::Storage;
7use std::sync::Arc;
8
9/// Query result type - can be either entries or schema information
10#[derive(Debug, Clone)]
11pub enum QueryResultData {
12    /// Standard entry query results
13    Entries(Vec<Entry>),
14    /// Schema query results (SHOW TABLES, DESCRIBE)
15    Schema(SchemaInfo),
16}
17
18/// Query engine for executing SQL queries against storage
19#[derive(Clone)]
20pub struct QueryEngine {
21    storage: Arc<dyn Storage>,
22}
23
24impl QueryEngine {
25    /// Create a new query engine
26    pub fn new<S>(storage: S) -> Self
27    where
28        S: Storage + 'static,
29    {
30        Self {
31            storage: Arc::new(storage),
32        }
33    }
34
35    /// Create a new query engine from an Arc<dyn Storage>
36    pub fn from_arc(storage: Arc<dyn Storage>) -> Self {
37        Self { storage }
38    }
39
40    /// Execute a SQL query and return matching entries
41    pub async fn execute(&self, sql: &str) -> QueryResult<Vec<Entry>> {
42        let parsed = crate::parse(sql)?;
43
44        // Handle schema queries
45        match &parsed.query_type {
46            QueryType::ShowTables => {
47                // For backward compatibility, return empty vec for SHOW TABLES
48                // Users should use execute_schema for schema queries
49                Ok(vec![])
50            }
51            QueryType::DescribeTable(_) => {
52                // For backward compatibility, return empty vec for DESCRIBE
53                // Users should use execute_schema for schema queries
54                Ok(vec![])
55            }
56            QueryType::Select => {
57                // Validate table name
58                if parsed
59                    .table
60                    .as_ref()
61                    .is_some_and(|t| t.to_lowercase() != "entries")
62                {
63                    if let Some(table) = parsed.table {
64                        return Err(QueryError::InvalidTable(table));
65                    }
66                }
67
68                // For now, handle simple cases
69                self.execute_simple_query(&parsed).await
70            }
71        }
72    }
73
74    /// Execute a SQL query and return full query result data (including schema)
75    pub async fn execute_full(&self, sql: &str) -> QueryResult<QueryResultData> {
76        let parsed = crate::parse(sql)?;
77
78        match &parsed.query_type {
79            QueryType::ShowTables => Ok(QueryResultData::Schema(SchemaInfo {
80                query_type: SchemaQueryType::ShowTables,
81                table_name: None,
82                tables: vec!["entries".to_string()],
83                columns: vec![],
84            })),
85            QueryType::DescribeTable(table_name) => Ok(QueryResultData::Schema(SchemaInfo {
86                query_type: SchemaQueryType::DescribeTable,
87                table_name: Some(table_name.clone()),
88                tables: vec![],
89                columns: Self::get_entry_columns(),
90            })),
91            QueryType::Select => {
92                // Validate table name
93                if parsed
94                    .table
95                    .as_ref()
96                    .is_some_and(|t| t.to_lowercase() != "entries")
97                {
98                    if let Some(table) = parsed.table.clone() {
99                        return Err(QueryError::InvalidTable(table));
100                    }
101                }
102
103                let entries = self.execute_simple_query(&parsed).await?;
104                Ok(QueryResultData::Entries(entries))
105            }
106        }
107    }
108
109    /// Get column information for the entries table
110    fn get_entry_columns() -> Vec<ColumnInfo> {
111        vec![
112            ColumnInfo {
113                name: "id".to_string(),
114                data_type: "TEXT".to_string(),
115                description: "Unique identifier".to_string(),
116            },
117            ColumnInfo {
118                name: "source".to_string(),
119                data_type: "TEXT".to_string(),
120                description: "Source system (Claude, Codex, Git, Other)".to_string(),
121            },
122            ColumnInfo {
123                name: "timestamp".to_string(),
124                data_type: "TIMESTAMP".to_string(),
125                description: "Interaction timestamp (UTC)".to_string(),
126            },
127            ColumnInfo {
128                name: "prompt".to_string(),
129                data_type: "TEXT".to_string(),
130                description: "User prompt/input".to_string(),
131            },
132            ColumnInfo {
133                name: "response".to_string(),
134                data_type: "TEXT".to_string(),
135                description: "AI response/output".to_string(),
136            },
137            ColumnInfo {
138                name: "metadata".to_string(),
139                data_type: "JSON".to_string(),
140                description: "Additional metadata (file_path, language, etc.)".to_string(),
141            },
142        ]
143    }
144
145    async fn execute_simple_query(&self, parsed: &ParsedQuery) -> QueryResult<Vec<Entry>> {
146        let mut entries = self.storage.query(None).await?;
147
148        // Apply simple WHERE filter
149        if let Some(ref where_sql) = parsed.where_sql {
150            entries = self.apply_where_filter(entries, where_sql)?;
151        }
152
153        // Apply ORDER BY
154        if !parsed.order_by.is_empty() {
155            entries = self.apply_order_by(entries, &parsed.order_by)?;
156        }
157
158        // Apply LIMIT
159        if let Some(limit) = parsed.limit {
160            entries.truncate(limit);
161        }
162
163        Ok(entries)
164    }
165
166    fn apply_where_filter(&self, entries: Vec<Entry>, where_sql: &str) -> QueryResult<Vec<Entry>> {
167        // Parse simple WHERE conditions
168        let where_upper = where_sql.to_uppercase();
169
170        // Extract values once to avoid lifetime issues
171        let expected_source = if where_upper.contains("SOURCE =") || where_upper.contains("SOURCE=")
172        {
173            extract_quoted_string(where_sql)
174        } else {
175            None
176        };
177
178        let expected_ts_gt =
179            if where_upper.contains("TIMESTAMP >") || where_upper.contains("TIMESTAMP>") {
180                extract_timestamp(where_sql)
181                    .and_then(|s| s.parse::<chrono::DateTime<chrono::Utc>>().ok())
182            } else {
183                None
184            };
185
186        let expected_ts_lt =
187            if where_upper.contains("TIMESTAMP <") || where_upper.contains("TIMESTAMP<") {
188                extract_timestamp(where_sql)
189                    .and_then(|s| s.parse::<chrono::DateTime<chrono::Utc>>().ok())
190            } else {
191                None
192            };
193
194        Ok(entries
195            .into_iter()
196            .filter(|entry| {
197                if let Some(ref source) = expected_source {
198                    if format!("{:?}", entry.source) != *source {
199                        return false;
200                    }
201                }
202                if let Some(ts) = expected_ts_gt {
203                    if entry.timestamp <= ts {
204                        return false;
205                    }
206                }
207                if let Some(ts) = expected_ts_lt {
208                    if entry.timestamp >= ts {
209                        return false;
210                    }
211                }
212                true
213            })
214            .collect::<Vec<_>>())
215    }
216
217    fn apply_order_by(
218        &self,
219        mut entries: Vec<Entry>,
220        order_by: &[(String, bool)],
221    ) -> QueryResult<Vec<Entry>> {
222        entries.sort_by(|a, b| {
223            for (col, asc) in order_by {
224                let cmp = match col.to_lowercase().as_str() {
225                    "timestamp" => a.timestamp.cmp(&b.timestamp),
226                    "source" => format!("{:?}", a.source).cmp(&format!("{:?}", b.source)),
227                    "id" => a.id.cmp(&b.id),
228                    "prompt" => a.prompt.cmp(&b.prompt),
229                    "response" => a.response.cmp(&b.response),
230                    _ => std::cmp::Ordering::Equal,
231                };
232
233                let cmp = if *asc { cmp } else { cmp.reverse() };
234
235                if cmp != std::cmp::Ordering::Equal {
236                    return cmp;
237                }
238            }
239            std::cmp::Ordering::Equal
240        });
241        Ok(entries)
242    }
243}
244
245fn extract_timestamp(sql: &str) -> Option<&str> {
246    let start = sql.find('\'')? + 1;
247    let end = sql[start..].find('\'')?;
248    Some(&sql[start..start + end])
249}
250
251fn extract_quoted_string(sql: &str) -> Option<String> {
252    let start = sql.find('\'')? + 1;
253    let end = sql[start..].find('\'')?;
254    Some(sql[start..start + end].to_string())
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260    use cai_core::Source;
261    use cai_storage::MemoryStorage;
262    use chrono::Utc;
263
264    fn make_test_entries() -> (MemoryStorage, Vec<Entry>) {
265        let storage = MemoryStorage::new();
266
267        let entries = vec![
268            Entry {
269                id: "1".to_string(),
270                source: Source::Claude,
271                timestamp: chrono::DateTime::parse_from_rfc3339("2024-01-15T10:00:00Z")
272                    .unwrap()
273                    .with_timezone(&Utc),
274                prompt: "hello".to_string(),
275                response: "world".to_string(),
276                metadata: cai_core::Metadata {
277                    file_path: Some("/path/to/file.rs".to_string()),
278                    repo_url: None,
279                    commit_hash: None,
280                    language: Some("rust".to_string()),
281                    ..Default::default()
282                },
283            },
284            Entry {
285                id: "2".to_string(),
286                source: Source::Git,
287                timestamp: chrono::DateTime::parse_from_rfc3339("2024-01-16T11:00:00Z")
288                    .unwrap()
289                    .with_timezone(&Utc),
290                prompt: "commit".to_string(),
291                response: "message".to_string(),
292                metadata: cai_core::Metadata {
293                    file_path: None,
294                    repo_url: None,
295                    commit_hash: Some("abc123".to_string()),
296                    language: None,
297                    ..Default::default()
298                },
299            },
300        ];
301
302        (storage, entries)
303    }
304
305    #[tokio::test]
306    async fn test_simple_select() {
307        let (storage, entries) = make_test_entries();
308        for entry in &entries {
309            storage.store(entry).await.unwrap();
310        }
311
312        let engine = QueryEngine::new(storage);
313        let results = engine.execute("SELECT * FROM entries").await.unwrap();
314
315        assert_eq!(results.len(), 2);
316    }
317
318    #[tokio::test]
319    async fn test_select_with_limit() {
320        let (storage, entries) = make_test_entries();
321        for entry in &entries {
322            storage.store(entry).await.unwrap();
323        }
324
325        let engine = QueryEngine::new(storage);
326        let results = engine
327            .execute("SELECT * FROM entries LIMIT 1")
328            .await
329            .unwrap();
330
331        assert_eq!(results.len(), 1);
332    }
333
334    #[tokio::test]
335    async fn test_select_with_where() {
336        let (storage, entries) = make_test_entries();
337        for entry in &entries {
338            storage.store(entry).await.unwrap();
339        }
340
341        let engine = QueryEngine::new(storage);
342        let results = engine
343            .execute("SELECT * FROM entries WHERE source = 'Claude'")
344            .await
345            .unwrap();
346
347        assert_eq!(results.len(), 1);
348        assert_eq!(results[0].source, Source::Claude);
349    }
350
351    #[tokio::test]
352    async fn test_order_by() {
353        let (storage, entries) = make_test_entries();
354        for entry in &entries {
355            storage.store(entry).await.unwrap();
356        }
357
358        let engine = QueryEngine::new(storage);
359        // Note: ORDER BY parsing not implemented in simple parser yet
360        let results = engine
361            .execute("SELECT * FROM entries ORDER BY timestamp DESC")
362            .await
363            .unwrap();
364
365        assert_eq!(results.len(), 2);
366        // For now, just verify we get results (ordering not implemented yet)
367    }
368
369    #[tokio::test]
370    async fn test_invalid_table() {
371        let storage = MemoryStorage::new();
372        let engine = QueryEngine::new(storage);
373
374        let result = engine.execute("SELECT * FROM invalid_table").await;
375
376        assert!(matches!(result, Err(QueryError::InvalidTable(_))));
377    }
378}