agentroot_core/providers/
sql.rs

1//! SQL Provider for indexing database content
2//!
3//! Supports SQLite databases with configurable queries.
4//! Can be extended to support PostgreSQL and MySQL.
5
6use crate::db::hash_content;
7use crate::error::{AgentRootError, Result};
8use crate::providers::{ProviderConfig, SourceItem, SourceProvider};
9use async_trait::async_trait;
10use rusqlite::{params, Connection};
11use std::path::Path;
12
13/// Provider for extracting content from SQL databases
14pub struct SQLProvider;
15
16impl Default for SQLProvider {
17    fn default() -> Self {
18        Self::new()
19    }
20}
21
22impl SQLProvider {
23    /// Create a new SQLProvider
24    pub fn new() -> Self {
25        Self
26    }
27
28    /// Execute query and extract rows as SourceItems
29    fn query_database(
30        &self,
31        db_path: &str,
32        query: &str,
33        config: &ProviderConfig,
34    ) -> Result<Vec<SourceItem>> {
35        let conn = Connection::open(db_path).map_err(|e| {
36            AgentRootError::Database(rusqlite::Error::SqliteFailure(
37                rusqlite::ffi::Error::new(1),
38                Some(format!("Failed to open database {}: {}", db_path, e)),
39            ))
40        })?;
41
42        let mut stmt = conn
43            .prepare(query)
44            .map_err(|e| AgentRootError::InvalidInput(format!("Invalid SQL query: {}", e)))?;
45
46        let id_column = config
47            .options
48            .get("id_column")
49            .map(|s| s.as_str())
50            .unwrap_or("id");
51        let title_column = config
52            .options
53            .get("title_column")
54            .map(|s| s.as_str())
55            .unwrap_or("title");
56        let content_column = config
57            .options
58            .get("content_column")
59            .map(|s| s.as_str())
60            .unwrap_or("content");
61
62        let column_count = stmt.column_count();
63        let column_names: Vec<String> = (0..column_count)
64            .map(|i| stmt.column_name(i).unwrap_or("").to_string())
65            .collect();
66
67        let id_idx = column_names
68            .iter()
69            .position(|name| name.eq_ignore_ascii_case(id_column))
70            .ok_or_else(|| {
71                AgentRootError::InvalidInput(format!(
72                    "Column '{}' not found in query result",
73                    id_column
74                ))
75            })?;
76
77        let title_idx = column_names
78            .iter()
79            .position(|name| name.eq_ignore_ascii_case(title_column));
80
81        let content_idx = column_names
82            .iter()
83            .position(|name| name.eq_ignore_ascii_case(content_column))
84            .ok_or_else(|| {
85                AgentRootError::InvalidInput(format!(
86                    "Column '{}' not found in query result",
87                    content_column
88                ))
89            })?;
90
91        let rows = stmt
92            .query_map(params![], |row| {
93                let id: String = match row.get_ref(id_idx)? {
94                    rusqlite::types::ValueRef::Integer(i) => i.to_string(),
95                    rusqlite::types::ValueRef::Text(s) => String::from_utf8_lossy(s).to_string(),
96                    rusqlite::types::ValueRef::Real(f) => f.to_string(),
97                    _ => row.get(id_idx)?,
98                };
99
100                let title: String = if let Some(idx) = title_idx {
101                    row.get(idx).unwrap_or_else(|_| id.clone())
102                } else {
103                    id.clone()
104                };
105                let content: String = row.get(content_idx)?;
106
107                Ok((id, title, content))
108            })
109            .map_err(AgentRootError::Database)?;
110
111        let mut items = Vec::new();
112        for row_result in rows {
113            let (id, title, content) = row_result.map_err(AgentRootError::Database)?;
114
115            if content.trim().is_empty() {
116                continue;
117            }
118
119            let hash = hash_content(&content);
120            let uri = format!("sql://{}/{}", db_path, id);
121
122            let mut item = SourceItem::new(uri, title, content, hash, "sql".to_string());
123            item.metadata
124                .insert("database".to_string(), db_path.to_string());
125            item.metadata.insert("row_id".to_string(), id.clone());
126            item.metadata.insert(
127                "table".to_string(),
128                config
129                    .options
130                    .get("table")
131                    .cloned()
132                    .unwrap_or_else(|| "unknown".to_string()),
133            );
134
135            items.push(item);
136        }
137
138        Ok(items)
139    }
140}
141
142#[async_trait]
143impl SourceProvider for SQLProvider {
144    fn provider_type(&self) -> &'static str {
145        "sql"
146    }
147
148    async fn list_items(&self, config: &ProviderConfig) -> Result<Vec<SourceItem>> {
149        let db_path = &config.base_path;
150
151        if !Path::new(db_path).exists() {
152            return Err(AgentRootError::InvalidInput(format!(
153                "Database file does not exist: {}",
154                db_path
155            )));
156        }
157
158        let query = if let Some(custom_query) = config.options.get("query") {
159            custom_query.clone()
160        } else if let Some(table) = config.options.get("table") {
161            let id_col = config
162                .options
163                .get("id_column")
164                .map(|s| s.as_str())
165                .unwrap_or("id");
166            let title_col = config
167                .options
168                .get("title_column")
169                .map(|s| s.as_str())
170                .unwrap_or("title");
171            let content_col = config
172                .options
173                .get("content_column")
174                .map(|s| s.as_str())
175                .unwrap_or("content");
176
177            format!(
178                "SELECT {}, {}, {} FROM {}",
179                id_col, title_col, content_col, table
180            )
181        } else {
182            return Err(AgentRootError::InvalidInput(
183                "SQL provider requires either 'query' or 'table' option".to_string(),
184            ));
185        };
186
187        self.query_database(db_path, &query, config)
188    }
189
190    async fn fetch_item(&self, uri: &str) -> Result<SourceItem> {
191        if !uri.starts_with("sql://") {
192            return Err(AgentRootError::InvalidInput(format!(
193                "Invalid SQL URI: {}. Expected format: sql://path/to/db.sqlite/id",
194                uri
195            )));
196        }
197
198        let parts: Vec<&str> = uri.strip_prefix("sql://").unwrap().splitn(2, '/').collect();
199        if parts.len() != 2 {
200            return Err(AgentRootError::InvalidInput(format!(
201                "Invalid SQL URI format: {}. Expected: sql://path/to/db.sqlite/id",
202                uri
203            )));
204        }
205
206        let (db_path, id) = (parts[0], parts[1]);
207
208        let conn = Connection::open(db_path)?;
209
210        let mut stmt = conn.prepare("SELECT id, title, content FROM items WHERE id = ?1")?;
211
212        let result = stmt.query_row(params![id], |row| {
213            let id_val: String = match row.get_ref(0)? {
214                rusqlite::types::ValueRef::Integer(i) => i.to_string(),
215                rusqlite::types::ValueRef::Text(s) => String::from_utf8_lossy(s).to_string(),
216                rusqlite::types::ValueRef::Real(f) => f.to_string(),
217                _ => row.get(0)?,
218            };
219            let title: String = row.get(1)?;
220            let content: String = row.get(2)?;
221            Ok((id_val, title, content))
222        })?;
223
224        let (id, title, content) = result;
225        let hash = hash_content(&content);
226
227        let mut item = SourceItem::new(uri.to_string(), title, content, hash, "sql".to_string());
228        item.metadata
229            .insert("database".to_string(), db_path.to_string());
230        item.metadata.insert("row_id".to_string(), id);
231
232        Ok(item)
233    }
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239    use tempfile::NamedTempFile;
240
241    fn create_test_db() -> NamedTempFile {
242        let temp_file = NamedTempFile::new().unwrap();
243        let conn = Connection::open(temp_file.path()).unwrap();
244
245        conn.execute(
246            "CREATE TABLE documents (
247                id INTEGER PRIMARY KEY,
248                title TEXT NOT NULL,
249                content TEXT NOT NULL
250            )",
251            [],
252        )
253        .unwrap();
254
255        conn.execute(
256            "INSERT INTO documents (id, title, content) VALUES (1, 'First Document', 'Content of first document')",
257            [],
258        )
259        .unwrap();
260
261        conn.execute(
262            "INSERT INTO documents (id, title, content) VALUES (2, 'Second Document', 'Content of second document')",
263            [],
264        )
265        .unwrap();
266
267        temp_file
268    }
269
270    #[test]
271    fn test_provider_type() {
272        let provider = SQLProvider::new();
273        assert_eq!(provider.provider_type(), "sql");
274    }
275
276    #[tokio::test]
277    async fn test_query_database() {
278        let temp_db = create_test_db();
279        let provider = SQLProvider::new();
280
281        let mut config =
282            ProviderConfig::new(temp_db.path().to_string_lossy().to_string(), "".to_string());
283        config
284            .options
285            .insert("table".to_string(), "documents".to_string());
286        config
287            .options
288            .insert("id_column".to_string(), "id".to_string());
289        config
290            .options
291            .insert("title_column".to_string(), "title".to_string());
292        config
293            .options
294            .insert("content_column".to_string(), "content".to_string());
295
296        let items = provider.list_items(&config).await.unwrap();
297        assert_eq!(items.len(), 2);
298        assert_eq!(items[0].title, "First Document");
299        assert_eq!(items[1].title, "Second Document");
300    }
301
302    #[tokio::test]
303    async fn test_custom_query() {
304        let temp_db = create_test_db();
305        let provider = SQLProvider::new();
306
307        let mut config =
308            ProviderConfig::new(temp_db.path().to_string_lossy().to_string(), "".to_string());
309        config.options.insert(
310            "query".to_string(),
311            "SELECT id, title, content FROM documents WHERE id = 1".to_string(),
312        );
313
314        let items = provider.list_items(&config).await.unwrap();
315        assert_eq!(items.len(), 1);
316        assert_eq!(items[0].title, "First Document");
317    }
318
319    #[tokio::test]
320    async fn test_missing_table_option() {
321        let temp_db = create_test_db();
322        let provider = SQLProvider::new();
323
324        let config =
325            ProviderConfig::new(temp_db.path().to_string_lossy().to_string(), "".to_string());
326
327        let result = provider.list_items(&config).await;
328        assert!(result.is_err());
329    }
330}