Skip to main content

camel_component_sql/
producer.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::{Context, Poll};
5use std::time::Duration;
6
7use bytes::Bytes;
8use serde_json::json;
9use sqlx::AnyPool;
10use sqlx::any::AnyRow;
11use sqlx::pool::PoolOptions;
12use tokio::sync::OnceCell;
13use tower::Service;
14use tracing::{debug, error, warn};
15
16use crate::config::{SqlEndpointConfig, SqlOutputType};
17use crate::headers;
18use crate::query::{PreparedQuery, is_select_query, parse_query_template, resolve_params};
19use crate::utils::{bind_json_values, row_to_json};
20use camel_api::{Body, CamelError, Exchange, Message, StreamBody, StreamMetadata};
21
22#[derive(Clone)]
23pub struct SqlProducer {
24    pub(crate) config: SqlEndpointConfig,
25    pub(crate) pool: Arc<OnceCell<AnyPool>>,
26}
27
28impl SqlProducer {
29    pub fn new(config: SqlEndpointConfig, pool: Arc<OnceCell<AnyPool>>) -> Self {
30        Self { config, pool }
31    }
32
33    /// Resolves the query source based on priority:
34    /// 1. Header `CamelSql.Query`
35    /// 2. Body (if `use_message_body_for_sql` is true)
36    /// 3. Config query
37    pub(crate) fn resolve_query_source(exchange: &Exchange, config: &SqlEndpointConfig) -> String {
38        // Priority 1: Header
39        if let Some(query_value) = exchange.input.header(headers::QUERY)
40            && let Some(query_str) = query_value.as_str()
41        {
42            return query_str.to_string();
43        }
44
45        // Priority 2: Body (if use_message_body_for_sql)
46        if config.use_message_body_for_sql
47            && let Some(body_text) = exchange.input.body.as_text()
48        {
49            return body_text.to_string();
50        }
51
52        // Priority 3: Config query
53        config.query.clone()
54    }
55}
56
57impl Service<Exchange> for SqlProducer {
58    type Response = Exchange;
59    type Error = CamelError;
60    type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
61
62    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
63        Poll::Ready(Ok(()))
64    }
65
66    fn call(&mut self, mut exchange: Exchange) -> Self::Future {
67        let mut config = self.config.clone();
68        let pool_cell = Arc::clone(&self.pool);
69
70        Box::pin(async move {
71            // Get or initialize the connection pool
72            let pool: &AnyPool = pool_cell
73                .get_or_try_init(|| async {
74                    // Defensive: ensure config is resolved even if caller didn't use create_endpoint
75                    config.resolve_defaults();
76
77                    // Install all compiled-in sqlx drivers so AnyPool can resolve them.
78                    // This is idempotent; safe to call multiple times.
79                    sqlx::any::install_default_drivers();
80                    let opts: PoolOptions<sqlx::Any> = PoolOptions::new()
81                        .max_connections(
82                            config
83                                .max_connections
84                                .expect("must be Some after resolve_defaults()"),
85                        )
86                        .min_connections(
87                            config
88                                .min_connections
89                                .expect("must be Some after resolve_defaults()"),
90                        )
91                        .idle_timeout(Duration::from_secs(
92                            config
93                                .idle_timeout_secs
94                                .expect("must be Some after resolve_defaults()"),
95                        ))
96                        .max_lifetime(Duration::from_secs(
97                            config
98                                .max_lifetime_secs
99                                .expect("must be Some after resolve_defaults()"),
100                        ));
101                    opts.connect(&config.db_url).await.map_err(|e| {
102                        error!("Failed to connect to database: {}", e);
103                        CamelError::EndpointCreationFailed(format!(
104                            "Failed to connect to database: {}",
105                            e
106                        ))
107                    })
108                })
109                .await
110                .map_err(|e: CamelError| {
111                    error!("Pool initialization failed: {}", e);
112                    e.clone()
113                })?;
114
115            // Resolve query string
116            let query_str = Self::resolve_query_source(&exchange, &config);
117
118            debug!("Executing SQL: {}", query_str);
119
120            // Execute based on mode
121            if config.batch {
122                // Batch mode: execute_batch handles its own template parsing per item
123                execute_batch(pool, &config, &mut exchange).await?;
124            } else {
125                // Non-batch: parse template, resolve params, apply header override
126                let template = parse_query_template(&query_str, config.placeholder)?;
127                let mut prepared = resolve_params(&template, &exchange)?;
128
129                // CamelSql.Parameters header override
130                if let Some(params_value) = exchange.input.header(headers::PARAMETERS) {
131                    if let Some(arr) = params_value.as_array() {
132                        if arr.len() != prepared.bindings.len() {
133                            warn!(
134                                expected = prepared.bindings.len(),
135                                got = arr.len(),
136                                header = headers::PARAMETERS,
137                                "Parameter count mismatch — SQL has {} placeholders but header provides {} values",
138                                prepared.bindings.len(),
139                                arr.len()
140                            );
141                        }
142                        debug!(
143                            "Overriding bindings from {} header with {} parameters",
144                            headers::PARAMETERS,
145                            arr.len()
146                        );
147                        prepared.bindings = arr.clone();
148                    } else {
149                        warn!(
150                            header = headers::PARAMETERS,
151                            "Header is present but not a JSON array — ignoring parameter override"
152                        );
153                    }
154                }
155
156                debug!("Executing SQL: {}", prepared.sql);
157
158                if is_select_query(&prepared.sql) {
159                    execute_select(pool, &prepared, &config, &mut exchange).await?;
160                } else {
161                    execute_modify(pool, &prepared, &config, &mut exchange).await?;
162                }
163            }
164
165            Ok(exchange)
166        })
167    }
168}
169
170/// Executes a SELECT query and populates the exchange body with results.
171async fn execute_select(
172    pool: &AnyPool,
173    prepared: &PreparedQuery,
174    config: &SqlEndpointConfig,
175    exchange: &mut Exchange,
176) -> Result<(), CamelError> {
177    match config.output_type {
178        SqlOutputType::SelectOne => {
179            // fetch_all and take first row
180            let mut query = sqlx::query(&prepared.sql);
181            query = bind_json_values(query, &prepared.bindings);
182
183            let rows: Vec<AnyRow> = query.fetch_all(pool).await.map_err(|e| {
184                error!("Query execution failed: {}", e);
185                CamelError::ProcessorError(format!("Query execution failed: {}", e))
186            })?;
187
188            let count = rows.len();
189            let json_rows: Vec<serde_json::Value> = rows
190                .iter()
191                .map(row_to_json)
192                .collect::<Result<Vec<_>, _>>()?;
193
194            if let Some(first_row) = json_rows.into_iter().next() {
195                exchange.input.body = Body::Json(first_row);
196            } else {
197                exchange.input.body = Body::Empty;
198            }
199            debug!("SelectOne returned {} row", if count > 0 { 1 } else { 0 });
200            exchange
201                .input
202                .set_header(headers::ROW_COUNT, serde_json::json!(count));
203        }
204        SqlOutputType::SelectList => {
205            // fetch_all for list output
206            let mut query = sqlx::query(&prepared.sql);
207            query = bind_json_values(query, &prepared.bindings);
208
209            let rows: Vec<AnyRow> = query.fetch_all(pool).await.map_err(|e| {
210                error!("Query execution failed: {}", e);
211                CamelError::ProcessorError(format!("Query execution failed: {}", e))
212            })?;
213
214            let count = rows.len();
215            let json_rows: Vec<serde_json::Value> = rows
216                .iter()
217                .map(row_to_json)
218                .collect::<Result<Vec<_>, _>>()?;
219
220            exchange.input.body = Body::Json(serde_json::Value::Array(json_rows));
221            debug!("SelectList returned {} rows", count);
222            exchange
223                .input
224                .set_header(headers::ROW_COUNT, serde_json::json!(count));
225        }
226        SqlOutputType::StreamList => {
227            // Use fetch() for true streaming - avoids loading all rows into memory
228            use futures::TryStreamExt;
229
230            let pool_clone = pool.clone();
231            let sql_str = prepared.sql.clone();
232            let bindings = prepared.bindings.clone();
233
234            // Build the stream that reads rows on demand and serializes to NDJSON bytes
235            let byte_stream = async_stream::try_stream! {
236                let mut q = sqlx::query(&sql_str);
237                q = bind_json_values(q, &bindings);
238                let mut rows = q.fetch(&pool_clone);
239                while let Some(row) = rows.try_next().await.map_err(|e| {
240                    CamelError::ProcessorError(format!("Query execution failed: {}", e))
241                })? {
242                    let json_val = row_to_json(&row).map_err(|e| {
243                        CamelError::ProcessorError(format!("JSON serialization failed: {}", e))
244                    })?;
245                    let mut bytes = serde_json::to_vec(&json_val)
246                        .map_err(|e| CamelError::ProcessorError(format!("JSON serialization failed: {}", e)))?;
247                    bytes.push(b'\n');
248                    yield Bytes::from(bytes);
249                }
250            };
251
252            exchange.input.body = Body::Stream(StreamBody {
253                stream: Arc::new(tokio::sync::Mutex::new(Some(Box::pin(byte_stream)))),
254                metadata: StreamMetadata {
255                    content_type: Some("application/x-ndjson".to_string()),
256                    size_hint: None,
257                    origin: None,
258                },
259            });
260            debug!("StreamList: created lazy stream (rows fetched on demand)");
261            // Note: ROW_COUNT not set for StreamList since row count is unknown until exhausted
262        }
263    }
264
265    Ok(())
266}
267
268/// Executes a modification query (INSERT/UPDATE/DELETE).
269async fn execute_modify(
270    pool: &AnyPool,
271    prepared: &PreparedQuery,
272    config: &SqlEndpointConfig,
273    exchange: &mut Exchange,
274) -> Result<(), CamelError> {
275    let mut query = sqlx::query(&prepared.sql);
276    query = bind_json_values(query, &prepared.bindings);
277
278    let result = query.execute(pool).await.map_err(|e| {
279        error!("Query execution failed: {}", e);
280        CamelError::ProcessorError(format!("Query execution failed: {}", e))
281    })?;
282
283    let rows_affected = result.rows_affected();
284
285    // Fix 4: Implement expected_update_count validation
286    if let Some(expected) = config.expected_update_count
287        && rows_affected as i64 != expected
288    {
289        error!("Expected {} rows affected, got {}", expected, rows_affected);
290        return Err(CamelError::ProcessorError(format!(
291            "Expected {} rows affected, got {}",
292            expected, rows_affected
293        )));
294    }
295
296    exchange
297        .input
298        .set_header(headers::UPDATE_COUNT, serde_json::json!(rows_affected));
299
300    if config.noop {
301        // Preserve original body
302    } else {
303        exchange.input.body = Body::Json(json!({ "rowsAffected": rows_affected }));
304    }
305
306    debug!("Modify query affected {} rows", rows_affected);
307
308    Ok(())
309}
310
311/// Executes a batch of queries from a JSON array body.
312async fn execute_batch(
313    pool: &AnyPool,
314    config: &SqlEndpointConfig,
315    exchange: &mut Exchange,
316) -> Result<(), CamelError> {
317    // Body must be JSON array of arrays
318    let body_json = match &exchange.input.body {
319        Body::Json(val) => val,
320        _ => {
321            return Err(CamelError::ProcessorError(
322                "Batch mode requires body to be a JSON array of arrays".to_string(),
323            ));
324        }
325    };
326
327    let batch_data = body_json
328        .as_array()
329        .ok_or_else(|| {
330            CamelError::ProcessorError("Batch mode requires body to be a JSON array".to_string())
331        })?
332        .clone();
333
334    // Parse template from config query
335    let template = parse_query_template(&config.query, config.placeholder)?;
336
337    // Fix 2: Batch operations must be wrapped in a transaction
338    let mut tx = pool.begin().await.map_err(|e| {
339        error!("Failed to begin transaction: {}", e);
340        CamelError::ProcessorError(format!("Failed to begin transaction: {}", e))
341    })?;
342
343    let mut total_rows_affected: u64 = 0;
344
345    for (batch_idx, params_array) in batch_data.into_iter().enumerate() {
346        // Each item must be an array of parameters
347        params_array.as_array().ok_or_else(|| {
348            CamelError::ProcessorError(format!(
349                "Batch item at index {} must be a JSON array of parameters",
350                batch_idx
351            ))
352        })?;
353
354        // Create a temporary exchange with the params as body for resolution
355        let temp_msg = Message::new(Body::Json(params_array.clone()));
356        let temp_exchange = Exchange::new(temp_msg);
357
358        // Resolve parameters for this batch item
359        let prepared = resolve_params(&template, &temp_exchange)?;
360
361        // Execute against transaction
362        let mut query = sqlx::query(&prepared.sql);
363        query = bind_json_values(query, &prepared.bindings);
364
365        let result = query.execute(&mut *tx).await.map_err(|e| {
366            error!("Batch query execution failed at index {}: {}", batch_idx, e);
367            CamelError::ProcessorError(format!("Batch query execution failed: {}", e))
368        })?;
369
370        // Validate expected_update_count per batch item
371        if let Some(expected) = config.expected_update_count
372            && result.rows_affected() as i64 != expected
373        {
374            error!(
375                "Batch item {}: expected {} rows affected, got {}",
376                batch_idx,
377                expected,
378                result.rows_affected()
379            );
380            return Err(CamelError::ProcessorError(format!(
381                "Batch item {}: expected {} rows affected, got {}",
382                batch_idx,
383                expected,
384                result.rows_affected()
385            )));
386        }
387
388        total_rows_affected += result.rows_affected();
389    }
390
391    // Commit the transaction
392    tx.commit().await.map_err(|e| {
393        error!("Failed to commit transaction: {}", e);
394        CamelError::ProcessorError(format!("Failed to commit transaction: {}", e))
395    })?;
396
397    exchange.input.set_header(
398        headers::UPDATE_COUNT,
399        serde_json::json!(total_rows_affected),
400    );
401
402    debug!(
403        "Batch execution completed, total rows affected: {}",
404        total_rows_affected
405    );
406
407    Ok(())
408}
409
410#[cfg(test)]
411mod tests {
412    use super::*;
413    use camel_api::Message;
414    use camel_endpoint::UriConfig;
415    use std::sync::Arc;
416    use tokio::sync::OnceCell;
417
418    fn test_config() -> SqlEndpointConfig {
419        let mut c =
420            SqlEndpointConfig::from_uri("sql:select 1?db_url=postgres://localhost/test").unwrap();
421        c.resolve_defaults();
422        c
423    }
424
425    #[test]
426    fn test_producer_clone_shares_pool() {
427        let p1 = SqlProducer::new(test_config(), Arc::new(OnceCell::new()));
428        let p2 = p1.clone();
429        assert!(Arc::ptr_eq(&p1.pool, &p2.pool));
430    }
431
432    #[test]
433    fn test_resolve_query_from_config() {
434        let config = test_config();
435        let ex = Exchange::new(Message::default());
436        let q = SqlProducer::resolve_query_source(&ex, &config);
437        assert_eq!(q, "select 1");
438    }
439
440    #[test]
441    fn test_resolve_query_from_header() {
442        let config = test_config();
443        let mut msg = Message::default();
444        msg.set_header(headers::QUERY, serde_json::json!("select 2"));
445        let ex = Exchange::new(msg);
446        let q = SqlProducer::resolve_query_source(&ex, &config);
447        assert_eq!(q, "select 2");
448    }
449
450    #[test]
451    fn test_resolve_query_from_body() {
452        let mut config = test_config();
453        config.use_message_body_for_sql = true;
454        let msg = Message::new(Body::Text("select 3".to_string()));
455        let ex = Exchange::new(msg);
456        let q = SqlProducer::resolve_query_source(&ex, &config);
457        assert_eq!(q, "select 3");
458    }
459
460    #[test]
461    fn test_resolve_query_header_priority_over_body() {
462        let mut config = test_config();
463        config.use_message_body_for_sql = true;
464        let mut msg = Message::new(Body::Text("select from body".to_string()));
465        msg.set_header(headers::QUERY, serde_json::json!("select from header"));
466        let ex = Exchange::new(msg);
467        let q = SqlProducer::resolve_query_source(&ex, &config);
468        assert_eq!(q, "select from header");
469    }
470
471    #[test]
472    fn test_resolve_query_body_priority_over_config() {
473        let mut config = test_config();
474        config.use_message_body_for_sql = true;
475        let msg = Message::new(Body::Text("select from body".to_string()));
476        let ex = Exchange::new(msg);
477        let q = SqlProducer::resolve_query_source(&ex, &config);
478        assert_eq!(q, "select from body");
479    }
480
481    #[test]
482    fn test_bind_json_null() {
483        let query = sqlx::query("SELECT ?");
484        let values = vec![serde_json::Value::Null];
485        let _bound = bind_json_values(query, &values);
486        // Compilation test - ensure it binds
487    }
488
489    #[test]
490    fn test_bind_json_bool() {
491        let query = sqlx::query("SELECT ?");
492        let values = vec![serde_json::Value::Bool(true)];
493        let _bound = bind_json_values(query, &values);
494    }
495
496    #[test]
497    fn test_bind_json_number_i64() {
498        let query = sqlx::query("SELECT ?");
499        let values = vec![serde_json::json!(42)];
500        let _bound = bind_json_values(query, &values);
501    }
502
503    #[test]
504    fn test_bind_json_number_f64() {
505        let query = sqlx::query("SELECT ?");
506        let values = vec![serde_json::json!(std::f64::consts::PI)];
507        let _bound = bind_json_values(query, &values);
508    }
509
510    #[test]
511    fn test_bind_json_string() {
512        let query = sqlx::query("SELECT ?");
513        let values = vec![serde_json::json!("hello world")];
514        let _bound = bind_json_values(query, &values);
515    }
516
517    #[test]
518    fn test_bind_json_array() {
519        let query = sqlx::query("SELECT ?");
520        let values = vec![serde_json::json!([1, 2, 3])];
521        let _bound = bind_json_values(query, &values);
522    }
523
524    #[test]
525    fn test_bind_json_object() {
526        let query = sqlx::query("SELECT ?");
527        let values = vec![serde_json::json!({"key": "value"})];
528        let _bound = bind_json_values(query, &values);
529    }
530
531    #[test]
532    fn test_bind_multiple_values() {
533        let query = sqlx::query("SELECT ?, ?, ?");
534        let values = vec![
535            serde_json::json!(1),
536            serde_json::json!("test"),
537            serde_json::Value::Null,
538        ];
539        let _bound = bind_json_values(query, &values);
540    }
541
542    // Test for Fix 4: expected_update_count config field presence
543    #[test]
544    fn test_expected_update_count_validation() {
545        // Test that expected_update_count is parsed from URI
546        let config = SqlEndpointConfig::from_uri(
547            "sql:update t set x=1?db_url=postgres://localhost/test&expectedUpdateCount=5",
548        )
549        .unwrap();
550        assert_eq!(config.expected_update_count, Some(5));
551
552        // Test default (no expected_update_count)
553        let config_default = test_config();
554        assert_eq!(config_default.expected_update_count, None);
555
556        // Test negative value (should parse)
557        let config_neg = SqlEndpointConfig::from_uri(
558            "sql:update t set x=1?db_url=postgres://localhost/test&expectedUpdateCount=-1",
559        )
560        .unwrap();
561        assert_eq!(config_neg.expected_update_count, Some(-1));
562    }
563
564    // Test for Fix 3: parameters header override logic
565    #[test]
566    fn test_parameters_header_override_logic() {
567        // Create a PreparedQuery manually
568        let mut prepared = PreparedQuery {
569            sql: "SELECT * FROM t WHERE id = $1".to_string(),
570            bindings: vec![serde_json::json!(42)],
571        };
572
573        // Simulate the header override logic
574        let header_params = serde_json::json!([99, "extra"]);
575        if let Some(arr) = header_params.as_array() {
576            prepared.bindings = arr.clone();
577        }
578
579        // Verify bindings were overridden
580        assert_eq!(prepared.bindings.len(), 2);
581        assert_eq!(prepared.bindings[0], serde_json::json!(99));
582        assert_eq!(prepared.bindings[1], serde_json::json!("extra"));
583
584        // Test with non-array header (should not override)
585        let mut prepared2 = PreparedQuery {
586            sql: "SELECT * FROM t WHERE id = $1".to_string(),
587            bindings: vec![serde_json::json!(42)],
588        };
589        let header_non_array = serde_json::json!({"not": "an array"});
590        if let Some(arr) = header_non_array.as_array() {
591            prepared2.bindings = arr.clone();
592        }
593        // Should remain unchanged
594        assert_eq!(prepared2.bindings.len(), 1);
595        assert_eq!(prepared2.bindings[0], serde_json::json!(42));
596    }
597}