1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::sync::atomic::{AtomicBool, Ordering};
5use std::task::{Context, Poll};
6use std::time::Duration;
7
8use bytes::Bytes;
9use serde_json::json;
10use sqlx::AnyPool;
11use sqlx::any::AnyRow;
12use sqlx::pool::PoolOptions;
13use tokio::sync::OnceCell;
14use tower::Service;
15use tracing::{debug, error, info, warn};
16
17use crate::config::{SqlEndpointConfig, SqlOutputType, enrich_db_url_with_ssl, redact_db_url};
18use crate::headers;
19use crate::query::{PreparedQuery, is_select_query, parse_query_template, resolve_params};
20use crate::utils::{bind_json_values, row_to_json};
21use camel_component_api::{Body, CamelError, Exchange, Message, StreamBody, StreamMetadata};
22
23#[derive(Clone)]
24pub struct SqlProducer {
25 pub(crate) config: SqlEndpointConfig,
26 pub(crate) pool: Arc<OnceCell<AnyPool>>,
27 pub(crate) stopped: Arc<AtomicBool>,
28}
29
30impl SqlProducer {
31 pub fn new(config: SqlEndpointConfig, pool: Arc<OnceCell<AnyPool>>) -> Self {
32 Self {
33 config,
34 pool,
35 stopped: Arc::new(AtomicBool::new(false)),
36 }
37 }
38
39 pub fn stop(&self) {
40 self.stopped.store(true, Ordering::Relaxed);
41 if let Some(pool) = self.pool.get() {
43 let pool = pool.clone();
44 tokio::spawn(async move {
45 if tokio::time::timeout(Duration::from_secs(5), pool.close())
46 .await
47 .is_err()
48 {
49 tracing::warn!("SQL producer pool did not close within 5s");
50 }
51 });
52 }
53 }
54
55 pub(crate) fn resolve_query_source(exchange: &Exchange, config: &SqlEndpointConfig) -> String {
60 if let Some(query_value) = exchange.input.header(headers::QUERY)
62 && let Some(query_str) = query_value.as_str()
63 {
64 return query_str.to_string();
65 }
66
67 if config.use_message_body_for_sql
69 && let Some(body_text) = exchange.input.body.as_text()
70 {
71 return body_text.to_string();
72 }
73
74 config.query.clone()
76 }
77
78 pub async fn check_connection(&self) -> Result<(), CamelError> {
84 let pool = self.pool.get().ok_or_else(|| {
85 CamelError::ProcessorError("SQL connection pool not initialized".into())
86 })?;
87
88 debug!("Running health check: SELECT 1");
89 sqlx::query("SELECT 1").execute(pool).await.map_err(|e| {
90 warn!(error = %e, "SQL health check failed");
91 CamelError::ProcessorError(format!("SQL health check failed: {}", e))
92 })?;
93
94 debug!("SQL health check passed");
95 Ok(())
96 }
97}
98
99impl Service<Exchange> for SqlProducer {
100 type Response = Exchange;
101 type Error = CamelError;
102 type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
103
104 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
105 if self.stopped.load(Ordering::Relaxed) {
106 return Poll::Ready(Err(CamelError::ProcessorError(
107 "SQL producer stopped".into(),
108 )));
109 }
110 if let Some(pool) = self.pool.get()
111 && pool.is_closed()
112 {
113 return Poll::Ready(Err(CamelError::ProcessorError(
114 "SQL connection pool is closed".into(),
115 )));
116 }
117 Poll::Ready(Ok(()))
118 }
119
120 fn call(&mut self, mut exchange: Exchange) -> Self::Future {
121 let mut config = self.config.clone();
122 let pool_cell = Arc::clone(&self.pool);
123
124 Box::pin(async move {
125 let pool: &AnyPool = pool_cell
127 .get_or_try_init(|| async {
128 config.resolve_defaults();
130 config.resolve_file_query().await?;
132 let db_url = enrich_db_url_with_ssl(&config.db_url, &config)?;
133
134 sqlx::any::install_default_drivers();
137
138 let max_conn = config.max_connections.ok_or_else(|| {
139 CamelError::Config("max_connections not resolved for SQL pool".into())
140 })?;
141 let min_conn = config.min_connections.ok_or_else(|| {
142 CamelError::Config("min_connections not resolved for SQL pool".into())
143 })?;
144 let idle_timeout = config.idle_timeout_secs.ok_or_else(|| {
145 CamelError::Config("idle_timeout_secs not resolved for SQL pool".into())
146 })?;
147 let max_lifetime = config.max_lifetime_secs.ok_or_else(|| {
148 CamelError::Config("max_lifetime_secs not resolved for SQL pool".into())
149 })?;
150
151 let opts: PoolOptions<sqlx::Any> = PoolOptions::new()
152 .max_connections(max_conn)
153 .min_connections(min_conn)
154 .idle_timeout(Duration::from_secs(idle_timeout))
155 .max_lifetime(Duration::from_secs(max_lifetime));
156
157 info!(
158 db_url = %redact_db_url(&config.db_url),
159 "SQL producer pool initializing"
160 );
161 opts.connect(&db_url).await.map_err(|e| {
162 error!(error = %e, db_url = %redact_db_url(&config.db_url), "Failed to connect to database");
163 CamelError::EndpointCreationFailed(format!(
164 "Failed to connect to database: {}",
165 e
166 ))
167 })
168 })
169 .await
170 .map_err(|e: CamelError| {
171 error!("SQL producer pool initialization failed: {}", e);
172 e.clone()
173 })?;
174
175 let query_str = Self::resolve_query_source(&exchange, &config);
177
178 if config.transaction_mode == crate::config::TransactionMode::Managed {
180 warn!("transactionManager not yet implemented; using Auto mode");
181 }
182
183 debug!(
184 query = %query_str,
185 "executing SQL query"
186 );
187
188 if config.batch {
190 execute_batch(pool, &config, &mut exchange).await?;
192 } else if config.use_placeholder {
193 let template = parse_query_template(&query_str, config.placeholder)?;
195 let mut prepared = resolve_params(&template, &exchange, &config.in_separator)?;
196
197 if let Some(params_value) = exchange.input.header(headers::PARAMETERS) {
199 if let Some(arr) = params_value.as_array() {
200 if arr.len() != prepared.bindings.len() {
201 warn!(
202 expected = prepared.bindings.len(),
203 got = arr.len(),
204 header = headers::PARAMETERS,
205 "Parameter count mismatch — SQL has {} placeholders but header provides {} values",
206 prepared.bindings.len(),
207 arr.len()
208 );
209 }
210 debug!(
211 "Overriding bindings from {} header with {} parameters",
212 headers::PARAMETERS,
213 arr.len()
214 );
215 prepared.bindings = arr.clone();
216 } else {
217 warn!(
218 header = headers::PARAMETERS,
219 "Header is present but not a JSON array — ignoring parameter override"
220 );
221 }
222 }
223
224 debug!(
225 "Executing prepared SQL ({} bindings)",
226 prepared.bindings.len()
227 );
228
229 if is_select_query(&prepared.sql) {
230 execute_select(pool, &prepared, &config, &mut exchange).await?;
231 } else {
232 execute_modify(pool, &prepared, &config, &mut exchange).await?;
233 }
234 } else {
235 debug!("Executing raw SQL (placeholder processing disabled)");
237 let prepared = PreparedQuery {
238 sql: query_str,
239 bindings: vec![],
240 };
241
242 if is_select_query(&prepared.sql) {
243 execute_select(pool, &prepared, &config, &mut exchange).await?;
244 } else {
245 execute_modify(pool, &prepared, &config, &mut exchange).await?;
246 }
247 }
248
249 Ok(exchange)
250 })
251 }
252}
253
254async fn execute_select(
256 pool: &AnyPool,
257 prepared: &PreparedQuery,
258 config: &SqlEndpointConfig,
259 exchange: &mut Exchange,
260) -> Result<(), CamelError> {
261 match config.output_type {
262 SqlOutputType::SelectOne => {
263 let mut query = sqlx::query(&prepared.sql);
265 query = bind_json_values(query, &prepared.bindings);
266
267 let rows: Vec<AnyRow> = query.fetch_all(pool).await.map_err(|e| {
268 warn!(error = %e, "SQL query failed");
269 CamelError::ProcessorError(format!("Query execution failed: {}", e))
270 })?;
271
272 let count = rows.len();
273 debug!(rows = count, "SQL query completed");
274 let json_rows: Vec<serde_json::Value> = rows
275 .iter()
276 .map(row_to_json)
277 .collect::<Result<Vec<_>, _>>()?;
278
279 if let Some(first_row) = json_rows.into_iter().next() {
280 exchange.input.body = Body::Json(first_row);
281 } else {
282 exchange.input.body = Body::Empty;
283 }
284 debug!("SelectOne returned {} row", if count > 0 { 1 } else { 0 });
285 exchange
286 .input
287 .set_header(headers::ROW_COUNT, serde_json::json!(count));
288 }
289 SqlOutputType::SelectList => {
290 let mut query = sqlx::query(&prepared.sql);
292 query = bind_json_values(query, &prepared.bindings);
293
294 let rows: Vec<AnyRow> = query.fetch_all(pool).await.map_err(|e| {
295 warn!(error = %e, "SQL query failed");
296 CamelError::ProcessorError(format!("Query execution failed: {}", e))
297 })?;
298
299 let count = rows.len();
300 debug!(rows = count, "SQL query completed");
301 let json_rows: Vec<serde_json::Value> = rows
302 .iter()
303 .map(row_to_json)
304 .collect::<Result<Vec<_>, _>>()?;
305
306 exchange.input.body = Body::Json(serde_json::Value::Array(json_rows));
307 debug!("SelectList returned {} rows", count);
308 exchange
309 .input
310 .set_header(headers::ROW_COUNT, serde_json::json!(count));
311 }
312 SqlOutputType::StreamList => {
313 use futures::TryStreamExt;
315
316 let pool_clone = pool.clone();
317 let sql_str = prepared.sql.clone();
318 let bindings = prepared.bindings.clone();
319
320 let byte_stream = async_stream::try_stream! {
322 let mut q = sqlx::query(&sql_str);
323 q = bind_json_values(q, &bindings);
324 let mut rows = q.fetch(&pool_clone);
325 while let Some(row) = rows.try_next().await.map_err(|e| {
326 CamelError::ProcessorError(format!("Query execution failed: {}", e))
327 })? {
328 let json_val = row_to_json(&row).map_err(|e| {
329 CamelError::ProcessorError(format!("JSON serialization failed: {}", e))
330 })?;
331 let mut bytes = serde_json::to_vec(&json_val)
332 .map_err(|e| CamelError::ProcessorError(format!("JSON serialization failed: {}", e)))?;
333 bytes.push(b'\n');
334 yield Bytes::from(bytes);
335 }
336 };
337
338 exchange.input.body = Body::Stream(StreamBody {
339 stream: Arc::new(tokio::sync::Mutex::new(Some(Box::pin(byte_stream)))),
340 metadata: StreamMetadata {
341 content_type: Some("application/x-ndjson".to_string()),
342 size_hint: None,
343 origin: None,
344 },
345 });
346 debug!("StreamList: created lazy stream (rows fetched on demand)");
347 }
349 }
350
351 Ok(())
352}
353
354async fn execute_modify(
356 pool: &AnyPool,
357 prepared: &PreparedQuery,
358 config: &SqlEndpointConfig,
359 exchange: &mut Exchange,
360) -> Result<(), CamelError> {
361 let mut query = sqlx::query(&prepared.sql);
362 query = bind_json_values(query, &prepared.bindings);
363
364 let result = query.execute(pool).await.map_err(|e| {
365 warn!(error = %e, "SQL query failed");
366 CamelError::ProcessorError(format!("Query execution failed: {}", e))
367 })?;
368
369 let rows_affected = result.rows_affected();
370
371 if let Some(expected) = config.expected_update_count
373 && rows_affected as i64 != expected
374 {
375 warn!(expected, actual = rows_affected, "Row count mismatch");
376 return Err(CamelError::ProcessorError(format!(
377 "Expected {} rows affected, got {}",
378 expected, rows_affected
379 )));
380 }
381
382 exchange
383 .input
384 .set_header(headers::UPDATE_COUNT, serde_json::json!(rows_affected));
385
386 if config.noop {
387 } else {
389 exchange.input.body = Body::Json(json!({ "rowsAffected": rows_affected }));
390 }
391
392 debug!(rows = rows_affected, "SQL modify query completed");
393
394 Ok(())
395}
396
397async fn execute_batch(
399 pool: &AnyPool,
400 config: &SqlEndpointConfig,
401 exchange: &mut Exchange,
402) -> Result<(), CamelError> {
403 let body_json = match &exchange.input.body {
405 Body::Json(val) => val,
406 _ => {
407 return Err(CamelError::ProcessorError(
408 "Batch mode requires body to be a JSON array of arrays".to_string(),
409 ));
410 }
411 };
412
413 let batch_data = body_json
414 .as_array()
415 .ok_or_else(|| {
416 CamelError::ProcessorError("Batch mode requires body to be a JSON array".to_string())
417 })?
418 .clone();
419
420 let template = parse_query_template(&config.query, config.placeholder)?;
422
423 let mut tx = pool.begin().await.map_err(|e| {
425 error!("Failed to begin transaction: {}", e);
426 CamelError::ProcessorError(format!("Failed to begin transaction: {}", e))
427 })?;
428
429 let mut total_rows_affected: u64 = 0;
430
431 for (batch_idx, params_array) in batch_data.into_iter().enumerate() {
432 params_array.as_array().ok_or_else(|| {
434 CamelError::ProcessorError(format!(
435 "Batch item at index {} must be a JSON array of parameters",
436 batch_idx
437 ))
438 })?;
439
440 let temp_msg = Message::new(Body::Json(params_array.clone()));
442 let temp_exchange = Exchange::new(temp_msg);
443
444 let prepared = resolve_params(&template, &temp_exchange, &config.in_separator)?;
446
447 let mut query = sqlx::query(&prepared.sql);
449 query = bind_json_values(query, &prepared.bindings);
450
451 let result = query.execute(&mut *tx).await.map_err(|e| {
452 error!("Batch query execution failed at index {}: {}", batch_idx, e);
453 CamelError::ProcessorError(format!("Batch query execution failed: {}", e))
454 })?;
455
456 if let Some(expected) = config.expected_update_count
458 && result.rows_affected() as i64 != expected
459 {
460 error!(
461 "Batch item {}: expected {} rows affected, got {}",
462 batch_idx,
463 expected,
464 result.rows_affected()
465 );
466 return Err(CamelError::ProcessorError(format!(
467 "Batch item {}: expected {} rows affected, got {}",
468 batch_idx,
469 expected,
470 result.rows_affected()
471 )));
472 }
473
474 total_rows_affected += result.rows_affected();
475 }
476
477 tx.commit().await.map_err(|e| {
479 error!("Failed to commit transaction: {}", e);
480 CamelError::ProcessorError(format!("Failed to commit transaction: {}", e))
481 })?;
482
483 exchange.input.set_header(
484 headers::UPDATE_COUNT,
485 serde_json::json!(total_rows_affected),
486 );
487
488 debug!(
489 "Batch execution completed, total rows affected: {}",
490 total_rows_affected
491 );
492
493 Ok(())
494}
495
496#[cfg(test)]
497mod tests {
498 use super::*;
499 use camel_component_api::Message;
500 use camel_component_api::UriConfig;
501 use sqlx::any::AnyPoolOptions;
502 use std::sync::Arc;
503 use tokio::sync::OnceCell;
504
505 async fn sqlite_pool() -> AnyPool {
506 sqlx::any::install_default_drivers();
507 AnyPoolOptions::new()
508 .max_connections(1)
509 .connect("sqlite::memory:")
510 .await
511 .expect("sqlite pool")
512 }
513
514 async fn seed_items_table(pool: &AnyPool) {
515 sqlx::query(
516 "CREATE TABLE items (id INTEGER PRIMARY KEY, name TEXT, done INTEGER DEFAULT 0)",
517 )
518 .execute(pool)
519 .await
520 .expect("create table");
521 sqlx::query("INSERT INTO items (id, name, done) VALUES (1, 'a', 0), (2, 'b', 0)")
522 .execute(pool)
523 .await
524 .expect("seed rows");
525 }
526
527 fn config() -> SqlEndpointConfig {
528 let mut c =
529 SqlEndpointConfig::from_uri("sql:select 1?db_url=postgres://localhost/test").unwrap();
530 c.resolve_defaults();
531 c
532 }
533
534 #[test]
535 fn producer_clone_shares_pool() {
536 let p1 = SqlProducer::new(config(), Arc::new(OnceCell::new()));
537 let p2 = p1.clone();
538 assert!(Arc::ptr_eq(&p1.pool, &p2.pool));
539 assert!(Arc::ptr_eq(&p1.stopped, &p2.stopped));
540 }
541
542 #[test]
543 fn resolve_query_from_config() {
544 let config = config();
545 let ex = Exchange::new(Message::default());
546 let q = SqlProducer::resolve_query_source(&ex, &config);
547 assert_eq!(q, "select 1");
548 }
549
550 #[test]
551 fn resolve_query_from_header() {
552 let config = config();
553 let mut msg = Message::default();
554 msg.set_header(headers::QUERY, serde_json::json!("select 2"));
555 let ex = Exchange::new(msg);
556 let q = SqlProducer::resolve_query_source(&ex, &config);
557 assert_eq!(q, "select 2");
558 }
559
560 #[test]
561 fn resolve_query_from_body() {
562 let mut config = config();
563 config.use_message_body_for_sql = true;
564 let msg = Message::new(Body::Text("select 3".to_string()));
565 let ex = Exchange::new(msg);
566 let q = SqlProducer::resolve_query_source(&ex, &config);
567 assert_eq!(q, "select 3");
568 }
569
570 #[test]
571 fn resolve_query_header_priority_over_body() {
572 let mut config = config();
573 config.use_message_body_for_sql = true;
574 let mut msg = Message::new(Body::Text("select from body".to_string()));
575 msg.set_header(headers::QUERY, serde_json::json!("select from header"));
576 let ex = Exchange::new(msg);
577 let q = SqlProducer::resolve_query_source(&ex, &config);
578 assert_eq!(q, "select from header");
579 }
580
581 #[test]
582 fn resolve_query_body_priority_over_config() {
583 let mut config = config();
584 config.use_message_body_for_sql = true;
585 let msg = Message::new(Body::Text("select from body".to_string()));
586 let ex = Exchange::new(msg);
587 let q = SqlProducer::resolve_query_source(&ex, &config);
588 assert_eq!(q, "select from body");
589 }
590
591 #[test]
592 fn bind_json_null() {
593 let query = sqlx::query("SELECT ?");
594 let values = vec![serde_json::Value::Null];
595 let _bound = bind_json_values(query, &values);
596 }
598
599 #[test]
600 fn bind_json_bool() {
601 let query = sqlx::query("SELECT ?");
602 let values = vec![serde_json::Value::Bool(true)];
603 let _bound = bind_json_values(query, &values);
604 }
605
606 #[test]
607 fn bind_json_number_i64() {
608 let query = sqlx::query("SELECT ?");
609 let values = vec![serde_json::json!(42)];
610 let _bound = bind_json_values(query, &values);
611 }
612
613 #[test]
614 fn bind_json_number_f64() {
615 let query = sqlx::query("SELECT ?");
616 let values = vec![serde_json::json!(std::f64::consts::PI)];
617 let _bound = bind_json_values(query, &values);
618 }
619
620 #[test]
621 fn bind_json_string() {
622 let query = sqlx::query("SELECT ?");
623 let values = vec![serde_json::json!("hello world")];
624 let _bound = bind_json_values(query, &values);
625 }
626
627 #[test]
628 fn bind_json_array() {
629 let query = sqlx::query("SELECT ?");
630 let values = vec![serde_json::json!([1, 2, 3])];
631 let _bound = bind_json_values(query, &values);
632 }
633
634 #[test]
635 fn bind_json_object() {
636 let query = sqlx::query("SELECT ?");
637 let values = vec![serde_json::json!({"key": "value"})];
638 let _bound = bind_json_values(query, &values);
639 }
640
641 #[test]
642 fn bind_multiple_values() {
643 let query = sqlx::query("SELECT ?, ?, ?");
644 let values = vec![
645 serde_json::json!(1),
646 serde_json::json!("test"),
647 serde_json::Value::Null,
648 ];
649 let _bound = bind_json_values(query, &values);
650 }
651
652 #[test]
654 fn expected_update_count_validation() {
655 let config = SqlEndpointConfig::from_uri(
657 "sql:update t set x=1?db_url=postgres://localhost/test&expectedUpdateCount=5",
658 )
659 .unwrap();
660 assert_eq!(config.expected_update_count, Some(5));
661
662 let config_default = self::config();
664 assert_eq!(config_default.expected_update_count, None);
665
666 let config_neg = SqlEndpointConfig::from_uri(
668 "sql:update t set x=1?db_url=postgres://localhost/test&expectedUpdateCount=-1",
669 )
670 .unwrap();
671 assert_eq!(config_neg.expected_update_count, Some(-1));
672 }
673
674 #[test]
676 fn parameters_header_override_logic() {
677 let mut prepared = PreparedQuery {
679 sql: "SELECT * FROM t WHERE id = $1".to_string(),
680 bindings: vec![serde_json::json!(42)],
681 };
682
683 let header_params = serde_json::json!([99, "extra"]);
685 if let Some(arr) = header_params.as_array() {
686 prepared.bindings = arr.clone();
687 }
688
689 assert_eq!(prepared.bindings.len(), 2);
691 assert_eq!(prepared.bindings[0], serde_json::json!(99));
692 assert_eq!(prepared.bindings[1], serde_json::json!("extra"));
693
694 let mut prepared2 = PreparedQuery {
696 sql: "SELECT * FROM t WHERE id = $1".to_string(),
697 bindings: vec![serde_json::json!(42)],
698 };
699 let header_non_array = serde_json::json!({"not": "an array"});
700 if let Some(arr) = header_non_array.as_array() {
701 prepared2.bindings = arr.clone();
702 }
703 assert_eq!(prepared2.bindings.len(), 1);
705 assert_eq!(prepared2.bindings[0], serde_json::json!(42));
706 }
707
708 #[tokio::test]
709 async fn execute_select_one_sets_body_and_row_count() {
710 let pool = sqlite_pool().await;
711 seed_items_table(&pool).await;
712
713 let mut config = SqlEndpointConfig::from_uri(
714 "sql:select id, name from items order by id?db_url=sqlite::memory:&outputType=SelectOne",
715 )
716 .unwrap();
717 config.resolve_defaults();
718
719 let prepared = PreparedQuery {
720 sql: "select id, name from items order by id".to_string(),
721 bindings: vec![],
722 };
723 let mut exchange = Exchange::new(Message::default());
724
725 execute_select(&pool, &prepared, &config, &mut exchange)
726 .await
727 .expect("select one");
728
729 assert_eq!(exchange.input.header(headers::ROW_COUNT), Some(&json!(2)));
730 assert_eq!(
731 exchange.input.body,
732 Body::Json(json!({"id": 1, "name": "a"}))
733 );
734 }
735
736 #[tokio::test]
737 async fn execute_stream_list_materializes_ndjson() {
738 let pool = sqlite_pool().await;
739 seed_items_table(&pool).await;
740
741 let mut config = SqlEndpointConfig::from_uri(
742 "sql:select id from items order by id?db_url=sqlite::memory:&outputType=StreamList",
743 )
744 .unwrap();
745 config.resolve_defaults();
746
747 let prepared = PreparedQuery {
748 sql: "select id from items order by id".to_string(),
749 bindings: vec![],
750 };
751 let mut exchange = Exchange::new(Message::default());
752
753 execute_select(&pool, &prepared, &config, &mut exchange)
754 .await
755 .expect("stream list");
756
757 let bytes = exchange
758 .input
759 .body
760 .clone()
761 .into_bytes(1024)
762 .await
763 .expect("stream bytes");
764 let text = String::from_utf8(bytes.to_vec()).expect("utf8");
765 assert!(text.contains("{\"id\":1}"));
766 assert!(text.contains("{\"id\":2}"));
767 assert_eq!(exchange.input.header(headers::ROW_COUNT), None);
768 }
769
770 #[tokio::test]
771 async fn execute_modify_expected_update_count_mismatch_returns_error() {
772 let pool = sqlite_pool().await;
773 seed_items_table(&pool).await;
774
775 let mut config = SqlEndpointConfig::from_uri(
776 "sql:update items set done=1 where id = #?db_url=sqlite::memory:&expectedUpdateCount=2",
777 )
778 .unwrap();
779 config.resolve_defaults();
780
781 let prepared = PreparedQuery {
782 sql: "update items set done=1 where id = $1".to_string(),
783 bindings: vec![json!(1)],
784 };
785 let mut exchange = Exchange::new(Message::default());
786
787 let err = execute_modify(&pool, &prepared, &config, &mut exchange)
788 .await
789 .expect_err("must fail due expected row count mismatch");
790 assert!(err.to_string().contains("Expected 2 rows affected, got 1"));
791 }
792
793 #[tokio::test]
794 async fn execute_batch_rollback_when_any_item_fails_expected_count() {
795 let pool = sqlite_pool().await;
796 seed_items_table(&pool).await;
797
798 let mut config = SqlEndpointConfig::from_uri(
799 "sql:update items set done=1 where id = #?db_url=sqlite::memory:&batch=true&expectedUpdateCount=1",
800 )
801 .unwrap();
802 config.resolve_defaults();
803
804 let mut exchange = Exchange::new(Message::new(Body::Json(json!([[1], [999]]))));
805
806 let err = execute_batch(&pool, &config, &mut exchange)
807 .await
808 .expect_err("second batch item should fail expectedUpdateCount");
809 assert!(
810 err.to_string()
811 .contains("Batch item 1: expected 1 rows affected, got 0")
812 );
813
814 let row = sqlx::query("select done from items where id = 1")
815 .fetch_one(&pool)
816 .await
817 .expect("query row");
818 let done: i64 = sqlx::Row::try_get(&row, 0).expect("done column");
819 assert_eq!(done, 0, "transaction must rollback first update");
820 }
821
822 #[tokio::test]
828 async fn producer_no_panic_without_prior_resolve_defaults() {
829 let config = SqlEndpointConfig::from_uri("sql:select 1?db_url=sqlite::memory:").unwrap();
831 assert!(config.max_connections.is_none());
832
833 let mut producer = SqlProducer::new(config, Arc::new(OnceCell::new()));
834 let exchange = Exchange::new(Message::default());
835
836 let result = producer.call(exchange).await;
838 assert!(
839 result.is_ok(),
840 "Producer should initialize pool without panic, got: {:?}",
841 result
842 );
843 }
844
845 #[tokio::test]
847 async fn producer_pool_init_returns_config_error_for_invalid_db() {
848 let mut config = SqlEndpointConfig::from_uri(
850 "sql:select 1?db_url=postgres://nonexistent-host:5432/nonexistent_db",
851 )
852 .unwrap();
853 config.max_connections = Some(1);
855 config.min_connections = Some(0);
856 config.idle_timeout_secs = Some(300);
857 config.max_lifetime_secs = Some(1800);
858
859 let mut producer = SqlProducer::new(config, Arc::new(OnceCell::new()));
860 let exchange = Exchange::new(Message::default());
861
862 let result = producer.call(exchange).await;
863 assert!(result.is_err());
864 let err_msg = result.unwrap_err().to_string();
866 assert!(
867 err_msg.contains("Failed to connect") || err_msg.contains("database"),
868 "Expected connection error, got: {}",
869 err_msg
870 );
871 }
872
873 #[test]
875 fn poll_ready_returns_ready_for_uninitialized_pool() {
876 let config = {
877 let mut c = SqlEndpointConfig::from_uri("sql:select 1?db_url=sqlite::memory:").unwrap();
878 c.resolve_defaults();
879 c
880 };
881 let mut producer = SqlProducer::new(config, Arc::new(OnceCell::new()));
882 let mut cx = Context::from_waker(futures::task::noop_waker_ref());
883 let result = producer.poll_ready(&mut cx);
884 assert!(matches!(result, Poll::Ready(Ok(()))));
885 }
886
887 #[test]
889 fn poll_ready_returns_error_when_stopped() {
890 let config = {
891 let mut c = SqlEndpointConfig::from_uri("sql:select 1?db_url=sqlite::memory:").unwrap();
892 c.resolve_defaults();
893 c
894 };
895 let mut producer = SqlProducer::new(config, Arc::new(OnceCell::new()));
896 producer.stop();
897 let mut cx = Context::from_waker(futures::task::noop_waker_ref());
898 let result = producer.poll_ready(&mut cx);
899 assert!(matches!(result, Poll::Ready(Err(_))));
900 let err_msg = match result {
901 Poll::Ready(Err(e)) => e.to_string(),
902 _ => unreachable!(),
903 };
904 assert!(err_msg.contains("SQL producer stopped"));
905 }
906
907 #[tokio::test]
909 async fn poll_ready_returns_error_when_pool_closed() {
910 let pool = sqlite_pool().await;
911 pool.close().await;
912
913 let config = {
914 let mut c = SqlEndpointConfig::from_uri("sql:select 1?db_url=sqlite::memory:").unwrap();
915 c.resolve_defaults();
916 c
917 };
918 let pool_cell = Arc::new(OnceCell::new());
919 pool_cell.set(pool).unwrap();
920
921 let mut producer = SqlProducer::new(config, pool_cell);
922 let mut cx = Context::from_waker(futures::task::noop_waker_ref());
923 let result = producer.poll_ready(&mut cx);
924 assert!(matches!(result, Poll::Ready(Err(_))));
925 let err_msg = match result {
926 Poll::Ready(Err(e)) => e.to_string(),
927 _ => unreachable!(),
928 };
929 assert!(err_msg.contains("SQL connection pool is closed"));
930 }
931
932 #[tokio::test]
934 async fn poll_ready_returns_ok_for_healthy_pool() {
935 let pool = sqlite_pool().await;
936
937 let config = {
938 let mut c = SqlEndpointConfig::from_uri("sql:select 1?db_url=sqlite::memory:").unwrap();
939 c.resolve_defaults();
940 c
941 };
942 let pool_cell = Arc::new(OnceCell::new());
943 pool_cell.set(pool).unwrap();
944
945 let mut producer = SqlProducer::new(config, pool_cell);
946 let mut cx = Context::from_waker(futures::task::noop_waker_ref());
947 let result = producer.poll_ready(&mut cx);
948 assert!(matches!(result, Poll::Ready(Ok(()))));
949 }
950
951 #[tokio::test]
953 async fn test_sql_stop_closes_pool() {
954 let pool = sqlite_pool().await;
955
956 let config = {
957 let mut c = SqlEndpointConfig::from_uri("sql:select 1?db_url=sqlite::memory:").unwrap();
958 c.resolve_defaults();
959 c
960 };
961 let pool_cell = Arc::new(OnceCell::new());
962 pool_cell.set(pool.clone()).unwrap();
963
964 let producer = SqlProducer::new(config, pool_cell.clone());
965 assert!(!pool.is_closed(), "Pool should be open before stop");
966
967 producer.stop();
968
969 tokio::time::sleep(Duration::from_millis(100)).await;
971
972 assert!(
973 pool.is_closed(),
974 "Pool should be closed after producer.stop()"
975 );
976
977 let mut producer2 = SqlProducer::new(
979 {
980 let mut c =
981 SqlEndpointConfig::from_uri("sql:select 1?db_url=sqlite::memory:").unwrap();
982 c.resolve_defaults();
983 c
984 },
985 pool_cell.clone(),
986 );
987 let mut cx = Context::from_waker(futures::task::noop_waker_ref());
988 let result = producer2.poll_ready(&mut cx);
989 assert!(
990 matches!(result, Poll::Ready(Err(_))),
991 "poll_ready should fail after pool closed"
992 );
993 }
994
995 #[tokio::test]
997 async fn use_placeholder_false_executes_raw_sql() {
998 let pool = sqlite_pool().await;
999 seed_items_table(&pool).await;
1000
1001 let mut config = SqlEndpointConfig::from_uri(
1002 "sql:select id, name from items order by id?db_url=sqlite::memory:&usePlaceholder=false",
1003 )
1004 .unwrap();
1005 config.resolve_defaults();
1006 assert!(!config.use_placeholder);
1007
1008 let mut producer = SqlProducer::new(config, Arc::new(OnceCell::new()));
1009 producer.pool.set(pool.clone()).unwrap();
1011
1012 let exchange = Exchange::new(Message::default());
1013 let result = producer.call(exchange).await;
1014 assert!(result.is_ok());
1015 let exchange = result.unwrap();
1016 assert!(matches!(exchange.input.body, Body::Json(_)));
1018 }
1019
1020 #[tokio::test]
1022 async fn use_placeholder_true_processes_placeholders() {
1023 let pool = sqlite_pool().await;
1024 seed_items_table(&pool).await;
1025
1026 let mut config = SqlEndpointConfig::from_uri(
1027 "sql:select id, name from items where id = #?db_url=sqlite::memory:",
1028 )
1029 .unwrap();
1030 config.resolve_defaults();
1031 assert!(config.use_placeholder);
1032
1033 let mut producer = SqlProducer::new(config, Arc::new(OnceCell::new()));
1034 producer.pool.set(pool.clone()).unwrap();
1035
1036 let msg = Message::new(Body::Json(json!([1])));
1037 let exchange = Exchange::new(msg);
1038 let result = producer.call(exchange).await;
1039 assert!(result.is_ok());
1040 }
1041}