1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::{Context, Poll};
5use std::time::Duration;
6
7use bytes::Bytes;
8use serde_json::json;
9use sqlx::AnyPool;
10use sqlx::any::AnyRow;
11use sqlx::pool::PoolOptions;
12use tokio::sync::OnceCell;
13use tower::Service;
14use tracing::{debug, error, warn};
15
16use crate::config::{SqlEndpointConfig, SqlOutputType, enrich_db_url_with_ssl};
17use crate::headers;
18use crate::query::{PreparedQuery, is_select_query, parse_query_template, resolve_params};
19use crate::utils::{bind_json_values, row_to_json};
20use camel_component_api::{Body, CamelError, Exchange, Message, StreamBody, StreamMetadata};
21
22#[derive(Clone)]
23pub struct SqlProducer {
24 pub(crate) config: SqlEndpointConfig,
25 pub(crate) pool: Arc<OnceCell<AnyPool>>,
26}
27
28impl SqlProducer {
29 pub fn new(config: SqlEndpointConfig, pool: Arc<OnceCell<AnyPool>>) -> Self {
30 Self { config, pool }
31 }
32
33 pub(crate) fn resolve_query_source(exchange: &Exchange, config: &SqlEndpointConfig) -> String {
38 if let Some(query_value) = exchange.input.header(headers::QUERY)
40 && let Some(query_str) = query_value.as_str()
41 {
42 return query_str.to_string();
43 }
44
45 if config.use_message_body_for_sql
47 && let Some(body_text) = exchange.input.body.as_text()
48 {
49 return body_text.to_string();
50 }
51
52 config.query.clone()
54 }
55}
56
57impl Service<Exchange> for SqlProducer {
58 type Response = Exchange;
59 type Error = CamelError;
60 type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
61
62 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
63 Poll::Ready(Ok(()))
64 }
65
66 fn call(&mut self, mut exchange: Exchange) -> Self::Future {
67 let mut config = self.config.clone();
68 let pool_cell = Arc::clone(&self.pool);
69
70 Box::pin(async move {
71 let pool: &AnyPool = pool_cell
73 .get_or_try_init(|| async {
74 config.resolve_defaults();
76 let db_url = enrich_db_url_with_ssl(&config.db_url, &config)?;
77
78 sqlx::any::install_default_drivers();
81 let opts: PoolOptions<sqlx::Any> = PoolOptions::new()
82 .max_connections(
83 config
84 .max_connections
85 .expect("must be Some after resolve_defaults()"),
86 )
87 .min_connections(
88 config
89 .min_connections
90 .expect("must be Some after resolve_defaults()"),
91 )
92 .idle_timeout(Duration::from_secs(
93 config
94 .idle_timeout_secs
95 .expect("must be Some after resolve_defaults()"),
96 ))
97 .max_lifetime(Duration::from_secs(
98 config
99 .max_lifetime_secs
100 .expect("must be Some after resolve_defaults()"),
101 ));
102 opts.connect(&db_url).await.map_err(|e| {
103 error!("Failed to connect to database: {}", e);
104 CamelError::EndpointCreationFailed(format!(
105 "Failed to connect to database: {}",
106 e
107 ))
108 })
109 })
110 .await
111 .map_err(|e: CamelError| {
112 error!("Pool initialization failed: {}", e);
113 e.clone()
114 })?;
115
116 let query_str = Self::resolve_query_source(&exchange, &config);
118
119 debug!("Executing SQL: {}", query_str);
120
121 if config.batch {
123 execute_batch(pool, &config, &mut exchange).await?;
125 } else {
126 let template = parse_query_template(&query_str, config.placeholder)?;
128 let mut prepared = resolve_params(&template, &exchange, &config.in_separator)?;
129
130 if let Some(params_value) = exchange.input.header(headers::PARAMETERS) {
132 if let Some(arr) = params_value.as_array() {
133 if arr.len() != prepared.bindings.len() {
134 warn!(
135 expected = prepared.bindings.len(),
136 got = arr.len(),
137 header = headers::PARAMETERS,
138 "Parameter count mismatch — SQL has {} placeholders but header provides {} values",
139 prepared.bindings.len(),
140 arr.len()
141 );
142 }
143 debug!(
144 "Overriding bindings from {} header with {} parameters",
145 headers::PARAMETERS,
146 arr.len()
147 );
148 prepared.bindings = arr.clone();
149 } else {
150 warn!(
151 header = headers::PARAMETERS,
152 "Header is present but not a JSON array — ignoring parameter override"
153 );
154 }
155 }
156
157 debug!("Executing SQL: {}", prepared.sql);
158
159 if is_select_query(&prepared.sql) {
160 execute_select(pool, &prepared, &config, &mut exchange).await?;
161 } else {
162 execute_modify(pool, &prepared, &config, &mut exchange).await?;
163 }
164 }
165
166 Ok(exchange)
167 })
168 }
169}
170
171async fn execute_select(
173 pool: &AnyPool,
174 prepared: &PreparedQuery,
175 config: &SqlEndpointConfig,
176 exchange: &mut Exchange,
177) -> Result<(), CamelError> {
178 match config.output_type {
179 SqlOutputType::SelectOne => {
180 let mut query = sqlx::query(&prepared.sql);
182 query = bind_json_values(query, &prepared.bindings);
183
184 let rows: Vec<AnyRow> = query.fetch_all(pool).await.map_err(|e| {
185 error!("Query execution failed: {}", e);
186 CamelError::ProcessorError(format!("Query execution failed: {}", e))
187 })?;
188
189 let count = rows.len();
190 let json_rows: Vec<serde_json::Value> = rows
191 .iter()
192 .map(row_to_json)
193 .collect::<Result<Vec<_>, _>>()?;
194
195 if let Some(first_row) = json_rows.into_iter().next() {
196 exchange.input.body = Body::Json(first_row);
197 } else {
198 exchange.input.body = Body::Empty;
199 }
200 debug!("SelectOne returned {} row", if count > 0 { 1 } else { 0 });
201 exchange
202 .input
203 .set_header(headers::ROW_COUNT, serde_json::json!(count));
204 }
205 SqlOutputType::SelectList => {
206 let mut query = sqlx::query(&prepared.sql);
208 query = bind_json_values(query, &prepared.bindings);
209
210 let rows: Vec<AnyRow> = query.fetch_all(pool).await.map_err(|e| {
211 error!("Query execution failed: {}", e);
212 CamelError::ProcessorError(format!("Query execution failed: {}", e))
213 })?;
214
215 let count = rows.len();
216 let json_rows: Vec<serde_json::Value> = rows
217 .iter()
218 .map(row_to_json)
219 .collect::<Result<Vec<_>, _>>()?;
220
221 exchange.input.body = Body::Json(serde_json::Value::Array(json_rows));
222 debug!("SelectList returned {} rows", count);
223 exchange
224 .input
225 .set_header(headers::ROW_COUNT, serde_json::json!(count));
226 }
227 SqlOutputType::StreamList => {
228 use futures::TryStreamExt;
230
231 let pool_clone = pool.clone();
232 let sql_str = prepared.sql.clone();
233 let bindings = prepared.bindings.clone();
234
235 let byte_stream = async_stream::try_stream! {
237 let mut q = sqlx::query(&sql_str);
238 q = bind_json_values(q, &bindings);
239 let mut rows = q.fetch(&pool_clone);
240 while let Some(row) = rows.try_next().await.map_err(|e| {
241 CamelError::ProcessorError(format!("Query execution failed: {}", e))
242 })? {
243 let json_val = row_to_json(&row).map_err(|e| {
244 CamelError::ProcessorError(format!("JSON serialization failed: {}", e))
245 })?;
246 let mut bytes = serde_json::to_vec(&json_val)
247 .map_err(|e| CamelError::ProcessorError(format!("JSON serialization failed: {}", e)))?;
248 bytes.push(b'\n');
249 yield Bytes::from(bytes);
250 }
251 };
252
253 exchange.input.body = Body::Stream(StreamBody {
254 stream: Arc::new(tokio::sync::Mutex::new(Some(Box::pin(byte_stream)))),
255 metadata: StreamMetadata {
256 content_type: Some("application/x-ndjson".to_string()),
257 size_hint: None,
258 origin: None,
259 },
260 });
261 debug!("StreamList: created lazy stream (rows fetched on demand)");
262 }
264 }
265
266 Ok(())
267}
268
269async fn execute_modify(
271 pool: &AnyPool,
272 prepared: &PreparedQuery,
273 config: &SqlEndpointConfig,
274 exchange: &mut Exchange,
275) -> Result<(), CamelError> {
276 let mut query = sqlx::query(&prepared.sql);
277 query = bind_json_values(query, &prepared.bindings);
278
279 let result = query.execute(pool).await.map_err(|e| {
280 error!("Query execution failed: {}", e);
281 CamelError::ProcessorError(format!("Query execution failed: {}", e))
282 })?;
283
284 let rows_affected = result.rows_affected();
285
286 if let Some(expected) = config.expected_update_count
288 && rows_affected as i64 != expected
289 {
290 error!("Expected {} rows affected, got {}", expected, rows_affected);
291 return Err(CamelError::ProcessorError(format!(
292 "Expected {} rows affected, got {}",
293 expected, rows_affected
294 )));
295 }
296
297 exchange
298 .input
299 .set_header(headers::UPDATE_COUNT, serde_json::json!(rows_affected));
300
301 if config.noop {
302 } else {
304 exchange.input.body = Body::Json(json!({ "rowsAffected": rows_affected }));
305 }
306
307 debug!("Modify query affected {} rows", rows_affected);
308
309 Ok(())
310}
311
312async fn execute_batch(
314 pool: &AnyPool,
315 config: &SqlEndpointConfig,
316 exchange: &mut Exchange,
317) -> Result<(), CamelError> {
318 let body_json = match &exchange.input.body {
320 Body::Json(val) => val,
321 _ => {
322 return Err(CamelError::ProcessorError(
323 "Batch mode requires body to be a JSON array of arrays".to_string(),
324 ));
325 }
326 };
327
328 let batch_data = body_json
329 .as_array()
330 .ok_or_else(|| {
331 CamelError::ProcessorError("Batch mode requires body to be a JSON array".to_string())
332 })?
333 .clone();
334
335 let template = parse_query_template(&config.query, config.placeholder)?;
337
338 let mut tx = pool.begin().await.map_err(|e| {
340 error!("Failed to begin transaction: {}", e);
341 CamelError::ProcessorError(format!("Failed to begin transaction: {}", e))
342 })?;
343
344 let mut total_rows_affected: u64 = 0;
345
346 for (batch_idx, params_array) in batch_data.into_iter().enumerate() {
347 params_array.as_array().ok_or_else(|| {
349 CamelError::ProcessorError(format!(
350 "Batch item at index {} must be a JSON array of parameters",
351 batch_idx
352 ))
353 })?;
354
355 let temp_msg = Message::new(Body::Json(params_array.clone()));
357 let temp_exchange = Exchange::new(temp_msg);
358
359 let prepared = resolve_params(&template, &temp_exchange, &config.in_separator)?;
361
362 let mut query = sqlx::query(&prepared.sql);
364 query = bind_json_values(query, &prepared.bindings);
365
366 let result = query.execute(&mut *tx).await.map_err(|e| {
367 error!("Batch query execution failed at index {}: {}", batch_idx, e);
368 CamelError::ProcessorError(format!("Batch query execution failed: {}", e))
369 })?;
370
371 if let Some(expected) = config.expected_update_count
373 && result.rows_affected() as i64 != expected
374 {
375 error!(
376 "Batch item {}: expected {} rows affected, got {}",
377 batch_idx,
378 expected,
379 result.rows_affected()
380 );
381 return Err(CamelError::ProcessorError(format!(
382 "Batch item {}: expected {} rows affected, got {}",
383 batch_idx,
384 expected,
385 result.rows_affected()
386 )));
387 }
388
389 total_rows_affected += result.rows_affected();
390 }
391
392 tx.commit().await.map_err(|e| {
394 error!("Failed to commit transaction: {}", e);
395 CamelError::ProcessorError(format!("Failed to commit transaction: {}", e))
396 })?;
397
398 exchange.input.set_header(
399 headers::UPDATE_COUNT,
400 serde_json::json!(total_rows_affected),
401 );
402
403 debug!(
404 "Batch execution completed, total rows affected: {}",
405 total_rows_affected
406 );
407
408 Ok(())
409}
410
411#[cfg(test)]
412mod tests {
413 use super::*;
414 use camel_component_api::Message;
415 use camel_component_api::UriConfig;
416 use sqlx::any::AnyPoolOptions;
417 use std::sync::Arc;
418 use tokio::sync::OnceCell;
419
420 async fn sqlite_pool() -> AnyPool {
421 sqlx::any::install_default_drivers();
422 AnyPoolOptions::new()
423 .max_connections(1)
424 .connect("sqlite::memory:")
425 .await
426 .expect("sqlite pool")
427 }
428
429 async fn seed_items_table(pool: &AnyPool) {
430 sqlx::query(
431 "CREATE TABLE items (id INTEGER PRIMARY KEY, name TEXT, done INTEGER DEFAULT 0)",
432 )
433 .execute(pool)
434 .await
435 .expect("create table");
436 sqlx::query("INSERT INTO items (id, name, done) VALUES (1, 'a', 0), (2, 'b', 0)")
437 .execute(pool)
438 .await
439 .expect("seed rows");
440 }
441
442 fn config() -> SqlEndpointConfig {
443 let mut c =
444 SqlEndpointConfig::from_uri("sql:select 1?db_url=postgres://localhost/test").unwrap();
445 c.resolve_defaults();
446 c
447 }
448
449 #[test]
450 fn producer_clone_shares_pool() {
451 let p1 = SqlProducer::new(config(), Arc::new(OnceCell::new()));
452 let p2 = p1.clone();
453 assert!(Arc::ptr_eq(&p1.pool, &p2.pool));
454 }
455
456 #[test]
457 fn resolve_query_from_config() {
458 let config = config();
459 let ex = Exchange::new(Message::default());
460 let q = SqlProducer::resolve_query_source(&ex, &config);
461 assert_eq!(q, "select 1");
462 }
463
464 #[test]
465 fn resolve_query_from_header() {
466 let config = config();
467 let mut msg = Message::default();
468 msg.set_header(headers::QUERY, serde_json::json!("select 2"));
469 let ex = Exchange::new(msg);
470 let q = SqlProducer::resolve_query_source(&ex, &config);
471 assert_eq!(q, "select 2");
472 }
473
474 #[test]
475 fn resolve_query_from_body() {
476 let mut config = config();
477 config.use_message_body_for_sql = true;
478 let msg = Message::new(Body::Text("select 3".to_string()));
479 let ex = Exchange::new(msg);
480 let q = SqlProducer::resolve_query_source(&ex, &config);
481 assert_eq!(q, "select 3");
482 }
483
484 #[test]
485 fn resolve_query_header_priority_over_body() {
486 let mut config = config();
487 config.use_message_body_for_sql = true;
488 let mut msg = Message::new(Body::Text("select from body".to_string()));
489 msg.set_header(headers::QUERY, serde_json::json!("select from header"));
490 let ex = Exchange::new(msg);
491 let q = SqlProducer::resolve_query_source(&ex, &config);
492 assert_eq!(q, "select from header");
493 }
494
495 #[test]
496 fn resolve_query_body_priority_over_config() {
497 let mut config = config();
498 config.use_message_body_for_sql = true;
499 let msg = Message::new(Body::Text("select from body".to_string()));
500 let ex = Exchange::new(msg);
501 let q = SqlProducer::resolve_query_source(&ex, &config);
502 assert_eq!(q, "select from body");
503 }
504
505 #[test]
506 fn bind_json_null() {
507 let query = sqlx::query("SELECT ?");
508 let values = vec![serde_json::Value::Null];
509 let _bound = bind_json_values(query, &values);
510 }
512
513 #[test]
514 fn bind_json_bool() {
515 let query = sqlx::query("SELECT ?");
516 let values = vec![serde_json::Value::Bool(true)];
517 let _bound = bind_json_values(query, &values);
518 }
519
520 #[test]
521 fn bind_json_number_i64() {
522 let query = sqlx::query("SELECT ?");
523 let values = vec![serde_json::json!(42)];
524 let _bound = bind_json_values(query, &values);
525 }
526
527 #[test]
528 fn bind_json_number_f64() {
529 let query = sqlx::query("SELECT ?");
530 let values = vec![serde_json::json!(std::f64::consts::PI)];
531 let _bound = bind_json_values(query, &values);
532 }
533
534 #[test]
535 fn bind_json_string() {
536 let query = sqlx::query("SELECT ?");
537 let values = vec![serde_json::json!("hello world")];
538 let _bound = bind_json_values(query, &values);
539 }
540
541 #[test]
542 fn bind_json_array() {
543 let query = sqlx::query("SELECT ?");
544 let values = vec![serde_json::json!([1, 2, 3])];
545 let _bound = bind_json_values(query, &values);
546 }
547
548 #[test]
549 fn bind_json_object() {
550 let query = sqlx::query("SELECT ?");
551 let values = vec![serde_json::json!({"key": "value"})];
552 let _bound = bind_json_values(query, &values);
553 }
554
555 #[test]
556 fn bind_multiple_values() {
557 let query = sqlx::query("SELECT ?, ?, ?");
558 let values = vec![
559 serde_json::json!(1),
560 serde_json::json!("test"),
561 serde_json::Value::Null,
562 ];
563 let _bound = bind_json_values(query, &values);
564 }
565
566 #[test]
568 fn expected_update_count_validation() {
569 let config = SqlEndpointConfig::from_uri(
571 "sql:update t set x=1?db_url=postgres://localhost/test&expectedUpdateCount=5",
572 )
573 .unwrap();
574 assert_eq!(config.expected_update_count, Some(5));
575
576 let config_default = self::config();
578 assert_eq!(config_default.expected_update_count, None);
579
580 let config_neg = SqlEndpointConfig::from_uri(
582 "sql:update t set x=1?db_url=postgres://localhost/test&expectedUpdateCount=-1",
583 )
584 .unwrap();
585 assert_eq!(config_neg.expected_update_count, Some(-1));
586 }
587
588 #[test]
590 fn parameters_header_override_logic() {
591 let mut prepared = PreparedQuery {
593 sql: "SELECT * FROM t WHERE id = $1".to_string(),
594 bindings: vec![serde_json::json!(42)],
595 };
596
597 let header_params = serde_json::json!([99, "extra"]);
599 if let Some(arr) = header_params.as_array() {
600 prepared.bindings = arr.clone();
601 }
602
603 assert_eq!(prepared.bindings.len(), 2);
605 assert_eq!(prepared.bindings[0], serde_json::json!(99));
606 assert_eq!(prepared.bindings[1], serde_json::json!("extra"));
607
608 let mut prepared2 = PreparedQuery {
610 sql: "SELECT * FROM t WHERE id = $1".to_string(),
611 bindings: vec![serde_json::json!(42)],
612 };
613 let header_non_array = serde_json::json!({"not": "an array"});
614 if let Some(arr) = header_non_array.as_array() {
615 prepared2.bindings = arr.clone();
616 }
617 assert_eq!(prepared2.bindings.len(), 1);
619 assert_eq!(prepared2.bindings[0], serde_json::json!(42));
620 }
621
622 #[tokio::test]
623 async fn execute_select_one_sets_body_and_row_count() {
624 let pool = sqlite_pool().await;
625 seed_items_table(&pool).await;
626
627 let mut config = SqlEndpointConfig::from_uri(
628 "sql:select id, name from items order by id?db_url=sqlite::memory:&outputType=SelectOne",
629 )
630 .unwrap();
631 config.resolve_defaults();
632
633 let prepared = PreparedQuery {
634 sql: "select id, name from items order by id".to_string(),
635 bindings: vec![],
636 };
637 let mut exchange = Exchange::new(Message::default());
638
639 execute_select(&pool, &prepared, &config, &mut exchange)
640 .await
641 .expect("select one");
642
643 assert_eq!(exchange.input.header(headers::ROW_COUNT), Some(&json!(2)));
644 assert_eq!(
645 exchange.input.body,
646 Body::Json(json!({"id": 1, "name": "a"}))
647 );
648 }
649
650 #[tokio::test]
651 async fn execute_stream_list_materializes_ndjson() {
652 let pool = sqlite_pool().await;
653 seed_items_table(&pool).await;
654
655 let mut config = SqlEndpointConfig::from_uri(
656 "sql:select id from items order by id?db_url=sqlite::memory:&outputType=StreamList",
657 )
658 .unwrap();
659 config.resolve_defaults();
660
661 let prepared = PreparedQuery {
662 sql: "select id from items order by id".to_string(),
663 bindings: vec![],
664 };
665 let mut exchange = Exchange::new(Message::default());
666
667 execute_select(&pool, &prepared, &config, &mut exchange)
668 .await
669 .expect("stream list");
670
671 let bytes = exchange
672 .input
673 .body
674 .clone()
675 .into_bytes(1024)
676 .await
677 .expect("stream bytes");
678 let text = String::from_utf8(bytes.to_vec()).expect("utf8");
679 assert!(text.contains("{\"id\":1}"));
680 assert!(text.contains("{\"id\":2}"));
681 assert_eq!(exchange.input.header(headers::ROW_COUNT), None);
682 }
683
684 #[tokio::test]
685 async fn execute_modify_expected_update_count_mismatch_returns_error() {
686 let pool = sqlite_pool().await;
687 seed_items_table(&pool).await;
688
689 let mut config = SqlEndpointConfig::from_uri(
690 "sql:update items set done=1 where id = #?db_url=sqlite::memory:&expectedUpdateCount=2",
691 )
692 .unwrap();
693 config.resolve_defaults();
694
695 let prepared = PreparedQuery {
696 sql: "update items set done=1 where id = $1".to_string(),
697 bindings: vec![json!(1)],
698 };
699 let mut exchange = Exchange::new(Message::default());
700
701 let err = execute_modify(&pool, &prepared, &config, &mut exchange)
702 .await
703 .expect_err("must fail due expected row count mismatch");
704 assert!(err.to_string().contains("Expected 2 rows affected, got 1"));
705 }
706
707 #[tokio::test]
708 async fn execute_batch_rollback_when_any_item_fails_expected_count() {
709 let pool = sqlite_pool().await;
710 seed_items_table(&pool).await;
711
712 let mut config = SqlEndpointConfig::from_uri(
713 "sql:update items set done=1 where id = #?db_url=sqlite::memory:&batch=true&expectedUpdateCount=1",
714 )
715 .unwrap();
716 config.resolve_defaults();
717
718 let mut exchange = Exchange::new(Message::new(Body::Json(json!([[1], [999]]))));
719
720 let err = execute_batch(&pool, &config, &mut exchange)
721 .await
722 .expect_err("second batch item should fail expectedUpdateCount");
723 assert!(
724 err.to_string()
725 .contains("Batch item 1: expected 1 rows affected, got 0")
726 );
727
728 let row = sqlx::query("select done from items where id = 1")
729 .fetch_one(&pool)
730 .await
731 .expect("query row");
732 let done: i64 = sqlx::Row::try_get(&row, 0).expect("done column");
733 assert_eq!(done, 0, "transaction must rollback first update");
734 }
735}