1use crate::HtsGetError;
5use crate::error::Result as HtsGetResult;
6use crate::middleware::error::Error::AuthBuilderError;
7use crate::middleware::error::Result;
8use cfg_if::cfg_if;
9use headers::authorization::Bearer;
10use headers::{Authorization, Header};
11use htsget_config::config::advanced::auth::authorization::UrlOrStatic;
12use htsget_config::config::advanced::auth::jwt::AuthMode;
13use htsget_config::config::advanced::auth::response::AuthorizationRestrictionsBuilder;
14use htsget_config::config::advanced::auth::{AuthConfig, AuthorizationRestrictions};
15use htsget_config::config::location::{Location, PrefixOrId};
16use htsget_config::types::{Class, Interval, Query};
17use http::{HeaderMap, HeaderName, HeaderValue, Uri};
18use jsonpath_rust::JsonPath;
19use jsonwebtoken::jwk::JwkSet;
20use jsonwebtoken::{Algorithm, DecodingKey, TokenData, Validation, decode, decode_header};
21use serde::de::DeserializeOwned;
22use serde_json::Value;
23use std::fmt::{Debug, Formatter};
24use std::str::FromStr;
25use tracing::trace;
26
27#[derive(Default, Debug)]
29pub struct AuthBuilder {
30 config: Option<AuthConfig>,
31}
32
33impl AuthBuilder {
34 pub fn with_config(mut self, config: AuthConfig) -> Self {
36 self.config = Some(config);
37 self
38 }
39
40 pub fn build(self) -> Result<Auth> {
42 let Some(mut config) = self.config else {
43 return Err(AuthBuilderError("missing config".to_string()));
44 };
45
46 let mut decoding_key = None;
47 if let Some(AuthMode::PublicKey(public_key)) = config.auth_mode_mut() {
48 decoding_key = Some(
49 Auth::decode_public_key(public_key)
50 .map_err(|_| AuthBuilderError("failed to decode public key".to_string()))?,
51 );
52 }
53
54 Ok(Auth {
55 config,
56 decoding_key,
57 })
58 }
59}
60
61#[derive(Clone)]
63pub struct Auth {
64 config: AuthConfig,
65 decoding_key: Option<DecodingKey>,
66}
67
68impl Debug for Auth {
69 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
70 f.debug_struct("config").finish()
71 }
72}
73
74const FORWARD_HEADER_PREFIX: &str = "Htsget-Context-";
75
76impl Auth {
77 pub fn config(&self) -> &AuthConfig {
79 &self.config
80 }
81
82 pub async fn fetch_from_url<D: DeserializeOwned>(
84 &self,
85 url: &str,
86 headers: HeaderMap,
87 ) -> HtsGetResult<D> {
88 trace!("fetching url: {}", url);
89 let err = || HtsGetError::InternalError(format!("failed to fetch data from {url}"));
90 let response = self
91 .config
92 .http_client()
93 .get(url)
94 .headers(headers)
95 .send()
96 .await
97 .map_err(|_| err())?;
98
99 response.json().await.map_err(|_| err())
100 }
101
102 pub async fn decode_jwks(&self, jwks_url: &Uri, token: &str) -> HtsGetResult<DecodingKey> {
104 let header = decode_header(token)?;
106 let kid = header
107 .kid
108 .ok_or_else(|| HtsGetError::PermissionDenied("JWT missing key ID".to_string()))?;
109
110 let jwks = self
112 .fetch_from_url::<JwkSet>(&jwks_url.to_string(), Default::default())
113 .await?;
114 let matched_jwk = jwks
115 .find(&kid)
116 .ok_or_else(|| HtsGetError::PermissionDenied("matching JWK not found".to_string()))?;
117
118 Ok(DecodingKey::from_jwk(matched_jwk)?)
119 }
120
121 pub fn decode_public_key(key: &[u8]) -> HtsGetResult<DecodingKey> {
123 Ok(
124 DecodingKey::from_rsa_pem(key)
125 .or_else(|_| DecodingKey::from_ed_pem(key))
126 .or_else(|_| DecodingKey::from_ec_pem(key))?,
127 )
128 }
129
130 pub fn forwarded_headers(
132 &self,
133 request_headers: &HeaderMap,
134 request_extensions: Option<Value>,
135 ) -> HtsGetResult<HeaderMap> {
136 let mut forwarded_headers = if self.config.passthrough_auth() {
137 let auth_header = request_headers
138 .iter()
139 .find_map(|(name, value)| {
140 if Authorization::<Bearer>::decode(&mut [value].into_iter()).is_ok() {
141 return Some((name.clone(), value.clone()));
142 }
143
144 None
145 })
146 .ok_or_else(|| HtsGetError::PermissionDenied("missing authorization header".to_string()))?;
147 HeaderMap::from_iter([auth_header])
148 } else {
149 HeaderMap::default()
150 };
151
152 for header in self.config.forward_headers() {
153 let Some((existing_name, existing_value)) = request_headers
154 .iter()
155 .find_map(|(name, value)| {
156 if header.to_lowercase() == name.as_str().to_lowercase() {
157 return match HeaderName::from_str(&format!("{}{}", FORWARD_HEADER_PREFIX, name)) {
158 Ok(header) => Some(Ok((header, value))),
159 Err(err) => Some(Err(HtsGetError::InternalError(err.to_string()))),
160 };
161 }
162
163 None
164 })
165 .transpose()?
166 else {
167 continue;
168 };
169
170 forwarded_headers.insert(existing_name, existing_value.clone());
171 }
172
173 if let Some(request_extensions) = request_extensions {
174 for extension in self.config.forward_extensions() {
175 let Some(value) = request_extensions.query(extension.json_path()).ok() else {
176 continue;
177 };
178
179 let value = value.first().ok_or_else(|| {
180 HtsGetError::InternalError("extension does not have only one value".to_string())
181 })?;
182 let value = value.as_str().ok_or_else(|| {
183 HtsGetError::InternalError("extension value is not a string".to_string())
184 })?;
185
186 let header_name =
187 HeaderName::from_str(&format!("{}{}", FORWARD_HEADER_PREFIX, extension.name()))
188 .map_err(|err| HtsGetError::InternalError(err.to_string()))?;
189 let value = HeaderValue::from_str(value)
190 .map_err(|err| HtsGetError::InternalError(err.to_string()))?;
191 forwarded_headers.insert(header_name, value);
192 }
193 }
194
195 Ok(forwarded_headers)
196 }
197
198 pub async fn query_authorization_service(
202 &self,
203 headers: &HeaderMap,
204 request_extensions: Option<Value>,
205 ) -> HtsGetResult<Option<AuthorizationRestrictions>> {
206 match self.config.authorization_url() {
207 Some(UrlOrStatic::Url(uri)) => {
208 let forwarded_headers = self.forwarded_headers(headers, request_extensions)?;
209
210 self
211 .fetch_from_url(&uri.to_string(), forwarded_headers)
212 .await
213 .map(Some)
214 }
215 Some(UrlOrStatic::Static(config)) => Ok(Some(config.clone())),
216 _ => Ok(None),
217 }
218 }
219
220 pub fn validate_restrictions(
225 restrictions: AuthorizationRestrictions,
226 path: &str,
227 queries: &mut [Query],
228 suppressed_interval: bool,
229 ) -> HtsGetResult<AuthorizationRestrictions> {
230 let matching_rules = restrictions
232 .into_rules()
233 .into_iter()
234 .filter(|rule| {
235 match rule.location() {
236 Location::Simple(location) if location.prefix_or_id().is_some() => {
237 match location.prefix_or_id().unwrap_or_default() {
238 PrefixOrId::Prefix(prefix) => {
239 path.starts_with(&prefix)
241 }
242 PrefixOrId::Id(id) => {
243 id == path
245 }
246 }
247 }
248 Location::Regex(location) => {
249 location.regex().is_match(path)
251 }
252 _ => false,
254 }
255 })
256 .collect::<Vec<_>>();
257
258 if matching_rules.is_empty() {
260 return Err(HtsGetError::PermissionDenied(
261 "failed to authorize user based on authorization service restrictions".to_string(),
262 ));
263 }
264
265 let (allows_all, allows_specific): (Vec<_>, Vec<_>) = matching_rules
266 .into_iter()
267 .partition(|rule| rule.rules().is_none());
268
269 for query in queries {
271 if query.class() == Class::Header {
273 continue;
274 }
275
276 let matching_restriction = allows_specific
277 .iter()
278 .flat_map(|rule| rule.rules().unwrap_or_default())
279 .filter_map(|restriction| {
280 let name_match = restriction.reference_name().is_none()
282 || restriction.reference_name() == query.reference_name();
283 let format_match =
285 restriction.format().is_none() || restriction.format() == Some(query.format());
286 let interval_match = if suppressed_interval {
288 restriction.interval().constraint_interval(query.interval())
289 } else {
290 restriction.interval().contains_interval(query.interval())
291 };
292
293 if let Some(interval_match) = interval_match {
294 if name_match && format_match {
295 return Some(interval_match);
296 }
297 }
298
299 None
300 })
301 .max_by(Interval::order_by_range); if suppressed_interval {
304 if allows_all.is_empty() && matching_restriction.is_none() {
305 query.set_class(Class::Header);
307 continue;
308 }
309
310 if let Some(matching_restriction) = matching_restriction {
311 query.set_interval(matching_restriction);
312 }
313 } else if allows_all.is_empty() && matching_restriction.is_none() {
314 return Err(HtsGetError::PermissionDenied(
315 "failed to authorize user based on authorization service restrictions".to_string(),
316 ));
317 }
318 }
319
320 AuthorizationRestrictionsBuilder::default()
321 .rules([allows_all, allows_specific].concat())
322 .build()
323 .map_err(|err| HtsGetError::InternalError(err.to_string()))
324 }
325
326 pub async fn validate_jwt(&self, headers: &HeaderMap) -> HtsGetResult<TokenData<Value>> {
329 let auth_token = headers
330 .values()
331 .find_map(|value| Authorization::<Bearer>::decode(&mut [value].into_iter()).ok())
332 .ok_or_else(|| {
333 HtsGetError::InvalidAuthentication("invalid authorization header".to_string())
334 })?;
335
336 let decoding_key = if let Some(ref decoding_key) = self.decoding_key {
337 decoding_key
338 } else {
339 match self.config.auth_mode() {
340 Some(AuthMode::Jwks(jwks)) => &self.decode_jwks(jwks, auth_token.token()).await?,
341 Some(AuthMode::PublicKey(public_key)) => &Self::decode_public_key(public_key)?,
342 _ => {
343 return Err(HtsGetError::InternalError(
344 "JWT validation not set".to_string(),
345 ));
346 }
347 }
348 };
349
350 let mut validation = Validation::default();
352 validation.validate_exp = true;
353 validation.validate_aud = true;
354 validation.validate_nbf = true;
355
356 if let Some(iss) = self.config.validate_issuer() {
357 validation.set_issuer(iss);
358 validation.required_spec_claims.insert("iss".to_string());
359 }
360 if let Some(aud) = self.config.validate_audience() {
361 validation.set_audience(aud);
362 validation.required_spec_claims.insert("aud".to_string());
363 }
364 if let Some(sub) = self.config.validate_subject() {
365 validation.sub = Some(sub.to_string());
366 validation.required_spec_claims.insert("sub".to_string());
367 }
368
369 validation.algorithms = vec![Algorithm::RS256];
372 let decoded_claims = decode::<Value>(auth_token.token(), decoding_key, &validation)
373 .or_else(|_| {
374 validation.algorithms = vec![Algorithm::ES256];
375 decode::<Value>(auth_token.token(), decoding_key, &validation)
376 })
377 .or_else(|_| {
378 validation.algorithms = vec![Algorithm::EdDSA];
379 decode::<Value>(auth_token.token(), decoding_key, &validation)
380 });
381
382 let claims = match decoded_claims {
383 Ok(claims) => claims,
384 Err(err) => return Err(HtsGetError::PermissionDenied(format!("invalid JWT: {err}"))),
385 };
386
387 Ok(claims)
388 }
389
390 pub async fn validate_authorization(
398 &self,
399 headers: &HeaderMap,
400 path: &str,
401 queries: &mut [Query],
402 request_extensions: Option<Value>,
403 ) -> HtsGetResult<Option<AuthorizationRestrictions>> {
404 let restrictions = self
405 .query_authorization_service(headers, request_extensions)
406 .await?;
407
408 if let Some(restrictions) = restrictions {
409 cfg_if! {
410 if #[cfg(feature = "experimental")] {
411 Self::validate_restrictions(restrictions, path, queries, self.config.suppress_errors()).map(Some)
412 } else {
413 Self::validate_restrictions(restrictions, path, queries, false).map(Some)
414 }
415 }
416 } else {
417 Ok(None)
418 }
419 }
420}
421
422#[cfg(test)]
423mod tests {
424 use super::*;
425 use crate::{Endpoint, convert_to_query, match_format_from_query};
426 use htsget_config::config::advanced::HttpClient;
427 use htsget_config::config::advanced::auth::AuthConfigBuilder;
428 use htsget_config::config::advanced::auth::authorization::ForwardExtensions;
429 use htsget_config::config::advanced::auth::response::{
430 AuthorizationRestrictionsBuilder, AuthorizationRuleBuilder, ReferenceNameRestrictionBuilder,
431 };
432 use htsget_config::config::advanced::regex_location::RegexLocation;
433 use htsget_config::config::location::SimpleLocation;
434 use htsget_config::types::{Format, Request};
435 use htsget_test::util::generate_key_pair;
436 use http::{HeaderMap, Uri};
437 use regex::Regex;
438 use reqwest_middleware::ClientBuilder;
439 use serde_json::json;
440 use std::collections::HashMap;
441
442 #[test]
443 fn auth_builder_missing_config() {
444 let result = AuthBuilder::default().build();
445 assert!(matches!(result, Err(AuthBuilderError(_))));
446 }
447
448 #[test]
449 fn auth_builder_success_with_public_key() {
450 let (_, public_key) = generate_key_pair();
451
452 let config = create_test_auth_config(public_key);
453 let result = AuthBuilder::default().with_config(config).build();
454 assert!(result.is_ok());
455 }
456
457 #[test]
458 fn validate_restrictions_rule_allows_all() {
459 let rule = AuthorizationRuleBuilder::default()
460 .location(test_location())
461 .build()
462 .unwrap();
463 let restrictions = AuthorizationRestrictionsBuilder::default()
464 .rule(rule)
465 .build()
466 .unwrap();
467
468 let request = create_test_query(Endpoint::Reads, "sample1", HashMap::new());
469 let result =
470 Auth::validate_restrictions(restrictions, request.id(), &mut [request.clone()], false);
471 assert!(result.is_ok());
472 }
473
474 #[test]
475 fn validate_restrictions_exact_path_match() {
476 let reference_restriction = ReferenceNameRestrictionBuilder::default()
477 .name("chr1")
478 .format(Format::Bam)
479 .start(1000)
480 .end(2000)
481 .build()
482 .unwrap();
483 let rule = AuthorizationRuleBuilder::default()
484 .location(test_location())
485 .reference_name(reference_restriction)
486 .build()
487 .unwrap();
488 let restrictions = AuthorizationRestrictionsBuilder::default()
489 .rule(rule)
490 .build()
491 .unwrap();
492
493 let mut query = HashMap::new();
494 query.insert("referenceName".to_string(), "chr1".to_string());
495 query.insert("start".to_string(), "1500".to_string());
496 query.insert("end".to_string(), "1800".to_string());
497 query.insert("format".to_string(), "BAM".to_string());
498
499 let request = create_test_query(Endpoint::Reads, "sample1", query);
500 let result =
501 Auth::validate_restrictions(restrictions, request.id(), &mut [request.clone()], false);
502 assert!(result.is_ok());
503 }
504
505 #[test]
506 fn validate_restrictions_regex_prefix_match() {
507 let reference_restriction = ReferenceNameRestrictionBuilder::default()
508 .name("chr1")
509 .format(Format::Bam)
510 .build()
511 .unwrap();
512 let rule = AuthorizationRuleBuilder::default()
513 .location(Location::Simple(Box::new(SimpleLocation::new(
514 Default::default(),
515 "".to_string(),
516 Some(PrefixOrId::Prefix("sam".to_string())),
517 ))))
518 .reference_name(reference_restriction)
519 .build()
520 .unwrap();
521 let restrictions = AuthorizationRestrictionsBuilder::default()
522 .rule(rule)
523 .build()
524 .unwrap();
525
526 let mut query = HashMap::new();
527 query.insert("referenceName".to_string(), "chr1".to_string());
528 query.insert("format".to_string(), "BAM".to_string());
529
530 let request = create_test_query(Endpoint::Reads, "sample123", query);
531 let result =
532 Auth::validate_restrictions(restrictions, request.id(), &mut [request.clone()], false);
533 assert!(result.is_ok());
534 }
535
536 #[test]
537 fn validate_restrictions_regex_match() {
538 let reference_restriction = ReferenceNameRestrictionBuilder::default()
539 .name("chr1")
540 .format(Format::Bam)
541 .build()
542 .unwrap();
543 let rule = AuthorizationRuleBuilder::default()
544 .location(Location::Regex(Box::new(RegexLocation::new(
545 Regex::new("sample(.+)").unwrap(),
546 "".to_string(),
547 Default::default(),
548 Default::default(),
549 ))))
550 .reference_name(reference_restriction)
551 .build()
552 .unwrap();
553 let restrictions = AuthorizationRestrictionsBuilder::default()
554 .rule(rule)
555 .build()
556 .unwrap();
557
558 let mut query = HashMap::new();
559 query.insert("referenceName".to_string(), "chr1".to_string());
560 query.insert("format".to_string(), "BAM".to_string());
561
562 let request = create_test_query(Endpoint::Reads, "sample123", query);
563 let result =
564 Auth::validate_restrictions(restrictions, request.id(), &mut [request.clone()], false);
565 assert!(result.is_ok());
566 }
567
568 #[test]
569 fn validate_restrictions_forward_headers() {
570 let (_, public_key) = generate_key_pair();
571
572 let builder = AuthConfigBuilder::default()
573 .auth_mode(AuthMode::PublicKey(public_key))
574 .authorization_url(UrlOrStatic::Url(Uri::from_static(
575 "https://www.example.com",
576 )))
577 .http_client(HttpClient::new(
578 ClientBuilder::new(reqwest::Client::new()).build(),
579 ));
580 let config = builder
581 .clone()
582 .passthrough_auth(true)
583 .forward_headers(vec!["Custom1".to_string()])
584 .build()
585 .unwrap();
586 let result = AuthBuilder::default().with_config(config).build().unwrap();
587
588 let request_headers = HeaderMap::from_iter([
589 (
590 "Authorization".parse().unwrap(),
591 "Bearer Value".parse().unwrap(),
592 ),
593 ("Custom1".parse().unwrap(), "Value".parse().unwrap()),
594 ("Custom2".parse().unwrap(), "Value".parse().unwrap()),
595 ]);
596 let forwarded_headers = result.forwarded_headers(&request_headers, None).unwrap();
597 assert_eq!(
598 forwarded_headers,
599 HeaderMap::from_iter([
600 (
601 format!("{}Custom1", FORWARD_HEADER_PREFIX).parse().unwrap(),
602 "Value".parse().unwrap()
603 ),
604 (
605 "Authorization".parse().unwrap(),
606 "Bearer Value".parse().unwrap()
607 ),
608 ])
609 );
610
611 let config = builder
612 .clone()
613 .passthrough_auth(true)
614 .forward_headers(vec!["Custom1".to_string(), "Authorization".to_string()])
615 .build()
616 .unwrap();
617 let result = AuthBuilder::default().with_config(config).build().unwrap();
618
619 let forwarded_headers = result.forwarded_headers(&request_headers, None).unwrap();
620 assert_eq!(
621 forwarded_headers,
622 HeaderMap::from_iter([
623 (
624 format!("{}Custom1", FORWARD_HEADER_PREFIX).parse().unwrap(),
625 "Value".parse().unwrap()
626 ),
627 (
628 format!("{}Authorization", FORWARD_HEADER_PREFIX)
629 .parse()
630 .unwrap(),
631 "Bearer Value".parse().unwrap()
632 ),
633 (
634 "Authorization".parse().unwrap(),
635 "Bearer Value".parse().unwrap()
636 ),
637 ])
638 );
639
640 let config = builder
641 .clone()
642 .forward_headers(vec!["Custom1".to_string()])
643 .passthrough_auth(false)
644 .build()
645 .unwrap();
646 let result = AuthBuilder::default().with_config(config).build().unwrap();
647
648 let forwarded_headers = result.forwarded_headers(&request_headers, None).unwrap();
649 assert_eq!(
650 forwarded_headers,
651 HeaderMap::from_iter([(
652 format!("{}Custom1", FORWARD_HEADER_PREFIX).parse().unwrap(),
653 "Value".parse().unwrap()
654 ),])
655 );
656
657 let config = builder
658 .forward_extensions(vec![ForwardExtensions::new(
659 "$.Key".to_string(),
660 "Custom1".to_string(),
661 )])
662 .passthrough_auth(false)
663 .build()
664 .unwrap();
665 let result = AuthBuilder::default().with_config(config).build().unwrap();
666
667 let forwarded_headers = result
668 .forwarded_headers(
669 &request_headers,
670 Some(json!({
671 "Key": "Value"
672 })),
673 )
674 .unwrap();
675 assert_eq!(
676 forwarded_headers,
677 HeaderMap::from_iter([(
678 format!("{}Custom1", FORWARD_HEADER_PREFIX).parse().unwrap(),
679 "Value".parse().unwrap()
680 ),])
681 );
682 }
683
684 #[test]
685 fn validate_restrictions_reference_name_mismatch() {
686 let reference_restriction = ReferenceNameRestrictionBuilder::default()
687 .name("chr1")
688 .format(Format::Bam)
689 .build()
690 .unwrap();
691 let rule = AuthorizationRuleBuilder::default()
692 .location(test_location())
693 .reference_name(reference_restriction)
694 .build()
695 .unwrap();
696 let restrictions = AuthorizationRestrictionsBuilder::default()
697 .rule(rule.clone())
698 .build()
699 .unwrap();
700
701 let mut query = HashMap::new();
702 query.insert("class".to_string(), "header".to_string());
703 query.insert("format".to_string(), "BAM".to_string());
704
705 let request = create_test_query(Endpoint::Reads, "sample1", query);
706 let result =
707 Auth::validate_restrictions(restrictions, request.id(), &mut [request.clone()], false);
708 assert!(result.is_ok());
709 }
710
711 #[test]
712 fn validate_restrictions_header() {
713 let reference_restriction = ReferenceNameRestrictionBuilder::default()
714 .name("chr1")
715 .format(Format::Bam)
716 .build()
717 .unwrap();
718 let rule = AuthorizationRuleBuilder::default()
719 .location(test_location())
720 .reference_name(reference_restriction)
721 .build()
722 .unwrap();
723 let restrictions = AuthorizationRestrictionsBuilder::default()
724 .rule(rule.clone())
725 .build()
726 .unwrap();
727
728 let mut query = HashMap::new();
729 query.insert("format".to_string(), "BAM".to_string());
730 query.insert("class".to_string(), "header".to_string());
731
732 let request = create_test_query(Endpoint::Reads, "sample1", query);
733 let result =
734 Auth::validate_restrictions(restrictions, request.id(), &mut [request.clone()], false);
735 assert!(result.is_ok());
736 }
737
738 #[cfg(feature = "experimental")]
739 #[test]
740 fn validate_restrictions_reference_name_mismatch_suppressed() {
741 let reference_restriction = ReferenceNameRestrictionBuilder::default()
742 .name("chr1")
743 .format(Format::Bam)
744 .build()
745 .unwrap();
746 let rule = AuthorizationRuleBuilder::default()
747 .location(test_location())
748 .reference_name(reference_restriction)
749 .build()
750 .unwrap();
751 let restrictions = AuthorizationRestrictionsBuilder::default()
752 .rule(rule.clone())
753 .build()
754 .unwrap();
755
756 let mut query = HashMap::new();
757 query.insert("referenceName".to_string(), "chr2".to_string());
758 query.insert("format".to_string(), "BAM".to_string());
759
760 let request = create_test_query(Endpoint::Reads, "sample1", query);
761 let result =
762 Auth::validate_restrictions(restrictions, request.id(), &mut [request.clone()], true);
763 assert!(result.is_ok());
764 }
765
766 #[test]
767 fn validate_restrictions_format_mismatch() {
768 let reference_restriction = ReferenceNameRestrictionBuilder::default()
769 .name("chr1")
770 .format(Format::Bam)
771 .build()
772 .unwrap();
773 let rule = AuthorizationRuleBuilder::default()
774 .location(test_location())
775 .reference_name(reference_restriction)
776 .build()
777 .unwrap();
778 let restrictions = AuthorizationRestrictionsBuilder::default()
779 .rule(rule.clone())
780 .build()
781 .unwrap();
782
783 let mut query = HashMap::new();
784 query.insert("referenceName".to_string(), "chr1".to_string());
785 query.insert("format".to_string(), "CRAM".to_string());
786
787 let request = create_test_query(Endpoint::Reads, "sample1", query);
788 let result =
789 Auth::validate_restrictions(restrictions, request.id(), &mut [request.clone()], false);
790 assert!(result.is_err());
791 }
792
793 #[cfg(feature = "experimental")]
794 #[test]
795 fn validate_restrictions_format_mismatch_suppressed() {
796 let reference_restriction = ReferenceNameRestrictionBuilder::default()
797 .name("chr1")
798 .format(Format::Bam)
799 .build()
800 .unwrap();
801 let rule = AuthorizationRuleBuilder::default()
802 .location(test_location())
803 .reference_name(reference_restriction)
804 .build()
805 .unwrap();
806 let restrictions = AuthorizationRestrictionsBuilder::default()
807 .rule(rule.clone())
808 .build()
809 .unwrap();
810
811 let mut query = HashMap::new();
812 query.insert("referenceName".to_string(), "chr1".to_string());
813 query.insert("format".to_string(), "CRAM".to_string());
814
815 let request = create_test_query(Endpoint::Reads, "sample1", query);
816 let result =
817 Auth::validate_restrictions(restrictions, request.id(), &mut [request.clone()], true);
818 assert!(result.is_ok());
819 }
820
821 #[test]
822 fn validate_restrictions_interval_not_contained() {
823 test_interval_suppressed(
827 Some(1000),
828 Some(2000),
829 Some(1250),
830 Some(1750),
831 (Interval::new(Some(1250), Some(1750)), Class::Body),
832 false,
833 false,
834 );
835
836 test_interval_suppressed(
840 Some(1000),
841 Some(2000),
842 Some(500),
843 None,
844 (Interval::new(Some(500), None), Class::Body),
845 true,
846 false,
847 );
848
849 test_interval_suppressed(
853 Some(1000),
854 Some(2000),
855 None,
856 Some(2500),
857 (Interval::new(None, Some(2500)), Class::Body),
858 true,
859 false,
860 );
861
862 test_interval_suppressed(
866 Some(1000),
867 Some(2000),
868 None,
869 None,
870 (Interval::new(None, None), Class::Body),
871 true,
872 false,
873 );
874
875 test_interval_suppressed(
879 Some(1000),
880 Some(2000),
881 Some(500),
882 Some(1500),
883 (Interval::new(Some(500), Some(1500)), Class::Body),
884 true,
885 false,
886 );
887
888 test_interval_suppressed(
892 Some(1000),
893 Some(2000),
894 None,
895 Some(1500),
896 (Interval::new(None, Some(1500)), Class::Body),
897 true,
898 false,
899 );
900
901 test_interval_suppressed(
905 Some(1000),
906 Some(2000),
907 Some(1500),
908 Some(2500),
909 (Interval::new(Some(1500), Some(2500)), Class::Body),
910 true,
911 false,
912 );
913
914 test_interval_suppressed(
918 Some(1000),
919 Some(2000),
920 Some(1500),
921 None,
922 (Interval::new(Some(1500), None), Class::Body),
923 true,
924 false,
925 );
926
927 test_interval_suppressed(
931 Some(1000),
932 Some(2000),
933 Some(500),
934 Some(1000),
935 (Interval::new(Some(500), Some(1000)), Class::Body),
936 true,
937 false,
938 );
939
940 test_interval_suppressed(
944 Some(1000),
945 Some(2000),
946 Some(2000),
947 Some(2500),
948 (Interval::new(Some(2000), Some(2500)), Class::Body),
949 true,
950 false,
951 );
952
953 test_interval_suppressed(
957 None,
958 Some(2000),
959 Some(500),
960 Some(1500),
961 (Interval::new(Some(500), Some(1500)), Class::Body),
962 false,
963 false,
964 );
965
966 test_interval_suppressed(
970 None,
971 Some(2000),
972 Some(1500),
973 Some(2500),
974 (Interval::new(Some(1500), Some(2500)), Class::Body),
975 true,
976 false,
977 );
978
979 test_interval_suppressed(
983 Some(1000),
984 None,
985 Some(1500),
986 Some(2500),
987 (Interval::new(Some(1500), Some(2500)), Class::Body),
988 false,
989 false,
990 );
991
992 test_interval_suppressed(
996 Some(1000),
997 None,
998 Some(500),
999 Some(1500),
1000 (Interval::new(Some(500), Some(1500)), Class::Body),
1001 true,
1002 false,
1003 );
1004
1005 test_interval_suppressed(
1009 None,
1010 None,
1011 Some(500),
1012 Some(2500),
1013 (Interval::new(Some(500), Some(2500)), Class::Body),
1014 false,
1015 false,
1016 );
1017
1018 test_interval_suppressed(
1022 None,
1023 None,
1024 Some(500),
1025 None,
1026 (Interval::new(Some(500), None), Class::Body),
1027 false,
1028 false,
1029 );
1030
1031 test_interval_suppressed(
1035 None,
1036 None,
1037 None,
1038 Some(2500),
1039 (Interval::new(None, Some(2500)), Class::Body),
1040 false,
1041 false,
1042 );
1043 }
1044
1045 #[cfg(feature = "experimental")]
1046 #[test]
1047 fn validate_restrictions_interval_suppressed() {
1048 test_interval_suppressed(
1052 Some(1000),
1053 Some(2000),
1054 Some(1250),
1055 Some(1750),
1056 (Interval::new(Some(1250), Some(1750)), Class::Body),
1057 false,
1058 true,
1059 );
1060
1061 test_interval_suppressed(
1065 Some(1000),
1066 Some(2000),
1067 Some(500),
1068 None,
1069 (Interval::new(Some(1000), Some(2000)), Class::Body),
1070 false,
1071 true,
1072 );
1073
1074 test_interval_suppressed(
1078 Some(1000),
1079 Some(2000),
1080 None,
1081 Some(2500),
1082 (Interval::new(Some(1000), Some(2000)), Class::Body),
1083 false,
1084 true,
1085 );
1086
1087 test_interval_suppressed(
1091 Some(1000),
1092 Some(2000),
1093 None,
1094 None,
1095 (Interval::new(Some(1000), Some(2000)), Class::Body),
1096 false,
1097 true,
1098 );
1099
1100 test_interval_suppressed(
1104 Some(1000),
1105 Some(2000),
1106 Some(500),
1107 Some(1500),
1108 (Interval::new(Some(1000), Some(1500)), Class::Body),
1109 false,
1110 true,
1111 );
1112
1113 test_interval_suppressed(
1117 Some(1000),
1118 Some(2000),
1119 None,
1120 Some(1500),
1121 (Interval::new(Some(1000), Some(1500)), Class::Body),
1122 false,
1123 true,
1124 );
1125
1126 test_interval_suppressed(
1130 Some(1000),
1131 Some(2000),
1132 Some(1500),
1133 Some(2500),
1134 (Interval::new(Some(1500), Some(2000)), Class::Body),
1135 false,
1136 true,
1137 );
1138
1139 test_interval_suppressed(
1143 Some(1000),
1144 Some(2000),
1145 Some(1500),
1146 None,
1147 (Interval::new(Some(1500), Some(2000)), Class::Body),
1148 false,
1149 true,
1150 );
1151
1152 test_interval_suppressed(
1156 Some(1000),
1157 Some(2000),
1158 Some(500),
1159 Some(1000),
1160 (Interval::new(Some(500), Some(1000)), Class::Header),
1161 false,
1162 true,
1163 );
1164
1165 test_interval_suppressed(
1169 Some(1000),
1170 Some(2000),
1171 Some(2000),
1172 Some(2500),
1173 (Interval::new(Some(2000), Some(2500)), Class::Header),
1174 false,
1175 true,
1176 );
1177
1178 test_interval_suppressed(
1182 None,
1183 Some(2000),
1184 Some(500),
1185 Some(1500),
1186 (Interval::new(Some(500), Some(1500)), Class::Body),
1187 false,
1188 true,
1189 );
1190
1191 test_interval_suppressed(
1195 None,
1196 Some(2000),
1197 Some(1500),
1198 Some(2500),
1199 (Interval::new(Some(1500), Some(2000)), Class::Body),
1200 false,
1201 true,
1202 );
1203
1204 test_interval_suppressed(
1208 Some(1000),
1209 None,
1210 Some(1500),
1211 Some(2500),
1212 (Interval::new(Some(1500), Some(2500)), Class::Body),
1213 false,
1214 true,
1215 );
1216
1217 test_interval_suppressed(
1221 Some(1000),
1222 None,
1223 Some(500),
1224 Some(1500),
1225 (Interval::new(Some(1000), Some(1500)), Class::Body),
1226 false,
1227 true,
1228 );
1229
1230 test_interval_suppressed(
1234 None,
1235 None,
1236 Some(500),
1237 Some(2500),
1238 (Interval::new(Some(500), Some(2500)), Class::Body),
1239 false,
1240 true,
1241 );
1242
1243 test_interval_suppressed(
1247 None,
1248 None,
1249 Some(500),
1250 None,
1251 (Interval::new(Some(500), None), Class::Body),
1252 false,
1253 true,
1254 );
1255
1256 test_interval_suppressed(
1260 None,
1261 None,
1262 None,
1263 Some(2500),
1264 (Interval::new(None, Some(2500)), Class::Body),
1265 false,
1266 true,
1267 );
1268 }
1269
1270 #[test]
1271 fn validate_restrictions_format_none_allows_any() {
1272 let reference_restriction = ReferenceNameRestrictionBuilder::default()
1273 .name("chr1")
1274 .build()
1275 .unwrap();
1276 let rule = AuthorizationRuleBuilder::default()
1277 .location(test_location())
1278 .reference_name(reference_restriction)
1279 .build()
1280 .unwrap();
1281 let restrictions = AuthorizationRestrictionsBuilder::default()
1282 .rule(rule)
1283 .build()
1284 .unwrap();
1285
1286 let mut query = HashMap::new();
1287 query.insert("referenceName".to_string(), "chr1".to_string());
1288 query.insert("format".to_string(), "CRAM".to_string());
1289
1290 let request = create_test_query(Endpoint::Reads, "sample1", query);
1291 let result =
1292 Auth::validate_restrictions(restrictions, request.id(), &mut [request.clone()], false);
1293 assert!(result.is_ok());
1294 }
1295
1296 #[test]
1297 fn validate_restrictions_path_with_leading_slash() {
1298 let rule = AuthorizationRuleBuilder::default()
1299 .location(test_location())
1300 .build()
1301 .unwrap();
1302 let restrictions = AuthorizationRestrictionsBuilder::default()
1303 .rule(rule)
1304 .build()
1305 .unwrap();
1306 let request = create_test_query(Endpoint::Reads, "sample1", HashMap::new());
1307 let result =
1308 Auth::validate_restrictions(restrictions, request.id(), &mut [request.clone()], false);
1309 assert!(result.is_ok());
1310 }
1311
1312 #[tokio::test]
1313 async fn validate_authorization_missing_auth_header() {
1314 let auth = create_mock_auth_with_restrictions();
1315 let request = Request::new("sample1".to_string(), HashMap::new(), HeaderMap::new());
1316
1317 let result = auth.validate_jwt(request.headers()).await;
1318 assert!(result.is_err());
1319 assert!(matches!(
1320 result.unwrap_err(),
1321 HtsGetError::InvalidAuthentication(_)
1322 ));
1323 }
1324
1325 #[tokio::test]
1326 async fn validate_authorization_invalid_jwt_format() {
1327 let auth = create_mock_auth_with_restrictions();
1328 let request = create_request_with_auth_header("sample1", HashMap::new(), "invalid.jwt.token");
1329
1330 let result = auth.validate_jwt(request.headers()).await;
1331 assert!(result.is_err());
1332 assert!(matches!(
1333 result.unwrap_err(),
1334 HtsGetError::PermissionDenied(_)
1335 ));
1336 }
1337
1338 fn create_test_auth_config(public_key: Vec<u8>) -> AuthConfig {
1339 AuthConfigBuilder::default()
1340 .auth_mode(AuthMode::PublicKey(public_key))
1341 .authorization_url(UrlOrStatic::Url(Uri::from_static(
1342 "https://www.example.com",
1343 )))
1344 .http_client(HttpClient::new(
1345 ClientBuilder::new(reqwest::Client::new()).build(),
1346 ))
1347 .build()
1348 .unwrap()
1349 }
1350
1351 fn create_test_query(endpoint: Endpoint, path: &str, query: HashMap<String, String>) -> Query {
1352 let request = Request::new(path.to_string(), query, HeaderMap::new());
1353 let format = match_format_from_query(&endpoint, request.query()).unwrap();
1354
1355 convert_to_query(request, format).unwrap()
1356 }
1357
1358 fn create_request_with_auth_header(
1359 path: &str,
1360 query: HashMap<String, String>,
1361 token: &str,
1362 ) -> Request {
1363 let mut headers = HeaderMap::new();
1364 headers.insert("authorization", format!("Bearer {token}").parse().unwrap());
1365 Request::new(path.to_string(), query, headers)
1366 }
1367
1368 fn create_mock_auth_with_restrictions() -> Auth {
1369 let (_, public_key) = generate_key_pair();
1370
1371 let config = create_test_auth_config(public_key);
1372 AuthBuilder::default().with_config(config).build().unwrap()
1373 }
1374
1375 fn test_interval_suppressed(
1376 restrict_start: Option<u32>,
1377 restrict_end: Option<u32>,
1378 request_start: Option<u32>,
1379 request_end: Option<u32>,
1380 expected_response: (Interval, Class),
1381 is_err: bool,
1382 suppress_interval: bool,
1383 ) {
1384 let mut reference_restriction = ReferenceNameRestrictionBuilder::default()
1385 .name("chr1")
1386 .format(Format::Bam);
1387
1388 if let Some(start) = restrict_start {
1389 reference_restriction = reference_restriction.start(start);
1390 }
1391 if let Some(end) = restrict_end {
1392 reference_restriction = reference_restriction.end(end);
1393 }
1394
1395 let reference_restriction = reference_restriction.build().unwrap();
1396 let rule = AuthorizationRuleBuilder::default()
1397 .location(test_location())
1398 .reference_name(reference_restriction)
1399 .build()
1400 .unwrap();
1401 let restrictions = AuthorizationRestrictionsBuilder::default()
1402 .rule(rule.clone())
1403 .build()
1404 .unwrap();
1405
1406 let mut query = HashMap::new();
1407 query.insert("referenceName".to_string(), "chr1".to_string());
1408 request_start.map(|start| query.insert("start".to_string(), start.to_string()));
1409 request_end.map(|end| query.insert("end".to_string(), end.to_string()));
1410
1411 let request = create_test_query(Endpoint::Reads, "sample1", query);
1412 let id = request.id().to_string();
1413 let mut slice = [request];
1414 let result = Auth::validate_restrictions(restrictions, &id, &mut slice, suppress_interval);
1415 if is_err {
1416 assert!(result.is_err());
1417 } else {
1418 assert!(result.is_ok());
1419 }
1420 assert_eq!(slice.first().unwrap().interval(), expected_response.0);
1421 assert_eq!(slice.last().unwrap().class(), expected_response.1);
1422 }
1423
1424 fn test_location() -> Location {
1425 Location::Simple(Box::new(SimpleLocation::new(
1426 Default::default(),
1427 "".to_string(),
1428 Some(PrefixOrId::Id("sample1".to_string())),
1429 )))
1430 }
1431}