1use std::str::FromStr;
2
3use camel_component_api::CamelError;
4use camel_component_api::{UriComponents, UriConfig, parse_uri};
5
6#[derive(Debug, Clone, PartialEq, Default)]
8pub enum SqlOutputType {
9 #[default]
11 SelectList,
12 SelectOne,
14 StreamList,
16}
17
18impl FromStr for SqlOutputType {
19 type Err = CamelError;
20
21 fn from_str(s: &str) -> Result<Self, Self::Err> {
22 match s {
23 "SelectList" => Ok(SqlOutputType::SelectList),
24 "SelectOne" => Ok(SqlOutputType::SelectOne),
25 "StreamList" => Ok(SqlOutputType::StreamList),
26 _ => Err(CamelError::InvalidUri(format!(
27 "Unknown output type: {}",
28 s
29 ))),
30 }
31 }
32}
33
34#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
39#[serde(default)]
40pub struct SqlGlobalConfig {
41 pub max_connections: u32,
42 pub min_connections: u32,
43 pub idle_timeout_secs: u64,
44 pub max_lifetime_secs: u64,
45 pub ssl_mode: Option<String>,
47 pub ssl_root_cert: Option<String>,
48 pub ssl_cert: Option<String>,
49 pub ssl_key: Option<String>,
50}
51
52impl Default for SqlGlobalConfig {
53 fn default() -> Self {
54 Self {
55 max_connections: 5,
56 min_connections: 1,
57 idle_timeout_secs: 300,
58 max_lifetime_secs: 1800,
59 ssl_mode: None,
60 ssl_root_cert: None,
61 ssl_cert: None,
62 ssl_key: None,
63 }
64 }
65}
66
67impl SqlGlobalConfig {
68 pub fn new() -> Self {
69 Self::default()
70 }
71
72 pub fn with_max_connections(mut self, value: u32) -> Self {
73 self.max_connections = value;
74 self
75 }
76
77 pub fn with_min_connections(mut self, value: u32) -> Self {
78 self.min_connections = value;
79 self
80 }
81
82 pub fn with_idle_timeout_secs(mut self, value: u64) -> Self {
83 self.idle_timeout_secs = value;
84 self
85 }
86
87 pub fn with_max_lifetime_secs(mut self, value: u64) -> Self {
88 self.max_lifetime_secs = value;
89 self
90 }
91
92 pub fn with_ssl_mode(mut self, value: impl Into<String>) -> Self {
93 self.ssl_mode = Some(value.into());
94 self
95 }
96
97 pub fn with_ssl_root_cert(mut self, value: impl Into<String>) -> Self {
98 self.ssl_root_cert = Some(value.into());
99 self
100 }
101
102 pub fn with_ssl_cert(mut self, value: impl Into<String>) -> Self {
103 self.ssl_cert = Some(value.into());
104 self
105 }
106
107 pub fn with_ssl_key(mut self, value: impl Into<String>) -> Self {
108 self.ssl_key = Some(value.into());
109 self
110 }
111}
112
113#[derive(Debug, Clone)]
121pub struct SqlEndpointConfig {
122 pub db_url: String,
125 pub max_connections: Option<u32>,
127 pub min_connections: Option<u32>,
129 pub idle_timeout_secs: Option<u64>,
131 pub max_lifetime_secs: Option<u64>,
133
134 pub query: String,
137 pub source_path: Option<String>,
139 pub output_type: SqlOutputType,
141 pub placeholder: char,
143 pub noop: bool,
145 pub in_separator: String,
147
148 pub delay_ms: u64,
151 pub initial_delay_ms: u64,
153 pub max_messages_per_poll: Option<i32>,
155 pub on_consume: Option<String>,
157 pub on_consume_failed: Option<String>,
159 pub on_consume_batch_complete: Option<String>,
161 pub route_empty_result_set: bool,
163 pub use_iterator: bool,
165 pub expected_update_count: Option<i64>,
167 pub break_batch_on_consume_fail: bool,
169
170 pub batch: bool,
173 pub use_message_body_for_sql: bool,
175
176 pub ssl_mode: Option<String>,
179 pub ssl_root_cert: Option<String>,
181 pub ssl_cert: Option<String>,
183 pub ssl_key: Option<String>,
185}
186
187impl SqlEndpointConfig {
188 pub fn apply_defaults(&mut self, defaults: &SqlGlobalConfig) {
190 if self.max_connections.is_none() {
191 self.max_connections = Some(defaults.max_connections);
192 }
193 if self.min_connections.is_none() {
194 self.min_connections = Some(defaults.min_connections);
195 }
196 if self.idle_timeout_secs.is_none() {
197 self.idle_timeout_secs = Some(defaults.idle_timeout_secs);
198 }
199 if self.max_lifetime_secs.is_none() {
200 self.max_lifetime_secs = Some(defaults.max_lifetime_secs);
201 }
202 if self.ssl_mode.is_none() {
203 self.ssl_mode = defaults.ssl_mode.clone();
204 }
205 if self.ssl_root_cert.is_none() {
206 self.ssl_root_cert = defaults.ssl_root_cert.clone();
207 }
208 if self.ssl_cert.is_none() {
209 self.ssl_cert = defaults.ssl_cert.clone();
210 }
211 if self.ssl_key.is_none() {
212 self.ssl_key = defaults.ssl_key.clone();
213 }
214 }
215
216 pub fn resolve_defaults(&mut self) {
218 let defaults = SqlGlobalConfig::default();
219 self.apply_defaults(&defaults);
220 }
221}
222
223struct SslParamMapping {
224 pg_key: &'static str,
225 mysql_key: &'static str,
226}
227
228const SSL_MAPPINGS: &[(&str, SslParamMapping)] = &[
229 (
230 "sslMode",
231 SslParamMapping {
232 pg_key: "sslmode",
233 mysql_key: "ssl-mode",
234 },
235 ),
236 (
237 "sslRootCert",
238 SslParamMapping {
239 pg_key: "sslrootcert",
240 mysql_key: "ssl-ca",
241 },
242 ),
243 (
244 "sslCert",
245 SslParamMapping {
246 pg_key: "sslcert",
247 mysql_key: "ssl-cert",
248 },
249 ),
250 (
251 "sslKey",
252 SslParamMapping {
253 pg_key: "sslkey",
254 mysql_key: "ssl-key",
255 },
256 ),
257];
258
259pub fn enrich_db_url_with_ssl(
260 db_url: &str,
261 config: &SqlEndpointConfig,
262) -> Result<String, CamelError> {
263 let ssl_params: Vec<(&str, &str)> = [
264 config.ssl_mode.as_deref().map(|v| ("sslMode", v)),
265 config.ssl_root_cert.as_deref().map(|v| ("sslRootCert", v)),
266 config.ssl_cert.as_deref().map(|v| ("sslCert", v)),
267 config.ssl_key.as_deref().map(|v| ("sslKey", v)),
268 ]
269 .into_iter()
270 .flatten()
271 .collect();
272
273 if ssl_params.is_empty() {
274 return Ok(db_url.to_string());
275 }
276
277 let mut parsed = url::Url::parse(db_url).map_err(|e| {
278 CamelError::InvalidUri(format!(
279 "Cannot parse database URL for SSL enrichment: {}",
280 e
281 ))
282 })?;
283
284 let scheme = parsed.scheme();
285 if scheme != "postgres" && scheme != "postgresql" && scheme != "mysql" {
286 return Ok(db_url.to_string());
287 }
288 let is_mysql = scheme == "mysql";
289
290 let mut query_pairs = parsed.query_pairs().collect::<Vec<_>>();
291 for (camel_name, value) in &ssl_params {
292 if let Some((_, mapping)) = SSL_MAPPINGS.iter().find(|(name, _)| *name == *camel_name) {
293 let driver_key = if is_mysql {
294 mapping.mysql_key
295 } else {
296 mapping.pg_key
297 };
298
299 if let Some(pos) = query_pairs.iter().position(|(k, _)| k == driver_key) {
300 query_pairs[pos].1 = (*value).into();
301 } else {
302 query_pairs.push((driver_key.into(), (*value).into()));
303 }
304 }
305 }
306
307 {
308 let mut serializer = url::form_urlencoded::Serializer::new(String::new());
309 for (k, v) in query_pairs {
310 serializer.append_pair(&k, &v);
311 }
312 parsed.set_query(Some(&serializer.finish()));
313 }
314
315 Ok(parsed.to_string())
316}
317
318impl UriConfig for SqlEndpointConfig {
319 fn scheme() -> &'static str {
320 "sql"
321 }
322
323 fn from_uri(uri: &str) -> Result<Self, CamelError> {
324 let parts = parse_uri(uri)?;
325 Self::from_components(parts)
326 }
327
328 fn from_components(parts: UriComponents) -> Result<Self, CamelError> {
329 if parts.scheme != Self::scheme() {
331 return Err(CamelError::InvalidUri(format!(
332 "expected scheme '{}' but got '{}'",
333 Self::scheme(),
334 parts.scheme
335 )));
336 }
337
338 let params = &parts.params;
339
340 let (query, source_path) = if parts.path.starts_with("file:") {
342 let file_path = parts.path.trim_start_matches("file:").to_string();
343 let contents = std::fs::read_to_string(&file_path).map_err(|e| {
344 CamelError::Config(format!("Failed to read SQL file '{}': {}", file_path, e))
345 })?;
346 (contents.trim().to_string(), Some(file_path))
347 } else {
348 (parts.path.clone(), None)
349 };
350
351 let db_url = params
353 .get("db_url")
354 .ok_or_else(|| CamelError::Config("db_url parameter is required".to_string()))?
355 .clone();
356
357 let max_connections = params.get("maxConnections").and_then(|v| v.parse().ok());
359 let min_connections = params.get("minConnections").and_then(|v| v.parse().ok());
360 let idle_timeout_secs = params.get("idleTimeoutSecs").and_then(|v| v.parse().ok());
361 let max_lifetime_secs = params.get("maxLifetimeSecs").and_then(|v| v.parse().ok());
362
363 let output_type = params
365 .get("outputType")
366 .map(|s| s.parse())
367 .transpose()?
368 .unwrap_or_default();
369 let placeholder = params
370 .get("placeholder")
371 .filter(|v| !v.is_empty())
372 .map(|v| v.chars().next().unwrap())
373 .unwrap_or('#');
374 let noop = params
375 .get("noop")
376 .map(|v| v.eq_ignore_ascii_case("true"))
377 .unwrap_or(false);
378 let in_separator = params
379 .get("inSeparator")
380 .map(|v| v.to_string())
381 .unwrap_or_else(|| ", ".to_string());
382 if in_separator.is_empty() {
383 return Err(CamelError::InvalidUri(
384 "inSeparator must not be empty".to_string(),
385 ));
386 }
387
388 let delay_ms = params
390 .get("delay")
391 .and_then(|v| v.parse().ok())
392 .unwrap_or(500);
393 let initial_delay_ms = params
394 .get("initialDelay")
395 .and_then(|v| v.parse().ok())
396 .unwrap_or(1000);
397 let max_messages_per_poll = params
398 .get("maxMessagesPerPoll")
399 .and_then(|v| v.parse().ok());
400 let on_consume = params.get("onConsume").cloned();
401 let on_consume_failed = params.get("onConsumeFailed").cloned();
402 let on_consume_batch_complete = params.get("onConsumeBatchComplete").cloned();
403 let route_empty_result_set = params
404 .get("routeEmptyResultSet")
405 .map(|v| v.eq_ignore_ascii_case("true"))
406 .unwrap_or(false);
407 let use_iterator = params
408 .get("useIterator")
409 .map(|v| v.eq_ignore_ascii_case("true"))
410 .unwrap_or(true);
411 let expected_update_count = params
412 .get("expectedUpdateCount")
413 .and_then(|v| v.parse().ok());
414 let break_batch_on_consume_fail = params
415 .get("breakBatchOnConsumeFail")
416 .map(|v| v.eq_ignore_ascii_case("true"))
417 .unwrap_or(false);
418
419 let batch = params
421 .get("batch")
422 .map(|v| v.eq_ignore_ascii_case("true"))
423 .unwrap_or(false);
424 let use_message_body_for_sql = params
425 .get("useMessageBodyForSql")
426 .map(|v| v.eq_ignore_ascii_case("true"))
427 .unwrap_or(false);
428 let ssl_mode = params.get("sslMode").cloned();
429 let ssl_root_cert = params.get("sslRootCert").cloned();
430 let ssl_cert = params.get("sslCert").cloned();
431 let ssl_key = params.get("sslKey").cloned();
432
433 Ok(Self {
434 db_url,
435 max_connections,
436 min_connections,
437 idle_timeout_secs,
438 max_lifetime_secs,
439 query,
440 source_path,
441 output_type,
442 placeholder,
443 noop,
444 in_separator,
445 delay_ms,
446 initial_delay_ms,
447 max_messages_per_poll,
448 on_consume,
449 on_consume_failed,
450 on_consume_batch_complete,
451 route_empty_result_set,
452 use_iterator,
453 expected_update_count,
454 break_batch_on_consume_fail,
455 batch,
456 use_message_body_for_sql,
457 ssl_mode,
458 ssl_root_cert,
459 ssl_cert,
460 ssl_key,
461 })
462 }
463}
464
465#[cfg(test)]
466mod tests {
467 use super::*;
468
469 #[test]
470 fn config_defaults() {
471 let mut c =
472 SqlEndpointConfig::from_uri("sql:select 1?db_url=postgres://localhost/test").unwrap();
473 c.resolve_defaults();
474 assert_eq!(c.query, "select 1");
475 assert_eq!(c.db_url, "postgres://localhost/test");
476 assert_eq!(c.max_connections, Some(5));
477 assert_eq!(c.min_connections, Some(1));
478 assert_eq!(c.idle_timeout_secs, Some(300));
479 assert_eq!(c.max_lifetime_secs, Some(1800));
480 assert_eq!(c.output_type, SqlOutputType::SelectList);
481 assert_eq!(c.placeholder, '#');
482 assert!(!c.noop);
483 assert_eq!(c.in_separator, ", ");
484 assert_eq!(c.delay_ms, 500);
485 assert_eq!(c.initial_delay_ms, 1000);
486 assert!(c.max_messages_per_poll.is_none());
487 assert!(c.on_consume.is_none());
488 assert!(c.on_consume_failed.is_none());
489 assert!(c.on_consume_batch_complete.is_none());
490 assert!(!c.route_empty_result_set);
491 assert!(c.use_iterator);
492 assert!(c.expected_update_count.is_none());
493 assert!(!c.break_batch_on_consume_fail);
494 assert!(!c.batch);
495 assert!(!c.use_message_body_for_sql);
496 assert!(c.ssl_mode.is_none());
497 assert!(c.ssl_root_cert.is_none());
498 assert!(c.ssl_cert.is_none());
499 assert!(c.ssl_key.is_none());
500 }
501
502 #[test]
503 fn ssl_none_by_default() {
504 let c =
505 SqlEndpointConfig::from_uri("sql:select 1?db_url=postgres://localhost/test").unwrap();
506 assert!(c.ssl_mode.is_none());
507 assert!(c.ssl_root_cert.is_none());
508 assert!(c.ssl_cert.is_none());
509 assert!(c.ssl_key.is_none());
510 }
511
512 #[test]
513 fn ssl_mode_from_uri() {
514 let c = SqlEndpointConfig::from_uri(
515 "sql:select 1?db_url=postgres://localhost/test&sslMode=require",
516 )
517 .unwrap();
518 assert_eq!(c.ssl_mode, Some("require".to_string()));
519 assert!(c.ssl_root_cert.is_none());
520 }
521
522 #[test]
523 fn ssl_all_params_from_uri() {
524 let c = SqlEndpointConfig::from_uri(
525 "sql:select 1?db_url=postgres://localhost/test&sslMode=require&sslRootCert=/ca.pem&sslCert=/cert.pem&sslKey=/key.pem",
526 )
527 .unwrap();
528 assert_eq!(c.ssl_mode, Some("require".to_string()));
529 assert_eq!(c.ssl_root_cert, Some("/ca.pem".to_string()));
530 assert_eq!(c.ssl_cert, Some("/cert.pem".to_string()));
531 assert_eq!(c.ssl_key, Some("/key.pem".to_string()));
532 }
533
534 #[test]
535 fn ssl_global_applied_to_endpoint() {
536 let mut c =
537 SqlEndpointConfig::from_uri("sql:select 1?db_url=postgres://localhost/test").unwrap();
538 let global = SqlGlobalConfig::default()
539 .with_ssl_mode("require")
540 .with_ssl_root_cert("/etc/ssl/ca.pem");
541 c.apply_defaults(&global);
542 assert_eq!(c.ssl_mode, Some("require".to_string()));
543 assert_eq!(c.ssl_root_cert, Some("/etc/ssl/ca.pem".to_string()));
544 assert!(c.ssl_cert.is_none());
545 assert!(c.ssl_key.is_none());
546 }
547
548 #[test]
549 fn ssl_uri_overrides_global() {
550 let mut c = SqlEndpointConfig::from_uri(
551 "sql:select 1?db_url=postgres://localhost/test&sslMode=verify-full",
552 )
553 .unwrap();
554 let global = SqlGlobalConfig::default().with_ssl_mode("require");
555 c.apply_defaults(&global);
556 assert_eq!(c.ssl_mode, Some("verify-full".to_string()));
557 }
558
559 #[test]
560 fn config_wrong_scheme() {
561 assert!(SqlEndpointConfig::from_uri("redis://localhost:6379").is_err());
562 }
563
564 #[test]
565 fn config_missing_db_url() {
566 assert!(SqlEndpointConfig::from_uri("sql:select 1").is_err());
567 }
568
569 #[test]
570 fn config_output_type_select_one() {
571 let c = SqlEndpointConfig::from_uri(
572 "sql:select 1?db_url=postgres://localhost/test&outputType=SelectOne",
573 )
574 .unwrap();
575 assert_eq!(c.output_type, SqlOutputType::SelectOne);
576 }
577
578 #[test]
579 fn config_output_type_stream_list() {
580 let c = SqlEndpointConfig::from_uri(
581 "sql:select 1?db_url=postgres://localhost/test&outputType=StreamList",
582 )
583 .unwrap();
584 assert_eq!(c.output_type, SqlOutputType::StreamList);
585 }
586
587 #[test]
588 fn in_separator_default() {
589 let c =
590 SqlEndpointConfig::from_uri("sql:select 1?db_url=postgres://localhost/test").unwrap();
591 assert_eq!(c.in_separator, ", ");
592 }
593
594 #[test]
595 fn in_separator_from_uri() {
596 let c = SqlEndpointConfig::from_uri(
597 "sql:select 1?db_url=postgres://localhost/test&inSeparator=;",
598 )
599 .unwrap();
600 assert_eq!(c.in_separator, ";");
601 }
602
603 #[test]
604 fn in_separator_empty_rejected() {
605 let result = SqlEndpointConfig::from_uri(
606 "sql:select 1?db_url=postgres://localhost/test&inSeparator=",
607 );
608 assert!(result.is_err());
609 let msg = format!("{:?}", result.unwrap_err());
610 assert!(msg.contains("inSeparator") || msg.contains("empty"));
611 }
612
613 #[test]
614 fn config_consumer_options() {
615 let c = SqlEndpointConfig::from_uri(
616 "sql:select * from t?db_url=postgres://localhost/test&delay=2000&initialDelay=500&maxMessagesPerPoll=10&onConsume=update t set done=true where id=:#id&onConsumeFailed=update t set failed=true where id=:#id&onConsumeBatchComplete=delete from t where done=true&routeEmptyResultSet=true&useIterator=false&expectedUpdateCount=1&breakBatchOnConsumeFail=true"
617 ).unwrap();
618 assert_eq!(c.delay_ms, 2000);
619 assert_eq!(c.initial_delay_ms, 500);
620 assert_eq!(c.max_messages_per_poll, Some(10));
621 assert_eq!(
622 c.on_consume,
623 Some("update t set done=true where id=:#id".to_string())
624 );
625 assert_eq!(
626 c.on_consume_failed,
627 Some("update t set failed=true where id=:#id".to_string())
628 );
629 assert_eq!(
630 c.on_consume_batch_complete,
631 Some("delete from t where done=true".to_string())
632 );
633 assert!(c.route_empty_result_set);
634 assert!(!c.use_iterator);
635 assert_eq!(c.expected_update_count, Some(1));
636 assert!(c.break_batch_on_consume_fail);
637 }
638
639 #[test]
640 fn config_producer_options() {
641 let c = SqlEndpointConfig::from_uri(
642 "sql:insert into t values (#)?db_url=postgres://localhost/test&batch=true&useMessageBodyForSql=true&noop=true"
643 ).unwrap();
644 assert!(c.batch);
645 assert!(c.use_message_body_for_sql);
646 assert!(c.noop);
647 }
648
649 #[test]
650 fn config_pool_options() {
651 let c = SqlEndpointConfig::from_uri(
652 "sql:select 1?db_url=postgres://localhost/test&maxConnections=20&minConnections=3&idleTimeoutSecs=600&maxLifetimeSecs=3600"
653 ).unwrap();
654 assert_eq!(c.max_connections, Some(20));
655 assert_eq!(c.min_connections, Some(3));
656 assert_eq!(c.idle_timeout_secs, Some(600));
657 assert_eq!(c.max_lifetime_secs, Some(3600));
658 }
659
660 #[test]
661 fn config_query_with_special_chars() {
662 let c = SqlEndpointConfig::from_uri(
663 "sql:select * from users where name = :#name and age > #?db_url=postgres://localhost/test",
664 )
665 .unwrap();
666 assert_eq!(
667 c.query,
668 "select * from users where name = :#name and age > #"
669 );
670 }
671
672 #[test]
673 fn output_type_from_str() {
674 assert_eq!(
675 "SelectList".parse::<SqlOutputType>().unwrap(),
676 SqlOutputType::SelectList
677 );
678 assert_eq!(
679 "SelectOne".parse::<SqlOutputType>().unwrap(),
680 SqlOutputType::SelectOne
681 );
682 assert_eq!(
683 "StreamList".parse::<SqlOutputType>().unwrap(),
684 SqlOutputType::StreamList
685 );
686 assert!("Invalid".parse::<SqlOutputType>().is_err());
687 }
688
689 #[test]
690 fn config_file_not_found() {
691 let result = SqlEndpointConfig::from_uri(
692 "sql:file:/nonexistent/path/query.sql?db_url=postgres://localhost/test",
693 );
694 assert!(result.is_err());
695 let err = result.unwrap_err();
696 let msg = format!("{:?}", err);
697 assert!(msg.contains("Failed to read SQL file") || msg.contains("nonexistent"));
698 }
699
700 #[test]
701 fn config_file_query() {
702 use std::io::Write;
703 let unique_name = format!(
704 "test_sql_query_{}.sql",
705 std::time::SystemTime::now()
706 .duration_since(std::time::UNIX_EPOCH)
707 .unwrap_or_default()
708 .as_nanos()
709 );
710 let mut tmp = std::env::temp_dir();
711 tmp.push(unique_name);
712 {
713 let mut f = std::fs::File::create(&tmp).unwrap();
714 writeln!(f, "SELECT * FROM users").unwrap();
715 }
716 let uri = format!(
717 "sql:file:{}?db_url=postgres://localhost/test",
718 tmp.display()
719 );
720 let c = SqlEndpointConfig::from_uri(&uri).unwrap();
721 assert_eq!(c.query, "SELECT * FROM users");
722 assert_eq!(c.source_path, Some(tmp.to_string_lossy().into_owned()));
723 std::fs::remove_file(&tmp).ok();
724 }
725
726 #[test]
728 fn pool_fields_none_when_not_set() {
729 let c =
730 SqlEndpointConfig::from_uri("sql:select 1?db_url=postgres://localhost/test").unwrap();
731 assert_eq!(c.max_connections, None);
732 assert_eq!(c.min_connections, None);
733 assert_eq!(c.idle_timeout_secs, None);
734 assert_eq!(c.max_lifetime_secs, None);
735 }
736
737 #[test]
738 fn apply_defaults_fills_none() {
739 let mut c =
740 SqlEndpointConfig::from_uri("sql:select 1?db_url=postgres://localhost/test").unwrap();
741 let global = SqlGlobalConfig {
742 max_connections: 10,
743 min_connections: 2,
744 idle_timeout_secs: 600,
745 max_lifetime_secs: 3600,
746 ssl_mode: None,
747 ssl_root_cert: None,
748 ssl_cert: None,
749 ssl_key: None,
750 };
751 c.apply_defaults(&global);
752 assert_eq!(c.max_connections, Some(10));
753 assert_eq!(c.min_connections, Some(2));
754 assert_eq!(c.idle_timeout_secs, Some(600));
755 assert_eq!(c.max_lifetime_secs, Some(3600));
756 assert!(c.ssl_mode.is_none());
757 assert!(c.ssl_root_cert.is_none());
758 assert!(c.ssl_cert.is_none());
759 assert!(c.ssl_key.is_none());
760 }
761
762 #[test]
763 fn apply_defaults_does_not_override() {
764 let mut c = SqlEndpointConfig::from_uri(
765 "sql:select 1?db_url=postgres://localhost/test&maxConnections=99&minConnections=5",
766 )
767 .unwrap();
768 let global = SqlGlobalConfig {
769 max_connections: 10,
770 min_connections: 2,
771 idle_timeout_secs: 600,
772 max_lifetime_secs: 3600,
773 ssl_mode: None,
774 ssl_root_cert: None,
775 ssl_cert: None,
776 ssl_key: None,
777 };
778 c.apply_defaults(&global);
779 assert_eq!(c.max_connections, Some(99));
781 assert_eq!(c.min_connections, Some(5));
782 assert_eq!(c.idle_timeout_secs, Some(600));
784 assert_eq!(c.max_lifetime_secs, Some(3600));
785 }
786
787 #[test]
788 fn resolve_defaults_fills_remaining() {
789 let mut c = SqlEndpointConfig::from_uri(
790 "sql:select 1?db_url=postgres://localhost/test&maxConnections=7",
791 )
792 .unwrap();
793 c.resolve_defaults();
794 assert_eq!(c.max_connections, Some(7)); assert_eq!(c.min_connections, Some(1)); assert_eq!(c.idle_timeout_secs, Some(300)); assert_eq!(c.max_lifetime_secs, Some(1800)); }
799
800 #[test]
801 fn global_config_builder() {
802 let c = SqlGlobalConfig::default()
803 .with_max_connections(20)
804 .with_min_connections(3)
805 .with_idle_timeout_secs(600)
806 .with_max_lifetime_secs(3600)
807 .with_ssl_mode("require")
808 .with_ssl_root_cert("/ca.pem")
809 .with_ssl_cert("/cert.pem")
810 .with_ssl_key("/key.pem");
811 assert_eq!(c.max_connections, 20);
812 assert_eq!(c.min_connections, 3);
813 assert_eq!(c.idle_timeout_secs, 600);
814 assert_eq!(c.max_lifetime_secs, 3600);
815 assert_eq!(c.ssl_mode, Some("require".to_string()));
816 assert_eq!(c.ssl_root_cert, Some("/ca.pem".to_string()));
817 assert_eq!(c.ssl_cert, Some("/cert.pem".to_string()));
818 assert_eq!(c.ssl_key, Some("/key.pem".to_string()));
819 }
820
821 #[test]
822 fn enrich_postgres_ssl_mode() {
823 let mut c = SqlEndpointConfig::from_uri(
824 "sql:select 1?db_url=postgres://localhost/test&sslMode=require",
825 )
826 .unwrap();
827 c.resolve_defaults();
828 let url = enrich_db_url_with_ssl(&c.db_url, &c).unwrap();
829 assert!(url.contains("sslmode=require"), "got: {}", url);
830 }
831
832 #[test]
833 fn enrich_postgres_all_ssl() {
834 let mut c = SqlEndpointConfig::from_uri(
835 "sql:select 1?db_url=postgres://localhost/test&sslMode=require&sslRootCert=/ca.pem&sslCert=/cert.pem&sslKey=/key.pem",
836 )
837 .unwrap();
838 c.resolve_defaults();
839 let url = enrich_db_url_with_ssl(&c.db_url, &c).unwrap();
840 assert!(url.contains("sslmode=require"), "got: {}", url);
841 assert!(url.contains("sslrootcert="), "got: {}", url);
842 assert!(url.contains("sslcert="), "got: {}", url);
843 assert!(url.contains("sslkey="), "got: {}", url);
844 }
845
846 #[test]
847 fn enrich_mysql_ssl() {
848 let mut c = SqlEndpointConfig::from_uri(
849 "sql:select 1?db_url=mysql://localhost/test&sslMode=require",
850 )
851 .unwrap();
852 c.resolve_defaults();
853 let url = enrich_db_url_with_ssl(&c.db_url, &c).unwrap();
854 assert!(url.contains("ssl-mode=require"), "got: {}", url);
855 }
856
857 #[test]
858 fn enrich_existing_query_params() {
859 let mut c = SqlEndpointConfig::from_uri(
860 "sql:select 1?db_url=postgres://localhost/test?existing=1&sslMode=require",
861 )
862 .unwrap();
863 c.resolve_defaults();
864 let url = enrich_db_url_with_ssl(&c.db_url, &c).unwrap();
865 assert!(url.contains("existing=1"), "got: {}", url);
866 assert!(url.contains("sslmode=require"), "got: {}", url);
867 }
868
869 #[test]
870 fn enrich_override_existing() {
871 let mut c = SqlEndpointConfig::from_uri(
872 "sql:select 1?db_url=postgres://localhost/test?sslmode=allow&sslMode=require",
873 )
874 .unwrap();
875 c.resolve_defaults();
876 let url = enrich_db_url_with_ssl(&c.db_url, &c).unwrap();
877 assert!(url.contains("sslmode=require"), "got: {}", url);
878 assert!(!url.contains("sslmode=allow"), "got: {}", url);
879 }
880
881 #[test]
882 fn enrich_no_params() {
883 let mut c =
884 SqlEndpointConfig::from_uri("sql:select 1?db_url=postgres://localhost/test").unwrap();
885 c.resolve_defaults();
886 let url = enrich_db_url_with_ssl(&c.db_url, &c).unwrap();
887 assert_eq!(url, "postgres://localhost/test");
888 }
889
890 #[test]
891 fn enrich_url_encodes_paths() {
892 let mut c = SqlEndpointConfig::from_uri(
893 "sql:select 1?db_url=postgres://localhost/test&sslRootCert=/path/to/my%20cert.pem",
894 )
895 .unwrap();
896 c.resolve_defaults();
897 let url = enrich_db_url_with_ssl(&c.db_url, &c).unwrap();
898 assert!(url.contains("sslrootcert="), "got: {}", url);
899 }
900
901 #[test]
902 fn enrich_unsupported_scheme_returns_unchanged() {
903 let mut c = SqlEndpointConfig::from_uri(
904 "sql:select 1?db_url=sqlite://localhost/test.db&sslMode=require",
905 )
906 .unwrap();
907 c.resolve_defaults();
908 let url = enrich_db_url_with_ssl(&c.db_url, &c).unwrap();
909 assert_eq!(url, "sqlite://localhost/test.db");
910 }
911
912 #[test]
913 fn enrich_invalid_url_returns_error() {
914 let mut c = SqlEndpointConfig::from_uri(
915 "sql:select 1?db_url=postgres://localhost/test&sslMode=require",
916 )
917 .unwrap();
918 c.resolve_defaults();
919 let result = enrich_db_url_with_ssl("://not-a-valid-url", &c);
920 assert!(result.is_err());
921 }
922}