1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::{Context, Poll};
5use std::time::Duration;
6
7use bytes::Bytes;
8use serde_json::json;
9use sqlx::AnyPool;
10use sqlx::any::AnyRow;
11use sqlx::pool::PoolOptions;
12use tokio::sync::OnceCell;
13use tower::Service;
14use tracing::{debug, error, warn};
15
16use crate::config::{SqlEndpointConfig, SqlOutputType};
17use crate::headers;
18use crate::query::{PreparedQuery, is_select_query, parse_query_template, resolve_params};
19use crate::utils::{bind_json_values, row_to_json};
20use camel_component_api::{Body, CamelError, Exchange, Message, StreamBody, StreamMetadata};
21
22#[derive(Clone)]
23pub struct SqlProducer {
24 pub(crate) config: SqlEndpointConfig,
25 pub(crate) pool: Arc<OnceCell<AnyPool>>,
26}
27
28impl SqlProducer {
29 pub fn new(config: SqlEndpointConfig, pool: Arc<OnceCell<AnyPool>>) -> Self {
30 Self { config, pool }
31 }
32
33 pub(crate) fn resolve_query_source(exchange: &Exchange, config: &SqlEndpointConfig) -> String {
38 if let Some(query_value) = exchange.input.header(headers::QUERY)
40 && let Some(query_str) = query_value.as_str()
41 {
42 return query_str.to_string();
43 }
44
45 if config.use_message_body_for_sql
47 && let Some(body_text) = exchange.input.body.as_text()
48 {
49 return body_text.to_string();
50 }
51
52 config.query.clone()
54 }
55}
56
57impl Service<Exchange> for SqlProducer {
58 type Response = Exchange;
59 type Error = CamelError;
60 type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
61
62 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
63 Poll::Ready(Ok(()))
64 }
65
66 fn call(&mut self, mut exchange: Exchange) -> Self::Future {
67 let mut config = self.config.clone();
68 let pool_cell = Arc::clone(&self.pool);
69
70 Box::pin(async move {
71 let pool: &AnyPool = pool_cell
73 .get_or_try_init(|| async {
74 config.resolve_defaults();
76
77 sqlx::any::install_default_drivers();
80 let opts: PoolOptions<sqlx::Any> = PoolOptions::new()
81 .max_connections(
82 config
83 .max_connections
84 .expect("must be Some after resolve_defaults()"),
85 )
86 .min_connections(
87 config
88 .min_connections
89 .expect("must be Some after resolve_defaults()"),
90 )
91 .idle_timeout(Duration::from_secs(
92 config
93 .idle_timeout_secs
94 .expect("must be Some after resolve_defaults()"),
95 ))
96 .max_lifetime(Duration::from_secs(
97 config
98 .max_lifetime_secs
99 .expect("must be Some after resolve_defaults()"),
100 ));
101 opts.connect(&config.db_url).await.map_err(|e| {
102 error!("Failed to connect to database: {}", e);
103 CamelError::EndpointCreationFailed(format!(
104 "Failed to connect to database: {}",
105 e
106 ))
107 })
108 })
109 .await
110 .map_err(|e: CamelError| {
111 error!("Pool initialization failed: {}", e);
112 e.clone()
113 })?;
114
115 let query_str = Self::resolve_query_source(&exchange, &config);
117
118 debug!("Executing SQL: {}", query_str);
119
120 if config.batch {
122 execute_batch(pool, &config, &mut exchange).await?;
124 } else {
125 let template = parse_query_template(&query_str, config.placeholder)?;
127 let mut prepared = resolve_params(&template, &exchange)?;
128
129 if let Some(params_value) = exchange.input.header(headers::PARAMETERS) {
131 if let Some(arr) = params_value.as_array() {
132 if arr.len() != prepared.bindings.len() {
133 warn!(
134 expected = prepared.bindings.len(),
135 got = arr.len(),
136 header = headers::PARAMETERS,
137 "Parameter count mismatch — SQL has {} placeholders but header provides {} values",
138 prepared.bindings.len(),
139 arr.len()
140 );
141 }
142 debug!(
143 "Overriding bindings from {} header with {} parameters",
144 headers::PARAMETERS,
145 arr.len()
146 );
147 prepared.bindings = arr.clone();
148 } else {
149 warn!(
150 header = headers::PARAMETERS,
151 "Header is present but not a JSON array — ignoring parameter override"
152 );
153 }
154 }
155
156 debug!("Executing SQL: {}", prepared.sql);
157
158 if is_select_query(&prepared.sql) {
159 execute_select(pool, &prepared, &config, &mut exchange).await?;
160 } else {
161 execute_modify(pool, &prepared, &config, &mut exchange).await?;
162 }
163 }
164
165 Ok(exchange)
166 })
167 }
168}
169
170async fn execute_select(
172 pool: &AnyPool,
173 prepared: &PreparedQuery,
174 config: &SqlEndpointConfig,
175 exchange: &mut Exchange,
176) -> Result<(), CamelError> {
177 match config.output_type {
178 SqlOutputType::SelectOne => {
179 let mut query = sqlx::query(&prepared.sql);
181 query = bind_json_values(query, &prepared.bindings);
182
183 let rows: Vec<AnyRow> = query.fetch_all(pool).await.map_err(|e| {
184 error!("Query execution failed: {}", e);
185 CamelError::ProcessorError(format!("Query execution failed: {}", e))
186 })?;
187
188 let count = rows.len();
189 let json_rows: Vec<serde_json::Value> = rows
190 .iter()
191 .map(row_to_json)
192 .collect::<Result<Vec<_>, _>>()?;
193
194 if let Some(first_row) = json_rows.into_iter().next() {
195 exchange.input.body = Body::Json(first_row);
196 } else {
197 exchange.input.body = Body::Empty;
198 }
199 debug!("SelectOne returned {} row", if count > 0 { 1 } else { 0 });
200 exchange
201 .input
202 .set_header(headers::ROW_COUNT, serde_json::json!(count));
203 }
204 SqlOutputType::SelectList => {
205 let mut query = sqlx::query(&prepared.sql);
207 query = bind_json_values(query, &prepared.bindings);
208
209 let rows: Vec<AnyRow> = query.fetch_all(pool).await.map_err(|e| {
210 error!("Query execution failed: {}", e);
211 CamelError::ProcessorError(format!("Query execution failed: {}", e))
212 })?;
213
214 let count = rows.len();
215 let json_rows: Vec<serde_json::Value> = rows
216 .iter()
217 .map(row_to_json)
218 .collect::<Result<Vec<_>, _>>()?;
219
220 exchange.input.body = Body::Json(serde_json::Value::Array(json_rows));
221 debug!("SelectList returned {} rows", count);
222 exchange
223 .input
224 .set_header(headers::ROW_COUNT, serde_json::json!(count));
225 }
226 SqlOutputType::StreamList => {
227 use futures::TryStreamExt;
229
230 let pool_clone = pool.clone();
231 let sql_str = prepared.sql.clone();
232 let bindings = prepared.bindings.clone();
233
234 let byte_stream = async_stream::try_stream! {
236 let mut q = sqlx::query(&sql_str);
237 q = bind_json_values(q, &bindings);
238 let mut rows = q.fetch(&pool_clone);
239 while let Some(row) = rows.try_next().await.map_err(|e| {
240 CamelError::ProcessorError(format!("Query execution failed: {}", e))
241 })? {
242 let json_val = row_to_json(&row).map_err(|e| {
243 CamelError::ProcessorError(format!("JSON serialization failed: {}", e))
244 })?;
245 let mut bytes = serde_json::to_vec(&json_val)
246 .map_err(|e| CamelError::ProcessorError(format!("JSON serialization failed: {}", e)))?;
247 bytes.push(b'\n');
248 yield Bytes::from(bytes);
249 }
250 };
251
252 exchange.input.body = Body::Stream(StreamBody {
253 stream: Arc::new(tokio::sync::Mutex::new(Some(Box::pin(byte_stream)))),
254 metadata: StreamMetadata {
255 content_type: Some("application/x-ndjson".to_string()),
256 size_hint: None,
257 origin: None,
258 },
259 });
260 debug!("StreamList: created lazy stream (rows fetched on demand)");
261 }
263 }
264
265 Ok(())
266}
267
268async fn execute_modify(
270 pool: &AnyPool,
271 prepared: &PreparedQuery,
272 config: &SqlEndpointConfig,
273 exchange: &mut Exchange,
274) -> Result<(), CamelError> {
275 let mut query = sqlx::query(&prepared.sql);
276 query = bind_json_values(query, &prepared.bindings);
277
278 let result = query.execute(pool).await.map_err(|e| {
279 error!("Query execution failed: {}", e);
280 CamelError::ProcessorError(format!("Query execution failed: {}", e))
281 })?;
282
283 let rows_affected = result.rows_affected();
284
285 if let Some(expected) = config.expected_update_count
287 && rows_affected as i64 != expected
288 {
289 error!("Expected {} rows affected, got {}", expected, rows_affected);
290 return Err(CamelError::ProcessorError(format!(
291 "Expected {} rows affected, got {}",
292 expected, rows_affected
293 )));
294 }
295
296 exchange
297 .input
298 .set_header(headers::UPDATE_COUNT, serde_json::json!(rows_affected));
299
300 if config.noop {
301 } else {
303 exchange.input.body = Body::Json(json!({ "rowsAffected": rows_affected }));
304 }
305
306 debug!("Modify query affected {} rows", rows_affected);
307
308 Ok(())
309}
310
311async fn execute_batch(
313 pool: &AnyPool,
314 config: &SqlEndpointConfig,
315 exchange: &mut Exchange,
316) -> Result<(), CamelError> {
317 let body_json = match &exchange.input.body {
319 Body::Json(val) => val,
320 _ => {
321 return Err(CamelError::ProcessorError(
322 "Batch mode requires body to be a JSON array of arrays".to_string(),
323 ));
324 }
325 };
326
327 let batch_data = body_json
328 .as_array()
329 .ok_or_else(|| {
330 CamelError::ProcessorError("Batch mode requires body to be a JSON array".to_string())
331 })?
332 .clone();
333
334 let template = parse_query_template(&config.query, config.placeholder)?;
336
337 let mut tx = pool.begin().await.map_err(|e| {
339 error!("Failed to begin transaction: {}", e);
340 CamelError::ProcessorError(format!("Failed to begin transaction: {}", e))
341 })?;
342
343 let mut total_rows_affected: u64 = 0;
344
345 for (batch_idx, params_array) in batch_data.into_iter().enumerate() {
346 params_array.as_array().ok_or_else(|| {
348 CamelError::ProcessorError(format!(
349 "Batch item at index {} must be a JSON array of parameters",
350 batch_idx
351 ))
352 })?;
353
354 let temp_msg = Message::new(Body::Json(params_array.clone()));
356 let temp_exchange = Exchange::new(temp_msg);
357
358 let prepared = resolve_params(&template, &temp_exchange)?;
360
361 let mut query = sqlx::query(&prepared.sql);
363 query = bind_json_values(query, &prepared.bindings);
364
365 let result = query.execute(&mut *tx).await.map_err(|e| {
366 error!("Batch query execution failed at index {}: {}", batch_idx, e);
367 CamelError::ProcessorError(format!("Batch query execution failed: {}", e))
368 })?;
369
370 if let Some(expected) = config.expected_update_count
372 && result.rows_affected() as i64 != expected
373 {
374 error!(
375 "Batch item {}: expected {} rows affected, got {}",
376 batch_idx,
377 expected,
378 result.rows_affected()
379 );
380 return Err(CamelError::ProcessorError(format!(
381 "Batch item {}: expected {} rows affected, got {}",
382 batch_idx,
383 expected,
384 result.rows_affected()
385 )));
386 }
387
388 total_rows_affected += result.rows_affected();
389 }
390
391 tx.commit().await.map_err(|e| {
393 error!("Failed to commit transaction: {}", e);
394 CamelError::ProcessorError(format!("Failed to commit transaction: {}", e))
395 })?;
396
397 exchange.input.set_header(
398 headers::UPDATE_COUNT,
399 serde_json::json!(total_rows_affected),
400 );
401
402 debug!(
403 "Batch execution completed, total rows affected: {}",
404 total_rows_affected
405 );
406
407 Ok(())
408}
409
410#[cfg(test)]
411mod tests {
412 use super::*;
413 use camel_component_api::Message;
414 use camel_component_api::UriConfig;
415 use sqlx::any::AnyPoolOptions;
416 use std::sync::Arc;
417 use tokio::sync::OnceCell;
418
419 async fn sqlite_pool() -> AnyPool {
420 sqlx::any::install_default_drivers();
421 AnyPoolOptions::new()
422 .max_connections(1)
423 .connect("sqlite::memory:")
424 .await
425 .expect("sqlite pool")
426 }
427
428 async fn seed_items_table(pool: &AnyPool) {
429 sqlx::query(
430 "CREATE TABLE items (id INTEGER PRIMARY KEY, name TEXT, done INTEGER DEFAULT 0)",
431 )
432 .execute(pool)
433 .await
434 .expect("create table");
435 sqlx::query("INSERT INTO items (id, name, done) VALUES (1, 'a', 0), (2, 'b', 0)")
436 .execute(pool)
437 .await
438 .expect("seed rows");
439 }
440
441 fn config() -> SqlEndpointConfig {
442 let mut c =
443 SqlEndpointConfig::from_uri("sql:select 1?db_url=postgres://localhost/test").unwrap();
444 c.resolve_defaults();
445 c
446 }
447
448 #[test]
449 fn producer_clone_shares_pool() {
450 let p1 = SqlProducer::new(config(), Arc::new(OnceCell::new()));
451 let p2 = p1.clone();
452 assert!(Arc::ptr_eq(&p1.pool, &p2.pool));
453 }
454
455 #[test]
456 fn resolve_query_from_config() {
457 let config = config();
458 let ex = Exchange::new(Message::default());
459 let q = SqlProducer::resolve_query_source(&ex, &config);
460 assert_eq!(q, "select 1");
461 }
462
463 #[test]
464 fn resolve_query_from_header() {
465 let config = config();
466 let mut msg = Message::default();
467 msg.set_header(headers::QUERY, serde_json::json!("select 2"));
468 let ex = Exchange::new(msg);
469 let q = SqlProducer::resolve_query_source(&ex, &config);
470 assert_eq!(q, "select 2");
471 }
472
473 #[test]
474 fn resolve_query_from_body() {
475 let mut config = config();
476 config.use_message_body_for_sql = true;
477 let msg = Message::new(Body::Text("select 3".to_string()));
478 let ex = Exchange::new(msg);
479 let q = SqlProducer::resolve_query_source(&ex, &config);
480 assert_eq!(q, "select 3");
481 }
482
483 #[test]
484 fn resolve_query_header_priority_over_body() {
485 let mut config = config();
486 config.use_message_body_for_sql = true;
487 let mut msg = Message::new(Body::Text("select from body".to_string()));
488 msg.set_header(headers::QUERY, serde_json::json!("select from header"));
489 let ex = Exchange::new(msg);
490 let q = SqlProducer::resolve_query_source(&ex, &config);
491 assert_eq!(q, "select from header");
492 }
493
494 #[test]
495 fn resolve_query_body_priority_over_config() {
496 let mut config = config();
497 config.use_message_body_for_sql = true;
498 let msg = Message::new(Body::Text("select from body".to_string()));
499 let ex = Exchange::new(msg);
500 let q = SqlProducer::resolve_query_source(&ex, &config);
501 assert_eq!(q, "select from body");
502 }
503
504 #[test]
505 fn bind_json_null() {
506 let query = sqlx::query("SELECT ?");
507 let values = vec![serde_json::Value::Null];
508 let _bound = bind_json_values(query, &values);
509 }
511
512 #[test]
513 fn bind_json_bool() {
514 let query = sqlx::query("SELECT ?");
515 let values = vec![serde_json::Value::Bool(true)];
516 let _bound = bind_json_values(query, &values);
517 }
518
519 #[test]
520 fn bind_json_number_i64() {
521 let query = sqlx::query("SELECT ?");
522 let values = vec![serde_json::json!(42)];
523 let _bound = bind_json_values(query, &values);
524 }
525
526 #[test]
527 fn bind_json_number_f64() {
528 let query = sqlx::query("SELECT ?");
529 let values = vec![serde_json::json!(std::f64::consts::PI)];
530 let _bound = bind_json_values(query, &values);
531 }
532
533 #[test]
534 fn bind_json_string() {
535 let query = sqlx::query("SELECT ?");
536 let values = vec![serde_json::json!("hello world")];
537 let _bound = bind_json_values(query, &values);
538 }
539
540 #[test]
541 fn bind_json_array() {
542 let query = sqlx::query("SELECT ?");
543 let values = vec![serde_json::json!([1, 2, 3])];
544 let _bound = bind_json_values(query, &values);
545 }
546
547 #[test]
548 fn bind_json_object() {
549 let query = sqlx::query("SELECT ?");
550 let values = vec![serde_json::json!({"key": "value"})];
551 let _bound = bind_json_values(query, &values);
552 }
553
554 #[test]
555 fn bind_multiple_values() {
556 let query = sqlx::query("SELECT ?, ?, ?");
557 let values = vec![
558 serde_json::json!(1),
559 serde_json::json!("test"),
560 serde_json::Value::Null,
561 ];
562 let _bound = bind_json_values(query, &values);
563 }
564
565 #[test]
567 fn expected_update_count_validation() {
568 let config = SqlEndpointConfig::from_uri(
570 "sql:update t set x=1?db_url=postgres://localhost/test&expectedUpdateCount=5",
571 )
572 .unwrap();
573 assert_eq!(config.expected_update_count, Some(5));
574
575 let config_default = self::config();
577 assert_eq!(config_default.expected_update_count, None);
578
579 let config_neg = SqlEndpointConfig::from_uri(
581 "sql:update t set x=1?db_url=postgres://localhost/test&expectedUpdateCount=-1",
582 )
583 .unwrap();
584 assert_eq!(config_neg.expected_update_count, Some(-1));
585 }
586
587 #[test]
589 fn parameters_header_override_logic() {
590 let mut prepared = PreparedQuery {
592 sql: "SELECT * FROM t WHERE id = $1".to_string(),
593 bindings: vec![serde_json::json!(42)],
594 };
595
596 let header_params = serde_json::json!([99, "extra"]);
598 if let Some(arr) = header_params.as_array() {
599 prepared.bindings = arr.clone();
600 }
601
602 assert_eq!(prepared.bindings.len(), 2);
604 assert_eq!(prepared.bindings[0], serde_json::json!(99));
605 assert_eq!(prepared.bindings[1], serde_json::json!("extra"));
606
607 let mut prepared2 = PreparedQuery {
609 sql: "SELECT * FROM t WHERE id = $1".to_string(),
610 bindings: vec![serde_json::json!(42)],
611 };
612 let header_non_array = serde_json::json!({"not": "an array"});
613 if let Some(arr) = header_non_array.as_array() {
614 prepared2.bindings = arr.clone();
615 }
616 assert_eq!(prepared2.bindings.len(), 1);
618 assert_eq!(prepared2.bindings[0], serde_json::json!(42));
619 }
620
621 #[tokio::test]
622 async fn execute_select_one_sets_body_and_row_count() {
623 let pool = sqlite_pool().await;
624 seed_items_table(&pool).await;
625
626 let mut config = SqlEndpointConfig::from_uri(
627 "sql:select id, name from items order by id?db_url=sqlite::memory:&outputType=SelectOne",
628 )
629 .unwrap();
630 config.resolve_defaults();
631
632 let prepared = PreparedQuery {
633 sql: "select id, name from items order by id".to_string(),
634 bindings: vec![],
635 };
636 let mut exchange = Exchange::new(Message::default());
637
638 execute_select(&pool, &prepared, &config, &mut exchange)
639 .await
640 .expect("select one");
641
642 assert_eq!(exchange.input.header(headers::ROW_COUNT), Some(&json!(2)));
643 assert_eq!(
644 exchange.input.body,
645 Body::Json(json!({"id": 1, "name": "a"}))
646 );
647 }
648
649 #[tokio::test]
650 async fn execute_stream_list_materializes_ndjson() {
651 let pool = sqlite_pool().await;
652 seed_items_table(&pool).await;
653
654 let mut config = SqlEndpointConfig::from_uri(
655 "sql:select id from items order by id?db_url=sqlite::memory:&outputType=StreamList",
656 )
657 .unwrap();
658 config.resolve_defaults();
659
660 let prepared = PreparedQuery {
661 sql: "select id from items order by id".to_string(),
662 bindings: vec![],
663 };
664 let mut exchange = Exchange::new(Message::default());
665
666 execute_select(&pool, &prepared, &config, &mut exchange)
667 .await
668 .expect("stream list");
669
670 let bytes = exchange
671 .input
672 .body
673 .clone()
674 .into_bytes(1024)
675 .await
676 .expect("stream bytes");
677 let text = String::from_utf8(bytes.to_vec()).expect("utf8");
678 assert!(text.contains("{\"id\":1}"));
679 assert!(text.contains("{\"id\":2}"));
680 assert_eq!(exchange.input.header(headers::ROW_COUNT), None);
681 }
682
683 #[tokio::test]
684 async fn execute_modify_expected_update_count_mismatch_returns_error() {
685 let pool = sqlite_pool().await;
686 seed_items_table(&pool).await;
687
688 let mut config = SqlEndpointConfig::from_uri(
689 "sql:update items set done=1 where id = #?db_url=sqlite::memory:&expectedUpdateCount=2",
690 )
691 .unwrap();
692 config.resolve_defaults();
693
694 let prepared = PreparedQuery {
695 sql: "update items set done=1 where id = $1".to_string(),
696 bindings: vec![json!(1)],
697 };
698 let mut exchange = Exchange::new(Message::default());
699
700 let err = execute_modify(&pool, &prepared, &config, &mut exchange)
701 .await
702 .expect_err("must fail due expected row count mismatch");
703 assert!(err.to_string().contains("Expected 2 rows affected, got 1"));
704 }
705
706 #[tokio::test]
707 async fn execute_batch_rollback_when_any_item_fails_expected_count() {
708 let pool = sqlite_pool().await;
709 seed_items_table(&pool).await;
710
711 let mut config = SqlEndpointConfig::from_uri(
712 "sql:update items set done=1 where id = #?db_url=sqlite::memory:&batch=true&expectedUpdateCount=1",
713 )
714 .unwrap();
715 config.resolve_defaults();
716
717 let mut exchange = Exchange::new(Message::new(Body::Json(json!([[1], [999]]))));
718
719 let err = execute_batch(&pool, &config, &mut exchange)
720 .await
721 .expect_err("second batch item should fail expectedUpdateCount");
722 assert!(
723 err.to_string()
724 .contains("Batch item 1: expected 1 rows affected, got 0")
725 );
726
727 let row = sqlx::query("select done from items where id = 1")
728 .fetch_one(&pool)
729 .await
730 .expect("query row");
731 let done: i64 = sqlx::Row::try_get(&row, 0).expect("done column");
732 assert_eq!(done, 0, "transaction must rollback first update");
733 }
734}