htsget_http/
query_builder.rs

1use std::collections::HashSet;
2
3#[cfg(feature = "experimental")]
4use htsget_config::encryption_scheme::EncryptionScheme;
5use htsget_config::types::{Class, Fields, Format, Interval, Query, Request, Tags};
6use tracing::instrument;
7
8use crate::error::{HtsGetError, Result};
9
10/// A helper struct to construct a [Query] from [Strings](String)
11#[derive(Debug)]
12pub struct QueryBuilder {
13  query: Query,
14}
15
16impl QueryBuilder {
17  pub fn new(request: Request, format: Format) -> Self {
18    let id = request.path().to_string();
19
20    Self {
21      query: Query::new(id, format, request),
22    }
23  }
24
25  pub fn build(self) -> Query {
26    self.query
27  }
28
29  #[instrument(level = "trace", skip_all, ret)]
30  pub fn with_class(mut self, class: Option<impl Into<String>>) -> Result<Self> {
31    let class = class.map(Into::into);
32
33    self.query = self.query.with_class(match class {
34      None => Class::Body,
35      Some(class) if class == "header" => Class::Header,
36      Some(class) => {
37        return Err(HtsGetError::InvalidInput(format!(
38          "invalid class `{class}`"
39        )));
40      }
41    });
42
43    Ok(self)
44  }
45
46  #[instrument(level = "trace", skip_all, ret)]
47  pub fn with_reference_name(mut self, reference_name: Option<impl Into<String>>) -> Self {
48    if let Some(reference_name) = reference_name {
49      self.query = self.query.with_reference_name(reference_name);
50    }
51    self
52  }
53
54  /// Set the interval for the query.
55  pub fn with_interval(mut self, interval: Interval) -> Self {
56    self.query = self.query.with_interval(interval);
57    self
58  }
59
60  #[instrument(level = "trace", skip_all, ret)]
61  pub fn with_range(
62    self,
63    start: Option<impl Into<String>>,
64    end: Option<impl Into<String>>,
65  ) -> Result<Self> {
66    let start = start
67      .map(Into::into)
68      .map(|start| {
69        start
70          .parse::<u32>()
71          .map_err(|err| HtsGetError::InvalidInput(format!("`{start}` isn't a valid start: {err}")))
72      })
73      .transpose()?;
74    let end = end
75      .map(Into::into)
76      .map(|end| {
77        end
78          .parse::<u32>()
79          .map_err(|err| HtsGetError::InvalidInput(format!("`{end}` isn't a valid end: {err}")))
80      })
81      .transpose()?;
82
83    self.with_range_from_u32(start, end)
84  }
85
86  pub fn with_range_from_u32(mut self, start: Option<u32>, end: Option<u32>) -> Result<Self> {
87    if let Some(start) = start {
88      self.query = self.query.with_start(start);
89    }
90    if let Some(end) = end {
91      self.query = self.query.with_end(end);
92    }
93
94    if (self.query.interval().start().is_some() || self.query.interval().end().is_some())
95      && self
96        .query
97        .reference_name()
98        .filter(|name| *name != "*")
99        .is_none()
100    {
101      return Err(HtsGetError::InvalidInput(
102        "reference name must be specified with start or end range".to_string(),
103      ));
104    }
105
106    if let (Some(start), Some(end)) = &(self.query.interval().start(), self.query.interval().end())
107    {
108      if start > end {
109        return Err(HtsGetError::InvalidRange(format!(
110          "end is greater than start (`{start}` > `{end}`)"
111        )));
112      }
113    }
114
115    Ok(self)
116  }
117
118  #[instrument(level = "trace", skip_all, ret)]
119  pub fn with_fields(self, fields: Option<impl Into<String>>) -> Self {
120    self.with_fields_from_vec(
121      fields.map(|fields| fields.into().split(',').map(|s| s.to_string()).collect()),
122    )
123  }
124
125  pub fn with_fields_from_vec(mut self, fields: Option<Vec<impl Into<String>>>) -> Self {
126    if let Some(fields) = fields {
127      self.query = self
128        .query
129        .with_fields(Fields::List(fields.into_iter().map(Into::into).collect()));
130    }
131
132    self
133  }
134
135  #[instrument(level = "trace", skip_all, ret)]
136  pub fn with_tags(
137    self,
138    tags: Option<impl Into<String>>,
139    notags: Option<impl Into<String>>,
140  ) -> Result<Self> {
141    self.with_tags_from_vec(
142      tags.map(|tags| tags.into().split(',').map(|s| s.to_string()).collect()),
143      notags.map(|notags| notags.into().split(',').map(|s| s.to_string()).collect()),
144    )
145  }
146
147  pub fn with_tags_from_vec(
148    mut self,
149    tags: Option<Vec<impl Into<String>>>,
150    notags: Option<Vec<impl Into<String>>>,
151  ) -> Result<Self> {
152    let notags = match notags {
153      Some(notags) => notags.into_iter().map(Into::into).collect(),
154      None => vec![],
155    };
156
157    if let Some(tags) = tags {
158      let tags: HashSet<String> = tags.into_iter().map(Into::into).collect();
159      if tags.iter().any(|tag| notags.contains(tag)) {
160        return Err(HtsGetError::InvalidInput(
161          "tags and notags can't intersect".to_string(),
162        ));
163      }
164      self.query = self.query.with_tags(Tags::List(tags));
165    };
166
167    if !notags.is_empty() {
168      self.query = self.query.with_no_tags(notags);
169    }
170
171    Ok(self)
172  }
173
174  /// Set the encryption scheme.
175  #[cfg(feature = "experimental")]
176  pub fn with_encryption_scheme(
177    mut self,
178    encryption_scheme: Option<impl Into<String>>,
179  ) -> Result<Self> {
180    if let Some(scheme) = encryption_scheme {
181      let scheme = match scheme.into().to_lowercase().as_str() {
182        "c4gh" => Ok(EncryptionScheme::C4GH),
183        scheme => Err(HtsGetError::UnsupportedFormat(format!(
184          "invalid encryption scheme `{scheme}`"
185        ))),
186      }?;
187
188      self.query = self.query.with_encryption_scheme(scheme);
189    }
190
191    Ok(self)
192  }
193}
194
195#[cfg(test)]
196mod tests {
197  use htsget_config::types::Format::{Bam, Vcf};
198  use htsget_config::types::NoTags;
199
200  use super::*;
201
202  #[test]
203  fn query_with_id() {
204    let request = Request::new_with_id("ValidId".to_string());
205    assert_eq!(
206      QueryBuilder::new(request, Bam).build().id(),
207      "ValidId".to_string()
208    );
209  }
210
211  #[test]
212  fn query_with_format() {
213    let request = Request::new_with_id("ValidId".to_string());
214    assert_eq!(QueryBuilder::new(request, Vcf).build().format(), Vcf);
215  }
216
217  #[test]
218  fn query_with_class() {
219    let request = Request::new_with_id("ValidId".to_string());
220
221    assert_eq!(
222      QueryBuilder::new(request, Bam)
223        .with_class(Some("header"))
224        .unwrap()
225        .build()
226        .class(),
227      Class::Header
228    );
229  }
230
231  #[test]
232  fn query_with_reference_name() {
233    let request = Request::new_with_id("ValidId".to_string());
234
235    assert_eq!(
236      QueryBuilder::new(request, Bam)
237        .with_reference_name(Some("ValidName"))
238        .build()
239        .reference_name(),
240      Some("ValidName")
241    );
242  }
243
244  #[test]
245  fn query_with_range() {
246    let request = Request::new_with_id("ValidId".to_string());
247
248    let query = QueryBuilder::new(request, Bam)
249      .with_reference_name(Some("ValidName"))
250      .with_range(Some("3"), Some("5"))
251      .unwrap()
252      .build();
253    assert_eq!(
254      (query.interval().start(), query.interval().end()),
255      (Some(3), Some(5))
256    );
257  }
258
259  #[test]
260  fn query_with_range_but_without_reference_name() {
261    let request = Request::new_with_id("ValidId".to_string());
262
263    assert!(matches!(
264      QueryBuilder::new(request, Bam)
265        .with_range(Some("3"), Some("5"))
266        .unwrap_err(),
267      HtsGetError::InvalidInput(_)
268    ));
269  }
270
271  #[test]
272  fn query_with_invalid_start() {
273    let request = Request::new_with_id("ValidId".to_string());
274
275    assert!(matches!(
276      QueryBuilder::new(request, Bam)
277        .with_reference_name(Some("ValidName"))
278        .with_range(Some("a"), Some("5"))
279        .unwrap_err(),
280      HtsGetError::InvalidInput(_)
281    ));
282  }
283
284  #[test]
285  fn query_with_invalid_end() {
286    let request = Request::new_with_id("ValidId".to_string());
287
288    assert!(matches!(
289      QueryBuilder::new(request, Bam)
290        .with_reference_name(Some("ValidName"))
291        .with_range(Some("5"), Some("a"))
292        .unwrap_err(),
293      HtsGetError::InvalidInput(_)
294    ));
295  }
296
297  #[test]
298  fn query_with_invalid_range() {
299    let request = Request::new_with_id("ValidId".to_string());
300
301    assert!(matches!(
302      QueryBuilder::new(request, Bam)
303        .with_reference_name(Some("ValidName"))
304        .with_range(Some("5"), Some("3"))
305        .unwrap_err(),
306      HtsGetError::InvalidRange(_)
307    ));
308  }
309
310  #[test]
311  fn query_with_fields() {
312    let request = Request::new_with_id("ValidId".to_string());
313
314    assert_eq!(
315      QueryBuilder::new(request, Bam)
316        .with_fields(Some("header,part1,part2"))
317        .build()
318        .fields(),
319      &Fields::List(HashSet::from_iter(vec![
320        "header".to_string(),
321        "part1".to_string(),
322        "part2".to_string()
323      ]))
324    );
325  }
326
327  #[test]
328  fn query_with_tags() {
329    let request = Request::new_with_id("ValidId".to_string());
330
331    let query = QueryBuilder::new(request, Bam)
332      .with_tags(Some("header,part1,part2"), Some("part3"))
333      .unwrap()
334      .build();
335    assert_eq!(
336      query.tags(),
337      &Tags::List(HashSet::from_iter(vec![
338        "header".to_string(),
339        "part1".to_string(),
340        "part2".to_string()
341      ]))
342    );
343    assert_eq!(
344      query.no_tags(),
345      &NoTags(Some(HashSet::from_iter(vec!["part3".to_string()])))
346    );
347  }
348
349  #[test]
350  fn query_with_invalid_tags() {
351    let request = Request::new_with_id("ValidId".to_string());
352
353    let query = QueryBuilder::new(request, Bam)
354      .with_tags(Some("header,part1,part2"), Some("part3"))
355      .unwrap()
356      .build();
357    assert_eq!(
358      query.tags(),
359      &Tags::List(HashSet::from_iter(vec![
360        "header".to_string(),
361        "part1".to_string(),
362        "part2".to_string()
363      ]))
364    );
365    assert_eq!(
366      query.no_tags(),
367      &NoTags(Some(HashSet::from_iter(vec!["part3".to_string()])))
368    );
369  }
370}