1use crate::error::{Result as HtsGetResult, WrappedHtsGetError};
5use crate::middleware::error::Error::AuthBuilderError;
6use crate::middleware::error::Result;
7use crate::{Endpoint, HtsGetError};
8use cfg_if::cfg_if;
9use headers::authorization::Bearer;
10use headers::{Authorization, Header};
11use htsget_config::config::advanced::CONTEXT_HEADER_PREFIX;
12use htsget_config::config::advanced::auth::authorization::UrlOrStatic;
13use htsget_config::config::advanced::auth::jwt::AuthMode;
14use htsget_config::config::advanced::auth::response::AuthorizationRestrictionsBuilder;
15use htsget_config::config::advanced::auth::{AuthConfig, AuthorizationRestrictions};
16use htsget_config::config::location::{Location, PrefixOrId};
17use htsget_config::types::{Class, Interval, Query};
18use http::{HeaderMap, HeaderName, HeaderValue, Uri};
19use jsonpath_rust::JsonPath;
20use jsonwebtoken::jwk::JwkSet;
21use jsonwebtoken::{Algorithm, DecodingKey, TokenData, Validation, decode, decode_header};
22use serde::de::DeserializeOwned;
23use serde_json::Value;
24use std::fmt::{Debug, Formatter};
25use std::str::FromStr;
26use tracing::{debug, trace};
27
28#[derive(Default, Debug)]
30pub struct AuthBuilder {
31 config: Option<AuthConfig>,
32}
33
34impl AuthBuilder {
35 pub fn with_config(mut self, config: AuthConfig) -> Self {
37 self.config = Some(config);
38 self
39 }
40
41 pub fn build(self) -> Result<Auth> {
43 let Some(mut config) = self.config else {
44 return Err(AuthBuilderError("missing config".to_string()));
45 };
46
47 let mut decoding_key = None;
48 if let Some(AuthMode::PublicKey(public_key)) = config.auth_mode_mut() {
49 decoding_key = Some(
50 Auth::decode_public_key(public_key)
51 .map_err(|_| AuthBuilderError("failed to decode public key".to_string()))?,
52 );
53 }
54
55 Ok(Auth {
56 config,
57 decoding_key,
58 })
59 }
60}
61
62#[derive(Clone)]
64pub struct Auth {
65 config: AuthConfig,
66 decoding_key: Option<DecodingKey>,
67}
68
69impl Debug for Auth {
70 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
71 f.debug_struct("config").finish()
72 }
73}
74
75const ENDPOINT_TYPE_HEADER_NAME: &str = "Endpoint-Type";
76const ID_HEADER_NAME: &str = "Id";
77
78impl Auth {
79 pub fn config(&self) -> &AuthConfig {
81 &self.config
82 }
83
84 pub async fn fetch_from_url<D: DeserializeOwned>(
86 &mut self,
87 url: &str,
88 headers: HeaderMap,
89 ) -> HtsGetResult<D> {
90 trace!("fetching url: {}", url);
91 let response = self
92 .config
93 .http_client()
94 .map_err(|err| HtsGetError::InternalError(format!("failed to fetch data from {url}: {err}")))?
95 .get(url)
96 .headers(headers)
97 .send()
98 .await?;
99 trace!("response: {:?}", response);
100
101 let status = response.status();
102
103 let value = response.json::<Value>().await.map_err(|err| {
105 HtsGetError::InternalError(format!("failed to fetch data from {url}: {err}"))
106 })?;
107 trace!("value: {}", value);
108
109 match serde_json::from_value::<D>(value.clone()) {
110 Ok(response) => Ok(response),
111 Err(_) => match serde_json::from_value::<WrappedHtsGetError>(value.clone()) {
112 Ok(err) => Err(HtsGetError::Wrapped(err, status)),
113 Err(_) => Err(HtsGetError::InternalError(format!(
114 "failed to fetch data from {url}: {value}"
115 ))),
116 },
117 }
118 }
119
120 pub async fn decode_jwks(&mut self, jwks_url: &Uri, token: &str) -> HtsGetResult<DecodingKey> {
122 let header = decode_header(token)?;
124 let kid = header
125 .kid
126 .ok_or_else(|| HtsGetError::PermissionDenied("JWT missing key ID".to_string()))?;
127
128 let jwks = self
130 .fetch_from_url::<JwkSet>(&jwks_url.to_string(), Default::default())
131 .await?;
132 let matched_jwk = jwks
133 .find(&kid)
134 .ok_or_else(|| HtsGetError::PermissionDenied("matching JWK not found".to_string()))?;
135
136 Ok(DecodingKey::from_jwk(matched_jwk)?)
137 }
138
139 pub fn decode_public_key(key: &[u8]) -> HtsGetResult<DecodingKey> {
141 Ok(
142 DecodingKey::from_rsa_pem(key)
143 .or_else(|_| DecodingKey::from_ed_pem(key))
144 .or_else(|_| DecodingKey::from_ec_pem(key))?,
145 )
146 }
147
148 pub fn forwarded_headers(
150 &self,
151 request_headers: &HeaderMap,
152 request_extensions: Option<Value>,
153 request_endpoint: &Endpoint,
154 id: &str,
155 ) -> HtsGetResult<HeaderMap> {
156 let mut forwarded_headers = if self.config.passthrough_auth() {
157 let auth_header = request_headers
158 .iter()
159 .find_map(|(name, value)| {
160 if Authorization::<Bearer>::decode(&mut [value].into_iter()).is_ok() {
161 return Some((name.clone(), value.clone()));
162 }
163
164 None
165 })
166 .ok_or_else(|| HtsGetError::PermissionDenied("missing authorization header".to_string()))?;
167 HeaderMap::from_iter([auth_header])
168 } else {
169 HeaderMap::default()
170 };
171
172 for header in self.config.forward_headers() {
173 let Some((existing_name, existing_value)) = request_headers
174 .iter()
175 .find_map(|(name, value)| {
176 if header.to_lowercase() == name.as_str().to_lowercase() {
177 return match HeaderName::from_str(&format!("{}{}", CONTEXT_HEADER_PREFIX, name)) {
178 Ok(header) => Some(Ok((header, value))),
179 Err(err) => Some(Err(HtsGetError::InternalError(err.to_string()))),
180 };
181 }
182
183 None
184 })
185 .transpose()?
186 else {
187 continue;
188 };
189
190 forwarded_headers.insert(existing_name, existing_value.clone());
191 }
192
193 if let Some(request_extensions) = request_extensions {
194 for extension in self.config.forward_extensions() {
195 let Some(value) = request_extensions.query(extension.json_path()).ok() else {
196 continue;
197 };
198
199 let value = value.first().ok_or_else(|| {
200 HtsGetError::InternalError("extension does not have only one value".to_string())
201 })?;
202 let value = value.as_str().ok_or_else(|| {
203 HtsGetError::InternalError("extension value is not a string".to_string())
204 })?;
205
206 let header_name =
207 HeaderName::from_str(&format!("{}{}", CONTEXT_HEADER_PREFIX, extension.name()))?;
208 let value = HeaderValue::from_str(value)?;
209 forwarded_headers.insert(header_name, value);
210 }
211 }
212
213 if self.config.forward_endpoint_type() {
214 let header_name = HeaderName::from_str(&format!(
215 "{}{}",
216 CONTEXT_HEADER_PREFIX, ENDPOINT_TYPE_HEADER_NAME
217 ))?;
218 let value = HeaderValue::from_str(&request_endpoint.to_string())?;
219
220 forwarded_headers.insert(header_name, value);
221 }
222
223 if self.config.forward_id() {
224 let header_name =
225 HeaderName::from_str(&format!("{}{}", CONTEXT_HEADER_PREFIX, ID_HEADER_NAME))?;
226 let value = HeaderValue::from_str(id)?;
227
228 forwarded_headers.insert(header_name, value);
229 }
230
231 Ok(forwarded_headers)
232 }
233
234 pub async fn query_authorization_service(
238 &mut self,
239 headers: &HeaderMap,
240 request_extensions: Option<Value>,
241 request_endpoint: &Endpoint,
242 id: &str,
243 ) -> HtsGetResult<Option<AuthorizationRestrictions>> {
244 match self.config.authorization_url() {
245 Some(UrlOrStatic::Url(uri)) => {
246 let forwarded_headers =
247 self.forwarded_headers(headers, request_extensions, request_endpoint, id)?;
248
249 self
250 .fetch_from_url(&uri.to_string(), forwarded_headers)
251 .await
252 .map(Some)
253 }
254 Some(UrlOrStatic::Static(config)) => Ok(Some(config.clone())),
255 _ => Ok(None),
256 }
257 }
258
259 pub fn validate_restrictions(
264 restrictions: AuthorizationRestrictions,
265 path: &str,
266 queries: &mut [Query],
267 suppressed_interval: bool,
268 ) -> HtsGetResult<AuthorizationRestrictions> {
269 let matching_rules = restrictions
271 .into_rules()
272 .into_iter()
273 .filter(|rule| {
274 match rule.location() {
275 Location::Simple(location) if location.prefix_or_id().is_some() => {
276 match location.prefix_or_id().unwrap_or_default() {
277 PrefixOrId::Prefix(prefix) => {
278 path.starts_with(&prefix)
280 }
281 PrefixOrId::Id(id) => {
282 id == path
284 }
285 }
286 }
287 Location::Regex(location) => {
288 location.regex().is_match(path)
290 }
291 _ => false,
293 }
294 })
295 .collect::<Vec<_>>();
296
297 if matching_rules.is_empty() {
299 return Err(HtsGetError::PermissionDenied(
300 "failed to authorize user based on authorization service restrictions".to_string(),
301 ));
302 }
303
304 let (allows_all, allows_specific): (Vec<_>, Vec<_>) = matching_rules
305 .into_iter()
306 .partition(|rule| rule.rules().is_none());
307
308 for query in queries {
310 if query.class() == Class::Header {
312 continue;
313 }
314
315 let matching_restriction = allows_specific
316 .iter()
317 .flat_map(|rule| rule.rules().unwrap_or_default())
318 .filter_map(|restriction| {
319 let name_match = restriction.reference_name().is_none()
321 || restriction.reference_name() == query.reference_name();
322 let format_match =
324 restriction.format().is_none() || restriction.format() == Some(query.format());
325 let interval_match = if suppressed_interval {
327 restriction.interval().constraint_interval(query.interval())
328 } else {
329 restriction.interval().contains_interval(query.interval())
330 };
331
332 if let Some(interval_match) = interval_match
333 && name_match
334 && format_match
335 {
336 return Some(interval_match);
337 }
338
339 None
340 })
341 .max_by(Interval::order_by_range); if suppressed_interval {
344 if allows_all.is_empty() && matching_restriction.is_none() {
345 query.set_class(Class::Header);
347 continue;
348 }
349
350 if let Some(matching_restriction) = matching_restriction {
351 query.set_interval(matching_restriction);
352 }
353 } else if allows_all.is_empty() && matching_restriction.is_none() {
354 return Err(HtsGetError::PermissionDenied(
355 "failed to authorize user based on authorization service restrictions".to_string(),
356 ));
357 }
358 }
359
360 AuthorizationRestrictionsBuilder::default()
361 .rules([allows_all, allows_specific].concat())
362 .build()
363 .map_err(|err| HtsGetError::InternalError(err.to_string()))
364 }
365
366 pub async fn validate_jwt(&mut self, headers: &HeaderMap) -> HtsGetResult<TokenData<Value>> {
369 let auth_token = headers
370 .values()
371 .find_map(|value| Authorization::<Bearer>::decode(&mut [value].into_iter()).ok())
372 .ok_or_else(|| {
373 HtsGetError::InvalidAuthentication("invalid authorization header".to_string())
374 })?;
375
376 let decoding_key = if let Some(ref decoding_key) = self.decoding_key {
377 decoding_key
378 } else if matches!(self.config.auth_mode(), Some(AuthMode::Jwks(_))) {
379 let url = if let Some(AuthMode::Jwks(uri)) = self.config.auth_mode() {
380 uri.clone()
381 } else {
382 return Err(HtsGetError::InternalError(
383 "JWT validation not set".to_string(),
384 ));
385 };
386
387 &self.decode_jwks(&url, auth_token.token()).await?
388 } else if let Some(AuthMode::PublicKey(key)) = self.config.auth_mode() {
389 &Self::decode_public_key(key)?
390 } else {
391 return Err(HtsGetError::InternalError(
392 "JWT validation not set".to_string(),
393 ));
394 };
395
396 let mut validation = Validation::default();
398 validation.validate_exp = true;
399 validation.validate_aud = true;
400 validation.validate_nbf = true;
401
402 if let Some(iss) = self.config.validate_issuer() {
403 validation.set_issuer(iss);
404 validation.required_spec_claims.insert("iss".to_string());
405 }
406 if let Some(aud) = self.config.validate_audience() {
407 validation.set_audience(aud);
408 validation.required_spec_claims.insert("aud".to_string());
409 }
410 if let Some(sub) = self.config.validate_subject() {
411 validation.sub = Some(sub.to_string());
412 validation.required_spec_claims.insert("sub".to_string());
413 }
414
415 validation.algorithms = vec![Algorithm::RS256];
418 let decoded_claims = decode::<Value>(auth_token.token(), decoding_key, &validation)
419 .or_else(|_| {
420 validation.algorithms = vec![Algorithm::ES256];
421 decode::<Value>(auth_token.token(), decoding_key, &validation)
422 })
423 .or_else(|_| {
424 validation.algorithms = vec![Algorithm::EdDSA];
425 decode::<Value>(auth_token.token(), decoding_key, &validation)
426 });
427
428 let claims = match decoded_claims {
429 Ok(claims) => claims,
430 Err(err) => return Err(HtsGetError::PermissionDenied(format!("invalid JWT: {err}"))),
431 };
432
433 Ok(claims)
434 }
435
436 pub async fn validate_authorization(
444 &mut self,
445 headers: &HeaderMap,
446 path: &str,
447 queries: &mut [Query],
448 request_extensions: Option<Value>,
449 endpoint: &Endpoint,
450 ) -> HtsGetResult<Option<AuthorizationRestrictions>> {
451 let restrictions = self
452 .query_authorization_service(headers, request_extensions, endpoint, path)
453 .await?;
454
455 debug!(restrictions = ?restrictions, "restrictions");
456
457 if let Some(restrictions) = restrictions {
458 cfg_if! {
459 if #[cfg(feature = "experimental")] {
460 Self::validate_restrictions(restrictions, path, queries, self.config.suppress_errors()).map(Some)
461 } else {
462 Self::validate_restrictions(restrictions, path, queries, false).map(Some)
463 }
464 }
465 } else {
466 Ok(None)
467 }
468 }
469}
470
471#[cfg(test)]
472mod tests {
473 use super::*;
474 use crate::{Endpoint, convert_to_query, match_format_from_query};
475 use htsget_config::config::advanced::HttpClient;
476 use htsget_config::config::advanced::auth::AuthConfigBuilder;
477 use htsget_config::config::advanced::auth::authorization::ForwardExtensions;
478 use htsget_config::config::advanced::auth::response::{
479 AuthorizationRestrictionsBuilder, AuthorizationRuleBuilder, ReferenceNameRestrictionBuilder,
480 };
481 use htsget_config::config::advanced::regex_location::RegexLocation;
482 use htsget_config::config::location::SimpleLocation;
483 use htsget_config::types::{Format, Request};
484 use htsget_test::util::generate_key_pair;
485 use http::{HeaderMap, Uri};
486 use regex::Regex;
487 use reqwest_middleware::ClientBuilder;
488 use serde_json::json;
489 use std::collections::HashMap;
490
491 #[test]
492 fn auth_builder_missing_config() {
493 let result = AuthBuilder::default().build();
494 assert!(matches!(result, Err(AuthBuilderError(_))));
495 }
496
497 #[test]
498 fn auth_builder_success_with_public_key() {
499 let (_, public_key) = generate_key_pair();
500
501 let config = create_test_auth_config(public_key);
502 let result = AuthBuilder::default().with_config(config).build();
503 assert!(result.is_ok());
504 }
505
506 #[test]
507 fn validate_restrictions_rule_allows_all() {
508 let rule = AuthorizationRuleBuilder::default()
509 .location(test_location())
510 .build()
511 .unwrap();
512 let restrictions = AuthorizationRestrictionsBuilder::default()
513 .rule(rule)
514 .build()
515 .unwrap();
516
517 let request = create_test_query(Endpoint::Reads, "sample1", HashMap::new());
518 let result =
519 Auth::validate_restrictions(restrictions, request.id(), &mut [request.clone()], false);
520 assert!(result.is_ok());
521 }
522
523 #[test]
524 fn validate_restrictions_exact_path_match() {
525 let reference_restriction = ReferenceNameRestrictionBuilder::default()
526 .name("chr1")
527 .format(Format::Bam)
528 .start(1000)
529 .end(2000)
530 .build()
531 .unwrap();
532 let rule = AuthorizationRuleBuilder::default()
533 .location(test_location())
534 .reference_name(reference_restriction)
535 .build()
536 .unwrap();
537 let restrictions = AuthorizationRestrictionsBuilder::default()
538 .rule(rule)
539 .build()
540 .unwrap();
541
542 let mut query = HashMap::new();
543 query.insert("referenceName".to_string(), "chr1".to_string());
544 query.insert("start".to_string(), "1500".to_string());
545 query.insert("end".to_string(), "1800".to_string());
546 query.insert("format".to_string(), "BAM".to_string());
547
548 let request = create_test_query(Endpoint::Reads, "sample1", query);
549 let result =
550 Auth::validate_restrictions(restrictions, request.id(), &mut [request.clone()], false);
551 assert!(result.is_ok());
552 }
553
554 #[test]
555 fn validate_restrictions_regex_prefix_match() {
556 let reference_restriction = ReferenceNameRestrictionBuilder::default()
557 .name("chr1")
558 .format(Format::Bam)
559 .build()
560 .unwrap();
561 let rule = AuthorizationRuleBuilder::default()
562 .location(Location::Simple(Box::new(SimpleLocation::new(
563 Default::default(),
564 "".to_string(),
565 Some(PrefixOrId::Prefix("sam".to_string())),
566 ))))
567 .reference_name(reference_restriction)
568 .build()
569 .unwrap();
570 let restrictions = AuthorizationRestrictionsBuilder::default()
571 .rule(rule)
572 .build()
573 .unwrap();
574
575 let mut query = HashMap::new();
576 query.insert("referenceName".to_string(), "chr1".to_string());
577 query.insert("format".to_string(), "BAM".to_string());
578
579 let request = create_test_query(Endpoint::Reads, "sample123", query);
580 let result =
581 Auth::validate_restrictions(restrictions, request.id(), &mut [request.clone()], false);
582 assert!(result.is_ok());
583 }
584
585 #[test]
586 fn validate_restrictions_regex_match() {
587 let reference_restriction = ReferenceNameRestrictionBuilder::default()
588 .name("chr1")
589 .format(Format::Bam)
590 .build()
591 .unwrap();
592 let rule = AuthorizationRuleBuilder::default()
593 .location(Location::Regex(Box::new(RegexLocation::new(
594 Regex::new("sample(.+)").unwrap(),
595 "".to_string(),
596 Default::default(),
597 Default::default(),
598 ))))
599 .reference_name(reference_restriction)
600 .build()
601 .unwrap();
602 let restrictions = AuthorizationRestrictionsBuilder::default()
603 .rule(rule)
604 .build()
605 .unwrap();
606
607 let mut query = HashMap::new();
608 query.insert("referenceName".to_string(), "chr1".to_string());
609 query.insert("format".to_string(), "BAM".to_string());
610
611 let request = create_test_query(Endpoint::Reads, "sample123", query);
612 let result =
613 Auth::validate_restrictions(restrictions, request.id(), &mut [request.clone()], false);
614 assert!(result.is_ok());
615 }
616
617 #[test]
618 fn validate_restrictions_forward_headers() {
619 let (_, public_key) = generate_key_pair();
620
621 let builder = AuthConfigBuilder::default()
622 .auth_mode(AuthMode::PublicKey(public_key))
623 .authorization_url(UrlOrStatic::Url(Uri::from_static(
624 "https://www.example.com",
625 )))
626 .http_client(HttpClient::new(
627 ClientBuilder::new(reqwest::Client::new()).build(),
628 ));
629 let config = builder
630 .clone()
631 .passthrough_auth(true)
632 .forward_headers(vec!["Custom1".to_string()])
633 .build()
634 .unwrap();
635 let result = AuthBuilder::default().with_config(config).build().unwrap();
636
637 let request_headers = HeaderMap::from_iter([
638 (
639 "Authorization".parse().unwrap(),
640 "Bearer Value".parse().unwrap(),
641 ),
642 ("Custom1".parse().unwrap(), "Value".parse().unwrap()),
643 ("Custom2".parse().unwrap(), "Value".parse().unwrap()),
644 ]);
645 let forwarded_headers = result
646 .forwarded_headers(&request_headers, None, &Endpoint::Reads, "id")
647 .unwrap();
648 assert_eq!(
649 forwarded_headers,
650 HeaderMap::from_iter([
651 (
652 format!("{}Custom1", CONTEXT_HEADER_PREFIX).parse().unwrap(),
653 "Value".parse().unwrap()
654 ),
655 (
656 "Authorization".parse().unwrap(),
657 "Bearer Value".parse().unwrap()
658 ),
659 ])
660 );
661
662 let config = builder
663 .clone()
664 .passthrough_auth(true)
665 .forward_headers(vec!["Custom1".to_string(), "Authorization".to_string()])
666 .build()
667 .unwrap();
668 let result = AuthBuilder::default().with_config(config).build().unwrap();
669
670 let forwarded_headers = result
671 .forwarded_headers(&request_headers, None, &Endpoint::Reads, "id")
672 .unwrap();
673 assert_eq!(
674 forwarded_headers,
675 HeaderMap::from_iter([
676 (
677 format!("{}Custom1", CONTEXT_HEADER_PREFIX).parse().unwrap(),
678 "Value".parse().unwrap()
679 ),
680 (
681 format!("{}Authorization", CONTEXT_HEADER_PREFIX)
682 .parse()
683 .unwrap(),
684 "Bearer Value".parse().unwrap()
685 ),
686 (
687 "Authorization".parse().unwrap(),
688 "Bearer Value".parse().unwrap()
689 ),
690 ])
691 );
692
693 let config = builder
694 .clone()
695 .forward_headers(vec!["Custom1".to_string()])
696 .build()
697 .unwrap();
698 let result = AuthBuilder::default().with_config(config).build().unwrap();
699
700 let forwarded_headers = result
701 .forwarded_headers(&request_headers, None, &Endpoint::Reads, "id")
702 .unwrap();
703 assert_eq!(
704 forwarded_headers,
705 HeaderMap::from_iter([(
706 format!("{}Custom1", CONTEXT_HEADER_PREFIX).parse().unwrap(),
707 "Value".parse().unwrap()
708 ),])
709 );
710
711 let config = builder
712 .clone()
713 .forward_extensions(vec![ForwardExtensions::new(
714 "$.Key".to_string(),
715 "Custom1".to_string(),
716 )])
717 .build()
718 .unwrap();
719 let result = AuthBuilder::default().with_config(config).build().unwrap();
720
721 let forwarded_headers = result
722 .forwarded_headers(
723 &request_headers,
724 Some(json!({
725 "Key": "Value"
726 })),
727 &Endpoint::Reads,
728 "id",
729 )
730 .unwrap();
731 assert_eq!(
732 forwarded_headers,
733 HeaderMap::from_iter([(
734 format!("{}Custom1", CONTEXT_HEADER_PREFIX).parse().unwrap(),
735 "Value".parse().unwrap()
736 ),])
737 );
738
739 let config = builder.clone().forward_endpoint_type(true).build().unwrap();
740 let result = AuthBuilder::default().with_config(config).build().unwrap();
741
742 let forwarded_headers = result
743 .forwarded_headers(&request_headers, None, &Endpoint::Variants, "id")
744 .unwrap();
745 assert_eq!(
746 forwarded_headers,
747 HeaderMap::from_iter([(
748 format!("{}{}", CONTEXT_HEADER_PREFIX, ENDPOINT_TYPE_HEADER_NAME)
749 .parse()
750 .unwrap(),
751 "variants".parse().unwrap()
752 ),])
753 );
754
755 let config = builder.forward_id(true).build().unwrap();
756 let result = AuthBuilder::default().with_config(config).build().unwrap();
757
758 let forwarded_headers = result
759 .forwarded_headers(&request_headers, None, &Endpoint::Variants, "id")
760 .unwrap();
761 assert_eq!(
762 forwarded_headers,
763 HeaderMap::from_iter([(
764 format!("{}{}", CONTEXT_HEADER_PREFIX, ID_HEADER_NAME)
765 .parse()
766 .unwrap(),
767 "id".parse().unwrap()
768 ),])
769 );
770 }
771
772 #[test]
773 fn validate_restrictions_reference_name_mismatch() {
774 let reference_restriction = ReferenceNameRestrictionBuilder::default()
775 .name("chr1")
776 .format(Format::Bam)
777 .build()
778 .unwrap();
779 let rule = AuthorizationRuleBuilder::default()
780 .location(test_location())
781 .reference_name(reference_restriction)
782 .build()
783 .unwrap();
784 let restrictions = AuthorizationRestrictionsBuilder::default()
785 .rule(rule.clone())
786 .build()
787 .unwrap();
788
789 let mut query = HashMap::new();
790 query.insert("class".to_string(), "header".to_string());
791 query.insert("format".to_string(), "BAM".to_string());
792
793 let request = create_test_query(Endpoint::Reads, "sample1", query);
794 let result =
795 Auth::validate_restrictions(restrictions, request.id(), &mut [request.clone()], false);
796 assert!(result.is_ok());
797 }
798
799 #[test]
800 fn validate_restrictions_header() {
801 let reference_restriction = ReferenceNameRestrictionBuilder::default()
802 .name("chr1")
803 .format(Format::Bam)
804 .build()
805 .unwrap();
806 let rule = AuthorizationRuleBuilder::default()
807 .location(test_location())
808 .reference_name(reference_restriction)
809 .build()
810 .unwrap();
811 let restrictions = AuthorizationRestrictionsBuilder::default()
812 .rule(rule.clone())
813 .build()
814 .unwrap();
815
816 let mut query = HashMap::new();
817 query.insert("format".to_string(), "BAM".to_string());
818 query.insert("class".to_string(), "header".to_string());
819
820 let request = create_test_query(Endpoint::Reads, "sample1", query);
821 let result =
822 Auth::validate_restrictions(restrictions, request.id(), &mut [request.clone()], false);
823 assert!(result.is_ok());
824 }
825
826 #[cfg(feature = "experimental")]
827 #[test]
828 fn validate_restrictions_reference_name_mismatch_suppressed() {
829 let reference_restriction = ReferenceNameRestrictionBuilder::default()
830 .name("chr1")
831 .format(Format::Bam)
832 .build()
833 .unwrap();
834 let rule = AuthorizationRuleBuilder::default()
835 .location(test_location())
836 .reference_name(reference_restriction)
837 .build()
838 .unwrap();
839 let restrictions = AuthorizationRestrictionsBuilder::default()
840 .rule(rule.clone())
841 .build()
842 .unwrap();
843
844 let mut query = HashMap::new();
845 query.insert("referenceName".to_string(), "chr2".to_string());
846 query.insert("format".to_string(), "BAM".to_string());
847
848 let request = create_test_query(Endpoint::Reads, "sample1", query);
849 let result =
850 Auth::validate_restrictions(restrictions, request.id(), &mut [request.clone()], true);
851 assert!(result.is_ok());
852 }
853
854 #[test]
855 fn validate_restrictions_format_mismatch() {
856 let reference_restriction = ReferenceNameRestrictionBuilder::default()
857 .name("chr1")
858 .format(Format::Bam)
859 .build()
860 .unwrap();
861 let rule = AuthorizationRuleBuilder::default()
862 .location(test_location())
863 .reference_name(reference_restriction)
864 .build()
865 .unwrap();
866 let restrictions = AuthorizationRestrictionsBuilder::default()
867 .rule(rule.clone())
868 .build()
869 .unwrap();
870
871 let mut query = HashMap::new();
872 query.insert("referenceName".to_string(), "chr1".to_string());
873 query.insert("format".to_string(), "CRAM".to_string());
874
875 let request = create_test_query(Endpoint::Reads, "sample1", query);
876 let result =
877 Auth::validate_restrictions(restrictions, request.id(), &mut [request.clone()], false);
878 assert!(result.is_err());
879 }
880
881 #[cfg(feature = "experimental")]
882 #[test]
883 fn validate_restrictions_format_mismatch_suppressed() {
884 let reference_restriction = ReferenceNameRestrictionBuilder::default()
885 .name("chr1")
886 .format(Format::Bam)
887 .build()
888 .unwrap();
889 let rule = AuthorizationRuleBuilder::default()
890 .location(test_location())
891 .reference_name(reference_restriction)
892 .build()
893 .unwrap();
894 let restrictions = AuthorizationRestrictionsBuilder::default()
895 .rule(rule.clone())
896 .build()
897 .unwrap();
898
899 let mut query = HashMap::new();
900 query.insert("referenceName".to_string(), "chr1".to_string());
901 query.insert("format".to_string(), "CRAM".to_string());
902
903 let request = create_test_query(Endpoint::Reads, "sample1", query);
904 let result =
905 Auth::validate_restrictions(restrictions, request.id(), &mut [request.clone()], true);
906 assert!(result.is_ok());
907 }
908
909 #[test]
910 fn validate_restrictions_interval_not_contained() {
911 test_interval_suppressed(
915 Some(1000),
916 Some(2000),
917 Some(1250),
918 Some(1750),
919 (Interval::new(Some(1250), Some(1750)), Class::Body),
920 false,
921 false,
922 );
923
924 test_interval_suppressed(
928 Some(1000),
929 Some(2000),
930 Some(500),
931 None,
932 (Interval::new(Some(500), None), Class::Body),
933 true,
934 false,
935 );
936
937 test_interval_suppressed(
941 Some(1000),
942 Some(2000),
943 None,
944 Some(2500),
945 (Interval::new(None, Some(2500)), Class::Body),
946 true,
947 false,
948 );
949
950 test_interval_suppressed(
954 Some(1000),
955 Some(2000),
956 None,
957 None,
958 (Interval::new(None, None), Class::Body),
959 true,
960 false,
961 );
962
963 test_interval_suppressed(
967 Some(1000),
968 Some(2000),
969 Some(500),
970 Some(1500),
971 (Interval::new(Some(500), Some(1500)), Class::Body),
972 true,
973 false,
974 );
975
976 test_interval_suppressed(
980 Some(1000),
981 Some(2000),
982 None,
983 Some(1500),
984 (Interval::new(None, Some(1500)), Class::Body),
985 true,
986 false,
987 );
988
989 test_interval_suppressed(
993 Some(1000),
994 Some(2000),
995 Some(1500),
996 Some(2500),
997 (Interval::new(Some(1500), Some(2500)), Class::Body),
998 true,
999 false,
1000 );
1001
1002 test_interval_suppressed(
1006 Some(1000),
1007 Some(2000),
1008 Some(1500),
1009 None,
1010 (Interval::new(Some(1500), None), Class::Body),
1011 true,
1012 false,
1013 );
1014
1015 test_interval_suppressed(
1019 Some(1000),
1020 Some(2000),
1021 Some(500),
1022 Some(1000),
1023 (Interval::new(Some(500), Some(1000)), Class::Body),
1024 true,
1025 false,
1026 );
1027
1028 test_interval_suppressed(
1032 Some(1000),
1033 Some(2000),
1034 Some(2000),
1035 Some(2500),
1036 (Interval::new(Some(2000), Some(2500)), Class::Body),
1037 true,
1038 false,
1039 );
1040
1041 test_interval_suppressed(
1045 None,
1046 Some(2000),
1047 Some(500),
1048 Some(1500),
1049 (Interval::new(Some(500), Some(1500)), Class::Body),
1050 false,
1051 false,
1052 );
1053
1054 test_interval_suppressed(
1058 None,
1059 Some(2000),
1060 Some(1500),
1061 Some(2500),
1062 (Interval::new(Some(1500), Some(2500)), Class::Body),
1063 true,
1064 false,
1065 );
1066
1067 test_interval_suppressed(
1071 Some(1000),
1072 None,
1073 Some(1500),
1074 Some(2500),
1075 (Interval::new(Some(1500), Some(2500)), Class::Body),
1076 false,
1077 false,
1078 );
1079
1080 test_interval_suppressed(
1084 Some(1000),
1085 None,
1086 Some(500),
1087 Some(1500),
1088 (Interval::new(Some(500), Some(1500)), Class::Body),
1089 true,
1090 false,
1091 );
1092
1093 test_interval_suppressed(
1097 None,
1098 None,
1099 Some(500),
1100 Some(2500),
1101 (Interval::new(Some(500), Some(2500)), Class::Body),
1102 false,
1103 false,
1104 );
1105
1106 test_interval_suppressed(
1110 None,
1111 None,
1112 Some(500),
1113 None,
1114 (Interval::new(Some(500), None), Class::Body),
1115 false,
1116 false,
1117 );
1118
1119 test_interval_suppressed(
1123 None,
1124 None,
1125 None,
1126 Some(2500),
1127 (Interval::new(None, Some(2500)), Class::Body),
1128 false,
1129 false,
1130 );
1131 }
1132
1133 #[cfg(feature = "experimental")]
1134 #[test]
1135 fn validate_restrictions_interval_suppressed() {
1136 test_interval_suppressed(
1140 Some(1000),
1141 Some(2000),
1142 Some(1250),
1143 Some(1750),
1144 (Interval::new(Some(1250), Some(1750)), Class::Body),
1145 false,
1146 true,
1147 );
1148
1149 test_interval_suppressed(
1153 Some(1000),
1154 Some(2000),
1155 Some(500),
1156 None,
1157 (Interval::new(Some(1000), Some(2000)), Class::Body),
1158 false,
1159 true,
1160 );
1161
1162 test_interval_suppressed(
1166 Some(1000),
1167 Some(2000),
1168 None,
1169 Some(2500),
1170 (Interval::new(Some(1000), Some(2000)), Class::Body),
1171 false,
1172 true,
1173 );
1174
1175 test_interval_suppressed(
1179 Some(1000),
1180 Some(2000),
1181 None,
1182 None,
1183 (Interval::new(Some(1000), Some(2000)), Class::Body),
1184 false,
1185 true,
1186 );
1187
1188 test_interval_suppressed(
1192 Some(1000),
1193 Some(2000),
1194 Some(500),
1195 Some(1500),
1196 (Interval::new(Some(1000), Some(1500)), Class::Body),
1197 false,
1198 true,
1199 );
1200
1201 test_interval_suppressed(
1205 Some(1000),
1206 Some(2000),
1207 None,
1208 Some(1500),
1209 (Interval::new(Some(1000), Some(1500)), Class::Body),
1210 false,
1211 true,
1212 );
1213
1214 test_interval_suppressed(
1218 Some(1000),
1219 Some(2000),
1220 Some(1500),
1221 Some(2500),
1222 (Interval::new(Some(1500), Some(2000)), Class::Body),
1223 false,
1224 true,
1225 );
1226
1227 test_interval_suppressed(
1231 Some(1000),
1232 Some(2000),
1233 Some(1500),
1234 None,
1235 (Interval::new(Some(1500), Some(2000)), Class::Body),
1236 false,
1237 true,
1238 );
1239
1240 test_interval_suppressed(
1244 Some(1000),
1245 Some(2000),
1246 Some(500),
1247 Some(1000),
1248 (Interval::new(Some(500), Some(1000)), Class::Header),
1249 false,
1250 true,
1251 );
1252
1253 test_interval_suppressed(
1257 Some(1000),
1258 Some(2000),
1259 Some(2000),
1260 Some(2500),
1261 (Interval::new(Some(2000), Some(2500)), Class::Header),
1262 false,
1263 true,
1264 );
1265
1266 test_interval_suppressed(
1270 None,
1271 Some(2000),
1272 Some(500),
1273 Some(1500),
1274 (Interval::new(Some(500), Some(1500)), Class::Body),
1275 false,
1276 true,
1277 );
1278
1279 test_interval_suppressed(
1283 None,
1284 Some(2000),
1285 Some(1500),
1286 Some(2500),
1287 (Interval::new(Some(1500), Some(2000)), Class::Body),
1288 false,
1289 true,
1290 );
1291
1292 test_interval_suppressed(
1296 Some(1000),
1297 None,
1298 Some(1500),
1299 Some(2500),
1300 (Interval::new(Some(1500), Some(2500)), Class::Body),
1301 false,
1302 true,
1303 );
1304
1305 test_interval_suppressed(
1309 Some(1000),
1310 None,
1311 Some(500),
1312 Some(1500),
1313 (Interval::new(Some(1000), Some(1500)), Class::Body),
1314 false,
1315 true,
1316 );
1317
1318 test_interval_suppressed(
1322 None,
1323 None,
1324 Some(500),
1325 Some(2500),
1326 (Interval::new(Some(500), Some(2500)), Class::Body),
1327 false,
1328 true,
1329 );
1330
1331 test_interval_suppressed(
1335 None,
1336 None,
1337 Some(500),
1338 None,
1339 (Interval::new(Some(500), None), Class::Body),
1340 false,
1341 true,
1342 );
1343
1344 test_interval_suppressed(
1348 None,
1349 None,
1350 None,
1351 Some(2500),
1352 (Interval::new(None, Some(2500)), Class::Body),
1353 false,
1354 true,
1355 );
1356 }
1357
1358 #[test]
1359 fn validate_restrictions_format_none_allows_any() {
1360 let reference_restriction = ReferenceNameRestrictionBuilder::default()
1361 .name("chr1")
1362 .build()
1363 .unwrap();
1364 let rule = AuthorizationRuleBuilder::default()
1365 .location(test_location())
1366 .reference_name(reference_restriction)
1367 .build()
1368 .unwrap();
1369 let restrictions = AuthorizationRestrictionsBuilder::default()
1370 .rule(rule)
1371 .build()
1372 .unwrap();
1373
1374 let mut query = HashMap::new();
1375 query.insert("referenceName".to_string(), "chr1".to_string());
1376 query.insert("format".to_string(), "CRAM".to_string());
1377
1378 let request = create_test_query(Endpoint::Reads, "sample1", query);
1379 let result =
1380 Auth::validate_restrictions(restrictions, request.id(), &mut [request.clone()], false);
1381 assert!(result.is_ok());
1382 }
1383
1384 #[test]
1385 fn validate_restrictions_path_with_leading_slash() {
1386 let rule = AuthorizationRuleBuilder::default()
1387 .location(test_location())
1388 .build()
1389 .unwrap();
1390 let restrictions = AuthorizationRestrictionsBuilder::default()
1391 .rule(rule)
1392 .build()
1393 .unwrap();
1394 let request = create_test_query(Endpoint::Reads, "sample1", HashMap::new());
1395 let result =
1396 Auth::validate_restrictions(restrictions, request.id(), &mut [request.clone()], false);
1397 assert!(result.is_ok());
1398 }
1399
1400 #[tokio::test]
1401 async fn validate_authorization_missing_auth_header() {
1402 let mut auth = create_mock_auth_with_restrictions();
1403 let request = Request::new("sample1".to_string(), HashMap::new(), HeaderMap::new());
1404
1405 let result = auth.validate_jwt(request.headers()).await;
1406 assert!(result.is_err());
1407 assert!(matches!(
1408 result.unwrap_err(),
1409 HtsGetError::InvalidAuthentication(_)
1410 ));
1411 }
1412
1413 #[tokio::test]
1414 async fn validate_authorization_invalid_jwt_format() {
1415 let mut auth = create_mock_auth_with_restrictions();
1416 let request = create_request_with_auth_header("sample1", HashMap::new(), "invalid.jwt.token");
1417
1418 let result = auth.validate_jwt(request.headers()).await;
1419 assert!(result.is_err());
1420 assert!(matches!(
1421 result.unwrap_err(),
1422 HtsGetError::PermissionDenied(_)
1423 ));
1424 }
1425
1426 fn create_test_auth_config(public_key: Vec<u8>) -> AuthConfig {
1427 AuthConfigBuilder::default()
1428 .auth_mode(AuthMode::PublicKey(public_key))
1429 .authorization_url(UrlOrStatic::Url(Uri::from_static(
1430 "https://www.example.com",
1431 )))
1432 .http_client(HttpClient::new(
1433 ClientBuilder::new(reqwest::Client::new()).build(),
1434 ))
1435 .build()
1436 .unwrap()
1437 }
1438
1439 fn create_test_query(endpoint: Endpoint, path: &str, query: HashMap<String, String>) -> Query {
1440 let request = Request::new(path.to_string(), query, HeaderMap::new());
1441 let format = match_format_from_query(&endpoint, request.query()).unwrap();
1442
1443 convert_to_query(request, format).unwrap()
1444 }
1445
1446 fn create_request_with_auth_header(
1447 path: &str,
1448 query: HashMap<String, String>,
1449 token: &str,
1450 ) -> Request {
1451 let mut headers = HeaderMap::new();
1452 headers.insert("authorization", format!("Bearer {token}").parse().unwrap());
1453 Request::new(path.to_string(), query, headers)
1454 }
1455
1456 fn create_mock_auth_with_restrictions() -> Auth {
1457 let (_, public_key) = generate_key_pair();
1458
1459 let config = create_test_auth_config(public_key);
1460 AuthBuilder::default().with_config(config).build().unwrap()
1461 }
1462
1463 fn test_interval_suppressed(
1464 restrict_start: Option<u32>,
1465 restrict_end: Option<u32>,
1466 request_start: Option<u32>,
1467 request_end: Option<u32>,
1468 expected_response: (Interval, Class),
1469 is_err: bool,
1470 suppress_interval: bool,
1471 ) {
1472 let mut reference_restriction = ReferenceNameRestrictionBuilder::default()
1473 .name("chr1")
1474 .format(Format::Bam);
1475
1476 if let Some(start) = restrict_start {
1477 reference_restriction = reference_restriction.start(start);
1478 }
1479 if let Some(end) = restrict_end {
1480 reference_restriction = reference_restriction.end(end);
1481 }
1482
1483 let reference_restriction = reference_restriction.build().unwrap();
1484 let rule = AuthorizationRuleBuilder::default()
1485 .location(test_location())
1486 .reference_name(reference_restriction)
1487 .build()
1488 .unwrap();
1489 let restrictions = AuthorizationRestrictionsBuilder::default()
1490 .rule(rule.clone())
1491 .build()
1492 .unwrap();
1493
1494 let mut query = HashMap::new();
1495 query.insert("referenceName".to_string(), "chr1".to_string());
1496 request_start.map(|start| query.insert("start".to_string(), start.to_string()));
1497 request_end.map(|end| query.insert("end".to_string(), end.to_string()));
1498
1499 let request = create_test_query(Endpoint::Reads, "sample1", query);
1500 let id = request.id().to_string();
1501 let mut slice = [request];
1502 let result = Auth::validate_restrictions(restrictions, &id, &mut slice, suppress_interval);
1503 if is_err {
1504 assert!(result.is_err());
1505 } else {
1506 assert!(result.is_ok());
1507 }
1508 assert_eq!(slice.first().unwrap().interval(), expected_response.0);
1509 assert_eq!(slice.last().unwrap().class(), expected_response.1);
1510 }
1511
1512 fn test_location() -> Location {
1513 Location::Simple(Box::new(SimpleLocation::new(
1514 Default::default(),
1515 "".to_string(),
1516 Some(PrefixOrId::Id("sample1".to_string())),
1517 )))
1518 }
1519}