agentroot_core/providers/
sql.rs1use crate::db::hash_content;
7use crate::error::{AgentRootError, Result};
8use crate::providers::{ProviderConfig, SourceItem, SourceProvider};
9use async_trait::async_trait;
10use rusqlite::{params, Connection};
11use std::path::Path;
12
13pub struct SQLProvider;
15
16impl Default for SQLProvider {
17 fn default() -> Self {
18 Self::new()
19 }
20}
21
22impl SQLProvider {
23 pub fn new() -> Self {
25 Self
26 }
27
28 fn query_database(
30 &self,
31 db_path: &str,
32 query: &str,
33 config: &ProviderConfig,
34 ) -> Result<Vec<SourceItem>> {
35 let conn = Connection::open(db_path).map_err(|e| {
36 AgentRootError::Database(rusqlite::Error::SqliteFailure(
37 rusqlite::ffi::Error::new(1),
38 Some(format!("Failed to open database {}: {}", db_path, e)),
39 ))
40 })?;
41
42 let mut stmt = conn
43 .prepare(query)
44 .map_err(|e| AgentRootError::InvalidInput(format!("Invalid SQL query: {}", e)))?;
45
46 let id_column = config
47 .options
48 .get("id_column")
49 .map(|s| s.as_str())
50 .unwrap_or("id");
51 let title_column = config
52 .options
53 .get("title_column")
54 .map(|s| s.as_str())
55 .unwrap_or("title");
56 let content_column = config
57 .options
58 .get("content_column")
59 .map(|s| s.as_str())
60 .unwrap_or("content");
61
62 let column_count = stmt.column_count();
63 let column_names: Vec<String> = (0..column_count)
64 .map(|i| stmt.column_name(i).unwrap_or("").to_string())
65 .collect();
66
67 let id_idx = column_names
68 .iter()
69 .position(|name| name.eq_ignore_ascii_case(id_column))
70 .ok_or_else(|| {
71 AgentRootError::InvalidInput(format!(
72 "Column '{}' not found in query result",
73 id_column
74 ))
75 })?;
76
77 let title_idx = column_names
78 .iter()
79 .position(|name| name.eq_ignore_ascii_case(title_column));
80
81 let content_idx = column_names
82 .iter()
83 .position(|name| name.eq_ignore_ascii_case(content_column))
84 .ok_or_else(|| {
85 AgentRootError::InvalidInput(format!(
86 "Column '{}' not found in query result",
87 content_column
88 ))
89 })?;
90
91 let rows = stmt
92 .query_map(params![], |row| {
93 let id: String = match row.get_ref(id_idx)? {
94 rusqlite::types::ValueRef::Integer(i) => i.to_string(),
95 rusqlite::types::ValueRef::Text(s) => String::from_utf8_lossy(s).to_string(),
96 rusqlite::types::ValueRef::Real(f) => f.to_string(),
97 _ => row.get(id_idx)?,
98 };
99
100 let title: String = if let Some(idx) = title_idx {
101 row.get(idx).unwrap_or_else(|_| id.clone())
102 } else {
103 id.clone()
104 };
105 let content: String = row.get(content_idx)?;
106
107 Ok((id, title, content))
108 })
109 .map_err(AgentRootError::Database)?;
110
111 let mut items = Vec::new();
112 for row_result in rows {
113 let (id, title, content) = row_result.map_err(AgentRootError::Database)?;
114
115 if content.trim().is_empty() {
116 continue;
117 }
118
119 let hash = hash_content(&content);
120 let uri = format!("sql://{}/{}", db_path, id);
121
122 let mut item = SourceItem::new(uri, title, content, hash, "sql".to_string());
123 item.metadata
124 .insert("database".to_string(), db_path.to_string());
125 item.metadata.insert("row_id".to_string(), id.clone());
126 item.metadata.insert(
127 "table".to_string(),
128 config
129 .options
130 .get("table")
131 .cloned()
132 .unwrap_or_else(|| "unknown".to_string()),
133 );
134
135 items.push(item);
136 }
137
138 Ok(items)
139 }
140}
141
142#[async_trait]
143impl SourceProvider for SQLProvider {
144 fn provider_type(&self) -> &'static str {
145 "sql"
146 }
147
148 async fn list_items(&self, config: &ProviderConfig) -> Result<Vec<SourceItem>> {
149 let db_path = &config.base_path;
150
151 if !Path::new(db_path).exists() {
152 return Err(AgentRootError::InvalidInput(format!(
153 "Database file does not exist: {}",
154 db_path
155 )));
156 }
157
158 let query = if let Some(custom_query) = config.options.get("query") {
159 custom_query.clone()
160 } else if let Some(table) = config.options.get("table") {
161 let id_col = config
162 .options
163 .get("id_column")
164 .map(|s| s.as_str())
165 .unwrap_or("id");
166 let title_col = config
167 .options
168 .get("title_column")
169 .map(|s| s.as_str())
170 .unwrap_or("title");
171 let content_col = config
172 .options
173 .get("content_column")
174 .map(|s| s.as_str())
175 .unwrap_or("content");
176
177 format!(
178 "SELECT {}, {}, {} FROM {}",
179 id_col, title_col, content_col, table
180 )
181 } else {
182 return Err(AgentRootError::InvalidInput(
183 "SQL provider requires either 'query' or 'table' option".to_string(),
184 ));
185 };
186
187 self.query_database(db_path, &query, config)
188 }
189
190 async fn fetch_item(&self, uri: &str) -> Result<SourceItem> {
191 if !uri.starts_with("sql://") {
192 return Err(AgentRootError::InvalidInput(format!(
193 "Invalid SQL URI: {}. Expected format: sql://path/to/db.sqlite/id",
194 uri
195 )));
196 }
197
198 let parts: Vec<&str> = uri.strip_prefix("sql://").unwrap().splitn(2, '/').collect();
199 if parts.len() != 2 {
200 return Err(AgentRootError::InvalidInput(format!(
201 "Invalid SQL URI format: {}. Expected: sql://path/to/db.sqlite/id",
202 uri
203 )));
204 }
205
206 let (db_path, id) = (parts[0], parts[1]);
207
208 let conn = Connection::open(db_path)?;
209
210 let mut stmt = conn.prepare("SELECT id, title, content FROM items WHERE id = ?1")?;
211
212 let result = stmt.query_row(params![id], |row| {
213 let id_val: String = match row.get_ref(0)? {
214 rusqlite::types::ValueRef::Integer(i) => i.to_string(),
215 rusqlite::types::ValueRef::Text(s) => String::from_utf8_lossy(s).to_string(),
216 rusqlite::types::ValueRef::Real(f) => f.to_string(),
217 _ => row.get(0)?,
218 };
219 let title: String = row.get(1)?;
220 let content: String = row.get(2)?;
221 Ok((id_val, title, content))
222 })?;
223
224 let (id, title, content) = result;
225 let hash = hash_content(&content);
226
227 let mut item = SourceItem::new(uri.to_string(), title, content, hash, "sql".to_string());
228 item.metadata
229 .insert("database".to_string(), db_path.to_string());
230 item.metadata.insert("row_id".to_string(), id);
231
232 Ok(item)
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239 use tempfile::NamedTempFile;
240
241 fn create_test_db() -> NamedTempFile {
242 let temp_file = NamedTempFile::new().unwrap();
243 let conn = Connection::open(temp_file.path()).unwrap();
244
245 conn.execute(
246 "CREATE TABLE documents (
247 id INTEGER PRIMARY KEY,
248 title TEXT NOT NULL,
249 content TEXT NOT NULL
250 )",
251 [],
252 )
253 .unwrap();
254
255 conn.execute(
256 "INSERT INTO documents (id, title, content) VALUES (1, 'First Document', 'Content of first document')",
257 [],
258 )
259 .unwrap();
260
261 conn.execute(
262 "INSERT INTO documents (id, title, content) VALUES (2, 'Second Document', 'Content of second document')",
263 [],
264 )
265 .unwrap();
266
267 temp_file
268 }
269
270 #[test]
271 fn test_provider_type() {
272 let provider = SQLProvider::new();
273 assert_eq!(provider.provider_type(), "sql");
274 }
275
276 #[tokio::test]
277 async fn test_query_database() {
278 let temp_db = create_test_db();
279 let provider = SQLProvider::new();
280
281 let mut config =
282 ProviderConfig::new(temp_db.path().to_string_lossy().to_string(), "".to_string());
283 config
284 .options
285 .insert("table".to_string(), "documents".to_string());
286 config
287 .options
288 .insert("id_column".to_string(), "id".to_string());
289 config
290 .options
291 .insert("title_column".to_string(), "title".to_string());
292 config
293 .options
294 .insert("content_column".to_string(), "content".to_string());
295
296 let items = provider.list_items(&config).await.unwrap();
297 assert_eq!(items.len(), 2);
298 assert_eq!(items[0].title, "First Document");
299 assert_eq!(items[1].title, "Second Document");
300 }
301
302 #[tokio::test]
303 async fn test_custom_query() {
304 let temp_db = create_test_db();
305 let provider = SQLProvider::new();
306
307 let mut config =
308 ProviderConfig::new(temp_db.path().to_string_lossy().to_string(), "".to_string());
309 config.options.insert(
310 "query".to_string(),
311 "SELECT id, title, content FROM documents WHERE id = 1".to_string(),
312 );
313
314 let items = provider.list_items(&config).await.unwrap();
315 assert_eq!(items.len(), 1);
316 assert_eq!(items[0].title, "First Document");
317 }
318
319 #[tokio::test]
320 async fn test_missing_table_option() {
321 let temp_db = create_test_db();
322 let provider = SQLProvider::new();
323
324 let config =
325 ProviderConfig::new(temp_db.path().to_string_lossy().to_string(), "".to_string());
326
327 let result = provider.list_items(&config).await;
328 assert!(result.is_err());
329 }
330}