1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::{Context, Poll};
5use std::time::Duration;
6
7use bytes::Bytes;
8use serde_json::json;
9use sqlx::AnyPool;
10use sqlx::any::AnyRow;
11use sqlx::pool::PoolOptions;
12use tokio::sync::OnceCell;
13use tower::Service;
14use tracing::{debug, error, warn};
15
16use crate::config::{SqlConfig, SqlOutputType};
17use crate::headers;
18use crate::query::{PreparedQuery, is_select_query, parse_query_template, resolve_params};
19use crate::utils::{bind_json_values, row_to_json};
20use camel_api::{Body, CamelError, Exchange, Message, StreamBody, StreamMetadata};
21
22#[derive(Clone)]
23pub struct SqlProducer {
24 pub(crate) config: SqlConfig,
25 pub(crate) pool: Arc<OnceCell<AnyPool>>,
26}
27
28impl SqlProducer {
29 pub fn new(config: SqlConfig, pool: Arc<OnceCell<AnyPool>>) -> Self {
30 Self { config, pool }
31 }
32
33 pub(crate) fn resolve_query_source(exchange: &Exchange, config: &SqlConfig) -> String {
38 if let Some(query_value) = exchange.input.header(headers::QUERY)
40 && let Some(query_str) = query_value.as_str()
41 {
42 return query_str.to_string();
43 }
44
45 if config.use_message_body_for_sql
47 && let Some(body_text) = exchange.input.body.as_text()
48 {
49 return body_text.to_string();
50 }
51
52 config.query.clone()
54 }
55}
56
57impl Service<Exchange> for SqlProducer {
58 type Response = Exchange;
59 type Error = CamelError;
60 type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
61
62 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
63 Poll::Ready(Ok(()))
64 }
65
66 fn call(&mut self, mut exchange: Exchange) -> Self::Future {
67 let config = self.config.clone();
68 let pool_cell = Arc::clone(&self.pool);
69
70 Box::pin(async move {
71 let pool: &AnyPool = pool_cell
73 .get_or_try_init(|| async {
74 sqlx::any::install_default_drivers();
77 let opts: PoolOptions<sqlx::Any> = PoolOptions::new()
78 .max_connections(config.max_connections)
79 .min_connections(config.min_connections)
80 .idle_timeout(Duration::from_secs(config.idle_timeout_secs))
81 .max_lifetime(Duration::from_secs(config.max_lifetime_secs));
82 opts.connect(&config.db_url).await.map_err(|e| {
83 error!("Failed to connect to database: {}", e);
84 CamelError::EndpointCreationFailed(format!(
85 "Failed to connect to database: {}",
86 e
87 ))
88 })
89 })
90 .await
91 .map_err(|e: CamelError| {
92 error!("Pool initialization failed: {}", e);
93 e.clone()
94 })?;
95
96 let query_str = Self::resolve_query_source(&exchange, &config);
98
99 debug!("Executing SQL: {}", query_str);
100
101 if config.batch {
103 execute_batch(pool, &config, &mut exchange).await?;
105 } else {
106 let template = parse_query_template(&query_str, config.placeholder)?;
108 let mut prepared = resolve_params(&template, &exchange)?;
109
110 if let Some(params_value) = exchange.input.header(headers::PARAMETERS) {
112 if let Some(arr) = params_value.as_array() {
113 if arr.len() != prepared.bindings.len() {
114 warn!(
115 expected = prepared.bindings.len(),
116 got = arr.len(),
117 header = headers::PARAMETERS,
118 "Parameter count mismatch — SQL has {} placeholders but header provides {} values",
119 prepared.bindings.len(),
120 arr.len()
121 );
122 }
123 debug!(
124 "Overriding bindings from {} header with {} parameters",
125 headers::PARAMETERS,
126 arr.len()
127 );
128 prepared.bindings = arr.clone();
129 } else {
130 warn!(
131 header = headers::PARAMETERS,
132 "Header is present but not a JSON array — ignoring parameter override"
133 );
134 }
135 }
136
137 debug!("Executing SQL: {}", prepared.sql);
138
139 if is_select_query(&prepared.sql) {
140 execute_select(pool, &prepared, &config, &mut exchange).await?;
141 } else {
142 execute_modify(pool, &prepared, &config, &mut exchange).await?;
143 }
144 }
145
146 Ok(exchange)
147 })
148 }
149}
150
151async fn execute_select(
153 pool: &AnyPool,
154 prepared: &PreparedQuery,
155 config: &SqlConfig,
156 exchange: &mut Exchange,
157) -> Result<(), CamelError> {
158 match config.output_type {
159 SqlOutputType::SelectOne => {
160 let mut query = sqlx::query(&prepared.sql);
162 query = bind_json_values(query, &prepared.bindings);
163
164 let rows: Vec<AnyRow> = query.fetch_all(pool).await.map_err(|e| {
165 error!("Query execution failed: {}", e);
166 CamelError::ProcessorError(format!("Query execution failed: {}", e))
167 })?;
168
169 let count = rows.len();
170 let json_rows: Vec<serde_json::Value> = rows
171 .iter()
172 .map(row_to_json)
173 .collect::<Result<Vec<_>, _>>()?;
174
175 if let Some(first_row) = json_rows.into_iter().next() {
176 exchange.input.body = Body::Json(first_row);
177 } else {
178 exchange.input.body = Body::Empty;
179 }
180 debug!("SelectOne returned {} row", if count > 0 { 1 } else { 0 });
181 exchange
182 .input
183 .set_header(headers::ROW_COUNT, serde_json::json!(count));
184 }
185 SqlOutputType::SelectList => {
186 let mut query = sqlx::query(&prepared.sql);
188 query = bind_json_values(query, &prepared.bindings);
189
190 let rows: Vec<AnyRow> = query.fetch_all(pool).await.map_err(|e| {
191 error!("Query execution failed: {}", e);
192 CamelError::ProcessorError(format!("Query execution failed: {}", e))
193 })?;
194
195 let count = rows.len();
196 let json_rows: Vec<serde_json::Value> = rows
197 .iter()
198 .map(row_to_json)
199 .collect::<Result<Vec<_>, _>>()?;
200
201 exchange.input.body = Body::Json(serde_json::Value::Array(json_rows));
202 debug!("SelectList returned {} rows", count);
203 exchange
204 .input
205 .set_header(headers::ROW_COUNT, serde_json::json!(count));
206 }
207 SqlOutputType::StreamList => {
208 use futures::TryStreamExt;
210
211 let pool_clone = pool.clone();
212 let sql_str = prepared.sql.clone();
213 let bindings = prepared.bindings.clone();
214
215 let byte_stream = async_stream::try_stream! {
217 let mut q = sqlx::query(&sql_str);
218 q = bind_json_values(q, &bindings);
219 let mut rows = q.fetch(&pool_clone);
220 while let Some(row) = rows.try_next().await.map_err(|e| {
221 CamelError::ProcessorError(format!("Query execution failed: {}", e))
222 })? {
223 let json_val = row_to_json(&row).map_err(|e| {
224 CamelError::ProcessorError(format!("JSON serialization failed: {}", e))
225 })?;
226 let mut bytes = serde_json::to_vec(&json_val)
227 .map_err(|e| CamelError::ProcessorError(format!("JSON serialization failed: {}", e)))?;
228 bytes.push(b'\n');
229 yield Bytes::from(bytes);
230 }
231 };
232
233 exchange.input.body = Body::Stream(StreamBody {
234 stream: Arc::new(tokio::sync::Mutex::new(Some(Box::pin(byte_stream)))),
235 metadata: StreamMetadata {
236 content_type: Some("application/x-ndjson".to_string()),
237 size_hint: None,
238 origin: None,
239 },
240 });
241 debug!("StreamList: created lazy stream (rows fetched on demand)");
242 }
244 }
245
246 Ok(())
247}
248
249async fn execute_modify(
251 pool: &AnyPool,
252 prepared: &PreparedQuery,
253 config: &SqlConfig,
254 exchange: &mut Exchange,
255) -> Result<(), CamelError> {
256 let mut query = sqlx::query(&prepared.sql);
257 query = bind_json_values(query, &prepared.bindings);
258
259 let result = query.execute(pool).await.map_err(|e| {
260 error!("Query execution failed: {}", e);
261 CamelError::ProcessorError(format!("Query execution failed: {}", e))
262 })?;
263
264 let rows_affected = result.rows_affected();
265
266 if let Some(expected) = config.expected_update_count
268 && rows_affected as i64 != expected
269 {
270 error!("Expected {} rows affected, got {}", expected, rows_affected);
271 return Err(CamelError::ProcessorError(format!(
272 "Expected {} rows affected, got {}",
273 expected, rows_affected
274 )));
275 }
276
277 exchange
278 .input
279 .set_header(headers::UPDATE_COUNT, serde_json::json!(rows_affected));
280
281 if config.noop {
282 } else {
284 exchange.input.body = Body::Json(json!({ "rowsAffected": rows_affected }));
285 }
286
287 debug!("Modify query affected {} rows", rows_affected);
288
289 Ok(())
290}
291
292async fn execute_batch(
294 pool: &AnyPool,
295 config: &SqlConfig,
296 exchange: &mut Exchange,
297) -> Result<(), CamelError> {
298 let body_json = match &exchange.input.body {
300 Body::Json(val) => val,
301 _ => {
302 return Err(CamelError::ProcessorError(
303 "Batch mode requires body to be a JSON array of arrays".to_string(),
304 ));
305 }
306 };
307
308 let batch_data = body_json
309 .as_array()
310 .ok_or_else(|| {
311 CamelError::ProcessorError("Batch mode requires body to be a JSON array".to_string())
312 })?
313 .clone();
314
315 let template = parse_query_template(&config.query, config.placeholder)?;
317
318 let mut tx = pool.begin().await.map_err(|e| {
320 error!("Failed to begin transaction: {}", e);
321 CamelError::ProcessorError(format!("Failed to begin transaction: {}", e))
322 })?;
323
324 let mut total_rows_affected: u64 = 0;
325
326 for (batch_idx, params_array) in batch_data.into_iter().enumerate() {
327 params_array.as_array().ok_or_else(|| {
329 CamelError::ProcessorError(format!(
330 "Batch item at index {} must be a JSON array of parameters",
331 batch_idx
332 ))
333 })?;
334
335 let temp_msg = Message::new(Body::Json(params_array.clone()));
337 let temp_exchange = Exchange::new(temp_msg);
338
339 let prepared = resolve_params(&template, &temp_exchange)?;
341
342 let mut query = sqlx::query(&prepared.sql);
344 query = bind_json_values(query, &prepared.bindings);
345
346 let result = query.execute(&mut *tx).await.map_err(|e| {
347 error!("Batch query execution failed at index {}: {}", batch_idx, e);
348 CamelError::ProcessorError(format!("Batch query execution failed: {}", e))
349 })?;
350
351 if let Some(expected) = config.expected_update_count
353 && result.rows_affected() as i64 != expected
354 {
355 error!(
356 "Batch item {}: expected {} rows affected, got {}",
357 batch_idx,
358 expected,
359 result.rows_affected()
360 );
361 return Err(CamelError::ProcessorError(format!(
362 "Batch item {}: expected {} rows affected, got {}",
363 batch_idx,
364 expected,
365 result.rows_affected()
366 )));
367 }
368
369 total_rows_affected += result.rows_affected();
370 }
371
372 tx.commit().await.map_err(|e| {
374 error!("Failed to commit transaction: {}", e);
375 CamelError::ProcessorError(format!("Failed to commit transaction: {}", e))
376 })?;
377
378 exchange.input.set_header(
379 headers::UPDATE_COUNT,
380 serde_json::json!(total_rows_affected),
381 );
382
383 debug!(
384 "Batch execution completed, total rows affected: {}",
385 total_rows_affected
386 );
387
388 Ok(())
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394 use camel_api::Message;
395 use std::sync::Arc;
396 use tokio::sync::OnceCell;
397
398 fn test_config() -> SqlConfig {
399 SqlConfig::from_uri("sql:select 1?db_url=postgres://localhost/test").unwrap()
400 }
401
402 #[test]
403 fn test_producer_clone_shares_pool() {
404 let p1 = SqlProducer::new(test_config(), Arc::new(OnceCell::new()));
405 let p2 = p1.clone();
406 assert!(Arc::ptr_eq(&p1.pool, &p2.pool));
407 }
408
409 #[test]
410 fn test_resolve_query_from_config() {
411 let config = test_config();
412 let ex = Exchange::new(Message::default());
413 let q = SqlProducer::resolve_query_source(&ex, &config);
414 assert_eq!(q, "select 1");
415 }
416
417 #[test]
418 fn test_resolve_query_from_header() {
419 let config = test_config();
420 let mut msg = Message::default();
421 msg.set_header(headers::QUERY, serde_json::json!("select 2"));
422 let ex = Exchange::new(msg);
423 let q = SqlProducer::resolve_query_source(&ex, &config);
424 assert_eq!(q, "select 2");
425 }
426
427 #[test]
428 fn test_resolve_query_from_body() {
429 let mut config = test_config();
430 config.use_message_body_for_sql = true;
431 let msg = Message::new(Body::Text("select 3".to_string()));
432 let ex = Exchange::new(msg);
433 let q = SqlProducer::resolve_query_source(&ex, &config);
434 assert_eq!(q, "select 3");
435 }
436
437 #[test]
438 fn test_resolve_query_header_priority_over_body() {
439 let mut config = test_config();
440 config.use_message_body_for_sql = true;
441 let mut msg = Message::new(Body::Text("select from body".to_string()));
442 msg.set_header(headers::QUERY, serde_json::json!("select from header"));
443 let ex = Exchange::new(msg);
444 let q = SqlProducer::resolve_query_source(&ex, &config);
445 assert_eq!(q, "select from header");
446 }
447
448 #[test]
449 fn test_resolve_query_body_priority_over_config() {
450 let mut config = test_config();
451 config.use_message_body_for_sql = true;
452 let msg = Message::new(Body::Text("select from body".to_string()));
453 let ex = Exchange::new(msg);
454 let q = SqlProducer::resolve_query_source(&ex, &config);
455 assert_eq!(q, "select from body");
456 }
457
458 #[test]
459 fn test_bind_json_null() {
460 let query = sqlx::query("SELECT ?");
461 let values = vec![serde_json::Value::Null];
462 let _bound = bind_json_values(query, &values);
463 }
465
466 #[test]
467 fn test_bind_json_bool() {
468 let query = sqlx::query("SELECT ?");
469 let values = vec![serde_json::Value::Bool(true)];
470 let _bound = bind_json_values(query, &values);
471 }
472
473 #[test]
474 fn test_bind_json_number_i64() {
475 let query = sqlx::query("SELECT ?");
476 let values = vec![serde_json::json!(42)];
477 let _bound = bind_json_values(query, &values);
478 }
479
480 #[test]
481 fn test_bind_json_number_f64() {
482 let query = sqlx::query("SELECT ?");
483 let values = vec![serde_json::json!(std::f64::consts::PI)];
484 let _bound = bind_json_values(query, &values);
485 }
486
487 #[test]
488 fn test_bind_json_string() {
489 let query = sqlx::query("SELECT ?");
490 let values = vec![serde_json::json!("hello world")];
491 let _bound = bind_json_values(query, &values);
492 }
493
494 #[test]
495 fn test_bind_json_array() {
496 let query = sqlx::query("SELECT ?");
497 let values = vec![serde_json::json!([1, 2, 3])];
498 let _bound = bind_json_values(query, &values);
499 }
500
501 #[test]
502 fn test_bind_json_object() {
503 let query = sqlx::query("SELECT ?");
504 let values = vec![serde_json::json!({"key": "value"})];
505 let _bound = bind_json_values(query, &values);
506 }
507
508 #[test]
509 fn test_bind_multiple_values() {
510 let query = sqlx::query("SELECT ?, ?, ?");
511 let values = vec![
512 serde_json::json!(1),
513 serde_json::json!("test"),
514 serde_json::Value::Null,
515 ];
516 let _bound = bind_json_values(query, &values);
517 }
518
519 #[test]
521 fn test_expected_update_count_validation() {
522 let config = SqlConfig::from_uri(
524 "sql:update t set x=1?db_url=postgres://localhost/test&expectedUpdateCount=5",
525 )
526 .unwrap();
527 assert_eq!(config.expected_update_count, Some(5));
528
529 let config_default = test_config();
531 assert_eq!(config_default.expected_update_count, None);
532
533 let config_neg = SqlConfig::from_uri(
535 "sql:update t set x=1?db_url=postgres://localhost/test&expectedUpdateCount=-1",
536 )
537 .unwrap();
538 assert_eq!(config_neg.expected_update_count, Some(-1));
539 }
540
541 #[test]
543 fn test_parameters_header_override_logic() {
544 let mut prepared = PreparedQuery {
546 sql: "SELECT * FROM t WHERE id = $1".to_string(),
547 bindings: vec![serde_json::json!(42)],
548 };
549
550 let header_params = serde_json::json!([99, "extra"]);
552 if let Some(arr) = header_params.as_array() {
553 prepared.bindings = arr.clone();
554 }
555
556 assert_eq!(prepared.bindings.len(), 2);
558 assert_eq!(prepared.bindings[0], serde_json::json!(99));
559 assert_eq!(prepared.bindings[1], serde_json::json!("extra"));
560
561 let mut prepared2 = PreparedQuery {
563 sql: "SELECT * FROM t WHERE id = $1".to_string(),
564 bindings: vec![serde_json::json!(42)],
565 };
566 let header_non_array = serde_json::json!({"not": "an array"});
567 if let Some(arr) = header_non_array.as_array() {
568 prepared2.bindings = arr.clone();
569 }
570 assert_eq!(prepared2.bindings.len(), 1);
572 assert_eq!(prepared2.bindings[0], serde_json::json!(42));
573 }
574}