use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use serde_json::Value as JsonValue;
use sqlx::AnyPool;
use sqlx::any::AnyPoolOptions;
use sqlx::any::AnyRow;
use tokio::sync::OnceCell;
use tracing::{error, info, warn};
use camel_component_api::{Body, CamelError, Exchange, Message};
use camel_component_api::{ConcurrencyModel, Consumer, ConsumerContext};
use crate::config::{SqlEndpointConfig, enrich_db_url_with_ssl};
use crate::headers;
use crate::query::{QueryTemplate, parse_query_template, resolve_params};
use crate::utils::{bind_json_values, row_to_json};
pub struct SqlConsumer {
pub(crate) config: SqlEndpointConfig,
pub(crate) pool: Arc<OnceCell<AnyPool>>,
}
impl SqlConsumer {
pub fn new(config: SqlEndpointConfig, pool: Arc<OnceCell<AnyPool>>) -> Self {
Self { config, pool }
}
async fn poll_database(
&self,
pool: &AnyPool,
context: &ConsumerContext,
template: &QueryTemplate,
) -> Result<(), CamelError> {
let empty_exchange = Exchange::new(Message::default());
let prepared = resolve_params(template, &empty_exchange, &self.config.in_separator)?;
let query = bind_json_values(sqlx::query(&prepared.sql), &prepared.bindings);
let rows: Vec<AnyRow> = query
.fetch_all(pool)
.await
.map_err(|e| CamelError::ProcessorError(format!("Query execution failed: {}", e)))?;
if rows.is_empty() && !self.config.route_empty_result_set {
return Ok(());
}
let rows_to_process: Vec<AnyRow> = if let Some(max) = self.config.max_messages_per_poll {
if max > 0 {
rows.into_iter().take(max as usize).collect()
} else {
rows
}
} else {
rows
};
if self.config.use_iterator {
for row in rows_to_process {
let row_json = row_to_json(&row)?;
let mut msg = Message::new(Body::Json(row_json.clone()));
if let Some(obj) = row_json.as_object() {
for (key, value) in obj {
msg.set_header(format!("CamelSql.{}", key), value.clone());
}
}
let exchange = Exchange::new(msg);
let result = context.send_and_wait(exchange).await;
if let Err(e) = self.handle_post_processing(pool, &result, &row_json).await {
error!(error = %e, "Post-processing failed");
if self.config.break_batch_on_consume_fail {
return Err(e);
}
}
if let Err(ref consume_err) = result
&& self.config.break_batch_on_consume_fail
{
return Err(consume_err.clone());
}
}
} else {
let rows_json: Vec<JsonValue> = rows_to_process
.iter()
.map(row_to_json)
.collect::<Result<Vec<_>, CamelError>>()?;
let row_count = rows_json.len();
let mut msg = Message::new(Body::Json(JsonValue::Array(rows_json)));
msg.set_header(headers::ROW_COUNT, JsonValue::Number(row_count.into()));
let exchange = Exchange::new(msg);
let result = context.send_and_wait(exchange).await;
if let Err(e) = self
.handle_post_processing(pool, &result, &JsonValue::Null)
.await
{
error!(error = %e, "Post-processing failed for batch");
if self.config.break_batch_on_consume_fail {
return Err(e);
}
}
if let Err(ref consume_err) = result
&& self.config.break_batch_on_consume_fail
{
return Err(consume_err.clone());
}
}
if let Some(ref batch_query) = self.config.on_consume_batch_complete
&& let Err(e) = self
.execute_post_query(pool, batch_query, &JsonValue::Null)
.await
{
error!(error = %e, "onConsumeBatchComplete query failed");
}
Ok(())
}
async fn handle_post_processing(
&self,
pool: &AnyPool,
result: &Result<Exchange, CamelError>,
row_json: &JsonValue,
) -> Result<(), CamelError> {
match result {
Ok(_) => {
if let Some(ref on_consume) = self.config.on_consume {
self.execute_post_query(pool, on_consume, row_json).await?;
}
}
Err(_) => {
if let Some(ref on_consume_failed) = self.config.on_consume_failed {
self.execute_post_query(pool, on_consume_failed, row_json)
.await?;
}
}
}
Ok(())
}
async fn execute_post_query(
&self,
pool: &AnyPool,
query_str: &str,
row_json: &JsonValue,
) -> Result<(), CamelError> {
let template = parse_query_template(query_str, self.config.placeholder)?;
let mut temp_msg = Message::new(Body::Json(row_json.clone()));
if let Some(obj) = row_json.as_object() {
for (key, value) in obj {
temp_msg.set_header(format!("CamelSql.{}", key), value.clone());
}
}
let temp_exchange = Exchange::new(temp_msg);
let prepared = resolve_params(&template, &temp_exchange, &self.config.in_separator)?;
let query = bind_json_values(sqlx::query(&prepared.sql), &prepared.bindings);
let result = query.execute(pool).await.map_err(|e| {
CamelError::ProcessorError(format!("Post-query execution failed: {}", e))
})?;
if result.rows_affected() == 0 {
warn!(
query = query_str,
"Post-processing query affected 0 rows — the row may not have been marked correctly"
);
}
Ok(())
}
}
#[async_trait]
impl Consumer for SqlConsumer {
async fn start(&mut self, context: ConsumerContext) -> Result<(), CamelError> {
let pool = self
.pool
.get_or_try_init(|| async {
self.config.resolve_defaults();
sqlx::any::install_default_drivers();
let db_url = enrich_db_url_with_ssl(&self.config.db_url, &self.config)?;
AnyPoolOptions::new()
.max_connections(
self.config
.max_connections
.expect("must be Some after resolve_defaults()"),
)
.min_connections(
self.config
.min_connections
.expect("must be Some after resolve_defaults()"),
)
.idle_timeout(Duration::from_secs(
self.config
.idle_timeout_secs
.expect("must be Some after resolve_defaults()"),
))
.max_lifetime(Duration::from_secs(
self.config
.max_lifetime_secs
.expect("must be Some after resolve_defaults()"),
))
.connect(&db_url)
.await
.map_err(|e| {
CamelError::EndpointCreationFailed(format!(
"Failed to connect to database: {}",
e
))
})
})
.await?;
if self.config.on_consume.is_none() {
warn!(
"SQL consumer started without onConsume configured — consumed rows will not be marked/deleted"
);
}
let template = parse_query_template(&self.config.query, self.config.placeholder)
.map_err(|e| CamelError::Config(format!("Invalid query template: {}", e)))?;
if self.config.initial_delay_ms > 0 {
tokio::select! {
_ = context.cancelled() => {
info!("SQL consumer stopped during initial delay");
return Ok(());
}
_ = tokio::time::sleep(Duration::from_millis(self.config.initial_delay_ms)) => {}
}
}
loop {
tokio::select! {
_ = context.cancelled() => {
info!("SQL consumer stopped");
break;
}
_ = tokio::time::sleep(Duration::from_millis(self.config.delay_ms)) => {
if let Err(e) = self.poll_database(pool, &context, &template).await {
error!(error = %e, "SQL consumer poll failed");
}
}
}
}
Ok(())
}
async fn stop(&mut self) -> Result<(), CamelError> {
Ok(())
}
fn concurrency_model(&self) -> ConcurrencyModel {
ConcurrencyModel::Sequential
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::SqlEndpointConfig;
use camel_component_api::ExchangeEnvelope;
use camel_component_api::UriConfig;
use sqlx::any::AnyPoolOptions;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
async fn sqlite_pool() -> AnyPool {
sqlx::any::install_default_drivers();
AnyPoolOptions::new()
.max_connections(1)
.connect("sqlite::memory:")
.await
.expect("sqlite pool")
}
async fn seed_consumer_table(pool: &AnyPool) {
sqlx::query("CREATE TABLE jobs (id INTEGER PRIMARY KEY, processed INTEGER DEFAULT 0, failed INTEGER DEFAULT 0)")
.execute(pool)
.await
.expect("create table");
sqlx::query("INSERT INTO jobs (id, processed, failed) VALUES (1, 0, 0), (2, 0, 0)")
.execute(pool)
.await
.expect("seed rows");
}
fn config() -> SqlEndpointConfig {
let mut c =
SqlEndpointConfig::from_uri("sql:select * from t?db_url=postgres://localhost/test")
.unwrap();
c.resolve_defaults();
c
}
#[test]
fn consumer_concurrency_model() {
let c = SqlConsumer::new(config(), Arc::new(OnceCell::new()));
assert_eq!(c.concurrency_model(), ConcurrencyModel::Sequential);
}
#[test]
fn consumer_stores_config() {
let mut config = SqlEndpointConfig::from_uri(
"sql:select * from t?db_url=postgres://localhost/test&delay=2000&onConsume=update t set done=true"
).unwrap();
config.resolve_defaults();
let c = SqlConsumer::new(config.clone(), Arc::new(OnceCell::new()));
assert_eq!(c.config.delay_ms, 2000);
assert!(c.config.on_consume.is_some());
}
#[tokio::test]
async fn poll_database_runs_on_consume_for_successful_rows() {
let pool = sqlite_pool().await;
seed_consumer_table(&pool).await;
let mut config = SqlEndpointConfig::from_uri(
"sql:select id, processed, failed from jobs where processed = 0 order by id?db_url=sqlite::memory:&onConsume=update jobs set processed=1 where id=:#id&initialDelay=0&delay=1",
)
.unwrap();
config.resolve_defaults();
let consumer = SqlConsumer::new(config.clone(), Arc::new(OnceCell::new()));
let template = parse_query_template(&config.query, config.placeholder).unwrap();
let (tx, mut rx) = mpsc::channel::<ExchangeEnvelope>(8);
tokio::spawn(async move {
while let Some(env) = rx.recv().await {
if let Some(reply_tx) = env.reply_tx {
let _ = reply_tx.send(Ok(env.exchange));
}
}
});
let ctx = ConsumerContext::new(tx, CancellationToken::new());
consumer
.poll_database(&pool, &ctx, &template)
.await
.expect("poll must succeed");
let row = sqlx::query("select processed from jobs where id = 1")
.fetch_one(&pool)
.await
.expect("row 1");
let processed_1: i64 = sqlx::Row::try_get(&row, 0).expect("processed");
let row = sqlx::query("select processed from jobs where id = 2")
.fetch_one(&pool)
.await
.expect("row 2");
let processed_2: i64 = sqlx::Row::try_get(&row, 0).expect("processed");
assert_eq!(processed_1, 1);
assert_eq!(processed_2, 1);
}
#[tokio::test]
async fn poll_database_runs_on_consume_failed_when_downstream_fails() {
let pool = sqlite_pool().await;
seed_consumer_table(&pool).await;
let mut config = SqlEndpointConfig::from_uri(
"sql:select id, processed, failed from jobs where processed = 0 order by id?db_url=sqlite::memory:&onConsumeFailed=update jobs set failed=1 where id=:#id&initialDelay=0&delay=1",
)
.unwrap();
config.resolve_defaults();
let consumer = SqlConsumer::new(config.clone(), Arc::new(OnceCell::new()));
let template = parse_query_template(&config.query, config.placeholder).unwrap();
let (tx, mut rx) = mpsc::channel::<ExchangeEnvelope>(8);
tokio::spawn(async move {
while let Some(env) = rx.recv().await {
if let Some(reply_tx) = env.reply_tx {
let _ =
reply_tx.send(Err(CamelError::ProcessorError("downstream boom".into())));
}
}
});
let ctx = ConsumerContext::new(tx, CancellationToken::new());
consumer
.poll_database(&pool, &ctx, &template)
.await
.expect("consumer should swallow downstream errors when breakBatchOnConsumeFail=false");
let row = sqlx::query("select failed from jobs where id = 1")
.fetch_one(&pool)
.await
.expect("row 1");
let failed_1: i64 = sqlx::Row::try_get(&row, 0).expect("failed");
let row = sqlx::query("select failed from jobs where id = 2")
.fetch_one(&pool)
.await
.expect("row 2");
let failed_2: i64 = sqlx::Row::try_get(&row, 0).expect("failed");
assert_eq!(failed_1, 1);
assert_eq!(failed_2, 1);
}
#[tokio::test]
async fn poll_database_breaks_batch_on_consume_fail() {
let pool = sqlite_pool().await;
seed_consumer_table(&pool).await;
let mut config = SqlEndpointConfig::from_uri(
"sql:select id, processed, failed from jobs where processed = 0 order by id?db_url=sqlite::memory:&onConsumeFailed=update jobs set failed=1 where id=:#id&breakBatchOnConsumeFail=true&initialDelay=0&delay=1",
)
.unwrap();
config.resolve_defaults();
let consumer = SqlConsumer::new(config.clone(), Arc::new(OnceCell::new()));
let template = parse_query_template(&config.query, config.placeholder).unwrap();
let (tx, mut rx) = mpsc::channel::<ExchangeEnvelope>(8);
tokio::spawn(async move {
while let Some(env) = rx.recv().await {
if let Some(reply_tx) = env.reply_tx {
let _ =
reply_tx.send(Err(CamelError::ProcessorError("downstream boom".into())));
}
}
});
let ctx = ConsumerContext::new(tx, CancellationToken::new());
let err = consumer
.poll_database(&pool, &ctx, &template)
.await
.expect_err("must stop on first downstream failure");
assert!(err.to_string().contains("downstream boom"));
let row = sqlx::query("select failed from jobs where id = 1")
.fetch_one(&pool)
.await
.expect("row 1");
let failed_1: i64 = sqlx::Row::try_get(&row, 0).expect("failed");
let row = sqlx::query("select failed from jobs where id = 2")
.fetch_one(&pool)
.await
.expect("row 2");
let failed_2: i64 = sqlx::Row::try_get(&row, 0).expect("failed");
assert_eq!(failed_1, 1);
assert_eq!(failed_2, 0, "second row must not be processed");
}
}