1#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/README.md"))]
5
6use std::borrow::Cow;
7use std::collections::HashMap;
8use std::sync::{Arc, Mutex, OnceLock};
9
10#[cfg(feature = "read_sql")]
11use arrow::record_batch::RecordBatch;
12use aws_credential_types::provider::ProvideCredentials;
13use aws_sdk_sts as sts;
14use log::{debug, error};
15use secrecy::{ExposeSecret, SecretString};
16use tokio::runtime::Runtime;
17
18#[doc(hidden)]
19pub mod iam_provider;
20#[doc(hidden)]
21pub mod redshift;
22pub mod saml_provider;
23
24pub(crate) mod re {
25 use regex::Regex;
26
27 pub fn compile(pattern: &str) -> Regex {
28 Regex::new(pattern).unwrap()
29 }
30}
31
32pub use iam_provider::IamProvider;
35pub use redshift::Redshift;
36pub use saml_provider::{PingCredentialsProvider, SamlProvider};
37
38#[doc(hidden)]
39pub mod prelude {
40 pub use crate::iam_provider::IamProvider;
41 pub use crate::redshift::Redshift;
42 pub use crate::saml_provider::PingCredentialsProvider;
43}
44
45#[derive(Debug)]
46pub enum RedshiftIamError {
47 ParseError(String),
48}
49
50#[allow(unreachable_patterns)]
51impl std::fmt::Display for RedshiftIamError {
52 fn fmt(&self, fmt: &mut core::fmt::Formatter<'_>) -> std::fmt::Result {
53 match self {
54 RedshiftIamError::ParseError(description) => fmt.write_str(description),
55 _ => write!(fmt, "Unknown error occurred"),
56 }
57 }
58}
59
60#[cfg(feature = "read_sql")]
61impl From<connectorx::errors::ConnectorXOutError> for RedshiftIamError {
62 fn from(err: connectorx::errors::ConnectorXOutError) -> Self {
63 RedshiftIamError::ParseError(format!("Error occurred: {err}"))
64 }
65}
66
67#[derive(Debug, Clone, PartialEq, Eq, Hash)]
80pub enum PluginName {
81 PingCredentialsProvider,
83 OktaCredentialsProvider,
85 BrowserSamlCredentialsProvider,
87 BrowserAzureCredentialsProvider,
89 AzureCredentialsProvider,
91 AdfsCredentialsProvider,
93 CustomCredentialsProvider,
95 UnknownCredentialsProvider,
97}
98
99impl From<&str> for PluginName {
100 fn from(s: &str) -> Self {
106 let name = s
107 .trim()
108 .trim_start_matches("com.amazon.redshift.plugin.")
109 .to_lowercase();
110 match name.as_str() {
111 "pingcredentialsprovider" => Self::PingCredentialsProvider,
112 "oktacredentialsprovider" => Self::OktaCredentialsProvider,
113 "browsersamlcredentialsprovider" => Self::BrowserSamlCredentialsProvider,
114 "browserazurecredentialsprovider" => Self::BrowserAzureCredentialsProvider,
115 "azurecredentialsprovider" => Self::AzureCredentialsProvider,
116 "adfscredentialsprovider" => Self::AdfsCredentialsProvider,
117 "customcredentialsprovider" => Self::CustomCredentialsProvider,
118 _ => Self::UnknownCredentialsProvider,
119 }
120 }
121}
122
123type ProviderFactory = Arc<
125 dyn Fn(
126 &HashMap<String, Cow<str>>,
127 &str,
128 Option<u16>,
129 &str,
130 SecretString,
131 ) -> Box<dyn SamlProvider>
132 + Send
133 + Sync,
134>;
135
136static PROVIDER_REGISTRY: OnceLock<Mutex<HashMap<PluginName, ProviderFactory>>> = OnceLock::new();
137
138fn registry() -> &'static Mutex<HashMap<PluginName, ProviderFactory>> {
141 PROVIDER_REGISTRY.get_or_init(|| {
142 let mut map: HashMap<PluginName, ProviderFactory> = HashMap::new();
143 map.insert(
144 PluginName::PingCredentialsProvider,
145 Arc::new(|conn_params, host, port, user, pwd| {
146 Box::new(PingCredentialsProvider::new(
147 conn_params,
148 host,
149 port,
150 user,
151 pwd,
152 ))
153 }),
154 );
155 Mutex::new(map)
156 })
157}
158
159pub fn register_provider(
191 plugin: PluginName,
192 factory: impl Fn(
193 &HashMap<String, Cow<str>>,
194 &str,
195 Option<u16>,
196 &str,
197 SecretString,
198 ) -> Box<dyn SamlProvider>
199 + Send
200 + Sync
201 + 'static,
202) {
203 registry().lock().unwrap().insert(plugin, Arc::new(factory));
204}
205
206fn get_redshift_from_uri(connection_uri: impl ToString) -> Result<Redshift, RedshiftIamError> {
208 let uri_string = connection_uri.to_string();
209 let mut uri_str = uri_string.trim();
210
211 let pattern = "redshift:iam://";
212 let (scheme, tail) = match uri_str.split_once(':') {
213 Some((scheme, tail)) => (scheme, tail),
214 None => {
215 return Err(RedshiftIamError::ParseError(format!(
216 "The connection uri needs to start with {pattern}"
217 )));
218 }
219 };
220 if scheme == "jdbc" {
221 uri_str = tail;
222 }
223 if !uri_str.starts_with(pattern) && !uri_str.starts_with("redshift-iam://") {
224 return Err(RedshiftIamError::ParseError(format!(
225 "The connection uri needs to start with {pattern}"
226 )));
227 }
228 uri_str = uri_str.split_once("://").unwrap().1;
229 let uri_str = format!("redshift://{uri_str}");
230 let redshift_url = reqwest::Url::parse(&uri_str).map_err(|e| {
231 RedshiftIamError::ParseError(format!("Invalid Redshift IAM URI: {e}"))
232 })?;
233 let database = redshift_url.path().trim_start_matches("/");
234
235 let params: HashMap<String, Cow<str>> = HashMap::from_iter(
236 redshift_url
237 .query_pairs()
238 .map(|(key, val)| (key.to_lowercase(), val)),
239 );
240 let autocreate = params
241 .get("autocreate")
242 .is_some_and(|val| val.to_lowercase() == "true");
243 let cluster = params.get("clusterid").map_or("", |val| val);
244 let idp_host = params.get("idp_host").map_or("", |val| val);
245 let idp_port = params
246 .get("idp_port")
247 .and_then(|val| val.parse::<u16>().ok());
248 let pwd = redshift_url.password().unwrap_or("");
249
250 let aws_credentials = if idp_host.is_empty() || pwd.is_empty() {
251 debug!("Initiating IAM login");
254 let rt = Runtime::new().unwrap();
255 let creds = rt.block_on(async {
256 aws_config::load_from_env()
257 .await
258 .credentials_provider()
259 .unwrap()
260 .provide_credentials()
261 .await
262 .unwrap()
263 });
264 sts::types::Credentials::builder()
265 .set_access_key_id(Some(creds.access_key_id().to_string()))
266 .set_secret_access_key(Some(creds.secret_access_key().to_string()))
267 .set_session_token(creds.session_token().map(str::to_string))
268 .build()
269 .unwrap()
270 } else {
271 let plugin_name = PluginName::from(params.get("plugin_name").map_or("", |v| v.as_ref()));
272 let factory = registry()
273 .lock()
274 .unwrap()
275 .get(&plugin_name)
276 .cloned()
277 .unwrap_or_else(|| {
278 panic!(
279 "No SAML provider registered for {plugin_name:?}. \
280 Register one with register_provider() before calling read_sql."
281 )
282 });
283 let provider = factory(
284 ¶ms,
285 idp_host,
286 idp_port,
287 redshift_url.username(),
288 SecretString::new(pwd.to_string().into_boxed_str()),
289 );
290 aws_creds_from_saml(provider, params.get("preferred_role").map_or("", |val| val))
291 };
292
293 let mut iam_provider = IamProvider::new(redshift_url.username(), database, cluster, autocreate);
294 if let Some(region) = params.get("region") {
295 iam_provider = iam_provider.set_region(region);
296 }
297 let (username, password) = iam_provider.auth(aws_credentials);
298
299 Ok(Redshift::new(
300 username,
301 password,
302 redshift_url.host_str().unwrap(),
303 redshift_url.port(),
304 database,
305 ))
306}
307
308#[cfg(feature = "read_sql")]
340pub fn read_sql(
341 query: &str,
342 connection_uri: impl ToString,
343) -> Result<Vec<RecordBatch>, RedshiftIamError> {
344 let redshift = get_redshift_from_uri(connection_uri)?;
345 Ok(redshift.execute(query)?)
346}
347
348pub fn redshift_to_postgres(connection_uri: impl ToString) -> reqwest::Url {
376 let redshift_res = get_redshift_from_uri(connection_uri.to_string());
377 if let Ok(redshift) = redshift_res {
378 let mut uri = reqwest::Url::parse(redshift.connection_string().expose_secret()).unwrap();
380 uri.set_query(None);
382 uri
383 } else {
384 error!(
385 "Logging to redshift using redshift-iam crate failed with: {:?}",
386 redshift_res.err()
387 );
388 let mut uri = reqwest::Url::parse(&connection_uri.to_string()).unwrap(); uri.set_scheme("postgres").unwrap(); uri
391 }
392}
393
394fn aws_creds_from_saml(
398 provider: Box<dyn SamlProvider>,
399 preferred_role: &str,
400) -> sts::types::Credentials {
401 let rt = Runtime::new().unwrap();
402 rt.block_on(crate::saml_provider::get_credentials(
403 provider.as_ref(),
404 preferred_role.to_string(),
405 ))
406 .unwrap()
407}