Skip to main content

redshift_iam/
saml_provider.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::str;
4
5// use aws_config;
6use aws_sdk_sts as sts;
7use base64::prelude::*;
8use log::{debug, warn};
9// use reqwest;
10use scraper::{ElementRef, Html, Selector};
11use secrecy::{ExposeSecret, SecretString};
12use tokio::runtime::Runtime;
13
14use crate::re;
15
16/// Trait for identity providers that can supply a SAML assertion.
17pub trait SamlProvider {
18    /// Fetches and returns a base64-encoded SAML assertion from the IdP.
19    fn get_saml_assertion(&self) -> impl Future<Output = String>;
20}
21
22/// Returns `true` if the input tag has `type="password"`.
23fn is_password(inputtag: &ElementRef) -> bool {
24    inputtag.attr("type") == Some("password")
25}
26
27/// Returns `true` if the input tag has `type="text"`.
28fn is_text(inputtag: &ElementRef) -> bool {
29    inputtag.attr("type") == Some("text")
30}
31
32/// Finds the first form `action` attribute whose method is POST (or unspecified).
33/// Forms with an explicit non-POST method are skipped. Returns `None` if no
34/// qualifying form is found.
35fn get_form_action(soup: &Html) -> Option<&str> {
36    // NOTE: selector case-insensitive; it will match both form and FORM
37    let selector = Selector::parse("form").unwrap();
38
39    for inputtag in soup.select(&selector) {
40        let action = inputtag.attr("action");
41        if action.is_some() {
42            let method = inputtag.attr("method");
43            // safe unwrap
44            if method.is_some() && method.unwrap().to_uppercase() != "POST" {
45                warn!("Found action, but method is not POST. Skipping.");
46                continue;
47            }
48            return action;
49        }
50    }
51
52    None
53}
54
55/// Obtains temporary AWS credentials by exchanging a SAML assertion for STS credentials.
56///
57/// Calls [`SamlProvider::get_saml_assertion`], decodes the assertion, extracts the
58/// IAM role and principal ARNs, and calls `sts:AssumeRoleWithSAML` for `role_arn`.
59///
60/// # Panics
61/// - If no IAM roles are found in the SAML assertion.
62/// - If `role_arn` is not present among the roles in the assertion.
63pub async fn get_credentials<T: SamlProvider>(
64    provider: &T,
65    role_arn: String,
66) -> Option<sts::types::Credentials> {
67    // refresh method alias
68    let saml_assertion = provider.get_saml_assertion().await;
69
70    // decode SAML assertion into xml format
71    let ass_bytes = BASE64_STANDARD.decode(saml_assertion.as_bytes()).unwrap();
72    let doc = str::from_utf8(&ass_bytes).unwrap();
73
74    debug!("decoded SAML assertion into xml format");
75    // NOTE could parse it as xml, but keeping it lightweighted
76    let soup = Html::parse_document(doc);
77    let selector = Selector::parse(r"saml\:AttributeValue").unwrap();
78    let attrs = soup.select(&selector);
79
80    // extract RoleArn and PrincipleArn from SAML assertion
81    let role_pattern = re::compile(r"arn:aws:iam::\d*:role/\S+");
82    let provider_pattern = re::compile(r"arn:aws:iam::\d*:saml-provider/\S+");
83    let mut roles: HashMap<&str, String> = HashMap::new();
84    debug!("searching SAML assertion for values matching patterns for RoleArn and PrincipalArn");
85    // TODO: let user specify None as role and then pick the first one
86    for attr in attrs {
87        for value in attr.text() {
88            let mut role = "";
89            let mut provider = String::new();
90            for arn_ in value.split(",") {
91                let arn = arn_.trim();
92                if role_pattern.is_match(arn) {
93                    debug!("RoleArn pattern matched");
94                    role = arn;
95                }
96                if provider_pattern.is_match(arn) {
97                    debug!("PrincipleArn pattern matched");
98                    provider = arn.to_string();
99                }
100            }
101            if !role.is_empty() && !provider.is_empty() {
102                roles.insert(role, provider);
103            }
104        }
105    }
106    debug!("Done reading SAML assertion attributes");
107    debug!("{} roles identified in SAML assertion", roles.len());
108
109    if roles.is_empty() {
110        let exec_msg = "No roles were found in SAML assertion. Please verify IdP configuration provides ARNs in the SAML https://aws.amazon.com/SAML/Attributes/Role Attribute.";
111        panic!("{exec_msg}");
112    }
113    debug!("User provided preferred_role, trying to use...");
114    if !roles.contains_key(&*role_arn) {
115        let exec_msg = "User specified preferred_role was not found in SAML assertion https://aws.amazon.com/SAML/Attributes/Role Attribute";
116        panic!("{exec_msg}");
117    }
118
119    // empty config; no prior aws identity needed
120    let config = aws_config::load_from_env().await;
121    let client = sts::Client::new(&config);
122    debug!(
123        "Attempting to retrieve temporary AWS credentials using the SAML assertion, principal ARN, and role ARN."
124    );
125    let response = client
126        .assume_role_with_saml()
127        .set_principal_arn(roles.remove(&*role_arn)) // remove instead of get, so we move the value out and not get ref
128        .set_role_arn(Some(role_arn))
129        .saml_assertion(saml_assertion)
130        .send()
131        .await
132        .unwrap();
133    debug!("Extracting temporary AWS credentials from assume_role_with_saml response");
134
135    response.credentials
136}
137
138/// Extracts the SAMLResponse assertion value from the IdP authentication response HTML.
139/// Panics if no `SAMLResponse` input tag is found.
140///
141/// # Examples
142/// ```
143/// use redshift_iam::saml_provider::parse_saml_assertion;
144///
145/// let html = r#"<html><body>
146/// <form method="POST" action="https://signin.aws.amazon.com/saml">
147///   <INPUT type="hidden" name="SAMLResponse" value="dGVzdA==" />
148/// </form>
149/// </body></html>"#;
150/// assert_eq!(parse_saml_assertion(html), "dGVzdA==");
151/// ```
152pub fn parse_saml_assertion(html: &str) -> String {
153    let soup = Html::parse_document(html);
154    let selector = Selector::parse("INPUT").unwrap();
155    let mut assertion = String::new();
156    for inputtag in soup.select(&selector) {
157        if inputtag.attr("name") == Some("SAMLResponse") {
158            debug!("SAMLResponse tag found");
159            assertion = inputtag.attr("value").unwrap().to_string();
160        }
161    }
162    if assertion.is_empty() {
163        panic!(
164            "Failed to retrieve SAMLAssertion. An input tag named SAMLResponse was not identified in the Ping IdP authentication response"
165        );
166    }
167    assertion
168}
169
170/// PingFederate identity provider plugin for SAML-based Redshift authentication.
171///
172/// See the [Amazon Redshift IAM docs](https://docs.aws.amazon.com/redshift/latest/mgmt/options-for-providing-iam-credentials.html)
173/// for setup instructions.
174#[derive(Debug)]
175pub struct PingCredentialsProvider {
176    partner_sp_id: String,
177    idp_host: String,
178    idp_port: u16,
179    user_name: String,
180    password: SecretString,
181    /// When `true`, TLS certificate verification is disabled. Defaults to `false`.
182    pub ssl_insecure: bool,
183}
184
185impl PingCredentialsProvider {
186    /// Creates a new `PingCredentialsProvider`.
187    ///
188    /// - `partner_sp_id`: The SP entity ID sent to PingFederate. `None` defaults to
189    ///   `"urn%3Aamazon%3Awebservices"`.
190    /// - `idp_port`: Defaults to `443` when `None`.
191    ///
192    /// # Examples
193    /// ```
194    /// use secrecy::SecretString;
195    /// use redshift_iam::PingCredentialsProvider;
196    ///
197    /// let scp = PingCredentialsProvider::new(
198    ///     None::<String>,
199    ///     "pingfed.example.com",
200    ///     None,
201    ///     "alice",
202    ///     SecretString::new("s3cr3t".to_string().into_boxed_str()),
203    /// );
204    /// assert!(!scp.ssl_insecure);
205    /// assert!(scp.do_verify_ssl_cert());
206    /// assert_eq!(scp.user(), "alice");
207    /// ```
208    pub fn new(
209        partner_sp_id_option: Option<impl ToString>,
210        idp_host: impl ToString,
211        idp_port: Option<u16>,
212        user_name: impl ToString,
213        password: SecretString,
214    ) -> Self {
215        // We could either accept pwd and create secretString here or force user to pass it
216        let partner_sp_id = if let Some(partner_sp_id) = partner_sp_id_option {
217            partner_sp_id.to_string()
218        } else {
219            "urn%3Aamazon%3Awebservices".to_string()
220        };
221        Self {
222            partner_sp_id,
223            idp_host: idp_host.to_string(),
224            idp_port: idp_port.unwrap_or(443),
225            user_name: user_name.to_string(),
226            password,
227            ssl_insecure: false,
228        }
229    }
230
231    /// user getter
232    pub fn user(&self) -> String {
233        self.user_name.clone()
234    }
235
236    /// Returns `true` when TLS certificate verification is enabled (i.e. `ssl_insecure` is `false`).
237    pub fn do_verify_ssl_cert(&self) -> bool {
238        !self.ssl_insecure
239    }
240
241    /// Synchronously retrieves temporary AWS credentials for `preferred_role`.
242    ///
243    /// Drives the full SAML -> STS flow on a new Tokio runtime. Prefer the async
244    /// [`get_credentials`] free function when already inside an async context.
245    pub fn get_credentials(
246        &self,
247        preferred_role: impl ToString,
248    ) -> Option<sts::types::Credentials> {
249        let rt = Runtime::new().unwrap(); //?
250        rt.block_on(async { get_credentials(self, preferred_role.to_string()).await })
251    }
252
253    /// Parses the IdP login page HTML, extracting the form submission payload and
254    /// the form's action path. Panics if username or password fields cannot be found.
255    fn parse_login_form(&self, html: &str) -> (HashMap<String, String>, Option<String>) {
256        let soup = Html::parse_document(html);
257        let selector = Selector::parse("INPUT").unwrap();
258        let mut payload: HashMap<String, String> = HashMap::new();
259        let mut username_found = false;
260        let mut pwd_found = false;
261
262        debug!(
263            "Looking for username and password input tags in Ping IdP login page in order to build authentication request payload"
264        );
265        for inputtag in soup.select(&selector) {
266            let name = inputtag.attr("name").unwrap_or("").to_string();
267            let id_ = inputtag.attr("id").unwrap_or("");
268            debug!("name={name} , id={id_}");
269
270            if !username_found && is_text(&inputtag) && id_ == "username" {
271                debug!("Using tag with name {name} for username");
272                payload.insert(name, self.user());
273                username_found = true;
274            } else if is_password(&inputtag) && name.contains("pass") {
275                debug!("Using tag with name {name} for password");
276                if pwd_found {
277                    panic!(
278                        "Failed to parse Ping IdP login form. More than one password field was found on the Ping IdP login page"
279                    );
280                }
281                payload.insert(name, self.password.expose_secret().to_string());
282                pwd_found = true;
283            } else if !name.is_empty() {
284                let value = inputtag.attr("value").unwrap_or("").to_string();
285                payload.insert(name, value);
286            }
287        }
288
289        if !username_found {
290            debug!(
291                "username tag still not found, continuing search using secondary preferred tags"
292            );
293            for inputtag in soup.select(&selector) {
294                let name = inputtag.attr("name").unwrap_or("").to_string();
295                if is_text(&inputtag) && (name.contains("user") || name.contains("email")) {
296                    debug!("Using tag with name {name} for username");
297                    payload.insert(name, self.user());
298                    username_found = true;
299                }
300            }
301        }
302
303        if !username_found || !pwd_found {
304            panic!("Failed to parse Ping IdP login form field(s)");
305        }
306
307        let action = get_form_action(&soup).map(str::to_owned);
308        (payload, action)
309    }
310}
311
312impl SamlProvider for PingCredentialsProvider {
313    /// Logs in to the PingFederate IdP and returns a base64-encoded SAML assertion.
314    ///
315    /// Issues a GET to the SSO start URL, parses the login form, submits credentials,
316    /// and extracts the `SAMLResponse` value from the resulting page.
317    ///
318    /// # Panics
319    /// - If the login form cannot be parsed or credentials fields are missing.
320    /// - If the POST to the IdP returns a non-200 status.
321    /// - If no `SAMLResponse` input is found in the response.
322    async fn get_saml_assertion(&self) -> String {
323        // Method to grab the SAML Response. Used to refresh temporary credentials.
324        debug!("PingCredentialsProvider.get_saml_assertion");
325        let session = reqwest::Client::builder() // scoped only in this method
326            .cookie_store(true) // the PF=... session state cookie needs to be preserved
327            // .https_only(true)
328            .build()
329            .unwrap();
330
331        let mut url = format!(
332            "https://{}:{}/idp/startSSO.ping?PartnerSpId={}",
333            self.idp_host, self.idp_port, self.partner_sp_id,
334        );
335
336        debug!(
337            "Issuing GET request for Ping IdP login page using uri={} verify={}",
338            url,
339            self.do_verify_ssl_cert(),
340        );
341        let resp = session.get(&url).send().await.unwrap(); // TODO: , verify=self.do_verify_ssl_cert()
342        debug!("Response code: {}", resp.status());
343        debug!("response length: {}", resp.content_length().unwrap_or(0));
344
345        let resp_text = resp.text().await.unwrap();
346        let (payload, action) = self.parse_login_form(&resp_text);
347
348        // NOTE: not sure if we want to continue with the original url in None case
349        if let Some(action_str) = action.as_deref()
350            && action_str.starts_with("/")
351        {
352            url = format!("https://{}:{}{action_str}", self.idp_host, self.idp_port);
353        }
354        // else {
355        //     panic!();
356        // }
357
358        debug!(
359            "Issuing authentication request to Ping IdP using uri {} verify {}",
360            &url,
361            self.do_verify_ssl_cert(),
362        );
363        let response = session
364            .post(&url) //verify=self.do_verify_ssl_cert()
365            .form(&payload)
366            .send()
367            .await
368            .unwrap();
369        let status_code = response.status();
370        debug!("Response code: {status_code}");
371        let resp_text = response.text().await.unwrap();
372        if status_code != 200 {
373            panic!(
374                "POST to {url} returned non-200 http status.\n{}",
375                &resp_text
376            );
377        }
378
379        parse_saml_assertion(&resp_text)
380    }
381}
382
383#[cfg(test)]
384mod tests {
385    use super::*;
386
387    fn _make_valid_ping_credentials_provider() -> PingCredentialsProvider {
388        PingCredentialsProvider::new(
389            None::<String>,
390            "example.example.com",
391            None,
392            "user",
393            SecretString::new("pwd".to_string().into_boxed_str()),
394        )
395    }
396
397    // parse_login_form tests
398    const LOGIN_PAGE_HTML: &str = r#"<html><body>
399    <form action="/idp/authLogin" method="POST">
400    <INPUT type="text" name="username" id="username" value="" />
401    <INPUT type="password" name="pf.pass" value="" />
402    <INPUT type="hidden" name="pf.ok" value="clicked" />
403    </form>
404    </body></html>"#;
405
406    #[test]
407    fn test_parse_login_form_extracts_credentials_and_hidden_fields() {
408        let scp = _make_valid_ping_credentials_provider();
409        let (payload, action) = scp.parse_login_form(LOGIN_PAGE_HTML);
410        assert_eq!(payload.get("username").map(String::as_str), Some("user"));
411        assert_eq!(payload.get("pf.pass").map(String::as_str), Some("pwd"));
412        assert_eq!(payload.get("pf.ok").map(String::as_str), Some("clicked"));
413        assert_eq!(action.as_deref(), Some("/idp/authLogin"));
414    }
415
416    #[test]
417    fn test_parse_login_form_secondary_username_lookup() {
418        let scp = _make_valid_ping_credentials_provider();
419        // No id="username"; falls back to matching by name containing "user"
420        let html = r#"<html><body><form action="/login">
421        <INPUT type="text" name="user_email" value="" />
422        <INPUT type="password" name="password" value="" />
423        </form></body></html>"#;
424        let (payload, _) = scp.parse_login_form(html);
425        assert_eq!(payload.get("user_email").map(String::as_str), Some("user"));
426    }
427
428    #[test]
429    #[should_panic(expected = "Failed to parse Ping IdP login form field(s)")]
430    fn test_parse_login_form_missing_fields_panics() {
431        let scp = _make_valid_ping_credentials_provider();
432        scp.parse_login_form("<html><body><form></form></body></html>");
433    }
434
435    #[test]
436    #[should_panic(expected = "More than one password field")]
437    fn test_parse_login_form_duplicate_password_panics() {
438        let scp = _make_valid_ping_credentials_provider();
439        let html = r#"<html><body><form>
440        <INPUT type="text" name="username" id="username" value="" />
441        <INPUT type="password" name="pf.pass" value="" />
442        <INPUT type="password" name="pf.pass2" value="" />
443        </form></body></html>"#;
444        scp.parse_login_form(html);
445    }
446
447    // get_form_action tests
448
449    fn _parse(html: &str) -> Html {
450        Html::parse_document(html)
451    }
452
453    #[test]
454    fn test_get_form_action_returns_action_for_post_form() {
455        let soup =
456            _parse(r#"<html><body><form action="/submit" method="POST"></form></body></html>"#);
457        assert_eq!(get_form_action(&soup), Some("/submit"));
458    }
459
460    #[test]
461    fn test_get_form_action_returns_action_when_no_method_attribute() {
462        // method is None -> the non-POST check doesn't fire -> action is returned
463        let soup = _parse(r#"<html><body><form action="/submit"></form></body></html>"#);
464        assert_eq!(get_form_action(&soup), Some("/submit"));
465    }
466
467    #[test]
468    fn test_get_form_action_skips_non_post_form() {
469        let soup =
470            _parse(r#"<html><body><form action="/submit" method="GET"></form></body></html>"#);
471        assert_eq!(get_form_action(&soup), None);
472    }
473
474    #[test]
475    fn test_get_form_action_returns_none_when_no_action() {
476        let soup = _parse(r#"<html><body><form method="POST"></form></body></html>"#);
477        assert_eq!(get_form_action(&soup), None);
478    }
479
480    #[test]
481    fn test_get_form_action_returns_none_when_no_form() {
482        let soup = _parse(r#"<html><body></body></html>"#);
483        assert_eq!(get_form_action(&soup), None);
484    }
485
486    #[test]
487    fn test_get_form_action_skips_non_post_returns_second_form_action() {
488        // First form has method=GET (skipped), second has a valid action
489        let soup = _parse(
490            r#"<html><body>
491            <form action="/bad" method="GET"></form>
492            <form action="/good" method="POST"></form>
493        </body></html>"#,
494        );
495        assert_eq!(get_form_action(&soup), Some("/good"));
496    }
497}