htsget_http/
lib.rs

1use cfg_if::cfg_if;
2pub use error::{HtsGetError, Result};
3pub use htsget_config::config::Config;
4use htsget_config::types::Format::{Bam, Bcf, Cram, Vcf};
5use htsget_config::types::{Format, Query, Request, Response};
6pub use http_core::{get, post};
7pub use post_request::{PostRequest, Region};
8use query_builder::QueryBuilder;
9pub use service_info::get_service_info_json;
10pub use service_info::{Htsget, ServiceInfo, Type};
11use std::collections::HashMap;
12use std::fmt::{Display, Formatter};
13use std::str::FromStr;
14use std::{fmt, result};
15
16pub mod error;
17pub mod http_core;
18pub mod middleware;
19pub mod post_request;
20pub mod query_builder;
21pub mod service_info;
22
23/// A enum to distinguish between the two endpoint defined in the
24/// [HtsGet specification](https://samtools.github.io/hts-specs/htsget.html)
25#[derive(Debug, PartialEq, Eq)]
26pub enum Endpoint {
27  Reads,
28  Variants,
29}
30
31impl FromStr for Endpoint {
32  type Err = ();
33
34  fn from_str(s: &str) -> result::Result<Self, Self::Err> {
35    match s {
36      "reads" => Ok(Self::Reads),
37      "variants" => Ok(Self::Variants),
38      _ => Err(()),
39    }
40  }
41}
42
43impl Display for Endpoint {
44  fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
45    match self {
46      Self::Reads => write!(f, "reads"),
47      Self::Variants => write!(f, "variants"),
48    }
49  }
50}
51
52/// Match the format from a query parameter.
53pub fn match_format_from_query(
54  endpoint: &Endpoint,
55  query: &HashMap<String, String>,
56) -> Result<Format> {
57  match_format(endpoint, query.get("format"))
58}
59
60/// Get the format from the string.
61pub fn match_format(endpoint: &Endpoint, format: Option<impl Into<String>>) -> Result<Format> {
62  let format = format.map(Into::into).map(|format| format.to_lowercase());
63
64  match (endpoint, format) {
65    (Endpoint::Reads, None) => Ok(Bam),
66    (Endpoint::Variants, None) => Ok(Vcf),
67    (Endpoint::Reads, Some(s)) if s == "bam" => Ok(Bam),
68    (Endpoint::Reads, Some(s)) if s == "cram" => Ok(Cram),
69    (Endpoint::Variants, Some(s)) if s == "vcf" => Ok(Vcf),
70    (Endpoint::Variants, Some(s)) if s == "bcf" => Ok(Bcf),
71    (_, Some(format)) => Err(HtsGetError::UnsupportedFormat(format!(
72      "{format} isn't a supported format for this endpoint"
73    ))),
74  }
75}
76
77fn convert_to_query(request: Request, format: Format) -> Result<Query> {
78  let query = request.query().clone();
79
80  set_query_builder(
81    QueryBuilder::new(request, format),
82    query.get("class"),
83    query.get("referenceName"),
84    query.get("fields"),
85    (query.get("tags"), query.get("notags")),
86    (query.get("start"), query.get("end")),
87    query.get("encryptionScheme"),
88  )
89}
90
91fn set_query_builder(
92  builder: QueryBuilder,
93  class: Option<impl Into<String>>,
94  reference_name: Option<impl Into<String>>,
95  fields: Option<impl Into<String>>,
96  (tags, no_tags): (Option<impl Into<String>>, Option<impl Into<String>>),
97  (start, end): (Option<impl Into<String>>, Option<impl Into<String>>),
98  _encryption_scheme: Option<impl Into<String>>,
99) -> Result<Query> {
100  let builder = builder
101    .with_class(class)?
102    .with_fields(fields)
103    .with_tags(tags, no_tags)?
104    .with_reference_name(reference_name)
105    .with_range(start, end)?;
106
107  cfg_if! {
108    if #[cfg(feature = "experimental")] {
109      Ok(builder.with_encryption_scheme(_encryption_scheme)?.build())
110    } else {
111      Ok(builder.build())
112    }
113  }
114}
115
116fn merge_responses(responses: Vec<Response>) -> Option<Response> {
117  responses.into_iter().reduce(|mut acc, mut response| {
118    acc.urls.append(&mut response.urls);
119    acc
120  })
121}
122
123#[cfg(test)]
124mod tests {
125  use std::collections::HashMap;
126  use std::path::PathBuf;
127
128  use htsget_config::storage;
129  use htsget_config::types::{Headers, JsonResponse, Request, Scheme, Url};
130  use htsget_search::FileStorage;
131  use htsget_search::HtsGet;
132  use htsget_search::Storage;
133  use htsget_search::from_storage::HtsGetFromStorage;
134  use http::uri::Authority;
135
136  use super::*;
137
138  #[test]
139  fn match_with_invalid_format() {
140    assert!(matches!(
141      match_format(&Endpoint::Reads, Some("Invalid".to_string())).unwrap_err(),
142      HtsGetError::UnsupportedFormat(_)
143    ));
144  }
145
146  #[test]
147  fn match_with_invalid_endpoint() {
148    assert!(matches!(
149      match_format(&Endpoint::Variants, Some("bam".to_string())).unwrap_err(),
150      HtsGetError::UnsupportedFormat(_)
151    ));
152  }
153
154  #[test]
155  fn match_with_valid_format() {
156    assert!(matches!(
157      match_format(&Endpoint::Reads, Some("bam".to_string())).unwrap(),
158      Bam,
159    ));
160  }
161
162  #[tokio::test]
163  async fn get_request() {
164    let request = HashMap::new();
165
166    let mut expected_response_headers = Headers::default();
167    expected_response_headers.insert("Range".to_string(), "bytes=0-2596798".to_string());
168
169    let request = Request::new(
170      "bam/htsnexus_test_NA12878".to_string(),
171      request,
172      Default::default(),
173    );
174
175    assert_eq!(
176      get(get_searcher(), request, Endpoint::Reads, None, None, None).await,
177      Ok(expected_bam_json_response(expected_response_headers))
178    );
179  }
180
181  #[tokio::test]
182  async fn get_reads_request_with_variants_format() {
183    let mut request = HashMap::new();
184    request.insert("format".to_string(), "VCF".to_string());
185
186    let request = Request::new(
187      "bam/htsnexus_test_NA12878".to_string(),
188      request,
189      Default::default(),
190    );
191
192    assert!(matches!(
193      get(get_searcher(), request, Endpoint::Reads, None, None, None).await,
194      Err(HtsGetError::UnsupportedFormat(_))
195    ));
196  }
197
198  #[tokio::test]
199  async fn get_request_with_range() {
200    let mut request = HashMap::new();
201    request.insert("referenceName".to_string(), "chrM".to_string());
202    request.insert("start".to_string(), "149".to_string());
203    request.insert("end".to_string(), "200".to_string());
204
205    let mut expected_response_headers = Headers::default();
206    expected_response_headers.insert("Range".to_string(), "bytes=0-3493".to_string());
207
208    let request = Request::new(
209      "vcf/sample1-bcbio-cancer".to_string(),
210      request,
211      Default::default(),
212    );
213
214    assert_eq!(
215      get(
216        get_searcher(),
217        request,
218        Endpoint::Variants,
219        None,
220        None,
221        None
222      )
223      .await,
224      Ok(expected_vcf_json_response(expected_response_headers))
225    );
226  }
227
228  #[tokio::test]
229  async fn post_request() {
230    let request = Request::new_with_id("bam/htsnexus_test_NA12878".to_string());
231    let body = PostRequest {
232      format: None,
233      class: None,
234      fields: None,
235      tags: None,
236      notags: None,
237      regions: None,
238      encryption_scheme: None,
239    };
240
241    let mut expected_response_headers = Headers::default();
242    expected_response_headers.insert("Range".to_string(), "bytes=0-2596798".to_string());
243
244    assert_eq!(
245      post(
246        get_searcher(),
247        body,
248        request,
249        Endpoint::Reads,
250        None,
251        None,
252        None
253      )
254      .await,
255      Ok(expected_bam_json_response(expected_response_headers))
256    );
257  }
258
259  #[tokio::test]
260  async fn post_variants_request_with_reads_format() {
261    let request = Request::new_with_id("bam/htsnexus_test_NA12878".to_string());
262    let body = PostRequest {
263      format: Some("BAM".to_string()),
264      class: None,
265      fields: None,
266      tags: None,
267      notags: None,
268      regions: None,
269      encryption_scheme: None,
270    };
271
272    assert!(matches!(
273      post(
274        get_searcher(),
275        body,
276        request,
277        Endpoint::Variants,
278        None,
279        None,
280        None
281      )
282      .await,
283      Err(HtsGetError::UnsupportedFormat(_))
284    ));
285  }
286
287  #[tokio::test]
288  async fn post_request_with_range() {
289    let request = Request::new_with_id("vcf/sample1-bcbio-cancer".to_string());
290    let body = PostRequest {
291      format: Some("VCF".to_string()),
292      class: None,
293      fields: None,
294      tags: None,
295      notags: None,
296      regions: Some(vec![Region {
297        reference_name: "chrM".to_string(),
298        start: Some(149),
299        end: Some(200),
300      }]),
301      encryption_scheme: None,
302    };
303
304    let mut expected_response_headers = Headers::default();
305    expected_response_headers.insert("Range".to_string(), "bytes=0-3493".to_string());
306
307    assert_eq!(
308      post(
309        get_searcher(),
310        body,
311        request,
312        Endpoint::Variants,
313        None,
314        None,
315        None
316      )
317      .await,
318      Ok(expected_vcf_json_response(expected_response_headers))
319    );
320  }
321
322  fn expected_vcf_json_response(headers: Headers) -> JsonResponse {
323    JsonResponse::from(Response::new(
324      Vcf,
325      vec![
326        Url::new("http://127.0.0.1:8081/vcf/sample1-bcbio-cancer.vcf.gz".to_string())
327          .with_headers(headers),
328      ],
329    ))
330  }
331
332  fn expected_bam_json_response(headers: Headers) -> JsonResponse {
333    JsonResponse::from(Response::new(
334      Bam,
335      vec![
336        Url::new("http://127.0.0.1:8081/bam/htsnexus_test_NA12878.bam".to_string())
337          .with_headers(headers),
338      ],
339    ))
340  }
341
342  fn get_base_path() -> PathBuf {
343    std::env::current_dir()
344      .unwrap()
345      .parent()
346      .unwrap()
347      .join("data")
348  }
349
350  fn get_searcher() -> impl HtsGet + Clone {
351    HtsGetFromStorage::new(Storage::new(
352      FileStorage::new(
353        get_base_path(),
354        storage::file::File::new(
355          Scheme::Http,
356          Authority::from_static("127.0.0.1:8081"),
357          "data".to_string(),
358        ),
359        vec![],
360      )
361      .unwrap(),
362    ))
363  }
364}