htsget_http/
lib.rs

1use cfg_if::cfg_if;
2pub use error::{HtsGetError, Result};
3pub use htsget_config::config::Config;
4use htsget_config::types::Format::{Bam, Bcf, Cram, Vcf};
5use htsget_config::types::{Format, Query, Request, Response};
6pub use http_core::{get, post};
7pub use post_request::{PostRequest, Region};
8use query_builder::QueryBuilder;
9pub use service_info::get_service_info_json;
10pub use service_info::{Htsget, ServiceInfo, Type};
11use std::collections::HashMap;
12use std::result;
13use std::str::FromStr;
14
15pub mod error;
16pub mod http_core;
17pub mod middleware;
18pub mod post_request;
19pub mod query_builder;
20pub mod service_info;
21
22/// A enum to distinguish between the two endpoint defined in the
23/// [HtsGet specification](https://samtools.github.io/hts-specs/htsget.html)
24#[derive(Debug, PartialEq, Eq)]
25pub enum Endpoint {
26  Reads,
27  Variants,
28}
29
30impl FromStr for Endpoint {
31  type Err = ();
32
33  fn from_str(s: &str) -> result::Result<Self, Self::Err> {
34    match s {
35      "reads" => Ok(Self::Reads),
36      "variants" => Ok(Self::Variants),
37      _ => Err(()),
38    }
39  }
40}
41
42/// Match the format from a query parameter.
43pub fn match_format_from_query(
44  endpoint: &Endpoint,
45  query: &HashMap<String, String>,
46) -> Result<Format> {
47  match_format(endpoint, query.get("format"))
48}
49
50/// Get the format from the string.
51pub fn match_format(endpoint: &Endpoint, format: Option<impl Into<String>>) -> Result<Format> {
52  let format = format.map(Into::into).map(|format| format.to_lowercase());
53
54  match (endpoint, format) {
55    (Endpoint::Reads, None) => Ok(Bam),
56    (Endpoint::Variants, None) => Ok(Vcf),
57    (Endpoint::Reads, Some(s)) if s == "bam" => Ok(Bam),
58    (Endpoint::Reads, Some(s)) if s == "cram" => Ok(Cram),
59    (Endpoint::Variants, Some(s)) if s == "vcf" => Ok(Vcf),
60    (Endpoint::Variants, Some(s)) if s == "bcf" => Ok(Bcf),
61    (_, Some(format)) => Err(HtsGetError::UnsupportedFormat(format!(
62      "{format} isn't a supported format for this endpoint"
63    ))),
64  }
65}
66
67fn convert_to_query(request: Request, format: Format) -> Result<Query> {
68  let query = request.query().clone();
69
70  set_query_builder(
71    QueryBuilder::new(request, format),
72    query.get("class"),
73    query.get("referenceName"),
74    query.get("fields"),
75    (query.get("tags"), query.get("notags")),
76    (query.get("start"), query.get("end")),
77    query.get("encryptionScheme"),
78  )
79}
80
81fn set_query_builder(
82  builder: QueryBuilder,
83  class: Option<impl Into<String>>,
84  reference_name: Option<impl Into<String>>,
85  fields: Option<impl Into<String>>,
86  (tags, no_tags): (Option<impl Into<String>>, Option<impl Into<String>>),
87  (start, end): (Option<impl Into<String>>, Option<impl Into<String>>),
88  _encryption_scheme: Option<impl Into<String>>,
89) -> Result<Query> {
90  let builder = builder
91    .with_class(class)?
92    .with_fields(fields)
93    .with_tags(tags, no_tags)?
94    .with_reference_name(reference_name)
95    .with_range(start, end)?;
96
97  cfg_if! {
98    if #[cfg(feature = "experimental")] {
99      Ok(builder.with_encryption_scheme(_encryption_scheme)?.build())
100    } else {
101      Ok(builder.build())
102    }
103  }
104}
105
106fn merge_responses(responses: Vec<Response>) -> Option<Response> {
107  responses.into_iter().reduce(|mut acc, mut response| {
108    acc.urls.append(&mut response.urls);
109    acc
110  })
111}
112
113#[cfg(test)]
114mod tests {
115  use std::collections::HashMap;
116  use std::path::PathBuf;
117
118  use htsget_config::storage;
119  use htsget_config::types::{Headers, JsonResponse, Request, Scheme, Url};
120  use htsget_search::FileStorage;
121  use htsget_search::HtsGet;
122  use htsget_search::Storage;
123  use htsget_search::from_storage::HtsGetFromStorage;
124  use http::uri::Authority;
125
126  use super::*;
127
128  #[test]
129  fn match_with_invalid_format() {
130    assert!(matches!(
131      match_format(&Endpoint::Reads, Some("Invalid".to_string())).unwrap_err(),
132      HtsGetError::UnsupportedFormat(_)
133    ));
134  }
135
136  #[test]
137  fn match_with_invalid_endpoint() {
138    assert!(matches!(
139      match_format(&Endpoint::Variants, Some("bam".to_string())).unwrap_err(),
140      HtsGetError::UnsupportedFormat(_)
141    ));
142  }
143
144  #[test]
145  fn match_with_valid_format() {
146    assert!(matches!(
147      match_format(&Endpoint::Reads, Some("bam".to_string())).unwrap(),
148      Bam,
149    ));
150  }
151
152  #[tokio::test]
153  async fn get_request() {
154    let request = HashMap::new();
155
156    let mut expected_response_headers = Headers::default();
157    expected_response_headers.insert("Range".to_string(), "bytes=0-2596798".to_string());
158
159    let request = Request::new(
160      "bam/htsnexus_test_NA12878".to_string(),
161      request,
162      Default::default(),
163    );
164
165    assert_eq!(
166      get(get_searcher(), request, Endpoint::Reads, None, None).await,
167      Ok(expected_bam_json_response(expected_response_headers))
168    );
169  }
170
171  #[tokio::test]
172  async fn get_reads_request_with_variants_format() {
173    let mut request = HashMap::new();
174    request.insert("format".to_string(), "VCF".to_string());
175
176    let request = Request::new(
177      "bam/htsnexus_test_NA12878".to_string(),
178      request,
179      Default::default(),
180    );
181
182    assert!(matches!(
183      get(get_searcher(), request, Endpoint::Reads, None, None).await,
184      Err(HtsGetError::UnsupportedFormat(_))
185    ));
186  }
187
188  #[tokio::test]
189  async fn get_request_with_range() {
190    let mut request = HashMap::new();
191    request.insert("referenceName".to_string(), "chrM".to_string());
192    request.insert("start".to_string(), "149".to_string());
193    request.insert("end".to_string(), "200".to_string());
194
195    let mut expected_response_headers = Headers::default();
196    expected_response_headers.insert("Range".to_string(), "bytes=0-3493".to_string());
197
198    let request = Request::new(
199      "vcf/sample1-bcbio-cancer".to_string(),
200      request,
201      Default::default(),
202    );
203
204    assert_eq!(
205      get(get_searcher(), request, Endpoint::Variants, None, None).await,
206      Ok(expected_vcf_json_response(expected_response_headers))
207    );
208  }
209
210  #[tokio::test]
211  async fn post_request() {
212    let request = Request::new_with_id("bam/htsnexus_test_NA12878".to_string());
213    let body = PostRequest {
214      format: None,
215      class: None,
216      fields: None,
217      tags: None,
218      notags: None,
219      regions: None,
220      encryption_scheme: None,
221    };
222
223    let mut expected_response_headers = Headers::default();
224    expected_response_headers.insert("Range".to_string(), "bytes=0-2596798".to_string());
225
226    assert_eq!(
227      post(get_searcher(), body, request, Endpoint::Reads, None, None).await,
228      Ok(expected_bam_json_response(expected_response_headers))
229    );
230  }
231
232  #[tokio::test]
233  async fn post_variants_request_with_reads_format() {
234    let request = Request::new_with_id("bam/htsnexus_test_NA12878".to_string());
235    let body = PostRequest {
236      format: Some("BAM".to_string()),
237      class: None,
238      fields: None,
239      tags: None,
240      notags: None,
241      regions: None,
242      encryption_scheme: None,
243    };
244
245    assert!(matches!(
246      post(
247        get_searcher(),
248        body,
249        request,
250        Endpoint::Variants,
251        None,
252        None
253      )
254      .await,
255      Err(HtsGetError::UnsupportedFormat(_))
256    ));
257  }
258
259  #[tokio::test]
260  async fn post_request_with_range() {
261    let request = Request::new_with_id("vcf/sample1-bcbio-cancer".to_string());
262    let body = PostRequest {
263      format: Some("VCF".to_string()),
264      class: None,
265      fields: None,
266      tags: None,
267      notags: None,
268      regions: Some(vec![Region {
269        reference_name: "chrM".to_string(),
270        start: Some(149),
271        end: Some(200),
272      }]),
273      encryption_scheme: None,
274    };
275
276    let mut expected_response_headers = Headers::default();
277    expected_response_headers.insert("Range".to_string(), "bytes=0-3493".to_string());
278
279    assert_eq!(
280      post(
281        get_searcher(),
282        body,
283        request,
284        Endpoint::Variants,
285        None,
286        None
287      )
288      .await,
289      Ok(expected_vcf_json_response(expected_response_headers))
290    );
291  }
292
293  fn expected_vcf_json_response(headers: Headers) -> JsonResponse {
294    JsonResponse::from(Response::new(
295      Vcf,
296      vec![
297        Url::new("http://127.0.0.1:8081/vcf/sample1-bcbio-cancer.vcf.gz".to_string())
298          .with_headers(headers),
299      ],
300    ))
301  }
302
303  fn expected_bam_json_response(headers: Headers) -> JsonResponse {
304    JsonResponse::from(Response::new(
305      Bam,
306      vec![
307        Url::new("http://127.0.0.1:8081/bam/htsnexus_test_NA12878.bam".to_string())
308          .with_headers(headers),
309      ],
310    ))
311  }
312
313  fn get_base_path() -> PathBuf {
314    std::env::current_dir()
315      .unwrap()
316      .parent()
317      .unwrap()
318      .join("data")
319  }
320
321  fn get_searcher() -> impl HtsGet + Clone {
322    HtsGetFromStorage::new(Storage::new(
323      FileStorage::new(
324        get_base_path(),
325        storage::file::File::new(
326          Scheme::Http,
327          Authority::from_static("127.0.0.1:8081"),
328          "data".to_string(),
329        ),
330        vec![],
331      )
332      .unwrap(),
333    ))
334  }
335}