Skip to main content

cai_storage/
lib.rs

1//! CAI Storage - Pluggable storage backends
2
3#![warn(missing_docs)]
4
5pub use cai_core::{Error, Result};
6
7use async_trait::async_trait;
8use cai_core::Entry;
9
10// SQLite storage temporarily disabled due to async/sync mismatch
11// #[cfg(feature = "sqlite")]
12// pub mod sqlite;
13//
14// #[cfg(feature = "sqlite")]
15// pub use sqlite::SqliteStorage;
16
17/// Storage backend trait
18#[async_trait]
19pub trait Storage: Send + Sync {
20    /// Store an entry
21    async fn store(&self, entry: &Entry) -> Result<()>;
22
23    /// Retrieve an entry by ID
24    async fn get(&self, id: &str) -> Result<Option<Entry>>;
25
26    /// Query entries with optional filter
27    async fn query(&self, filter: Option<&Filter>) -> Result<Vec<Entry>>;
28
29    /// Count entries
30    async fn count(&self) -> Result<usize>;
31}
32
33#[cfg(feature = "duckdb")]
34pub mod duckdb;
35
36#[cfg(feature = "duckdb")]
37pub use duckdb::DuckDBStorage;
38
39/// Query filter
40#[derive(Debug, Clone, Default)]
41pub struct Filter {
42    /// Source to filter by
43    pub source: Option<String>,
44    /// Minimum timestamp
45    pub after: Option<chrono::DateTime<chrono::Utc>>,
46    /// Maximum timestamp
47    pub before: Option<chrono::DateTime<chrono::Utc>>,
48}
49
50/// In-memory storage implementation
51pub struct MemoryStorage {
52    entries: std::sync::Arc<tokio::sync::RwLock<Vec<Entry>>>,
53}
54
55impl MemoryStorage {
56    /// Create new in-memory storage
57    pub fn new() -> Self {
58        Self {
59            entries: std::sync::Arc::new(tokio::sync::RwLock::new(Vec::new())),
60        }
61    }
62}
63
64impl Default for MemoryStorage {
65    fn default() -> Self {
66        Self::new()
67    }
68}
69
70impl MemoryStorage {
71    /// Create new in-memory storage with mock data for testing
72    pub fn with_mock_data() -> Self {
73        use cai_core::{Entry, Metadata, Source};
74        use chrono::Utc;
75
76        let storage = Self::new();
77
78        let mock_entries = vec![
79            Entry {
80                id: "1".to_string(),
81                source: Source::Claude,
82                timestamp: Utc::now() - chrono::Duration::hours(2),
83                prompt: "Help me refactor this Rust function to be more idiomatic".to_string(),
84                response: "Here's a more idiomatic version using iterators and pattern matching..."
85                    .to_string(),
86                metadata: Metadata {
87                    file_path: Some("src/main.rs".to_string()),
88                    language: Some("Rust".to_string()),
89                    ..Default::default()
90                },
91            },
92            Entry {
93                id: "2".to_string(),
94                source: Source::Claude,
95                timestamp: Utc::now() - chrono::Duration::hours(4),
96                prompt: "Write a unit test for this module".to_string(),
97                response: "Here are comprehensive unit tests using rstest...".to_string(),
98                metadata: Metadata {
99                    file_path: Some("src/storage.rs".to_string()),
100                    language: Some("Rust".to_string()),
101                    ..Default::default()
102                },
103            },
104            Entry {
105                id: "3".to_string(),
106                source: Source::Claude,
107                timestamp: Utc::now() - chrono::Duration::hours(6),
108                prompt: "Explain how async/await works in Rust".to_string(),
109                response: "Async/await in Rust is built on futures...".to_string(),
110                metadata: Metadata {
111                    language: Some("Rust".to_string()),
112                    ..Default::default()
113                },
114            },
115            Entry {
116                id: "4".to_string(),
117                source: Source::Codex,
118                timestamp: Utc::now() - chrono::Duration::hours(8),
119                prompt: "Implement a binary search function".to_string(),
120                response: "fn binary_search(arr: &[i32], target: i32) -> Option<usize> { ... }"
121                    .to_string(),
122                metadata: Metadata {
123                    language: Some("Rust".to_string()),
124                    ..Default::default()
125                },
126            },
127            Entry {
128                id: "5".to_string(),
129                source: Source::Git,
130                timestamp: Utc::now() - chrono::Duration::hours(10),
131                prompt: "feat: Add TUI implementation".to_string(),
132                response: "Implemented terminal UI with ratatui...".to_string(),
133                metadata: Metadata {
134                    commit_hash: Some("abc123def".to_string()),
135                    repo_url: Some("https://github.com/cai-dev/coding-agent-insights".to_string()),
136                    ..Default::default()
137                },
138            },
139            Entry {
140                id: "6".to_string(),
141                source: Source::Claude,
142                timestamp: Utc::now() - chrono::Duration::hours(12),
143                prompt: "What's the difference between Arc and Rc in Rust?".to_string(),
144                response: "Arc (Atomic Reference Counting) is thread-safe...".to_string(),
145                metadata: Metadata {
146                    language: Some("Rust".to_string()),
147                    ..Default::default()
148                },
149            },
150            Entry {
151                id: "7".to_string(),
152                source: Source::Claude,
153                timestamp: Utc::now() - chrono::Duration::days(1),
154                prompt: "Design a REST API for user management".to_string(),
155                response: "Here's a RESTful API design using axum...".to_string(),
156                metadata: Metadata {
157                    language: Some("Rust".to_string()),
158                    ..Default::default()
159                },
160            },
161            Entry {
162                id: "8".to_string(),
163                source: Source::Claude,
164                timestamp: Utc::now() - chrono::Duration::days(2),
165                prompt: "Debug this segmentation fault".to_string(),
166                response: "The segfault is caused by a dangling reference...".to_string(),
167                metadata: Metadata {
168                    file_path: Some("src/parser.rs".to_string()),
169                    language: Some("Rust".to_string()),
170                    ..Default::default()
171                },
172            },
173        ];
174
175        // Use tokio runtime to store entries
176        let rt = tokio::runtime::Runtime::new().unwrap();
177        rt.block_on(async {
178            for entry in mock_entries {
179                let _ = storage.store(&entry).await;
180            }
181        });
182
183        storage
184    }
185}
186
187#[async_trait]
188impl Storage for MemoryStorage {
189    async fn store(&self, entry: &Entry) -> Result<()> {
190        self.entries.write().await.push(entry.clone());
191        Ok(())
192    }
193
194    async fn get(&self, id: &str) -> Result<Option<Entry>> {
195        Ok(self
196            .entries
197            .read()
198            .await
199            .iter()
200            .find(|e| e.id == id)
201            .cloned())
202    }
203
204    async fn query(&self, filter: Option<&Filter>) -> Result<Vec<Entry>> {
205        let entries = self.entries.read().await;
206        Ok(if let Some(f) = filter {
207            entries
208                .iter()
209                .filter(|e| {
210                    if let Some(ref src) = f.source {
211                        if format!("{:?}", e.source) != *src {
212                            return false;
213                        }
214                    }
215                    if let Some(after) = f.after {
216                        if e.timestamp < after {
217                            return false;
218                        }
219                    }
220                    if let Some(before) = f.before {
221                        if e.timestamp > before {
222                            return false;
223                        }
224                    }
225                    true
226                })
227                .cloned()
228                .collect()
229        } else {
230            entries.clone()
231        })
232    }
233
234    async fn count(&self) -> Result<usize> {
235        Ok(self.entries.read().await.len())
236    }
237}