htsget-http 0.8.5

Crate for handling HTTP in htsget-rs.
Documentation
use cfg_if::cfg_if;
pub use error::{HtsGetError, Result};
pub use htsget_config::config::Config;
use htsget_config::types::Format::{Bam, Bcf, Cram, Vcf};
use htsget_config::types::{Format, Query, Request, Response};
pub use http_core::{get, post};
pub use post_request::{PostRequest, Region};
use query_builder::QueryBuilder;
pub use service_info::get_service_info_json;
pub use service_info::{Htsget, ServiceInfo, Type};
use std::collections::HashMap;
use std::fmt::{Display, Formatter};
use std::str::FromStr;
use std::{fmt, result};

pub mod error;
pub mod http_core;
pub mod middleware;
pub mod post_request;
pub mod query_builder;
pub mod service_info;

/// A enum to distinguish between the two endpoint defined in the
/// [HtsGet specification](https://samtools.github.io/hts-specs/htsget.html)
#[derive(Debug, PartialEq, Eq)]
pub enum Endpoint {
  Reads,
  Variants,
}

impl FromStr for Endpoint {
  type Err = ();

  fn from_str(s: &str) -> result::Result<Self, Self::Err> {
    match s {
      "reads" => Ok(Self::Reads),
      "variants" => Ok(Self::Variants),
      _ => Err(()),
    }
  }
}

impl Display for Endpoint {
  fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
    match self {
      Self::Reads => write!(f, "reads"),
      Self::Variants => write!(f, "variants"),
    }
  }
}

/// Match the format from a query parameter.
pub fn match_format_from_query(
  endpoint: &Endpoint,
  query: &HashMap<String, String>,
) -> Result<Format> {
  match_format(endpoint, query.get("format"))
}

/// Get the format from the string.
pub fn match_format(endpoint: &Endpoint, format: Option<impl Into<String>>) -> Result<Format> {
  let format = format.map(Into::into).map(|format| format.to_lowercase());

  match (endpoint, format) {
    (Endpoint::Reads, None) => Ok(Bam),
    (Endpoint::Variants, None) => Ok(Vcf),
    (Endpoint::Reads, Some(s)) if s == "bam" => Ok(Bam),
    (Endpoint::Reads, Some(s)) if s == "cram" => Ok(Cram),
    (Endpoint::Variants, Some(s)) if s == "vcf" => Ok(Vcf),
    (Endpoint::Variants, Some(s)) if s == "bcf" => Ok(Bcf),
    (_, Some(format)) => Err(HtsGetError::UnsupportedFormat(format!(
      "{format} isn't a supported format for this endpoint"
    ))),
  }
}

fn convert_to_query(request: Request, format: Format) -> Result<Query> {
  let query = request.query().clone();

  set_query_builder(
    QueryBuilder::new(request, format),
    query.get("class"),
    query.get("referenceName"),
    query.get("fields"),
    (query.get("tags"), query.get("notags")),
    (query.get("start"), query.get("end")),
    query.get("encryptionScheme"),
  )
}

fn set_query_builder(
  builder: QueryBuilder,
  class: Option<impl Into<String>>,
  reference_name: Option<impl Into<String>>,
  fields: Option<impl Into<String>>,
  (tags, no_tags): (Option<impl Into<String>>, Option<impl Into<String>>),
  (start, end): (Option<impl Into<String>>, Option<impl Into<String>>),
  _encryption_scheme: Option<impl Into<String>>,
) -> Result<Query> {
  let builder = builder
    .with_class(class)?
    .with_fields(fields)
    .with_tags(tags, no_tags)?
    .with_reference_name(reference_name)
    .with_range(start, end)?;

  cfg_if! {
    if #[cfg(feature = "experimental")] {
      Ok(builder.with_encryption_scheme(_encryption_scheme)?.build())
    } else {
      Ok(builder.build())
    }
  }
}

fn merge_responses(responses: Vec<Response>) -> Option<Response> {
  responses.into_iter().reduce(|mut acc, mut response| {
    acc.urls.append(&mut response.urls);
    acc
  })
}

#[cfg(test)]
mod tests {
  use std::collections::HashMap;
  use std::path::PathBuf;

  use htsget_config::storage;
  use htsget_config::types::{Headers, JsonResponse, Request, Scheme, Url};
  use htsget_search::FileStorage;
  use htsget_search::HtsGet;
  use htsget_search::Storage;
  use htsget_search::from_storage::HtsGetFromStorage;
  use http::uri::Authority;

  use super::*;

  #[test]
  fn match_with_invalid_format() {
    assert!(matches!(
      match_format(&Endpoint::Reads, Some("Invalid".to_string())).unwrap_err(),
      HtsGetError::UnsupportedFormat(_)
    ));
  }

  #[test]
  fn match_with_invalid_endpoint() {
    assert!(matches!(
      match_format(&Endpoint::Variants, Some("bam".to_string())).unwrap_err(),
      HtsGetError::UnsupportedFormat(_)
    ));
  }

