agentic_connect/engine/
db_engine.rs1use rusqlite::Connection as SqlConn;
4use std::collections::HashMap;
5use std::path::Path;
6
7use crate::types::{ConnectError, ConnectResult};
8
9pub struct DbConnection {
11 db: SqlConn,
12 db_type: DbType,
13 path: String,
14}
15
16#[derive(Debug, Clone, Copy, serde::Serialize)]
17pub enum DbType { Sqlite }
18
19#[derive(Debug, Clone, serde::Serialize)]
21pub struct ColumnInfo {
22 pub name: String,
23 pub data_type: String,
24 pub nullable: bool,
25 pub primary_key: bool,
26 pub default_value: Option<String>,
27}
28
29#[derive(Debug, Clone, serde::Serialize)]
31pub struct TableInfo {
32 pub name: String,
33 pub columns: Vec<ColumnInfo>,
34 pub row_count: Option<i64>,
35}
36
37#[derive(Debug, Clone, serde::Serialize)]
39pub struct QueryResult {
40 pub columns: Vec<String>,
41 pub rows: Vec<HashMap<String, serde_json::Value>>,
42 pub row_count: usize,
43 pub affected: usize,
44}
45
46impl DbConnection {
47 pub fn open_sqlite(path: &str) -> ConnectResult<Self> {
49 let db = if path == ":memory:" {
50 SqlConn::open_in_memory()?
51 } else {
52 SqlConn::open(Path::new(path))?
53 };
54 Ok(Self { db, db_type: DbType::Sqlite, path: path.into() })
55 }
56
57 pub fn from_url(url: &str) -> ConnectResult<Self> {
59 if url.starts_with("sqlite://") || url.ends_with(".db") || url.ends_with(".sqlite") {
60 let path = url.strip_prefix("sqlite://").unwrap_or(url);
61 Self::open_sqlite(path)
62 } else {
63 Err(ConnectError::NotSupported(
64 "Only SQLite is supported in this build. Use sqlx feature for Postgres/MySQL".into()
65 ))
66 }
67 }
68
69 pub fn query(&self, sql: &str) -> ConnectResult<QueryResult> {
71 let mut stmt = self.db.prepare(sql)?;
72 let col_names: Vec<String> = stmt.column_names().iter().map(|s| s.to_string()).collect();
73 let col_count = col_names.len();
74
75 let rows: Vec<HashMap<String, serde_json::Value>> = stmt
76 .query_map([], |row| {
77 let mut map = HashMap::new();
78 for i in 0..col_count {
79 let val = match row.get_ref(i) {
80 Ok(rusqlite::types::ValueRef::Null) => serde_json::Value::Null,
81 Ok(rusqlite::types::ValueRef::Integer(n)) => serde_json::json!(n),
82 Ok(rusqlite::types::ValueRef::Real(f)) => serde_json::json!(f),
83 Ok(rusqlite::types::ValueRef::Text(s)) => {
84 serde_json::Value::String(String::from_utf8_lossy(s).into())
85 }
86 Ok(rusqlite::types::ValueRef::Blob(b)) => {
87 serde_json::json!(format!("<blob {} bytes>", b.len()))
88 }
89 Err(_) => serde_json::Value::Null,
90 };
91 map.insert(col_names[i].clone(), val);
92 }
93 Ok(map)
94 })?
95 .filter_map(|r| r.ok())
96 .collect();
97
98 let count = rows.len();
99 Ok(QueryResult { columns: col_names, rows, row_count: count, affected: 0 })
100 }
101
102 pub fn execute(&self, sql: &str) -> ConnectResult<usize> {
104 let affected = self.db.execute(sql, [])?;
105 Ok(affected)
106 }
107
108 pub fn discover_schema(&self) -> ConnectResult<Vec<TableInfo>> {
110 let tables = self.query(
111 "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' ORDER BY name"
112 )?;
113
114 let mut result = Vec::new();
115 for row in &tables.rows {
116 if let Some(serde_json::Value::String(name)) = row.get("name") {
117 let columns = self.table_columns(name)?;
118 let count = self.query(&format!("SELECT COUNT(*) as cnt FROM \"{}\"", name))?;
119 let row_count = count.rows.first()
120 .and_then(|r| r.get("cnt"))
121 .and_then(|v| v.as_i64());
122 result.push(TableInfo { name: name.clone(), columns, row_count });
123 }
124 }
125 Ok(result)
126 }
127
128 pub fn table_columns(&self, table: &str) -> ConnectResult<Vec<ColumnInfo>> {
130 let info = self.query(&format!("PRAGMA table_info(\"{}\")", table))?;
131 let mut cols = Vec::new();
132 for row in &info.rows {
133 cols.push(ColumnInfo {
134 name: row.get("name").and_then(|v| v.as_str()).unwrap_or("").into(),
135 data_type: row.get("type").and_then(|v| v.as_str()).unwrap_or("").into(),
136 nullable: row.get("notnull").and_then(|v| v.as_i64()).unwrap_or(0) == 0,
137 primary_key: row.get("pk").and_then(|v| v.as_i64()).unwrap_or(0) != 0,
138 default_value: row.get("dflt_value").and_then(|v| v.as_str()).map(String::from),
139 });
140 }
141 Ok(cols)
142 }
143
144 pub fn db_type(&self) -> DbType { self.db_type }
145 pub fn path(&self) -> &str { &self.path }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151
152 #[test]
153 fn test_sqlite_query() {
154 let db = DbConnection::open_sqlite(":memory:").unwrap();
155 db.execute("CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)").unwrap();
156 db.execute("INSERT INTO test VALUES (1, 'alice')").unwrap();
157 db.execute("INSERT INTO test VALUES (2, 'bob')").unwrap();
158 let result = db.query("SELECT * FROM test ORDER BY id").unwrap();
159 assert_eq!(result.row_count, 2);
160 assert_eq!(result.columns, vec!["id", "name"]);
161 }
162
163 #[test]
164 fn test_schema_discovery() {
165 let db = DbConnection::open_sqlite(":memory:").unwrap();
166 db.execute("CREATE TABLE users (id INTEGER PRIMARY KEY, email TEXT NOT NULL, age REAL)").unwrap();
167 let schema = db.discover_schema().unwrap();
168 assert_eq!(schema.len(), 1);
169 assert_eq!(schema[0].name, "users");
170 assert_eq!(schema[0].columns.len(), 3);
171 assert!(schema[0].columns[0].primary_key);
172 }
173
174 #[test]
175 fn test_from_url_sqlite() {
176 let db = DbConnection::from_url("sqlite://:memory:").unwrap();
177 db.execute("CREATE TABLE t (x INTEGER)").unwrap();
178 let r = db.query("SELECT COUNT(*) as c FROM t").unwrap();
179 assert_eq!(r.rows[0]["c"], serde_json::json!(0));
180 }
181}