redshift_iam/lib.rs
1// inspired by github.com/aws/amazon-redshift-python-driver
2// provides saml and IAM temp credential login
3
4#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/README.md"))]
5
6use std::borrow::Cow;
7use std::collections::HashMap;
8use std::sync::{Arc, Mutex, OnceLock};
9
10use arrow::record_batch::RecordBatch;
11use aws_credential_types::provider::ProvideCredentials;
12use aws_sdk_sts as sts;
13use connectorx::errors::ConnectorXOutError;
14use log::{debug, error};
15use secrecy::{ExposeSecret, SecretString};
16use tokio::runtime::Runtime;
17
18#[doc(hidden)]
19pub mod iam_provider;
20#[doc(hidden)]
21pub mod redshift;
22pub mod saml_provider;
23
24pub(crate) mod re {
25 use regex::Regex;
26
27 pub fn compile(pattern: &str) -> Regex {
28 Regex::new(pattern).unwrap()
29 }
30}
31
32// Re-export public API at crate root so structs and traits appear at the
33// top level in docs and can be imported as `redshift_iam::PingCredentialsProvider`.
34pub use iam_provider::IamProvider;
35pub use redshift::Redshift;
36pub use saml_provider::{PingCredentialsProvider, SamlProvider};
37
38#[doc(hidden)]
39pub mod prelude {
40 pub use crate::iam_provider::IamProvider;
41 pub use crate::redshift::Redshift;
42 pub use crate::saml_provider::PingCredentialsProvider;
43}
44
45/// Identifies the SAML provider plugin to use when an IdP host is present in the
46/// connection URI.
47///
48/// The `Plugin_Name` query parameter in the JDBC URI is parsed into one of these
49/// variants. The optional `com.amazon.redshift.plugin.` prefix is stripped
50/// automatically, so both `"PingCredentialsProvider"` and
51/// `"com.amazon.redshift.plugin.PingCredentialsProvider"` resolve to
52/// [`PluginName::PingCredentialsProvider`].
53///
54/// Only [`PluginName::PingCredentialsProvider`] has a built-in factory.
55/// All other variants require a factory to be registered via [`register_provider`]
56/// before calling [`read_sql`].
57#[derive(Debug, Clone, PartialEq, Eq, Hash)]
58pub enum PluginName {
59 /// PingFederate IdP (built-in — backed by [`PingCredentialsProvider`]).
60 PingCredentialsProvider,
61 /// Okta IdP.
62 OktaCredentialsProvider,
63 /// Browser-based SAML flow.
64 BrowserSamlCredentialsProvider,
65 /// Browser-based Azure AD SAML flow.
66 BrowserAzureCredentialsProvider,
67 /// Azure AD IdP.
68 AzureCredentialsProvider,
69 /// ADFS IdP.
70 AdfsCredentialsProvider,
71 /// User-defined custom provider.
72 CustomCredentialsProvider,
73 /// Fallback for unrecognised `Plugin_Name` values.
74 UnknownCredentialsProvider,
75}
76
77impl From<&str> for PluginName {
78 /// Converts a `Plugin_Name` URI parameter value to a `PluginName` variant.
79 ///
80 /// The optional `com.amazon.redshift.plugin.` package prefix is stripped
81 /// before matching. Comparison is case-insensitive. Unrecognised strings
82 /// map to [`PluginName::UnknownCredentialsProvider`].
83 fn from(s: &str) -> Self {
84 let name = s
85 .trim()
86 .trim_start_matches("com.amazon.redshift.plugin.")
87 .to_lowercase();
88 match name.as_str() {
89 "pingcredentialsprovider" => Self::PingCredentialsProvider,
90 "oktacredentialsprovider" => Self::OktaCredentialsProvider,
91 "browsersamlcredentialsprovider" => Self::BrowserSamlCredentialsProvider,
92 "browserazurecredentialsprovider" => Self::BrowserAzureCredentialsProvider,
93 "azurecredentialsprovider" => Self::AzureCredentialsProvider,
94 "adfscredentialsprovider" => Self::AdfsCredentialsProvider,
95 "customcredentialsprovider" => Self::CustomCredentialsProvider,
96 _ => Self::UnknownCredentialsProvider,
97 }
98 }
99}
100
101/// Type-erased factory function stored in the provider registry.
102type ProviderFactory = Arc<
103 dyn Fn(
104 &HashMap<String, Cow<str>>,
105 &str,
106 Option<u16>,
107 &str,
108 SecretString,
109 ) -> Box<dyn SamlProvider>
110 + Send
111 + Sync,
112>;
113
114static PROVIDER_REGISTRY: OnceLock<Mutex<HashMap<PluginName, ProviderFactory>>> = OnceLock::new();
115
116/// Returns the global provider registry, pre-populated with the built-in
117/// [`PluginName::PingCredentialsProvider`] -> [`PingCredentialsProvider`] mapping.
118fn registry() -> &'static Mutex<HashMap<PluginName, ProviderFactory>> {
119 PROVIDER_REGISTRY.get_or_init(|| {
120 let mut map: HashMap<PluginName, ProviderFactory> = HashMap::new();
121 map.insert(
122 PluginName::PingCredentialsProvider,
123 Arc::new(|conn_params, host, port, user, pwd| {
124 Box::new(PingCredentialsProvider::new(
125 conn_params,
126 host,
127 port,
128 user,
129 pwd,
130 ))
131 }),
132 );
133 Mutex::new(map)
134 })
135}
136
137/// Registers a factory for the given [`PluginName`] variant.
138///
139/// The factory receives `(conn_parameters, idp_host, idp_port, username, password)` and must
140/// return a `Box<dyn SamlProvider>`. Call this once at application startup
141/// before invoking [`read_sql`].
142///
143/// conn_parameters is a map of provider-specific arguments, like PartnerSpId for Ping,
144/// app_id - Used only with Okta. https://example.okta.com/home/amazon_aws/0oa2hylwrpM8UGehd1t7/272
145/// idp_tenant - A tenant used for Azure AD. Used only with Azure.
146/// client_id - A client ID for the Amazon Redshift enterprise application in Azure AD. Used only with Azure.
147///
148/// [`PluginName::PingCredentialsProvider`] is pre-registered and maps to
149/// [`PingCredentialsProvider`]. Registering it again replaces the built-in.
150///
151/// # Example
152///
153/// ```rust,no_run
154/// use secrecy::SecretString;
155/// use redshift_iam::{register_provider, PluginName, SamlProvider};
156///
157/// struct MyOktaProvider;
158///
159/// #[async_trait::async_trait]
160/// impl SamlProvider for MyOktaProvider {
161/// async fn get_saml_assertion(&self) -> String { todo!() }
162/// }
163///
164/// register_provider(PluginName::OktaCredentialsProvider, |_conn_params, _host, _port, _user, _pwd| {
165/// Box::new(MyOktaProvider)
166/// });
167/// ```
168pub fn register_provider(
169 plugin: PluginName,
170 factory: impl Fn(
171 &HashMap<String, Cow<str>>,
172 &str,
173 Option<u16>,
174 &str,
175 SecretString,
176 ) -> Box<dyn SamlProvider>
177 + Send
178 + Sync
179 + 'static,
180) {
181 registry().lock().unwrap().insert(plugin, Arc::new(factory));
182}
183
184/// Uses the main functionality from the crate modules to convert connection URI to Redshift type.
185fn get_redshift_from_uri(connection_uri: impl ToString) -> Result<Redshift, ConnectorXOutError> {
186 let uri_string = connection_uri.to_string();
187 let mut uri_str = uri_string.trim();
188
189 let pattern = "redshift:iam://";
190 let (scheme, tail) = match uri_str.split_once(':') {
191 Some((scheme, tail)) => (scheme, tail),
192 None => {
193 return Err(ConnectorXOutError::SourceNotSupport(format!(
194 "The connection uri needs to start with {pattern}"
195 )));
196 }
197 };
198 if scheme == "jdbc" {
199 uri_str = tail;
200 }
201 if !uri_str.starts_with(pattern) && !uri_str.starts_with("redshift-iam://") {
202 return Err(ConnectorXOutError::SourceNotSupport(format!(
203 "The connection uri needs to start with {pattern}"
204 )));
205 }
206 uri_str = uri_str.split_once("://").unwrap().1;
207 let uri_str = format!("redshift://{uri_str}");
208 let redshift_url = reqwest::Url::parse(&uri_str).map_err(|e| {
209 ConnectorXOutError::SourceNotSupport(format!("Invalid Redshift IAM URI: {e}"))
210 })?;
211 let database = redshift_url.path().trim_start_matches("/");
212
213 let params: HashMap<String, Cow<str>> = HashMap::from_iter(
214 redshift_url
215 .query_pairs()
216 .map(|(key, val)| (key.to_lowercase(), val)),
217 );
218 let autocreate = params
219 .get("autocreate")
220 .is_some_and(|val| val.to_lowercase() == "true");
221 let cluster = params.get("clusterid").map_or("", |val| val);
222 let idp_host = params.get("idp_host").map_or("", |val| val);
223 let idp_port = params
224 .get("idp_port")
225 .and_then(|val| val.parse::<u16>().ok());
226 let pwd = redshift_url.password().unwrap_or("");
227
228 let aws_credentials = if idp_host.is_empty() || pwd.is_empty() {
229 // No IdP credentials — fall back to ambient AWS credentials from the environment
230 // TODO: other ways to log in from the uri parameters?
231 debug!("Initiating IAM login");
232 let rt = Runtime::new().unwrap();
233 let creds = rt.block_on(async {
234 aws_config::load_from_env()
235 .await
236 .credentials_provider()
237 .unwrap()
238 .provide_credentials()
239 .await
240 .unwrap()
241 });
242 sts::types::Credentials::builder()
243 .set_access_key_id(Some(creds.access_key_id().to_string()))
244 .set_secret_access_key(Some(creds.secret_access_key().to_string()))
245 .set_session_token(creds.session_token().map(str::to_string))
246 .build()
247 .unwrap()
248 } else {
249 let plugin_name = PluginName::from(params.get("plugin_name").map_or("", |v| v.as_ref()));
250 let factory = registry()
251 .lock()
252 .unwrap()
253 .get(&plugin_name)
254 .cloned()
255 .unwrap_or_else(|| {
256 panic!(
257 "No SAML provider registered for {plugin_name:?}. \
258 Register one with register_provider() before calling read_sql."
259 )
260 });
261 let provider = factory(
262 ¶ms,
263 idp_host,
264 idp_port,
265 redshift_url.username(),
266 SecretString::new(pwd.to_string().into_boxed_str()),
267 );
268 aws_creds_from_saml(provider, params.get("preferred_role").map_or("", |val| val))
269 };
270
271 let mut iam_provider = IamProvider::new(redshift_url.username(), database, cluster, autocreate);
272 if let Some(region) = params.get("region") {
273 iam_provider = iam_provider.set_region(region);
274 }
275 let (username, password) = iam_provider.auth(aws_credentials);
276
277 Ok(Redshift::new(
278 username,
279 password,
280 redshift_url.host_str().unwrap(),
281 redshift_url.port(),
282 database,
283 ))
284}
285
286/// Executes `query` against a Redshift cluster described by a JDBC-style IAM connection URI
287/// and returns the results as Arrow [`RecordBatch`]es.
288///
289/// # URI format
290///
291/// ```text
292/// [jdbc:]redshift:iam://<user>:<password>@<host>:<port>/<database>?<params>
293/// ```
294///
295/// The `jdbc:` prefix is optional and stripped automatically. Supported query parameters
296/// (all case-insensitive):
297///
298/// | Parameter | Description |
299/// |---|---|
300/// | `ClusterID` | Redshift cluster identifier (required for IAM auth) |
301/// | `Region` | AWS region (default: `us-east-1`) |
302/// | `AutoCreate` | `true` to auto-create the DB user |
303/// | `IdP_Host` | IdP hostname. If absent, falls back to ambient AWS credentials |
304/// | `IdP_Port` | IdP port (default: `443`) |
305/// | `Plugin_Name` | SAML provider variant (e.g. `PingCredentialsProvider`). Maps to [`PluginName`]. |
306/// | `Preferred_Role` | IAM role ARN to assume via SAML |
307///
308/// When `IdP_Host` and a password are present the `Plugin_Name` parameter is
309/// parsed into a [`PluginName`] variant and looked up in the global registry.
310/// [`PluginName::PingCredentialsProvider`] is pre-registered. All other variants
311/// must be registered first via [`register_provider`].
312///
313/// # Errors
314///
315/// Returns [`ConnectorXOutError::SourceNotSupport`] if the URI does not start with
316/// `redshift:iam://`.
317pub fn read_sql(
318 query: &str,
319 connection_uri: impl ToString,
320) -> Result<Vec<RecordBatch>, ConnectorXOutError> {
321 let redshift = get_redshift_from_uri(connection_uri).unwrap();
322 redshift.execute(query)
323}
324
325/// Converts a Redshift IAM connection URI into a parsed PostgreSQL connection string
326/// with temporary credentials already embedded.
327///
328/// Parses `connection_uri`, performs the full IAM / SAML authentication flow (identical
329/// to [`read_sql`]), and returns the resulting `postgres://` URL with the short-lived
330/// username and password substituted in.
331///
332/// This is useful when you need to hand a live connection string to a third-party
333/// library that speaks the PostgreSQL wire protocol directly (e.g. `sqlx`, `diesel`,
334/// `psycopg2` via a subprocess) without going through `connectorx`.
335///
336/// # URI format
337///
338/// Accepts the same `[jdbc:]redshift:iam://…` format described in [`read_sql`].
339///
340/// # Fallback behaviour
341///
342/// If the IAM / SAML exchange fails, the error is logged at the `error` level and the
343/// function falls back to returning the original URI with its scheme replaced by
344/// `postgres`. This allows callers to still attempt a direct connection using
345/// whatever credentials were present in the URI.
346///
347/// # Returns
348///
349/// A `postgres://username:password@host:port/database` connection string as an
350/// Url instance. The password is a short-lived STS session token and should
351/// not be cached beyond its expiry window.
352pub fn redshift_to_postgres(connection_uri: impl ToString) -> reqwest::Url {
353 let redshift_res = get_redshift_from_uri(connection_uri.to_string());
354 if let Ok(redshift) = redshift_res {
355 // already parsed before, safe to unwrap
356 reqwest::Url::parse(redshift.connection_string().expose_secret()).unwrap()
357 } else {
358 error!(
359 "Logging to redshift using redshift-iam crate failed with: {:?}",
360 redshift_res.err()
361 );
362 let mut uri = reqwest::Url::parse(&connection_uri.to_string()).unwrap(); // we need to return Url; if not parsable, just panic
363 uri.set_scheme("postgres").unwrap(); // postgres is valid scheme, no reason for panic
364 uri
365 }
366}
367
368/// Obtains temporary AWS credentials from any [`SamlProvider`] synchronously.
369///
370/// Drives the async [`saml_provider::get_credentials`] on a new Tokio runtime.
371fn aws_creds_from_saml(
372 provider: Box<dyn SamlProvider>,
373 preferred_role: &str,
374) -> sts::types::Credentials {
375 let rt = Runtime::new().unwrap();
376 rt.block_on(crate::saml_provider::get_credentials(
377 provider.as_ref(),
378 preferred_role.to_string(),
379 ))
380 .unwrap()
381}