1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::{Context, Poll};
5use std::time::Duration;
6
7use bytes::Bytes;
8use serde_json::json;
9use sqlx::AnyPool;
10use sqlx::any::AnyRow;
11use sqlx::pool::PoolOptions;
12use tokio::sync::OnceCell;
13use tower::Service;
14use tracing::{debug, error, warn};
15
16use crate::config::{SqlEndpointConfig, SqlOutputType};
17use crate::headers;
18use crate::query::{PreparedQuery, is_select_query, parse_query_template, resolve_params};
19use crate::utils::{bind_json_values, row_to_json};
20use camel_api::{Body, CamelError, Exchange, Message, StreamBody, StreamMetadata};
21
22#[derive(Clone)]
23pub struct SqlProducer {
24 pub(crate) config: SqlEndpointConfig,
25 pub(crate) pool: Arc<OnceCell<AnyPool>>,
26}
27
28impl SqlProducer {
29 pub fn new(config: SqlEndpointConfig, pool: Arc<OnceCell<AnyPool>>) -> Self {
30 Self { config, pool }
31 }
32
33 pub(crate) fn resolve_query_source(exchange: &Exchange, config: &SqlEndpointConfig) -> String {
38 if let Some(query_value) = exchange.input.header(headers::QUERY)
40 && let Some(query_str) = query_value.as_str()
41 {
42 return query_str.to_string();
43 }
44
45 if config.use_message_body_for_sql
47 && let Some(body_text) = exchange.input.body.as_text()
48 {
49 return body_text.to_string();
50 }
51
52 config.query.clone()
54 }
55}
56
57impl Service<Exchange> for SqlProducer {
58 type Response = Exchange;
59 type Error = CamelError;
60 type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
61
62 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
63 Poll::Ready(Ok(()))
64 }
65
66 fn call(&mut self, mut exchange: Exchange) -> Self::Future {
67 let mut config = self.config.clone();
68 let pool_cell = Arc::clone(&self.pool);
69
70 Box::pin(async move {
71 let pool: &AnyPool = pool_cell
73 .get_or_try_init(|| async {
74 config.resolve_defaults();
76
77 sqlx::any::install_default_drivers();
80 let opts: PoolOptions<sqlx::Any> = PoolOptions::new()
81 .max_connections(
82 config
83 .max_connections
84 .expect("must be Some after resolve_defaults()"),
85 )
86 .min_connections(
87 config
88 .min_connections
89 .expect("must be Some after resolve_defaults()"),
90 )
91 .idle_timeout(Duration::from_secs(
92 config
93 .idle_timeout_secs
94 .expect("must be Some after resolve_defaults()"),
95 ))
96 .max_lifetime(Duration::from_secs(
97 config
98 .max_lifetime_secs
99 .expect("must be Some after resolve_defaults()"),
100 ));
101 opts.connect(&config.db_url).await.map_err(|e| {
102 error!("Failed to connect to database: {}", e);
103 CamelError::EndpointCreationFailed(format!(
104 "Failed to connect to database: {}",
105 e
106 ))
107 })
108 })
109 .await
110 .map_err(|e: CamelError| {
111 error!("Pool initialization failed: {}", e);
112 e.clone()
113 })?;
114
115 let query_str = Self::resolve_query_source(&exchange, &config);
117
118 debug!("Executing SQL: {}", query_str);
119
120 if config.batch {
122 execute_batch(pool, &config, &mut exchange).await?;
124 } else {
125 let template = parse_query_template(&query_str, config.placeholder)?;
127 let mut prepared = resolve_params(&template, &exchange)?;
128
129 if let Some(params_value) = exchange.input.header(headers::PARAMETERS) {
131 if let Some(arr) = params_value.as_array() {
132 if arr.len() != prepared.bindings.len() {
133 warn!(
134 expected = prepared.bindings.len(),
135 got = arr.len(),
136 header = headers::PARAMETERS,
137 "Parameter count mismatch — SQL has {} placeholders but header provides {} values",
138 prepared.bindings.len(),
139 arr.len()
140 );
141 }
142 debug!(
143 "Overriding bindings from {} header with {} parameters",
144 headers::PARAMETERS,
145 arr.len()
146 );
147 prepared.bindings = arr.clone();
148 } else {
149 warn!(
150 header = headers::PARAMETERS,
151 "Header is present but not a JSON array — ignoring parameter override"
152 );
153 }
154 }
155
156 debug!("Executing SQL: {}", prepared.sql);
157
158 if is_select_query(&prepared.sql) {
159 execute_select(pool, &prepared, &config, &mut exchange).await?;
160 } else {
161 execute_modify(pool, &prepared, &config, &mut exchange).await?;
162 }
163 }
164
165 Ok(exchange)
166 })
167 }
168}
169
170async fn execute_select(
172 pool: &AnyPool,
173 prepared: &PreparedQuery,
174 config: &SqlEndpointConfig,
175 exchange: &mut Exchange,
176) -> Result<(), CamelError> {
177 match config.output_type {
178 SqlOutputType::SelectOne => {
179 let mut query = sqlx::query(&prepared.sql);
181 query = bind_json_values(query, &prepared.bindings);
182
183 let rows: Vec<AnyRow> = query.fetch_all(pool).await.map_err(|e| {
184 error!("Query execution failed: {}", e);
185 CamelError::ProcessorError(format!("Query execution failed: {}", e))
186 })?;
187
188 let count = rows.len();
189 let json_rows: Vec<serde_json::Value> = rows
190 .iter()
191 .map(row_to_json)
192 .collect::<Result<Vec<_>, _>>()?;
193
194 if let Some(first_row) = json_rows.into_iter().next() {
195 exchange.input.body = Body::Json(first_row);
196 } else {
197 exchange.input.body = Body::Empty;
198 }
199 debug!("SelectOne returned {} row", if count > 0 { 1 } else { 0 });
200 exchange
201 .input
202 .set_header(headers::ROW_COUNT, serde_json::json!(count));
203 }
204 SqlOutputType::SelectList => {
205 let mut query = sqlx::query(&prepared.sql);
207 query = bind_json_values(query, &prepared.bindings);
208
209 let rows: Vec<AnyRow> = query.fetch_all(pool).await.map_err(|e| {
210 error!("Query execution failed: {}", e);
211 CamelError::ProcessorError(format!("Query execution failed: {}", e))
212 })?;
213
214 let count = rows.len();
215 let json_rows: Vec<serde_json::Value> = rows
216 .iter()
217 .map(row_to_json)
218 .collect::<Result<Vec<_>, _>>()?;
219
220 exchange.input.body = Body::Json(serde_json::Value::Array(json_rows));
221 debug!("SelectList returned {} rows", count);
222 exchange
223 .input
224 .set_header(headers::ROW_COUNT, serde_json::json!(count));
225 }
226 SqlOutputType::StreamList => {
227 use futures::TryStreamExt;
229
230 let pool_clone = pool.clone();
231 let sql_str = prepared.sql.clone();
232 let bindings = prepared.bindings.clone();
233
234 let byte_stream = async_stream::try_stream! {
236 let mut q = sqlx::query(&sql_str);
237 q = bind_json_values(q, &bindings);
238 let mut rows = q.fetch(&pool_clone);
239 while let Some(row) = rows.try_next().await.map_err(|e| {
240 CamelError::ProcessorError(format!("Query execution failed: {}", e))
241 })? {
242 let json_val = row_to_json(&row).map_err(|e| {
243 CamelError::ProcessorError(format!("JSON serialization failed: {}", e))
244 })?;
245 let mut bytes = serde_json::to_vec(&json_val)
246 .map_err(|e| CamelError::ProcessorError(format!("JSON serialization failed: {}", e)))?;
247 bytes.push(b'\n');
248 yield Bytes::from(bytes);
249 }
250 };
251
252 exchange.input.body = Body::Stream(StreamBody {
253 stream: Arc::new(tokio::sync::Mutex::new(Some(Box::pin(byte_stream)))),
254 metadata: StreamMetadata {
255 content_type: Some("application/x-ndjson".to_string()),
256 size_hint: None,
257 origin: None,
258 },
259 });
260 debug!("StreamList: created lazy stream (rows fetched on demand)");
261 }
263 }
264
265 Ok(())
266}
267
268async fn execute_modify(
270 pool: &AnyPool,
271 prepared: &PreparedQuery,
272 config: &SqlEndpointConfig,
273 exchange: &mut Exchange,
274) -> Result<(), CamelError> {
275 let mut query = sqlx::query(&prepared.sql);
276 query = bind_json_values(query, &prepared.bindings);
277
278 let result = query.execute(pool).await.map_err(|e| {
279 error!("Query execution failed: {}", e);
280 CamelError::ProcessorError(format!("Query execution failed: {}", e))
281 })?;
282
283 let rows_affected = result.rows_affected();
284
285 if let Some(expected) = config.expected_update_count
287 && rows_affected as i64 != expected
288 {
289 error!("Expected {} rows affected, got {}", expected, rows_affected);
290 return Err(CamelError::ProcessorError(format!(
291 "Expected {} rows affected, got {}",
292 expected, rows_affected
293 )));
294 }
295
296 exchange
297 .input
298 .set_header(headers::UPDATE_COUNT, serde_json::json!(rows_affected));
299
300 if config.noop {
301 } else {
303 exchange.input.body = Body::Json(json!({ "rowsAffected": rows_affected }));
304 }
305
306 debug!("Modify query affected {} rows", rows_affected);
307
308 Ok(())
309}
310
311async fn execute_batch(
313 pool: &AnyPool,
314 config: &SqlEndpointConfig,
315 exchange: &mut Exchange,
316) -> Result<(), CamelError> {
317 let body_json = match &exchange.input.body {
319 Body::Json(val) => val,
320 _ => {
321 return Err(CamelError::ProcessorError(
322 "Batch mode requires body to be a JSON array of arrays".to_string(),
323 ));
324 }
325 };
326
327 let batch_data = body_json
328 .as_array()
329 .ok_or_else(|| {
330 CamelError::ProcessorError("Batch mode requires body to be a JSON array".to_string())
331 })?
332 .clone();
333
334 let template = parse_query_template(&config.query, config.placeholder)?;
336
337 let mut tx = pool.begin().await.map_err(|e| {
339 error!("Failed to begin transaction: {}", e);
340 CamelError::ProcessorError(format!("Failed to begin transaction: {}", e))
341 })?;
342
343 let mut total_rows_affected: u64 = 0;
344
345 for (batch_idx, params_array) in batch_data.into_iter().enumerate() {
346 params_array.as_array().ok_or_else(|| {
348 CamelError::ProcessorError(format!(
349 "Batch item at index {} must be a JSON array of parameters",
350 batch_idx
351 ))
352 })?;
353
354 let temp_msg = Message::new(Body::Json(params_array.clone()));
356 let temp_exchange = Exchange::new(temp_msg);
357
358 let prepared = resolve_params(&template, &temp_exchange)?;
360
361 let mut query = sqlx::query(&prepared.sql);
363 query = bind_json_values(query, &prepared.bindings);
364
365 let result = query.execute(&mut *tx).await.map_err(|e| {
366 error!("Batch query execution failed at index {}: {}", batch_idx, e);
367 CamelError::ProcessorError(format!("Batch query execution failed: {}", e))
368 })?;
369
370 if let Some(expected) = config.expected_update_count
372 && result.rows_affected() as i64 != expected
373 {
374 error!(
375 "Batch item {}: expected {} rows affected, got {}",
376 batch_idx,
377 expected,
378 result.rows_affected()
379 );
380 return Err(CamelError::ProcessorError(format!(
381 "Batch item {}: expected {} rows affected, got {}",
382 batch_idx,
383 expected,
384 result.rows_affected()
385 )));
386 }
387
388 total_rows_affected += result.rows_affected();
389 }
390
391 tx.commit().await.map_err(|e| {
393 error!("Failed to commit transaction: {}", e);
394 CamelError::ProcessorError(format!("Failed to commit transaction: {}", e))
395 })?;
396
397 exchange.input.set_header(
398 headers::UPDATE_COUNT,
399 serde_json::json!(total_rows_affected),
400 );
401
402 debug!(
403 "Batch execution completed, total rows affected: {}",
404 total_rows_affected
405 );
406
407 Ok(())
408}
409
410#[cfg(test)]
411mod tests {
412 use super::*;
413 use camel_api::Message;
414 use camel_endpoint::UriConfig;
415 use std::sync::Arc;
416 use tokio::sync::OnceCell;
417
418 fn test_config() -> SqlEndpointConfig {
419 let mut c =
420 SqlEndpointConfig::from_uri("sql:select 1?db_url=postgres://localhost/test").unwrap();
421 c.resolve_defaults();
422 c
423 }
424
425 #[test]
426 fn test_producer_clone_shares_pool() {
427 let p1 = SqlProducer::new(test_config(), Arc::new(OnceCell::new()));
428 let p2 = p1.clone();
429 assert!(Arc::ptr_eq(&p1.pool, &p2.pool));
430 }
431
432 #[test]
433 fn test_resolve_query_from_config() {
434 let config = test_config();
435 let ex = Exchange::new(Message::default());
436 let q = SqlProducer::resolve_query_source(&ex, &config);
437 assert_eq!(q, "select 1");
438 }
439
440 #[test]
441 fn test_resolve_query_from_header() {
442 let config = test_config();
443 let mut msg = Message::default();
444 msg.set_header(headers::QUERY, serde_json::json!("select 2"));
445 let ex = Exchange::new(msg);
446 let q = SqlProducer::resolve_query_source(&ex, &config);
447 assert_eq!(q, "select 2");
448 }
449
450 #[test]
451 fn test_resolve_query_from_body() {
452 let mut config = test_config();
453 config.use_message_body_for_sql = true;
454 let msg = Message::new(Body::Text("select 3".to_string()));
455 let ex = Exchange::new(msg);
456 let q = SqlProducer::resolve_query_source(&ex, &config);
457 assert_eq!(q, "select 3");
458 }
459
460 #[test]
461 fn test_resolve_query_header_priority_over_body() {
462 let mut config = test_config();
463 config.use_message_body_for_sql = true;
464 let mut msg = Message::new(Body::Text("select from body".to_string()));
465 msg.set_header(headers::QUERY, serde_json::json!("select from header"));
466 let ex = Exchange::new(msg);
467 let q = SqlProducer::resolve_query_source(&ex, &config);
468 assert_eq!(q, "select from header");
469 }
470
471 #[test]
472 fn test_resolve_query_body_priority_over_config() {
473 let mut config = test_config();
474 config.use_message_body_for_sql = true;
475 let msg = Message::new(Body::Text("select from body".to_string()));
476 let ex = Exchange::new(msg);
477 let q = SqlProducer::resolve_query_source(&ex, &config);
478 assert_eq!(q, "select from body");
479 }
480
481 #[test]
482 fn test_bind_json_null() {
483 let query = sqlx::query("SELECT ?");
484 let values = vec![serde_json::Value::Null];
485 let _bound = bind_json_values(query, &values);
486 }
488
489 #[test]
490 fn test_bind_json_bool() {
491 let query = sqlx::query("SELECT ?");
492 let values = vec![serde_json::Value::Bool(true)];
493 let _bound = bind_json_values(query, &values);
494 }
495
496 #[test]
497 fn test_bind_json_number_i64() {
498 let query = sqlx::query("SELECT ?");
499 let values = vec![serde_json::json!(42)];
500 let _bound = bind_json_values(query, &values);
501 }
502
503 #[test]
504 fn test_bind_json_number_f64() {
505 let query = sqlx::query("SELECT ?");
506 let values = vec![serde_json::json!(std::f64::consts::PI)];
507 let _bound = bind_json_values(query, &values);
508 }
509
510 #[test]
511 fn test_bind_json_string() {
512 let query = sqlx::query("SELECT ?");
513 let values = vec![serde_json::json!("hello world")];
514 let _bound = bind_json_values(query, &values);
515 }
516
517 #[test]
518 fn test_bind_json_array() {
519 let query = sqlx::query("SELECT ?");
520 let values = vec![serde_json::json!([1, 2, 3])];
521 let _bound = bind_json_values(query, &values);
522 }
523
524 #[test]
525 fn test_bind_json_object() {
526 let query = sqlx::query("SELECT ?");
527 let values = vec![serde_json::json!({"key": "value"})];
528 let _bound = bind_json_values(query, &values);
529 }
530
531 #[test]
532 fn test_bind_multiple_values() {
533 let query = sqlx::query("SELECT ?, ?, ?");
534 let values = vec![
535 serde_json::json!(1),
536 serde_json::json!("test"),
537 serde_json::Value::Null,
538 ];
539 let _bound = bind_json_values(query, &values);
540 }
541
542 #[test]
544 fn test_expected_update_count_validation() {
545 let config = SqlEndpointConfig::from_uri(
547 "sql:update t set x=1?db_url=postgres://localhost/test&expectedUpdateCount=5",
548 )
549 .unwrap();
550 assert_eq!(config.expected_update_count, Some(5));
551
552 let config_default = test_config();
554 assert_eq!(config_default.expected_update_count, None);
555
556 let config_neg = SqlEndpointConfig::from_uri(
558 "sql:update t set x=1?db_url=postgres://localhost/test&expectedUpdateCount=-1",
559 )
560 .unwrap();
561 assert_eq!(config_neg.expected_update_count, Some(-1));
562 }
563
564 #[test]
566 fn test_parameters_header_override_logic() {
567 let mut prepared = PreparedQuery {
569 sql: "SELECT * FROM t WHERE id = $1".to_string(),
570 bindings: vec![serde_json::json!(42)],
571 };
572
573 let header_params = serde_json::json!([99, "extra"]);
575 if let Some(arr) = header_params.as_array() {
576 prepared.bindings = arr.clone();
577 }
578
579 assert_eq!(prepared.bindings.len(), 2);
581 assert_eq!(prepared.bindings[0], serde_json::json!(99));
582 assert_eq!(prepared.bindings[1], serde_json::json!("extra"));
583
584 let mut prepared2 = PreparedQuery {
586 sql: "SELECT * FROM t WHERE id = $1".to_string(),
587 bindings: vec![serde_json::json!(42)],
588 };
589 let header_non_array = serde_json::json!({"not": "an array"});
590 if let Some(arr) = header_non_array.as_array() {
591 prepared2.bindings = arr.clone();
592 }
593 assert_eq!(prepared2.bindings.len(), 1);
595 assert_eq!(prepared2.bindings[0], serde_json::json!(42));
596 }
597}