1use crate::error::Result as HtsGetResult;
5use crate::middleware::error::Error::AuthBuilderError;
6use crate::middleware::error::Result;
7use crate::{Endpoint, HtsGetError, convert_to_query, match_format_from_query};
8use headers::authorization::Bearer;
9use headers::{Authorization, Header};
10use htsget_config::config::advanced::auth::{AuthConfig, AuthMode, AuthorizationRestrictions};
11use htsget_config::types::Request;
12use http::{HeaderMap, Uri};
13use jsonpath_rust::JsonPath;
14use jsonwebtoken::jwk::JwkSet;
15use jsonwebtoken::{Algorithm, DecodingKey, TokenData, Validation, decode, decode_header};
16use regex::Regex;
17use serde::de::DeserializeOwned;
18use serde_json::Value;
19use std::collections::HashMap;
20
21#[derive(Default, Debug)]
23pub struct AuthBuilder {
24 config: Option<AuthConfig>,
25}
26
27impl AuthBuilder {
28 pub fn with_config(mut self, config: AuthConfig) -> Self {
30 self.config = Some(config);
31 self
32 }
33
34 pub fn build(self) -> Result<Auth> {
36 let Some(mut config) = self.config else {
37 return Err(AuthBuilderError("missing config".to_string()));
38 };
39
40 if config.trusted_authorization_urls().is_empty() {
41 return Err(AuthBuilderError(
42 "at least one trusted authorization url must be set".to_string(),
43 ));
44 }
45 if config.authorization_path().is_none() && config.trusted_authorization_urls().len() > 1 {
46 return Err(AuthBuilderError(
47 "only one trusted authorization url should be set when not using authorization paths"
48 .to_string(),
49 ));
50 }
51
52 let mut decoding_key = None;
53 if let AuthMode::PublicKey(public_key) = config.auth_mode_mut() {
54 decoding_key = Some(
55 Auth::decode_public_key(public_key)
56 .map_err(|_| AuthBuilderError("failed to decode public key".to_string()))?,
57 );
58 }
59
60 Ok(Auth {
61 config,
62 decoding_key,
63 })
64 }
65}
66
67#[derive(Clone)]
69pub struct Auth {
70 config: AuthConfig,
71 decoding_key: Option<DecodingKey>,
72}
73
74impl Auth {
75 pub async fn fetch_from_url<D: DeserializeOwned>(
77 &self,
78 url: &str,
79 headers: HeaderMap,
80 ) -> HtsGetResult<D> {
81 let err = || HtsGetError::InternalError(format!("failed to fetch data from {url}"));
82 let response = self
83 .config
84 .http_client()
85 .get(url)
86 .headers(headers)
87 .send()
88 .await
89 .map_err(|_| err())?;
90
91 response.json().await.map_err(|_| err())
92 }
93
94 pub async fn decode_jwks(&self, jwks_url: &Uri, token: &str) -> HtsGetResult<DecodingKey> {
96 let header = decode_header(token)?;
98 let kid = header
99 .kid
100 .ok_or_else(|| HtsGetError::PermissionDenied("JWT missing key ID".to_string()))?;
101
102 let jwks = self
104 .fetch_from_url::<JwkSet>(&jwks_url.to_string(), Default::default())
105 .await?;
106 let matched_jwk = jwks
107 .find(&kid)
108 .ok_or_else(|| HtsGetError::PermissionDenied("matching JWK not found".to_string()))?;
109
110 Ok(DecodingKey::from_jwk(matched_jwk)?)
111 }
112
113 pub fn decode_public_key(key: &[u8]) -> HtsGetResult<DecodingKey> {
115 Ok(
116 DecodingKey::from_rsa_pem(key)
117 .or_else(|_| DecodingKey::from_ed_pem(key))
118 .or_else(|_| DecodingKey::from_ec_pem(key))?,
119 )
120 }
121
122 pub async fn query_authorization_service(
126 &self,
127 claims: Value,
128 headers: &HeaderMap,
129 ) -> HtsGetResult<AuthorizationRestrictions> {
130 let query_url = match self.config.authorization_path() {
131 None => self
132 .config
133 .trusted_authorization_urls()
134 .first()
135 .ok_or_else(|| {
136 HtsGetError::InternalError("missing trusted authorization url".to_string())
137 })?,
138 Some(path) => {
139 let path = claims.query(path).map_err(|err| {
141 HtsGetError::InvalidAuthentication(format!(
142 "failed to find authorization service in claims: {err}",
143 ))
144 })?;
145 let url = path
146 .first()
147 .ok_or_else(|| {
148 HtsGetError::InvalidAuthentication(
149 "expected one value for authorization service in claims".to_string(),
150 )
151 })?
152 .as_str()
153 .ok_or_else(|| {
154 HtsGetError::InvalidAuthentication(
155 "expected string value for authorization service in claims".to_string(),
156 )
157 })?;
158 &url.parse::<Uri>().map_err(|err| {
159 HtsGetError::InvalidAuthentication(format!("failed to parse authorization url: {err}"))
160 })?
161 }
162 };
163
164 if !self.config.trusted_authorization_urls().contains(query_url) {
166 return Err(HtsGetError::PermissionDenied(
167 "authorization service in claims not a trusted authorization url".to_string(),
168 ));
169 };
170
171 let auth_header = headers
172 .iter()
173 .find_map(|(name, value)| {
174 if Authorization::<Bearer>::decode(&mut [value].into_iter()).is_ok() {
175 return Some((name.clone(), value.clone()));
176 }
177
178 None
179 })
180 .ok_or_else(|| HtsGetError::PermissionDenied("missing authorization header".to_string()))?;
181
182 self
183 .fetch_from_url(&query_url.to_string(), HeaderMap::from_iter([auth_header]))
184 .await
185 }
186
187 pub fn validate_restrictions(
189 restrictions: AuthorizationRestrictions,
190 request: Request,
191 endpoint: Endpoint,
192 ) -> HtsGetResult<()> {
193 let matching_rules = restrictions
195 .htsget_auth()
196 .iter()
197 .filter(|rule| {
198 if rule.path().strip_prefix("/").unwrap_or(rule.path())
200 == request.path().strip_prefix("/").unwrap_or(request.path())
201 {
202 return true;
203 }
204
205 Regex::new(rule.path()).is_ok_and(|regex| regex.is_match(request.path()))
207 })
208 .collect::<Vec<_>>();
209
210 if matching_rules
212 .iter()
213 .any(|rule| rule.reference_names().is_none())
214 {
215 return Ok(());
216 }
217
218 let format = match_format_from_query(&endpoint, request.query())?;
219 let query = convert_to_query(request, format)?;
220 let matching_restriction = matching_rules
221 .iter()
222 .flat_map(|rule| rule.reference_names().unwrap_or_default())
223 .find(|restriction| {
224 let name_match = Some(restriction.name()) == query.reference_name();
226 let format_match =
228 restriction.format().is_none() || restriction.format() == Some(query.format());
229 let interval_match = restriction.interval().contains_interval(query.interval());
231
232 name_match && format_match && interval_match
233 });
234
235 if matching_restriction.is_some() {
238 Ok(())
239 } else {
240 Err(HtsGetError::PermissionDenied(
241 "failed to authorize user based on authorization service restrictions".to_string(),
242 ))
243 }
244 }
245
246 pub async fn validate_jwt(&self, headers: &HeaderMap) -> HtsGetResult<TokenData<Value>> {
249 let auth_token = headers
250 .values()
251 .find_map(|value| Authorization::<Bearer>::decode(&mut [value].into_iter()).ok())
252 .ok_or_else(|| {
253 HtsGetError::InvalidAuthentication("invalid authorization header".to_string())
254 })?;
255
256 let decoding_key = if let Some(ref decoding_key) = self.decoding_key {
257 decoding_key
258 } else {
259 match self.config.auth_mode() {
260 AuthMode::Jwks(jwks) => &self.decode_jwks(jwks, auth_token.token()).await?,
261 AuthMode::PublicKey(public_key) => &Self::decode_public_key(public_key)?,
262 }
263 };
264
265 let mut validation = Validation::default();
267 validation.validate_exp = true;
268 validation.validate_aud = true;
269 validation.validate_nbf = true;
270
271 if let Some(iss) = self.config.validate_issuer() {
272 validation.set_issuer(iss);
273 validation.required_spec_claims.insert("iss".to_string());
274 }
275 if let Some(aud) = self.config.validate_audience() {
276 validation.set_audience(aud);
277 validation.required_spec_claims.insert("aud".to_string());
278 }
279 if let Some(sub) = self.config.validate_subject() {
280 validation.sub = Some(sub.to_string());
281 validation.required_spec_claims.insert("sub".to_string());
282 }
283
284 validation.algorithms = vec![Algorithm::RS256];
287 let decoded_claims = decode::<Value>(auth_token.token(), decoding_key, &validation)
288 .or_else(|_| {
289 validation.algorithms = vec![Algorithm::ES256];
290 decode::<Value>(auth_token.token(), decoding_key, &validation)
291 })
292 .or_else(|_| {
293 validation.algorithms = vec![Algorithm::EdDSA];
294 decode::<Value>(auth_token.token(), decoding_key, &validation)
295 });
296
297 let claims = match decoded_claims {
298 Ok(claims) => claims,
299 Err(err) => return Err(HtsGetError::PermissionDenied(format!("invalid JWT: {err}"))),
300 };
301
302 Ok(claims)
303 }
304
305 pub async fn validate_authorization(
313 &self,
314 request: Request,
315 endpoint: Endpoint,
316 ) -> HtsGetResult<()> {
317 let claims = self.validate_jwt(request.headers()).await?;
318 if self.config.authentication_only() {
319 return Ok(());
320 }
321
322 let restrictions = self
323 .query_authorization_service(claims.claims, request.headers())
324 .await?;
325 Self::validate_restrictions(restrictions, request, endpoint)
326 }
327
328 pub async fn authorize_request(
332 &self,
333 path: &str,
334 query: HashMap<String, String>,
335 headers: HeaderMap,
336 ) -> HtsGetResult<()> {
337 if let Some(reads) = path.strip_prefix("/reads") {
338 if reads.starts_with("/service-info") {
339 self.validate_jwt(&headers).await?;
340 return Ok(());
341 }
342
343 self
344 .validate_authorization(
345 Request::new(path.to_string(), query, headers),
346 Endpoint::Reads,
347 )
348 .await
349 } else if let Some(variants) = path.strip_prefix("/variants") {
350 if variants.starts_with("/service-info") {
351 self.validate_jwt(&headers).await?;
352 return Ok(());
353 }
354
355 self
356 .validate_authorization(
357 Request::new(path.to_string(), query, headers),
358 Endpoint::Variants,
359 )
360 .await
361 } else {
362 self.validate_jwt(&headers).await?;
363 Ok(())
364 }
365 }
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371 use htsget_config::config::advanced::HttpClient;
372 use htsget_config::config::advanced::auth::response::{
373 AuthorizationRestrictionsBuilder, AuthorizationRuleBuilder, ReferenceNameRestrictionBuilder,
374 };
375 use htsget_config::config::advanced::auth::{
376 AuthConfigBuilder, AuthMode, AuthorizationRestrictions,
377 };
378 use htsget_config::types::Format;
379 use htsget_test::util::generate_key_pair;
380 use http::{HeaderMap, Uri};
381 use std::collections::HashMap;
382
383 #[test]
384 fn auth_builder_missing_config() {
385 let result = AuthBuilder::default().build();
386 assert!(matches!(result, Err(AuthBuilderError(_))));
387 }
388
389 #[test]
390 fn auth_builder_success_with_public_key() {
391 let (_, public_key) = generate_key_pair();
392
393 let config = create_test_auth_config(public_key);
394 let result = AuthBuilder::default().with_config(config).build();
395 assert!(result.is_ok());
396 }
397
398 #[test]
399 fn validate_restrictions_rule_allows_all() {
400 let rule = AuthorizationRuleBuilder::default()
401 .path("/reads/sample1")
402 .build()
403 .unwrap();
404 let restrictions = AuthorizationRestrictionsBuilder::default()
405 .rule(rule)
406 .build()
407 .unwrap();
408
409 let request = create_test_request("/reads/sample1", HashMap::new());
410 let result = Auth::validate_restrictions(restrictions, request, Endpoint::Reads);
411 assert!(result.is_ok());
412 }
413
414 #[test]
415 fn validate_restrictions_exact_path_match() {
416 let reference_restriction = ReferenceNameRestrictionBuilder::default()
417 .name("chr1")
418 .format(Format::Bam)
419 .start(1000)
420 .end(2000)
421 .build()
422 .unwrap();
423 let rule = AuthorizationRuleBuilder::default()
424 .path("/reads/sample1")
425 .reference_name(reference_restriction)
426 .build()
427 .unwrap();
428 let restrictions = AuthorizationRestrictionsBuilder::default()
429 .rule(rule)
430 .build()
431 .unwrap();
432
433 let mut query = HashMap::new();
434 query.insert("referenceName".to_string(), "chr1".to_string());
435 query.insert("start".to_string(), "1500".to_string());
436 query.insert("end".to_string(), "1800".to_string());
437 query.insert("format".to_string(), "BAM".to_string());
438
439 let request = create_test_request("/reads/sample1", query);
440 let result = Auth::validate_restrictions(restrictions, request, Endpoint::Reads);
441 assert!(result.is_ok());
442 }
443
444 #[test]
445 fn validate_restrictions_regex_path_match() {
446 let reference_restriction = ReferenceNameRestrictionBuilder::default()
447 .name("chr1")
448 .format(Format::Bam)
449 .build()
450 .unwrap();
451 let rule = AuthorizationRuleBuilder::default()
452 .path("/reads/sample(.+)")
453 .reference_name(reference_restriction)
454 .build()
455 .unwrap();
456 let restrictions = AuthorizationRestrictionsBuilder::default()
457 .rule(rule)
458 .build()
459 .unwrap();
460
461 let mut query = HashMap::new();
462 query.insert("referenceName".to_string(), "chr1".to_string());
463 query.insert("format".to_string(), "BAM".to_string());
464
465 let request = create_test_request("/reads/sample123", query);
466 let result = Auth::validate_restrictions(restrictions, request, Endpoint::Reads);
467 assert!(result.is_ok());
468 }
469
470 #[test]
471 fn validate_restrictions_reference_name_mismatch() {
472 let reference_restriction = ReferenceNameRestrictionBuilder::default()
473 .name("chr1")
474 .format(Format::Bam)
475 .build()
476 .unwrap();
477 let rule = AuthorizationRuleBuilder::default()
478 .path("/reads/sample1")
479 .reference_name(reference_restriction)
480 .build()
481 .unwrap();
482 let restrictions = AuthorizationRestrictionsBuilder::default()
483 .rule(rule)
484 .build()
485 .unwrap();
486
487 let mut query = HashMap::new();
488 query.insert("referenceName".to_string(), "chr2".to_string());
489 query.insert("format".to_string(), "BAM".to_string());
490
491 let request = create_test_request("/reads/sample1", query);
492 let result = Auth::validate_restrictions(restrictions, request, Endpoint::Reads);
493 assert!(result.is_err());
494 }
495
496 #[test]
497 fn validate_restrictions_format_mismatch() {
498 let reference_restriction = ReferenceNameRestrictionBuilder::default()
499 .name("chr1")
500 .format(Format::Bam)
501 .build()
502 .unwrap();
503 let rule = AuthorizationRuleBuilder::default()
504 .path("/reads/sample1")
505 .reference_name(reference_restriction)
506 .build()
507 .unwrap();
508 let restrictions = AuthorizationRestrictionsBuilder::default()
509 .rule(rule)
510 .build()
511 .unwrap();
512
513 let mut query = HashMap::new();
514 query.insert("referenceName".to_string(), "chr1".to_string());
515 query.insert("format".to_string(), "CRAM".to_string());
516
517 let request = create_test_request("/reads/sample1", query);
518 let result = Auth::validate_restrictions(restrictions, request, Endpoint::Reads);
519 assert!(result.is_err());
520 }
521
522 #[test]
523 fn validate_restrictions_interval_not_contained() {
524 let reference_restriction = ReferenceNameRestrictionBuilder::default()
525 .name("chr1")
526 .format(Format::Bam)
527 .start(1000)
528 .end(2000)
529 .build()
530 .unwrap();
531 let rule = AuthorizationRuleBuilder::default()
532 .path("/reads/sample1")
533 .reference_name(reference_restriction)
534 .build()
535 .unwrap();
536 let restrictions = AuthorizationRestrictionsBuilder::default()
537 .rule(rule)
538 .build()
539 .unwrap();
540
541 let mut query = HashMap::new();
542 query.insert("referenceName".to_string(), "chr1".to_string());
543 query.insert("start".to_string(), "500".to_string());
544 query.insert("end".to_string(), "1500".to_string());
545
546 let request = create_test_request("/reads/sample1", query);
547 let result = Auth::validate_restrictions(restrictions, request, Endpoint::Reads);
548 assert!(result.is_err());
549 }
550
551 #[test]
552 fn validate_restrictions_format_none_allows_any() {
553 let reference_restriction = ReferenceNameRestrictionBuilder::default()
554 .name("chr1")
555 .build()
556 .unwrap();
557 let rule = AuthorizationRuleBuilder::default()
558 .path("/reads/sample1")
559 .reference_name(reference_restriction)
560 .build()
561 .unwrap();
562 let restrictions = AuthorizationRestrictionsBuilder::default()
563 .rule(rule)
564 .build()
565 .unwrap();
566
567 let mut query = HashMap::new();
568 query.insert("referenceName".to_string(), "chr1".to_string());
569 query.insert("format".to_string(), "CRAM".to_string());
570
571 let request = create_test_request("/reads/sample1", query);
572 let result = Auth::validate_restrictions(restrictions, request, Endpoint::Reads);
573 assert!(result.is_ok());
574 }
575
576 #[test]
577 fn validate_restrictions_path_with_leading_slash() {
578 let rule = AuthorizationRuleBuilder::default()
579 .path("/reads/sample1")
580 .build()
581 .unwrap();
582 let restrictions = AuthorizationRestrictionsBuilder::default()
583 .rule(rule)
584 .build()
585 .unwrap();
586 let request = create_test_request("/reads/sample1", HashMap::new());
587 let result = Auth::validate_restrictions(restrictions, request, Endpoint::Reads);
588 assert!(result.is_ok());
589 }
590
591 #[tokio::test]
592 async fn validate_authorization_missing_auth_header() {
593 let (auth, _) = create_mock_auth_with_restrictions();
594 let request = create_test_request("/reads/sample1", HashMap::new());
595
596 let result = auth.validate_authorization(request, Endpoint::Reads).await;
597 assert!(result.is_err());
598 assert!(matches!(
599 result.unwrap_err(),
600 HtsGetError::InvalidAuthentication(_)
601 ));
602 }
603
604 #[tokio::test]
605 async fn validate_authorization_invalid_jwt_format() {
606 let (auth, _) = create_mock_auth_with_restrictions();
607 let request =
608 create_request_with_auth_header("/reads/sample1", HashMap::new(), "invalid.jwt.token");
609
610 let result = auth.validate_authorization(request, Endpoint::Reads).await;
611 assert!(result.is_err());
612 assert!(matches!(
613 result.unwrap_err(),
614 HtsGetError::PermissionDenied(_)
615 ));
616 }
617
618 fn create_test_auth_config(public_key: Vec<u8>) -> AuthConfig {
619 AuthConfigBuilder::default()
620 .auth_mode(AuthMode::PublicKey(public_key))
621 .trusted_authorization_url(Uri::from_static("https://www.example.com"))
622 .http_client(HttpClient::new(reqwest::Client::new()))
623 .build()
624 .unwrap()
625 }
626
627 fn create_test_request(path: &str, query: HashMap<String, String>) -> Request {
628 Request::new(path.to_string(), query, HeaderMap::new())
629 }
630
631 fn create_request_with_auth_header(
632 path: &str,
633 query: HashMap<String, String>,
634 token: &str,
635 ) -> Request {
636 let mut headers = HeaderMap::new();
637 headers.insert("authorization", format!("Bearer {token}").parse().unwrap());
638 Request::new(path.to_string(), query, headers)
639 }
640
641 fn create_mock_auth_with_restrictions() -> (Auth, AuthorizationRestrictions) {
642 let (_, public_key) = generate_key_pair();
643
644 let config = create_test_auth_config(public_key);
645 let auth = AuthBuilder::default().with_config(config).build().unwrap();
646
647 let reference_restriction = ReferenceNameRestrictionBuilder::default()
648 .name("chr1")
649 .format(Format::Bam)
650 .start(1000)
651 .end(2000)
652 .build()
653 .unwrap();
654 let rule = AuthorizationRuleBuilder::default()
655 .path("/reads/sample1")
656 .reference_name(reference_restriction)
657 .build()
658 .unwrap();
659 let restrictions = AuthorizationRestrictionsBuilder::default()
660 .rule(rule)
661 .build()
662 .unwrap();
663
664 (auth, restrictions)
665 }
666}