redshift_iam/
saml_provider.rs1use std::borrow::Cow;
2use std::collections::HashMap;
3use std::str;
4
5use aws_sdk_sts as sts;
7use base64::prelude::*;
8use log::{debug, warn};
9use scraper::{ElementRef, Html, Selector};
11use secrecy::{ExposeSecret, SecretString};
12use tokio::runtime::Runtime;
13
14use crate::re;
15
16#[async_trait::async_trait]
21pub trait SamlProvider: Send + Sync {
22 async fn get_saml_assertion(&self) -> String;
24}
25
26fn is_password(inputtag: &ElementRef) -> bool {
28 inputtag.attr("type") == Some("password")
29}
30
31fn is_text(inputtag: &ElementRef) -> bool {
33 inputtag.attr("type") == Some("text")
34}
35
36fn get_form_action(soup: &Html) -> Option<&str> {
40 let selector = Selector::parse("form").unwrap();
42
43 for inputtag in soup.select(&selector) {
44 let action = inputtag.attr("action");
45 if action.is_some() {
46 let method = inputtag.attr("method");
47 if method.is_some() && method.unwrap().to_uppercase() != "POST" {
49 warn!("Found action, but method is not POST. Skipping.");
50 continue;
51 }
52 return action;
53 }
54 }
55
56 None
57}
58
59pub async fn get_credentials(
68 provider: &dyn SamlProvider,
69 role_arn: String,
70) -> Option<sts::types::Credentials> {
71 let saml_assertion = provider.get_saml_assertion().await;
73
74 let ass_bytes = BASE64_STANDARD.decode(saml_assertion.as_bytes()).unwrap();
76 let doc = str::from_utf8(&ass_bytes).unwrap();
77
78 debug!("decoded SAML assertion into xml format");
79 let soup = Html::parse_document(doc);
81 let selector = Selector::parse(r"saml\:AttributeValue").unwrap();
82 let attrs = soup.select(&selector);
83
84 let role_pattern = re::compile(r"arn:aws:iam::\d*:role/\S+");
86 let provider_pattern = re::compile(r"arn:aws:iam::\d*:saml-provider/\S+");
87 let mut roles: HashMap<&str, String> = HashMap::new();
88 debug!("searching SAML assertion for values matching patterns for RoleArn and PrincipalArn");
89 for attr in attrs {
91 for value in attr.text() {
92 let mut role = "";
93 let mut provider = String::new();
94 for arn_ in value.split(",") {
95 let arn = arn_.trim();
96 if role_pattern.is_match(arn) {
97 debug!("RoleArn pattern matched");
98 role = arn;
99 }
100 if provider_pattern.is_match(arn) {
101 debug!("PrincipleArn pattern matched");
102 provider = arn.to_string();
103 }
104 }
105 if !role.is_empty() && !provider.is_empty() {
106 roles.insert(role, provider);
107 }
108 }
109 }
110 debug!("Done reading SAML assertion attributes");
111 debug!("{} roles identified in SAML assertion", roles.len());
112
113 if roles.is_empty() {
114 let exec_msg = "No roles were found in SAML assertion. Please verify IdP configuration provides ARNs in the SAML https://aws.amazon.com/SAML/Attributes/Role Attribute.";
115 panic!("{exec_msg}");
116 }
117 debug!("User provided preferred_role, trying to use...");
118 if !roles.contains_key(&*role_arn) {
119 let exec_msg = "User specified preferred_role was not found in SAML assertion https://aws.amazon.com/SAML/Attributes/Role Attribute";
120 panic!("{exec_msg}");
121 }
122
123 let config = aws_config::load_from_env().await;
125 let client = sts::Client::new(&config);
126 debug!(
127 "Attempting to retrieve temporary AWS credentials using the SAML assertion, principal ARN, and role ARN."
128 );
129 let response = client
130 .assume_role_with_saml()
131 .set_principal_arn(roles.remove(&*role_arn)) .set_role_arn(Some(role_arn))
133 .saml_assertion(saml_assertion)
134 .send()
135 .await
136 .unwrap();
137 debug!("Extracting temporary AWS credentials from assume_role_with_saml response");
138
139 response.credentials
140}
141
142fn parse_saml_assertion(html: &str) -> String {
157 let soup = Html::parse_document(html);
158 let selector = Selector::parse("INPUT").unwrap();
159 let mut assertion = String::new();
160 for inputtag in soup.select(&selector) {
161 if inputtag.attr("name") == Some("SAMLResponse") {
162 debug!("SAMLResponse tag found");
163 assertion = inputtag.attr("value").unwrap().to_string();
164 }
165 }
166 if assertion.is_empty() {
167 panic!(
168 "Failed to retrieve SAMLAssertion. An input tag named SAMLResponse was not identified in the Ping IdP authentication response"
169 );
170 }
171 assertion
172}
173
174#[derive(Debug)]
179pub struct PingCredentialsProvider {
180 partner_sp_id: String,
181 idp_host: String,
182 idp_port: u16,
183 user_name: String,
184 password: SecretString,
185 pub ssl_insecure: bool,
187}
188
189impl PingCredentialsProvider {
190 pub fn new(
214 conn_parameters: &HashMap<String, Cow<str>>,
215 idp_host: impl ToString,
216 idp_port: Option<u16>,
217 user_name: impl ToString,
218 password: SecretString,
219 ) -> Self {
220 let partner_sp_id_option = conn_parameters.get("partnerspid");
222 let partner_sp_id = if let Some(partner_sp_id) = partner_sp_id_option {
223 partner_sp_id.to_string()
224 } else {
225 "urn%3Aamazon%3Awebservices".to_string()
226 };
227 Self {
228 partner_sp_id,
229 idp_host: idp_host.to_string(),
230 idp_port: idp_port.unwrap_or(443),
231 user_name: user_name.to_string(),
232 password,
233 ssl_insecure: false,
234 }
235 }
236
237 pub fn user(&self) -> String {
239 self.user_name.clone()
240 }
241
242 pub fn do_verify_ssl_cert(&self) -> bool {
244 !self.ssl_insecure
245 }
246
247 pub fn get_credentials(
252 &self,
253 preferred_role: impl ToString,
254 ) -> Option<sts::types::Credentials> {
255 let rt = Runtime::new().unwrap(); rt.block_on(async { get_credentials(self, preferred_role.to_string()).await })
257 }
258
259 fn parse_login_form(&self, html: &str) -> (HashMap<String, String>, Option<String>) {
262 let soup = Html::parse_document(html);
263 let selector = Selector::parse("INPUT").unwrap();
264 let mut payload: HashMap<String, String> = HashMap::new();
265 let mut username_found = false;
266 let mut pwd_found = false;
267
268 debug!(
269 "Looking for username and password input tags in Ping IdP login page in order to build authentication request payload"
270 );
271 for inputtag in soup.select(&selector) {
272 let name = inputtag.attr("name").unwrap_or("").to_string();
273 let id_ = inputtag.attr("id").unwrap_or("");
274 debug!("name={name} , id={id_}");
275
276 if !username_found && is_text(&inputtag) && id_ == "username" {
277 debug!("Using tag with name {name} for username");
278 payload.insert(name, self.user());
279 username_found = true;
280 } else if is_password(&inputtag) && name.contains("pass") {
281 debug!("Using tag with name {name} for password");
282 if pwd_found {
283 panic!(
284 "Failed to parse Ping IdP login form. More than one password field was found on the Ping IdP login page"
285 );
286 }
287 payload.insert(name, self.password.expose_secret().to_string());
288 pwd_found = true;
289 } else if !name.is_empty() {
290 let value = inputtag.attr("value").unwrap_or("").to_string();
291 payload.insert(name, value);
292 }
293 }
294
295 if !username_found {
296 debug!(
297 "username tag still not found, continuing search using secondary preferred tags"
298 );
299 for inputtag in soup.select(&selector) {
300 let name = inputtag.attr("name").unwrap_or("").to_string();
301 if is_text(&inputtag) && (name.contains("user") || name.contains("email")) {
302 debug!("Using tag with name {name} for username");
303 payload.insert(name, self.user());
304 username_found = true;
305 }
306 }
307 }
308
309 if !username_found || !pwd_found {
310 panic!("Failed to parse Ping IdP login form field(s)");
311 }
312
313 let action = get_form_action(&soup).map(str::to_owned);
314 (payload, action)
315 }
316}
317
318#[async_trait::async_trait]
319impl SamlProvider for PingCredentialsProvider {
320 async fn get_saml_assertion(&self) -> String {
330 debug!("PingCredentialsProvider.get_saml_assertion");
332 let session = reqwest::Client::builder() .cookie_store(true) .build()
336 .unwrap();
337
338 let mut url = format!(
339 "https://{}:{}/idp/startSSO.ping?PartnerSpId={}",
340 self.idp_host, self.idp_port, self.partner_sp_id,
341 );
342
343 debug!(
344 "Issuing GET request for Ping IdP login page using uri={} verify={}",
345 url,
346 self.do_verify_ssl_cert(),
347 );
348 let resp = session.get(&url).send().await.unwrap(); debug!("Response code: {}", resp.status());
350 debug!("response length: {}", resp.content_length().unwrap_or(0));
351
352 let resp_text = resp.text().await.unwrap();
353 let (payload, action) = self.parse_login_form(&resp_text);
354
355 if let Some(action_str) = action.as_deref()
357 && action_str.starts_with("/")
358 {
359 url = format!("https://{}:{}{action_str}", self.idp_host, self.idp_port);
360 }
361 debug!(
366 "Issuing authentication request to Ping IdP using uri {} verify {}",
367 &url,
368 self.do_verify_ssl_cert(),
369 );
370 let response = session
371 .post(&url) .form(&payload)
373 .send()
374 .await
375 .unwrap();
376 let status_code = response.status();
377 debug!("Response code: {status_code}");
378 let resp_text = response.text().await.unwrap();
379 if status_code != 200 {
380 panic!(
381 "POST to {url} returned non-200 http status.\n{}",
382 &resp_text
383 );
384 }
385
386 parse_saml_assertion(&resp_text)
387 }
388}
389
390#[cfg(test)]
391mod tests {
392 use super::*;
393
394 fn _make_valid_ping_credentials_provider() -> PingCredentialsProvider {
395 PingCredentialsProvider::new(
396 &HashMap::new(),
397 "example.example.com",
398 None,
399 "user",
400 SecretString::new("pwd".to_string().into_boxed_str()),
401 )
402 }
403
404 #[test]
405 #[should_panic(expected = "Failed to retrieve SAMLAssertion")]
406 fn test_parse_saml_assertion_missing_panics() {
407 parse_saml_assertion("<html><body><form></form></body></html>");
408 }
409
410 const LOGIN_PAGE_HTML: &str = r#"<html><body>
412 <form action="/idp/authLogin" method="POST">
413 <INPUT type="text" name="username" id="username" value="" />
414 <INPUT type="password" name="pf.pass" value="" />
415 <INPUT type="hidden" name="pf.ok" value="clicked" />
416 </form>
417 </body></html>"#;
418
419 #[test]
420 fn test_parse_login_form_extracts_credentials_and_hidden_fields() {
421 let scp = _make_valid_ping_credentials_provider();
422 let (payload, action) = scp.parse_login_form(LOGIN_PAGE_HTML);
423 assert_eq!(payload.get("username").map(String::as_str), Some("user"));
424 assert_eq!(payload.get("pf.pass").map(String::as_str), Some("pwd"));
425 assert_eq!(payload.get("pf.ok").map(String::as_str), Some("clicked"));
426 assert_eq!(action.as_deref(), Some("/idp/authLogin"));
427 }
428
429 #[test]
430 fn test_parse_login_form_secondary_username_lookup() {
431 let scp = _make_valid_ping_credentials_provider();
432 let html = r#"<html><body><form action="/login">
434 <INPUT type="text" name="user_email" value="" />
435 <INPUT type="password" name="password" value="" />
436 </form></body></html>"#;
437 let (payload, _) = scp.parse_login_form(html);
438 assert_eq!(payload.get("user_email").map(String::as_str), Some("user"));
439 }
440
441 #[test]
442 #[should_panic(expected = "Failed to parse Ping IdP login form field(s)")]
443 fn test_parse_login_form_missing_fields_panics() {
444 let scp = _make_valid_ping_credentials_provider();
445 scp.parse_login_form("<html><body><form></form></body></html>");
446 }
447
448 #[test]
449 #[should_panic(expected = "More than one password field")]
450 fn test_parse_login_form_duplicate_password_panics() {
451 let scp = _make_valid_ping_credentials_provider();
452 let html = r#"<html><body><form>
453 <INPUT type="text" name="username" id="username" value="" />
454 <INPUT type="password" name="pf.pass" value="" />
455 <INPUT type="password" name="pf.pass2" value="" />
456 </form></body></html>"#;
457 scp.parse_login_form(html);
458 }
459
460 fn _parse(html: &str) -> Html {
463 Html::parse_document(html)
464 }
465
466 #[test]
467 fn test_get_form_action_returns_action_for_post_form() {
468 let soup =
469 _parse(r#"<html><body><form action="/submit" method="POST"></form></body></html>"#);
470 assert_eq!(get_form_action(&soup), Some("/submit"));
471 }
472
473 #[test]
474 fn test_get_form_action_returns_action_when_no_method_attribute() {
475 let soup = _parse(r#"<html><body><form action="/submit"></form></body></html>"#);
477 assert_eq!(get_form_action(&soup), Some("/submit"));
478 }
479
480 #[test]
481 fn test_get_form_action_skips_non_post_form() {
482 let soup =
483 _parse(r#"<html><body><form action="/submit" method="GET"></form></body></html>"#);
484 assert_eq!(get_form_action(&soup), None);
485 }
486
487 #[test]
488 fn test_get_form_action_returns_none_when_no_action() {
489 let soup = _parse(r#"<html><body><form method="POST"></form></body></html>"#);
490 assert_eq!(get_form_action(&soup), None);
491 }
492
493 #[test]
494 fn test_get_form_action_returns_none_when_no_form() {
495 let soup = _parse(r#"<html><body></body></html>"#);
496 assert_eq!(get_form_action(&soup), None);
497 }
498
499 #[test]
500 fn test_get_form_action_skips_non_post_returns_second_form_action() {
501 let soup = _parse(
503 r#"<html><body>
504 <form action="/bad" method="GET"></form>
505 <form action="/good" method="POST"></form>
506 </body></html>"#,
507 );
508 assert_eq!(get_form_action(&soup), Some("/good"));
509 }
510}