1use crate::config::SqliteSourceConfig;
4use async_trait::async_trait;
5use faucet_core::{FaucetError, Stream, StreamPage};
6use futures::TryStreamExt;
7use serde_json::Value;
8use sqlx::sqlite::SqlitePoolOptions;
9use sqlx::{Column, Row, SqlitePool};
10use std::pin::Pin;
11
12pub struct SqliteSource {
14 config: SqliteSourceConfig,
15 pool: SqlitePool,
16}
17
18impl SqliteSource {
19 pub async fn new(config: SqliteSourceConfig) -> Result<Self, FaucetError> {
21 faucet_core::validate_batch_size(config.batch_size)?;
22
23 let pool = SqlitePoolOptions::new()
24 .max_connections(config.max_connections)
25 .connect(&config.database_url)
26 .await
27 .map_err(|e| FaucetError::Config(format!("SQLite connection failed: {e}")))?;
28
29 Ok(Self { config, pool })
30 }
31}
32
33fn sqlite_value_to_json(row: &sqlx::sqlite::SqliteRow, col_name: &str) -> Value {
38 if let Ok(v) = row.try_get::<Value, _>(col_name) {
40 return v;
41 }
42
43 if let Ok(v) = row.try_get::<String, _>(col_name) {
44 return Value::String(v);
45 }
46 if let Ok(v) = row.try_get::<i64, _>(col_name) {
47 return Value::Number(v.into());
48 }
49 if let Ok(v) = row.try_get::<i32, _>(col_name) {
50 return Value::Number(v.into());
51 }
52 if let Ok(v) = row.try_get::<f64, _>(col_name) {
53 return serde_json::Number::from_f64(v)
54 .map(Value::Number)
55 .unwrap_or(Value::Null);
56 }
57 if let Ok(v) = row.try_get::<bool, _>(col_name) {
58 return Value::Bool(v);
59 }
60 if let Ok(v) = row.try_get::<Vec<u8>, _>(col_name) {
64 use base64::Engine as _;
65 return Value::String(base64::engine::general_purpose::STANDARD.encode(v));
66 }
67
68 Value::Null
69}
70
71fn resolve_query(
77 config: &SqliteSourceConfig,
78 context: &std::collections::HashMap<String, Value>,
79) -> (String, Vec<Value>) {
80 if context.is_empty() {
81 (config.query.clone(), Vec::new())
82 } else {
83 faucet_core::util::substitute_context_bind_params(&config.query, context, 1, |_| {
84 "?".to_string()
85 })
86 }
87}
88
89fn bind_params<'q>(
91 mut query: sqlx::query::Query<'q, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'q>>,
92 bind_values: &'q [Value],
93) -> sqlx::query::Query<'q, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'q>> {
94 for value in bind_values {
95 query = match value {
96 Value::String(s) => query.bind(s.clone()),
97 Value::Number(n) if n.is_i64() => query.bind(n.as_i64().unwrap()),
98 Value::Number(n) => query.bind(n.as_f64().unwrap_or(0.0)),
99 Value::Bool(b) => query.bind(*b),
100 Value::Null => query.bind(None::<String>),
101 _ => query.bind(value.to_string()),
102 };
103 }
104 query
105}
106
107fn row_to_json(row: &sqlx::sqlite::SqliteRow) -> Value {
110 let mut map = serde_json::Map::new();
111 for col in row.columns() {
112 let name = col.name().to_string();
113 let value = sqlite_value_to_json(row, &name);
114 map.insert(name, value);
115 }
116 Value::Object(map)
117}
118
119#[async_trait]
120impl faucet_core::Source for SqliteSource {
121 async fn fetch_with_context(
122 &self,
123 context: &std::collections::HashMap<String, serde_json::Value>,
124 ) -> Result<Vec<Value>, FaucetError> {
125 let (query_str, bind_values) = resolve_query(&self.config, context);
126 let query = bind_params(sqlx::query(&query_str), &bind_values);
127
128 let rows = query
129 .fetch_all(&self.pool)
130 .await
131 .map_err(|e| FaucetError::Config(format!("SQLite query failed: {e}")))?;
132
133 let records: Vec<Value> = rows.iter().map(row_to_json).collect();
134 tracing::info!(
135 rows = records.len(),
136 query = %self.config.query,
137 "SQLite source fetch complete"
138 );
139 Ok(records)
140 }
141
142 fn stream_pages<'a>(
157 &'a self,
158 context: &'a std::collections::HashMap<String, Value>,
159 _batch_size: usize,
160 ) -> Pin<Box<dyn Stream<Item = Result<StreamPage, FaucetError>> + Send + 'a>> {
161 let batch_size = self.config.batch_size;
162
163 Box::pin(async_stream::try_stream! {
164 let (query_str, bind_values) = resolve_query(&self.config, context);
165 let query = bind_params(sqlx::query(&query_str), &bind_values);
166
167 let mut rows = query.fetch(&self.pool);
168 let chunk = if batch_size == 0 { usize::MAX } else { batch_size };
169 let initial_capacity = if batch_size == 0 { 1024 } else { batch_size };
170 let mut buffer: Vec<Value> = Vec::with_capacity(initial_capacity);
171 let mut total = 0usize;
172
173 while let Some(row) = rows
174 .try_next()
175 .await
176 .map_err(|e| FaucetError::Config(format!("SQLite query failed: {e}")))?
177 {
178 buffer.push(row_to_json(&row));
179 if buffer.len() >= chunk {
180 let page = std::mem::replace(&mut buffer, Vec::with_capacity(initial_capacity));
181 total += page.len();
182 yield StreamPage { records: page, bookmark: None };
183 }
184 }
185 if !buffer.is_empty() {
186 total += buffer.len();
187 yield StreamPage { records: buffer, bookmark: None };
188 }
189
190 tracing::info!(
191 rows = total,
192 batch_size,
193 query = %self.config.query,
194 "SQLite source stream complete",
195 );
196 })
197 }
198
199 fn config_schema(&self) -> serde_json::Value {
200 serde_json::to_value(faucet_core::schema_for!(SqliteSourceConfig))
201 .expect("schema serialization")
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208 use faucet_core::Source;
209
210 #[tokio::test]
211 async fn fetch_from_memory_db() {
212 let config = SqliteSourceConfig::new("sqlite::memory:", "SELECT 1 AS val, 'hello' AS msg");
213 let source = SqliteSource::new(config).await.unwrap();
214 let records = source.fetch_all().await.unwrap();
215 assert_eq!(records.len(), 1);
216 assert_eq!(records[0]["val"], 1);
217 assert_eq!(records[0]["msg"], "hello");
218 }
219
220 #[tokio::test]
221 async fn fetch_from_table() {
222 let config = SqliteSourceConfig::new("sqlite::memory:", "SELECT 1");
223 let source = SqliteSource::new(config).await.unwrap();
224
225 sqlx::query("CREATE TABLE test_items (id INTEGER PRIMARY KEY, name TEXT, score REAL)")
227 .execute(&source.pool)
228 .await
229 .unwrap();
230 sqlx::query(
231 "INSERT INTO test_items (id, name, score) VALUES (1, 'Alice', 95.5), (2, 'Bob', 87.0)",
232 )
233 .execute(&source.pool)
234 .await
235 .unwrap();
236
237 let rows = sqlx::query("SELECT * FROM test_items ORDER BY id")
240 .fetch_all(&source.pool)
241 .await
242 .unwrap();
243
244 assert_eq!(rows.len(), 2);
245 let row0 = &rows[0];
246 assert_eq!(row0.try_get::<i64, _>("id").unwrap(), 1);
247 assert_eq!(row0.try_get::<String, _>("name").unwrap(), "Alice");
248 }
249
250 #[tokio::test]
251 async fn blob_column_decodes_to_base64() {
252 let config = SqliteSourceConfig::new("sqlite::memory:", "SELECT 1");
254 let source = SqliteSource::new(config).await.unwrap();
255 sqlx::query("CREATE TABLE b (id INTEGER, data BLOB)")
256 .execute(&source.pool)
257 .await
258 .unwrap();
259 sqlx::query("INSERT INTO b (id, data) VALUES (1, X'00FF')")
261 .execute(&source.pool)
262 .await
263 .unwrap();
264 let rows = sqlx::query("SELECT data FROM b")
265 .fetch_all(&source.pool)
266 .await
267 .unwrap();
268 let v = sqlite_value_to_json(&rows[0], "data");
269 assert_eq!(v, Value::String("AP8=".to_string()), "BLOB must be base64");
270 }
271
272 #[tokio::test]
273 async fn empty_result() {
274 let config = SqliteSourceConfig::new("sqlite::memory:", "SELECT 1 AS x WHERE 1 = 0");
275 let source = SqliteSource::new(config).await.unwrap();
276 let records = source.fetch_all().await.unwrap();
277 assert!(records.is_empty());
278 }
279
280 #[tokio::test]
281 async fn invalid_query_returns_error() {
282 let config = SqliteSourceConfig::new("sqlite::memory:", "INVALID SQL");
283 let source = SqliteSource::new(config).await.unwrap();
284 let result = source.fetch_all().await;
285 assert!(result.is_err());
286 }
287
288 #[tokio::test]
289 async fn fetch_with_context_substitutes_query_placeholders() {
290 let config =
291 SqliteSourceConfig::new("sqlite::memory:", "SELECT {val} AS result, {name} AS name");
292 let source = SqliteSource::new(config).await.unwrap();
293
294 let mut context = std::collections::HashMap::new();
295 context.insert("val".to_string(), serde_json::json!(42));
296 context.insert("name".to_string(), serde_json::json!("hello"));
297
298 let records = source.fetch_with_context(&context).await.unwrap();
299 assert_eq!(records.len(), 1);
300 assert_eq!(records[0]["result"], 42);
301 assert_eq!(records[0]["name"], "hello");
302 }
303
304 #[tokio::test]
305 async fn fetch_with_context_prevents_sql_injection() {
306 let config = SqliteSourceConfig::new("sqlite::memory:", "SELECT {val} AS result");
307 let source = SqliteSource::new(config).await.unwrap();
308
309 let mut context = std::collections::HashMap::new();
310 context.insert(
311 "val".to_string(),
312 serde_json::json!("1; DROP TABLE test; --"),
313 );
314
315 let records = source.fetch_with_context(&context).await.unwrap();
317 assert_eq!(records.len(), 1);
318 assert_eq!(records[0]["result"], "1; DROP TABLE test; --");
319 }
320
321 #[tokio::test]
322 async fn new_rejects_out_of_range_batch_size() {
323 let mut config = SqliteSourceConfig::new("sqlite::memory:", "SELECT 1");
324 config.batch_size = faucet_core::MAX_BATCH_SIZE + 1;
325 match SqliteSource::new(config).await {
326 Err(faucet_core::FaucetError::Config(m)) => {
327 assert!(m.contains("batch_size"), "got: {m}")
328 }
329 _ => panic!("expected a batch_size Config error"),
330 }
331 }
332}