use crate::error::{Result as HtsGetResult, WrappedHtsGetError};
use crate::middleware::error::Error::AuthBuilderError;
use crate::middleware::error::Result;
use crate::{Endpoint, HtsGetError};
use cfg_if::cfg_if;
use headers::authorization::Bearer;
use headers::{Authorization, Header};
use htsget_config::config::advanced::CONTEXT_HEADER_PREFIX;
use htsget_config::config::advanced::auth::authorization::UrlOrStatic;
use htsget_config::config::advanced::auth::jwt::AuthMode;
use htsget_config::config::advanced::auth::response::AuthorizationRestrictionsBuilder;
use htsget_config::config::advanced::auth::{AuthConfig, AuthorizationRestrictions};
use htsget_config::config::location::{Location, PrefixOrId};
use htsget_config::types::{Class, Interval, Query};
use http::{HeaderMap, HeaderName, HeaderValue, Uri};
use jsonpath_rust::JsonPath;
use jsonwebtoken::jwk::JwkSet;
use jsonwebtoken::{Algorithm, DecodingKey, TokenData, Validation, decode, decode_header};
use serde::de::DeserializeOwned;
use serde_json::Value;
use std::fmt::{Debug, Formatter};
use std::str::FromStr;
use tracing::{debug, trace};
#[derive(Default, Debug)]
pub struct AuthBuilder {
config: Option<AuthConfig>,
}
impl AuthBuilder {
pub fn with_config(mut self, config: AuthConfig) -> Self {
self.config = Some(config);
self
}
pub fn build(self) -> Result<Auth> {
let Some(mut config) = self.config else {
return Err(AuthBuilderError("missing config".to_string()));
};
let mut decoding_key = None;
if let Some(AuthMode::PublicKey(public_key)) = config.auth_mode_mut() {
decoding_key = Some(
Auth::decode_public_key(public_key)
.map_err(|_| AuthBuilderError("failed to decode public key".to_string()))?,
);
}
Ok(Auth {
config,
decoding_key,
})
}
}
#[derive(Clone)]
pub struct Auth {
config: AuthConfig,
decoding_key: Option<DecodingKey>,
}
impl Debug for Auth {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("config").finish()
}
}
const ENDPOINT_TYPE_HEADER_NAME: &str = "Endpoint-Type";
const ID_HEADER_NAME: &str = "Id";
impl Auth {
pub fn config(&self) -> &AuthConfig {
&self.config
}
pub async fn fetch_from_url<D: DeserializeOwned>(
&mut self,
url: &str,
headers: HeaderMap,
) -> HtsGetResult<D> {
trace!("fetching url: {}", url);
let response = self
.config
.http_client()
.map_err(|err| HtsGetError::InternalError(format!("failed to fetch data from {url}: {err}")))?
.get(url)
.headers(headers)
.send()
.await?;
trace!("response: {:?}", response);
let status = response.status();
let value = response.json::<Value>().await.map_err(|err| {
HtsGetError::InternalError(format!("failed to fetch data from {url}: {err}"))
})?;
trace!("value: {}", value);
match serde_json::from_value::<D>(value.clone()) {
Ok(response) => Ok(response),
Err(_) => match serde_json::from_value::<WrappedHtsGetError>(value.clone()) {
Ok(err) => Err(HtsGetError::Wrapped(err, status)),
Err(_) => Err(HtsGetError::InternalError(format!(
"failed to fetch data from {url}: {value}"
))),
},
}
}
pub async fn decode_jwks(&mut self, jwks_url: &Uri, token: &str) -> HtsGetResult<DecodingKey> {
let header = decode_header(token)?;
let kid = header
.kid
.ok_or_else(|| HtsGetError::PermissionDenied("JWT missing key ID".to_string()))?;
let jwks = self
.fetch_from_url::<JwkSet>(&jwks_url.to_string(), Default::default())
.await?;
let matched_jwk = jwks
.find(&kid)
.ok_or_else(|| HtsGetError::PermissionDenied("matching JWK not found".to_string()))?;
Ok(DecodingKey::from_jwk(matched_jwk)?)
}
pub fn decode_public_key(key: &[u8]) -> HtsGetResult<DecodingKey> {
Ok(
DecodingKey::from_rsa_pem(key)
.or_else(|_| DecodingKey::from_ed_pem(key))
.or_else(|_| DecodingKey::from_ec_pem(key))?,
)
}
pub fn forwarded_headers(
&self,
request_headers: &HeaderMap,
request_extensions: Option<Value>,
request_endpoint: &Endpoint,
id: &str,
) -> HtsGetResult<HeaderMap> {
let mut forwarded_headers = if self.config.passthrough_auth() {
let auth_header = request_headers
.iter()
.find_map(|(name, value)| {
if Authorization::<Bearer>::decode(&mut [value].into_iter()).is_ok() {
return Some((name.clone(), value.clone()));
}
None
})
.ok_or_else(|| HtsGetError::PermissionDenied("missing authorization header".to_string()))?;
HeaderMap::from_iter([auth_header])
} else {
HeaderMap::default()
};
for header in self.config.forward_headers() {
let Some((existing_name, existing_value)) = request_headers
.iter()
.find_map(|(name, value)| {
if header.to_lowercase() == name.as_str().to_lowercase() {
return match HeaderName::from_str(&format!("{}{}", CONTEXT_HEADER_PREFIX, name)) {
Ok(header) => Some(Ok((header, value))),
Err(err) => Some(Err(HtsGetError::InternalError(err.to_string()))),
};
}
None
})
.transpose()?
else {
continue;
};
forwarded_headers.insert(existing_name, existing_value.clone());
}
if let Some(request_extensions) = request_extensions {
for extension in self.config.forward_extensions() {
let Some(value) = request_extensions.query(extension.json_path()).ok() else {
continue;
};
let value = value.first().ok_or_else(|| {
HtsGetError::InternalError("extension does not have only one value".to_string())
})?;
let value = value.as_str().ok_or_else(|| {
HtsGetError::InternalError("extension value is not a string".to_string())
})?;
let header_name =
HeaderName::from_str(&format!("{}{}", CONTEXT_HEADER_PREFIX, extension.name()))?;
let value = HeaderValue::from_str(value)?;
forwarded_headers.insert(header_name, value);
}
}
if self.config.forward_endpoint_type() {
let header_name = HeaderName::from_str(&format!(
"{}{}",
CONTEXT_HEADER_PREFIX, ENDPOINT_TYPE_HEADER_NAME
))?;
let value = HeaderValue::from_str(&request_endpoint.to_string())?;
forwarded_headers.insert(header_name, value);
}
if self.config.forward_id() {
let header_name =
HeaderName::from_str(&format!("{}{}", CONTEXT_HEADER_PREFIX, ID_HEADER_NAME))?;
let value = HeaderValue::from_str(id)?;
forwarded_headers.insert(header_name, value);
}
Ok(forwarded_headers)
}
pub async fn query_authorization_service(
&mut self,
headers: &HeaderMap,
request_extensions: Option<Value>,
request_endpoint: &Endpoint,
id: &str,
) -> HtsGetResult<Option<AuthorizationRestrictions>> {
match self.config.authorization_url() {
Some(UrlOrStatic::Url(uri)) => {
let forwarded_headers =
self.forwarded_headers(headers, request_extensions, request_endpoint, id)?;
self
.fetch_from_url(&uri.to_string(), forwarded_headers)
.await
.map(Some)
}
Some(UrlOrStatic::Static(config)) => Ok(Some(config.clone())),
_ => Ok(None),
}
}
pub fn validate_restrictions(
restrictions: AuthorizationRestrictions,
path: &str,
queries: &mut [Query],
suppressed_interval: bool,
) -> HtsGetResult<AuthorizationRestrictions> {
let matching_rules = restrictions
.into_rules()
.into_iter()
.filter(|rule| {
match rule.location() {
Location::Simple(location) if location.prefix_or_id().is_some() => {
match location.prefix_or_id().unwrap_or_default() {
PrefixOrId::Prefix(prefix) => {
path.starts_with(&prefix)
}
PrefixOrId::Id(id) => {
id == path
}
}
}
Location::Regex(location) => {
location.regex().is_match(path)
}
_ => false,
}
})
.collect::<Vec<_>>();
if matching_rules.is_empty() {
return Err(HtsGetError::PermissionDenied(
"failed to authorize user based on authorization service restrictions".to_string(),
));
}
let (allows_all, allows_specific): (Vec<_>, Vec<_>) = matching_rules
.into_iter()
.partition(|rule| rule.rules().is_none());
for query in queries {
if query.class() == Class::Header {
continue;
}
let matching_restriction = allows_specific
.iter()
.flat_map(|rule| rule.rules().unwrap_or_default())
.filter_map(|restriction| {
let name_match = restriction.reference_name().is_none()
|| restriction.reference_name() == query.reference_name();
let format_match =
restriction.format().is_none() || restriction.format() == Some(query.format());
let interval_match = if suppressed_interval {
restriction.interval().constraint_interval(query.interval())
} else {
restriction.interval().contains_interval(query.interval())
};
if let Some(interval_match) = interval_match
&& name_match
&& format_match
{
return Some(interval_match);
}
None
})
.max_by(Interval::order_by_range);
if suppressed_interval {
if allows_all.is_empty() && matching_restriction.is_none() {
query.set_class(Class::Header);
continue;
}
if let Some(matching_restriction) = matching_restriction {
query.set_interval(matching_restriction);
}
} else if allows_all.is_empty() && matching_restriction.is_none() {
return Err(HtsGetError::PermissionDenied(
"failed to authorize user based on authorization service restrictions".to_string(),
));
}
}
AuthorizationRestrictionsBuilder::default()
.rules([allows_all, allows_specific].concat())
.build()
.map_err(|err| HtsGetError::InternalError(err.to_string()))
}
pub async fn validate_jwt(&mut self, headers: &HeaderMap) -> HtsGetResult<TokenData<Value>> {
let auth_token = headers
.values()
.find_map(|value| Authorization::<Bearer>::decode(&mut [value].into_iter()).ok())
.ok_or_else(|| {
HtsGetError::InvalidAuthentication("invalid authorization header".to_string())
})?;
let decoding_key = if let Some(ref decoding_key) = self.decoding_key {
decoding_key
} else if matches!(self.config.auth_mode(), Some(AuthMode::Jwks(_))) {
let url = if let Some(AuthMode::Jwks(uri)) = self.config.auth_mode() {
uri.clone()
} else {
return Err(HtsGetError::InternalError(
"JWT validation not set".to_string(),
));
};
&self.decode_jwks(&url, auth_token.token()).await?
} else if let Some(AuthMode::PublicKey(key)) = self.config.auth_mode() {
&Self::decode_public_key(key)?
} else {
return Err(HtsGetError::InternalError(
"JWT validation not set".to_string(),
));
};
let mut validation = Validation::default();
validation.validate_exp = true;
validation.validate_aud = true;
validation.validate_nbf = true;
if let Some(iss) = self.config.validate_issuer() {
validation.set_issuer(iss);
validation.required_spec_claims.insert("iss".to_string());
}
if let Some(aud) = self.config.validate_audience() {
validation.set_audience(aud);
validation.required_spec_claims.insert("aud".to_string());
}
if let Some(sub) = self.config.validate_subject() {
validation.sub = Some(sub.to_string());
validation.required_spec_claims.insert("sub".to_string());
}
validation.algorithms = vec![Algorithm::RS256];
let decoded_claims = decode::<Value>(auth_token.token(), decoding_key, &validation)
.or_else(|_| {
validation.algorithms = vec![Algorithm::ES256];
decode::<Value>(auth_token.token(), decoding_key, &validation)
})
.or_else(|_| {
validation.algorithms = vec![Algorithm::EdDSA];
decode::<Value>(auth_token.token(), decoding_key, &validation)
});
let claims = match decoded_claims {
Ok(claims) => claims,
Err(err) => return Err(HtsGetError::PermissionDenied(format!("invalid JWT: {err}"))),
};
Ok(claims)
}
pub async fn validate_authorization(
&mut self,
headers: &HeaderMap,
path: &str,
queries: &mut [Query],
request_extensions: Option<Value>,
endpoint: &Endpoint,
) -> HtsGetResult<Option<AuthorizationRestrictions>> {
let restrictions = self
.query_authorization_service(headers, request_extensions, endpoint, path)
.await?;
debug!(restrictions = ?restrictions, "restrictions");
if let Some(restrictions) = restrictions {
cfg_if! {
if #[cfg(feature = "experimental")] {
Self::validate_restrictions(restrictions, path, queries, self.config.suppress_errors()).map(Some)
} else {
Self::validate_restrictions(restrictions, path, queries, false).map(Some)
}
}
} else {
Ok(None)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Endpoint, convert_to_query, match_format_from_query};
use htsget_config::config::advanced::HttpClient;
use htsget_config::config::advanced::auth::AuthConfigBuilder;
use htsget_config::config::advanced::auth::authorization::ForwardExtensions;
use htsget_config::config::advanced::auth::response::{
AuthorizationRestrictionsBuilder, AuthorizationRuleBuilder, ReferenceNameRestrictionBuilder,
};
use htsget_config::config::advanced::regex_location::RegexLocation;
use htsget_config::config::location::SimpleLocation;
use htsget_config::types::{Format, Request};
use htsget_test::util::generate_key_pair;
use http::{HeaderMap, Uri};
use regex::Regex;
use reqwest_middleware::ClientBuilder;
use serde_json::json;
use std::collections::HashMap;
#[test]
fn auth_builder_missing_config() {
let result = AuthBuilder::default().build();
assert!(matches!(result, Err(AuthBuilderError(_))));
}
#[test]
fn auth_builder_success_with_public_key() {
let (_, public_key) = generate_key_pair();
let config = create_test_auth_config(public_key);
let result = AuthBuilder::default().with_config(config).build();
assert!(result.is_ok());
}
#[test]
fn validate_restrictions_rule_allows_all() {
let rule = AuthorizationRuleBuilder::default()
.location(test_location())
.build()
.unwrap();
let restrictions = AuthorizationRestrictionsBuilder::default()
.rule(rule)
.build()
.unwrap();
let request = create_test_query(Endpoint::Reads, "sample1", HashMap::new());
let result =
Auth::validate_restrictions(restrictions, request.id(), &mut [request.clone()], false);
assert!(result.is_ok());
}
#[test]
fn validate_restrictions_exact_path_match() {
let reference_restriction = ReferenceNameRestrictionBuilder::default()
.name("chr1")
.format(Format::Bam)
.start(1000)
.end(2000)
.build()
.unwrap();
let rule = AuthorizationRuleBuilder::default()
.location(test_location())
.reference_name(reference_restriction)
.build()
.unwrap();
let restrictions = AuthorizationRestrictionsBuilder::default()
.rule(rule)
.build()
.unwrap();
let mut query = HashMap::new();
query.insert("referenceName".to_string(), "chr1".to_string());
query.insert("start".to_string(), "1500".to_string());
query.insert("end".to_string(), "1800".to_string());
query.insert("format".to_string(), "BAM".to_string());
let request = create_test_query(Endpoint::Reads, "sample1", query);
let result =
Auth::validate_restrictions(restrictions, request.id(), &mut [request.clone()], false);
assert!(result.is_ok());
}
#[test]
fn validate_restrictions_regex_prefix_match() {
let reference_restriction = ReferenceNameRestrictionBuilder::default()
.name("chr1")
.format(Format::Bam)
.build()
.unwrap();
let rule = AuthorizationRuleBuilder::default()
.location(Location::Simple(Box::new(SimpleLocation::new(
Default::default(),
"".to_string(),
Some(PrefixOrId::Prefix("sam".to_string())),
))))
.reference_name(reference_restriction)
.build()
.unwrap();
let restrictions = AuthorizationRestrictionsBuilder::default()
.rule(rule)
.build()
.unwrap();
let mut query = HashMap::new();
query.insert("referenceName".to_string(), "chr1".to_string());
query.insert("format".to_string(), "BAM".to_string());
let request = create_test_query(Endpoint::Reads, "sample123", query);
let result =
Auth::validate_restrictions(restrictions, request.id(), &mut [request.clone()], false);
assert!(result.is_ok());
}
#[test]
fn validate_restrictions_regex_match() {
let reference_restriction = ReferenceNameRestrictionBuilder::default()
.name("chr1")
.format(Format::Bam)
.build()
.unwrap();
let rule = AuthorizationRuleBuilder::default()
.location(Location::Regex(Box::new(RegexLocation::new(
Regex::new("sample(.+)").unwrap(),
"".to_string(),
Default::default(),
Default::default(),
))))
.reference_name(reference_restriction)
.build()
.unwrap();
let restrictions = AuthorizationRestrictionsBuilder::default()
.rule(rule)
.build()
.unwrap();
let mut query = HashMap::new();
query.insert("referenceName".to_string(), "chr1".to_string());
query.insert("format".to_string(), "BAM".to_string());
let request = create_test_query(Endpoint::Reads, "sample123", query);
let result =
Auth::validate_restrictions(restrictions, request.id(), &mut [request.clone()], false);
assert!(result.is_ok());
}
#[test]
fn validate_restrictions_forward_headers() {
let (_, public_key) = generate_key_pair();
let builder = AuthConfigBuilder::default()
.auth_mode(AuthMode::PublicKey(public_key))
.authorization_url(UrlOrStatic::Url(Uri::from_static(
"https://www.example.com",
)))
.http_client(HttpClient::new(
ClientBuilder::new(reqwest::Client::new()).build(),
));
let config = builder
.clone()
.passthrough_auth(true)
.forward_headers(vec!["Custom1".to_string()])
.build()
.unwrap();
let result = AuthBuilder::default().with_config(config).build().unwrap();
let request_headers = HeaderMap::from_iter([
(
"Authorization".parse().unwrap(),
"Bearer Value".parse().unwrap(),
),
("Custom1".parse().unwrap(), "Value".parse().unwrap()),
("Custom2".parse().unwrap(), "Value".parse().unwrap()),
]);
let forwarded_headers = result
.forwarded_headers(&request_headers, None, &Endpoint::Reads, "id")
.unwrap();
assert_eq!(
forwarded_headers,
HeaderMap::from_iter([
(
format!("{}Custom1", CONTEXT_HEADER_PREFIX).parse().unwrap(),
"Value".parse().unwrap()
),
(
"Authorization".parse().unwrap(),
"Bearer Value".parse().unwrap()
),
])
);
let config = builder
.clone()
.passthrough_auth(true)
.forward_headers(vec!["Custom1".to_string(), "Authorization".to_string()])
.build()
.unwrap();
let result = AuthBuilder::default().with_config(config).build().unwrap();
let forwarded_headers = result
.forwarded_headers(&request_headers, None, &Endpoint::Reads, "id")
.unwrap();
assert_eq!(
forwarded_headers,
HeaderMap::from_iter([
(
format!("{}Custom1", CONTEXT_HEADER_PREFIX).parse().unwrap(),
"Value".parse().unwrap()
),
(
format!("{}Authorization", CONTEXT_HEADER_PREFIX)
.parse()
.unwrap(),
"Bearer Value".parse().unwrap()
),
(
"Authorization".parse().unwrap(),
"Bearer Value".parse().unwrap()
),
])
);
let config = builder
.clone()
.forward_headers(vec!["Custom1".to_string()])
.build()
.unwrap();
let result = AuthBuilder::default().with_config(config).build().unwrap();
let forwarded_headers = result
.forwarded_headers(&request_headers, None, &Endpoint::Reads, "id")
.unwrap();
assert_eq!(
forwarded_headers,
HeaderMap::from_iter([(
format!("{}Custom1", CONTEXT_HEADER_PREFIX).parse().unwrap(),
"Value".parse().unwrap()
),])
);
let config = builder
.clone()
.forward_extensions(vec![ForwardExtensions::new(
"$.Key".to_string(),
"Custom1".to_string(),
)])
.build()
.unwrap();
let result = AuthBuilder::default().with_config(config).build().unwrap();
let forwarded_headers = result
.forwarded_headers(
&request_headers,
Some(json!({
"Key": "Value"
})),
&Endpoint::Reads,
"id",
)
.unwrap();
assert_eq!(
forwarded_headers,
HeaderMap::from_iter([(
format!("{}Custom1", CONTEXT_HEADER_PREFIX).parse().unwrap(),
"Value".parse().unwrap()
),])
);
let config = builder.clone().forward_endpoint_type(true).build().unwrap();
let result = AuthBuilder::default().with_config(config).build().unwrap();
let forwarded_headers = result
.forwarded_headers(&request_headers, None, &Endpoint::Variants, "id")
.unwrap();
assert_eq!(
forwarded_headers,
HeaderMap::from_iter([(
format!("{}{}", CONTEXT_HEADER_PREFIX, ENDPOINT_TYPE_HEADER_NAME)
.parse()
.unwrap(),
"variants".parse().unwrap()
),])
);
let config = builder.forward_id(true).build().unwrap();
let result = AuthBuilder::default().with_config(config).build().unwrap();
let forwarded_headers = result
.forwarded_headers(&request_headers, None, &Endpoint::Variants, "id")
.unwrap();
assert_eq!(
forwarded_headers,
HeaderMap::from_iter([(
format!("{}{}", CONTEXT_HEADER_PREFIX, ID_HEADER_NAME)
.parse()
.unwrap(),
"id".parse().unwrap()
),])
);
}
#[test]
fn validate_restrictions_reference_name_mismatch() {
let reference_restriction = ReferenceNameRestrictionBuilder::default()
.name("chr1")
.format(Format::Bam)
.build()
.unwrap();
let rule = AuthorizationRuleBuilder::default()
.location(test_location())
.reference_name(reference_restriction)
.build()
.unwrap();
let restrictions = AuthorizationRestrictionsBuilder::default()
.rule(rule.clone())
.build()
.unwrap();
let mut query = HashMap::new();
query.insert("class".to_string(), "header".to_string());
query.insert("format".to_string(), "BAM".to_string());
let request = create_test_query(Endpoint::Reads, "sample1", query);
let result =
Auth::validate_restrictions(restrictions, request.id(), &mut [request.clone()], false);
assert!(result.is_ok());
}
#[test]
fn validate_restrictions_header() {
let reference_restriction = ReferenceNameRestrictionBuilder::default()
.name("chr1")
.format(Format::Bam)
.build()
.unwrap();
let rule = AuthorizationRuleBuilder::default()
.location(test_location())
.reference_name(reference_restriction)
.build()
.unwrap();
let restrictions = AuthorizationRestrictionsBuilder::default()
.rule(rule.clone())
.build()
.unwrap();
let mut query = HashMap::new();
query.insert("format".to_string(), "BAM".to_string());
query.insert("class".to_string(), "header".to_string());
let request = create_test_query(Endpoint::Reads, "sample1", query);
let result =
Auth::validate_restrictions(restrictions, request.id(), &mut [request.clone()], false);
assert!(result.is_ok());
}
#[cfg(feature = "experimental")]
#[test]
fn validate_restrictions_reference_name_mismatch_suppressed() {
let reference_restriction = ReferenceNameRestrictionBuilder::default()
.name("chr1")
.format(Format::Bam)
.build()
.unwrap();
let rule = AuthorizationRuleBuilder::default()
.location(test_location())
.reference_name(reference_restriction)
.build()
.unwrap();
let restrictions = AuthorizationRestrictionsBuilder::default()
.rule(rule.clone())
.build()
.unwrap();
let mut query = HashMap::new();
query.insert("referenceName".to_string(), "chr2".to_string());
query.insert("format".to_string(), "BAM".to_string());
let request = create_test_query(Endpoint::Reads, "sample1", query);
let result =
Auth::validate_restrictions(restrictions, request.id(), &mut [request.clone()], true);
assert!(result.is_ok());
}
#[test]
fn validate_restrictions_format_mismatch() {
let reference_restriction = ReferenceNameRestrictionBuilder::default()
.name("chr1")
.format(Format::Bam)
.build()
.unwrap();
let rule = AuthorizationRuleBuilder::default()
.location(test_location())
.reference_name(reference_restriction)
.build()
.unwrap();
let restrictions = AuthorizationRestrictionsBuilder::default()
.rule(rule.clone())
.build()
.unwrap();
let mut query = HashMap::new();
query.insert("referenceName".to_string(), "chr1".to_string());
query.insert("format".to_string(), "CRAM".to_string());
let request = create_test_query(Endpoint::Reads, "sample1", query);
let result =
Auth::validate_restrictions(restrictions, request.id(), &mut [request.clone()], false);
assert!(result.is_err());
}
#[cfg(feature = "experimental")]
#[test]
fn validate_restrictions_format_mismatch_suppressed() {
let reference_restriction = ReferenceNameRestrictionBuilder::default()
.name("chr1")
.format(Format::Bam)
.build()
.unwrap();
let rule = AuthorizationRuleBuilder::default()
.location(test_location())
.reference_name(reference_restriction)
.build()
.unwrap();
let restrictions = AuthorizationRestrictionsBuilder::default()
.rule(rule.clone())
.build()
.unwrap();
let mut query = HashMap::new();
query.insert("referenceName".to_string(), "chr1".to_string());
query.insert("format".to_string(), "CRAM".to_string());
let request = create_test_query(Endpoint::Reads, "sample1", query);
let result =
Auth::validate_restrictions(restrictions, request.id(), &mut [request.clone()], true);
assert!(result.is_ok());
}
#[test]
fn validate_restrictions_interval_not_contained() {
test_interval_suppressed(
Some(1000),
Some(2000),
Some(1250),
Some(1750),
(Interval::new(Some(1250), Some(1750)), Class::Body),
false,
false,
);
test_interval_suppressed(
Some(1000),
Some(2000),
Some(500),
None,
(Interval::new(Some(500), None), Class::Body),
true,
false,
);
test_interval_suppressed(
Some(1000),
Some(2000),
None,
Some(2500),
(Interval::new(None, Some(2500)), Class::Body),
true,
false,
);
test_interval_suppressed(
Some(1000),
Some(2000),
None,
None,
(Interval::new(None, None), Class::Body),
true,
false,
);
test_interval_suppressed(
Some(1000),
Some(2000),
Some(500),
Some(1500),
(Interval::new(Some(500), Some(1500)), Class::Body),
true,
false,
);
test_interval_suppressed(
Some(1000),
Some(2000),
None,
Some(1500),
(Interval::new(None, Some(1500)), Class::Body),
true,
false,
);
test_interval_suppressed(
Some(1000),
Some(2000),
Some(1500),
Some(2500),
(Interval::new(Some(1500), Some(2500)), Class::Body),
true,
false,
);
test_interval_suppressed(
Some(1000),
Some(2000),
Some(1500),
None,
(Interval::new(Some(1500), None), Class::Body),
true,
false,
);
test_interval_suppressed(
Some(1000),
Some(2000),
Some(500),
Some(1000),
(Interval::new(Some(500), Some(1000)), Class::Body),
true,
false,
);
test_interval_suppressed(
Some(1000),
Some(2000),
Some(2000),
Some(2500),
(Interval::new(Some(2000), Some(2500)), Class::Body),
true,
false,
);
test_interval_suppressed(
None,
Some(2000),
Some(500),
Some(1500),
(Interval::new(Some(500), Some(1500)), Class::Body),
false,
false,
);
test_interval_suppressed(
None,
Some(2000),
Some(1500),
Some(2500),
(Interval::new(Some(1500), Some(2500)), Class::Body),
true,
false,
);
test_interval_suppressed(
Some(1000),
None,
Some(1500),
Some(2500),
(Interval::new(Some(1500), Some(2500)), Class::Body),
false,
false,
);
test_interval_suppressed(
Some(1000),
None,
Some(500),
Some(1500),
(Interval::new(Some(500), Some(1500)), Class::Body),
true,
false,
);
test_interval_suppressed(
None,
None,
Some(500),
Some(2500),
(Interval::new(Some(500), Some(2500)), Class::Body),
false,
false,
);
test_interval_suppressed(
None,
None,
Some(500),
None,
(Interval::new(Some(500), None), Class::Body),
false,
false,
);
test_interval_suppressed(
None,
None,
None,
Some(2500),
(Interval::new(None, Some(2500)), Class::Body),
false,
false,
);
}
#[cfg(feature = "experimental")]
#[test]
fn validate_restrictions_interval_suppressed() {
test_interval_suppressed(
Some(1000),
Some(2000),
Some(1250),
Some(1750),
(Interval::new(Some(1250), Some(1750)), Class::Body),
false,
true,
);
test_interval_suppressed(
Some(1000),
Some(2000),
Some(500),
None,
(Interval::new(Some(1000), Some(2000)), Class::Body),
false,
true,
);
test_interval_suppressed(
Some(1000),
Some(2000),
None,
Some(2500),
(Interval::new(Some(1000), Some(2000)), Class::Body),
false,
true,
);
test_interval_suppressed(
Some(1000),
Some(2000),
None,
None,
(Interval::new(Some(1000), Some(2000)), Class::Body),
false,
true,
);
test_interval_suppressed(
Some(1000),
Some(2000),
Some(500),
Some(1500),
(Interval::new(Some(1000), Some(1500)), Class::Body),
false,
true,
);
test_interval_suppressed(
Some(1000),
Some(2000),
None,
Some(1500),
(Interval::new(Some(1000), Some(1500)), Class::Body),
false,
true,
);
test_interval_suppressed(
Some(1000),
Some(2000),
Some(1500),
Some(2500),
(Interval::new(Some(1500), Some(2000)), Class::Body),
false,
true,
);
test_interval_suppressed(
Some(1000),
Some(2000),
Some(1500),
None,
(Interval::new(Some(1500), Some(2000)), Class::Body),
false,
true,
);
test_interval_suppressed(
Some(1000),
Some(2000),
Some(500),
Some(1000),
(Interval::new(Some(500), Some(1000)), Class::Header),
false,
true,
);
test_interval_suppressed(
Some(1000),
Some(2000),
Some(2000),
Some(2500),
(Interval::new(Some(2000), Some(2500)), Class::Header),
false,
true,
);
test_interval_suppressed(
None,
Some(2000),
Some(500),
Some(1500),
(Interval::new(Some(500), Some(1500)), Class::Body),
false,
true,
);
test_interval_suppressed(
None,
Some(2000),
Some(1500),
Some(2500),
(Interval::new(Some(1500), Some(2000)), Class::Body),
false,
true,
);
test_interval_suppressed(
Some(1000),
None,
Some(1500),
Some(2500),
(Interval::new(Some(1500), Some(2500)), Class::Body),
false,
true,
);
test_interval_suppressed(
Some(1000),
None,
Some(500),
Some(1500),
(Interval::new(Some(1000), Some(1500)), Class::Body),
false,
true,
);
test_interval_suppressed(
None,
None,
Some(500),
Some(2500),
(Interval::new(Some(500), Some(2500)), Class::Body),
false,
true,
);
test_interval_suppressed(
None,
None,
Some(500),
None,
(Interval::new(Some(500), None), Class::Body),
false,
true,
);
test_interval_suppressed(
None,
None,
None,
Some(2500),
(Interval::new(None, Some(2500)), Class::Body),
false,
true,
);
}
#[test]
fn validate_restrictions_format_none_allows_any() {
let reference_restriction = ReferenceNameRestrictionBuilder::default()
.name("chr1")
.build()
.unwrap();
let rule = AuthorizationRuleBuilder::default()
.location(test_location())
.reference_name(reference_restriction)
.build()
.unwrap();
let restrictions = AuthorizationRestrictionsBuilder::default()
.rule(rule)
.build()
.unwrap();
let mut query = HashMap::new();
query.insert("referenceName".to_string(), "chr1".to_string());
query.insert("format".to_string(), "CRAM".to_string());
let request = create_test_query(Endpoint::Reads, "sample1", query);
let result =
Auth::validate_restrictions(restrictions, request.id(), &mut [request.clone()], false);
assert!(result.is_ok());
}
#[test]
fn validate_restrictions_path_with_leading_slash() {
let rule = AuthorizationRuleBuilder::default()
.location(test_location())
.build()
.unwrap();
let restrictions = AuthorizationRestrictionsBuilder::default()
.rule(rule)
.build()
.unwrap();
let request = create_test_query(Endpoint::Reads, "sample1", HashMap::new());
let result =
Auth::validate_restrictions(restrictions, request.id(), &mut [request.clone()], false);
assert!(result.is_ok());
}
#[tokio::test]
async fn validate_authorization_missing_auth_header() {
let mut auth = create_mock_auth_with_restrictions();
let request = Request::new("sample1".to_string(), HashMap::new(), HeaderMap::new());
let result = auth.validate_jwt(request.headers()).await;
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
HtsGetError::InvalidAuthentication(_)
));
}
#[tokio::test]
async fn validate_authorization_invalid_jwt_format() {
let mut auth = create_mock_auth_with_restrictions();
let request = create_request_with_auth_header("sample1", HashMap::new(), "invalid.jwt.token");
let result = auth.validate_jwt(request.headers()).await;
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
HtsGetError::PermissionDenied(_)
));
}
fn create_test_auth_config(public_key: Vec<u8>) -> AuthConfig {
AuthConfigBuilder::default()
.auth_mode(AuthMode::PublicKey(public_key))
.authorization_url(UrlOrStatic::Url(Uri::from_static(
"https://www.example.com",
)))
.http_client(HttpClient::new(
ClientBuilder::new(reqwest::Client::new()).build(),
))
.build()
.unwrap()
}
fn create_test_query(endpoint: Endpoint, path: &str, query: HashMap<String, String>) -> Query {
let request = Request::new(path.to_string(), query, HeaderMap::new());
let format = match_format_from_query(&endpoint, request.query()).unwrap();
convert_to_query(request, format).unwrap()
}
fn create_request_with_auth_header(
path: &str,
query: HashMap<String, String>,
token: &str,
) -> Request {
let mut headers = HeaderMap::new();
headers.insert("authorization", format!("Bearer {token}").parse().unwrap());
Request::new(path.to_string(), query, headers)
}
fn create_mock_auth_with_restrictions() -> Auth {
let (_, public_key) = generate_key_pair();
let config = create_test_auth_config(public_key);
AuthBuilder::default().with_config(config).build().unwrap()
}
fn test_interval_suppressed(
restrict_start: Option<u32>,
restrict_end: Option<u32>,
request_start: Option<u32>,
request_end: Option<u32>,
expected_response: (Interval, Class),
is_err: bool,
suppress_interval: bool,
) {
let mut reference_restriction = ReferenceNameRestrictionBuilder::default()
.name("chr1")
.format(Format::Bam);
if let Some(start) = restrict_start {
reference_restriction = reference_restriction.start(start);
}
if let Some(end) = restrict_end {
reference_restriction = reference_restriction.end(end);
}
let reference_restriction = reference_restriction.build().unwrap();
let rule = AuthorizationRuleBuilder::default()
.location(test_location())
.reference_name(reference_restriction)
.build()
.unwrap();
let restrictions = AuthorizationRestrictionsBuilder::default()
.rule(rule.clone())
.build()
.unwrap();
let mut query = HashMap::new();
query.insert("referenceName".to_string(), "chr1".to_string());
request_start.map(|start| query.insert("start".to_string(), start.to_string()));
request_end.map(|end| query.insert("end".to_string(), end.to_string()));
let request = create_test_query(Endpoint::Reads, "sample1", query);
let id = request.id().to_string();
let mut slice = [request];
let result = Auth::validate_restrictions(restrictions, &id, &mut slice, suppress_interval);
if is_err {
assert!(result.is_err());
} else {
assert!(result.is_ok());
}
assert_eq!(slice.first().unwrap().interval(), expected_response.0);
assert_eq!(slice.last().unwrap().class(), expected_response.1);
}
fn test_location() -> Location {
Location::Simple(Box::new(SimpleLocation::new(
Default::default(),
"".to_string(),
Some(PrefixOrId::Id("sample1".to_string())),
)))
}
}