Skip to main content

camel_component_sql/
consumer.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use async_trait::async_trait;
5use serde_json::Value as JsonValue;
6use sqlx::AnyPool;
7use sqlx::any::AnyPoolOptions;
8use sqlx::any::AnyRow;
9use tokio::sync::OnceCell;
10use tracing::{error, info, warn};
11
12use camel_component_api::{Body, CamelError, Exchange, Message};
13use camel_component_api::{ConcurrencyModel, Consumer, ConsumerContext};
14
15use crate::config::{SqlEndpointConfig, enrich_db_url_with_ssl};
16use crate::headers;
17use crate::query::{QueryTemplate, parse_query_template, resolve_params};
18use crate::utils::{bind_json_values, row_to_json};
19
20pub struct SqlConsumer {
21    pub(crate) config: SqlEndpointConfig,
22    pub(crate) pool: Arc<OnceCell<AnyPool>>,
23}
24
25impl SqlConsumer {
26    pub fn new(config: SqlEndpointConfig, pool: Arc<OnceCell<AnyPool>>) -> Self {
27        Self { config, pool }
28    }
29
30    /// Poll the database for new rows and process them.
31    async fn poll_database(
32        &self,
33        pool: &AnyPool,
34        context: &ConsumerContext,
35        template: &QueryTemplate,
36    ) -> Result<(), CamelError> {
37        // Create an empty exchange for parameter resolution (consumer has no input)
38        let empty_exchange = Exchange::new(Message::default());
39
40        // Resolve parameters
41        let prepared = resolve_params(template, &empty_exchange, &self.config.in_separator)?;
42
43        // Build and execute the query
44        let query = bind_json_values(sqlx::query(&prepared.sql), &prepared.bindings);
45        let rows: Vec<AnyRow> = query
46            .fetch_all(pool)
47            .await
48            .map_err(|e| CamelError::ProcessorError(format!("Query execution failed: {}", e)))?;
49
50        // Check for empty result set
51        if rows.is_empty() && !self.config.route_empty_result_set {
52            return Ok(());
53        }
54
55        // Apply max_messages_per_poll limit
56        let rows_to_process: Vec<AnyRow> = if let Some(max) = self.config.max_messages_per_poll {
57            if max > 0 {
58                rows.into_iter().take(max as usize).collect()
59            } else {
60                rows
61            }
62        } else {
63            rows
64        };
65
66        if self.config.use_iterator {
67            // Process each row individually
68            for row in rows_to_process {
69                let row_json = row_to_json(&row)?;
70
71                // Create exchange with the row as JSON body
72                let mut msg = Message::new(Body::Json(row_json.clone()));
73
74                // Set individual column headers with CamelSql. prefix per Apache Camel convention
75                if let Some(obj) = row_json.as_object() {
76                    for (key, value) in obj {
77                        msg.set_header(format!("CamelSql.{}", key), value.clone());
78                    }
79                }
80
81                let exchange = Exchange::new(msg);
82
83                // Send and wait for processing
84                let result = context.send_and_wait(exchange).await;
85
86                // Handle post-processing (onConsume/onConsumeFailed)
87                if let Err(e) = self.handle_post_processing(pool, &result, &row_json).await {
88                    error!(error = %e, "Post-processing failed");
89                    if self.config.break_batch_on_consume_fail {
90                        return Err(e);
91                    }
92                }
93
94                // If downstream processing itself failed, honour break_batch_on_consume_fail
95                if let Err(ref consume_err) = result
96                    && self.config.break_batch_on_consume_fail
97                {
98                    return Err(consume_err.clone());
99                }
100            }
101        } else {
102            // Process all rows as a single batch
103            let rows_json: Vec<JsonValue> = rows_to_process
104                .iter()
105                .map(row_to_json)
106                .collect::<Result<Vec<_>, CamelError>>()?;
107
108            let row_count = rows_json.len();
109
110            // Create exchange with array of rows
111            let mut msg = Message::new(Body::Json(JsonValue::Array(rows_json)));
112            msg.set_header(headers::ROW_COUNT, JsonValue::Number(row_count.into()));
113
114            let exchange = Exchange::new(msg);
115
116            // Send and wait for result, then run post-processing with Null row
117            let result = context.send_and_wait(exchange).await;
118            if let Err(e) = self
119                .handle_post_processing(pool, &result, &JsonValue::Null)
120                .await
121            {
122                error!(error = %e, "Post-processing failed for batch");
123                if self.config.break_batch_on_consume_fail {
124                    return Err(e);
125                }
126            }
127            // If downstream processing itself failed, honour break_batch_on_consume_fail
128            if let Err(ref consume_err) = result
129                && self.config.break_batch_on_consume_fail
130            {
131                return Err(consume_err.clone());
132            }
133        }
134
135        // Execute on_consume_batch_complete if configured
136        if let Some(ref batch_query) = self.config.on_consume_batch_complete
137            && let Err(e) = self
138                .execute_post_query(pool, batch_query, &JsonValue::Null)
139                .await
140        {
141            error!(error = %e, "onConsumeBatchComplete query failed");
142        }
143
144        Ok(())
145    }
146
147    /// Handle post-processing after a row is processed (onConsume/onConsumeFailed).
148    async fn handle_post_processing(
149        &self,
150        pool: &AnyPool,
151        result: &Result<Exchange, CamelError>,
152        row_json: &JsonValue,
153    ) -> Result<(), CamelError> {
154        match result {
155            Ok(_) => {
156                // Success - execute onConsume if configured
157                if let Some(ref on_consume) = self.config.on_consume {
158                    self.execute_post_query(pool, on_consume, row_json).await?;
159                }
160            }
161            Err(_) => {
162                // Failure - execute onConsumeFailed if configured
163                if let Some(ref on_consume_failed) = self.config.on_consume_failed {
164                    self.execute_post_query(pool, on_consume_failed, row_json)
165                        .await?;
166                }
167            }
168        }
169        Ok(())
170    }
171
172    /// Execute a post-processing query with the row data as parameters.
173    async fn execute_post_query(
174        &self,
175        pool: &AnyPool,
176        query_str: &str,
177        row_json: &JsonValue,
178    ) -> Result<(), CamelError> {
179        // Parse the query template
180        let template = parse_query_template(query_str, self.config.placeholder)?;
181
182        // Create a temporary exchange with the row as body for parameter resolution
183        // Populate CamelSql.* headers so named params can reference them
184        let mut temp_msg = Message::new(Body::Json(row_json.clone()));
185        if let Some(obj) = row_json.as_object() {
186            for (key, value) in obj {
187                temp_msg.set_header(format!("CamelSql.{}", key), value.clone());
188            }
189        }
190        let temp_exchange = Exchange::new(temp_msg);
191
192        // Resolve parameters
193        let prepared = resolve_params(&template, &temp_exchange, &self.config.in_separator)?;
194
195        // Build and execute the query
196        let query = bind_json_values(sqlx::query(&prepared.sql), &prepared.bindings);
197        let result = query.execute(pool).await.map_err(|e| {
198            CamelError::ProcessorError(format!("Post-query execution failed: {}", e))
199        })?;
200
201        // Warn if 0 rows affected (the row may not have been marked correctly)
202        if result.rows_affected() == 0 {
203            warn!(
204                query = query_str,
205                "Post-processing query affected 0 rows — the row may not have been marked correctly"
206            );
207        }
208
209        Ok(())
210    }
211}
212
213#[async_trait]
214impl Consumer for SqlConsumer {
215    async fn start(&mut self, context: ConsumerContext) -> Result<(), CamelError> {
216        // Step 1: Initialize the connection pool
217        let pool = self
218            .pool
219            .get_or_try_init(|| async {
220                // Defensive: ensure config is resolved even if caller didn't use create_endpoint
221                self.config.resolve_defaults();
222
223                // Install all compiled-in sqlx drivers so AnyPool can resolve them.
224                // This is idempotent; safe to call multiple times.
225                sqlx::any::install_default_drivers();
226                let db_url = enrich_db_url_with_ssl(&self.config.db_url, &self.config)?;
227                AnyPoolOptions::new()
228                    .max_connections(
229                        self.config
230                            .max_connections
231                            .expect("must be Some after resolve_defaults()"),
232                    )
233                    .min_connections(
234                        self.config
235                            .min_connections
236                            .expect("must be Some after resolve_defaults()"),
237                    )
238                    .idle_timeout(Duration::from_secs(
239                        self.config
240                            .idle_timeout_secs
241                            .expect("must be Some after resolve_defaults()"),
242                    ))
243                    .max_lifetime(Duration::from_secs(
244                        self.config
245                            .max_lifetime_secs
246                            .expect("must be Some after resolve_defaults()"),
247                    ))
248                    .connect(&db_url)
249                    .await
250                    .map_err(|e| {
251                        CamelError::EndpointCreationFailed(format!(
252                            "Failed to connect to database: {}",
253                            e
254                        ))
255                    })
256            })
257            .await?;
258
259        // Warn if no onConsume configured
260        if self.config.on_consume.is_none() {
261            warn!(
262                "SQL consumer started without onConsume configured — consumed rows will not be marked/deleted"
263            );
264        }
265
266        // Step 2: Parse query template once (avoid re-parsing every poll)
267        let template = parse_query_template(&self.config.query, self.config.placeholder)
268            .map_err(|e| CamelError::Config(format!("Invalid query template: {}", e)))?;
269
270        // Step 3: Initial delay before starting polling
271        if self.config.initial_delay_ms > 0 {
272            tokio::select! {
273                _ = context.cancelled() => {
274                    info!("SQL consumer stopped during initial delay");
275                    return Ok(());
276                }
277                _ = tokio::time::sleep(Duration::from_millis(self.config.initial_delay_ms)) => {}
278            }
279        }
280
281        // Step 4: Polling loop
282        loop {
283            tokio::select! {
284                _ = context.cancelled() => {
285                    info!("SQL consumer stopped");
286                    break;
287                }
288                _ = tokio::time::sleep(Duration::from_millis(self.config.delay_ms)) => {
289                    if let Err(e) = self.poll_database(pool, &context, &template).await {
290                        error!(error = %e, "SQL consumer poll failed");
291                    }
292                }
293            }
294        }
295
296        Ok(())
297    }
298
299    async fn stop(&mut self) -> Result<(), CamelError> {
300        Ok(())
301    }
302
303    fn concurrency_model(&self) -> ConcurrencyModel {
304        // Sequential is correct for SQL consumers: concurrent polls would fetch
305        // duplicate rows. The design doc mentioned SharedState (which doesn't exist
306        // in this runtime) — Sequential is the correct equivalent.
307        ConcurrencyModel::Sequential
308    }
309}
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314    use crate::config::SqlEndpointConfig;
315    use camel_component_api::ExchangeEnvelope;
316    use camel_component_api::UriConfig;
317    use sqlx::any::AnyPoolOptions;
318    use std::sync::Arc;
319    use tokio::sync::mpsc;
320    use tokio_util::sync::CancellationToken;
321
322    async fn sqlite_pool() -> AnyPool {
323        sqlx::any::install_default_drivers();
324        AnyPoolOptions::new()
325            .max_connections(1)
326            .connect("sqlite::memory:")
327            .await
328            .expect("sqlite pool")
329    }
330
331    async fn seed_consumer_table(pool: &AnyPool) {
332        sqlx::query("CREATE TABLE jobs (id INTEGER PRIMARY KEY, processed INTEGER DEFAULT 0, failed INTEGER DEFAULT 0)")
333            .execute(pool)
334            .await
335            .expect("create table");
336        sqlx::query("INSERT INTO jobs (id, processed, failed) VALUES (1, 0, 0), (2, 0, 0)")
337            .execute(pool)
338            .await
339            .expect("seed rows");
340    }
341
342    fn config() -> SqlEndpointConfig {
343        let mut c =
344            SqlEndpointConfig::from_uri("sql:select * from t?db_url=postgres://localhost/test")
345                .unwrap();
346        c.resolve_defaults();
347        c
348    }
349
350    #[test]
351    fn consumer_concurrency_model() {
352        let c = SqlConsumer::new(config(), Arc::new(OnceCell::new()));
353        assert_eq!(c.concurrency_model(), ConcurrencyModel::Sequential);
354    }
355
356    #[test]
357    fn consumer_stores_config() {
358        let mut config = SqlEndpointConfig::from_uri(
359            "sql:select * from t?db_url=postgres://localhost/test&delay=2000&onConsume=update t set done=true"
360        ).unwrap();
361        config.resolve_defaults();
362        let c = SqlConsumer::new(config.clone(), Arc::new(OnceCell::new()));
363        assert_eq!(c.config.delay_ms, 2000);
364        assert!(c.config.on_consume.is_some());
365    }
366
367    #[tokio::test]
368    async fn poll_database_runs_on_consume_for_successful_rows() {
369        let pool = sqlite_pool().await;
370        seed_consumer_table(&pool).await;
371
372        let mut config = SqlEndpointConfig::from_uri(
373            "sql:select id, processed, failed from jobs where processed = 0 order by id?db_url=sqlite::memory:&onConsume=update jobs set processed=1 where id=:#id&initialDelay=0&delay=1",
374        )
375        .unwrap();
376        config.resolve_defaults();
377
378        let consumer = SqlConsumer::new(config.clone(), Arc::new(OnceCell::new()));
379        let template = parse_query_template(&config.query, config.placeholder).unwrap();
380
381        let (tx, mut rx) = mpsc::channel::<ExchangeEnvelope>(8);
382        tokio::spawn(async move {
383            while let Some(env) = rx.recv().await {
384                if let Some(reply_tx) = env.reply_tx {
385                    let _ = reply_tx.send(Ok(env.exchange));
386                }
387            }
388        });
389        let ctx = ConsumerContext::new(tx, CancellationToken::new());
390
391        consumer
392            .poll_database(&pool, &ctx, &template)
393            .await
394            .expect("poll must succeed");
395
396        let row = sqlx::query("select processed from jobs where id = 1")
397            .fetch_one(&pool)
398            .await
399            .expect("row 1");
400        let processed_1: i64 = sqlx::Row::try_get(&row, 0).expect("processed");
401
402        let row = sqlx::query("select processed from jobs where id = 2")
403            .fetch_one(&pool)
404            .await
405            .expect("row 2");
406        let processed_2: i64 = sqlx::Row::try_get(&row, 0).expect("processed");
407
408        assert_eq!(processed_1, 1);
409        assert_eq!(processed_2, 1);
410    }
411
412    #[tokio::test]
413    async fn poll_database_runs_on_consume_failed_when_downstream_fails() {
414        let pool = sqlite_pool().await;
415        seed_consumer_table(&pool).await;
416
417        let mut config = SqlEndpointConfig::from_uri(
418            "sql:select id, processed, failed from jobs where processed = 0 order by id?db_url=sqlite::memory:&onConsumeFailed=update jobs set failed=1 where id=:#id&initialDelay=0&delay=1",
419        )
420        .unwrap();
421        config.resolve_defaults();
422
423        let consumer = SqlConsumer::new(config.clone(), Arc::new(OnceCell::new()));
424        let template = parse_query_template(&config.query, config.placeholder).unwrap();
425
426        let (tx, mut rx) = mpsc::channel::<ExchangeEnvelope>(8);
427        tokio::spawn(async move {
428            while let Some(env) = rx.recv().await {
429                if let Some(reply_tx) = env.reply_tx {
430                    let _ =
431                        reply_tx.send(Err(CamelError::ProcessorError("downstream boom".into())));
432                }
433            }
434        });
435        let ctx = ConsumerContext::new(tx, CancellationToken::new());
436
437        consumer
438            .poll_database(&pool, &ctx, &template)
439            .await
440            .expect("consumer should swallow downstream errors when breakBatchOnConsumeFail=false");
441
442        let row = sqlx::query("select failed from jobs where id = 1")
443            .fetch_one(&pool)
444            .await
445            .expect("row 1");
446        let failed_1: i64 = sqlx::Row::try_get(&row, 0).expect("failed");
447
448        let row = sqlx::query("select failed from jobs where id = 2")
449            .fetch_one(&pool)
450            .await
451            .expect("row 2");
452        let failed_2: i64 = sqlx::Row::try_get(&row, 0).expect("failed");
453
454        assert_eq!(failed_1, 1);
455        assert_eq!(failed_2, 1);
456    }
457
458    #[tokio::test]
459    async fn poll_database_breaks_batch_on_consume_fail() {
460        let pool = sqlite_pool().await;
461        seed_consumer_table(&pool).await;
462
463        let mut config = SqlEndpointConfig::from_uri(
464            "sql:select id, processed, failed from jobs where processed = 0 order by id?db_url=sqlite::memory:&onConsumeFailed=update jobs set failed=1 where id=:#id&breakBatchOnConsumeFail=true&initialDelay=0&delay=1",
465        )
466        .unwrap();
467        config.resolve_defaults();
468
469        let consumer = SqlConsumer::new(config.clone(), Arc::new(OnceCell::new()));
470        let template = parse_query_template(&config.query, config.placeholder).unwrap();
471
472        let (tx, mut rx) = mpsc::channel::<ExchangeEnvelope>(8);
473        tokio::spawn(async move {
474            while let Some(env) = rx.recv().await {
475                if let Some(reply_tx) = env.reply_tx {
476                    let _ =
477                        reply_tx.send(Err(CamelError::ProcessorError("downstream boom".into())));
478                }
479            }
480        });
481        let ctx = ConsumerContext::new(tx, CancellationToken::new());
482
483        let err = consumer
484            .poll_database(&pool, &ctx, &template)
485            .await
486            .expect_err("must stop on first downstream failure");
487        assert!(err.to_string().contains("downstream boom"));
488
489        let row = sqlx::query("select failed from jobs where id = 1")
490            .fetch_one(&pool)
491            .await
492            .expect("row 1");
493        let failed_1: i64 = sqlx::Row::try_get(&row, 0).expect("failed");
494
495        let row = sqlx::query("select failed from jobs where id = 2")
496            .fetch_one(&pool)
497            .await
498            .expect("row 2");
499        let failed_2: i64 = sqlx::Row::try_get(&row, 0).expect("failed");
500
501        assert_eq!(failed_1, 1);
502        assert_eq!(failed_2, 0, "second row must not be processed");
503    }
504}