htsget_http/
query_builder.rs

1use std::collections::HashSet;
2
3#[cfg(feature = "experimental")]
4use htsget_config::encryption_scheme::EncryptionScheme;
5use htsget_config::types::{Class, Fields, Format, Query, Request, Tags};
6use tracing::instrument;
7
8use crate::error::{HtsGetError, Result};
9
10/// A helper struct to construct a [Query] from [Strings](String)
11#[derive(Debug)]
12pub struct QueryBuilder {
13  query: Query,
14}
15
16impl QueryBuilder {
17  pub fn new(request: Request, format: Format) -> Self {
18    let id = request.path().to_string();
19
20    Self {
21      query: Query::new(id, format, request),
22    }
23  }
24
25  pub fn build(self) -> Query {
26    self.query
27  }
28
29  #[instrument(level = "trace", skip_all, ret)]
30  pub fn with_class(mut self, class: Option<impl Into<String>>) -> Result<Self> {
31    let class = class.map(Into::into);
32
33    self.query = self.query.with_class(match class {
34      None => Class::Body,
35      Some(class) if class == "header" => Class::Header,
36      Some(class) => {
37        return Err(HtsGetError::InvalidInput(format!(
38          "invalid class `{class}`"
39        )));
40      }
41    });
42
43    Ok(self)
44  }
45
46  #[instrument(level = "trace", skip_all, ret)]
47  pub fn with_reference_name(mut self, reference_name: Option<impl Into<String>>) -> Self {
48    if let Some(reference_name) = reference_name {
49      self.query = self.query.with_reference_name(reference_name);
50    }
51    self
52  }
53
54  #[instrument(level = "trace", skip_all, ret)]
55  pub fn with_range(
56    self,
57    start: Option<impl Into<String>>,
58    end: Option<impl Into<String>>,
59  ) -> Result<Self> {
60    let start = start
61      .map(Into::into)
62      .map(|start| {
63        start
64          .parse::<u32>()
65          .map_err(|err| HtsGetError::InvalidInput(format!("`{start}` isn't a valid start: {err}")))
66      })
67      .transpose()?;
68    let end = end
69      .map(Into::into)
70      .map(|end| {
71        end
72          .parse::<u32>()
73          .map_err(|err| HtsGetError::InvalidInput(format!("`{end}` isn't a valid end: {err}")))
74      })
75      .transpose()?;
76
77    self.with_range_from_u32(start, end)
78  }
79
80  pub fn with_range_from_u32(mut self, start: Option<u32>, end: Option<u32>) -> Result<Self> {
81    if let Some(start) = start {
82      self.query = self.query.with_start(start);
83    }
84    if let Some(end) = end {
85      self.query = self.query.with_end(end);
86    }
87
88    if (self.query.interval().start().is_some() || self.query.interval().end().is_some())
89      && self
90        .query
91        .reference_name()
92        .filter(|name| *name != "*")
93        .is_none()
94    {
95      return Err(HtsGetError::InvalidInput(
96        "reference name must be specified with start or end range".to_string(),
97      ));
98    }
99
100    if let (Some(start), Some(end)) = &(self.query.interval().start(), self.query.interval().end())
101    {
102      if start > end {
103        return Err(HtsGetError::InvalidRange(format!(
104          "end is greater than start (`{start}` > `{end}`)"
105        )));
106      }
107    }
108
109    Ok(self)
110  }
111
112  #[instrument(level = "trace", skip_all, ret)]
113  pub fn with_fields(self, fields: Option<impl Into<String>>) -> Self {
114    self.with_fields_from_vec(
115      fields.map(|fields| fields.into().split(',').map(|s| s.to_string()).collect()),
116    )
117  }
118
119  pub fn with_fields_from_vec(mut self, fields: Option<Vec<impl Into<String>>>) -> Self {
120    if let Some(fields) = fields {
121      self.query = self
122        .query
123        .with_fields(Fields::List(fields.into_iter().map(Into::into).collect()));
124    }
125
126    self
127  }
128
129  #[instrument(level = "trace", skip_all, ret)]
130  pub fn with_tags(
131    self,
132    tags: Option<impl Into<String>>,
133    notags: Option<impl Into<String>>,
134  ) -> Result<Self> {
135    self.with_tags_from_vec(
136      tags.map(|tags| tags.into().split(',').map(|s| s.to_string()).collect()),
137      notags.map(|notags| notags.into().split(',').map(|s| s.to_string()).collect()),
138    )
139  }
140
141  pub fn with_tags_from_vec(
142    mut self,
143    tags: Option<Vec<impl Into<String>>>,
144    notags: Option<Vec<impl Into<String>>>,
145  ) -> Result<Self> {
146    let notags = match notags {
147      Some(notags) => notags.into_iter().map(Into::into).collect(),
148      None => vec![],
149    };
150
151    if let Some(tags) = tags {
152      let tags: HashSet<String> = tags.into_iter().map(Into::into).collect();
153      if tags.iter().any(|tag| notags.contains(tag)) {
154        return Err(HtsGetError::InvalidInput(
155          "tags and notags can't intersect".to_string(),
156        ));
157      }
158      self.query = self.query.with_tags(Tags::List(tags));
159    };
160
161    if !notags.is_empty() {
162      self.query = self.query.with_no_tags(notags);
163    }
164
165    Ok(self)
166  }
167
168  /// Set the encryption scheme.
169  #[cfg(feature = "experimental")]
170  pub fn with_encryption_scheme(
171    mut self,
172    encryption_scheme: Option<impl Into<String>>,
173  ) -> Result<Self> {
174    if let Some(scheme) = encryption_scheme {
175      let scheme = match scheme.into().to_lowercase().as_str() {
176        "c4gh" => Ok(EncryptionScheme::C4GH),
177        scheme => Err(HtsGetError::UnsupportedFormat(format!(
178          "invalid encryption scheme `{scheme}`"
179        ))),
180      }?;
181
182      self.query = self.query.with_encryption_scheme(scheme);
183    }
184
185    Ok(self)
186  }
187}
188
189#[cfg(test)]
190mod tests {
191  use htsget_config::types::Format::{Bam, Vcf};
192  use htsget_config::types::NoTags;
193
194  use super::*;
195
196  #[test]
197  fn query_with_id() {
198    let request = Request::new_with_id("ValidId".to_string());
199    assert_eq!(
200      QueryBuilder::new(request, Bam).build().id(),
201      "ValidId".to_string()
202    );
203  }
204
205  #[test]
206  fn query_with_format() {
207    let request = Request::new_with_id("ValidId".to_string());
208    assert_eq!(QueryBuilder::new(request, Vcf).build().format(), Vcf);
209  }
210
211  #[test]
212  fn query_with_class() {
213    let request = Request::new_with_id("ValidId".to_string());
214
215    assert_eq!(
216      QueryBuilder::new(request, Bam)
217        .with_class(Some("header"))
218        .unwrap()
219        .build()
220        .class(),
221      Class::Header
222    );
223  }
224
225  #[test]
226  fn query_with_reference_name() {
227    let request = Request::new_with_id("ValidId".to_string());
228
229    assert_eq!(
230      QueryBuilder::new(request, Bam)
231        .with_reference_name(Some("ValidName"))
232        .build()
233        .reference_name(),
234      Some("ValidName")
235    );
236  }
237
238  #[test]
239  fn query_with_range() {
240    let request = Request::new_with_id("ValidId".to_string());
241
242    let query = QueryBuilder::new(request, Bam)
243      .with_reference_name(Some("ValidName"))
244      .with_range(Some("3"), Some("5"))
245      .unwrap()
246      .build();
247    assert_eq!(
248      (query.interval().start(), query.interval().end()),
249      (Some(3), Some(5))
250    );
251  }
252
253  #[test]
254  fn query_with_range_but_without_reference_name() {
255    let request = Request::new_with_id("ValidId".to_string());
256
257    assert!(matches!(
258      QueryBuilder::new(request, Bam)
259        .with_range(Some("3"), Some("5"))
260        .unwrap_err(),
261      HtsGetError::InvalidInput(_)
262    ));
263  }
264
265  #[test]
266  fn query_with_invalid_start() {
267    let request = Request::new_with_id("ValidId".to_string());
268
269    assert!(matches!(
270      QueryBuilder::new(request, Bam)
271        .with_reference_name(Some("ValidName"))
272        .with_range(Some("a"), Some("5"))
273        .unwrap_err(),
274      HtsGetError::InvalidInput(_)
275    ));
276  }
277
278  #[test]
279  fn query_with_invalid_end() {
280    let request = Request::new_with_id("ValidId".to_string());
281
282    assert!(matches!(
283      QueryBuilder::new(request, Bam)
284        .with_reference_name(Some("ValidName"))
285        .with_range(Some("5"), Some("a"))
286        .unwrap_err(),
287      HtsGetError::InvalidInput(_)
288    ));
289  }
290
291  #[test]
292  fn query_with_invalid_range() {
293    let request = Request::new_with_id("ValidId".to_string());
294
295    assert!(matches!(
296      QueryBuilder::new(request, Bam)
297        .with_reference_name(Some("ValidName"))
298        .with_range(Some("5"), Some("3"))
299        .unwrap_err(),
300      HtsGetError::InvalidRange(_)
301    ));
302  }
303
304  #[test]
305  fn query_with_fields() {
306    let request = Request::new_with_id("ValidId".to_string());
307
308    assert_eq!(
309      QueryBuilder::new(request, Bam)
310        .with_fields(Some("header,part1,part2"))
311        .build()
312        .fields(),
313      &Fields::List(HashSet::from_iter(vec![
314        "header".to_string(),
315        "part1".to_string(),
316        "part2".to_string()
317      ]))
318    );
319  }
320
321  #[test]
322  fn query_with_tags() {
323    let request = Request::new_with_id("ValidId".to_string());
324
325    let query = QueryBuilder::new(request, Bam)
326      .with_tags(Some("header,part1,part2"), Some("part3"))
327      .unwrap()
328      .build();
329    assert_eq!(
330      query.tags(),
331      &Tags::List(HashSet::from_iter(vec![
332        "header".to_string(),
333        "part1".to_string(),
334        "part2".to_string()
335      ]))
336    );
337    assert_eq!(
338      query.no_tags(),
339      &NoTags(Some(HashSet::from_iter(vec!["part3".to_string()])))
340    );
341  }
342
343  #[test]
344  fn query_with_invalid_tags() {
345    let request = Request::new_with_id("ValidId".to_string());
346
347    let query = QueryBuilder::new(request, Bam)
348      .with_tags(Some("header,part1,part2"), Some("part3"))
349      .unwrap()
350      .build();
351    assert_eq!(
352      query.tags(),
353      &Tags::List(HashSet::from_iter(vec![
354        "header".to_string(),
355        "part1".to_string(),
356        "part2".to_string()
357      ]))
358    );
359    assert_eq!(
360      query.no_tags(),
361      &NoTags(Some(HashSet::from_iter(vec!["part3".to_string()])))
362    );
363  }
364}