use crate::config::SqliteSourceConfig;
use async_trait::async_trait;
use faucet_core::{FaucetError, Stream, StreamPage};
use futures::TryStreamExt;
use serde_json::Value;
use sqlx::sqlite::SqlitePoolOptions;
use sqlx::{Column, Row, SqlitePool};
use std::pin::Pin;
pub struct SqliteSource {
config: SqliteSourceConfig,
pool: SqlitePool,
}
impl SqliteSource {
pub async fn new(config: SqliteSourceConfig) -> Result<Self, FaucetError> {
faucet_core::validate_batch_size(config.batch_size)?;
let pool = SqlitePoolOptions::new()
.max_connections(config.max_connections)
.connect(&config.database_url)
.await
.map_err(|e| FaucetError::Config(format!("SQLite connection failed: {e}")))?;
Ok(Self { config, pool })
}
}
fn sqlite_value_to_json(row: &sqlx::sqlite::SqliteRow, col_name: &str) -> Value {
if let Ok(v) = row.try_get::<Value, _>(col_name) {
return v;
}
if let Ok(v) = row.try_get::<String, _>(col_name) {
return Value::String(v);
}
if let Ok(v) = row.try_get::<i64, _>(col_name) {
return Value::Number(v.into());
}
if let Ok(v) = row.try_get::<i32, _>(col_name) {
return Value::Number(v.into());
}
if let Ok(v) = row.try_get::<f64, _>(col_name) {
return serde_json::Number::from_f64(v)
.map(Value::Number)
.unwrap_or(Value::Null);
}
if let Ok(v) = row.try_get::<bool, _>(col_name) {
return Value::Bool(v);
}
if let Ok(v) = row.try_get::<Vec<u8>, _>(col_name) {
use base64::Engine as _;
return Value::String(base64::engine::general_purpose::STANDARD.encode(v));
}
Value::Null
}
fn resolve_query(
config: &SqliteSourceConfig,
context: &std::collections::HashMap<String, Value>,
) -> (String, Vec<Value>) {
if context.is_empty() {
(config.query.clone(), Vec::new())
} else {
faucet_core::util::substitute_context_bind_params(&config.query, context, 1, |_| {
"?".to_string()
})
}
}
fn bind_params<'q>(
mut query: sqlx::query::Query<'q, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'q>>,
bind_values: &'q [Value],
) -> sqlx::query::Query<'q, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'q>> {
for value in bind_values {
query = match value {
Value::String(s) => query.bind(s.clone()),
Value::Number(n) if n.is_i64() => query.bind(n.as_i64().unwrap()),
Value::Number(n) => query.bind(n.as_f64().unwrap_or(0.0)),
Value::Bool(b) => query.bind(*b),
Value::Null => query.bind(None::<String>),
_ => query.bind(value.to_string()),
};
}
query
}
fn row_to_json(row: &sqlx::sqlite::SqliteRow) -> Value {
let mut map = serde_json::Map::new();
for col in row.columns() {
let name = col.name().to_string();
let value = sqlite_value_to_json(row, &name);
map.insert(name, value);
}
Value::Object(map)
}
#[async_trait]
impl faucet_core::Source for SqliteSource {
async fn fetch_with_context(
&self,
context: &std::collections::HashMap<String, serde_json::Value>,
) -> Result<Vec<Value>, FaucetError> {
let (query_str, bind_values) = resolve_query(&self.config, context);
let query = bind_params(sqlx::query(&query_str), &bind_values);
let rows = query
.fetch_all(&self.pool)
.await
.map_err(|e| FaucetError::Config(format!("SQLite query failed: {e}")))?;
let records: Vec<Value> = rows.iter().map(row_to_json).collect();
tracing::info!(
rows = records.len(),
query = %self.config.query,
"SQLite source fetch complete"
);
Ok(records)
}
fn stream_pages<'a>(
&'a self,
context: &'a std::collections::HashMap<String, Value>,
_batch_size: usize,
) -> Pin<Box<dyn Stream<Item = Result<StreamPage, FaucetError>> + Send + 'a>> {
let batch_size = self.config.batch_size;
Box::pin(async_stream::try_stream! {
let (query_str, bind_values) = resolve_query(&self.config, context);
let query = bind_params(sqlx::query(&query_str), &bind_values);
let mut rows = query.fetch(&self.pool);
let chunk = if batch_size == 0 { usize::MAX } else { batch_size };
let initial_capacity = if batch_size == 0 { 1024 } else { batch_size };
let mut buffer: Vec<Value> = Vec::with_capacity(initial_capacity);
let mut total = 0usize;
while let Some(row) = rows
.try_next()
.await
.map_err(|e| FaucetError::Config(format!("SQLite query failed: {e}")))?
{
buffer.push(row_to_json(&row));
if buffer.len() >= chunk {
let page = std::mem::replace(&mut buffer, Vec::with_capacity(initial_capacity));
total += page.len();
yield StreamPage { records: page, bookmark: None };
}
}
if !buffer.is_empty() {
total += buffer.len();
yield StreamPage { records: buffer, bookmark: None };
}
tracing::info!(
rows = total,
batch_size,
query = %self.config.query,
"SQLite source stream complete",
);
})
}
fn config_schema(&self) -> serde_json::Value {
serde_json::to_value(faucet_core::schema_for!(SqliteSourceConfig))
.expect("schema serialization")
}
}
#[cfg(test)]
mod tests {
use super::*;
use faucet_core::Source;
#[tokio::test]
async fn fetch_from_memory_db() {
let config = SqliteSourceConfig::new("sqlite::memory:", "SELECT 1 AS val, 'hello' AS msg");
let source = SqliteSource::new(config).await.unwrap();
let records = source.fetch_all().await.unwrap();
assert_eq!(records.len(), 1);
assert_eq!(records[0]["val"], 1);
assert_eq!(records[0]["msg"], "hello");
}
#[tokio::test]
async fn fetch_from_table() {
let config = SqliteSourceConfig::new("sqlite::memory:", "SELECT 1");
let source = SqliteSource::new(config).await.unwrap();
sqlx::query("CREATE TABLE test_items (id INTEGER PRIMARY KEY, name TEXT, score REAL)")
.execute(&source.pool)
.await
.unwrap();
sqlx::query(
"INSERT INTO test_items (id, name, score) VALUES (1, 'Alice', 95.5), (2, 'Bob', 87.0)",
)
.execute(&source.pool)
.await
.unwrap();
let rows = sqlx::query("SELECT * FROM test_items ORDER BY id")
.fetch_all(&source.pool)
.await
.unwrap();
assert_eq!(rows.len(), 2);
let row0 = &rows[0];
assert_eq!(row0.try_get::<i64, _>("id").unwrap(), 1);
assert_eq!(row0.try_get::<String, _>("name").unwrap(), "Alice");
}
#[tokio::test]
async fn blob_column_decodes_to_base64() {
let config = SqliteSourceConfig::new("sqlite::memory:", "SELECT 1");
let source = SqliteSource::new(config).await.unwrap();
sqlx::query("CREATE TABLE b (id INTEGER, data BLOB)")
.execute(&source.pool)
.await
.unwrap();
sqlx::query("INSERT INTO b (id, data) VALUES (1, X'00FF')")
.execute(&source.pool)
.await
.unwrap();
let rows = sqlx::query("SELECT data FROM b")
.fetch_all(&source.pool)
.await
.unwrap();
let v = sqlite_value_to_json(&rows[0], "data");
assert_eq!(v, Value::String("AP8=".to_string()), "BLOB must be base64");
}
#[tokio::test]
async fn empty_result() {
let config = SqliteSourceConfig::new("sqlite::memory:", "SELECT 1 AS x WHERE 1 = 0");
let source = SqliteSource::new(config).await.unwrap();
let records = source.fetch_all().await.unwrap();
assert!(records.is_empty());
}
#[tokio::test]
async fn invalid_query_returns_error() {
let config = SqliteSourceConfig::new("sqlite::memory:", "INVALID SQL");
let source = SqliteSource::new(config).await.unwrap();
let result = source.fetch_all().await;
assert!(result.is_err());
}
#[tokio::test]
async fn fetch_with_context_substitutes_query_placeholders() {
let config =
SqliteSourceConfig::new("sqlite::memory:", "SELECT {val} AS result, {name} AS name");
let source = SqliteSource::new(config).await.unwrap();
let mut context = std::collections::HashMap::new();
context.insert("val".to_string(), serde_json::json!(42));
context.insert("name".to_string(), serde_json::json!("hello"));
let records = source.fetch_with_context(&context).await.unwrap();
assert_eq!(records.len(), 1);
assert_eq!(records[0]["result"], 42);
assert_eq!(records[0]["name"], "hello");
}
#[tokio::test]
async fn fetch_with_context_prevents_sql_injection() {
let config = SqliteSourceConfig::new("sqlite::memory:", "SELECT {val} AS result");
let source = SqliteSource::new(config).await.unwrap();
let mut context = std::collections::HashMap::new();
context.insert(
"val".to_string(),
serde_json::json!("1; DROP TABLE test; --"),
);
let records = source.fetch_with_context(&context).await.unwrap();
assert_eq!(records.len(), 1);
assert_eq!(records[0]["result"], "1; DROP TABLE test; --");
}
#[tokio::test]
async fn new_rejects_out_of_range_batch_size() {
let mut config = SqliteSourceConfig::new("sqlite::memory:", "SELECT 1");
config.batch_size = faucet_core::MAX_BATCH_SIZE + 1;
match SqliteSource::new(config).await {
Err(faucet_core::FaucetError::Config(m)) => {
assert!(m.contains("batch_size"), "got: {m}")
}
_ => panic!("expected a batch_size Config error"),
}
}
}