Skip to main content

aurora_dsql_sqlx_connector/
config.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::{util::ClusterId, DsqlError, Result};
5use aws_config::{Region, SdkConfig};
6use aws_credential_types::provider::SharedCredentialsProvider;
7use derive_builder::Builder;
8use sqlx::postgres::{PgConnectOptions, PgSslMode};
9#[cfg(feature = "pool")]
10use std::time::Duration;
11use url::Url;
12
13const DEFAULT_USER: &str = "admin";
14const DEFAULT_DATABASE: &str = "postgres";
15const DEFAULT_PORT: u16 = 5432;
16const DEFAULT_TOKEN_DURATION_SECS: u64 = 900;
17
18#[derive(Debug, Clone, Builder)]
19#[builder(setter(into), build_fn(validate = "Self::validate"))]
20pub struct DsqlConnectOptions {
21    pg_connect_options: PgConnectOptions,
22    #[builder(default)]
23    region: Option<Region>,
24    #[builder(default)]
25    profile: Option<String>,
26    #[builder(default = "DEFAULT_TOKEN_DURATION_SECS")]
27    token_duration_secs: u64,
28    #[builder(default)]
29    orm_prefix: Option<String>,
30    #[builder(default)]
31    credentials_provider: Option<SharedCredentialsProvider>,
32}
33
34impl DsqlConnectOptionsBuilder {
35    fn validate(&self) -> std::result::Result<(), String> {
36        if let Some(ref pg) = self.pg_connect_options {
37            crate::util::validate_host(pg.get_host())?;
38        }
39        Ok(())
40    }
41}
42
43impl DsqlConnectOptions {
44    pub fn from_connection_string(conn_str: &str) -> Result<Self> {
45        let url = Self::parse_url(conn_str)?;
46        Self::from_url(&url)
47    }
48
49    fn parse_url(conn_str: &str) -> Result<Url> {
50        let url = Url::parse(conn_str).map_err(|e| DsqlError::ConfigError(e.into()))?;
51
52        match url.scheme() {
53            "postgres" | "postgresql" => {}
54            _ => {
55                return Err(DsqlError::ConfigError(
56                    "Unsupported URL scheme. Use 'postgres://' or 'postgresql://'".into(),
57                ));
58            }
59        }
60
61        Ok(url)
62    }
63
64    fn from_url(url: &Url) -> Result<Self> {
65        let host = url
66            .host_str()
67            .ok_or_else(|| DsqlError::ConfigError("Host is required".into()))?;
68
69        crate::util::validate_host(host).map_err(|e| DsqlError::ConfigError(e.into()))?;
70
71        let port = url.port().unwrap_or(DEFAULT_PORT);
72
73        let user = if !url.username().is_empty() {
74            url.username()
75        } else {
76            DEFAULT_USER
77        };
78
79        let database = {
80            let db = url.path().trim_start_matches('/');
81            if db.is_empty() {
82                DEFAULT_DATABASE
83            } else {
84                db
85            }
86        };
87
88        let mut region = None;
89        let mut profile = None;
90        let mut token_duration_secs = DEFAULT_TOKEN_DURATION_SECS;
91        let mut orm_prefix = None;
92
93        for (key, value) in url.query_pairs() {
94            match key.as_ref() {
95                "region" => {
96                    region = Some(Region::new(value.to_string()));
97                }
98                "profile" => profile = Some(value.to_string()),
99                "tokenDurationSecs" => {
100                    let secs: u64 = value
101                        .parse()
102                        .map_err(|e: std::num::ParseIntError| DsqlError::ConfigError(e.into()))?;
103                    token_duration_secs = secs;
104                }
105                "ormPrefix" => orm_prefix = Some(value.to_string()),
106                other => {
107                    log::debug!(
108                        "aurora-dsql: ignoring unrecognized connection parameter: {}",
109                        other
110                    );
111                }
112            }
113        }
114
115        let app_name = crate::util::build_application_name(orm_prefix.as_deref());
116
117        let pg = PgConnectOptions::new()
118            .host(host)
119            .port(port)
120            .username(user)
121            .database(database)
122            .ssl_mode(PgSslMode::VerifyFull)
123            .application_name(&app_name);
124
125        Ok(DsqlConnectOptions {
126            pg_connect_options: pg,
127            region,
128            profile,
129            token_duration_secs,
130            orm_prefix,
131            credentials_provider: None,
132        })
133    }
134
135    /// Generate a fresh IAM token and return `PgConnectOptions` ready for use.
136    ///
137    /// This is the main entry point for advanced use cases where you need
138    /// to supply your own `PgPoolOptions` or manage connections directly.
139    pub async fn authenticated_pg_options(&self) -> Result<PgConnectOptions> {
140        let sdk_config = load_aws_config(self.profile(), self.credentials_provider()).await;
141        let host = self.resolve_host(&sdk_config)?;
142        let region = self.resolve_region(&sdk_config)?;
143        let signer =
144            crate::token::build_signer(&host, &region, &sdk_config, Some(self.token_duration()))?;
145        let user = self.pg_connect_options.get_username();
146        let token = crate::token::generate_token(&signer, user, &sdk_config).await?;
147        self.build_connect_options(&sdk_config, &token)
148    }
149
150    /// Clone the inner PgConnectOptions with the resolved host and token as password.
151    /// If the host is a bare cluster ID, it is expanded to a full DSQL hostname.
152    /// Always enforces `SslMode::VerifyFull` regardless of how the config was constructed.
153    pub(crate) fn build_connect_options(
154        &self,
155        sdk_config: &SdkConfig,
156        token: &str,
157    ) -> Result<PgConnectOptions> {
158        let host = self.resolve_host(sdk_config)?;
159        let app_name = crate::util::build_application_name(self.orm_prefix.as_deref());
160        Ok(self
161            .pg_connect_options
162            .clone()
163            .host(&host)
164            .password(token)
165            .ssl_mode(PgSslMode::VerifyFull)
166            .application_name(&app_name))
167    }
168
169    /// Read access to the inner PgConnectOptions.
170    #[cfg(feature = "pool")]
171    pub(crate) fn pg_connect_options(&self) -> &PgConnectOptions {
172        &self.pg_connect_options
173    }
174
175    /// AWS profile name, if configured.
176    pub(crate) fn profile(&self) -> Option<&str> {
177        self.profile.as_deref()
178    }
179
180    /// Custom credentials provider, if configured.
181    pub(crate) fn credentials_provider(&self) -> Option<&SharedCredentialsProvider> {
182        self.credentials_provider.as_ref()
183    }
184
185    /// Token validity duration in seconds. Defaults to 900s.
186    pub(crate) fn token_duration(&self) -> u64 {
187        self.token_duration_secs
188    }
189
190    /// How often the background refresh task should rotate tokens.
191    /// Returns `token_duration * 4/5` (80%).
192    #[cfg(feature = "pool")]
193    pub(crate) fn refresh_interval(&self) -> Duration {
194        Duration::from_secs((self.token_duration() * 4 / 5).max(1))
195    }
196
197    /// If host is a bare cluster ID, expand it to a full DSQL hostname.
198    pub(crate) fn resolve_host(&self, sdk_config: &SdkConfig) -> Result<String> {
199        let host = self.pg_connect_options.get_host();
200        if let Some(cluster_id) = ClusterId::new(host) {
201            let region = self.resolve_region(sdk_config)?;
202            Ok(crate::util::build_hostname(&cluster_id, &region))
203        } else {
204            Ok(host.to_string())
205        }
206    }
207
208    pub(crate) fn resolve_region(&self, sdk_config: &SdkConfig) -> Result<Region> {
209        // 1. Parse from hostname
210        let host = self.pg_connect_options.get_host();
211        if let Some(region) = crate::util::parse_region(host) {
212            return Ok(region);
213        }
214
215        // 2. Explicit region
216        if let Some(ref region) = self.region {
217            return Ok(region.clone());
218        }
219
220        // 3. AWS SDK default region
221        if let Some(region) = sdk_config.region() {
222            return Ok(region.clone());
223        }
224
225        Err(DsqlError::ConfigError(
226            "Could not determine region from connection string, hostname, or AWS configuration"
227                .into(),
228        ))
229    }
230}
231
232/// Load AWS SDK config, optionally using a named profile and/or custom credentials.
233pub(crate) async fn load_aws_config(
234    profile: Option<&str>,
235    credentials_provider: Option<&SharedCredentialsProvider>,
236) -> SdkConfig {
237    let mut loader = aws_config::defaults(aws_config::BehaviorVersion::latest());
238    if let Some(profile) = profile {
239        loader = loader.profile_name(profile);
240    }
241    if let Some(provider) = credentials_provider {
242        loader = loader.credentials_provider(provider.clone());
243    }
244    loader.load().await
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250
251    // --- from_connection_string tests ---
252
253    #[test]
254    fn test_parse_basic_connection_string() -> Result<()> {
255        let config = DsqlConnectOptions::from_connection_string(
256            "postgres://admin@example.dsql.us-east-1.on.aws:5432/postgres",
257        )?;
258
259        assert_eq!(config.pg_connect_options.get_username(), "admin");
260        assert_eq!(
261            config.pg_connect_options.get_host(),
262            "example.dsql.us-east-1.on.aws"
263        );
264        assert_eq!(config.pg_connect_options.get_port(), 5432);
265        assert_eq!(
266            config.pg_connect_options.get_database().unwrap(),
267            "postgres"
268        );
269        assert!(config.region.is_none());
270        Ok(())
271    }
272
273    #[test]
274    fn test_parse_with_region_param() -> Result<()> {
275        let config = DsqlConnectOptions::from_connection_string(
276            "postgres://admin@example.dsql.us-west-2.on.aws/postgres?region=us-west-2",
277        )?;
278
279        assert_eq!(
280            config.region.as_ref().map(|r| r.as_ref()),
281            Some("us-west-2")
282        );
283        Ok(())
284    }
285
286    #[test]
287    fn test_parse_with_profile_param() -> Result<()> {
288        let config = DsqlConnectOptions::from_connection_string(
289            "postgres://admin@example.dsql.us-east-1.on.aws/postgres?profile=dev",
290        )?;
291
292        assert_eq!(config.profile, Some("dev".to_string()));
293        Ok(())
294    }
295
296    #[test]
297    fn test_parse_with_region_and_profile() -> Result<()> {
298        let config = DsqlConnectOptions::from_connection_string(
299            "postgres://admin@example.dsql.us-east-1.on.aws/postgres?region=us-east-1&profile=prod",
300        )?;
301
302        assert_eq!(
303            config.region.as_ref().map(|r| r.as_ref()),
304            Some("us-east-1")
305        );
306        assert_eq!(config.profile, Some("prod".to_string()));
307        Ok(())
308    }
309
310    #[test]
311    fn test_invalid_connection_string() {
312        let result = DsqlConnectOptions::from_connection_string("invalid://connection");
313        assert!(result.is_err());
314    }
315
316    #[test]
317    fn test_postgresql_scheme_alias() -> Result<()> {
318        let config = DsqlConnectOptions::from_connection_string(
319            "postgresql://admin@example.dsql.us-east-1.on.aws/postgres",
320        )?;
321
322        assert_eq!(
323            config.pg_connect_options.get_host(),
324            "example.dsql.us-east-1.on.aws"
325        );
326        assert_eq!(config.pg_connect_options.get_username(), "admin");
327        Ok(())
328    }
329
330    #[test]
331    fn test_parse_query_params() -> Result<()> {
332        let config = DsqlConnectOptions::from_connection_string(
333            "postgres://admin@example.dsql.us-east-1.on.aws/postgres?\
334             tokenDurationSecs=900&ormPrefix=myapp",
335        )?;
336
337        assert_eq!(config.token_duration_secs, 900);
338        assert!(
339            config
340                .pg_connect_options
341                .get_application_name()
342                .unwrap()
343                .starts_with("myapp:aurora-dsql-rust-sqlx/"),
344            "ormPrefix should be prepended to application_name"
345        );
346        Ok(())
347    }
348
349    #[test]
350    fn test_parse_cluster_id_stores_raw_host() -> Result<()> {
351        let config = DsqlConnectOptions::from_connection_string(
352            "postgres://admin@abcdefghijklmnopqrstuvwxyz/postgres?region=us-east-1",
353        )?;
354
355        assert_eq!(
356            config.pg_connect_options.get_host(),
357            "abcdefghijklmnopqrstuvwxyz"
358        );
359        assert_eq!(
360            config.region.as_ref().map(|r| r.as_ref()),
361            Some("us-east-1")
362        );
363        Ok(())
364    }
365
366    // --- resolve_host / resolve_region tests ---
367
368    #[tokio::test]
369    async fn test_resolve_region_from_param() -> Result<()> {
370        let config = DsqlConnectOptions::from_connection_string(
371            "postgres://admin@example.dsql.us-east-1.on.aws/postgres?region=us-east-1",
372        )?;
373
374        let sdk_config = load_aws_config(config.profile(), config.credentials_provider()).await;
375        let region = config.resolve_region(&sdk_config)?;
376        assert_eq!(region.as_ref(), "us-east-1");
377        Ok(())
378    }
379
380    #[tokio::test]
381    async fn test_resolve_region_from_hostname() -> Result<()> {
382        let config = DsqlConnectOptions::from_connection_string(
383            "postgres://admin@example.dsql.us-west-2.on.aws/postgres",
384        )?;
385
386        let sdk_config = load_aws_config(config.profile(), config.credentials_provider()).await;
387        let region = config.resolve_region(&sdk_config)?;
388        assert_eq!(region.as_ref(), "us-west-2");
389        Ok(())
390    }
391
392    #[tokio::test]
393    async fn test_resolve_host_expands_cluster_id() -> Result<()> {
394        let config = DsqlConnectOptions::from_connection_string(
395            "postgres://admin@abcdefghijklmnopqrstuvwxyz/postgres?region=us-east-1",
396        )?;
397
398        let sdk_config = load_aws_config(config.profile(), config.credentials_provider()).await;
399        let host = config.resolve_host(&sdk_config)?;
400        assert_eq!(host, "abcdefghijklmnopqrstuvwxyz.dsql.us-east-1.on.aws");
401        Ok(())
402    }
403
404    #[tokio::test]
405    async fn test_resolve_host_noop_for_full_hostname() -> Result<()> {
406        let config = DsqlConnectOptions::from_connection_string(
407            "postgres://admin@example.dsql.us-east-1.on.aws/postgres",
408        )?;
409
410        let sdk_config = load_aws_config(config.profile(), config.credentials_provider()).await;
411        let host = config.resolve_host(&sdk_config)?;
412        assert_eq!(host, "example.dsql.us-east-1.on.aws");
413        Ok(())
414    }
415
416    // --- builder tests ---
417
418    #[test]
419    fn test_builder_rejects_empty_host() {
420        let pg = PgConnectOptions::new()
421            .host("")
422            .username("admin")
423            .database("postgres");
424
425        let result = DsqlConnectOptionsBuilder::default()
426            .pg_connect_options(pg)
427            .build();
428
429        assert!(result.is_err());
430        let err = result.unwrap_err().to_string();
431        assert!(
432            err.contains("Host is required"),
433            "Expected host error, got: {}",
434            err
435        );
436    }
437
438    // --- build_connect_options tests ---
439
440    #[tokio::test]
441    async fn test_build_connect_options() -> Result<()> {
442        let config = DsqlConnectOptions::from_connection_string(
443            "postgres://admin@example.dsql.us-east-1.on.aws/postgres",
444        )?;
445
446        let sdk_config = load_aws_config(config.profile(), config.credentials_provider()).await;
447        let opts = config.build_connect_options(&sdk_config, "test-token")?;
448        assert_eq!(opts.get_host(), "example.dsql.us-east-1.on.aws");
449        assert_eq!(opts.get_port(), 5432);
450        assert_eq!(opts.get_username(), "admin");
451        assert_eq!(opts.get_database().unwrap(), "postgres");
452        assert!(matches!(opts.get_ssl_mode(), PgSslMode::VerifyFull));
453        Ok(())
454    }
455
456    #[tokio::test]
457    async fn test_build_connect_options_with_cluster_id() -> Result<()> {
458        let config = DsqlConnectOptions::from_connection_string(
459            "postgres://admin@abcdefghijklmnopqrstuvwxyz/postgres?region=us-east-1",
460        )?;
461
462        let sdk_config = load_aws_config(config.profile(), config.credentials_provider()).await;
463        let opts = config.build_connect_options(&sdk_config, "test-token")?;
464        assert_eq!(
465            opts.get_host(),
466            "abcdefghijklmnopqrstuvwxyz.dsql.us-east-1.on.aws",
467        );
468        Ok(())
469    }
470
471    #[tokio::test]
472    async fn test_connect_options_default_application_name() -> Result<()> {
473        let config = DsqlConnectOptions::from_connection_string(
474            "postgres://admin@example.dsql.us-east-1.on.aws/postgres",
475        )?;
476
477        let sdk_config = load_aws_config(config.profile(), config.credentials_provider()).await;
478        let opts = config.build_connect_options(&sdk_config, "test-token")?;
479        let app_name = opts
480            .get_application_name()
481            .expect("application_name should be set");
482        assert!(app_name.starts_with("aurora-dsql-rust-sqlx/"));
483        Ok(())
484    }
485
486    #[tokio::test]
487    async fn test_connect_options_with_orm_prefix() -> Result<()> {
488        let config = DsqlConnectOptions::from_connection_string(
489            "postgres://admin@example.dsql.us-east-1.on.aws/postgres?ormPrefix=my-service",
490        )?;
491
492        let sdk_config = load_aws_config(config.profile(), config.credentials_provider()).await;
493        let opts = config.build_connect_options(&sdk_config, "test-token")?;
494        assert!(
495            opts.get_application_name()
496                .unwrap()
497                .starts_with("my-service:aurora-dsql-rust-sqlx/"),
498            "ormPrefix should be prepended to application_name"
499        );
500        Ok(())
501    }
502
503    #[test]
504    fn test_ssl_mode_always_verify_full() {
505        let config = DsqlConnectOptions::from_connection_string(
506            "postgres://admin@example.dsql.us-east-1.on.aws/postgres",
507        )
508        .unwrap();
509
510        assert!(matches!(
511            config.pg_connect_options.get_ssl_mode(),
512            PgSslMode::VerifyFull
513        ));
514    }
515
516    #[tokio::test]
517    async fn test_ssl_mode_enforced_via_builder() -> Result<()> {
518        let pg = PgConnectOptions::new()
519            .host("example.dsql.us-east-1.on.aws")
520            .username("admin")
521            .database("postgres")
522            .ssl_mode(PgSslMode::Prefer); // intentionally weak
523
524        let config = DsqlConnectOptionsBuilder::default()
525            .pg_connect_options(pg)
526            .build()
527            .unwrap();
528
529        let sdk_config = load_aws_config(config.profile(), config.credentials_provider()).await;
530        let opts = config.build_connect_options(&sdk_config, "test-token")?;
531        assert!(
532            matches!(opts.get_ssl_mode(), PgSslMode::VerifyFull),
533            "SSL must be VerifyFull regardless of builder input"
534        );
535        Ok(())
536    }
537
538    // --- refresh_interval tests ---
539
540    #[test]
541    #[cfg(feature = "pool")]
542    fn test_refresh_interval_default() {
543        let config = DsqlConnectOptions::from_connection_string(
544            "postgres://admin@example.dsql.us-east-1.on.aws/postgres",
545        )
546        .unwrap();
547
548        assert_eq!(config.refresh_interval(), Duration::from_secs(720));
549    }
550
551    #[test]
552    #[cfg(feature = "pool")]
553    fn test_refresh_interval_floors_to_one_second() {
554        let pg = PgConnectOptions::new()
555            .host("example.dsql.us-east-1.on.aws")
556            .username("admin")
557            .database("postgres");
558
559        let config = DsqlConnectOptionsBuilder::default()
560            .pg_connect_options(pg)
561            .token_duration_secs(1u64)
562            .build()
563            .unwrap();
564
565        assert_eq!(config.refresh_interval(), Duration::from_secs(1));
566    }
567
568    // --- credentials_provider tests ---
569
570    #[test]
571    fn test_from_connection_string_has_no_credentials_provider() {
572        let config = DsqlConnectOptions::from_connection_string(
573            "postgres://admin@example.dsql.us-east-1.on.aws/postgres",
574        )
575        .unwrap();
576
577        assert!(config.credentials_provider.is_none());
578    }
579
580    #[test]
581    fn test_builder_with_custom_credentials_provider() {
582        use aws_credential_types::provider::SharedCredentialsProvider;
583        use aws_credential_types::Credentials;
584
585        let creds = Credentials::new("custom_key", "custom_secret", None, None, "test");
586        let provider = SharedCredentialsProvider::new(creds);
587
588        let pg = PgConnectOptions::new()
589            .host("example.dsql.us-east-1.on.aws")
590            .username("admin")
591            .database("postgres");
592
593        let config = DsqlConnectOptionsBuilder::default()
594            .pg_connect_options(pg)
595            .credentials_provider(provider)
596            .build()
597            .unwrap();
598
599        assert!(config.credentials_provider.is_some());
600    }
601
602    #[test]
603    fn test_builder_without_credentials_provider() {
604        let pg = PgConnectOptions::new()
605            .host("example.dsql.us-east-1.on.aws")
606            .username("admin")
607            .database("postgres");
608
609        let config = DsqlConnectOptionsBuilder::default()
610            .pg_connect_options(pg)
611            .build()
612            .unwrap();
613
614        assert!(config.credentials_provider.is_none());
615    }
616
617    #[tokio::test]
618    async fn test_load_aws_config_with_custom_credentials() {
619        use aws_credential_types::provider::{ProvideCredentials, SharedCredentialsProvider};
620        use aws_credential_types::Credentials;
621
622        let creds = Credentials::new("custom_key", "custom_secret", None, None, "test");
623        let provider = SharedCredentialsProvider::new(creds);
624
625        let sdk_config = load_aws_config(None, Some(&provider)).await;
626        let resolved = sdk_config
627            .credentials_provider()
628            .expect("SdkConfig should have a credentials provider")
629            .provide_credentials()
630            .await
631            .expect("should resolve credentials");
632        assert_eq!(resolved.access_key_id(), "custom_key");
633        assert_eq!(resolved.secret_access_key(), "custom_secret");
634    }
635}