htsget_http/middleware/
auth.rs

1//! The htsget authorization middleware.
2//!
3
4use crate::error::Result as HtsGetResult;
5use crate::middleware::error::Error::AuthBuilderError;
6use crate::middleware::error::Result;
7use crate::{Endpoint, HtsGetError, convert_to_query, match_format_from_query};
8use headers::authorization::Bearer;
9use headers::{Authorization, Header};
10use htsget_config::config::advanced::auth::{AuthConfig, AuthMode, AuthorizationRestrictions};
11use htsget_config::types::Request;
12use http::{HeaderMap, Uri};
13use jsonpath_rust::JsonPath;
14use jsonwebtoken::jwk::JwkSet;
15use jsonwebtoken::{Algorithm, DecodingKey, TokenData, Validation, decode, decode_header};
16use regex::Regex;
17use serde::de::DeserializeOwned;
18use serde_json::Value;
19use std::collections::HashMap;
20
21/// Builder the the authorization middleware.
22#[derive(Default, Debug)]
23pub struct AuthBuilder {
24  config: Option<AuthConfig>,
25}
26
27impl AuthBuilder {
28  /// Set the config.
29  pub fn with_config(mut self, config: AuthConfig) -> Self {
30    self.config = Some(config);
31    self
32  }
33
34  /// Build the auth layer, ensures that the config sets the correct parameters.
35  pub fn build(self) -> Result<Auth> {
36    let Some(mut config) = self.config else {
37      return Err(AuthBuilderError("missing config".to_string()));
38    };
39
40    if config.trusted_authorization_urls().is_empty() {
41      return Err(AuthBuilderError(
42        "at least one trusted authorization url must be set".to_string(),
43      ));
44    }
45    if config.authorization_path().is_none() && config.trusted_authorization_urls().len() > 1 {
46      return Err(AuthBuilderError(
47        "only one trusted authorization url should be set when not using authorization paths"
48          .to_string(),
49      ));
50    }
51
52    let mut decoding_key = None;
53    if let AuthMode::PublicKey(public_key) = config.auth_mode_mut() {
54      decoding_key = Some(
55        Auth::decode_public_key(public_key)
56          .map_err(|_| AuthBuilderError("failed to decode public key".to_string()))?,
57      );
58    }
59
60    Ok(Auth {
61      config,
62      decoding_key,
63    })
64  }
65}
66
67/// The auth middleware layer.
68#[derive(Clone)]
69pub struct Auth {
70  config: AuthConfig,
71  decoding_key: Option<DecodingKey>,
72}
73
74impl Auth {
75  /// Fetch JWKS from the authorization server.
76  pub async fn fetch_from_url<D: DeserializeOwned>(
77    &self,
78    url: &str,
79    headers: HeaderMap,
80  ) -> HtsGetResult<D> {
81    let err = || HtsGetError::InternalError(format!("failed to fetch data from {url}"));
82    let response = self
83      .config
84      .http_client()
85      .get(url)
86      .headers(headers)
87      .send()
88      .await
89      .map_err(|_| err())?;
90
91    response.json().await.map_err(|_| err())
92  }
93
94  /// Get a decoding key form the JWKS url.
95  pub async fn decode_jwks(&self, jwks_url: &Uri, token: &str) -> HtsGetResult<DecodingKey> {
96    // Decode header and get the key id.
97    let header = decode_header(token)?;
98    let kid = header
99      .kid
100      .ok_or_else(|| HtsGetError::PermissionDenied("JWT missing key ID".to_string()))?;
101
102    // Fetch JWKS from the authorization server and find matching JWK.
103    let jwks = self
104      .fetch_from_url::<JwkSet>(&jwks_url.to_string(), Default::default())
105      .await?;
106    let matched_jwk = jwks
107      .find(&kid)
108      .ok_or_else(|| HtsGetError::PermissionDenied("matching JWK not found".to_string()))?;
109
110    Ok(DecodingKey::from_jwk(matched_jwk)?)
111  }
112
113  /// Decode a public key into an RSA, EdDSA or ECDSA pem-formatted decoding key.
114  pub fn decode_public_key(key: &[u8]) -> HtsGetResult<DecodingKey> {
115    Ok(
116      DecodingKey::from_rsa_pem(key)
117        .or_else(|_| DecodingKey::from_ed_pem(key))
118        .or_else(|_| DecodingKey::from_ec_pem(key))?,
119    )
120  }
121
122  /// Query the authorization service to get the restrictions. This function validates
123  /// that the authorization url is trusted in the config settings before calling the
124  /// service. The claims are assumed to be valid.
125  pub async fn query_authorization_service(
126    &self,
127    claims: Value,
128    headers: &HeaderMap,
129  ) -> HtsGetResult<AuthorizationRestrictions> {
130    let query_url = match self.config.authorization_path() {
131      None => self
132        .config
133        .trusted_authorization_urls()
134        .first()
135        .ok_or_else(|| {
136          HtsGetError::InternalError("missing trusted authorization url".to_string())
137        })?,
138      Some(path) => {
139        // Extract the url from the path.
140        let path = claims.query(path).map_err(|err| {
141          HtsGetError::InvalidAuthentication(format!(
142            "failed to find authorization service in claims: {err}",
143          ))
144        })?;
145        let url = path
146          .first()
147          .ok_or_else(|| {
148            HtsGetError::InvalidAuthentication(
149              "expected one value for authorization service in claims".to_string(),
150            )
151          })?
152          .as_str()
153          .ok_or_else(|| {
154            HtsGetError::InvalidAuthentication(
155              "expected string value for authorization service in claims".to_string(),
156            )
157          })?;
158        &url.parse::<Uri>().map_err(|err| {
159          HtsGetError::InvalidAuthentication(format!("failed to parse authorization url: {err}"))
160        })?
161      }
162    };
163
164    // Ensure that the authorization url is trusted.
165    if !self.config.trusted_authorization_urls().contains(query_url) {
166      return Err(HtsGetError::PermissionDenied(
167        "authorization service in claims not a trusted authorization url".to_string(),
168      ));
169    };
170
171    let auth_header = headers
172      .iter()
173      .find_map(|(name, value)| {
174        if Authorization::<Bearer>::decode(&mut [value].into_iter()).is_ok() {
175          return Some((name.clone(), value.clone()));
176        }
177
178        None
179      })
180      .ok_or_else(|| HtsGetError::PermissionDenied("missing authorization header".to_string()))?;
181
182    self
183      .fetch_from_url(&query_url.to_string(), HeaderMap::from_iter([auth_header]))
184      .await
185  }
186
187  /// Validate the restrictions, returning an error if the user is not authorized.
188  pub fn validate_restrictions(
189    restrictions: AuthorizationRestrictions,
190    request: Request,
191    endpoint: Endpoint,
192  ) -> HtsGetResult<()> {
193    // Find all rules matching the path.
194    let matching_rules = restrictions
195      .htsget_auth()
196      .iter()
197      .filter(|rule| {
198        // If this path is a direct match then just return that.
199        if rule.path().strip_prefix("/").unwrap_or(rule.path())
200          == request.path().strip_prefix("/").unwrap_or(request.path())
201        {
202          return true;
203        }
204
205        // Otherwise, try and parse it as a regex.
206        Regex::new(rule.path()).is_ok_and(|regex| regex.is_match(request.path()))
207      })
208      .collect::<Vec<_>>();
209
210    // If any of the rules allow all reference names (nothing set in the rule) then the user is authorized.
211    if matching_rules
212      .iter()
213      .any(|rule| rule.reference_names().is_none())
214    {
215      return Ok(());
216    }
217
218    let format = match_format_from_query(&endpoint, request.query())?;
219    let query = convert_to_query(request, format)?;
220    let matching_restriction = matching_rules
221      .iter()
222      .flat_map(|rule| rule.reference_names().unwrap_or_default())
223      .find(|restriction| {
224        // The reference name should match exactly.
225        let name_match = Some(restriction.name()) == query.reference_name();
226        // The format should match if it's defined, otherwise it allows any format.
227        let format_match =
228          restriction.format().is_none() || restriction.format() == Some(query.format());
229        // The interval should match exactly, considering undefined start or end ranges.
230        let interval_match = restriction.interval().contains_interval(query.interval());
231
232        name_match && format_match && interval_match
233      });
234
235    // If the matching rule with the restriction was found, then the user is authorized, otherwise
236    // it is a permission denied response.
237    if matching_restriction.is_some() {
238      Ok(())
239    } else {
240      Err(HtsGetError::PermissionDenied(
241        "failed to authorize user based on authorization service restrictions".to_string(),
242      ))
243    }
244  }
245
246  /// Validate only the JWT without looking up restrictions and validating those. Returns the
247  /// decoded JWT token.
248  pub async fn validate_jwt(&self, headers: &HeaderMap) -> HtsGetResult<TokenData<Value>> {
249    let auth_token = headers
250      .values()
251      .find_map(|value| Authorization::<Bearer>::decode(&mut [value].into_iter()).ok())
252      .ok_or_else(|| {
253        HtsGetError::InvalidAuthentication("invalid authorization header".to_string())
254      })?;
255
256    let decoding_key = if let Some(ref decoding_key) = self.decoding_key {
257      decoding_key
258    } else {
259      match self.config.auth_mode() {
260        AuthMode::Jwks(jwks) => &self.decode_jwks(jwks, auth_token.token()).await?,
261        AuthMode::PublicKey(public_key) => &Self::decode_public_key(public_key)?,
262      }
263    };
264
265    // Decode and validate the JWT
266    let mut validation = Validation::default();
267    validation.validate_exp = true;
268    validation.validate_aud = true;
269    validation.validate_nbf = true;
270
271    if let Some(iss) = self.config.validate_issuer() {
272      validation.set_issuer(iss);
273      validation.required_spec_claims.insert("iss".to_string());
274    }
275    if let Some(aud) = self.config.validate_audience() {
276      validation.set_audience(aud);
277      validation.required_spec_claims.insert("aud".to_string());
278    }
279    if let Some(sub) = self.config.validate_subject() {
280      validation.sub = Some(sub.to_string());
281      validation.required_spec_claims.insert("sub".to_string());
282    }
283
284    // Each supported algorithm must be tried individually because the jsonwebtoken validation
285    // logic only tries one algorithm in the vec: https://github.com/Keats/jsonwebtoken/issues/297
286    validation.algorithms = vec![Algorithm::RS256];
287    let decoded_claims = decode::<Value>(auth_token.token(), decoding_key, &validation)
288      .or_else(|_| {
289        validation.algorithms = vec![Algorithm::ES256];
290        decode::<Value>(auth_token.token(), decoding_key, &validation)
291      })
292      .or_else(|_| {
293        validation.algorithms = vec![Algorithm::EdDSA];
294        decode::<Value>(auth_token.token(), decoding_key, &validation)
295      });
296
297    let claims = match decoded_claims {
298      Ok(claims) => claims,
299      Err(err) => return Err(HtsGetError::PermissionDenied(format!("invalid JWT: {err}"))),
300    };
301
302    Ok(claims)
303  }
304
305  /// Validate the authorization flow, returning an error if the user is not authorized.
306  /// This performs the following steps:
307  ///
308  /// 1. Finds the JWT decoding key from the config or by querying a JWKS url.
309  /// 2. Validates the JWT token according to the config.
310  /// 3. Queries the authorization service for restrictions based on the config or JWT claims.
311  /// 4. Validates the restrictions to determine if the user is authorized.
312  pub async fn validate_authorization(
313    &self,
314    request: Request,
315    endpoint: Endpoint,
316  ) -> HtsGetResult<()> {
317    let claims = self.validate_jwt(request.headers()).await?;
318    if self.config.authentication_only() {
319      return Ok(());
320    }
321
322    let restrictions = self
323      .query_authorization_service(claims.claims, request.headers())
324      .await?;
325    Self::validate_restrictions(restrictions, request, endpoint)
326  }
327
328  /// Validate authorization flow according to `validate_authorization` and `validate_jwt`.
329  /// This function will only `validate_jwt` on service-info requests, and will otherwise
330  /// `validate_authorization` for htsget requests to `/reads` or `/variants`.
331  pub async fn authorize_request(
332    &self,
333    path: &str,
334    query: HashMap<String, String>,
335    headers: HeaderMap,
336  ) -> HtsGetResult<()> {
337    if let Some(reads) = path.strip_prefix("/reads") {
338      if reads.starts_with("/service-info") {
339        self.validate_jwt(&headers).await?;
340        return Ok(());
341      }
342
343      self
344        .validate_authorization(
345          Request::new(path.to_string(), query, headers),
346          Endpoint::Reads,
347        )
348        .await
349    } else if let Some(variants) = path.strip_prefix("/variants") {
350      if variants.starts_with("/service-info") {
351        self.validate_jwt(&headers).await?;
352        return Ok(());
353      }
354
355      self
356        .validate_authorization(
357          Request::new(path.to_string(), query, headers),
358          Endpoint::Variants,
359        )
360        .await
361    } else {
362      self.validate_jwt(&headers).await?;
363      Ok(())
364    }
365  }
366}
367
368#[cfg(test)]
369mod tests {
370  use super::*;
371  use htsget_config::config::advanced::HttpClient;
372  use htsget_config::config::advanced::auth::response::{
373    AuthorizationRestrictionsBuilder, AuthorizationRuleBuilder, ReferenceNameRestrictionBuilder,
374  };
375  use htsget_config::config::advanced::auth::{
376    AuthConfigBuilder, AuthMode, AuthorizationRestrictions,
377  };
378  use htsget_config::types::Format;
379  use htsget_test::util::generate_key_pair;
380  use http::{HeaderMap, Uri};
381  use std::collections::HashMap;
382
383  #[test]
384  fn auth_builder_missing_config() {
385    let result = AuthBuilder::default().build();
386    assert!(matches!(result, Err(AuthBuilderError(_))));
387  }
388
389  #[test]
390  fn auth_builder_success_with_public_key() {
391    let (_, public_key) = generate_key_pair();
392
393    let config = create_test_auth_config(public_key);
394    let result = AuthBuilder::default().with_config(config).build();
395    assert!(result.is_ok());
396  }
397
398  #[test]
399  fn validate_restrictions_rule_allows_all() {
400    let rule = AuthorizationRuleBuilder::default()
401      .path("/reads/sample1")
402      .build()
403      .unwrap();
404    let restrictions = AuthorizationRestrictionsBuilder::default()
405      .rule(rule)
406      .build()
407      .unwrap();
408
409    let request = create_test_request("/reads/sample1", HashMap::new());
410    let result = Auth::validate_restrictions(restrictions, request, Endpoint::Reads);
411    assert!(result.is_ok());
412  }
413
414  #[test]
415  fn validate_restrictions_exact_path_match() {
416    let reference_restriction = ReferenceNameRestrictionBuilder::default()
417      .name("chr1")
418      .format(Format::Bam)
419      .start(1000)
420      .end(2000)
421      .build()
422      .unwrap();
423    let rule = AuthorizationRuleBuilder::default()
424      .path("/reads/sample1")
425      .reference_name(reference_restriction)
426      .build()
427      .unwrap();
428    let restrictions = AuthorizationRestrictionsBuilder::default()
429      .rule(rule)
430      .build()
431      .unwrap();
432
433    let mut query = HashMap::new();
434    query.insert("referenceName".to_string(), "chr1".to_string());
435    query.insert("start".to_string(), "1500".to_string());
436    query.insert("end".to_string(), "1800".to_string());
437    query.insert("format".to_string(), "BAM".to_string());
438
439    let request = create_test_request("/reads/sample1", query);
440    let result = Auth::validate_restrictions(restrictions, request, Endpoint::Reads);
441    assert!(result.is_ok());
442  }
443
444  #[test]
445  fn validate_restrictions_regex_path_match() {
446    let reference_restriction = ReferenceNameRestrictionBuilder::default()
447      .name("chr1")
448      .format(Format::Bam)
449      .build()
450      .unwrap();
451    let rule = AuthorizationRuleBuilder::default()
452      .path("/reads/sample(.+)")
453      .reference_name(reference_restriction)
454      .build()
455      .unwrap();
456    let restrictions = AuthorizationRestrictionsBuilder::default()
457      .rule(rule)
458      .build()
459      .unwrap();
460
461    let mut query = HashMap::new();
462    query.insert("referenceName".to_string(), "chr1".to_string());
463    query.insert("format".to_string(), "BAM".to_string());
464
465    let request = create_test_request("/reads/sample123", query);
466    let result = Auth::validate_restrictions(restrictions, request, Endpoint::Reads);
467    assert!(result.is_ok());
468  }
469
470  #[test]
471  fn validate_restrictions_reference_name_mismatch() {
472    let reference_restriction = ReferenceNameRestrictionBuilder::default()
473      .name("chr1")
474      .format(Format::Bam)
475      .build()
476      .unwrap();
477    let rule = AuthorizationRuleBuilder::default()
478      .path("/reads/sample1")
479      .reference_name(reference_restriction)
480      .build()
481      .unwrap();
482    let restrictions = AuthorizationRestrictionsBuilder::default()
483      .rule(rule)
484      .build()
485      .unwrap();
486
487    let mut query = HashMap::new();
488    query.insert("referenceName".to_string(), "chr2".to_string());
489    query.insert("format".to_string(), "BAM".to_string());
490
491    let request = create_test_request("/reads/sample1", query);
492    let result = Auth::validate_restrictions(restrictions, request, Endpoint::Reads);
493    assert!(result.is_err());
494  }
495
496  #[test]
497  fn validate_restrictions_format_mismatch() {
498    let reference_restriction = ReferenceNameRestrictionBuilder::default()
499      .name("chr1")
500      .format(Format::Bam)
501      .build()
502      .unwrap();
503    let rule = AuthorizationRuleBuilder::default()
504      .path("/reads/sample1")
505      .reference_name(reference_restriction)
506      .build()
507      .unwrap();
508    let restrictions = AuthorizationRestrictionsBuilder::default()
509      .rule(rule)
510      .build()
511      .unwrap();
512
513    let mut query = HashMap::new();
514    query.insert("referenceName".to_string(), "chr1".to_string());
515    query.insert("format".to_string(), "CRAM".to_string());
516
517    let request = create_test_request("/reads/sample1", query);
518    let result = Auth::validate_restrictions(restrictions, request, Endpoint::Reads);
519    assert!(result.is_err());
520  }
521
522  #[test]
523  fn validate_restrictions_interval_not_contained() {
524    let reference_restriction = ReferenceNameRestrictionBuilder::default()
525      .name("chr1")
526      .format(Format::Bam)
527      .start(1000)
528      .end(2000)
529      .build()
530      .unwrap();
531    let rule = AuthorizationRuleBuilder::default()
532      .path("/reads/sample1")
533      .reference_name(reference_restriction)
534      .build()
535      .unwrap();
536    let restrictions = AuthorizationRestrictionsBuilder::default()
537      .rule(rule)
538      .build()
539      .unwrap();
540
541    let mut query = HashMap::new();
542    query.insert("referenceName".to_string(), "chr1".to_string());
543    query.insert("start".to_string(), "500".to_string());
544    query.insert("end".to_string(), "1500".to_string());
545
546    let request = create_test_request("/reads/sample1", query);
547    let result = Auth::validate_restrictions(restrictions, request, Endpoint::Reads);
548    assert!(result.is_err());
549  }
550
551  #[test]
552  fn validate_restrictions_format_none_allows_any() {
553    let reference_restriction = ReferenceNameRestrictionBuilder::default()
554      .name("chr1")
555      .build()
556      .unwrap();
557    let rule = AuthorizationRuleBuilder::default()
558      .path("/reads/sample1")
559      .reference_name(reference_restriction)
560      .build()
561      .unwrap();
562    let restrictions = AuthorizationRestrictionsBuilder::default()
563      .rule(rule)
564      .build()
565      .unwrap();
566
567    let mut query = HashMap::new();
568    query.insert("referenceName".to_string(), "chr1".to_string());
569    query.insert("format".to_string(), "CRAM".to_string());
570
571    let request = create_test_request("/reads/sample1", query);
572    let result = Auth::validate_restrictions(restrictions, request, Endpoint::Reads);
573    assert!(result.is_ok());
574  }
575
576  #[test]
577  fn validate_restrictions_path_with_leading_slash() {
578    let rule = AuthorizationRuleBuilder::default()
579      .path("/reads/sample1")
580      .build()
581      .unwrap();
582    let restrictions = AuthorizationRestrictionsBuilder::default()
583      .rule(rule)
584      .build()
585      .unwrap();
586    let request = create_test_request("/reads/sample1", HashMap::new());
587    let result = Auth::validate_restrictions(restrictions, request, Endpoint::Reads);
588    assert!(result.is_ok());
589  }
590
591  #[tokio::test]
592  async fn validate_authorization_missing_auth_header() {
593    let (auth, _) = create_mock_auth_with_restrictions();
594    let request = create_test_request("/reads/sample1", HashMap::new());
595
596    let result = auth.validate_authorization(request, Endpoint::Reads).await;
597    assert!(result.is_err());
598    assert!(matches!(
599      result.unwrap_err(),
600      HtsGetError::InvalidAuthentication(_)
601    ));
602  }
603
604  #[tokio::test]
605  async fn validate_authorization_invalid_jwt_format() {
606    let (auth, _) = create_mock_auth_with_restrictions();
607    let request =
608      create_request_with_auth_header("/reads/sample1", HashMap::new(), "invalid.jwt.token");
609
610    let result = auth.validate_authorization(request, Endpoint::Reads).await;
611    assert!(result.is_err());
612    assert!(matches!(
613      result.unwrap_err(),
614      HtsGetError::PermissionDenied(_)
615    ));
616  }
617
618  fn create_test_auth_config(public_key: Vec<u8>) -> AuthConfig {
619    AuthConfigBuilder::default()
620      .auth_mode(AuthMode::PublicKey(public_key))
621      .trusted_authorization_url(Uri::from_static("https://www.example.com"))
622      .http_client(HttpClient::new(reqwest::Client::new()))
623      .build()
624      .unwrap()
625  }
626
627  fn create_test_request(path: &str, query: HashMap<String, String>) -> Request {
628    Request::new(path.to_string(), query, HeaderMap::new())
629  }
630
631  fn create_request_with_auth_header(
632    path: &str,
633    query: HashMap<String, String>,
634    token: &str,
635  ) -> Request {
636    let mut headers = HeaderMap::new();
637    headers.insert("authorization", format!("Bearer {token}").parse().unwrap());
638    Request::new(path.to_string(), query, headers)
639  }
640
641  fn create_mock_auth_with_restrictions() -> (Auth, AuthorizationRestrictions) {
642    let (_, public_key) = generate_key_pair();
643
644    let config = create_test_auth_config(public_key);
645    let auth = AuthBuilder::default().with_config(config).build().unwrap();
646
647    let reference_restriction = ReferenceNameRestrictionBuilder::default()
648      .name("chr1")
649      .format(Format::Bam)
650      .start(1000)
651      .end(2000)
652      .build()
653      .unwrap();
654    let rule = AuthorizationRuleBuilder::default()
655      .path("/reads/sample1")
656      .reference_name(reference_restriction)
657      .build()
658      .unwrap();
659    let restrictions = AuthorizationRestrictionsBuilder::default()
660      .rule(rule)
661      .build()
662      .unwrap();
663
664    (auth, restrictions)
665  }
666}