Skip to main content

redshift_iam/
lib.rs

1// inspired by github.com/aws/amazon-redshift-python-driver
2// provides saml and IAM temp credential login
3
4#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/README.md"))]
5
6use std::borrow::Cow;
7use std::collections::HashMap;
8use std::sync::{Arc, Mutex, OnceLock};
9
10#[cfg(feature = "read_sql")]
11use arrow::record_batch::RecordBatch;
12use aws_credential_types::provider::ProvideCredentials;
13use aws_sdk_sts as sts;
14use log::{debug, error};
15use secrecy::{ExposeSecret, SecretString};
16use tokio::runtime::Runtime;
17
18#[doc(hidden)]
19pub mod iam_provider;
20#[doc(hidden)]
21pub mod redshift;
22pub mod saml_provider;
23
24pub(crate) mod re {
25    use regex::Regex;
26
27    pub fn compile(pattern: &str) -> Regex {
28        Regex::new(pattern).unwrap()
29    }
30}
31
32// Re-export public API at crate root so structs and traits appear at the
33// top level in docs and can be imported as `redshift_iam::PingCredentialsProvider`.
34pub use iam_provider::IamProvider;
35pub use redshift::Redshift;
36pub use saml_provider::{PingCredentialsProvider, SamlProvider};
37
38#[doc(hidden)]
39pub mod prelude {
40    pub use crate::iam_provider::IamProvider;
41    pub use crate::redshift::Redshift;
42    pub use crate::saml_provider::PingCredentialsProvider;
43}
44
45#[derive(Debug)]
46pub enum RedshiftIamError {
47    ParseError(String),
48}
49
50#[allow(unreachable_patterns)]
51impl std::fmt::Display for RedshiftIamError {
52    fn fmt(&self, fmt: &mut core::fmt::Formatter<'_>) -> std::fmt::Result {
53        match self {
54            RedshiftIamError::ParseError(description) => fmt.write_str(description),
55            _ => write!(fmt, "Unknown error occurred"),
56        }
57    }
58}
59
60#[cfg(feature = "read_sql")]
61impl From<connectorx::errors::ConnectorXOutError> for RedshiftIamError {
62    fn from(err: connectorx::errors::ConnectorXOutError) -> Self {
63        RedshiftIamError::ParseError(format!("Error occurred: {err}"))
64    }
65}
66
67/// Identifies the SAML provider plugin to use when an IdP host is present in the
68/// connection URI.
69///
70/// The `Plugin_Name` query parameter in the JDBC URI is parsed into one of these
71/// variants. The optional `com.amazon.redshift.plugin.` prefix is stripped
72/// automatically, so both `"PingCredentialsProvider"` and
73/// `"com.amazon.redshift.plugin.PingCredentialsProvider"` resolve to
74/// [`PluginName::PingCredentialsProvider`].
75///
76/// Only [`PluginName::PingCredentialsProvider`] has a built-in factory.
77/// All other variants require a factory to be registered via [`register_provider`]
78/// before calling [`read_sql`].
79#[derive(Debug, Clone, PartialEq, Eq, Hash)]
80pub enum PluginName {
81    /// PingFederate IdP (built-in — backed by [`PingCredentialsProvider`]).
82    PingCredentialsProvider,
83    /// Okta IdP.
84    OktaCredentialsProvider,
85    /// Browser-based SAML flow.
86    BrowserSamlCredentialsProvider,
87    /// Browser-based Azure AD SAML flow.
88    BrowserAzureCredentialsProvider,
89    /// Azure AD IdP.
90    AzureCredentialsProvider,
91    /// ADFS IdP.
92    AdfsCredentialsProvider,
93    /// User-defined custom provider.
94    CustomCredentialsProvider,
95    /// Fallback for unrecognised `Plugin_Name` values.
96    UnknownCredentialsProvider,
97}
98
99impl From<&str> for PluginName {
100    /// Converts a `Plugin_Name` URI parameter value to a `PluginName` variant.
101    ///
102    /// The optional `com.amazon.redshift.plugin.` package prefix is stripped
103    /// before matching. Comparison is case-insensitive. Unrecognised strings
104    /// map to [`PluginName::UnknownCredentialsProvider`].
105    fn from(s: &str) -> Self {
106        let name = s
107            .trim()
108            .trim_start_matches("com.amazon.redshift.plugin.")
109            .to_lowercase();
110        match name.as_str() {
111            "pingcredentialsprovider" => Self::PingCredentialsProvider,
112            "oktacredentialsprovider" => Self::OktaCredentialsProvider,
113            "browsersamlcredentialsprovider" => Self::BrowserSamlCredentialsProvider,
114            "browserazurecredentialsprovider" => Self::BrowserAzureCredentialsProvider,
115            "azurecredentialsprovider" => Self::AzureCredentialsProvider,
116            "adfscredentialsprovider" => Self::AdfsCredentialsProvider,
117            "customcredentialsprovider" => Self::CustomCredentialsProvider,
118            _ => Self::UnknownCredentialsProvider,
119        }
120    }
121}
122
123/// Type-erased factory function stored in the provider registry.
124type ProviderFactory = Arc<
125    dyn Fn(
126            &HashMap<String, Cow<str>>,
127            &str,
128            Option<u16>,
129            &str,
130            SecretString,
131        ) -> Box<dyn SamlProvider>
132        + Send
133        + Sync,
134>;
135
136static PROVIDER_REGISTRY: OnceLock<Mutex<HashMap<PluginName, ProviderFactory>>> = OnceLock::new();
137
138/// Returns the global provider registry, pre-populated with the built-in
139/// [`PluginName::PingCredentialsProvider`] -> [`PingCredentialsProvider`] mapping.
140fn registry() -> &'static Mutex<HashMap<PluginName, ProviderFactory>> {
141    PROVIDER_REGISTRY.get_or_init(|| {
142        let mut map: HashMap<PluginName, ProviderFactory> = HashMap::new();
143        map.insert(
144            PluginName::PingCredentialsProvider,
145            Arc::new(|conn_params, host, port, user, pwd| {
146                Box::new(PingCredentialsProvider::new(
147                    conn_params,
148                    host,
149                    port,
150                    user,
151                    pwd,
152                ))
153            }),
154        );
155        Mutex::new(map)
156    })
157}
158
159/// Registers a factory for the given [`PluginName`] variant.
160///
161/// The factory receives `(conn_parameters, idp_host, idp_port, username, password)` and must
162/// return a `Box<dyn SamlProvider>`. Call this once at application startup
163/// before invoking [`read_sql`].
164///
165/// conn_parameters is a map of provider-specific arguments, like PartnerSpId for Ping,
166/// app_id - Used only with Okta. https://example.okta.com/home/amazon_aws/0oa2hylwrpM8UGehd1t7/272
167/// idp_tenant - A tenant used for Azure AD. Used only with Azure.
168/// client_id - A client ID for the Amazon Redshift enterprise application in Azure AD. Used only with Azure.
169///
170/// [`PluginName::PingCredentialsProvider`] is pre-registered and maps to
171/// [`PingCredentialsProvider`]. Registering it again replaces the built-in.
172///
173/// # Example
174///
175/// ```rust,no_run
176/// use secrecy::SecretString;
177/// use redshift_iam::{register_provider, PluginName, SamlProvider};
178///
179/// struct MyOktaProvider;
180///
181/// #[async_trait::async_trait]
182/// impl SamlProvider for MyOktaProvider {
183///     async fn get_saml_assertion(&self) -> String { todo!() }
184/// }
185///
186/// register_provider(PluginName::OktaCredentialsProvider, |_conn_params, _host, _port, _user, _pwd| {
187///     Box::new(MyOktaProvider)
188/// });
189/// ```
190pub fn register_provider(
191    plugin: PluginName,
192    factory: impl Fn(
193        &HashMap<String, Cow<str>>,
194        &str,
195        Option<u16>,
196        &str,
197        SecretString,
198    ) -> Box<dyn SamlProvider>
199    + Send
200    + Sync
201    + 'static,
202) {
203    registry().lock().unwrap().insert(plugin, Arc::new(factory));
204}
205
206/// Uses the main functionality from the crate modules to convert connection URI to Redshift type.
207fn get_redshift_from_uri(connection_uri: impl ToString) -> Result<Redshift, RedshiftIamError> {
208    let uri_string = connection_uri.to_string();
209    let mut uri_str = uri_string.trim();
210
211    let pattern = "redshift:iam://";
212    let (scheme, tail) = match uri_str.split_once(':') {
213        Some((scheme, tail)) => (scheme, tail),
214        None => {
215            return Err(RedshiftIamError::ParseError(format!(
216                "The connection uri needs to start with {pattern}"
217            )));
218        }
219    };
220    if scheme == "jdbc" {
221        uri_str = tail;
222    }
223    if !uri_str.starts_with(pattern) && !uri_str.starts_with("redshift-iam://") {
224        return Err(RedshiftIamError::ParseError(format!(
225            "The connection uri needs to start with {pattern}"
226        )));
227    }
228    uri_str = uri_str.split_once("://").unwrap().1;
229    let uri_str = format!("redshift://{uri_str}");
230    let redshift_url = reqwest::Url::parse(&uri_str).map_err(|e| {
231        RedshiftIamError::ParseError(format!("Invalid Redshift IAM URI: {e}"))
232    })?;
233    let database = redshift_url.path().trim_start_matches("/");
234
235    let params: HashMap<String, Cow<str>> = HashMap::from_iter(
236        redshift_url
237            .query_pairs()
238            .map(|(key, val)| (key.to_lowercase(), val)),
239    );
240    let autocreate = params
241        .get("autocreate")
242        .is_some_and(|val| val.to_lowercase() == "true");
243    let cluster = params.get("clusterid").map_or("", |val| val);
244    let idp_host = params.get("idp_host").map_or("", |val| val);
245    let idp_port = params
246        .get("idp_port")
247        .and_then(|val| val.parse::<u16>().ok());
248    let pwd = redshift_url.password().unwrap_or("");
249
250    let aws_credentials = if idp_host.is_empty() || pwd.is_empty() {
251        // No IdP credentials — fall back to ambient AWS credentials from the environment
252        // TODO: other ways to log in from the uri parameters?
253        debug!("Initiating IAM login");
254        let rt = Runtime::new().unwrap();
255        let creds = rt.block_on(async {
256            aws_config::load_from_env()
257                .await
258                .credentials_provider()
259                .unwrap()
260                .provide_credentials()
261                .await
262                .unwrap()
263        });
264        sts::types::Credentials::builder()
265            .set_access_key_id(Some(creds.access_key_id().to_string()))
266            .set_secret_access_key(Some(creds.secret_access_key().to_string()))
267            .set_session_token(creds.session_token().map(str::to_string))
268            .build()
269            .unwrap()
270    } else {
271        let plugin_name = PluginName::from(params.get("plugin_name").map_or("", |v| v.as_ref()));
272        let factory = registry()
273            .lock()
274            .unwrap()
275            .get(&plugin_name)
276            .cloned()
277            .unwrap_or_else(|| {
278                panic!(
279                    "No SAML provider registered for {plugin_name:?}. \
280                    Register one with register_provider() before calling read_sql."
281                )
282            });
283        let provider = factory(
284            &params,
285            idp_host,
286            idp_port,
287            redshift_url.username(),
288            SecretString::new(pwd.to_string().into_boxed_str()),
289        );
290        aws_creds_from_saml(provider, params.get("preferred_role").map_or("", |val| val))
291    };
292
293    let mut iam_provider = IamProvider::new(redshift_url.username(), database, cluster, autocreate);
294    if let Some(region) = params.get("region") {
295        iam_provider = iam_provider.set_region(region);
296    }
297    let (username, password) = iam_provider.auth(aws_credentials);
298
299    Ok(Redshift::new(
300        username,
301        password,
302        redshift_url.host_str().unwrap(),
303        redshift_url.port(),
304        database,
305    ))
306}
307
308/// Executes `query` against a Redshift cluster described by a JDBC-style IAM connection URI
309/// and returns the results as Arrow [`RecordBatch`]es.
310///
311/// # URI format
312///
313/// ```text
314/// [jdbc:]redshift:iam://<user>:<password>@<host>:<port>/<database>?<params>
315/// ```
316///
317/// The `jdbc:` prefix is optional and stripped automatically. Supported query parameters
318/// (all case-insensitive):
319///
320/// | Parameter | Description |
321/// |---|---|
322/// | `ClusterID` | Redshift cluster identifier (required for IAM auth) |
323/// | `Region` | AWS region (default: `us-east-1`) |
324/// | `AutoCreate` | `true` to auto-create the DB user |
325/// | `IdP_Host` | IdP hostname. If absent, falls back to ambient AWS credentials |
326/// | `IdP_Port` | IdP port (default: `443`) |
327/// | `Plugin_Name` | SAML provider variant (e.g. `PingCredentialsProvider`). Maps to [`PluginName`]. |
328/// | `Preferred_Role` | IAM role ARN to assume via SAML |
329///
330/// When `IdP_Host` and a password are present the `Plugin_Name` parameter is
331/// parsed into a [`PluginName`] variant and looked up in the global registry.
332/// [`PluginName::PingCredentialsProvider`] is pre-registered. All other variants
333/// must be registered first via [`register_provider`].
334///
335/// # Errors
336///
337/// Returns [`RedshiftIamError::ParseError`] if the URI does not start with
338/// `redshift:iam://`.
339#[cfg(feature = "read_sql")]
340pub fn read_sql(
341    query: &str,
342    connection_uri: impl ToString,
343) -> Result<Vec<RecordBatch>, RedshiftIamError> {
344    let redshift = get_redshift_from_uri(connection_uri)?;
345    Ok(redshift.execute(query)?)
346}
347
348/// Converts a Redshift IAM connection URI into a parsed PostgreSQL connection string
349/// with temporary credentials already embedded.
350///
351/// Parses `connection_uri`, performs the full IAM / SAML authentication flow (identical
352/// to [`read_sql`]), and returns the resulting `postgres://` URL with the short-lived
353/// username and password substituted in.
354///
355/// This is useful when you need to hand a live connection string to a third-party
356/// library that speaks the PostgreSQL wire protocol directly (e.g. `sqlx`, `diesel`,
357/// `psycopg2` via a subprocess) without going through `connectorx`.
358///
359/// # URI format
360///
361/// Accepts the same `[jdbc:]redshift:iam://…` format described in [`read_sql`].
362///
363/// # Fallback behaviour
364///
365/// If the IAM / SAML exchange fails, the error is logged at the `error` level and the
366/// function falls back to returning the original URI with its scheme replaced by
367/// `postgres`. This allows callers to still attempt a direct connection using
368/// whatever credentials were present in the URI.
369///
370/// # Returns
371///
372/// A `postgres://username:password@host:port/database` connection string as an
373/// Url instance. The password is a short-lived STS session token and should
374/// not be cached beyond its expiry window.
375pub fn redshift_to_postgres(connection_uri: impl ToString) -> reqwest::Url {
376    let redshift_res = get_redshift_from_uri(connection_uri.to_string());
377    if let Ok(redshift) = redshift_res {
378        // already parsed before, safe to unwrap
379        let mut uri = reqwest::Url::parse(redshift.connection_string().expose_secret()).unwrap();
380        // remove the protocol
381        uri.set_query(None);
382        uri
383    } else {
384        error!(
385            "Logging to redshift using redshift-iam crate failed with: {:?}",
386            redshift_res.err()
387        );
388        let mut uri = reqwest::Url::parse(&connection_uri.to_string()).unwrap(); // we need to return Url; if not parsable, just panic
389        uri.set_scheme("postgres").unwrap(); // postgres is valid scheme, no reason for panic
390        uri
391    }
392}
393
394/// Obtains temporary AWS credentials from any [`SamlProvider`] synchronously.
395///
396/// Drives the async [`saml_provider::get_credentials`] on a new Tokio runtime.
397fn aws_creds_from_saml(
398    provider: Box<dyn SamlProvider>,
399    preferred_role: &str,
400) -> sts::types::Credentials {
401    let rt = Runtime::new().unwrap();
402    rt.block_on(crate::saml_provider::get_credentials(
403        provider.as_ref(),
404        preferred_role.to_string(),
405    ))
406    .unwrap()
407}