1use std::sync::Arc;
2use std::time::Duration;
3
4use async_trait::async_trait;
5use serde_json::Value as JsonValue;
6use sqlx::AnyPool;
7use sqlx::any::AnyPoolOptions;
8use sqlx::any::AnyRow;
9use tokio::sync::OnceCell;
10use tracing::{error, info, warn};
11
12use camel_component_api::{Body, CamelError, Exchange, Message};
13use camel_component_api::{ConcurrencyModel, Consumer, ConsumerContext};
14
15use crate::config::SqlEndpointConfig;
16use crate::headers;
17use crate::query::{QueryTemplate, parse_query_template, resolve_params};
18use crate::utils::{bind_json_values, row_to_json};
19
20pub struct SqlConsumer {
21 pub(crate) config: SqlEndpointConfig,
22 pub(crate) pool: Arc<OnceCell<AnyPool>>,
23}
24
25impl SqlConsumer {
26 pub fn new(config: SqlEndpointConfig, pool: Arc<OnceCell<AnyPool>>) -> Self {
27 Self { config, pool }
28 }
29
30 async fn poll_database(
32 &self,
33 pool: &AnyPool,
34 context: &ConsumerContext,
35 template: &QueryTemplate,
36 ) -> Result<(), CamelError> {
37 let empty_exchange = Exchange::new(Message::default());
39
40 let prepared = resolve_params(template, &empty_exchange)?;
42
43 let query = bind_json_values(sqlx::query(&prepared.sql), &prepared.bindings);
45 let rows: Vec<AnyRow> = query
46 .fetch_all(pool)
47 .await
48 .map_err(|e| CamelError::ProcessorError(format!("Query execution failed: {}", e)))?;
49
50 if rows.is_empty() && !self.config.route_empty_result_set {
52 return Ok(());
53 }
54
55 let rows_to_process: Vec<AnyRow> = if let Some(max) = self.config.max_messages_per_poll {
57 if max > 0 {
58 rows.into_iter().take(max as usize).collect()
59 } else {
60 rows
61 }
62 } else {
63 rows
64 };
65
66 if self.config.use_iterator {
67 for row in rows_to_process {
69 let row_json = row_to_json(&row)?;
70
71 let mut msg = Message::new(Body::Json(row_json.clone()));
73
74 if let Some(obj) = row_json.as_object() {
76 for (key, value) in obj {
77 msg.set_header(format!("CamelSql.{}", key), value.clone());
78 }
79 }
80
81 let exchange = Exchange::new(msg);
82
83 let result = context.send_and_wait(exchange).await;
85
86 if let Err(e) = self.handle_post_processing(pool, &result, &row_json).await {
88 error!(error = %e, "Post-processing failed");
89 if self.config.break_batch_on_consume_fail {
90 return Err(e);
91 }
92 }
93
94 if let Err(ref consume_err) = result
96 && self.config.break_batch_on_consume_fail
97 {
98 return Err(consume_err.clone());
99 }
100 }
101 } else {
102 let rows_json: Vec<JsonValue> = rows_to_process
104 .iter()
105 .map(row_to_json)
106 .collect::<Result<Vec<_>, CamelError>>()?;
107
108 let row_count = rows_json.len();
109
110 let mut msg = Message::new(Body::Json(JsonValue::Array(rows_json)));
112 msg.set_header(headers::ROW_COUNT, JsonValue::Number(row_count.into()));
113
114 let exchange = Exchange::new(msg);
115
116 let result = context.send_and_wait(exchange).await;
118 if let Err(e) = self
119 .handle_post_processing(pool, &result, &JsonValue::Null)
120 .await
121 {
122 error!(error = %e, "Post-processing failed for batch");
123 if self.config.break_batch_on_consume_fail {
124 return Err(e);
125 }
126 }
127 if let Err(ref consume_err) = result
129 && self.config.break_batch_on_consume_fail
130 {
131 return Err(consume_err.clone());
132 }
133 }
134
135 if let Some(ref batch_query) = self.config.on_consume_batch_complete
137 && let Err(e) = self
138 .execute_post_query(pool, batch_query, &JsonValue::Null)
139 .await
140 {
141 error!(error = %e, "onConsumeBatchComplete query failed");
142 }
143
144 Ok(())
145 }
146
147 async fn handle_post_processing(
149 &self,
150 pool: &AnyPool,
151 result: &Result<Exchange, CamelError>,
152 row_json: &JsonValue,
153 ) -> Result<(), CamelError> {
154 match result {
155 Ok(_) => {
156 if let Some(ref on_consume) = self.config.on_consume {
158 self.execute_post_query(pool, on_consume, row_json).await?;
159 }
160 }
161 Err(_) => {
162 if let Some(ref on_consume_failed) = self.config.on_consume_failed {
164 self.execute_post_query(pool, on_consume_failed, row_json)
165 .await?;
166 }
167 }
168 }
169 Ok(())
170 }
171
172 async fn execute_post_query(
174 &self,
175 pool: &AnyPool,
176 query_str: &str,
177 row_json: &JsonValue,
178 ) -> Result<(), CamelError> {
179 let template = parse_query_template(query_str, self.config.placeholder)?;
181
182 let mut temp_msg = Message::new(Body::Json(row_json.clone()));
185 if let Some(obj) = row_json.as_object() {
186 for (key, value) in obj {
187 temp_msg.set_header(format!("CamelSql.{}", key), value.clone());
188 }
189 }
190 let temp_exchange = Exchange::new(temp_msg);
191
192 let prepared = resolve_params(&template, &temp_exchange)?;
194
195 let query = bind_json_values(sqlx::query(&prepared.sql), &prepared.bindings);
197 let result = query.execute(pool).await.map_err(|e| {
198 CamelError::ProcessorError(format!("Post-query execution failed: {}", e))
199 })?;
200
201 if result.rows_affected() == 0 {
203 warn!(
204 query = query_str,
205 "Post-processing query affected 0 rows — the row may not have been marked correctly"
206 );
207 }
208
209 Ok(())
210 }
211}
212
213#[async_trait]
214impl Consumer for SqlConsumer {
215 async fn start(&mut self, context: ConsumerContext) -> Result<(), CamelError> {
216 let pool = self
218 .pool
219 .get_or_try_init(|| async {
220 self.config.resolve_defaults();
222
223 sqlx::any::install_default_drivers();
226 AnyPoolOptions::new()
227 .max_connections(
228 self.config
229 .max_connections
230 .expect("must be Some after resolve_defaults()"),
231 )
232 .min_connections(
233 self.config
234 .min_connections
235 .expect("must be Some after resolve_defaults()"),
236 )
237 .idle_timeout(Duration::from_secs(
238 self.config
239 .idle_timeout_secs
240 .expect("must be Some after resolve_defaults()"),
241 ))
242 .max_lifetime(Duration::from_secs(
243 self.config
244 .max_lifetime_secs
245 .expect("must be Some after resolve_defaults()"),
246 ))
247 .connect(&self.config.db_url)
248 .await
249 .map_err(|e| {
250 CamelError::EndpointCreationFailed(format!(
251 "Failed to connect to database: {}",
252 e
253 ))
254 })
255 })
256 .await?;
257
258 if self.config.on_consume.is_none() {
260 warn!(
261 "SQL consumer started without onConsume configured — consumed rows will not be marked/deleted"
262 );
263 }
264
265 let template = parse_query_template(&self.config.query, self.config.placeholder)
267 .map_err(|e| CamelError::Config(format!("Invalid query template: {}", e)))?;
268
269 if self.config.initial_delay_ms > 0 {
271 tokio::select! {
272 _ = context.cancelled() => {
273 info!("SQL consumer stopped during initial delay");
274 return Ok(());
275 }
276 _ = tokio::time::sleep(Duration::from_millis(self.config.initial_delay_ms)) => {}
277 }
278 }
279
280 loop {
282 tokio::select! {
283 _ = context.cancelled() => {
284 info!("SQL consumer stopped");
285 break;
286 }
287 _ = tokio::time::sleep(Duration::from_millis(self.config.delay_ms)) => {
288 if let Err(e) = self.poll_database(pool, &context, &template).await {
289 error!(error = %e, "SQL consumer poll failed");
290 }
291 }
292 }
293 }
294
295 Ok(())
296 }
297
298 async fn stop(&mut self) -> Result<(), CamelError> {
299 Ok(())
300 }
301
302 fn concurrency_model(&self) -> ConcurrencyModel {
303 ConcurrencyModel::Sequential
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313 use crate::config::SqlEndpointConfig;
314 use camel_component_api::ExchangeEnvelope;
315 use camel_component_api::UriConfig;
316 use sqlx::any::AnyPoolOptions;
317 use std::sync::Arc;
318 use tokio::sync::mpsc;
319 use tokio_util::sync::CancellationToken;
320
321 async fn sqlite_pool() -> AnyPool {
322 sqlx::any::install_default_drivers();
323 AnyPoolOptions::new()
324 .max_connections(1)
325 .connect("sqlite::memory:")
326 .await
327 .expect("sqlite pool")
328 }
329
330 async fn seed_consumer_table(pool: &AnyPool) {
331 sqlx::query("CREATE TABLE jobs (id INTEGER PRIMARY KEY, processed INTEGER DEFAULT 0, failed INTEGER DEFAULT 0)")
332 .execute(pool)
333 .await
334 .expect("create table");
335 sqlx::query("INSERT INTO jobs (id, processed, failed) VALUES (1, 0, 0), (2, 0, 0)")
336 .execute(pool)
337 .await
338 .expect("seed rows");
339 }
340
341 fn config() -> SqlEndpointConfig {
342 let mut c =
343 SqlEndpointConfig::from_uri("sql:select * from t?db_url=postgres://localhost/test")
344 .unwrap();
345 c.resolve_defaults();
346 c
347 }
348
349 #[test]
350 fn consumer_concurrency_model() {
351 let c = SqlConsumer::new(config(), Arc::new(OnceCell::new()));
352 assert_eq!(c.concurrency_model(), ConcurrencyModel::Sequential);
353 }
354
355 #[test]
356 fn consumer_stores_config() {
357 let mut config = SqlEndpointConfig::from_uri(
358 "sql:select * from t?db_url=postgres://localhost/test&delay=2000&onConsume=update t set done=true"
359 ).unwrap();
360 config.resolve_defaults();
361 let c = SqlConsumer::new(config.clone(), Arc::new(OnceCell::new()));
362 assert_eq!(c.config.delay_ms, 2000);
363 assert!(c.config.on_consume.is_some());
364 }
365
366 #[tokio::test]
367 async fn poll_database_runs_on_consume_for_successful_rows() {
368 let pool = sqlite_pool().await;
369 seed_consumer_table(&pool).await;
370
371 let mut config = SqlEndpointConfig::from_uri(
372 "sql:select id, processed, failed from jobs where processed = 0 order by id?db_url=sqlite::memory:&onConsume=update jobs set processed=1 where id=:#id&initialDelay=0&delay=1",
373 )
374 .unwrap();
375 config.resolve_defaults();
376
377 let consumer = SqlConsumer::new(config.clone(), Arc::new(OnceCell::new()));
378 let template = parse_query_template(&config.query, config.placeholder).unwrap();
379
380 let (tx, mut rx) = mpsc::channel::<ExchangeEnvelope>(8);
381 tokio::spawn(async move {
382 while let Some(env) = rx.recv().await {
383 if let Some(reply_tx) = env.reply_tx {
384 let _ = reply_tx.send(Ok(env.exchange));
385 }
386 }
387 });
388 let ctx = ConsumerContext::new(tx, CancellationToken::new());
389
390 consumer
391 .poll_database(&pool, &ctx, &template)
392 .await
393 .expect("poll must succeed");
394
395 let row = sqlx::query("select processed from jobs where id = 1")
396 .fetch_one(&pool)
397 .await
398 .expect("row 1");
399 let processed_1: i64 = sqlx::Row::try_get(&row, 0).expect("processed");
400
401 let row = sqlx::query("select processed from jobs where id = 2")
402 .fetch_one(&pool)
403 .await
404 .expect("row 2");
405 let processed_2: i64 = sqlx::Row::try_get(&row, 0).expect("processed");
406
407 assert_eq!(processed_1, 1);
408 assert_eq!(processed_2, 1);
409 }
410
411 #[tokio::test]
412 async fn poll_database_runs_on_consume_failed_when_downstream_fails() {
413 let pool = sqlite_pool().await;
414 seed_consumer_table(&pool).await;
415
416 let mut config = SqlEndpointConfig::from_uri(
417 "sql:select id, processed, failed from jobs where processed = 0 order by id?db_url=sqlite::memory:&onConsumeFailed=update jobs set failed=1 where id=:#id&initialDelay=0&delay=1",
418 )
419 .unwrap();
420 config.resolve_defaults();
421
422 let consumer = SqlConsumer::new(config.clone(), Arc::new(OnceCell::new()));
423 let template = parse_query_template(&config.query, config.placeholder).unwrap();
424
425 let (tx, mut rx) = mpsc::channel::<ExchangeEnvelope>(8);
426 tokio::spawn(async move {
427 while let Some(env) = rx.recv().await {
428 if let Some(reply_tx) = env.reply_tx {
429 let _ =
430 reply_tx.send(Err(CamelError::ProcessorError("downstream boom".into())));
431 }
432 }
433 });
434 let ctx = ConsumerContext::new(tx, CancellationToken::new());
435
436 consumer
437 .poll_database(&pool, &ctx, &template)
438 .await
439 .expect("consumer should swallow downstream errors when breakBatchOnConsumeFail=false");
440
441 let row = sqlx::query("select failed from jobs where id = 1")
442 .fetch_one(&pool)
443 .await
444 .expect("row 1");
445 let failed_1: i64 = sqlx::Row::try_get(&row, 0).expect("failed");
446
447 let row = sqlx::query("select failed from jobs where id = 2")
448 .fetch_one(&pool)
449 .await
450 .expect("row 2");
451 let failed_2: i64 = sqlx::Row::try_get(&row, 0).expect("failed");
452
453 assert_eq!(failed_1, 1);
454 assert_eq!(failed_2, 1);
455 }
456
457 #[tokio::test]
458 async fn poll_database_breaks_batch_on_consume_fail() {
459 let pool = sqlite_pool().await;
460 seed_consumer_table(&pool).await;
461
462 let mut config = SqlEndpointConfig::from_uri(
463 "sql:select id, processed, failed from jobs where processed = 0 order by id?db_url=sqlite::memory:&onConsumeFailed=update jobs set failed=1 where id=:#id&breakBatchOnConsumeFail=true&initialDelay=0&delay=1",
464 )
465 .unwrap();
466 config.resolve_defaults();
467
468 let consumer = SqlConsumer::new(config.clone(), Arc::new(OnceCell::new()));
469 let template = parse_query_template(&config.query, config.placeholder).unwrap();
470
471 let (tx, mut rx) = mpsc::channel::<ExchangeEnvelope>(8);
472 tokio::spawn(async move {
473 while let Some(env) = rx.recv().await {
474 if let Some(reply_tx) = env.reply_tx {
475 let _ =
476 reply_tx.send(Err(CamelError::ProcessorError("downstream boom".into())));
477 }
478 }
479 });
480 let ctx = ConsumerContext::new(tx, CancellationToken::new());
481
482 let err = consumer
483 .poll_database(&pool, &ctx, &template)
484 .await
485 .expect_err("must stop on first downstream failure");
486 assert!(err.to_string().contains("downstream boom"));
487
488 let row = sqlx::query("select failed from jobs where id = 1")
489 .fetch_one(&pool)
490 .await
491 .expect("row 1");
492 let failed_1: i64 = sqlx::Row::try_get(&row, 0).expect("failed");
493
494 let row = sqlx::query("select failed from jobs where id = 2")
495 .fetch_one(&pool)
496 .await
497 .expect("row 2");
498 let failed_2: i64 = sqlx::Row::try_get(&row, 0).expect("failed");
499
500 assert_eq!(failed_1, 1);
501 assert_eq!(failed_2, 0, "second row must not be processed");
502 }
503}