use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use bytes::Bytes;
use serde_json::json;
use sqlx::AnyPool;
use sqlx::any::AnyRow;
use sqlx::pool::PoolOptions;
use tokio::sync::OnceCell;
use tower::Service;
use tracing::{debug, error, warn};
use crate::config::{SqlEndpointConfig, SqlOutputType, enrich_db_url_with_ssl};
use crate::headers;
use crate::query::{PreparedQuery, is_select_query, parse_query_template, resolve_params};
use crate::utils::{bind_json_values, row_to_json};
use camel_component_api::{Body, CamelError, Exchange, Message, StreamBody, StreamMetadata};
#[derive(Clone)]
pub struct SqlProducer {
pub(crate) config: SqlEndpointConfig,
pub(crate) pool: Arc<OnceCell<AnyPool>>,
}
impl SqlProducer {
pub fn new(config: SqlEndpointConfig, pool: Arc<OnceCell<AnyPool>>) -> Self {
Self { config, pool }
}
pub(crate) fn resolve_query_source(exchange: &Exchange, config: &SqlEndpointConfig) -> String {
if let Some(query_value) = exchange.input.header(headers::QUERY)
&& let Some(query_str) = query_value.as_str()
{
return query_str.to_string();
}
if config.use_message_body_for_sql
&& let Some(body_text) = exchange.input.body.as_text()
{
return body_text.to_string();
}
config.query.clone()
}
}
impl Service<Exchange> for SqlProducer {
type Response = Exchange;
type Error = CamelError;
type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, mut exchange: Exchange) -> Self::Future {
let mut config = self.config.clone();
let pool_cell = Arc::clone(&self.pool);
Box::pin(async move {
let pool: &AnyPool = pool_cell
.get_or_try_init(|| async {
config.resolve_defaults();
let db_url = enrich_db_url_with_ssl(&config.db_url, &config)?;
sqlx::any::install_default_drivers();
let opts: PoolOptions<sqlx::Any> = PoolOptions::new()
.max_connections(
config
.max_connections
.expect("must be Some after resolve_defaults()"),
)
.min_connections(
config
.min_connections
.expect("must be Some after resolve_defaults()"),
)
.idle_timeout(Duration::from_secs(
config
.idle_timeout_secs
.expect("must be Some after resolve_defaults()"),
))
.max_lifetime(Duration::from_secs(
config
.max_lifetime_secs
.expect("must be Some after resolve_defaults()"),
));
opts.connect(&db_url).await.map_err(|e| {
error!("Failed to connect to database: {}", e);
CamelError::EndpointCreationFailed(format!(
"Failed to connect to database: {}",
e
))
})
})
.await
.map_err(|e: CamelError| {
error!("Pool initialization failed: {}", e);
e.clone()
})?;
let query_str = Self::resolve_query_source(&exchange, &config);
debug!("Executing SQL: {}", query_str);
if config.batch {
execute_batch(pool, &config, &mut exchange).await?;
} else {
let template = parse_query_template(&query_str, config.placeholder)?;
let mut prepared = resolve_params(&template, &exchange, &config.in_separator)?;
if let Some(params_value) = exchange.input.header(headers::PARAMETERS) {
if let Some(arr) = params_value.as_array() {
if arr.len() != prepared.bindings.len() {
warn!(
expected = prepared.bindings.len(),
got = arr.len(),
header = headers::PARAMETERS,
"Parameter count mismatch — SQL has {} placeholders but header provides {} values",
prepared.bindings.len(),
arr.len()
);
}
debug!(
"Overriding bindings from {} header with {} parameters",
headers::PARAMETERS,
arr.len()
);
prepared.bindings = arr.clone();
} else {
warn!(
header = headers::PARAMETERS,
"Header is present but not a JSON array — ignoring parameter override"
);
}
}
debug!("Executing SQL: {}", prepared.sql);
if is_select_query(&prepared.sql) {
execute_select(pool, &prepared, &config, &mut exchange).await?;
} else {
execute_modify(pool, &prepared, &config, &mut exchange).await?;
}
}
Ok(exchange)
})
}
}
async fn execute_select(
pool: &AnyPool,
prepared: &PreparedQuery,
config: &SqlEndpointConfig,
exchange: &mut Exchange,
) -> Result<(), CamelError> {
match config.output_type {
SqlOutputType::SelectOne => {
let mut query = sqlx::query(&prepared.sql);
query = bind_json_values(query, &prepared.bindings);
let rows: Vec<AnyRow> = query.fetch_all(pool).await.map_err(|e| {
error!("Query execution failed: {}", e);
CamelError::ProcessorError(format!("Query execution failed: {}", e))
})?;
let count = rows.len();
let json_rows: Vec<serde_json::Value> = rows
.iter()
.map(row_to_json)
.collect::<Result<Vec<_>, _>>()?;
if let Some(first_row) = json_rows.into_iter().next() {
exchange.input.body = Body::Json(first_row);
} else {
exchange.input.body = Body::Empty;
}
debug!("SelectOne returned {} row", if count > 0 { 1 } else { 0 });
exchange
.input
.set_header(headers::ROW_COUNT, serde_json::json!(count));
}
SqlOutputType::SelectList => {
let mut query = sqlx::query(&prepared.sql);
query = bind_json_values(query, &prepared.bindings);
let rows: Vec<AnyRow> = query.fetch_all(pool).await.map_err(|e| {
error!("Query execution failed: {}", e);
CamelError::ProcessorError(format!("Query execution failed: {}", e))
})?;
let count = rows.len();
let json_rows: Vec<serde_json::Value> = rows
.iter()
.map(row_to_json)
.collect::<Result<Vec<_>, _>>()?;
exchange.input.body = Body::Json(serde_json::Value::Array(json_rows));
debug!("SelectList returned {} rows", count);
exchange
.input
.set_header(headers::ROW_COUNT, serde_json::json!(count));
}
SqlOutputType::StreamList => {
use futures::TryStreamExt;
let pool_clone = pool.clone();
let sql_str = prepared.sql.clone();
let bindings = prepared.bindings.clone();
let byte_stream = async_stream::try_stream! {
let mut q = sqlx::query(&sql_str);
q = bind_json_values(q, &bindings);
let mut rows = q.fetch(&pool_clone);
while let Some(row) = rows.try_next().await.map_err(|e| {
CamelError::ProcessorError(format!("Query execution failed: {}", e))
})? {
let json_val = row_to_json(&row).map_err(|e| {
CamelError::ProcessorError(format!("JSON serialization failed: {}", e))
})?;
let mut bytes = serde_json::to_vec(&json_val)
.map_err(|e| CamelError::ProcessorError(format!("JSON serialization failed: {}", e)))?;
bytes.push(b'\n');
yield Bytes::from(bytes);
}
};
exchange.input.body = Body::Stream(StreamBody {
stream: Arc::new(tokio::sync::Mutex::new(Some(Box::pin(byte_stream)))),
metadata: StreamMetadata {
content_type: Some("application/x-ndjson".to_string()),
size_hint: None,
origin: None,
},
});
debug!("StreamList: created lazy stream (rows fetched on demand)");
}
}
Ok(())
}
async fn execute_modify(
pool: &AnyPool,
prepared: &PreparedQuery,
config: &SqlEndpointConfig,
exchange: &mut Exchange,
) -> Result<(), CamelError> {
let mut query = sqlx::query(&prepared.sql);
query = bind_json_values(query, &prepared.bindings);
let result = query.execute(pool).await.map_err(|e| {
error!("Query execution failed: {}", e);
CamelError::ProcessorError(format!("Query execution failed: {}", e))
})?;
let rows_affected = result.rows_affected();
if let Some(expected) = config.expected_update_count
&& rows_affected as i64 != expected
{
error!("Expected {} rows affected, got {}", expected, rows_affected);
return Err(CamelError::ProcessorError(format!(
"Expected {} rows affected, got {}",
expected, rows_affected
)));
}
exchange
.input
.set_header(headers::UPDATE_COUNT, serde_json::json!(rows_affected));
if config.noop {
} else {
exchange.input.body = Body::Json(json!({ "rowsAffected": rows_affected }));
}
debug!("Modify query affected {} rows", rows_affected);
Ok(())
}
async fn execute_batch(
pool: &AnyPool,
config: &SqlEndpointConfig,
exchange: &mut Exchange,
) -> Result<(), CamelError> {
let body_json = match &exchange.input.body {
Body::Json(val) => val,
_ => {
return Err(CamelError::ProcessorError(
"Batch mode requires body to be a JSON array of arrays".to_string(),
));
}
};
let batch_data = body_json
.as_array()
.ok_or_else(|| {
CamelError::ProcessorError("Batch mode requires body to be a JSON array".to_string())
})?
.clone();
let template = parse_query_template(&config.query, config.placeholder)?;
let mut tx = pool.begin().await.map_err(|e| {
error!("Failed to begin transaction: {}", e);
CamelError::ProcessorError(format!("Failed to begin transaction: {}", e))
})?;
let mut total_rows_affected: u64 = 0;
for (batch_idx, params_array) in batch_data.into_iter().enumerate() {
params_array.as_array().ok_or_else(|| {
CamelError::ProcessorError(format!(
"Batch item at index {} must be a JSON array of parameters",
batch_idx
))
})?;
let temp_msg = Message::new(Body::Json(params_array.clone()));
let temp_exchange = Exchange::new(temp_msg);
let prepared = resolve_params(&template, &temp_exchange, &config.in_separator)?;
let mut query = sqlx::query(&prepared.sql);
query = bind_json_values(query, &prepared.bindings);
let result = query.execute(&mut *tx).await.map_err(|e| {
error!("Batch query execution failed at index {}: {}", batch_idx, e);
CamelError::ProcessorError(format!("Batch query execution failed: {}", e))
})?;
if let Some(expected) = config.expected_update_count
&& result.rows_affected() as i64 != expected
{
error!(
"Batch item {}: expected {} rows affected, got {}",
batch_idx,
expected,
result.rows_affected()
);
return Err(CamelError::ProcessorError(format!(
"Batch item {}: expected {} rows affected, got {}",
batch_idx,
expected,
result.rows_affected()
)));
}
total_rows_affected += result.rows_affected();
}
tx.commit().await.map_err(|e| {
error!("Failed to commit transaction: {}", e);
CamelError::ProcessorError(format!("Failed to commit transaction: {}", e))
})?;
exchange.input.set_header(
headers::UPDATE_COUNT,
serde_json::json!(total_rows_affected),
);
debug!(
"Batch execution completed, total rows affected: {}",
total_rows_affected
);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use camel_component_api::Message;
use camel_component_api::UriConfig;
use sqlx::any::AnyPoolOptions;
use std::sync::Arc;
use tokio::sync::OnceCell;
async fn sqlite_pool() -> AnyPool {
sqlx::any::install_default_drivers();
AnyPoolOptions::new()
.max_connections(1)
.connect("sqlite::memory:")
.await
.expect("sqlite pool")
}
async fn seed_items_table(pool: &AnyPool) {
sqlx::query(
"CREATE TABLE items (id INTEGER PRIMARY KEY, name TEXT, done INTEGER DEFAULT 0)",
)
.execute(pool)
.await
.expect("create table");
sqlx::query("INSERT INTO items (id, name, done) VALUES (1, 'a', 0), (2, 'b', 0)")
.execute(pool)
.await
.expect("seed rows");
}
fn config() -> SqlEndpointConfig {
let mut c =
SqlEndpointConfig::from_uri("sql:select 1?db_url=postgres://localhost/test").unwrap();
c.resolve_defaults();
c
}
#[test]
fn producer_clone_shares_pool() {
let p1 = SqlProducer::new(config(), Arc::new(OnceCell::new()));
let p2 = p1.clone();
assert!(Arc::ptr_eq(&p1.pool, &p2.pool));
}
#[test]
fn resolve_query_from_config() {
let config = config();
let ex = Exchange::new(Message::default());
let q = SqlProducer::resolve_query_source(&ex, &config);
assert_eq!(q, "select 1");
}
#[test]
fn resolve_query_from_header() {
let config = config();
let mut msg = Message::default();
msg.set_header(headers::QUERY, serde_json::json!("select 2"));
let ex = Exchange::new(msg);
let q = SqlProducer::resolve_query_source(&ex, &config);
assert_eq!(q, "select 2");
}
#[test]
fn resolve_query_from_body() {
let mut config = config();
config.use_message_body_for_sql = true;
let msg = Message::new(Body::Text("select 3".to_string()));
let ex = Exchange::new(msg);
let q = SqlProducer::resolve_query_source(&ex, &config);
assert_eq!(q, "select 3");
}
#[test]
fn resolve_query_header_priority_over_body() {
let mut config = config();
config.use_message_body_for_sql = true;
let mut msg = Message::new(Body::Text("select from body".to_string()));
msg.set_header(headers::QUERY, serde_json::json!("select from header"));
let ex = Exchange::new(msg);
let q = SqlProducer::resolve_query_source(&ex, &config);
assert_eq!(q, "select from header");
}
#[test]
fn resolve_query_body_priority_over_config() {
let mut config = config();
config.use_message_body_for_sql = true;
let msg = Message::new(Body::Text("select from body".to_string()));
let ex = Exchange::new(msg);
let q = SqlProducer::resolve_query_source(&ex, &config);
assert_eq!(q, "select from body");
}
#[test]
fn bind_json_null() {
let query = sqlx::query("SELECT ?");
let values = vec![serde_json::Value::Null];
let _bound = bind_json_values(query, &values);
}
#[test]
fn bind_json_bool() {
let query = sqlx::query("SELECT ?");
let values = vec![serde_json::Value::Bool(true)];
let _bound = bind_json_values(query, &values);
}
#[test]
fn bind_json_number_i64() {
let query = sqlx::query("SELECT ?");
let values = vec![serde_json::json!(42)];
let _bound = bind_json_values(query, &values);
}
#[test]
fn bind_json_number_f64() {
let query = sqlx::query("SELECT ?");
let values = vec![serde_json::json!(std::f64::consts::PI)];
let _bound = bind_json_values(query, &values);
}
#[test]
fn bind_json_string() {
let query = sqlx::query("SELECT ?");
let values = vec![serde_json::json!("hello world")];
let _bound = bind_json_values(query, &values);
}
#[test]
fn bind_json_array() {
let query = sqlx::query("SELECT ?");
let values = vec![serde_json::json!([1, 2, 3])];
let _bound = bind_json_values(query, &values);
}
#[test]
fn bind_json_object() {
let query = sqlx::query("SELECT ?");
let values = vec![serde_json::json!({"key": "value"})];
let _bound = bind_json_values(query, &values);
}
#[test]
fn bind_multiple_values() {
let query = sqlx::query("SELECT ?, ?, ?");
let values = vec![
serde_json::json!(1),
serde_json::json!("test"),
serde_json::Value::Null,
];
let _bound = bind_json_values(query, &values);
}
#[test]
fn expected_update_count_validation() {
let config = SqlEndpointConfig::from_uri(
"sql:update t set x=1?db_url=postgres://localhost/test&expectedUpdateCount=5",
)
.unwrap();
assert_eq!(config.expected_update_count, Some(5));
let config_default = self::config();
assert_eq!(config_default.expected_update_count, None);
let config_neg = SqlEndpointConfig::from_uri(
"sql:update t set x=1?db_url=postgres://localhost/test&expectedUpdateCount=-1",
)
.unwrap();
assert_eq!(config_neg.expected_update_count, Some(-1));
}
#[test]
fn parameters_header_override_logic() {
let mut prepared = PreparedQuery {
sql: "SELECT * FROM t WHERE id = $1".to_string(),
bindings: vec![serde_json::json!(42)],
};
let header_params = serde_json::json!([99, "extra"]);
if let Some(arr) = header_params.as_array() {
prepared.bindings = arr.clone();
}
assert_eq!(prepared.bindings.len(), 2);
assert_eq!(prepared.bindings[0], serde_json::json!(99));
assert_eq!(prepared.bindings[1], serde_json::json!("extra"));
let mut prepared2 = PreparedQuery {
sql: "SELECT * FROM t WHERE id = $1".to_string(),
bindings: vec![serde_json::json!(42)],
};
let header_non_array = serde_json::json!({"not": "an array"});
if let Some(arr) = header_non_array.as_array() {
prepared2.bindings = arr.clone();
}
assert_eq!(prepared2.bindings.len(), 1);
assert_eq!(prepared2.bindings[0], serde_json::json!(42));
}
#[tokio::test]
async fn execute_select_one_sets_body_and_row_count() {
let pool = sqlite_pool().await;
seed_items_table(&pool).await;
let mut config = SqlEndpointConfig::from_uri(
"sql:select id, name from items order by id?db_url=sqlite::memory:&outputType=SelectOne",
)
.unwrap();
config.resolve_defaults();
let prepared = PreparedQuery {
sql: "select id, name from items order by id".to_string(),
bindings: vec![],
};
let mut exchange = Exchange::new(Message::default());
execute_select(&pool, &prepared, &config, &mut exchange)
.await
.expect("select one");
assert_eq!(exchange.input.header(headers::ROW_COUNT), Some(&json!(2)));
assert_eq!(
exchange.input.body,
Body::Json(json!({"id": 1, "name": "a"}))
);
}
#[tokio::test]
async fn execute_stream_list_materializes_ndjson() {
let pool = sqlite_pool().await;
seed_items_table(&pool).await;
let mut config = SqlEndpointConfig::from_uri(
"sql:select id from items order by id?db_url=sqlite::memory:&outputType=StreamList",
)
.unwrap();
config.resolve_defaults();
let prepared = PreparedQuery {
sql: "select id from items order by id".to_string(),
bindings: vec![],
};
let mut exchange = Exchange::new(Message::default());
execute_select(&pool, &prepared, &config, &mut exchange)
.await
.expect("stream list");
let bytes = exchange
.input
.body
.clone()
.into_bytes(1024)
.await
.expect("stream bytes");
let text = String::from_utf8(bytes.to_vec()).expect("utf8");
assert!(text.contains("{\"id\":1}"));
assert!(text.contains("{\"id\":2}"));
assert_eq!(exchange.input.header(headers::ROW_COUNT), None);
}
#[tokio::test]
async fn execute_modify_expected_update_count_mismatch_returns_error() {
let pool = sqlite_pool().await;
seed_items_table(&pool).await;
let mut config = SqlEndpointConfig::from_uri(
"sql:update items set done=1 where id = #?db_url=sqlite::memory:&expectedUpdateCount=2",
)
.unwrap();
config.resolve_defaults();
let prepared = PreparedQuery {
sql: "update items set done=1 where id = $1".to_string(),
bindings: vec![json!(1)],
};
let mut exchange = Exchange::new(Message::default());
let err = execute_modify(&pool, &prepared, &config, &mut exchange)
.await
.expect_err("must fail due expected row count mismatch");
assert!(err.to_string().contains("Expected 2 rows affected, got 1"));
}
#[tokio::test]
async fn execute_batch_rollback_when_any_item_fails_expected_count() {
let pool = sqlite_pool().await;
seed_items_table(&pool).await;
let mut config = SqlEndpointConfig::from_uri(
"sql:update items set done=1 where id = #?db_url=sqlite::memory:&batch=true&expectedUpdateCount=1",
)
.unwrap();
config.resolve_defaults();
let mut exchange = Exchange::new(Message::new(Body::Json(json!([[1], [999]]))));
let err = execute_batch(&pool, &config, &mut exchange)
.await
.expect_err("second batch item should fail expectedUpdateCount");
assert!(
err.to_string()
.contains("Batch item 1: expected 1 rows affected, got 0")
);
let row = sqlx::query("select done from items where id = 1")
.fetch_one(&pool)
.await
.expect("query row");
let done: i64 = sqlx::Row::try_get(&row, 0).expect("done column");
assert_eq!(done, 0, "transaction must rollback first update");
}
}