aurora_dsql_sqlx_connector/
config.rs1use crate::{util::ClusterId, DsqlError, Result};
5use aws_config::{Region, SdkConfig};
6use aws_credential_types::provider::SharedCredentialsProvider;
7use derive_builder::Builder;
8use sqlx::postgres::{PgConnectOptions, PgSslMode};
9#[cfg(feature = "pool")]
10use std::time::Duration;
11use url::Url;
12
13const DEFAULT_USER: &str = "admin";
14const DEFAULT_DATABASE: &str = "postgres";
15const DEFAULT_PORT: u16 = 5432;
16const DEFAULT_TOKEN_DURATION_SECS: u64 = 900;
17
18#[derive(Debug, Clone, Builder)]
19#[builder(setter(into), build_fn(validate = "Self::validate"))]
20pub struct DsqlConnectOptions {
21 pg_connect_options: PgConnectOptions,
22 #[builder(default)]
23 region: Option<Region>,
24 #[builder(default)]
25 profile: Option<String>,
26 #[builder(default = "DEFAULT_TOKEN_DURATION_SECS")]
27 token_duration_secs: u64,
28 #[builder(default)]
29 orm_prefix: Option<String>,
30 #[builder(default)]
31 credentials_provider: Option<SharedCredentialsProvider>,
32}
33
34impl DsqlConnectOptionsBuilder {
35 fn validate(&self) -> std::result::Result<(), String> {
36 if let Some(ref pg) = self.pg_connect_options {
37 crate::util::validate_host(pg.get_host())?;
38 }
39 Ok(())
40 }
41}
42
43impl DsqlConnectOptions {
44 pub fn from_connection_string(conn_str: &str) -> Result<Self> {
45 let url = Self::parse_url(conn_str)?;
46 Self::from_url(&url)
47 }
48
49 fn parse_url(conn_str: &str) -> Result<Url> {
50 let url = Url::parse(conn_str).map_err(|e| DsqlError::ConfigError(e.into()))?;
51
52 match url.scheme() {
53 "postgres" | "postgresql" => {}
54 _ => {
55 return Err(DsqlError::ConfigError(
56 "Unsupported URL scheme. Use 'postgres://' or 'postgresql://'".into(),
57 ));
58 }
59 }
60
61 Ok(url)
62 }
63
64 fn from_url(url: &Url) -> Result<Self> {
65 let host = url
66 .host_str()
67 .ok_or_else(|| DsqlError::ConfigError("Host is required".into()))?;
68
69 crate::util::validate_host(host).map_err(|e| DsqlError::ConfigError(e.into()))?;
70
71 let port = url.port().unwrap_or(DEFAULT_PORT);
72
73 let user = if !url.username().is_empty() {
74 url.username()
75 } else {
76 DEFAULT_USER
77 };
78
79 let database = {
80 let db = url.path().trim_start_matches('/');
81 if db.is_empty() {
82 DEFAULT_DATABASE
83 } else {
84 db
85 }
86 };
87
88 let mut region = None;
89 let mut profile = None;
90 let mut token_duration_secs = DEFAULT_TOKEN_DURATION_SECS;
91 let mut orm_prefix = None;
92
93 for (key, value) in url.query_pairs() {
94 match key.as_ref() {
95 "region" => {
96 region = Some(Region::new(value.to_string()));
97 }
98 "profile" => profile = Some(value.to_string()),
99 "tokenDurationSecs" => {
100 let secs: u64 = value
101 .parse()
102 .map_err(|e: std::num::ParseIntError| DsqlError::ConfigError(e.into()))?;
103 token_duration_secs = secs;
104 }
105 "ormPrefix" => orm_prefix = Some(value.to_string()),
106 other => {
107 log::debug!(
108 "aurora-dsql: ignoring unrecognized connection parameter: {}",
109 other
110 );
111 }
112 }
113 }
114
115 let app_name = crate::util::build_application_name(orm_prefix.as_deref());
116
117 let pg = PgConnectOptions::new()
118 .host(host)
119 .port(port)
120 .username(user)
121 .database(database)
122 .ssl_mode(PgSslMode::VerifyFull)
123 .application_name(&app_name);
124
125 Ok(DsqlConnectOptions {
126 pg_connect_options: pg,
127 region,
128 profile,
129 token_duration_secs,
130 orm_prefix,
131 credentials_provider: None,
132 })
133 }
134
135 pub async fn authenticated_pg_options(&self) -> Result<PgConnectOptions> {
140 let sdk_config = load_aws_config(self.profile(), self.credentials_provider()).await;
141 let host = self.resolve_host(&sdk_config)?;
142 let region = self.resolve_region(&sdk_config)?;
143 let signer =
144 crate::token::build_signer(&host, ®ion, &sdk_config, Some(self.token_duration()))?;
145 let user = self.pg_connect_options.get_username();
146 let token = crate::token::generate_token(&signer, user, &sdk_config).await?;
147 self.build_connect_options(&sdk_config, &token)
148 }
149
150 pub(crate) fn build_connect_options(
154 &self,
155 sdk_config: &SdkConfig,
156 token: &str,
157 ) -> Result<PgConnectOptions> {
158 let host = self.resolve_host(sdk_config)?;
159 let app_name = crate::util::build_application_name(self.orm_prefix.as_deref());
160 Ok(self
161 .pg_connect_options
162 .clone()
163 .host(&host)
164 .password(token)
165 .ssl_mode(PgSslMode::VerifyFull)
166 .application_name(&app_name))
167 }
168
169 #[cfg(feature = "pool")]
171 pub(crate) fn pg_connect_options(&self) -> &PgConnectOptions {
172 &self.pg_connect_options
173 }
174
175 pub(crate) fn profile(&self) -> Option<&str> {
177 self.profile.as_deref()
178 }
179
180 pub(crate) fn credentials_provider(&self) -> Option<&SharedCredentialsProvider> {
182 self.credentials_provider.as_ref()
183 }
184
185 pub(crate) fn token_duration(&self) -> u64 {
187 self.token_duration_secs
188 }
189
190 #[cfg(feature = "pool")]
193 pub(crate) fn refresh_interval(&self) -> Duration {
194 Duration::from_secs((self.token_duration() * 4 / 5).max(1))
195 }
196
197 pub(crate) fn resolve_host(&self, sdk_config: &SdkConfig) -> Result<String> {
199 let host = self.pg_connect_options.get_host();
200 if let Some(cluster_id) = ClusterId::new(host) {
201 let region = self.resolve_region(sdk_config)?;
202 Ok(crate::util::build_hostname(&cluster_id, ®ion))
203 } else {
204 Ok(host.to_string())
205 }
206 }
207
208 pub(crate) fn resolve_region(&self, sdk_config: &SdkConfig) -> Result<Region> {
209 let host = self.pg_connect_options.get_host();
211 if let Some(region) = crate::util::parse_region(host) {
212 return Ok(region);
213 }
214
215 if let Some(ref region) = self.region {
217 return Ok(region.clone());
218 }
219
220 if let Some(region) = sdk_config.region() {
222 return Ok(region.clone());
223 }
224
225 Err(DsqlError::ConfigError(
226 "Could not determine region from connection string, hostname, or AWS configuration"
227 .into(),
228 ))
229 }
230}
231
232pub(crate) async fn load_aws_config(
234 profile: Option<&str>,
235 credentials_provider: Option<&SharedCredentialsProvider>,
236) -> SdkConfig {
237 let mut loader = aws_config::defaults(aws_config::BehaviorVersion::latest());
238 if let Some(profile) = profile {
239 loader = loader.profile_name(profile);
240 }
241 if let Some(provider) = credentials_provider {
242 loader = loader.credentials_provider(provider.clone());
243 }
244 loader.load().await
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250
251 #[test]
254 fn test_parse_basic_connection_string() -> Result<()> {
255 let config = DsqlConnectOptions::from_connection_string(
256 "postgres://admin@example.dsql.us-east-1.on.aws:5432/postgres",
257 )?;
258
259 assert_eq!(config.pg_connect_options.get_username(), "admin");
260 assert_eq!(
261 config.pg_connect_options.get_host(),
262 "example.dsql.us-east-1.on.aws"
263 );
264 assert_eq!(config.pg_connect_options.get_port(), 5432);
265 assert_eq!(
266 config.pg_connect_options.get_database().unwrap(),
267 "postgres"
268 );
269 assert!(config.region.is_none());
270 Ok(())
271 }
272
273 #[test]
274 fn test_parse_with_region_param() -> Result<()> {
275 let config = DsqlConnectOptions::from_connection_string(
276 "postgres://admin@example.dsql.us-west-2.on.aws/postgres?region=us-west-2",
277 )?;
278
279 assert_eq!(
280 config.region.as_ref().map(|r| r.as_ref()),
281 Some("us-west-2")
282 );
283 Ok(())
284 }
285
286 #[test]
287 fn test_parse_with_profile_param() -> Result<()> {
288 let config = DsqlConnectOptions::from_connection_string(
289 "postgres://admin@example.dsql.us-east-1.on.aws/postgres?profile=dev",
290 )?;
291
292 assert_eq!(config.profile, Some("dev".to_string()));
293 Ok(())
294 }
295
296 #[test]
297 fn test_parse_with_region_and_profile() -> Result<()> {
298 let config = DsqlConnectOptions::from_connection_string(
299 "postgres://admin@example.dsql.us-east-1.on.aws/postgres?region=us-east-1&profile=prod",
300 )?;
301
302 assert_eq!(
303 config.region.as_ref().map(|r| r.as_ref()),
304 Some("us-east-1")
305 );
306 assert_eq!(config.profile, Some("prod".to_string()));
307 Ok(())
308 }
309
310 #[test]
311 fn test_invalid_connection_string() {
312 let result = DsqlConnectOptions::from_connection_string("invalid://connection");
313 assert!(result.is_err());
314 }
315
316 #[test]
317 fn test_postgresql_scheme_alias() -> Result<()> {
318 let config = DsqlConnectOptions::from_connection_string(
319 "postgresql://admin@example.dsql.us-east-1.on.aws/postgres",
320 )?;
321
322 assert_eq!(
323 config.pg_connect_options.get_host(),
324 "example.dsql.us-east-1.on.aws"
325 );
326 assert_eq!(config.pg_connect_options.get_username(), "admin");
327 Ok(())
328 }
329
330 #[test]
331 fn test_parse_query_params() -> Result<()> {
332 let config = DsqlConnectOptions::from_connection_string(
333 "postgres://admin@example.dsql.us-east-1.on.aws/postgres?\
334 tokenDurationSecs=900&ormPrefix=myapp",
335 )?;
336
337 assert_eq!(config.token_duration_secs, 900);
338 assert!(
339 config
340 .pg_connect_options
341 .get_application_name()
342 .unwrap()
343 .starts_with("myapp:aurora-dsql-rust-sqlx/"),
344 "ormPrefix should be prepended to application_name"
345 );
346 Ok(())
347 }
348
349 #[test]
350 fn test_parse_cluster_id_stores_raw_host() -> Result<()> {
351 let config = DsqlConnectOptions::from_connection_string(
352 "postgres://admin@abcdefghijklmnopqrstuvwxyz/postgres?region=us-east-1",
353 )?;
354
355 assert_eq!(
356 config.pg_connect_options.get_host(),
357 "abcdefghijklmnopqrstuvwxyz"
358 );
359 assert_eq!(
360 config.region.as_ref().map(|r| r.as_ref()),
361 Some("us-east-1")
362 );
363 Ok(())
364 }
365
366 #[tokio::test]
369 async fn test_resolve_region_from_param() -> Result<()> {
370 let config = DsqlConnectOptions::from_connection_string(
371 "postgres://admin@example.dsql.us-east-1.on.aws/postgres?region=us-east-1",
372 )?;
373
374 let sdk_config = load_aws_config(config.profile(), config.credentials_provider()).await;
375 let region = config.resolve_region(&sdk_config)?;
376 assert_eq!(region.as_ref(), "us-east-1");
377 Ok(())
378 }
379
380 #[tokio::test]
381 async fn test_resolve_region_from_hostname() -> Result<()> {
382 let config = DsqlConnectOptions::from_connection_string(
383 "postgres://admin@example.dsql.us-west-2.on.aws/postgres",
384 )?;
385
386 let sdk_config = load_aws_config(config.profile(), config.credentials_provider()).await;
387 let region = config.resolve_region(&sdk_config)?;
388 assert_eq!(region.as_ref(), "us-west-2");
389 Ok(())
390 }
391
392 #[tokio::test]
393 async fn test_resolve_host_expands_cluster_id() -> Result<()> {
394 let config = DsqlConnectOptions::from_connection_string(
395 "postgres://admin@abcdefghijklmnopqrstuvwxyz/postgres?region=us-east-1",
396 )?;
397
398 let sdk_config = load_aws_config(config.profile(), config.credentials_provider()).await;
399 let host = config.resolve_host(&sdk_config)?;
400 assert_eq!(host, "abcdefghijklmnopqrstuvwxyz.dsql.us-east-1.on.aws");
401 Ok(())
402 }
403
404 #[tokio::test]
405 async fn test_resolve_host_noop_for_full_hostname() -> Result<()> {
406 let config = DsqlConnectOptions::from_connection_string(
407 "postgres://admin@example.dsql.us-east-1.on.aws/postgres",
408 )?;
409
410 let sdk_config = load_aws_config(config.profile(), config.credentials_provider()).await;
411 let host = config.resolve_host(&sdk_config)?;
412 assert_eq!(host, "example.dsql.us-east-1.on.aws");
413 Ok(())
414 }
415
416 #[test]
419 fn test_builder_rejects_empty_host() {
420 let pg = PgConnectOptions::new()
421 .host("")
422 .username("admin")
423 .database("postgres");
424
425 let result = DsqlConnectOptionsBuilder::default()
426 .pg_connect_options(pg)
427 .build();
428
429 assert!(result.is_err());
430 let err = result.unwrap_err().to_string();
431 assert!(
432 err.contains("Host is required"),
433 "Expected host error, got: {}",
434 err
435 );
436 }
437
438 #[tokio::test]
441 async fn test_build_connect_options() -> Result<()> {
442 let config = DsqlConnectOptions::from_connection_string(
443 "postgres://admin@example.dsql.us-east-1.on.aws/postgres",
444 )?;
445
446 let sdk_config = load_aws_config(config.profile(), config.credentials_provider()).await;
447 let opts = config.build_connect_options(&sdk_config, "test-token")?;
448 assert_eq!(opts.get_host(), "example.dsql.us-east-1.on.aws");
449 assert_eq!(opts.get_port(), 5432);
450 assert_eq!(opts.get_username(), "admin");
451 assert_eq!(opts.get_database().unwrap(), "postgres");
452 assert!(matches!(opts.get_ssl_mode(), PgSslMode::VerifyFull));
453 Ok(())
454 }
455
456 #[tokio::test]
457 async fn test_build_connect_options_with_cluster_id() -> Result<()> {
458 let config = DsqlConnectOptions::from_connection_string(
459 "postgres://admin@abcdefghijklmnopqrstuvwxyz/postgres?region=us-east-1",
460 )?;
461
462 let sdk_config = load_aws_config(config.profile(), config.credentials_provider()).await;
463 let opts = config.build_connect_options(&sdk_config, "test-token")?;
464 assert_eq!(
465 opts.get_host(),
466 "abcdefghijklmnopqrstuvwxyz.dsql.us-east-1.on.aws",
467 );
468 Ok(())
469 }
470
471 #[tokio::test]
472 async fn test_connect_options_default_application_name() -> Result<()> {
473 let config = DsqlConnectOptions::from_connection_string(
474 "postgres://admin@example.dsql.us-east-1.on.aws/postgres",
475 )?;
476
477 let sdk_config = load_aws_config(config.profile(), config.credentials_provider()).await;
478 let opts = config.build_connect_options(&sdk_config, "test-token")?;
479 let app_name = opts
480 .get_application_name()
481 .expect("application_name should be set");
482 assert!(app_name.starts_with("aurora-dsql-rust-sqlx/"));
483 Ok(())
484 }
485
486 #[tokio::test]
487 async fn test_connect_options_with_orm_prefix() -> Result<()> {
488 let config = DsqlConnectOptions::from_connection_string(
489 "postgres://admin@example.dsql.us-east-1.on.aws/postgres?ormPrefix=my-service",
490 )?;
491
492 let sdk_config = load_aws_config(config.profile(), config.credentials_provider()).await;
493 let opts = config.build_connect_options(&sdk_config, "test-token")?;
494 assert!(
495 opts.get_application_name()
496 .unwrap()
497 .starts_with("my-service:aurora-dsql-rust-sqlx/"),
498 "ormPrefix should be prepended to application_name"
499 );
500 Ok(())
501 }
502
503 #[test]
504 fn test_ssl_mode_always_verify_full() {
505 let config = DsqlConnectOptions::from_connection_string(
506 "postgres://admin@example.dsql.us-east-1.on.aws/postgres",
507 )
508 .unwrap();
509
510 assert!(matches!(
511 config.pg_connect_options.get_ssl_mode(),
512 PgSslMode::VerifyFull
513 ));
514 }
515
516 #[tokio::test]
517 async fn test_ssl_mode_enforced_via_builder() -> Result<()> {
518 let pg = PgConnectOptions::new()
519 .host("example.dsql.us-east-1.on.aws")
520 .username("admin")
521 .database("postgres")
522 .ssl_mode(PgSslMode::Prefer); let config = DsqlConnectOptionsBuilder::default()
525 .pg_connect_options(pg)
526 .build()
527 .unwrap();
528
529 let sdk_config = load_aws_config(config.profile(), config.credentials_provider()).await;
530 let opts = config.build_connect_options(&sdk_config, "test-token")?;
531 assert!(
532 matches!(opts.get_ssl_mode(), PgSslMode::VerifyFull),
533 "SSL must be VerifyFull regardless of builder input"
534 );
535 Ok(())
536 }
537
538 #[test]
541 #[cfg(feature = "pool")]
542 fn test_refresh_interval_default() {
543 let config = DsqlConnectOptions::from_connection_string(
544 "postgres://admin@example.dsql.us-east-1.on.aws/postgres",
545 )
546 .unwrap();
547
548 assert_eq!(config.refresh_interval(), Duration::from_secs(720));
549 }
550
551 #[test]
552 #[cfg(feature = "pool")]
553 fn test_refresh_interval_floors_to_one_second() {
554 let pg = PgConnectOptions::new()
555 .host("example.dsql.us-east-1.on.aws")
556 .username("admin")
557 .database("postgres");
558
559 let config = DsqlConnectOptionsBuilder::default()
560 .pg_connect_options(pg)
561 .token_duration_secs(1u64)
562 .build()
563 .unwrap();
564
565 assert_eq!(config.refresh_interval(), Duration::from_secs(1));
566 }
567
568 #[test]
571 fn test_from_connection_string_has_no_credentials_provider() {
572 let config = DsqlConnectOptions::from_connection_string(
573 "postgres://admin@example.dsql.us-east-1.on.aws/postgres",
574 )
575 .unwrap();
576
577 assert!(config.credentials_provider.is_none());
578 }
579
580 #[test]
581 fn test_builder_with_custom_credentials_provider() {
582 use aws_credential_types::provider::SharedCredentialsProvider;
583 use aws_credential_types::Credentials;
584
585 let creds = Credentials::new("custom_key", "custom_secret", None, None, "test");
586 let provider = SharedCredentialsProvider::new(creds);
587
588 let pg = PgConnectOptions::new()
589 .host("example.dsql.us-east-1.on.aws")
590 .username("admin")
591 .database("postgres");
592
593 let config = DsqlConnectOptionsBuilder::default()
594 .pg_connect_options(pg)
595 .credentials_provider(provider)
596 .build()
597 .unwrap();
598
599 assert!(config.credentials_provider.is_some());
600 }
601
602 #[test]
603 fn test_builder_without_credentials_provider() {
604 let pg = PgConnectOptions::new()
605 .host("example.dsql.us-east-1.on.aws")
606 .username("admin")
607 .database("postgres");
608
609 let config = DsqlConnectOptionsBuilder::default()
610 .pg_connect_options(pg)
611 .build()
612 .unwrap();
613
614 assert!(config.credentials_provider.is_none());
615 }
616
617 #[tokio::test]
618 async fn test_load_aws_config_with_custom_credentials() {
619 use aws_credential_types::provider::{ProvideCredentials, SharedCredentialsProvider};
620 use aws_credential_types::Credentials;
621
622 let creds = Credentials::new("custom_key", "custom_secret", None, None, "test");
623 let provider = SharedCredentialsProvider::new(creds);
624
625 let sdk_config = load_aws_config(None, Some(&provider)).await;
626 let resolved = sdk_config
627 .credentials_provider()
628 .expect("SdkConfig should have a credentials provider")
629 .provide_credentials()
630 .await
631 .expect("should resolve credentials");
632 assert_eq!(resolved.access_key_id(), "custom_key");
633 assert_eq!(resolved.secret_access_key(), "custom_secret");
634 }
635}