htsget_http/middleware/
auth.rs

1//! The htsget authorization middleware.
2//!
3
4use crate::HtsGetError;
5use crate::error::Result as HtsGetResult;
6use crate::middleware::error::Error::AuthBuilderError;
7use crate::middleware::error::Result;
8use cfg_if::cfg_if;
9use headers::authorization::Bearer;
10use headers::{Authorization, Header};
11use htsget_config::config::advanced::auth::authorization::UrlOrStatic;
12use htsget_config::config::advanced::auth::jwt::AuthMode;
13use htsget_config::config::advanced::auth::response::AuthorizationRestrictionsBuilder;
14use htsget_config::config::advanced::auth::{AuthConfig, AuthorizationRestrictions};
15use htsget_config::config::location::{Location, PrefixOrId};
16use htsget_config::types::{Class, Interval, Query};
17use http::{HeaderMap, HeaderName, HeaderValue, Uri};
18use jsonpath_rust::JsonPath;
19use jsonwebtoken::jwk::JwkSet;
20use jsonwebtoken::{Algorithm, DecodingKey, TokenData, Validation, decode, decode_header};
21use serde::de::DeserializeOwned;
22use serde_json::Value;
23use std::fmt::{Debug, Formatter};
24use std::str::FromStr;
25use tracing::trace;
26
27/// The authorization middleware builder.
28#[derive(Default, Debug)]
29pub struct AuthBuilder {
30  config: Option<AuthConfig>,
31}
32
33impl AuthBuilder {
34  /// Set the config.
35  pub fn with_config(mut self, config: AuthConfig) -> Self {
36    self.config = Some(config);
37    self
38  }
39
40  /// Build the auth layer, ensures that the config sets the correct parameters.
41  pub fn build(self) -> Result<Auth> {
42    let Some(mut config) = self.config else {
43      return Err(AuthBuilderError("missing config".to_string()));
44    };
45
46    let mut decoding_key = None;
47    if let Some(AuthMode::PublicKey(public_key)) = config.auth_mode_mut() {
48      decoding_key = Some(
49        Auth::decode_public_key(public_key)
50          .map_err(|_| AuthBuilderError("failed to decode public key".to_string()))?,
51      );
52    }
53
54    Ok(Auth {
55      config,
56      decoding_key,
57    })
58  }
59}
60
61/// The auth middleware layer.
62#[derive(Clone)]
63pub struct Auth {
64  config: AuthConfig,
65  decoding_key: Option<DecodingKey>,
66}
67
68impl Debug for Auth {
69  fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
70    f.debug_struct("config").finish()
71  }
72}
73
74const FORWARD_HEADER_PREFIX: &str = "Htsget-Context-";
75
76impl Auth {
77  /// Get the config for this auth layer instance.
78  pub fn config(&self) -> &AuthConfig {
79    &self.config
80  }
81
82  /// Fetch JWKS from the authorization server.
83  pub async fn fetch_from_url<D: DeserializeOwned>(
84    &self,
85    url: &str,
86    headers: HeaderMap,
87  ) -> HtsGetResult<D> {
88    trace!("fetching url: {}", url);
89    let err = || HtsGetError::InternalError(format!("failed to fetch data from {url}"));
90    let response = self
91      .config
92      .http_client()
93      .get(url)
94      .headers(headers)
95      .send()
96      .await
97      .map_err(|_| err())?;
98
99    response.json().await.map_err(|_| err())
100  }
101
102  /// Get a decoding key form the JWKS url.
103  pub async fn decode_jwks(&self, jwks_url: &Uri, token: &str) -> HtsGetResult<DecodingKey> {
104    // Decode header and get the key id.
105    let header = decode_header(token)?;
106    let kid = header
107      .kid
108      .ok_or_else(|| HtsGetError::PermissionDenied("JWT missing key ID".to_string()))?;
109
110    // Fetch JWKS from the authorization server and find matching JWK.
111    let jwks = self
112      .fetch_from_url::<JwkSet>(&jwks_url.to_string(), Default::default())
113      .await?;
114    let matched_jwk = jwks
115      .find(&kid)
116      .ok_or_else(|| HtsGetError::PermissionDenied("matching JWK not found".to_string()))?;
117
118    Ok(DecodingKey::from_jwk(matched_jwk)?)
119  }
120
121  /// Decode a public key into an RSA, EdDSA or ECDSA pem-formatted decoding key.
122  pub fn decode_public_key(key: &[u8]) -> HtsGetResult<DecodingKey> {
123    Ok(
124      DecodingKey::from_rsa_pem(key)
125        .or_else(|_| DecodingKey::from_ed_pem(key))
126        .or_else(|_| DecodingKey::from_ec_pem(key))?,
127    )
128  }
129
130  /// Get the headers to send to the authorization service.
131  pub fn forwarded_headers(
132    &self,
133    request_headers: &HeaderMap,
134    request_extensions: Option<Value>,
135  ) -> HtsGetResult<HeaderMap> {
136    let mut forwarded_headers = if self.config.passthrough_auth() {
137      let auth_header = request_headers
138        .iter()
139        .find_map(|(name, value)| {
140          if Authorization::<Bearer>::decode(&mut [value].into_iter()).is_ok() {
141            return Some((name.clone(), value.clone()));
142          }
143
144          None
145        })
146        .ok_or_else(|| HtsGetError::PermissionDenied("missing authorization header".to_string()))?;
147      HeaderMap::from_iter([auth_header])
148    } else {
149      HeaderMap::default()
150    };
151
152    for header in self.config.forward_headers() {
153      let Some((existing_name, existing_value)) = request_headers
154        .iter()
155        .find_map(|(name, value)| {
156          if header.to_lowercase() == name.as_str().to_lowercase() {
157            return match HeaderName::from_str(&format!("{}{}", FORWARD_HEADER_PREFIX, name)) {
158              Ok(header) => Some(Ok((header, value))),
159              Err(err) => Some(Err(HtsGetError::InternalError(err.to_string()))),
160            };
161          }
162
163          None
164        })
165        .transpose()?
166      else {
167        continue;
168      };
169
170      forwarded_headers.insert(existing_name, existing_value.clone());
171    }
172
173    if let Some(request_extensions) = request_extensions {
174      for extension in self.config.forward_extensions() {
175        let Some(value) = request_extensions.query(extension.json_path()).ok() else {
176          continue;
177        };
178
179        let value = value.first().ok_or_else(|| {
180          HtsGetError::InternalError("extension does not have only one value".to_string())
181        })?;
182        let value = value.as_str().ok_or_else(|| {
183          HtsGetError::InternalError("extension value is not a string".to_string())
184        })?;
185
186        let header_name =
187          HeaderName::from_str(&format!("{}{}", FORWARD_HEADER_PREFIX, extension.name()))
188            .map_err(|err| HtsGetError::InternalError(err.to_string()))?;
189        let value = HeaderValue::from_str(value)
190          .map_err(|err| HtsGetError::InternalError(err.to_string()))?;
191        forwarded_headers.insert(header_name, value);
192      }
193    }
194
195    Ok(forwarded_headers)
196  }
197
198  /// Query the authorization service to get the restrictions. This function validates
199  /// that the authorization url is trusted in the config settings before calling the
200  /// service. The claims are assumed to be valid.
201  pub async fn query_authorization_service(
202    &self,
203    headers: &HeaderMap,
204    request_extensions: Option<Value>,
205  ) -> HtsGetResult<Option<AuthorizationRestrictions>> {
206    match self.config.authorization_url() {
207      Some(UrlOrStatic::Url(uri)) => {
208        let forwarded_headers = self.forwarded_headers(headers, request_extensions)?;
209
210        self
211          .fetch_from_url(&uri.to_string(), forwarded_headers)
212          .await
213          .map(Some)
214      }
215      Some(UrlOrStatic::Static(config)) => Ok(Some(config.clone())),
216      _ => Ok(None),
217    }
218  }
219
220  /// Validate the restrictions, returning an error if the user is not authorized.
221  /// If `suppressed_interval` is set then no error is returning if there is a
222  /// path match but no restrictions match. Instead, as many regions as possible
223  /// are returned.
224  pub fn validate_restrictions(
225    restrictions: AuthorizationRestrictions,
226    path: &str,
227    queries: &mut [Query],
228    suppressed_interval: bool,
229  ) -> HtsGetResult<AuthorizationRestrictions> {
230    // Find all rules matching the path.
231    let matching_rules = restrictions
232      .into_rules()
233      .into_iter()
234      .filter(|rule| {
235        match rule.location() {
236          Location::Simple(location) if location.prefix_or_id().is_some() => {
237            match location.prefix_or_id().unwrap_or_default() {
238              PrefixOrId::Prefix(prefix) => {
239                // A prefix has a starts with rule.
240                path.starts_with(&prefix)
241              }
242              PrefixOrId::Id(id) => {
243                // An id location must match exactly.
244                id == path
245              }
246            }
247          }
248          Location::Regex(location) => {
249            // A regex location matches using the regex.
250            location.regex().is_match(path)
251          }
252          // Missing valid location.
253          _ => false,
254        }
255      })
256      .collect::<Vec<_>>();
257
258    // If no paths match, then this is an authorization error.
259    if matching_rules.is_empty() {
260      return Err(HtsGetError::PermissionDenied(
261        "failed to authorize user based on authorization service restrictions".to_string(),
262      ));
263    }
264
265    let (allows_all, allows_specific): (Vec<_>, Vec<_>) = matching_rules
266      .into_iter()
267      .partition(|rule| rule.rules().is_none());
268
269    // Otherwise, we need to check if the specific reference name is allowed for all queries.
270    for query in queries {
271      // If the request is for headers only, then this should always be allowed.
272      if query.class() == Class::Header {
273        continue;
274      }
275
276      let matching_restriction = allows_specific
277        .iter()
278        .flat_map(|rule| rule.rules().unwrap_or_default())
279        .filter_map(|restriction| {
280          // The reference name should match exactly if it's set, otherwise allow any reference name.
281          let name_match = restriction.reference_name().is_none()
282            || restriction.reference_name() == query.reference_name();
283          // The format should match if it's defined, otherwise it allows any format.
284          let format_match =
285            restriction.format().is_none() || restriction.format() == Some(query.format());
286          // The interval should match and be constrained, considering undefined start or end ranges.
287          let interval_match = if suppressed_interval {
288            restriction.interval().constraint_interval(query.interval())
289          } else {
290            restriction.interval().contains_interval(query.interval())
291          };
292
293          if let Some(interval_match) = interval_match {
294            if name_match && format_match {
295              return Some(interval_match);
296            }
297          }
298
299          None
300        })
301        .max_by(Interval::order_by_range); // The largest interval should be used if there are multiple matches.
302
303      if suppressed_interval {
304        if allows_all.is_empty() && matching_restriction.is_none() {
305          // If nothing allows all and there are no matching intervals then return an empty response.
306          query.set_class(Class::Header);
307          continue;
308        }
309
310        if let Some(matching_restriction) = matching_restriction {
311          query.set_interval(matching_restriction);
312        }
313      } else if allows_all.is_empty() && matching_restriction.is_none() {
314        return Err(HtsGetError::PermissionDenied(
315          "failed to authorize user based on authorization service restrictions".to_string(),
316        ));
317      }
318    }
319
320    AuthorizationRestrictionsBuilder::default()
321      .rules([allows_all, allows_specific].concat())
322      .build()
323      .map_err(|err| HtsGetError::InternalError(err.to_string()))
324  }
325
326  /// Validate only the JWT without looking up restrictions and validating those. Returns the
327  /// decoded JWT token.
328  pub async fn validate_jwt(&self, headers: &HeaderMap) -> HtsGetResult<TokenData<Value>> {
329    let auth_token = headers
330      .values()
331      .find_map(|value| Authorization::<Bearer>::decode(&mut [value].into_iter()).ok())
332      .ok_or_else(|| {
333        HtsGetError::InvalidAuthentication("invalid authorization header".to_string())
334      })?;
335
336    let decoding_key = if let Some(ref decoding_key) = self.decoding_key {
337      decoding_key
338    } else {
339      match self.config.auth_mode() {
340        Some(AuthMode::Jwks(jwks)) => &self.decode_jwks(jwks, auth_token.token()).await?,
341        Some(AuthMode::PublicKey(public_key)) => &Self::decode_public_key(public_key)?,
342        _ => {
343          return Err(HtsGetError::InternalError(
344            "JWT validation not set".to_string(),
345          ));
346        }
347      }
348    };
349
350    // Decode and validate the JWT
351    let mut validation = Validation::default();
352    validation.validate_exp = true;
353    validation.validate_aud = true;
354    validation.validate_nbf = true;
355
356    if let Some(iss) = self.config.validate_issuer() {
357      validation.set_issuer(iss);
358      validation.required_spec_claims.insert("iss".to_string());
359    }
360    if let Some(aud) = self.config.validate_audience() {
361      validation.set_audience(aud);
362      validation.required_spec_claims.insert("aud".to_string());
363    }
364    if let Some(sub) = self.config.validate_subject() {
365      validation.sub = Some(sub.to_string());
366      validation.required_spec_claims.insert("sub".to_string());
367    }
368
369    // Each supported algorithm must be tried individually because the jsonwebtoken validation
370    // logic only tries one algorithm in the vec: https://github.com/Keats/jsonwebtoken/issues/297
371    validation.algorithms = vec![Algorithm::RS256];
372    let decoded_claims = decode::<Value>(auth_token.token(), decoding_key, &validation)
373      .or_else(|_| {
374        validation.algorithms = vec![Algorithm::ES256];
375        decode::<Value>(auth_token.token(), decoding_key, &validation)
376      })
377      .or_else(|_| {
378        validation.algorithms = vec![Algorithm::EdDSA];
379        decode::<Value>(auth_token.token(), decoding_key, &validation)
380      });
381
382    let claims = match decoded_claims {
383      Ok(claims) => claims,
384      Err(err) => return Err(HtsGetError::PermissionDenied(format!("invalid JWT: {err}"))),
385    };
386
387    Ok(claims)
388  }
389
390  /// Validate the authorization flow, returning an error if the user is not authorized.
391  /// This performs the following steps:
392  ///
393  /// 1. Finds the JWT decoding key from the config or by querying a JWKS url.
394  /// 2. Validates the JWT token according to the config.
395  /// 3. Queries the authorization service for restrictions based on the config or JWT claims.
396  /// 4. Validates the restrictions to determine if the user is authorized.
397  pub async fn validate_authorization(
398    &self,
399    headers: &HeaderMap,
400    path: &str,
401    queries: &mut [Query],
402    request_extensions: Option<Value>,
403  ) -> HtsGetResult<Option<AuthorizationRestrictions>> {
404    let restrictions = self
405      .query_authorization_service(headers, request_extensions)
406      .await?;
407
408    if let Some(restrictions) = restrictions {
409      cfg_if! {
410        if #[cfg(feature = "experimental")] {
411          Self::validate_restrictions(restrictions, path, queries, self.config.suppress_errors()).map(Some)
412        } else {
413          Self::validate_restrictions(restrictions, path, queries, false).map(Some)
414        }
415      }
416    } else {
417      Ok(None)
418    }
419  }
420}
421
422#[cfg(test)]
423mod tests {
424  use super::*;
425  use crate::{Endpoint, convert_to_query, match_format_from_query};
426  use htsget_config::config::advanced::HttpClient;
427  use htsget_config::config::advanced::auth::AuthConfigBuilder;
428  use htsget_config::config::advanced::auth::authorization::ForwardExtensions;
429  use htsget_config::config::advanced::auth::response::{
430    AuthorizationRestrictionsBuilder, AuthorizationRuleBuilder, ReferenceNameRestrictionBuilder,
431  };
432  use htsget_config::config::advanced::regex_location::RegexLocation;
433  use htsget_config::config::location::SimpleLocation;
434  use htsget_config::types::{Format, Request};
435  use htsget_test::util::generate_key_pair;
436  use http::{HeaderMap, Uri};
437  use regex::Regex;
438  use reqwest_middleware::ClientBuilder;
439  use serde_json::json;
440  use std::collections::HashMap;
441
442  #[test]
443  fn auth_builder_missing_config() {
444    let result = AuthBuilder::default().build();
445    assert!(matches!(result, Err(AuthBuilderError(_))));
446  }
447
448  #[test]
449  fn auth_builder_success_with_public_key() {
450    let (_, public_key) = generate_key_pair();
451
452    let config = create_test_auth_config(public_key);
453    let result = AuthBuilder::default().with_config(config).build();
454    assert!(result.is_ok());
455  }
456
457  #[test]
458  fn validate_restrictions_rule_allows_all() {
459    let rule = AuthorizationRuleBuilder::default()
460      .location(test_location())
461      .build()
462      .unwrap();
463    let restrictions = AuthorizationRestrictionsBuilder::default()
464      .rule(rule)
465      .build()
466      .unwrap();
467
468    let request = create_test_query(Endpoint::Reads, "sample1", HashMap::new());
469    let result =
470      Auth::validate_restrictions(restrictions, request.id(), &mut [request.clone()], false);
471    assert!(result.is_ok());
472  }
473
474  #[test]
475  fn validate_restrictions_exact_path_match() {
476    let reference_restriction = ReferenceNameRestrictionBuilder::default()
477      .name("chr1")
478      .format(Format::Bam)
479      .start(1000)
480      .end(2000)
481      .build()
482      .unwrap();
483    let rule = AuthorizationRuleBuilder::default()
484      .location(test_location())
485      .reference_name(reference_restriction)
486      .build()
487      .unwrap();
488    let restrictions = AuthorizationRestrictionsBuilder::default()
489      .rule(rule)
490      .build()
491      .unwrap();
492
493    let mut query = HashMap::new();
494    query.insert("referenceName".to_string(), "chr1".to_string());
495    query.insert("start".to_string(), "1500".to_string());
496    query.insert("end".to_string(), "1800".to_string());
497    query.insert("format".to_string(), "BAM".to_string());
498
499    let request = create_test_query(Endpoint::Reads, "sample1", query);
500    let result =
501      Auth::validate_restrictions(restrictions, request.id(), &mut [request.clone()], false);
502    assert!(result.is_ok());
503  }
504
505  #[test]
506  fn validate_restrictions_regex_prefix_match() {
507    let reference_restriction = ReferenceNameRestrictionBuilder::default()
508      .name("chr1")
509      .format(Format::Bam)
510      .build()
511      .unwrap();
512    let rule = AuthorizationRuleBuilder::default()
513      .location(Location::Simple(Box::new(SimpleLocation::new(
514        Default::default(),
515        "".to_string(),
516        Some(PrefixOrId::Prefix("sam".to_string())),
517      ))))
518      .reference_name(reference_restriction)
519      .build()
520      .unwrap();
521    let restrictions = AuthorizationRestrictionsBuilder::default()
522      .rule(rule)
523      .build()
524      .unwrap();
525
526    let mut query = HashMap::new();
527    query.insert("referenceName".to_string(), "chr1".to_string());
528    query.insert("format".to_string(), "BAM".to_string());
529
530    let request = create_test_query(Endpoint::Reads, "sample123", query);
531    let result =
532      Auth::validate_restrictions(restrictions, request.id(), &mut [request.clone()], false);
533    assert!(result.is_ok());
534  }
535
536  #[test]
537  fn validate_restrictions_regex_match() {
538    let reference_restriction = ReferenceNameRestrictionBuilder::default()
539      .name("chr1")
540      .format(Format::Bam)
541      .build()
542      .unwrap();
543    let rule = AuthorizationRuleBuilder::default()
544      .location(Location::Regex(Box::new(RegexLocation::new(
545        Regex::new("sample(.+)").unwrap(),
546        "".to_string(),
547        Default::default(),
548        Default::default(),
549      ))))
550      .reference_name(reference_restriction)
551      .build()
552      .unwrap();
553    let restrictions = AuthorizationRestrictionsBuilder::default()
554      .rule(rule)
555      .build()
556      .unwrap();
557
558    let mut query = HashMap::new();
559    query.insert("referenceName".to_string(), "chr1".to_string());
560    query.insert("format".to_string(), "BAM".to_string());
561
562    let request = create_test_query(Endpoint::Reads, "sample123", query);
563    let result =
564      Auth::validate_restrictions(restrictions, request.id(), &mut [request.clone()], false);
565    assert!(result.is_ok());
566  }
567
568  #[test]
569  fn validate_restrictions_forward_headers() {
570    let (_, public_key) = generate_key_pair();
571
572    let builder = AuthConfigBuilder::default()
573      .auth_mode(AuthMode::PublicKey(public_key))
574      .authorization_url(UrlOrStatic::Url(Uri::from_static(
575        "https://www.example.com",
576      )))
577      .http_client(HttpClient::new(
578        ClientBuilder::new(reqwest::Client::new()).build(),
579      ));
580    let config = builder
581      .clone()
582      .passthrough_auth(true)
583      .forward_headers(vec!["Custom1".to_string()])
584      .build()
585      .unwrap();
586    let result = AuthBuilder::default().with_config(config).build().unwrap();
587
588    let request_headers = HeaderMap::from_iter([
589      (
590        "Authorization".parse().unwrap(),
591        "Bearer Value".parse().unwrap(),
592      ),
593      ("Custom1".parse().unwrap(), "Value".parse().unwrap()),
594      ("Custom2".parse().unwrap(), "Value".parse().unwrap()),
595    ]);
596    let forwarded_headers = result.forwarded_headers(&request_headers, None).unwrap();
597    assert_eq!(
598      forwarded_headers,
599      HeaderMap::from_iter([
600        (
601          format!("{}Custom1", FORWARD_HEADER_PREFIX).parse().unwrap(),
602          "Value".parse().unwrap()
603        ),
604        (
605          "Authorization".parse().unwrap(),
606          "Bearer Value".parse().unwrap()
607        ),
608      ])
609    );
610
611    let config = builder
612      .clone()
613      .passthrough_auth(true)
614      .forward_headers(vec!["Custom1".to_string(), "Authorization".to_string()])
615      .build()
616      .unwrap();
617    let result = AuthBuilder::default().with_config(config).build().unwrap();
618
619    let forwarded_headers = result.forwarded_headers(&request_headers, None).unwrap();
620    assert_eq!(
621      forwarded_headers,
622      HeaderMap::from_iter([
623        (
624          format!("{}Custom1", FORWARD_HEADER_PREFIX).parse().unwrap(),
625          "Value".parse().unwrap()
626        ),
627        (
628          format!("{}Authorization", FORWARD_HEADER_PREFIX)
629            .parse()
630            .unwrap(),
631          "Bearer Value".parse().unwrap()
632        ),
633        (
634          "Authorization".parse().unwrap(),
635          "Bearer Value".parse().unwrap()
636        ),
637      ])
638    );
639
640    let config = builder
641      .clone()
642      .forward_headers(vec!["Custom1".to_string()])
643      .passthrough_auth(false)
644      .build()
645      .unwrap();
646    let result = AuthBuilder::default().with_config(config).build().unwrap();
647
648    let forwarded_headers = result.forwarded_headers(&request_headers, None).unwrap();
649    assert_eq!(
650      forwarded_headers,
651      HeaderMap::from_iter([(
652        format!("{}Custom1", FORWARD_HEADER_PREFIX).parse().unwrap(),
653        "Value".parse().unwrap()
654      ),])
655    );
656
657    let config = builder
658      .forward_extensions(vec![ForwardExtensions::new(
659        "$.Key".to_string(),
660        "Custom1".to_string(),
661      )])
662      .passthrough_auth(false)
663      .build()
664      .unwrap();
665    let result = AuthBuilder::default().with_config(config).build().unwrap();
666
667    let forwarded_headers = result
668      .forwarded_headers(
669        &request_headers,
670        Some(json!({
671          "Key": "Value"
672        })),
673      )
674      .unwrap();
675    assert_eq!(
676      forwarded_headers,
677      HeaderMap::from_iter([(
678        format!("{}Custom1", FORWARD_HEADER_PREFIX).parse().unwrap(),
679        "Value".parse().unwrap()
680      ),])
681    );
682  }
683
684  #[test]
685  fn validate_restrictions_reference_name_mismatch() {
686    let reference_restriction = ReferenceNameRestrictionBuilder::default()
687      .name("chr1")
688      .format(Format::Bam)
689      .build()
690      .unwrap();
691    let rule = AuthorizationRuleBuilder::default()
692      .location(test_location())
693      .reference_name(reference_restriction)
694      .build()
695      .unwrap();
696    let restrictions = AuthorizationRestrictionsBuilder::default()
697      .rule(rule.clone())
698      .build()
699      .unwrap();
700
701    let mut query = HashMap::new();
702    query.insert("class".to_string(), "header".to_string());
703    query.insert("format".to_string(), "BAM".to_string());
704
705    let request = create_test_query(Endpoint::Reads, "sample1", query);
706    let result =
707      Auth::validate_restrictions(restrictions, request.id(), &mut [request.clone()], false);
708    assert!(result.is_ok());
709  }
710
711  #[test]
712  fn validate_restrictions_header() {
713    let reference_restriction = ReferenceNameRestrictionBuilder::default()
714      .name("chr1")
715      .format(Format::Bam)
716      .build()
717      .unwrap();
718    let rule = AuthorizationRuleBuilder::default()
719      .location(test_location())
720      .reference_name(reference_restriction)
721      .build()
722      .unwrap();
723    let restrictions = AuthorizationRestrictionsBuilder::default()
724      .rule(rule.clone())
725      .build()
726      .unwrap();
727
728    let mut query = HashMap::new();
729    query.insert("format".to_string(), "BAM".to_string());
730    query.insert("class".to_string(), "header".to_string());
731
732    let request = create_test_query(Endpoint::Reads, "sample1", query);
733    let result =
734      Auth::validate_restrictions(restrictions, request.id(), &mut [request.clone()], false);
735    assert!(result.is_ok());
736  }
737
738  #[cfg(feature = "experimental")]
739  #[test]
740  fn validate_restrictions_reference_name_mismatch_suppressed() {
741    let reference_restriction = ReferenceNameRestrictionBuilder::default()
742      .name("chr1")
743      .format(Format::Bam)
744      .build()
745      .unwrap();
746    let rule = AuthorizationRuleBuilder::default()
747      .location(test_location())
748      .reference_name(reference_restriction)
749      .build()
750      .unwrap();
751    let restrictions = AuthorizationRestrictionsBuilder::default()
752      .rule(rule.clone())
753      .build()
754      .unwrap();
755
756    let mut query = HashMap::new();
757    query.insert("referenceName".to_string(), "chr2".to_string());
758    query.insert("format".to_string(), "BAM".to_string());
759
760    let request = create_test_query(Endpoint::Reads, "sample1", query);
761    let result =
762      Auth::validate_restrictions(restrictions, request.id(), &mut [request.clone()], true);
763    assert!(result.is_ok());
764  }
765
766  #[test]
767  fn validate_restrictions_format_mismatch() {
768    let reference_restriction = ReferenceNameRestrictionBuilder::default()
769      .name("chr1")
770      .format(Format::Bam)
771      .build()
772      .unwrap();
773    let rule = AuthorizationRuleBuilder::default()
774      .location(test_location())
775      .reference_name(reference_restriction)
776      .build()
777      .unwrap();
778    let restrictions = AuthorizationRestrictionsBuilder::default()
779      .rule(rule.clone())
780      .build()
781      .unwrap();
782
783    let mut query = HashMap::new();
784    query.insert("referenceName".to_string(), "chr1".to_string());
785    query.insert("format".to_string(), "CRAM".to_string());
786
787    let request = create_test_query(Endpoint::Reads, "sample1", query);
788    let result =
789      Auth::validate_restrictions(restrictions, request.id(), &mut [request.clone()], false);
790    assert!(result.is_err());
791  }
792
793  #[cfg(feature = "experimental")]
794  #[test]
795  fn validate_restrictions_format_mismatch_suppressed() {
796    let reference_restriction = ReferenceNameRestrictionBuilder::default()
797      .name("chr1")
798      .format(Format::Bam)
799      .build()
800      .unwrap();
801    let rule = AuthorizationRuleBuilder::default()
802      .location(test_location())
803      .reference_name(reference_restriction)
804      .build()
805      .unwrap();
806    let restrictions = AuthorizationRestrictionsBuilder::default()
807      .rule(rule.clone())
808      .build()
809      .unwrap();
810
811    let mut query = HashMap::new();
812    query.insert("referenceName".to_string(), "chr1".to_string());
813    query.insert("format".to_string(), "CRAM".to_string());
814
815    let request = create_test_query(Endpoint::Reads, "sample1", query);
816    let result =
817      Auth::validate_restrictions(restrictions, request.id(), &mut [request.clone()], true);
818    assert!(result.is_ok());
819  }
820
821  #[test]
822  fn validate_restrictions_interval_not_contained() {
823    // Restriction:       1000----------2000
824    // Request:               1250--1750
825    // Result:                1250--1750
826    test_interval_suppressed(
827      Some(1000),
828      Some(2000),
829      Some(1250),
830      Some(1750),
831      (Interval::new(Some(1250), Some(1750)), Class::Body),
832      false,
833      false,
834    );
835
836    // Restriction:       1000----------2000
837    // Request:   500------------------------------->
838    // Result:                   err
839    test_interval_suppressed(
840      Some(1000),
841      Some(2000),
842      Some(500),
843      None,
844      (Interval::new(Some(500), None), Class::Body),
845      true,
846      false,
847    );
848
849    // Restriction:       1000----------2000
850    // Request:   <------------------------------2500
851    // Result:                   err
852    test_interval_suppressed(
853      Some(1000),
854      Some(2000),
855      None,
856      Some(2500),
857      (Interval::new(None, Some(2500)), Class::Body),
858      true,
859      false,
860    );
861
862    // Restriction:       1000----------2000
863    // Request:   <--------------------------------->
864    // Result:                   err
865    test_interval_suppressed(
866      Some(1000),
867      Some(2000),
868      None,
869      None,
870      (Interval::new(None, None), Class::Body),
871      true,
872      false,
873    );
874
875    // Restriction:       1000----------2000
876    // Request:   500------------1500
877    // Result:                   err
878    test_interval_suppressed(
879      Some(1000),
880      Some(2000),
881      Some(500),
882      Some(1500),
883      (Interval::new(Some(500), Some(1500)), Class::Body),
884      true,
885      false,
886    );
887
888    // Restriction:       1000----------2000
889    // Request:   <--------------1500
890    // Result:                   err
891    test_interval_suppressed(
892      Some(1000),
893      Some(2000),
894      None,
895      Some(1500),
896      (Interval::new(None, Some(1500)), Class::Body),
897      true,
898      false,
899    );
900
901    // Restriction:       1000----------2000
902    // Request:                  1500------------2500
903    // Result:                   err
904    test_interval_suppressed(
905      Some(1000),
906      Some(2000),
907      Some(1500),
908      Some(2500),
909      (Interval::new(Some(1500), Some(2500)), Class::Body),
910      true,
911      false,
912    );
913
914    // Restriction:       1000----------2000
915    // Request:                  1500--------------->
916    // Result:                   err
917    test_interval_suppressed(
918      Some(1000),
919      Some(2000),
920      Some(1500),
921      None,
922      (Interval::new(Some(1500), None), Class::Body),
923      true,
924      false,
925    );
926
927    // Restriction:       1000----------2000
928    // Request:   500-----1000
929    // Result:                   err
930    test_interval_suppressed(
931      Some(1000),
932      Some(2000),
933      Some(500),
934      Some(1000),
935      (Interval::new(Some(500), Some(1000)), Class::Body),
936      true,
937      false,
938    );
939
940    // Restriction:       1000----------2000
941    // Request:                         2000-----2500
942    // Result:                   err
943    test_interval_suppressed(
944      Some(1000),
945      Some(2000),
946      Some(2000),
947      Some(2500),
948      (Interval::new(Some(2000), Some(2500)), Class::Body),
949      true,
950      false,
951    );
952
953    // Restriction:       <-------------2000
954    // Request:   500------------1500
955    // Result:    500------------1500
956    test_interval_suppressed(
957      None,
958      Some(2000),
959      Some(500),
960      Some(1500),
961      (Interval::new(Some(500), Some(1500)), Class::Body),
962      false,
963      false,
964    );
965
966    // Restriction:       <-------------2000
967    // Request:                  1500------------2500
968    // Result:                   err
969    test_interval_suppressed(
970      None,
971      Some(2000),
972      Some(1500),
973      Some(2500),
974      (Interval::new(Some(1500), Some(2500)), Class::Body),
975      true,
976      false,
977    );
978
979    // Restriction:       1000------------->
980    // Request:                  1500------------2500
981    // Result:                   1500------------2500
982    test_interval_suppressed(
983      Some(1000),
984      None,
985      Some(1500),
986      Some(2500),
987      (Interval::new(Some(1500), Some(2500)), Class::Body),
988      false,
989      false,
990    );
991
992    // Restriction:       1000------------->
993    // Request:   500------------1500
994    // Result:                   err
995    test_interval_suppressed(
996      Some(1000),
997      None,
998      Some(500),
999      Some(1500),
1000      (Interval::new(Some(500), Some(1500)), Class::Body),
1001      true,
1002      false,
1003    );
1004
1005    // Restriction:       <---------------->
1006    // Request:   500----------------------------2500
1007    // Result:    500----------------------------2500
1008    test_interval_suppressed(
1009      None,
1010      None,
1011      Some(500),
1012      Some(2500),
1013      (Interval::new(Some(500), Some(2500)), Class::Body),
1014      false,
1015      false,
1016    );
1017
1018    // Restriction:       <---------------->
1019    // Request:   500------------------------------->
1020    // Result:    500------------------------------->
1021    test_interval_suppressed(
1022      None,
1023      None,
1024      Some(500),
1025      None,
1026      (Interval::new(Some(500), None), Class::Body),
1027      false,
1028      false,
1029    );
1030
1031    // Restriction:       <---------------->
1032    // Request:   <------------------------------2500
1033    // Result:    <------------------------------2500
1034    test_interval_suppressed(
1035      None,
1036      None,
1037      None,
1038      Some(2500),
1039      (Interval::new(None, Some(2500)), Class::Body),
1040      false,
1041      false,
1042    );
1043  }
1044
1045  #[cfg(feature = "experimental")]
1046  #[test]
1047  fn validate_restrictions_interval_suppressed() {
1048    // Restriction:       1000----------2000
1049    // Request:               1250--1750
1050    // Result:                1250--1750
1051    test_interval_suppressed(
1052      Some(1000),
1053      Some(2000),
1054      Some(1250),
1055      Some(1750),
1056      (Interval::new(Some(1250), Some(1750)), Class::Body),
1057      false,
1058      true,
1059    );
1060
1061    // Restriction:       1000----------2000
1062    // Request:   500------------------------------->
1063    // Result:            1000----------2000
1064    test_interval_suppressed(
1065      Some(1000),
1066      Some(2000),
1067      Some(500),
1068      None,
1069      (Interval::new(Some(1000), Some(2000)), Class::Body),
1070      false,
1071      true,
1072    );
1073
1074    // Restriction:       1000----------2000
1075    // Request:   <------------------------------2500
1076    // Result:            1000----------2000
1077    test_interval_suppressed(
1078      Some(1000),
1079      Some(2000),
1080      None,
1081      Some(2500),
1082      (Interval::new(Some(1000), Some(2000)), Class::Body),
1083      false,
1084      true,
1085    );
1086
1087    // Restriction:       1000----------2000
1088    // Request:   <--------------------------------->
1089    // Result:            1000----------2000
1090    test_interval_suppressed(
1091      Some(1000),
1092      Some(2000),
1093      None,
1094      None,
1095      (Interval::new(Some(1000), Some(2000)), Class::Body),
1096      false,
1097      true,
1098    );
1099
1100    // Restriction:       1000----------2000
1101    // Request:   500------------1500
1102    // Result:            1000---1500
1103    test_interval_suppressed(
1104      Some(1000),
1105      Some(2000),
1106      Some(500),
1107      Some(1500),
1108      (Interval::new(Some(1000), Some(1500)), Class::Body),
1109      false,
1110      true,
1111    );
1112
1113    // Restriction:       1000----------2000
1114    // Request:   <--------------1500
1115    // Result:            1000---1500
1116    test_interval_suppressed(
1117      Some(1000),
1118      Some(2000),
1119      None,
1120      Some(1500),
1121      (Interval::new(Some(1000), Some(1500)), Class::Body),
1122      false,
1123      true,
1124    );
1125
1126    // Restriction:       1000----------2000
1127    // Request:                  1500------------2500
1128    // Result:                   1500---2000
1129    test_interval_suppressed(
1130      Some(1000),
1131      Some(2000),
1132      Some(1500),
1133      Some(2500),
1134      (Interval::new(Some(1500), Some(2000)), Class::Body),
1135      false,
1136      true,
1137    );
1138
1139    // Restriction:       1000----------2000
1140    // Request:                  1500--------------->
1141    // Result:                   1500---2000
1142    test_interval_suppressed(
1143      Some(1000),
1144      Some(2000),
1145      Some(1500),
1146      None,
1147      (Interval::new(Some(1500), Some(2000)), Class::Body),
1148      false,
1149      true,
1150    );
1151
1152    // Restriction:       1000----------2000
1153    // Request:   500-----1000
1154    // Result:            -
1155    test_interval_suppressed(
1156      Some(1000),
1157      Some(2000),
1158      Some(500),
1159      Some(1000),
1160      (Interval::new(Some(500), Some(1000)), Class::Header),
1161      false,
1162      true,
1163    );
1164
1165    // Restriction:       1000----------2000
1166    // Request:                         2000-----2500
1167    // Result:                          -
1168    test_interval_suppressed(
1169      Some(1000),
1170      Some(2000),
1171      Some(2000),
1172      Some(2500),
1173      (Interval::new(Some(2000), Some(2500)), Class::Header),
1174      false,
1175      true,
1176    );
1177
1178    // Restriction:       <-------------2000
1179    // Request:   500------------1500
1180    // Result:    500------------1500
1181    test_interval_suppressed(
1182      None,
1183      Some(2000),
1184      Some(500),
1185      Some(1500),
1186      (Interval::new(Some(500), Some(1500)), Class::Body),
1187      false,
1188      true,
1189    );
1190
1191    // Restriction:       <-------------2000
1192    // Request:                  1500------------2500
1193    // Result:                   1500---2000
1194    test_interval_suppressed(
1195      None,
1196      Some(2000),
1197      Some(1500),
1198      Some(2500),
1199      (Interval::new(Some(1500), Some(2000)), Class::Body),
1200      false,
1201      true,
1202    );
1203
1204    // Restriction:       1000------------->
1205    // Request:                  1500------------2500
1206    // Result:                   1500------------2500
1207    test_interval_suppressed(
1208      Some(1000),
1209      None,
1210      Some(1500),
1211      Some(2500),
1212      (Interval::new(Some(1500), Some(2500)), Class::Body),
1213      false,
1214      true,
1215    );
1216
1217    // Restriction:       1000------------->
1218    // Request:   500------------1500
1219    // Result:            1000---1500
1220    test_interval_suppressed(
1221      Some(1000),
1222      None,
1223      Some(500),
1224      Some(1500),
1225      (Interval::new(Some(1000), Some(1500)), Class::Body),
1226      false,
1227      true,
1228    );
1229
1230    // Restriction:       <---------------->
1231    // Request:   500----------------------------2500
1232    // Result:    500----------------------------2500
1233    test_interval_suppressed(
1234      None,
1235      None,
1236      Some(500),
1237      Some(2500),
1238      (Interval::new(Some(500), Some(2500)), Class::Body),
1239      false,
1240      true,
1241    );
1242
1243    // Restriction:       <---------------->
1244    // Request:   500------------------------------->
1245    // Result:    500------------------------------->
1246    test_interval_suppressed(
1247      None,
1248      None,
1249      Some(500),
1250      None,
1251      (Interval::new(Some(500), None), Class::Body),
1252      false,
1253      true,
1254    );
1255
1256    // Restriction:       <---------------->
1257    // Request:   <------------------------------2500
1258    // Result:    <------------------------------2500
1259    test_interval_suppressed(
1260      None,
1261      None,
1262      None,
1263      Some(2500),
1264      (Interval::new(None, Some(2500)), Class::Body),
1265      false,
1266      true,
1267    );
1268  }
1269
1270  #[test]
1271  fn validate_restrictions_format_none_allows_any() {
1272    let reference_restriction = ReferenceNameRestrictionBuilder::default()
1273      .name("chr1")
1274      .build()
1275      .unwrap();
1276    let rule = AuthorizationRuleBuilder::default()
1277      .location(test_location())
1278      .reference_name(reference_restriction)
1279      .build()
1280      .unwrap();
1281    let restrictions = AuthorizationRestrictionsBuilder::default()
1282      .rule(rule)
1283      .build()
1284      .unwrap();
1285
1286    let mut query = HashMap::new();
1287    query.insert("referenceName".to_string(), "chr1".to_string());
1288    query.insert("format".to_string(), "CRAM".to_string());
1289
1290    let request = create_test_query(Endpoint::Reads, "sample1", query);
1291    let result =
1292      Auth::validate_restrictions(restrictions, request.id(), &mut [request.clone()], false);
1293    assert!(result.is_ok());
1294  }
1295
1296  #[test]
1297  fn validate_restrictions_path_with_leading_slash() {
1298    let rule = AuthorizationRuleBuilder::default()
1299      .location(test_location())
1300      .build()
1301      .unwrap();
1302    let restrictions = AuthorizationRestrictionsBuilder::default()
1303      .rule(rule)
1304      .build()
1305      .unwrap();
1306    let request = create_test_query(Endpoint::Reads, "sample1", HashMap::new());
1307    let result =
1308      Auth::validate_restrictions(restrictions, request.id(), &mut [request.clone()], false);
1309    assert!(result.is_ok());
1310  }
1311
1312  #[tokio::test]
1313  async fn validate_authorization_missing_auth_header() {
1314    let auth = create_mock_auth_with_restrictions();
1315    let request = Request::new("sample1".to_string(), HashMap::new(), HeaderMap::new());
1316
1317    let result = auth.validate_jwt(request.headers()).await;
1318    assert!(result.is_err());
1319    assert!(matches!(
1320      result.unwrap_err(),
1321      HtsGetError::InvalidAuthentication(_)
1322    ));
1323  }
1324
1325  #[tokio::test]
1326  async fn validate_authorization_invalid_jwt_format() {
1327    let auth = create_mock_auth_with_restrictions();
1328    let request = create_request_with_auth_header("sample1", HashMap::new(), "invalid.jwt.token");
1329
1330    let result = auth.validate_jwt(request.headers()).await;
1331    assert!(result.is_err());
1332    assert!(matches!(
1333      result.unwrap_err(),
1334      HtsGetError::PermissionDenied(_)
1335    ));
1336  }
1337
1338  fn create_test_auth_config(public_key: Vec<u8>) -> AuthConfig {
1339    AuthConfigBuilder::default()
1340      .auth_mode(AuthMode::PublicKey(public_key))
1341      .authorization_url(UrlOrStatic::Url(Uri::from_static(
1342        "https://www.example.com",
1343      )))
1344      .http_client(HttpClient::new(
1345        ClientBuilder::new(reqwest::Client::new()).build(),
1346      ))
1347      .build()
1348      .unwrap()
1349  }
1350
1351  fn create_test_query(endpoint: Endpoint, path: &str, query: HashMap<String, String>) -> Query {
1352    let request = Request::new(path.to_string(), query, HeaderMap::new());
1353    let format = match_format_from_query(&endpoint, request.query()).unwrap();
1354
1355    convert_to_query(request, format).unwrap()
1356  }
1357
1358  fn create_request_with_auth_header(
1359    path: &str,
1360    query: HashMap<String, String>,
1361    token: &str,
1362  ) -> Request {
1363    let mut headers = HeaderMap::new();
1364    headers.insert("authorization", format!("Bearer {token}").parse().unwrap());
1365    Request::new(path.to_string(), query, headers)
1366  }
1367
1368  fn create_mock_auth_with_restrictions() -> Auth {
1369    let (_, public_key) = generate_key_pair();
1370
1371    let config = create_test_auth_config(public_key);
1372    AuthBuilder::default().with_config(config).build().unwrap()
1373  }
1374
1375  fn test_interval_suppressed(
1376    restrict_start: Option<u32>,
1377    restrict_end: Option<u32>,
1378    request_start: Option<u32>,
1379    request_end: Option<u32>,
1380    expected_response: (Interval, Class),
1381    is_err: bool,
1382    suppress_interval: bool,
1383  ) {
1384    let mut reference_restriction = ReferenceNameRestrictionBuilder::default()
1385      .name("chr1")
1386      .format(Format::Bam);
1387
1388    if let Some(start) = restrict_start {
1389      reference_restriction = reference_restriction.start(start);
1390    }
1391    if let Some(end) = restrict_end {
1392      reference_restriction = reference_restriction.end(end);
1393    }
1394
1395    let reference_restriction = reference_restriction.build().unwrap();
1396    let rule = AuthorizationRuleBuilder::default()
1397      .location(test_location())
1398      .reference_name(reference_restriction)
1399      .build()
1400      .unwrap();
1401    let restrictions = AuthorizationRestrictionsBuilder::default()
1402      .rule(rule.clone())
1403      .build()
1404      .unwrap();
1405
1406    let mut query = HashMap::new();
1407    query.insert("referenceName".to_string(), "chr1".to_string());
1408    request_start.map(|start| query.insert("start".to_string(), start.to_string()));
1409    request_end.map(|end| query.insert("end".to_string(), end.to_string()));
1410
1411    let request = create_test_query(Endpoint::Reads, "sample1", query);
1412    let id = request.id().to_string();
1413    let mut slice = [request];
1414    let result = Auth::validate_restrictions(restrictions, &id, &mut slice, suppress_interval);
1415    if is_err {
1416      assert!(result.is_err());
1417    } else {
1418      assert!(result.is_ok());
1419    }
1420    assert_eq!(slice.first().unwrap().interval(), expected_response.0);
1421    assert_eq!(slice.last().unwrap().class(), expected_response.1);
1422  }
1423
1424  fn test_location() -> Location {
1425    Location::Simple(Box::new(SimpleLocation::new(
1426      Default::default(),
1427      "".to_string(),
1428      Some(PrefixOrId::Id("sample1".to_string())),
1429    )))
1430  }
1431}