sage_runtime/tools/
database.rs1use crate::error::{SageError, SageResult};
7use crate::mock::{try_get_mock, MockResponse};
8
9#[cfg(feature = "database")]
10use sqlx::{any::AnyRow, AnyPool, Column, Row};
11
12#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
14pub struct DbRow {
15 pub columns: Vec<String>,
17 pub values: Vec<String>,
19}
20
21#[derive(Debug, Clone)]
25pub struct DatabaseClient {
26 #[cfg(feature = "database")]
27 pool: AnyPool,
28 #[cfg(not(feature = "database"))]
29 _marker: std::marker::PhantomData<()>,
30}
31
32impl DatabaseClient {
33 #[cfg(feature = "database")]
38 pub async fn connect(url: &str) -> SageResult<Self> {
39 sqlx::any::install_default_drivers();
41
42 let pool = AnyPool::connect(url)
43 .await
44 .map_err(|e| SageError::Tool(format!("Database connection failed: {e}")))?;
45 Ok(Self { pool })
46 }
47
48 #[cfg(not(feature = "database"))]
50 pub async fn connect(_url: &str) -> SageResult<Self> {
51 Err(SageError::Tool(
52 "Database support not enabled. Compile with the 'database' feature.".to_string(),
53 ))
54 }
55
56 #[cfg(feature = "database")]
61 pub async fn from_env() -> SageResult<Self> {
62 let url = std::env::var("SAGE_DATABASE_URL")
63 .map_err(|_| SageError::Tool("SAGE_DATABASE_URL environment variable not set".to_string()))?;
64 Self::connect(&url).await
65 }
66
67 #[cfg(not(feature = "database"))]
69 pub async fn from_env() -> SageResult<Self> {
70 Err(SageError::Tool(
71 "Database support not enabled. Compile with the 'database' feature.".to_string(),
72 ))
73 }
74
75 #[cfg(feature = "database")]
83 pub async fn query(&self, sql: String) -> SageResult<Vec<DbRow>> {
84 if let Some(mock_response) = try_get_mock("Database", "query") {
86 return Self::apply_mock_vec(mock_response);
87 }
88
89 let rows: Vec<AnyRow> = sqlx::query(&sql)
90 .fetch_all(&self.pool)
91 .await
92 .map_err(|e| SageError::Tool(format!("Query failed: {e}")))?;
93
94 let result: Vec<DbRow> = rows
95 .iter()
96 .map(|row| {
97 let columns: Vec<String> = row.columns().iter().map(|c| c.name().to_string()).collect();
98 let values: Vec<String> = (0..row.columns().len())
99 .map(|i| {
100 if let Ok(v) = row.try_get::<String, _>(i) {
102 v
103 } else if let Ok(v) = row.try_get::<i64, _>(i) {
104 v.to_string()
105 } else if let Ok(v) = row.try_get::<i32, _>(i) {
106 v.to_string()
107 } else if let Ok(v) = row.try_get::<f64, _>(i) {
108 v.to_string()
109 } else if let Ok(v) = row.try_get::<bool, _>(i) {
110 v.to_string()
111 } else {
112 row.try_get::<Option<String>, _>(i)
114 .ok()
115 .flatten()
116 .unwrap_or_else(|| "null".to_string())
117 }
118 })
119 .collect();
120 DbRow { columns, values }
121 })
122 .collect();
123
124 Ok(result)
125 }
126
127 #[cfg(not(feature = "database"))]
129 pub async fn query(&self, _sql: String) -> SageResult<Vec<DbRow>> {
130 if let Some(mock_response) = try_get_mock("Database", "query") {
132 return Self::apply_mock_vec(mock_response);
133 }
134
135 Err(SageError::Tool(
136 "Database support not enabled. Compile with the 'database' feature.".to_string(),
137 ))
138 }
139
140 #[cfg(feature = "database")]
148 pub async fn execute(&self, sql: String) -> SageResult<i64> {
149 if let Some(mock_response) = try_get_mock("Database", "execute") {
151 return Self::apply_mock_i64(mock_response);
152 }
153
154 let result = sqlx::query(&sql)
155 .execute(&self.pool)
156 .await
157 .map_err(|e| SageError::Tool(format!("Execute failed: {e}")))?;
158
159 Ok(result.rows_affected() as i64)
160 }
161
162 #[cfg(not(feature = "database"))]
164 pub async fn execute(&self, _sql: String) -> SageResult<i64> {
165 if let Some(mock_response) = try_get_mock("Database", "execute") {
167 return Self::apply_mock_i64(mock_response);
168 }
169
170 Err(SageError::Tool(
171 "Database support not enabled. Compile with the 'database' feature.".to_string(),
172 ))
173 }
174
175 fn apply_mock_vec(mock_response: MockResponse) -> SageResult<Vec<DbRow>> {
177 match mock_response {
178 MockResponse::Value(v) => serde_json::from_value(v)
179 .map_err(|e| SageError::Tool(format!("mock deserialize: {e}"))),
180 MockResponse::Fail(msg) => Err(SageError::Tool(msg)),
181 }
182 }
183
184 fn apply_mock_i64(mock_response: MockResponse) -> SageResult<i64> {
186 match mock_response {
187 MockResponse::Value(v) => serde_json::from_value(v)
188 .map_err(|e| SageError::Tool(format!("mock deserialize: {e}"))),
189 MockResponse::Fail(msg) => Err(SageError::Tool(msg)),
190 }
191 }
192}
193
194#[cfg(all(test, feature = "database"))]
195mod tests {
196 use super::*;
197
198 #[tokio::test]
199 async fn database_connect_sqlite() {
200 let client = DatabaseClient::connect("sqlite:file::memory:?mode=memory&cache=shared").await.unwrap();
202 drop(client);
203 }
204
205 #[tokio::test]
206 async fn database_execute_and_query() {
207 let temp_dir = tempfile::tempdir().unwrap();
209 let db_path = temp_dir.path().join("test.db");
210 std::fs::write(&db_path, "").unwrap();
212 let url = format!("sqlite:{}?mode=rwc", db_path.display());
213
214 let client = DatabaseClient::connect(&url).await.unwrap();
215
216 client
218 .execute("CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)".to_string())
219 .await
220 .unwrap();
221
222 let affected = client
224 .execute("INSERT INTO test (id, name) VALUES (1, 'Alice'), (2, 'Bob')".to_string())
225 .await
226 .unwrap();
227 assert_eq!(affected, 2);
228
229 let rows = client
231 .query("SELECT id, name FROM test ORDER BY id".to_string())
232 .await
233 .unwrap();
234 assert_eq!(rows.len(), 2);
235 assert_eq!(rows[0].columns, vec!["id", "name"]);
236 assert_eq!(rows[0].values, vec!["1", "Alice"]);
237 assert_eq!(rows[1].values, vec!["2", "Bob"]);
238 }
239
240 #[tokio::test]
241 async fn database_query_select_one() {
242 let client = DatabaseClient::connect("sqlite:file::memory:?mode=memory&cache=shared").await.unwrap();
243 let rows = client.query("SELECT 1 as value".to_string()).await.unwrap();
244 assert_eq!(rows.len(), 1);
245 assert_eq!(rows[0].columns, vec!["value"]);
246 assert_eq!(rows[0].values, vec!["1"]);
247 }
248}