Skip to main content

camel_component_sql/
consumer.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use async_trait::async_trait;
5use serde_json::Value as JsonValue;
6use sqlx::AnyPool;
7use sqlx::any::AnyPoolOptions;
8use sqlx::any::AnyRow;
9use tokio::sync::OnceCell;
10use tracing::{error, info, warn};
11
12use camel_component_api::{Body, CamelError, Exchange, Message};
13use camel_component_api::{ConcurrencyModel, Consumer, ConsumerContext};
14
15use crate::config::SqlEndpointConfig;
16use crate::headers;
17use crate::query::{QueryTemplate, parse_query_template, resolve_params};
18use crate::utils::{bind_json_values, row_to_json};
19
20pub struct SqlConsumer {
21    pub(crate) config: SqlEndpointConfig,
22    pub(crate) pool: Arc<OnceCell<AnyPool>>,
23}
24
25impl SqlConsumer {
26    pub fn new(config: SqlEndpointConfig, pool: Arc<OnceCell<AnyPool>>) -> Self {
27        Self { config, pool }
28    }
29
30    /// Poll the database for new rows and process them.
31    async fn poll_database(
32        &self,
33        pool: &AnyPool,
34        context: &ConsumerContext,
35        template: &QueryTemplate,
36    ) -> Result<(), CamelError> {
37        // Create an empty exchange for parameter resolution (consumer has no input)
38        let empty_exchange = Exchange::new(Message::default());
39
40        // Resolve parameters
41        let prepared = resolve_params(template, &empty_exchange)?;
42
43        // Build and execute the query
44        let query = bind_json_values(sqlx::query(&prepared.sql), &prepared.bindings);
45        let rows: Vec<AnyRow> = query
46            .fetch_all(pool)
47            .await
48            .map_err(|e| CamelError::ProcessorError(format!("Query execution failed: {}", e)))?;
49
50        // Check for empty result set
51        if rows.is_empty() && !self.config.route_empty_result_set {
52            return Ok(());
53        }
54
55        // Apply max_messages_per_poll limit
56        let rows_to_process: Vec<AnyRow> = if let Some(max) = self.config.max_messages_per_poll {
57            if max > 0 {
58                rows.into_iter().take(max as usize).collect()
59            } else {
60                rows
61            }
62        } else {
63            rows
64        };
65
66        if self.config.use_iterator {
67            // Process each row individually
68            for row in rows_to_process {
69                let row_json = row_to_json(&row)?;
70
71                // Create exchange with the row as JSON body
72                let mut msg = Message::new(Body::Json(row_json.clone()));
73
74                // Set individual column headers with CamelSql. prefix per Apache Camel convention
75                if let Some(obj) = row_json.as_object() {
76                    for (key, value) in obj {
77                        msg.set_header(format!("CamelSql.{}", key), value.clone());
78                    }
79                }
80
81                let exchange = Exchange::new(msg);
82
83                // Send and wait for processing
84                let result = context.send_and_wait(exchange).await;
85
86                // Handle post-processing (onConsume/onConsumeFailed)
87                if let Err(e) = self.handle_post_processing(pool, &result, &row_json).await {
88                    error!(error = %e, "Post-processing failed");
89                    if self.config.break_batch_on_consume_fail {
90                        return Err(e);
91                    }
92                }
93
94                // If downstream processing itself failed, honour break_batch_on_consume_fail
95                if let Err(ref consume_err) = result
96                    && self.config.break_batch_on_consume_fail
97                {
98                    return Err(consume_err.clone());
99                }
100            }
101        } else {
102            // Process all rows as a single batch
103            let rows_json: Vec<JsonValue> = rows_to_process
104                .iter()
105                .map(row_to_json)
106                .collect::<Result<Vec<_>, CamelError>>()?;
107
108            let row_count = rows_json.len();
109
110            // Create exchange with array of rows
111            let mut msg = Message::new(Body::Json(JsonValue::Array(rows_json)));
112            msg.set_header(headers::ROW_COUNT, JsonValue::Number(row_count.into()));
113
114            let exchange = Exchange::new(msg);
115
116            // Send and wait for result, then run post-processing with Null row
117            let result = context.send_and_wait(exchange).await;
118            if let Err(e) = self
119                .handle_post_processing(pool, &result, &JsonValue::Null)
120                .await
121            {
122                error!(error = %e, "Post-processing failed for batch");
123                if self.config.break_batch_on_consume_fail {
124                    return Err(e);
125                }
126            }
127            // If downstream processing itself failed, honour break_batch_on_consume_fail
128            if let Err(ref consume_err) = result
129                && self.config.break_batch_on_consume_fail
130            {
131                return Err(consume_err.clone());
132            }
133        }
134
135        // Execute on_consume_batch_complete if configured
136        if let Some(ref batch_query) = self.config.on_consume_batch_complete
137            && let Err(e) = self
138                .execute_post_query(pool, batch_query, &JsonValue::Null)
139                .await
140        {
141            error!(error = %e, "onConsumeBatchComplete query failed");
142        }
143
144        Ok(())
145    }
146
147    /// Handle post-processing after a row is processed (onConsume/onConsumeFailed).
148    async fn handle_post_processing(
149        &self,
150        pool: &AnyPool,
151        result: &Result<Exchange, CamelError>,
152        row_json: &JsonValue,
153    ) -> Result<(), CamelError> {
154        match result {
155            Ok(_) => {
156                // Success - execute onConsume if configured
157                if let Some(ref on_consume) = self.config.on_consume {
158                    self.execute_post_query(pool, on_consume, row_json).await?;
159                }
160            }
161            Err(_) => {
162                // Failure - execute onConsumeFailed if configured
163                if let Some(ref on_consume_failed) = self.config.on_consume_failed {
164                    self.execute_post_query(pool, on_consume_failed, row_json)
165                        .await?;
166                }
167            }
168        }
169        Ok(())
170    }
171
172    /// Execute a post-processing query with the row data as parameters.
173    async fn execute_post_query(
174        &self,
175        pool: &AnyPool,
176        query_str: &str,
177        row_json: &JsonValue,
178    ) -> Result<(), CamelError> {
179        // Parse the query template
180        let template = parse_query_template(query_str, self.config.placeholder)?;
181
182        // Create a temporary exchange with the row as body for parameter resolution
183        // Populate CamelSql.* headers so named params can reference them
184        let mut temp_msg = Message::new(Body::Json(row_json.clone()));
185        if let Some(obj) = row_json.as_object() {
186            for (key, value) in obj {
187                temp_msg.set_header(format!("CamelSql.{}", key), value.clone());
188            }
189        }
190        let temp_exchange = Exchange::new(temp_msg);
191
192        // Resolve parameters
193        let prepared = resolve_params(&template, &temp_exchange)?;
194
195        // Build and execute the query
196        let query = bind_json_values(sqlx::query(&prepared.sql), &prepared.bindings);
197        let result = query.execute(pool).await.map_err(|e| {
198            CamelError::ProcessorError(format!("Post-query execution failed: {}", e))
199        })?;
200
201        // Warn if 0 rows affected (the row may not have been marked correctly)
202        if result.rows_affected() == 0 {
203            warn!(
204                query = query_str,
205                "Post-processing query affected 0 rows — the row may not have been marked correctly"
206            );
207        }
208
209        Ok(())
210    }
211}
212
213#[async_trait]
214impl Consumer for SqlConsumer {
215    async fn start(&mut self, context: ConsumerContext) -> Result<(), CamelError> {
216        // Step 1: Initialize the connection pool
217        let pool = self
218            .pool
219            .get_or_try_init(|| async {
220                // Defensive: ensure config is resolved even if caller didn't use create_endpoint
221                self.config.resolve_defaults();
222
223                // Install all compiled-in sqlx drivers so AnyPool can resolve them.
224                // This is idempotent; safe to call multiple times.
225                sqlx::any::install_default_drivers();
226                AnyPoolOptions::new()
227                    .max_connections(
228                        self.config
229                            .max_connections
230                            .expect("must be Some after resolve_defaults()"),
231                    )
232                    .min_connections(
233                        self.config
234                            .min_connections
235                            .expect("must be Some after resolve_defaults()"),
236                    )
237                    .idle_timeout(Duration::from_secs(
238                        self.config
239                            .idle_timeout_secs
240                            .expect("must be Some after resolve_defaults()"),
241                    ))
242                    .max_lifetime(Duration::from_secs(
243                        self.config
244                            .max_lifetime_secs
245                            .expect("must be Some after resolve_defaults()"),
246                    ))
247                    .connect(&self.config.db_url)
248                    .await
249                    .map_err(|e| {
250                        CamelError::EndpointCreationFailed(format!(
251                            "Failed to connect to database: {}",
252                            e
253                        ))
254                    })
255            })
256            .await?;
257
258        // Warn if no onConsume configured
259        if self.config.on_consume.is_none() {
260            warn!(
261                "SQL consumer started without onConsume configured — consumed rows will not be marked/deleted"
262            );
263        }
264
265        // Step 2: Parse query template once (avoid re-parsing every poll)
266        let template = parse_query_template(&self.config.query, self.config.placeholder)
267            .map_err(|e| CamelError::Config(format!("Invalid query template: {}", e)))?;
268
269        // Step 3: Initial delay before starting polling
270        if self.config.initial_delay_ms > 0 {
271            tokio::select! {
272                _ = context.cancelled() => {
273                    info!("SQL consumer stopped during initial delay");
274                    return Ok(());
275                }
276                _ = tokio::time::sleep(Duration::from_millis(self.config.initial_delay_ms)) => {}
277            }
278        }
279
280        // Step 4: Polling loop
281        loop {
282            tokio::select! {
283                _ = context.cancelled() => {
284                    info!("SQL consumer stopped");
285                    break;
286                }
287                _ = tokio::time::sleep(Duration::from_millis(self.config.delay_ms)) => {
288                    if let Err(e) = self.poll_database(pool, &context, &template).await {
289                        error!(error = %e, "SQL consumer poll failed");
290                    }
291                }
292            }
293        }
294
295        Ok(())
296    }
297
298    async fn stop(&mut self) -> Result<(), CamelError> {
299        Ok(())
300    }
301
302    fn concurrency_model(&self) -> ConcurrencyModel {
303        // Sequential is correct for SQL consumers: concurrent polls would fetch
304        // duplicate rows. The design doc mentioned SharedState (which doesn't exist
305        // in this runtime) — Sequential is the correct equivalent.
306        ConcurrencyModel::Sequential
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313    use crate::config::SqlEndpointConfig;
314    use camel_component_api::ExchangeEnvelope;
315    use camel_component_api::UriConfig;
316    use sqlx::any::AnyPoolOptions;
317    use std::sync::Arc;
318    use tokio::sync::mpsc;
319    use tokio_util::sync::CancellationToken;
320
321    async fn sqlite_pool() -> AnyPool {
322        sqlx::any::install_default_drivers();
323        AnyPoolOptions::new()
324            .max_connections(1)
325            .connect("sqlite::memory:")
326            .await
327            .expect("sqlite pool")
328    }
329
330    async fn seed_consumer_table(pool: &AnyPool) {
331        sqlx::query("CREATE TABLE jobs (id INTEGER PRIMARY KEY, processed INTEGER DEFAULT 0, failed INTEGER DEFAULT 0)")
332            .execute(pool)
333            .await
334            .expect("create table");
335        sqlx::query("INSERT INTO jobs (id, processed, failed) VALUES (1, 0, 0), (2, 0, 0)")
336            .execute(pool)
337            .await
338            .expect("seed rows");
339    }
340
341    fn config() -> SqlEndpointConfig {
342        let mut c =
343            SqlEndpointConfig::from_uri("sql:select * from t?db_url=postgres://localhost/test")
344                .unwrap();
345        c.resolve_defaults();
346        c
347    }
348
349    #[test]
350    fn consumer_concurrency_model() {
351        let c = SqlConsumer::new(config(), Arc::new(OnceCell::new()));
352        assert_eq!(c.concurrency_model(), ConcurrencyModel::Sequential);
353    }
354
355    #[test]
356    fn consumer_stores_config() {
357        let mut config = SqlEndpointConfig::from_uri(
358            "sql:select * from t?db_url=postgres://localhost/test&delay=2000&onConsume=update t set done=true"
359        ).unwrap();
360        config.resolve_defaults();
361        let c = SqlConsumer::new(config.clone(), Arc::new(OnceCell::new()));
362        assert_eq!(c.config.delay_ms, 2000);
363        assert!(c.config.on_consume.is_some());
364    }
365
366    #[tokio::test]
367    async fn poll_database_runs_on_consume_for_successful_rows() {
368        let pool = sqlite_pool().await;
369        seed_consumer_table(&pool).await;
370
371        let mut config = SqlEndpointConfig::from_uri(
372            "sql:select id, processed, failed from jobs where processed = 0 order by id?db_url=sqlite::memory:&onConsume=update jobs set processed=1 where id=:#id&initialDelay=0&delay=1",
373        )
374        .unwrap();
375        config.resolve_defaults();
376
377        let consumer = SqlConsumer::new(config.clone(), Arc::new(OnceCell::new()));
378        let template = parse_query_template(&config.query, config.placeholder).unwrap();
379
380        let (tx, mut rx) = mpsc::channel::<ExchangeEnvelope>(8);
381        tokio::spawn(async move {
382            while let Some(env) = rx.recv().await {
383                if let Some(reply_tx) = env.reply_tx {
384                    let _ = reply_tx.send(Ok(env.exchange));
385                }
386            }
387        });
388        let ctx = ConsumerContext::new(tx, CancellationToken::new());
389
390        consumer
391            .poll_database(&pool, &ctx, &template)
392            .await
393            .expect("poll must succeed");
394
395        let row = sqlx::query("select processed from jobs where id = 1")
396            .fetch_one(&pool)
397            .await
398            .expect("row 1");
399        let processed_1: i64 = sqlx::Row::try_get(&row, 0).expect("processed");
400
401        let row = sqlx::query("select processed from jobs where id = 2")
402            .fetch_one(&pool)
403            .await
404            .expect("row 2");
405        let processed_2: i64 = sqlx::Row::try_get(&row, 0).expect("processed");
406
407        assert_eq!(processed_1, 1);
408        assert_eq!(processed_2, 1);
409    }
410
411    #[tokio::test]
412    async fn poll_database_runs_on_consume_failed_when_downstream_fails() {
413        let pool = sqlite_pool().await;
414        seed_consumer_table(&pool).await;
415
416        let mut config = SqlEndpointConfig::from_uri(
417            "sql:select id, processed, failed from jobs where processed = 0 order by id?db_url=sqlite::memory:&onConsumeFailed=update jobs set failed=1 where id=:#id&initialDelay=0&delay=1",
418        )
419        .unwrap();
420        config.resolve_defaults();
421
422        let consumer = SqlConsumer::new(config.clone(), Arc::new(OnceCell::new()));
423        let template = parse_query_template(&config.query, config.placeholder).unwrap();
424
425        let (tx, mut rx) = mpsc::channel::<ExchangeEnvelope>(8);
426        tokio::spawn(async move {
427            while let Some(env) = rx.recv().await {
428                if let Some(reply_tx) = env.reply_tx {
429                    let _ =
430                        reply_tx.send(Err(CamelError::ProcessorError("downstream boom".into())));
431                }
432            }
433        });
434        let ctx = ConsumerContext::new(tx, CancellationToken::new());
435
436        consumer
437            .poll_database(&pool, &ctx, &template)
438            .await
439            .expect("consumer should swallow downstream errors when breakBatchOnConsumeFail=false");
440
441        let row = sqlx::query("select failed from jobs where id = 1")
442            .fetch_one(&pool)
443            .await
444            .expect("row 1");
445        let failed_1: i64 = sqlx::Row::try_get(&row, 0).expect("failed");
446
447        let row = sqlx::query("select failed from jobs where id = 2")
448            .fetch_one(&pool)
449            .await
450            .expect("row 2");
451        let failed_2: i64 = sqlx::Row::try_get(&row, 0).expect("failed");
452
453        assert_eq!(failed_1, 1);
454        assert_eq!(failed_2, 1);
455    }
456
457    #[tokio::test]
458    async fn poll_database_breaks_batch_on_consume_fail() {
459        let pool = sqlite_pool().await;
460        seed_consumer_table(&pool).await;
461
462        let mut config = SqlEndpointConfig::from_uri(
463            "sql:select id, processed, failed from jobs where processed = 0 order by id?db_url=sqlite::memory:&onConsumeFailed=update jobs set failed=1 where id=:#id&breakBatchOnConsumeFail=true&initialDelay=0&delay=1",
464        )
465        .unwrap();
466        config.resolve_defaults();
467
468        let consumer = SqlConsumer::new(config.clone(), Arc::new(OnceCell::new()));
469        let template = parse_query_template(&config.query, config.placeholder).unwrap();
470
471        let (tx, mut rx) = mpsc::channel::<ExchangeEnvelope>(8);
472        tokio::spawn(async move {
473            while let Some(env) = rx.recv().await {
474                if let Some(reply_tx) = env.reply_tx {
475                    let _ =
476                        reply_tx.send(Err(CamelError::ProcessorError("downstream boom".into())));
477                }
478            }
479        });
480        let ctx = ConsumerContext::new(tx, CancellationToken::new());
481
482        let err = consumer
483            .poll_database(&pool, &ctx, &template)
484            .await
485            .expect_err("must stop on first downstream failure");
486        assert!(err.to_string().contains("downstream boom"));
487
488        let row = sqlx::query("select failed from jobs where id = 1")
489            .fetch_one(&pool)
490            .await
491            .expect("row 1");
492        let failed_1: i64 = sqlx::Row::try_get(&row, 0).expect("failed");
493
494        let row = sqlx::query("select failed from jobs where id = 2")
495            .fetch_one(&pool)
496            .await
497            .expect("row 2");
498        let failed_2: i64 = sqlx::Row::try_get(&row, 0).expect("failed");
499
500        assert_eq!(failed_1, 1);
501        assert_eq!(failed_2, 0, "second row must not be processed");
502    }
503}