1use std::sync::Arc;
2use std::time::Duration;
3
4use async_trait::async_trait;
5use serde_json::Value as JsonValue;
6use sqlx::AnyPool;
7use sqlx::any::AnyPoolOptions;
8use sqlx::any::AnyRow;
9use tokio::sync::OnceCell;
10use tracing::{error, info, warn};
11
12use camel_component_api::{Body, CamelError, Exchange, Message};
13use camel_component_api::{ConcurrencyModel, Consumer, ConsumerContext};
14
15use crate::config::{SqlEndpointConfig, enrich_db_url_with_ssl};
16use crate::headers;
17use crate::query::{QueryTemplate, parse_query_template, resolve_params};
18use crate::utils::{bind_json_values, row_to_json};
19
20pub struct SqlConsumer {
21 pub(crate) config: SqlEndpointConfig,
22 pub(crate) pool: Arc<OnceCell<AnyPool>>,
23}
24
25impl SqlConsumer {
26 pub fn new(config: SqlEndpointConfig, pool: Arc<OnceCell<AnyPool>>) -> Self {
27 Self { config, pool }
28 }
29
30 async fn poll_database(
32 &self,
33 pool: &AnyPool,
34 context: &ConsumerContext,
35 template: &QueryTemplate,
36 ) -> Result<(), CamelError> {
37 let empty_exchange = Exchange::new(Message::default());
39
40 let prepared = resolve_params(template, &empty_exchange, &self.config.in_separator)?;
42
43 let query = bind_json_values(sqlx::query(&prepared.sql), &prepared.bindings);
45 let rows: Vec<AnyRow> = query
46 .fetch_all(pool)
47 .await
48 .map_err(|e| CamelError::ProcessorError(format!("Query execution failed: {}", e)))?;
49
50 if rows.is_empty() && !self.config.route_empty_result_set {
52 return Ok(());
53 }
54
55 let rows_to_process: Vec<AnyRow> = if let Some(max) = self.config.max_messages_per_poll {
57 if max > 0 {
58 rows.into_iter().take(max as usize).collect()
59 } else {
60 rows
61 }
62 } else {
63 rows
64 };
65
66 if self.config.use_iterator {
67 for row in rows_to_process {
69 let row_json = row_to_json(&row)?;
70
71 let mut msg = Message::new(Body::Json(row_json.clone()));
73
74 if let Some(obj) = row_json.as_object() {
76 for (key, value) in obj {
77 msg.set_header(format!("CamelSql.{}", key), value.clone());
78 }
79 }
80
81 let exchange = Exchange::new(msg);
82
83 let result = context.send_and_wait(exchange).await;
85
86 if let Err(e) = self.handle_post_processing(pool, &result, &row_json).await {
88 error!(error = %e, "Post-processing failed");
89 if self.config.break_batch_on_consume_fail {
90 return Err(e);
91 }
92 }
93
94 if let Err(ref consume_err) = result
96 && self.config.break_batch_on_consume_fail
97 {
98 return Err(consume_err.clone());
99 }
100 }
101 } else {
102 let rows_json: Vec<JsonValue> = rows_to_process
104 .iter()
105 .map(row_to_json)
106 .collect::<Result<Vec<_>, CamelError>>()?;
107
108 let row_count = rows_json.len();
109
110 let mut msg = Message::new(Body::Json(JsonValue::Array(rows_json)));
112 msg.set_header(headers::ROW_COUNT, JsonValue::Number(row_count.into()));
113
114 let exchange = Exchange::new(msg);
115
116 let result = context.send_and_wait(exchange).await;
118 if let Err(e) = self
119 .handle_post_processing(pool, &result, &JsonValue::Null)
120 .await
121 {
122 error!(error = %e, "Post-processing failed for batch");
123 if self.config.break_batch_on_consume_fail {
124 return Err(e);
125 }
126 }
127 if let Err(ref consume_err) = result
129 && self.config.break_batch_on_consume_fail
130 {
131 return Err(consume_err.clone());
132 }
133 }
134
135 if let Some(ref batch_query) = self.config.on_consume_batch_complete
137 && let Err(e) = self
138 .execute_post_query(pool, batch_query, &JsonValue::Null)
139 .await
140 {
141 error!(error = %e, "onConsumeBatchComplete query failed");
142 }
143
144 Ok(())
145 }
146
147 async fn handle_post_processing(
149 &self,
150 pool: &AnyPool,
151 result: &Result<Exchange, CamelError>,
152 row_json: &JsonValue,
153 ) -> Result<(), CamelError> {
154 match result {
155 Ok(_) => {
156 if let Some(ref on_consume) = self.config.on_consume {
158 self.execute_post_query(pool, on_consume, row_json).await?;
159 }
160 }
161 Err(_) => {
162 if let Some(ref on_consume_failed) = self.config.on_consume_failed {
164 self.execute_post_query(pool, on_consume_failed, row_json)
165 .await?;
166 }
167 }
168 }
169 Ok(())
170 }
171
172 async fn execute_post_query(
174 &self,
175 pool: &AnyPool,
176 query_str: &str,
177 row_json: &JsonValue,
178 ) -> Result<(), CamelError> {
179 let template = parse_query_template(query_str, self.config.placeholder)?;
181
182 let mut temp_msg = Message::new(Body::Json(row_json.clone()));
185 if let Some(obj) = row_json.as_object() {
186 for (key, value) in obj {
187 temp_msg.set_header(format!("CamelSql.{}", key), value.clone());
188 }
189 }
190 let temp_exchange = Exchange::new(temp_msg);
191
192 let prepared = resolve_params(&template, &temp_exchange, &self.config.in_separator)?;
194
195 let query = bind_json_values(sqlx::query(&prepared.sql), &prepared.bindings);
197 let result = query.execute(pool).await.map_err(|e| {
198 CamelError::ProcessorError(format!("Post-query execution failed: {}", e))
199 })?;
200
201 if result.rows_affected() == 0 {
203 warn!(
204 query = query_str,
205 "Post-processing query affected 0 rows — the row may not have been marked correctly"
206 );
207 }
208
209 Ok(())
210 }
211}
212
213#[async_trait]
214impl Consumer for SqlConsumer {
215 async fn start(&mut self, context: ConsumerContext) -> Result<(), CamelError> {
216 let pool = self
218 .pool
219 .get_or_try_init(|| async {
220 self.config.resolve_defaults();
222
223 sqlx::any::install_default_drivers();
226 let db_url = enrich_db_url_with_ssl(&self.config.db_url, &self.config)?;
227 AnyPoolOptions::new()
228 .max_connections(
229 self.config
230 .max_connections
231 .expect("must be Some after resolve_defaults()"),
232 )
233 .min_connections(
234 self.config
235 .min_connections
236 .expect("must be Some after resolve_defaults()"),
237 )
238 .idle_timeout(Duration::from_secs(
239 self.config
240 .idle_timeout_secs
241 .expect("must be Some after resolve_defaults()"),
242 ))
243 .max_lifetime(Duration::from_secs(
244 self.config
245 .max_lifetime_secs
246 .expect("must be Some after resolve_defaults()"),
247 ))
248 .connect(&db_url)
249 .await
250 .map_err(|e| {
251 CamelError::EndpointCreationFailed(format!(
252 "Failed to connect to database: {}",
253 e
254 ))
255 })
256 })
257 .await?;
258
259 if self.config.on_consume.is_none() {
261 warn!(
262 "SQL consumer started without onConsume configured — consumed rows will not be marked/deleted"
263 );
264 }
265
266 let template = parse_query_template(&self.config.query, self.config.placeholder)
268 .map_err(|e| CamelError::Config(format!("Invalid query template: {}", e)))?;
269
270 if self.config.initial_delay_ms > 0 {
272 tokio::select! {
273 _ = context.cancelled() => {
274 info!("SQL consumer stopped during initial delay");
275 return Ok(());
276 }
277 _ = tokio::time::sleep(Duration::from_millis(self.config.initial_delay_ms)) => {}
278 }
279 }
280
281 loop {
283 tokio::select! {
284 _ = context.cancelled() => {
285 info!("SQL consumer stopped");
286 break;
287 }
288 _ = tokio::time::sleep(Duration::from_millis(self.config.delay_ms)) => {
289 if let Err(e) = self.poll_database(pool, &context, &template).await {
290 error!(error = %e, "SQL consumer poll failed");
291 }
292 }
293 }
294 }
295
296 Ok(())
297 }
298
299 async fn stop(&mut self) -> Result<(), CamelError> {
300 Ok(())
301 }
302
303 fn concurrency_model(&self) -> ConcurrencyModel {
304 ConcurrencyModel::Sequential
308 }
309}
310
311#[cfg(test)]
312mod tests {
313 use super::*;
314 use crate::config::SqlEndpointConfig;
315 use camel_component_api::ExchangeEnvelope;
316 use camel_component_api::UriConfig;
317 use sqlx::any::AnyPoolOptions;
318 use std::sync::Arc;
319 use tokio::sync::mpsc;
320 use tokio_util::sync::CancellationToken;
321
322 async fn sqlite_pool() -> AnyPool {
323 sqlx::any::install_default_drivers();
324 AnyPoolOptions::new()
325 .max_connections(1)
326 .connect("sqlite::memory:")
327 .await
328 .expect("sqlite pool")
329 }
330
331 async fn seed_consumer_table(pool: &AnyPool) {
332 sqlx::query("CREATE TABLE jobs (id INTEGER PRIMARY KEY, processed INTEGER DEFAULT 0, failed INTEGER DEFAULT 0)")
333 .execute(pool)
334 .await
335 .expect("create table");
336 sqlx::query("INSERT INTO jobs (id, processed, failed) VALUES (1, 0, 0), (2, 0, 0)")
337 .execute(pool)
338 .await
339 .expect("seed rows");
340 }
341
342 fn config() -> SqlEndpointConfig {
343 let mut c =
344 SqlEndpointConfig::from_uri("sql:select * from t?db_url=postgres://localhost/test")
345 .unwrap();
346 c.resolve_defaults();
347 c
348 }
349
350 #[test]
351 fn consumer_concurrency_model() {
352 let c = SqlConsumer::new(config(), Arc::new(OnceCell::new()));
353 assert_eq!(c.concurrency_model(), ConcurrencyModel::Sequential);
354 }
355
356 #[test]
357 fn consumer_stores_config() {
358 let mut config = SqlEndpointConfig::from_uri(
359 "sql:select * from t?db_url=postgres://localhost/test&delay=2000&onConsume=update t set done=true"
360 ).unwrap();
361 config.resolve_defaults();
362 let c = SqlConsumer::new(config.clone(), Arc::new(OnceCell::new()));
363 assert_eq!(c.config.delay_ms, 2000);
364 assert!(c.config.on_consume.is_some());
365 }
366
367 #[tokio::test]
368 async fn poll_database_runs_on_consume_for_successful_rows() {
369 let pool = sqlite_pool().await;
370 seed_consumer_table(&pool).await;
371
372 let mut config = SqlEndpointConfig::from_uri(
373 "sql:select id, processed, failed from jobs where processed = 0 order by id?db_url=sqlite::memory:&onConsume=update jobs set processed=1 where id=:#id&initialDelay=0&delay=1",
374 )
375 .unwrap();
376 config.resolve_defaults();
377
378 let consumer = SqlConsumer::new(config.clone(), Arc::new(OnceCell::new()));
379 let template = parse_query_template(&config.query, config.placeholder).unwrap();
380
381 let (tx, mut rx) = mpsc::channel::<ExchangeEnvelope>(8);
382 tokio::spawn(async move {
383 while let Some(env) = rx.recv().await {
384 if let Some(reply_tx) = env.reply_tx {
385 let _ = reply_tx.send(Ok(env.exchange));
386 }
387 }
388 });
389 let ctx = ConsumerContext::new(tx, CancellationToken::new());
390
391 consumer
392 .poll_database(&pool, &ctx, &template)
393 .await
394 .expect("poll must succeed");
395
396 let row = sqlx::query("select processed from jobs where id = 1")
397 .fetch_one(&pool)
398 .await
399 .expect("row 1");
400 let processed_1: i64 = sqlx::Row::try_get(&row, 0).expect("processed");
401
402 let row = sqlx::query("select processed from jobs where id = 2")
403 .fetch_one(&pool)
404 .await
405 .expect("row 2");
406 let processed_2: i64 = sqlx::Row::try_get(&row, 0).expect("processed");
407
408 assert_eq!(processed_1, 1);
409 assert_eq!(processed_2, 1);
410 }
411
412 #[tokio::test]
413 async fn poll_database_runs_on_consume_failed_when_downstream_fails() {
414 let pool = sqlite_pool().await;
415 seed_consumer_table(&pool).await;
416
417 let mut config = SqlEndpointConfig::from_uri(
418 "sql:select id, processed, failed from jobs where processed = 0 order by id?db_url=sqlite::memory:&onConsumeFailed=update jobs set failed=1 where id=:#id&initialDelay=0&delay=1",
419 )
420 .unwrap();
421 config.resolve_defaults();
422
423 let consumer = SqlConsumer::new(config.clone(), Arc::new(OnceCell::new()));
424 let template = parse_query_template(&config.query, config.placeholder).unwrap();
425
426 let (tx, mut rx) = mpsc::channel::<ExchangeEnvelope>(8);
427 tokio::spawn(async move {
428 while let Some(env) = rx.recv().await {
429 if let Some(reply_tx) = env.reply_tx {
430 let _ =
431 reply_tx.send(Err(CamelError::ProcessorError("downstream boom".into())));
432 }
433 }
434 });
435 let ctx = ConsumerContext::new(tx, CancellationToken::new());
436
437 consumer
438 .poll_database(&pool, &ctx, &template)
439 .await
440 .expect("consumer should swallow downstream errors when breakBatchOnConsumeFail=false");
441
442 let row = sqlx::query("select failed from jobs where id = 1")
443 .fetch_one(&pool)
444 .await
445 .expect("row 1");
446 let failed_1: i64 = sqlx::Row::try_get(&row, 0).expect("failed");
447
448 let row = sqlx::query("select failed from jobs where id = 2")
449 .fetch_one(&pool)
450 .await
451 .expect("row 2");
452 let failed_2: i64 = sqlx::Row::try_get(&row, 0).expect("failed");
453
454 assert_eq!(failed_1, 1);
455 assert_eq!(failed_2, 1);
456 }
457
458 #[tokio::test]
459 async fn poll_database_breaks_batch_on_consume_fail() {
460 let pool = sqlite_pool().await;
461 seed_consumer_table(&pool).await;
462
463 let mut config = SqlEndpointConfig::from_uri(
464 "sql:select id, processed, failed from jobs where processed = 0 order by id?db_url=sqlite::memory:&onConsumeFailed=update jobs set failed=1 where id=:#id&breakBatchOnConsumeFail=true&initialDelay=0&delay=1",
465 )
466 .unwrap();
467 config.resolve_defaults();
468
469 let consumer = SqlConsumer::new(config.clone(), Arc::new(OnceCell::new()));
470 let template = parse_query_template(&config.query, config.placeholder).unwrap();
471
472 let (tx, mut rx) = mpsc::channel::<ExchangeEnvelope>(8);
473 tokio::spawn(async move {
474 while let Some(env) = rx.recv().await {
475 if let Some(reply_tx) = env.reply_tx {
476 let _ =
477 reply_tx.send(Err(CamelError::ProcessorError("downstream boom".into())));
478 }
479 }
480 });
481 let ctx = ConsumerContext::new(tx, CancellationToken::new());
482
483 let err = consumer
484 .poll_database(&pool, &ctx, &template)
485 .await
486 .expect_err("must stop on first downstream failure");
487 assert!(err.to_string().contains("downstream boom"));
488
489 let row = sqlx::query("select failed from jobs where id = 1")
490 .fetch_one(&pool)
491 .await
492 .expect("row 1");
493 let failed_1: i64 = sqlx::Row::try_get(&row, 0).expect("failed");
494
495 let row = sqlx::query("select failed from jobs where id = 2")
496 .fetch_one(&pool)
497 .await
498 .expect("row 2");
499 let failed_2: i64 = sqlx::Row::try_get(&row, 0).expect("failed");
500
501 assert_eq!(failed_1, 1);
502 assert_eq!(failed_2, 0, "second row must not be processed");
503 }
504}