  #[test]
  fn match_with_valid_format() {
    assert!(matches!(
      match_format(&Endpoint::Reads, Some("bam".to_string())).unwrap(),
      Bam,
    ));
  }

  #[tokio::test]
  async fn get_request() {
    let request = HashMap::new();

    let mut expected_response_headers = Headers::default();
    expected_response_headers.insert("Range".to_string(), "bytes=0-2596798".to_string());

    let request = Request::new(
      "bam/htsnexus_test_NA12878".to_string(),
      request,
      Default::default(),
    );

    assert_eq!(
      get(get_searcher(), request, Endpoint::Reads, None, None, None).await,
      Ok(expected_bam_json_response(expected_response_headers))
    );
  }

  #[tokio::test]
  async fn get_reads_request_with_variants_format() {
    let mut request = HashMap::new();
    request.insert("format".to_string(), "VCF".to_string());

    let request = Request::new(
      "bam/htsnexus_test_NA12878".to_string(),
      request,
      Default::default(),
    );

    assert!(matches!(
      get(get_searcher(), request, Endpoint::Reads, None, None, None).await,
      Err(HtsGetError::UnsupportedFormat(_))
    ));
  }

  #[tokio::test]
  async fn get_request_with_range() {
    let mut request = HashMap::new();
    request.insert("referenceName".to_string(), "chrM".to_string());
    request.insert("start".to_string(), "149".to_string());
    request.insert("end".to_string(), "200".to_string());

    let mut expected_response_headers = Headers::default();
    expected_response_headers.insert("Range".to_string(), "bytes=0-3493".to_string());

    let request = Request::new(
      "vcf/sample1-bcbio-cancer".to_string(),
      request,
      Default::default(),
    );

    assert_eq!(
      get(
        get_searcher(),
        request,
        Endpoint::Variants,
        None,
        None,
        None
      )
      .await,
      Ok(expected_vcf_json_response(expected_response_headers))
    );
  }

  #[tokio::test]
  async fn post_request() {
    let request = Request::new_with_id("bam/htsnexus_test_NA12878".to_string());
    let body = PostRequest {
      format: None,
      class: None,
      fields: None,
      tags: None,
      notags: None,
      regions: None,
      encryption_scheme: None,
    };

    let mut expected_response_headers = Headers::default();
    expected_response_headers.insert("Range".to_string(), "bytes=0-2596798".to_string());

    assert_eq!(
      post(
        get_searcher(),
        body,
        request,
        Endpoint::Reads,
        None,
        None,
        None
      )
      .await,
      Ok(expected_bam_json_response(expected_response_headers))
    );
  }

  #[tokio::test]
  async fn post_variants_request_with_reads_format() {
    let request = Request::new_with_id("bam/htsnexus_test_NA12878".to_string());
    let body = PostRequest {
      format: Some("BAM".to_string()),
      class: None,
      fields: None,
      tags: None,
      notags: None,
      regions: None,
      encryption_scheme: None,
    };

    assert!(matches!(
      post(
        get_searcher(),
        body,
        request,
        Endpoint::Variants,
        None,
        None,
        None
      )
      .await,
      Err(HtsGetError::UnsupportedFormat(_))
    ));
  }

  #[tokio::test]
  async fn post_request_with_range() {
    let request = Request::new_with_id("vcf/sample1-bcbio-cancer".to_string());
    let body = PostRequest {
      format: Some("VCF".to_string()),
      class: None,
      fields: None,
      tags: None,
      notags: None,
      regions: Some(vec![Region {
        reference_name: "chrM".to_string(),
        start: Some(149),
        end: Some(200),
      }]),
      encryption_scheme: None,
    };

    let mut expected_response_headers = Headers::default();
    expected_response_headers.insert("Range".to_string(), "bytes=0-3493".to_string());

    assert_eq!(
      post(
        get_searcher(),
        body,
        request,
        Endpoint::Variants,
        None,
        None,
        None
      )
      .await,
      Ok(expected_vcf_json_response(expected_response_headers))
    );
  }

  fn expected_vcf_json_response(headers: Headers) -> JsonResponse {
    JsonResponse::from(Response::new(
      Vcf,
      vec![
        Url::new("http://127.0.0.1:8081/vcf/sample1-bcbio-cancer.vcf.gz".to_string())
          .with_headers(headers),
      ],
    ))
  }

  fn expected_bam_json_response(headers: Headers) -> JsonResponse {
    JsonResponse::from(Response::new(
      Bam,
      vec![
        Url::new("http://127.0.0.1:8081/bam/htsnexus_test_NA12878.bam".to_string())
          .with_headers(headers),
      ],
    ))
  }

  fn get_base_path() -> PathBuf {
    std::env::current_dir()
      .unwrap()
      .parent()
      .unwrap()
      .join("data")
  }

  fn get_searcher() -> impl HtsGet + Clone {
    HtsGetFromStorage::new(Storage::new(
      FileStorage::new(
        get_base_path(),
        storage::file::File::new(
          Scheme::Http,
          Authority::from_static("127.0.0.1:8081"),
          "data".to_string(),
        ),
        vec![],
      )
      .unwrap(),
    ))
  }
}