htsget_http/
query_builder.rs

1use std::collections::HashSet;
2
3#[cfg(feature = "experimental")]
4use htsget_config::encryption_scheme::EncryptionScheme;
5use htsget_config::types::{Class, Fields, Format, Interval, Query, Request, Tags};
6use tracing::instrument;
7
8use crate::error::{HtsGetError, Result};
9
10/// A helper struct to construct a [Query] from [Strings](String)
11#[derive(Debug)]
12pub struct QueryBuilder {
13  query: Query,
14}
15
16impl QueryBuilder {
17  pub fn new(request: Request, format: Format) -> Self {
18    let id = request.path().to_string();
19
20    Self {
21      query: Query::new(id, format, request),
22    }
23  }
24
25  pub fn build(self) -> Query {
26    self.query
27  }
28
29  #[instrument(level = "trace", skip_all, ret)]
30  pub fn with_class(mut self, class: Option<impl Into<String>>) -> Result<Self> {
31    let class = class.map(Into::into);
32
33    self.query = self.query.with_class(match class {
34      None => Class::Body,
35      Some(class) if class == "header" => Class::Header,
36      Some(class) => {
37        return Err(HtsGetError::InvalidInput(format!(
38          "invalid class `{class}`"
39        )));
40      }
41    });
42
43    Ok(self)
44  }
45
46  #[instrument(level = "trace", skip_all, ret)]
47  pub fn with_reference_name(mut self, reference_name: Option<impl Into<String>>) -> Self {
48    if let Some(reference_name) = reference_name {
49      self.query = self.query.with_reference_name(reference_name);
50    }
51    self
52  }
53
54  /// Set the interval for the query.
55  pub fn with_interval(mut self, interval: Interval) -> Self {
56    self.query = self.query.with_interval(interval);
57    self
58  }
59
60  #[instrument(level = "trace", skip_all, ret)]
61  pub fn with_range(
62    self,
63    start: Option<impl Into<String>>,
64    end: Option<impl Into<String>>,
65  ) -> Result<Self> {
66    let start = start
67      .map(Into::into)
68      .map(|start| {
69        start
70          .parse::<u32>()
71          .map_err(|err| HtsGetError::InvalidInput(format!("`{start}` isn't a valid start: {err}")))
72      })
73      .transpose()?;
74    let end = end
75      .map(Into::into)
76      .map(|end| {
77        end
78          .parse::<u32>()
79          .map_err(|err| HtsGetError::InvalidInput(format!("`{end}` isn't a valid end: {err}")))
80      })
81      .transpose()?;
82
83    self.with_range_from_u32(start, end)
84  }
85
86  pub fn with_range_from_u32(mut self, start: Option<u32>, end: Option<u32>) -> Result<Self> {
87    if let Some(start) = start {
88      self.query = self.query.with_start(start);
89    }
90    if let Some(end) = end {
91      self.query = self.query.with_end(end);
92    }
93
94    if (self.query.interval().start().is_some() || self.query.interval().end().is_some())
95      && self
96        .query
97        .reference_name()
98        .filter(|name| *name != "*")
99        .is_none()
100    {
101      return Err(HtsGetError::InvalidInput(
102        "reference name must be specified with start or end range".to_string(),
103      ));
104    }
105
106    if let (Some(start), Some(end)) = &(self.query.interval().start(), self.query.interval().end())
107      && start > end
108    {
109      return Err(HtsGetError::InvalidRange(format!(
110        "end is greater than start (`{start}` > `{end}`)"
111      )));
112    }
113
114    Ok(self)
115  }
116
117  #[instrument(level = "trace", skip_all, ret)]
118  pub fn with_fields(self, fields: Option<impl Into<String>>) -> Self {
119    self.with_fields_from_vec(
120      fields.map(|fields| fields.into().split(',').map(|s| s.to_string()).collect()),
121    )
122  }
123
124  pub fn with_fields_from_vec(mut self, fields: Option<Vec<impl Into<String>>>) -> Self {
125    if let Some(fields) = fields {
126      self.query = self
127        .query
128        .with_fields(Fields::List(fields.into_iter().map(Into::into).collect()));
129    }
130
131    self
132  }
133
134  #[instrument(level = "trace", skip_all, ret)]
135  pub fn with_tags(
136    self,
137    tags: Option<impl Into<String>>,
138    notags: Option<impl Into<String>>,
139  ) -> Result<Self> {
140    self.with_tags_from_vec(
141      tags.map(|tags| tags.into().split(',').map(|s| s.to_string()).collect()),
142      notags.map(|notags| notags.into().split(',').map(|s| s.to_string()).collect()),
143    )
144  }
145
146  pub fn with_tags_from_vec(
147    mut self,
148    tags: Option<Vec<impl Into<String>>>,
149    notags: Option<Vec<impl Into<String>>>,
150  ) -> Result<Self> {
151    let notags = match notags {
152      Some(notags) => notags.into_iter().map(Into::into).collect(),
153      None => vec![],
154    };
155
156    if let Some(tags) = tags {
157      let tags: HashSet<String> = tags.into_iter().map(Into::into).collect();
158      if tags.iter().any(|tag| notags.contains(tag)) {
159        return Err(HtsGetError::InvalidInput(
160          "tags and notags can't intersect".to_string(),
161        ));
162      }
163      self.query = self.query.with_tags(Tags::List(tags));
164    };
165
166    if !notags.is_empty() {
167      self.query = self.query.with_no_tags(notags);
168    }
169
170    Ok(self)
171  }
172
173  /// Set the encryption scheme.
174  #[cfg(feature = "experimental")]
175  pub fn with_encryption_scheme(
176    mut self,
177    encryption_scheme: Option<impl Into<String>>,
178  ) -> Result<Self> {
179    if let Some(scheme) = encryption_scheme {
180      let scheme = match scheme.into().to_lowercase().as_str() {
181        "c4gh" => Ok(EncryptionScheme::C4GH),
182        scheme => Err(HtsGetError::UnsupportedFormat(format!(
183          "invalid encryption scheme `{scheme}`"
184        ))),
185      }?;
186
187      self.query = self.query.with_encryption_scheme(scheme);
188    }
189
190    Ok(self)
191  }
192}
193
194#[cfg(test)]
195mod tests {
196  use htsget_config::types::Format::{Bam, Vcf};
197  use htsget_config::types::NoTags;
198
199  use super::*;
200
201  #[test]
202  fn query_with_id() {
203    let request = Request::new_with_id("ValidId".to_string());
204    assert_eq!(
205      QueryBuilder::new(request, Bam).build().id(),
206      "ValidId".to_string()
207    );
208  }
209
210  #[test]
211  fn query_with_format() {
212    let request = Request::new_with_id("ValidId".to_string());
213    assert_eq!(QueryBuilder::new(request, Vcf).build().format(), Vcf);
214  }
215
216  #[test]
217  fn query_with_class() {
218    let request = Request::new_with_id("ValidId".to_string());
219
220    assert_eq!(
221      QueryBuilder::new(request, Bam)
222        .with_class(Some("header"))
223        .unwrap()
224        .build()
225        .class(),
226      Class::Header
227    );
228  }
229
230  #[test]
231  fn query_with_reference_name() {
232    let request = Request::new_with_id("ValidId".to_string());
233
234    assert_eq!(
235      QueryBuilder::new(request, Bam)
236        .with_reference_name(Some("ValidName"))
237        .build()
238        .reference_name(),
239      Some("ValidName")
240    );
241  }
242
243  #[test]
244  fn query_with_range() {
245    let request = Request::new_with_id("ValidId".to_string());
246
247    let query = QueryBuilder::new(request, Bam)
248      .with_reference_name(Some("ValidName"))
249      .with_range(Some("3"), Some("5"))
250      .unwrap()
251      .build();
252    assert_eq!(
253      (query.interval().start(), query.interval().end()),
254      (Some(3), Some(5))
255    );
256  }
257
258  #[test]
259  fn query_with_range_but_without_reference_name() {
260    let request = Request::new_with_id("ValidId".to_string());
261
262    assert!(matches!(
263      QueryBuilder::new(request, Bam)
264        .with_range(Some("3"), Some("5"))
265        .unwrap_err(),
266      HtsGetError::InvalidInput(_)
267    ));
268  }
269
270  #[test]
271  fn query_with_invalid_start() {
272    let request = Request::new_with_id("ValidId".to_string());
273
274    assert!(matches!(
275      QueryBuilder::new(request, Bam)
276        .with_reference_name(Some("ValidName"))
277        .with_range(Some("a"), Some("5"))
278        .unwrap_err(),
279      HtsGetError::InvalidInput(_)
280    ));
281  }
282
283  #[test]
284  fn query_with_invalid_end() {
285    let request = Request::new_with_id("ValidId".to_string());
286
287    assert!(matches!(
288      QueryBuilder::new(request, Bam)
289        .with_reference_name(Some("ValidName"))
290        .with_range(Some("5"), Some("a"))
291        .unwrap_err(),
292      HtsGetError::InvalidInput(_)
293    ));
294  }
295
296  #[test]
297  fn query_with_invalid_range() {
298    let request = Request::new_with_id("ValidId".to_string());
299
300    assert!(matches!(
301      QueryBuilder::new(request, Bam)
302        .with_reference_name(Some("ValidName"))
303        .with_range(Some("5"), Some("3"))
304        .unwrap_err(),
305      HtsGetError::InvalidRange(_)
306    ));
307  }
308
309  #[test]
310  fn query_with_fields() {
311    let request = Request::new_with_id("ValidId".to_string());
312
313    assert_eq!(
314      QueryBuilder::new(request, Bam)
315        .with_fields(Some("header,part1,part2"))
316        .build()
317        .fields(),
318      &Fields::List(HashSet::from_iter(vec![
319        "header".to_string(),
320        "part1".to_string(),
321        "part2".to_string()
322      ]))
323    );
324  }
325
326  #[test]
327  fn query_with_tags() {
328    let request = Request::new_with_id("ValidId".to_string());
329
330    let query = QueryBuilder::new(request, Bam)
331      .with_tags(Some("header,part1,part2"), Some("part3"))
332      .unwrap()
333      .build();
334    assert_eq!(
335      query.tags(),
336      &Tags::List(HashSet::from_iter(vec![
337        "header".to_string(),
338        "part1".to_string(),
339        "part2".to_string()
340      ]))
341    );
342    assert_eq!(
343      query.no_tags(),
344      &NoTags(Some(HashSet::from_iter(vec!["part3".to_string()])))
345    );
346  }
347
348  #[test]
349  fn query_with_invalid_tags() {
350    let request = Request::new_with_id("ValidId".to_string());
351
352    let query = QueryBuilder::new(request, Bam)
353      .with_tags(Some("header,part1,part2"), Some("part3"))
354      .unwrap()
355      .build();
356    assert_eq!(
357      query.tags(),
358      &Tags::List(HashSet::from_iter(vec![
359        "header".to_string(),
360        "part1".to_string(),
361        "part2".to_string()
362      ]))
363    );
364    assert_eq!(
365      query.no_tags(),
366      &NoTags(Some(HashSet::from_iter(vec!["part3".to_string()])))
367    );
368  }
369}