redshift_iam/
saml_provider.rs1use std::collections::HashMap;
2use std::future::Future;
3use std::str;
4
5use aws_sdk_sts as sts;
7use base64::prelude::*;
8use log::{debug, warn};
9use scraper::{ElementRef, Html, Selector};
11use secrecy::{ExposeSecret, SecretString};
12use tokio::runtime::Runtime;
13
14use crate::re;
15
16pub trait SamlProvider {
18 fn get_saml_assertion(&self) -> impl Future<Output = String>;
20}
21
22fn is_password(inputtag: &ElementRef) -> bool {
24 inputtag.attr("type") == Some("password")
25}
26
27fn is_text(inputtag: &ElementRef) -> bool {
29 inputtag.attr("type") == Some("text")
30}
31
32fn get_form_action(soup: &Html) -> Option<&str> {
36 let selector = Selector::parse("form").unwrap();
38
39 for inputtag in soup.select(&selector) {
40 let action = inputtag.attr("action");
41 if action.is_some() {
42 let method = inputtag.attr("method");
43 if method.is_some() && method.unwrap().to_uppercase() != "POST" {
45 warn!("Found action, but method is not POST. Skipping.");
46 continue;
47 }
48 return action;
49 }
50 }
51
52 None
53}
54
55pub async fn get_credentials<T: SamlProvider>(
64 provider: &T,
65 role_arn: String,
66) -> Option<sts::types::Credentials> {
67 let saml_assertion = provider.get_saml_assertion().await;
69
70 let ass_bytes = BASE64_STANDARD.decode(saml_assertion.as_bytes()).unwrap();
72 let doc = str::from_utf8(&ass_bytes).unwrap();
73
74 debug!("decoded SAML assertion into xml format");
75 let soup = Html::parse_document(doc);
77 let selector = Selector::parse(r"saml\:AttributeValue").unwrap();
78 let attrs = soup.select(&selector);
79
80 let role_pattern = re::compile(r"arn:aws:iam::\d*:role/\S+");
82 let provider_pattern = re::compile(r"arn:aws:iam::\d*:saml-provider/\S+");
83 let mut roles: HashMap<&str, String> = HashMap::new();
84 debug!("searching SAML assertion for values matching patterns for RoleArn and PrincipalArn");
85 for attr in attrs {
87 for value in attr.text() {
88 let mut role = "";
89 let mut provider = String::new();
90 for arn_ in value.split(",") {
91 let arn = arn_.trim();
92 if role_pattern.is_match(arn) {
93 debug!("RoleArn pattern matched");
94 role = arn;
95 }
96 if provider_pattern.is_match(arn) {
97 debug!("PrincipleArn pattern matched");
98 provider = arn.to_string();
99 }
100 }
101 if !role.is_empty() && !provider.is_empty() {
102 roles.insert(role, provider);
103 }
104 }
105 }
106 debug!("Done reading SAML assertion attributes");
107 debug!("{} roles identified in SAML assertion", roles.len());
108
109 if roles.is_empty() {
110 let exec_msg = "No roles were found in SAML assertion. Please verify IdP configuration provides ARNs in the SAML https://aws.amazon.com/SAML/Attributes/Role Attribute.";
111 panic!("{exec_msg}");
112 }
113 debug!("User provided preferred_role, trying to use...");
114 if !roles.contains_key(&*role_arn) {
115 let exec_msg = "User specified preferred_role was not found in SAML assertion https://aws.amazon.com/SAML/Attributes/Role Attribute";
116 panic!("{exec_msg}");
117 }
118
119 let config = aws_config::load_from_env().await;
121 let client = sts::Client::new(&config);
122 debug!(
123 "Attempting to retrieve temporary AWS credentials using the SAML assertion, principal ARN, and role ARN."
124 );
125 let response = client
126 .assume_role_with_saml()
127 .set_principal_arn(roles.remove(&*role_arn)) .set_role_arn(Some(role_arn))
129 .saml_assertion(saml_assertion)
130 .send()
131 .await
132 .unwrap();
133 debug!("Extracting temporary AWS credentials from assume_role_with_saml response");
134
135 response.credentials
136}
137
138pub fn parse_saml_assertion(html: &str) -> String {
153 let soup = Html::parse_document(html);
154 let selector = Selector::parse("INPUT").unwrap();
155 let mut assertion = String::new();
156 for inputtag in soup.select(&selector) {
157 if inputtag.attr("name") == Some("SAMLResponse") {
158 debug!("SAMLResponse tag found");
159 assertion = inputtag.attr("value").unwrap().to_string();
160 }
161 }
162 if assertion.is_empty() {
163 panic!(
164 "Failed to retrieve SAMLAssertion. An input tag named SAMLResponse was not identified in the Ping IdP authentication response"
165 );
166 }
167 assertion
168}
169
170#[derive(Debug)]
175pub struct PingCredentialsProvider {
176 partner_sp_id: String,
177 idp_host: String,
178 idp_port: u16,
179 user_name: String,
180 password: SecretString,
181 pub ssl_insecure: bool,
183}
184
185impl PingCredentialsProvider {
186 pub fn new(
209 partner_sp_id_option: Option<impl ToString>,
210 idp_host: impl ToString,
211 idp_port: Option<u16>,
212 user_name: impl ToString,
213 password: SecretString,
214 ) -> Self {
215 let partner_sp_id = if let Some(partner_sp_id) = partner_sp_id_option {
217 partner_sp_id.to_string()
218 } else {
219 "urn%3Aamazon%3Awebservices".to_string()
220 };
221 Self {
222 partner_sp_id,
223 idp_host: idp_host.to_string(),
224 idp_port: idp_port.unwrap_or(443),
225 user_name: user_name.to_string(),
226 password,
227 ssl_insecure: false,
228 }
229 }
230
231 pub fn user(&self) -> String {
233 self.user_name.clone()
234 }
235
236 pub fn do_verify_ssl_cert(&self) -> bool {
238 !self.ssl_insecure
239 }
240
241 pub fn get_credentials(
246 &self,
247 preferred_role: impl ToString,
248 ) -> Option<sts::types::Credentials> {
249 let rt = Runtime::new().unwrap(); rt.block_on(async { get_credentials(self, preferred_role.to_string()).await })
251 }
252
253 fn parse_login_form(&self, html: &str) -> (HashMap<String, String>, Option<String>) {
256 let soup = Html::parse_document(html);
257 let selector = Selector::parse("INPUT").unwrap();
258 let mut payload: HashMap<String, String> = HashMap::new();
259 let mut username_found = false;
260 let mut pwd_found = false;
261
262 debug!(
263 "Looking for username and password input tags in Ping IdP login page in order to build authentication request payload"
264 );
265 for inputtag in soup.select(&selector) {
266 let name = inputtag.attr("name").unwrap_or("").to_string();
267 let id_ = inputtag.attr("id").unwrap_or("");
268 debug!("name={name} , id={id_}");
269
270 if !username_found && is_text(&inputtag) && id_ == "username" {
271 debug!("Using tag with name {name} for username");
272 payload.insert(name, self.user());
273 username_found = true;
274 } else if is_password(&inputtag) && name.contains("pass") {
275 debug!("Using tag with name {name} for password");
276 if pwd_found {
277 panic!(
278 "Failed to parse Ping IdP login form. More than one password field was found on the Ping IdP login page"
279 );
280 }
281 payload.insert(name, self.password.expose_secret().to_string());
282 pwd_found = true;
283 } else if !name.is_empty() {
284 let value = inputtag.attr("value").unwrap_or("").to_string();
285 payload.insert(name, value);
286 }
287 }
288
289 if !username_found {
290 debug!(
291 "username tag still not found, continuing search using secondary preferred tags"
292 );
293 for inputtag in soup.select(&selector) {
294 let name = inputtag.attr("name").unwrap_or("").to_string();
295 if is_text(&inputtag) && (name.contains("user") || name.contains("email")) {
296 debug!("Using tag with name {name} for username");
297 payload.insert(name, self.user());
298 username_found = true;
299 }
300 }
301 }
302
303 if !username_found || !pwd_found {
304 panic!("Failed to parse Ping IdP login form field(s)");
305 }
306
307 let action = get_form_action(&soup).map(str::to_owned);
308 (payload, action)
309 }
310}
311
312impl SamlProvider for PingCredentialsProvider {
313 async fn get_saml_assertion(&self) -> String {
323 debug!("PingCredentialsProvider.get_saml_assertion");
325 let session = reqwest::Client::builder() .cookie_store(true) .build()
329 .unwrap();
330
331 let mut url = format!(
332 "https://{}:{}/idp/startSSO.ping?PartnerSpId={}",
333 self.idp_host, self.idp_port, self.partner_sp_id,
334 );
335
336 debug!(
337 "Issuing GET request for Ping IdP login page using uri={} verify={}",
338 url,
339 self.do_verify_ssl_cert(),
340 );
341 let resp = session.get(&url).send().await.unwrap(); debug!("Response code: {}", resp.status());
343 debug!("response length: {}", resp.content_length().unwrap_or(0));
344
345 let resp_text = resp.text().await.unwrap();
346 let (payload, action) = self.parse_login_form(&resp_text);
347
348 if let Some(action_str) = action.as_deref()
350 && action_str.starts_with("/")
351 {
352 url = format!("https://{}:{}{action_str}", self.idp_host, self.idp_port);
353 }
354 debug!(
359 "Issuing authentication request to Ping IdP using uri {} verify {}",
360 &url,
361 self.do_verify_ssl_cert(),
362 );
363 let response = session
364 .post(&url) .form(&payload)
366 .send()
367 .await
368 .unwrap();
369 let status_code = response.status();
370 debug!("Response code: {status_code}");
371 let resp_text = response.text().await.unwrap();
372 if status_code != 200 {
373 panic!(
374 "POST to {url} returned non-200 http status.\n{}",
375 &resp_text
376 );
377 }
378
379 parse_saml_assertion(&resp_text)
380 }
381}
382
383#[cfg(test)]
384mod tests {
385 use super::*;
386
387 fn _make_valid_ping_credentials_provider() -> PingCredentialsProvider {
388 PingCredentialsProvider::new(
389 None::<String>,
390 "example.example.com",
391 None,
392 "user",
393 SecretString::new("pwd".to_string().into_boxed_str()),
394 )
395 }
396
397 const LOGIN_PAGE_HTML: &str = r#"<html><body>
399 <form action="/idp/authLogin" method="POST">
400 <INPUT type="text" name="username" id="username" value="" />
401 <INPUT type="password" name="pf.pass" value="" />
402 <INPUT type="hidden" name="pf.ok" value="clicked" />
403 </form>
404 </body></html>"#;
405
406 #[test]
407 fn test_parse_login_form_extracts_credentials_and_hidden_fields() {
408 let scp = _make_valid_ping_credentials_provider();
409 let (payload, action) = scp.parse_login_form(LOGIN_PAGE_HTML);
410 assert_eq!(payload.get("username").map(String::as_str), Some("user"));
411 assert_eq!(payload.get("pf.pass").map(String::as_str), Some("pwd"));
412 assert_eq!(payload.get("pf.ok").map(String::as_str), Some("clicked"));
413 assert_eq!(action.as_deref(), Some("/idp/authLogin"));
414 }
415
416 #[test]
417 fn test_parse_login_form_secondary_username_lookup() {
418 let scp = _make_valid_ping_credentials_provider();
419 let html = r#"<html><body><form action="/login">
421 <INPUT type="text" name="user_email" value="" />
422 <INPUT type="password" name="password" value="" />
423 </form></body></html>"#;
424 let (payload, _) = scp.parse_login_form(html);
425 assert_eq!(payload.get("user_email").map(String::as_str), Some("user"));
426 }
427
428 #[test]
429 #[should_panic(expected = "Failed to parse Ping IdP login form field(s)")]
430 fn test_parse_login_form_missing_fields_panics() {
431 let scp = _make_valid_ping_credentials_provider();
432 scp.parse_login_form("<html><body><form></form></body></html>");
433 }
434
435 #[test]
436 #[should_panic(expected = "More than one password field")]
437 fn test_parse_login_form_duplicate_password_panics() {
438 let scp = _make_valid_ping_credentials_provider();
439 let html = r#"<html><body><form>
440 <INPUT type="text" name="username" id="username" value="" />
441 <INPUT type="password" name="pf.pass" value="" />
442 <INPUT type="password" name="pf.pass2" value="" />
443 </form></body></html>"#;
444 scp.parse_login_form(html);
445 }
446
447 fn _parse(html: &str) -> Html {
450 Html::parse_document(html)
451 }
452
453 #[test]
454 fn test_get_form_action_returns_action_for_post_form() {
455 let soup =
456 _parse(r#"<html><body><form action="/submit" method="POST"></form></body></html>"#);
457 assert_eq!(get_form_action(&soup), Some("/submit"));
458 }
459
460 #[test]
461 fn test_get_form_action_returns_action_when_no_method_attribute() {
462 let soup = _parse(r#"<html><body><form action="/submit"></form></body></html>"#);
464 assert_eq!(get_form_action(&soup), Some("/submit"));
465 }
466
467 #[test]
468 fn test_get_form_action_skips_non_post_form() {
469 let soup =
470 _parse(r#"<html><body><form action="/submit" method="GET"></form></body></html>"#);
471 assert_eq!(get_form_action(&soup), None);
472 }
473
474 #[test]
475 fn test_get_form_action_returns_none_when_no_action() {
476 let soup = _parse(r#"<html><body><form method="POST"></form></body></html>"#);
477 assert_eq!(get_form_action(&soup), None);
478 }
479
480 #[test]
481 fn test_get_form_action_returns_none_when_no_form() {
482 let soup = _parse(r#"<html><body></body></html>"#);
483 assert_eq!(get_form_action(&soup), None);
484 }
485
486 #[test]
487 fn test_get_form_action_skips_non_post_returns_second_form_action() {
488 let soup = _parse(
490 r#"<html><body>
491 <form action="/bad" method="GET"></form>
492 <form action="/good" method="POST"></form>
493 </body></html>"#,
494 );
495 assert_eq!(get_form_action(&soup), Some("/good"));
496 }
497}