Skip to main content

camel_component_sql/
producer.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::{Context, Poll};
5use std::time::Duration;
6
7use bytes::Bytes;
8use serde_json::json;
9use sqlx::AnyPool;
10use sqlx::any::AnyRow;
11use sqlx::pool::PoolOptions;
12use tokio::sync::OnceCell;
13use tower::Service;
14use tracing::{debug, error, warn};
15
16use crate::config::{SqlConfig, SqlOutputType};
17use crate::headers;
18use crate::query::{PreparedQuery, is_select_query, parse_query_template, resolve_params};
19use crate::utils::{bind_json_values, row_to_json};
20use camel_api::{Body, CamelError, Exchange, Message, StreamBody, StreamMetadata};
21
22#[derive(Clone)]
23pub struct SqlProducer {
24    pub(crate) config: SqlConfig,
25    pub(crate) pool: Arc<OnceCell<AnyPool>>,
26}
27
28impl SqlProducer {
29    pub fn new(config: SqlConfig, pool: Arc<OnceCell<AnyPool>>) -> Self {
30        Self { config, pool }
31    }
32
33    /// Resolves the query source based on priority:
34    /// 1. Header `CamelSql.Query`
35    /// 2. Body (if `use_message_body_for_sql` is true)
36    /// 3. Config query
37    pub(crate) fn resolve_query_source(exchange: &Exchange, config: &SqlConfig) -> String {
38        // Priority 1: Header
39        if let Some(query_value) = exchange.input.header(headers::QUERY)
40            && let Some(query_str) = query_value.as_str()
41        {
42            return query_str.to_string();
43        }
44
45        // Priority 2: Body (if use_message_body_for_sql)
46        if config.use_message_body_for_sql
47            && let Some(body_text) = exchange.input.body.as_text()
48        {
49            return body_text.to_string();
50        }
51
52        // Priority 3: Config query
53        config.query.clone()
54    }
55}
56
57impl Service<Exchange> for SqlProducer {
58    type Response = Exchange;
59    type Error = CamelError;
60    type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
61
62    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
63        Poll::Ready(Ok(()))
64    }
65
66    fn call(&mut self, mut exchange: Exchange) -> Self::Future {
67        let config = self.config.clone();
68        let pool_cell = Arc::clone(&self.pool);
69
70        Box::pin(async move {
71            // Get or initialize the connection pool
72            let pool: &AnyPool = pool_cell
73                .get_or_try_init(|| async {
74                    // Install all compiled-in sqlx drivers so AnyPool can resolve them.
75                    // This is idempotent; safe to call multiple times.
76                    sqlx::any::install_default_drivers();
77                    let opts: PoolOptions<sqlx::Any> = PoolOptions::new()
78                        .max_connections(config.max_connections)
79                        .min_connections(config.min_connections)
80                        .idle_timeout(Duration::from_secs(config.idle_timeout_secs))
81                        .max_lifetime(Duration::from_secs(config.max_lifetime_secs));
82                    opts.connect(&config.db_url).await.map_err(|e| {
83                        error!("Failed to connect to database: {}", e);
84                        CamelError::EndpointCreationFailed(format!(
85                            "Failed to connect to database: {}",
86                            e
87                        ))
88                    })
89                })
90                .await
91                .map_err(|e: CamelError| {
92                    error!("Pool initialization failed: {}", e);
93                    e.clone()
94                })?;
95
96            // Resolve query string
97            let query_str = Self::resolve_query_source(&exchange, &config);
98
99            debug!("Executing SQL: {}", query_str);
100
101            // Execute based on mode
102            if config.batch {
103                // Batch mode: execute_batch handles its own template parsing per item
104                execute_batch(pool, &config, &mut exchange).await?;
105            } else {
106                // Non-batch: parse template, resolve params, apply header override
107                let template = parse_query_template(&query_str, config.placeholder)?;
108                let mut prepared = resolve_params(&template, &exchange)?;
109
110                // CamelSql.Parameters header override
111                if let Some(params_value) = exchange.input.header(headers::PARAMETERS) {
112                    if let Some(arr) = params_value.as_array() {
113                        if arr.len() != prepared.bindings.len() {
114                            warn!(
115                                expected = prepared.bindings.len(),
116                                got = arr.len(),
117                                header = headers::PARAMETERS,
118                                "Parameter count mismatch — SQL has {} placeholders but header provides {} values",
119                                prepared.bindings.len(),
120                                arr.len()
121                            );
122                        }
123                        debug!(
124                            "Overriding bindings from {} header with {} parameters",
125                            headers::PARAMETERS,
126                            arr.len()
127                        );
128                        prepared.bindings = arr.clone();
129                    } else {
130                        warn!(
131                            header = headers::PARAMETERS,
132                            "Header is present but not a JSON array — ignoring parameter override"
133                        );
134                    }
135                }
136
137                debug!("Executing SQL: {}", prepared.sql);
138
139                if is_select_query(&prepared.sql) {
140                    execute_select(pool, &prepared, &config, &mut exchange).await?;
141                } else {
142                    execute_modify(pool, &prepared, &config, &mut exchange).await?;
143                }
144            }
145
146            Ok(exchange)
147        })
148    }
149}
150
151/// Executes a SELECT query and populates the exchange body with results.
152async fn execute_select(
153    pool: &AnyPool,
154    prepared: &PreparedQuery,
155    config: &SqlConfig,
156    exchange: &mut Exchange,
157) -> Result<(), CamelError> {
158    match config.output_type {
159        SqlOutputType::SelectOne => {
160            // fetch_all and take first row
161            let mut query = sqlx::query(&prepared.sql);
162            query = bind_json_values(query, &prepared.bindings);
163
164            let rows: Vec<AnyRow> = query.fetch_all(pool).await.map_err(|e| {
165                error!("Query execution failed: {}", e);
166                CamelError::ProcessorError(format!("Query execution failed: {}", e))
167            })?;
168
169            let count = rows.len();
170            let json_rows: Vec<serde_json::Value> = rows
171                .iter()
172                .map(row_to_json)
173                .collect::<Result<Vec<_>, _>>()?;
174
175            if let Some(first_row) = json_rows.into_iter().next() {
176                exchange.input.body = Body::Json(first_row);
177            } else {
178                exchange.input.body = Body::Empty;
179            }
180            debug!("SelectOne returned {} row", if count > 0 { 1 } else { 0 });
181            exchange
182                .input
183                .set_header(headers::ROW_COUNT, serde_json::json!(count));
184        }
185        SqlOutputType::SelectList => {
186            // fetch_all for list output
187            let mut query = sqlx::query(&prepared.sql);
188            query = bind_json_values(query, &prepared.bindings);
189
190            let rows: Vec<AnyRow> = query.fetch_all(pool).await.map_err(|e| {
191                error!("Query execution failed: {}", e);
192                CamelError::ProcessorError(format!("Query execution failed: {}", e))
193            })?;
194
195            let count = rows.len();
196            let json_rows: Vec<serde_json::Value> = rows
197                .iter()
198                .map(row_to_json)
199                .collect::<Result<Vec<_>, _>>()?;
200
201            exchange.input.body = Body::Json(serde_json::Value::Array(json_rows));
202            debug!("SelectList returned {} rows", count);
203            exchange
204                .input
205                .set_header(headers::ROW_COUNT, serde_json::json!(count));
206        }
207        SqlOutputType::StreamList => {
208            // Use fetch() for true streaming - avoids loading all rows into memory
209            use futures::TryStreamExt;
210
211            let pool_clone = pool.clone();
212            let sql_str = prepared.sql.clone();
213            let bindings = prepared.bindings.clone();
214
215            // Build the stream that reads rows on demand and serializes to NDJSON bytes
216            let byte_stream = async_stream::try_stream! {
217                let mut q = sqlx::query(&sql_str);
218                q = bind_json_values(q, &bindings);
219                let mut rows = q.fetch(&pool_clone);
220                while let Some(row) = rows.try_next().await.map_err(|e| {
221                    CamelError::ProcessorError(format!("Query execution failed: {}", e))
222                })? {
223                    let json_val = row_to_json(&row).map_err(|e| {
224                        CamelError::ProcessorError(format!("JSON serialization failed: {}", e))
225                    })?;
226                    let mut bytes = serde_json::to_vec(&json_val)
227                        .map_err(|e| CamelError::ProcessorError(format!("JSON serialization failed: {}", e)))?;
228                    bytes.push(b'\n');
229                    yield Bytes::from(bytes);
230                }
231            };
232
233            exchange.input.body = Body::Stream(StreamBody {
234                stream: Arc::new(tokio::sync::Mutex::new(Some(Box::pin(byte_stream)))),
235                metadata: StreamMetadata {
236                    content_type: Some("application/x-ndjson".to_string()),
237                    size_hint: None,
238                    origin: None,
239                },
240            });
241            debug!("StreamList: created lazy stream (rows fetched on demand)");
242            // Note: ROW_COUNT not set for StreamList since row count is unknown until exhausted
243        }
244    }
245
246    Ok(())
247}
248
249/// Executes a modification query (INSERT/UPDATE/DELETE).
250async fn execute_modify(
251    pool: &AnyPool,
252    prepared: &PreparedQuery,
253    config: &SqlConfig,
254    exchange: &mut Exchange,
255) -> Result<(), CamelError> {
256    let mut query = sqlx::query(&prepared.sql);
257    query = bind_json_values(query, &prepared.bindings);
258
259    let result = query.execute(pool).await.map_err(|e| {
260        error!("Query execution failed: {}", e);
261        CamelError::ProcessorError(format!("Query execution failed: {}", e))
262    })?;
263
264    let rows_affected = result.rows_affected();
265
266    // Fix 4: Implement expected_update_count validation
267    if let Some(expected) = config.expected_update_count
268        && rows_affected as i64 != expected
269    {
270        error!("Expected {} rows affected, got {}", expected, rows_affected);
271        return Err(CamelError::ProcessorError(format!(
272            "Expected {} rows affected, got {}",
273            expected, rows_affected
274        )));
275    }
276
277    exchange
278        .input
279        .set_header(headers::UPDATE_COUNT, serde_json::json!(rows_affected));
280
281    if config.noop {
282        // Preserve original body
283    } else {
284        exchange.input.body = Body::Json(json!({ "rowsAffected": rows_affected }));
285    }
286
287    debug!("Modify query affected {} rows", rows_affected);
288
289    Ok(())
290}
291
292/// Executes a batch of queries from a JSON array body.
293async fn execute_batch(
294    pool: &AnyPool,
295    config: &SqlConfig,
296    exchange: &mut Exchange,
297) -> Result<(), CamelError> {
298    // Body must be JSON array of arrays
299    let body_json = match &exchange.input.body {
300        Body::Json(val) => val,
301        _ => {
302            return Err(CamelError::ProcessorError(
303                "Batch mode requires body to be a JSON array of arrays".to_string(),
304            ));
305        }
306    };
307
308    let batch_data = body_json
309        .as_array()
310        .ok_or_else(|| {
311            CamelError::ProcessorError("Batch mode requires body to be a JSON array".to_string())
312        })?
313        .clone();
314
315    // Parse template from config query
316    let template = parse_query_template(&config.query, config.placeholder)?;
317
318    // Fix 2: Batch operations must be wrapped in a transaction
319    let mut tx = pool.begin().await.map_err(|e| {
320        error!("Failed to begin transaction: {}", e);
321        CamelError::ProcessorError(format!("Failed to begin transaction: {}", e))
322    })?;
323
324    let mut total_rows_affected: u64 = 0;
325
326    for (batch_idx, params_array) in batch_data.into_iter().enumerate() {
327        // Each item must be an array of parameters
328        params_array.as_array().ok_or_else(|| {
329            CamelError::ProcessorError(format!(
330                "Batch item at index {} must be a JSON array of parameters",
331                batch_idx
332            ))
333        })?;
334
335        // Create a temporary exchange with the params as body for resolution
336        let temp_msg = Message::new(Body::Json(params_array.clone()));
337        let temp_exchange = Exchange::new(temp_msg);
338
339        // Resolve parameters for this batch item
340        let prepared = resolve_params(&template, &temp_exchange)?;
341
342        // Execute against transaction
343        let mut query = sqlx::query(&prepared.sql);
344        query = bind_json_values(query, &prepared.bindings);
345
346        let result = query.execute(&mut *tx).await.map_err(|e| {
347            error!("Batch query execution failed at index {}: {}", batch_idx, e);
348            CamelError::ProcessorError(format!("Batch query execution failed: {}", e))
349        })?;
350
351        // Validate expected_update_count per batch item
352        if let Some(expected) = config.expected_update_count
353            && result.rows_affected() as i64 != expected
354        {
355            error!(
356                "Batch item {}: expected {} rows affected, got {}",
357                batch_idx,
358                expected,
359                result.rows_affected()
360            );
361            return Err(CamelError::ProcessorError(format!(
362                "Batch item {}: expected {} rows affected, got {}",
363                batch_idx,
364                expected,
365                result.rows_affected()
366            )));
367        }
368
369        total_rows_affected += result.rows_affected();
370    }
371
372    // Commit the transaction
373    tx.commit().await.map_err(|e| {
374        error!("Failed to commit transaction: {}", e);
375        CamelError::ProcessorError(format!("Failed to commit transaction: {}", e))
376    })?;
377
378    exchange.input.set_header(
379        headers::UPDATE_COUNT,
380        serde_json::json!(total_rows_affected),
381    );
382
383    debug!(
384        "Batch execution completed, total rows affected: {}",
385        total_rows_affected
386    );
387
388    Ok(())
389}
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394    use camel_api::Message;
395    use std::sync::Arc;
396    use tokio::sync::OnceCell;
397
398    fn test_config() -> SqlConfig {
399        SqlConfig::from_uri("sql:select 1?db_url=postgres://localhost/test").unwrap()
400    }
401
402    #[test]
403    fn test_producer_clone_shares_pool() {
404        let p1 = SqlProducer::new(test_config(), Arc::new(OnceCell::new()));
405        let p2 = p1.clone();
406        assert!(Arc::ptr_eq(&p1.pool, &p2.pool));
407    }
408
409    #[test]
410    fn test_resolve_query_from_config() {
411        let config = test_config();
412        let ex = Exchange::new(Message::default());
413        let q = SqlProducer::resolve_query_source(&ex, &config);
414        assert_eq!(q, "select 1");
415    }
416
417    #[test]
418    fn test_resolve_query_from_header() {
419        let config = test_config();
420        let mut msg = Message::default();
421        msg.set_header(headers::QUERY, serde_json::json!("select 2"));
422        let ex = Exchange::new(msg);
423        let q = SqlProducer::resolve_query_source(&ex, &config);
424        assert_eq!(q, "select 2");
425    }
426
427    #[test]
428    fn test_resolve_query_from_body() {
429        let mut config = test_config();
430        config.use_message_body_for_sql = true;
431        let msg = Message::new(Body::Text("select 3".to_string()));
432        let ex = Exchange::new(msg);
433        let q = SqlProducer::resolve_query_source(&ex, &config);
434        assert_eq!(q, "select 3");
435    }
436
437    #[test]
438    fn test_resolve_query_header_priority_over_body() {
439        let mut config = test_config();
440        config.use_message_body_for_sql = true;
441        let mut msg = Message::new(Body::Text("select from body".to_string()));
442        msg.set_header(headers::QUERY, serde_json::json!("select from header"));
443        let ex = Exchange::new(msg);
444        let q = SqlProducer::resolve_query_source(&ex, &config);
445        assert_eq!(q, "select from header");
446    }
447
448    #[test]
449    fn test_resolve_query_body_priority_over_config() {
450        let mut config = test_config();
451        config.use_message_body_for_sql = true;
452        let msg = Message::new(Body::Text("select from body".to_string()));
453        let ex = Exchange::new(msg);
454        let q = SqlProducer::resolve_query_source(&ex, &config);
455        assert_eq!(q, "select from body");
456    }
457
458    #[test]
459    fn test_bind_json_null() {
460        let query = sqlx::query("SELECT ?");
461        let values = vec![serde_json::Value::Null];
462        let _bound = bind_json_values(query, &values);
463        // Compilation test - ensure it binds
464    }
465
466    #[test]
467    fn test_bind_json_bool() {
468        let query = sqlx::query("SELECT ?");
469        let values = vec![serde_json::Value::Bool(true)];
470        let _bound = bind_json_values(query, &values);
471    }
472
473    #[test]
474    fn test_bind_json_number_i64() {
475        let query = sqlx::query("SELECT ?");
476        let values = vec![serde_json::json!(42)];
477        let _bound = bind_json_values(query, &values);
478    }
479
480    #[test]
481    fn test_bind_json_number_f64() {
482        let query = sqlx::query("SELECT ?");
483        let values = vec![serde_json::json!(std::f64::consts::PI)];
484        let _bound = bind_json_values(query, &values);
485    }
486
487    #[test]
488    fn test_bind_json_string() {
489        let query = sqlx::query("SELECT ?");
490        let values = vec![serde_json::json!("hello world")];
491        let _bound = bind_json_values(query, &values);
492    }
493
494    #[test]
495    fn test_bind_json_array() {
496        let query = sqlx::query("SELECT ?");
497        let values = vec![serde_json::json!([1, 2, 3])];
498        let _bound = bind_json_values(query, &values);
499    }
500
501    #[test]
502    fn test_bind_json_object() {
503        let query = sqlx::query("SELECT ?");
504        let values = vec![serde_json::json!({"key": "value"})];
505        let _bound = bind_json_values(query, &values);
506    }
507
508    #[test]
509    fn test_bind_multiple_values() {
510        let query = sqlx::query("SELECT ?, ?, ?");
511        let values = vec![
512            serde_json::json!(1),
513            serde_json::json!("test"),
514            serde_json::Value::Null,
515        ];
516        let _bound = bind_json_values(query, &values);
517    }
518
519    // Test for Fix 4: expected_update_count config field presence
520    #[test]
521    fn test_expected_update_count_validation() {
522        // Test that expected_update_count is parsed from URI
523        let config = SqlConfig::from_uri(
524            "sql:update t set x=1?db_url=postgres://localhost/test&expectedUpdateCount=5",
525        )
526        .unwrap();
527        assert_eq!(config.expected_update_count, Some(5));
528
529        // Test default (no expected_update_count)
530        let config_default = test_config();
531        assert_eq!(config_default.expected_update_count, None);
532
533        // Test negative value (should parse)
534        let config_neg = SqlConfig::from_uri(
535            "sql:update t set x=1?db_url=postgres://localhost/test&expectedUpdateCount=-1",
536        )
537        .unwrap();
538        assert_eq!(config_neg.expected_update_count, Some(-1));
539    }
540
541    // Test for Fix 3: parameters header override logic
542    #[test]
543    fn test_parameters_header_override_logic() {
544        // Create a PreparedQuery manually
545        let mut prepared = PreparedQuery {
546            sql: "SELECT * FROM t WHERE id = $1".to_string(),
547            bindings: vec![serde_json::json!(42)],
548        };
549
550        // Simulate the header override logic
551        let header_params = serde_json::json!([99, "extra"]);
552        if let Some(arr) = header_params.as_array() {
553            prepared.bindings = arr.clone();
554        }
555
556        // Verify bindings were overridden
557        assert_eq!(prepared.bindings.len(), 2);
558        assert_eq!(prepared.bindings[0], serde_json::json!(99));
559        assert_eq!(prepared.bindings[1], serde_json::json!("extra"));
560
561        // Test with non-array header (should not override)
562        let mut prepared2 = PreparedQuery {
563            sql: "SELECT * FROM t WHERE id = $1".to_string(),
564            bindings: vec![serde_json::json!(42)],
565        };
566        let header_non_array = serde_json::json!({"not": "an array"});
567        if let Some(arr) = header_non_array.as_array() {
568            prepared2.bindings = arr.clone();
569        }
570        // Should remain unchanged
571        assert_eq!(prepared2.bindings.len(), 1);
572        assert_eq!(prepared2.bindings[0], serde_json::json!(42));
573    }
574}