Skip to main content

redshift_iam/
saml_provider.rs

1use std::borrow::Cow;
2use std::collections::HashMap;
3use std::str;
4
5// use aws_config;
6use aws_sdk_sts as sts;
7use base64::prelude::*;
8use log::{debug, warn};
9// use reqwest;
10use scraper::{ElementRef, Html, Selector};
11use secrecy::{ExposeSecret, SecretString};
12use tokio::runtime::Runtime;
13
14use crate::re;
15
16/// Trait for identity providers that can supply a SAML assertion.
17///
18/// Implement this trait with `#[async_trait::async_trait]` to create a custom IdP plugin,
19/// then register it via [`crate::register_provider`].
20#[async_trait::async_trait]
21pub trait SamlProvider: Send + Sync {
22    /// Fetches and returns a base64-encoded SAML assertion from the IdP.
23    async fn get_saml_assertion(&self) -> String;
24}
25
26/// Returns `true` if the input tag has `type="password"`.
27fn is_password(inputtag: &ElementRef) -> bool {
28    inputtag.attr("type") == Some("password")
29}
30
31/// Returns `true` if the input tag has `type="text"`.
32fn is_text(inputtag: &ElementRef) -> bool {
33    inputtag.attr("type") == Some("text")
34}
35
36/// Finds the first form `action` attribute whose method is POST (or unspecified).
37/// Forms with an explicit non-POST method are skipped. Returns `None` if no
38/// qualifying form is found.
39fn get_form_action(soup: &Html) -> Option<&str> {
40    // NOTE: selector case-insensitive; it will match both form and FORM
41    let selector = Selector::parse("form").unwrap();
42
43    for inputtag in soup.select(&selector) {
44        let action = inputtag.attr("action");
45        if action.is_some() {
46            let method = inputtag.attr("method");
47            // safe unwrap
48            if method.is_some() && method.unwrap().to_uppercase() != "POST" {
49                warn!("Found action, but method is not POST. Skipping.");
50                continue;
51            }
52            return action;
53        }
54    }
55
56    None
57}
58
59/// Obtains temporary AWS credentials by exchanging a SAML assertion for STS credentials.
60///
61/// Calls [`SamlProvider::get_saml_assertion`], decodes the assertion, extracts the
62/// IAM role and principal ARNs, and calls `sts:AssumeRoleWithSAML` for `role_arn`.
63///
64/// # Panics
65/// - If no IAM roles are found in the SAML assertion.
66/// - If `role_arn` is not present among the roles in the assertion.
67pub async fn get_credentials(
68    provider: &dyn SamlProvider,
69    role_arn: String,
70) -> Option<sts::types::Credentials> {
71    // refresh method alias
72    let saml_assertion = provider.get_saml_assertion().await;
73
74    // decode SAML assertion into xml format
75    let ass_bytes = BASE64_STANDARD.decode(saml_assertion.as_bytes()).unwrap();
76    let doc = str::from_utf8(&ass_bytes).unwrap();
77
78    debug!("decoded SAML assertion into xml format");
79    // NOTE could parse it as xml, but keeping it lightweighted
80    let soup = Html::parse_document(doc);
81    let selector = Selector::parse(r"saml\:AttributeValue").unwrap();
82    let attrs = soup.select(&selector);
83
84    // extract RoleArn and PrincipleArn from SAML assertion
85    let role_pattern = re::compile(r"arn:aws:iam::\d*:role/\S+");
86    let provider_pattern = re::compile(r"arn:aws:iam::\d*:saml-provider/\S+");
87    let mut roles: HashMap<&str, String> = HashMap::new();
88    debug!("searching SAML assertion for values matching patterns for RoleArn and PrincipalArn");
89    // TODO: let user specify None as role and then pick the first one
90    for attr in attrs {
91        for value in attr.text() {
92            let mut role = "";
93            let mut provider = String::new();
94            for arn_ in value.split(",") {
95                let arn = arn_.trim();
96                if role_pattern.is_match(arn) {
97                    debug!("RoleArn pattern matched");
98                    role = arn;
99                }
100                if provider_pattern.is_match(arn) {
101                    debug!("PrincipleArn pattern matched");
102                    provider = arn.to_string();
103                }
104            }
105            if !role.is_empty() && !provider.is_empty() {
106                roles.insert(role, provider);
107            }
108        }
109    }
110    debug!("Done reading SAML assertion attributes");
111    debug!("{} roles identified in SAML assertion", roles.len());
112
113    if roles.is_empty() {
114        let exec_msg = "No roles were found in SAML assertion. Please verify IdP configuration provides ARNs in the SAML https://aws.amazon.com/SAML/Attributes/Role Attribute.";
115        panic!("{exec_msg}");
116    }
117    debug!("User provided preferred_role, trying to use...");
118    if !roles.contains_key(&*role_arn) {
119        let exec_msg = "User specified preferred_role was not found in SAML assertion https://aws.amazon.com/SAML/Attributes/Role Attribute";
120        panic!("{exec_msg}");
121    }
122
123    // empty config; no prior aws identity needed
124    let config = aws_config::load_from_env().await;
125    let client = sts::Client::new(&config);
126    debug!(
127        "Attempting to retrieve temporary AWS credentials using the SAML assertion, principal ARN, and role ARN."
128    );
129    let response = client
130        .assume_role_with_saml()
131        .set_principal_arn(roles.remove(&*role_arn)) // remove instead of get, so we move the value out and not get ref
132        .set_role_arn(Some(role_arn))
133        .saml_assertion(saml_assertion)
134        .send()
135        .await
136        .unwrap();
137    debug!("Extracting temporary AWS credentials from assume_role_with_saml response");
138
139    response.credentials
140}
141
142/// Extracts the SAMLResponse assertion value from the IdP authentication response HTML.
143/// Panics if no `SAMLResponse` input tag is found.
144///
145/// # Examples
146/// ```rust,ignore
147/// use crate::saml_provider::parse_saml_assertion;
148///
149/// let html = r#"<html><body>
150/// <form method="POST" action="https://signin.aws.amazon.com/saml">
151///   <INPUT type="hidden" name="SAMLResponse" value="dGVzdA==" />
152/// </form>
153/// </body></html>"#;
154/// assert_eq!(parse_saml_assertion(html), "dGVzdA==");
155/// ```
156fn parse_saml_assertion(html: &str) -> String {
157    let soup = Html::parse_document(html);
158    let selector = Selector::parse("INPUT").unwrap();
159    let mut assertion = String::new();
160    for inputtag in soup.select(&selector) {
161        if inputtag.attr("name") == Some("SAMLResponse") {
162            debug!("SAMLResponse tag found");
163            assertion = inputtag.attr("value").unwrap().to_string();
164        }
165    }
166    if assertion.is_empty() {
167        panic!(
168            "Failed to retrieve SAMLAssertion. An input tag named SAMLResponse was not identified in the Ping IdP authentication response"
169        );
170    }
171    assertion
172}
173
174/// PingFederate identity provider plugin for SAML-based Redshift authentication.
175///
176/// See the [Amazon Redshift IAM docs](https://docs.aws.amazon.com/redshift/latest/mgmt/options-for-providing-iam-credentials.html)
177/// for setup instructions.
178#[derive(Debug)]
179pub struct PingCredentialsProvider {
180    partner_sp_id: String,
181    idp_host: String,
182    idp_port: u16,
183    user_name: String,
184    password: SecretString,
185    /// When `true`, TLS certificate verification is disabled. Defaults to `false`.
186    pub ssl_insecure: bool,
187}
188
189impl PingCredentialsProvider {
190    /// Creates a new `PingCredentialsProvider`.
191    ///
192    /// - `conn_parameters`: `HashMap` that may contain a `partnerspid` key: the SP entity ID sent to PingFederate.
193    ///   If the map is empty or does not contain `partnerspid`, `"urn%3Aamazon%3Awebservices"` is used.
194    /// - `idp_port`: Defaults to `443` when `None`.
195    ///
196    /// # Examples
197    /// ```
198    /// use std::collections::HashMap;
199    /// use secrecy::SecretString;
200    /// use redshift_iam::PingCredentialsProvider;
201    ///
202    /// let scp = PingCredentialsProvider::new(
203    ///     &HashMap::new(),
204    ///     "pingfed.example.com",
205    ///     None,
206    ///     "alice",
207    ///     SecretString::new("s3cr3t".to_string().into_boxed_str()),
208    /// );
209    /// assert!(!scp.ssl_insecure);
210    /// assert!(scp.do_verify_ssl_cert());
211    /// assert_eq!(scp.user(), "alice");
212    /// ```
213    pub fn new(
214        conn_parameters: &HashMap<String, Cow<str>>,
215        idp_host: impl ToString,
216        idp_port: Option<u16>,
217        user_name: impl ToString,
218        password: SecretString,
219    ) -> Self {
220        // We could either accept pwd and create secretString here or force user to pass it
221        let partner_sp_id_option = conn_parameters.get("partnerspid");
222        let partner_sp_id = if let Some(partner_sp_id) = partner_sp_id_option {
223            partner_sp_id.to_string()
224        } else {
225            "urn%3Aamazon%3Awebservices".to_string()
226        };
227        Self {
228            partner_sp_id,
229            idp_host: idp_host.to_string(),
230            idp_port: idp_port.unwrap_or(443),
231            user_name: user_name.to_string(),
232            password,
233            ssl_insecure: false,
234        }
235    }
236
237    /// user getter
238    pub fn user(&self) -> String {
239        self.user_name.clone()
240    }
241
242    /// Returns `true` when TLS certificate verification is enabled (i.e. `ssl_insecure` is `false`).
243    pub fn do_verify_ssl_cert(&self) -> bool {
244        !self.ssl_insecure
245    }
246
247    /// Synchronously retrieves temporary AWS credentials for `preferred_role`.
248    ///
249    /// Drives the full SAML -> STS flow on a new Tokio runtime. Prefer the async
250    /// [`get_credentials`] free function when already inside an async context.
251    pub fn get_credentials(
252        &self,
253        preferred_role: impl ToString,
254    ) -> Option<sts::types::Credentials> {
255        let rt = Runtime::new().unwrap(); //?
256        rt.block_on(async { get_credentials(self, preferred_role.to_string()).await })
257    }
258
259    /// Parses the IdP login page HTML, extracting the form submission payload and
260    /// the form's action path. Panics if username or password fields cannot be found.
261    fn parse_login_form(&self, html: &str) -> (HashMap<String, String>, Option<String>) {
262        let soup = Html::parse_document(html);
263        let selector = Selector::parse("INPUT").unwrap();
264        let mut payload: HashMap<String, String> = HashMap::new();
265        let mut username_found = false;
266        let mut pwd_found = false;
267
268        debug!(
269            "Looking for username and password input tags in Ping IdP login page in order to build authentication request payload"
270        );
271        for inputtag in soup.select(&selector) {
272            let name = inputtag.attr("name").unwrap_or("").to_string();
273            let id_ = inputtag.attr("id").unwrap_or("");
274            debug!("name={name} , id={id_}");
275
276            if !username_found && is_text(&inputtag) && id_ == "username" {
277                debug!("Using tag with name {name} for username");
278                payload.insert(name, self.user());
279                username_found = true;
280            } else if is_password(&inputtag) && name.contains("pass") {
281                debug!("Using tag with name {name} for password");
282                if pwd_found {
283                    panic!(
284                        "Failed to parse Ping IdP login form. More than one password field was found on the Ping IdP login page"
285                    );
286                }
287                payload.insert(name, self.password.expose_secret().to_string());
288                pwd_found = true;
289            } else if !name.is_empty() {
290                let value = inputtag.attr("value").unwrap_or("").to_string();
291                payload.insert(name, value);
292            }
293        }
294
295        if !username_found {
296            debug!(
297                "username tag still not found, continuing search using secondary preferred tags"
298            );
299            for inputtag in soup.select(&selector) {
300                let name = inputtag.attr("name").unwrap_or("").to_string();
301                if is_text(&inputtag) && (name.contains("user") || name.contains("email")) {
302                    debug!("Using tag with name {name} for username");
303                    payload.insert(name, self.user());
304                    username_found = true;
305                }
306            }
307        }
308
309        if !username_found || !pwd_found {
310            panic!("Failed to parse Ping IdP login form field(s)");
311        }
312
313        let action = get_form_action(&soup).map(str::to_owned);
314        (payload, action)
315    }
316}
317
318#[async_trait::async_trait]
319impl SamlProvider for PingCredentialsProvider {
320    /// Logs in to the PingFederate IdP and returns a base64-encoded SAML assertion.
321    ///
322    /// Issues a GET to the SSO start URL, parses the login form, submits credentials,
323    /// and extracts the `SAMLResponse` value from the resulting page.
324    ///
325    /// # Panics
326    /// - If the login form cannot be parsed or credentials fields are missing.
327    /// - If the POST to the IdP returns a non-200 status.
328    /// - If no `SAMLResponse` input is found in the response.
329    async fn get_saml_assertion(&self) -> String {
330        // Method to grab the SAML Response. Used to refresh temporary credentials.
331        debug!("PingCredentialsProvider.get_saml_assertion");
332        let session = reqwest::Client::builder() // scoped only in this method
333            .cookie_store(true) // the PF=... session state cookie needs to be preserved
334            // .https_only(true)
335            .build()
336            .unwrap();
337
338        let mut url = format!(
339            "https://{}:{}/idp/startSSO.ping?PartnerSpId={}",
340            self.idp_host, self.idp_port, self.partner_sp_id,
341        );
342
343        debug!(
344            "Issuing GET request for Ping IdP login page using uri={} verify={}",
345            url,
346            self.do_verify_ssl_cert(),
347        );
348        let resp = session.get(&url).send().await.unwrap(); // TODO: , verify=self.do_verify_ssl_cert()
349        debug!("Response code: {}", resp.status());
350        debug!("response length: {}", resp.content_length().unwrap_or(0));
351
352        let resp_text = resp.text().await.unwrap();
353        let (payload, action) = self.parse_login_form(&resp_text);
354
355        // NOTE: not sure if we want to continue with the original url in None case
356        if let Some(action_str) = action.as_deref()
357            && action_str.starts_with("/")
358        {
359            url = format!("https://{}:{}{action_str}", self.idp_host, self.idp_port);
360        }
361        // else {
362        //     panic!();
363        // }
364
365        debug!(
366            "Issuing authentication request to Ping IdP using uri {} verify {}",
367            &url,
368            self.do_verify_ssl_cert(),
369        );
370        let response = session
371            .post(&url) //verify=self.do_verify_ssl_cert()
372            .form(&payload)
373            .send()
374            .await
375            .unwrap();
376        let status_code = response.status();
377        debug!("Response code: {status_code}");
378        let resp_text = response.text().await.unwrap();
379        if status_code != 200 {
380            panic!(
381                "POST to {url} returned non-200 http status.\n{}",
382                &resp_text
383            );
384        }
385
386        parse_saml_assertion(&resp_text)
387    }
388}
389
390#[cfg(test)]
391mod tests {
392    use super::*;
393
394    fn _make_valid_ping_credentials_provider() -> PingCredentialsProvider {
395        PingCredentialsProvider::new(
396            &HashMap::new(),
397            "example.example.com",
398            None,
399            "user",
400            SecretString::new("pwd".to_string().into_boxed_str()),
401        )
402    }
403
404    #[test]
405    #[should_panic(expected = "Failed to retrieve SAMLAssertion")]
406    fn test_parse_saml_assertion_missing_panics() {
407        parse_saml_assertion("<html><body><form></form></body></html>");
408    }
409
410    // parse_login_form tests
411    const LOGIN_PAGE_HTML: &str = r#"<html><body>
412    <form action="/idp/authLogin" method="POST">
413    <INPUT type="text" name="username" id="username" value="" />
414    <INPUT type="password" name="pf.pass" value="" />
415    <INPUT type="hidden" name="pf.ok" value="clicked" />
416    </form>
417    </body></html>"#;
418
419    #[test]
420    fn test_parse_login_form_extracts_credentials_and_hidden_fields() {
421        let scp = _make_valid_ping_credentials_provider();
422        let (payload, action) = scp.parse_login_form(LOGIN_PAGE_HTML);
423        assert_eq!(payload.get("username").map(String::as_str), Some("user"));
424        assert_eq!(payload.get("pf.pass").map(String::as_str), Some("pwd"));
425        assert_eq!(payload.get("pf.ok").map(String::as_str), Some("clicked"));
426        assert_eq!(action.as_deref(), Some("/idp/authLogin"));
427    }
428
429    #[test]
430    fn test_parse_login_form_secondary_username_lookup() {
431        let scp = _make_valid_ping_credentials_provider();
432        // No id="username"; falls back to matching by name containing "user"
433        let html = r#"<html><body><form action="/login">
434        <INPUT type="text" name="user_email" value="" />
435        <INPUT type="password" name="password" value="" />
436        </form></body></html>"#;
437        let (payload, _) = scp.parse_login_form(html);
438        assert_eq!(payload.get("user_email").map(String::as_str), Some("user"));
439    }
440
441    #[test]
442    #[should_panic(expected = "Failed to parse Ping IdP login form field(s)")]
443    fn test_parse_login_form_missing_fields_panics() {
444        let scp = _make_valid_ping_credentials_provider();
445        scp.parse_login_form("<html><body><form></form></body></html>");
446    }
447
448    #[test]
449    #[should_panic(expected = "More than one password field")]
450    fn test_parse_login_form_duplicate_password_panics() {
451        let scp = _make_valid_ping_credentials_provider();
452        let html = r#"<html><body><form>
453        <INPUT type="text" name="username" id="username" value="" />
454        <INPUT type="password" name="pf.pass" value="" />
455        <INPUT type="password" name="pf.pass2" value="" />
456        </form></body></html>"#;
457        scp.parse_login_form(html);
458    }
459
460    // get_form_action tests
461
462    fn _parse(html: &str) -> Html {
463        Html::parse_document(html)
464    }
465
466    #[test]
467    fn test_get_form_action_returns_action_for_post_form() {
468        let soup =
469            _parse(r#"<html><body><form action="/submit" method="POST"></form></body></html>"#);
470        assert_eq!(get_form_action(&soup), Some("/submit"));
471    }
472
473    #[test]
474    fn test_get_form_action_returns_action_when_no_method_attribute() {
475        // method is None -> the non-POST check doesn't fire -> action is returned
476        let soup = _parse(r#"<html><body><form action="/submit"></form></body></html>"#);
477        assert_eq!(get_form_action(&soup), Some("/submit"));
478    }
479
480    #[test]
481    fn test_get_form_action_skips_non_post_form() {
482        let soup =
483            _parse(r#"<html><body><form action="/submit" method="GET"></form></body></html>"#);
484        assert_eq!(get_form_action(&soup), None);
485    }
486
487    #[test]
488    fn test_get_form_action_returns_none_when_no_action() {
489        let soup = _parse(r#"<html><body><form method="POST"></form></body></html>"#);
490        assert_eq!(get_form_action(&soup), None);
491    }
492
493    #[test]
494    fn test_get_form_action_returns_none_when_no_form() {
495        let soup = _parse(r#"<html><body></body></html>"#);
496        assert_eq!(get_form_action(&soup), None);
497    }
498
499    #[test]
500    fn test_get_form_action_skips_non_post_returns_second_form_action() {
501        // First form has method=GET (skipped), second has a valid action
502        let soup = _parse(
503            r#"<html><body>
504            <form action="/bad" method="GET"></form>
505            <form action="/good" method="POST"></form>
506        </body></html>"#,
507        );
508        assert_eq!(get_form_action(&soup), Some("/good"));
509    }
510}