Skip to main content

redshift_iam/
lib.rs

1// inspired by github.com/aws/amazon-redshift-python-driver
2// provides saml and IAM temp credential login
3
4#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/README.md"))]
5
6use std::borrow::Cow;
7use std::collections::HashMap;
8use std::sync::{Arc, Mutex, OnceLock};
9
10use arrow::record_batch::RecordBatch;
11use aws_credential_types::provider::ProvideCredentials;
12use aws_sdk_sts as sts;
13use connectorx::errors::ConnectorXOutError;
14use log::{debug, error};
15use secrecy::{ExposeSecret, SecretString};
16use tokio::runtime::Runtime;
17
18#[doc(hidden)]
19pub mod iam_provider;
20#[doc(hidden)]
21pub mod redshift;
22pub mod saml_provider;
23
24pub(crate) mod re {
25    use regex::Regex;
26
27    pub fn compile(pattern: &str) -> Regex {
28        Regex::new(pattern).unwrap()
29    }
30}
31
32// Re-export public API at crate root so structs and traits appear at the
33// top level in docs and can be imported as `redshift_iam::PingCredentialsProvider`.
34pub use iam_provider::IamProvider;
35pub use redshift::Redshift;
36pub use saml_provider::{PingCredentialsProvider, SamlProvider};
37
38#[doc(hidden)]
39pub mod prelude {
40    pub use crate::iam_provider::IamProvider;
41    pub use crate::redshift::Redshift;
42    pub use crate::saml_provider::PingCredentialsProvider;
43}
44
45/// Identifies the SAML provider plugin to use when an IdP host is present in the
46/// connection URI.
47///
48/// The `Plugin_Name` query parameter in the JDBC URI is parsed into one of these
49/// variants. The optional `com.amazon.redshift.plugin.` prefix is stripped
50/// automatically, so both `"PingCredentialsProvider"` and
51/// `"com.amazon.redshift.plugin.PingCredentialsProvider"` resolve to
52/// [`PluginName::PingCredentialsProvider`].
53///
54/// Only [`PluginName::PingCredentialsProvider`] has a built-in factory.
55/// All other variants require a factory to be registered via [`register_provider`]
56/// before calling [`read_sql`].
57#[derive(Debug, Clone, PartialEq, Eq, Hash)]
58pub enum PluginName {
59    /// PingFederate IdP (built-in — backed by [`PingCredentialsProvider`]).
60    PingCredentialsProvider,
61    /// Okta IdP.
62    OktaCredentialsProvider,
63    /// Browser-based SAML flow.
64    BrowserSamlCredentialsProvider,
65    /// Browser-based Azure AD SAML flow.
66    BrowserAzureCredentialsProvider,
67    /// Azure AD IdP.
68    AzureCredentialsProvider,
69    /// ADFS IdP.
70    AdfsCredentialsProvider,
71    /// User-defined custom provider.
72    CustomCredentialsProvider,
73    /// Fallback for unrecognised `Plugin_Name` values.
74    UnknownCredentialsProvider,
75}
76
77impl From<&str> for PluginName {
78    /// Converts a `Plugin_Name` URI parameter value to a `PluginName` variant.
79    ///
80    /// The optional `com.amazon.redshift.plugin.` package prefix is stripped
81    /// before matching. Comparison is case-insensitive. Unrecognised strings
82    /// map to [`PluginName::UnknownCredentialsProvider`].
83    fn from(s: &str) -> Self {
84        let name = s
85            .trim()
86            .trim_start_matches("com.amazon.redshift.plugin.")
87            .to_lowercase();
88        match name.as_str() {
89            "pingcredentialsprovider" => Self::PingCredentialsProvider,
90            "oktacredentialsprovider" => Self::OktaCredentialsProvider,
91            "browsersamlcredentialsprovider" => Self::BrowserSamlCredentialsProvider,
92            "browserazurecredentialsprovider" => Self::BrowserAzureCredentialsProvider,
93            "azurecredentialsprovider" => Self::AzureCredentialsProvider,
94            "adfscredentialsprovider" => Self::AdfsCredentialsProvider,
95            "customcredentialsprovider" => Self::CustomCredentialsProvider,
96            _ => Self::UnknownCredentialsProvider,
97        }
98    }
99}
100
101/// Type-erased factory function stored in the provider registry.
102type ProviderFactory = Arc<
103    dyn Fn(
104            &HashMap<String, Cow<str>>,
105            &str,
106            Option<u16>,
107            &str,
108            SecretString,
109        ) -> Box<dyn SamlProvider>
110        + Send
111        + Sync,
112>;
113
114static PROVIDER_REGISTRY: OnceLock<Mutex<HashMap<PluginName, ProviderFactory>>> = OnceLock::new();
115
116/// Returns the global provider registry, pre-populated with the built-in
117/// [`PluginName::PingCredentialsProvider`] -> [`PingCredentialsProvider`] mapping.
118fn registry() -> &'static Mutex<HashMap<PluginName, ProviderFactory>> {
119    PROVIDER_REGISTRY.get_or_init(|| {
120        let mut map: HashMap<PluginName, ProviderFactory> = HashMap::new();
121        map.insert(
122            PluginName::PingCredentialsProvider,
123            Arc::new(|conn_params, host, port, user, pwd| {
124                Box::new(PingCredentialsProvider::new(
125                    conn_params,
126                    host,
127                    port,
128                    user,
129                    pwd,
130                ))
131            }),
132        );
133        Mutex::new(map)
134    })
135}
136
137/// Registers a factory for the given [`PluginName`] variant.
138///
139/// The factory receives `(conn_parameters, idp_host, idp_port, username, password)` and must
140/// return a `Box<dyn SamlProvider>`. Call this once at application startup
141/// before invoking [`read_sql`].
142///
143/// conn_parameters is a map of provider-specific arguments, like PartnerSpId for Ping,
144/// app_id - Used only with Okta. https://example.okta.com/home/amazon_aws/0oa2hylwrpM8UGehd1t7/272
145/// idp_tenant - A tenant used for Azure AD. Used only with Azure.
146/// client_id - A client ID for the Amazon Redshift enterprise application in Azure AD. Used only with Azure.
147///
148/// [`PluginName::PingCredentialsProvider`] is pre-registered and maps to
149/// [`PingCredentialsProvider`]. Registering it again replaces the built-in.
150///
151/// # Example
152///
153/// ```rust,no_run
154/// use secrecy::SecretString;
155/// use redshift_iam::{register_provider, PluginName, SamlProvider};
156///
157/// struct MyOktaProvider;
158///
159/// #[async_trait::async_trait]
160/// impl SamlProvider for MyOktaProvider {
161///     async fn get_saml_assertion(&self) -> String { todo!() }
162/// }
163///
164/// register_provider(PluginName::OktaCredentialsProvider, |_conn_params, _host, _port, _user, _pwd| {
165///     Box::new(MyOktaProvider)
166/// });
167/// ```
168pub fn register_provider(
169    plugin: PluginName,
170    factory: impl Fn(
171        &HashMap<String, Cow<str>>,
172        &str,
173        Option<u16>,
174        &str,
175        SecretString,
176    ) -> Box<dyn SamlProvider>
177    + Send
178    + Sync
179    + 'static,
180) {
181    registry().lock().unwrap().insert(plugin, Arc::new(factory));
182}
183
184/// Uses the main functionality from the crate modules to convert connection URI to Redshift type.
185fn get_redshift_from_uri(connection_uri: impl ToString) -> Result<Redshift, ConnectorXOutError> {
186    let uri_string = connection_uri.to_string();
187    let mut uri_str = uri_string.trim();
188
189    let pattern = "redshift:iam://";
190    let (scheme, tail) = match uri_str.split_once(':') {
191        Some((scheme, tail)) => (scheme, tail),
192        None => {
193            return Err(ConnectorXOutError::SourceNotSupport(format!(
194                "The connection uri needs to start with {pattern}"
195            )));
196        }
197    };
198    if scheme == "jdbc" {
199        uri_str = tail;
200    }
201    if !uri_str.starts_with(pattern) && !uri_str.starts_with("redshift-iam://") {
202        return Err(ConnectorXOutError::SourceNotSupport(format!(
203            "The connection uri needs to start with {pattern}"
204        )));
205    }
206    uri_str = uri_str.split_once("://").unwrap().1;
207    let uri_str = format!("redshift://{uri_str}");
208    let redshift_url = reqwest::Url::parse(&uri_str).map_err(|e| {
209        ConnectorXOutError::SourceNotSupport(format!("Invalid Redshift IAM URI: {e}"))
210    })?;
211    let database = redshift_url.path().trim_start_matches("/");
212
213    let params: HashMap<String, Cow<str>> = HashMap::from_iter(
214        redshift_url
215            .query_pairs()
216            .map(|(key, val)| (key.to_lowercase(), val)),
217    );
218    let autocreate = params
219        .get("autocreate")
220        .is_some_and(|val| val.to_lowercase() == "true");
221    let cluster = params.get("clusterid").map_or("", |val| val);
222    let idp_host = params.get("idp_host").map_or("", |val| val);
223    let idp_port = params
224        .get("idp_port")
225        .and_then(|val| val.parse::<u16>().ok());
226    let pwd = redshift_url.password().unwrap_or("");
227
228    let aws_credentials = if idp_host.is_empty() || pwd.is_empty() {
229        // No IdP credentials — fall back to ambient AWS credentials from the environment
230        // TODO: other ways to log in from the uri parameters?
231        debug!("Initiating IAM login");
232        let rt = Runtime::new().unwrap();
233        let creds = rt.block_on(async {
234            aws_config::load_from_env()
235                .await
236                .credentials_provider()
237                .unwrap()
238                .provide_credentials()
239                .await
240                .unwrap()
241        });
242        sts::types::Credentials::builder()
243            .set_access_key_id(Some(creds.access_key_id().to_string()))
244            .set_secret_access_key(Some(creds.secret_access_key().to_string()))
245            .set_session_token(creds.session_token().map(str::to_string))
246            .build()
247            .unwrap()
248    } else {
249        let plugin_name = PluginName::from(params.get("plugin_name").map_or("", |v| v.as_ref()));
250        let factory = registry()
251            .lock()
252            .unwrap()
253            .get(&plugin_name)
254            .cloned()
255            .unwrap_or_else(|| {
256                panic!(
257                    "No SAML provider registered for {plugin_name:?}. \
258                    Register one with register_provider() before calling read_sql."
259                )
260            });
261        let provider = factory(
262            &params,
263            idp_host,
264            idp_port,
265            redshift_url.username(),
266            SecretString::new(pwd.to_string().into_boxed_str()),
267        );
268        aws_creds_from_saml(provider, params.get("preferred_role").map_or("", |val| val))
269    };
270
271    let mut iam_provider = IamProvider::new(redshift_url.username(), database, cluster, autocreate);
272    if let Some(region) = params.get("region") {
273        iam_provider = iam_provider.set_region(region);
274    }
275    let (username, password) = iam_provider.auth(aws_credentials);
276
277    Ok(Redshift::new(
278        username,
279        password,
280        redshift_url.host_str().unwrap(),
281        redshift_url.port(),
282        database,
283    ))
284}
285
286/// Executes `query` against a Redshift cluster described by a JDBC-style IAM connection URI
287/// and returns the results as Arrow [`RecordBatch`]es.
288///
289/// # URI format
290///
291/// ```text
292/// [jdbc:]redshift:iam://<user>:<password>@<host>:<port>/<database>?<params>
293/// ```
294///
295/// The `jdbc:` prefix is optional and stripped automatically. Supported query parameters
296/// (all case-insensitive):
297///
298/// | Parameter | Description |
299/// |---|---|
300/// | `ClusterID` | Redshift cluster identifier (required for IAM auth) |
301/// | `Region` | AWS region (default: `us-east-1`) |
302/// | `AutoCreate` | `true` to auto-create the DB user |
303/// | `IdP_Host` | IdP hostname. If absent, falls back to ambient AWS credentials |
304/// | `IdP_Port` | IdP port (default: `443`) |
305/// | `Plugin_Name` | SAML provider variant (e.g. `PingCredentialsProvider`). Maps to [`PluginName`]. |
306/// | `Preferred_Role` | IAM role ARN to assume via SAML |
307///
308/// When `IdP_Host` and a password are present the `Plugin_Name` parameter is
309/// parsed into a [`PluginName`] variant and looked up in the global registry.
310/// [`PluginName::PingCredentialsProvider`] is pre-registered. All other variants
311/// must be registered first via [`register_provider`].
312///
313/// # Errors
314///
315/// Returns [`ConnectorXOutError::SourceNotSupport`] if the URI does not start with
316/// `redshift:iam://`.
317pub fn read_sql(
318    query: &str,
319    connection_uri: impl ToString,
320) -> Result<Vec<RecordBatch>, ConnectorXOutError> {
321    let redshift = get_redshift_from_uri(connection_uri).unwrap();
322    redshift.execute(query)
323}
324
325/// Converts a Redshift IAM connection URI into a parsed PostgreSQL connection string
326/// with temporary credentials already embedded.
327///
328/// Parses `connection_uri`, performs the full IAM / SAML authentication flow (identical
329/// to [`read_sql`]), and returns the resulting `postgres://` URL with the short-lived
330/// username and password substituted in.
331///
332/// This is useful when you need to hand a live connection string to a third-party
333/// library that speaks the PostgreSQL wire protocol directly (e.g. `sqlx`, `diesel`,
334/// `psycopg2` via a subprocess) without going through `connectorx`.
335///
336/// # URI format
337///
338/// Accepts the same `[jdbc:]redshift:iam://…` format described in [`read_sql`].
339///
340/// # Fallback behaviour
341///
342/// If the IAM / SAML exchange fails, the error is logged at the `error` level and the
343/// function falls back to returning the original URI with its scheme replaced by
344/// `postgres`. This allows callers to still attempt a direct connection using
345/// whatever credentials were present in the URI.
346///
347/// # Returns
348///
349/// A `postgres://username:password@host:port/database` connection string as an
350/// Url instance. The password is a short-lived STS session token and should
351/// not be cached beyond its expiry window.
352pub fn redshift_to_postgres(connection_uri: impl ToString) -> reqwest::Url {
353    let redshift_res = get_redshift_from_uri(connection_uri.to_string());
354    if let Ok(redshift) = redshift_res {
355        // already parsed before, safe to unwrap
356        reqwest::Url::parse(redshift.connection_string().expose_secret()).unwrap()
357    } else {
358        error!(
359            "Logging to redshift using redshift-iam crate failed with: {:?}",
360            redshift_res.err()
361        );
362        let mut uri = reqwest::Url::parse(&connection_uri.to_string()).unwrap(); // we need to return Url; if not parsable, just panic
363        uri.set_scheme("postgres").unwrap(); // postgres is valid scheme, no reason for panic
364        uri
365    }
366}
367
368/// Obtains temporary AWS credentials from any [`SamlProvider`] synchronously.
369///
370/// Drives the async [`saml_provider::get_credentials`] on a new Tokio runtime.
371fn aws_creds_from_saml(
372    provider: Box<dyn SamlProvider>,
373    preferred_role: &str,
374) -> sts::types::Credentials {
375    let rt = Runtime::new().unwrap();
376    rt.block_on(crate::saml_provider::get_credentials(
377        provider.as_ref(),
378        preferred_role.to_string(),
379    ))
380    .unwrap()
381}