1use std::collections::HashSet;
2
3#[cfg(feature = "experimental")]
4use htsget_config::encryption_scheme::EncryptionScheme;
5use htsget_config::types::{Class, Fields, Format, Interval, Query, Request, Tags};
6use tracing::instrument;
7
8use crate::error::{HtsGetError, Result};
9
10#[derive(Debug)]
12pub struct QueryBuilder {
13 query: Query,
14}
15
16impl QueryBuilder {
17 pub fn new(request: Request, format: Format) -> Self {
18 let id = request.path().to_string();
19
20 Self {
21 query: Query::new(id, format, request),
22 }
23 }
24
25 pub fn build(self) -> Query {
26 self.query
27 }
28
29 #[instrument(level = "trace", skip_all, ret)]
30 pub fn with_class(mut self, class: Option<impl Into<String>>) -> Result<Self> {
31 let class = class.map(Into::into);
32
33 self.query = self.query.with_class(match class {
34 None => Class::Body,
35 Some(class) if class == "header" => Class::Header,
36 Some(class) => {
37 return Err(HtsGetError::InvalidInput(format!(
38 "invalid class `{class}`"
39 )));
40 }
41 });
42
43 Ok(self)
44 }
45
46 #[instrument(level = "trace", skip_all, ret)]
47 pub fn with_reference_name(mut self, reference_name: Option<impl Into<String>>) -> Self {
48 if let Some(reference_name) = reference_name {
49 self.query = self.query.with_reference_name(reference_name);
50 }
51 self
52 }
53
54 pub fn with_interval(mut self, interval: Interval) -> Self {
56 self.query = self.query.with_interval(interval);
57 self
58 }
59
60 #[instrument(level = "trace", skip_all, ret)]
61 pub fn with_range(
62 self,
63 start: Option<impl Into<String>>,
64 end: Option<impl Into<String>>,
65 ) -> Result<Self> {
66 let start = start
67 .map(Into::into)
68 .map(|start| {
69 start
70 .parse::<u32>()
71 .map_err(|err| HtsGetError::InvalidInput(format!("`{start}` isn't a valid start: {err}")))
72 })
73 .transpose()?;
74 let end = end
75 .map(Into::into)
76 .map(|end| {
77 end
78 .parse::<u32>()
79 .map_err(|err| HtsGetError::InvalidInput(format!("`{end}` isn't a valid end: {err}")))
80 })
81 .transpose()?;
82
83 self.with_range_from_u32(start, end)
84 }
85
86 pub fn with_range_from_u32(mut self, start: Option<u32>, end: Option<u32>) -> Result<Self> {
87 if let Some(start) = start {
88 self.query = self.query.with_start(start);
89 }
90 if let Some(end) = end {
91 self.query = self.query.with_end(end);
92 }
93
94 if (self.query.interval().start().is_some() || self.query.interval().end().is_some())
95 && self
96 .query
97 .reference_name()
98 .filter(|name| *name != "*")
99 .is_none()
100 {
101 return Err(HtsGetError::InvalidInput(
102 "reference name must be specified with start or end range".to_string(),
103 ));
104 }
105
106 if let (Some(start), Some(end)) = &(self.query.interval().start(), self.query.interval().end())
107 {
108 if start > end {
109 return Err(HtsGetError::InvalidRange(format!(
110 "end is greater than start (`{start}` > `{end}`)"
111 )));
112 }
113 }
114
115 Ok(self)
116 }
117
118 #[instrument(level = "trace", skip_all, ret)]
119 pub fn with_fields(self, fields: Option<impl Into<String>>) -> Self {
120 self.with_fields_from_vec(
121 fields.map(|fields| fields.into().split(',').map(|s| s.to_string()).collect()),
122 )
123 }
124
125 pub fn with_fields_from_vec(mut self, fields: Option<Vec<impl Into<String>>>) -> Self {
126 if let Some(fields) = fields {
127 self.query = self
128 .query
129 .with_fields(Fields::List(fields.into_iter().map(Into::into).collect()));
130 }
131
132 self
133 }
134
135 #[instrument(level = "trace", skip_all, ret)]
136 pub fn with_tags(
137 self,
138 tags: Option<impl Into<String>>,
139 notags: Option<impl Into<String>>,
140 ) -> Result<Self> {
141 self.with_tags_from_vec(
142 tags.map(|tags| tags.into().split(',').map(|s| s.to_string()).collect()),
143 notags.map(|notags| notags.into().split(',').map(|s| s.to_string()).collect()),
144 )
145 }
146
147 pub fn with_tags_from_vec(
148 mut self,
149 tags: Option<Vec<impl Into<String>>>,
150 notags: Option<Vec<impl Into<String>>>,
151 ) -> Result<Self> {
152 let notags = match notags {
153 Some(notags) => notags.into_iter().map(Into::into).collect(),
154 None => vec![],
155 };
156
157 if let Some(tags) = tags {
158 let tags: HashSet<String> = tags.into_iter().map(Into::into).collect();
159 if tags.iter().any(|tag| notags.contains(tag)) {
160 return Err(HtsGetError::InvalidInput(
161 "tags and notags can't intersect".to_string(),
162 ));
163 }
164 self.query = self.query.with_tags(Tags::List(tags));
165 };
166
167 if !notags.is_empty() {
168 self.query = self.query.with_no_tags(notags);
169 }
170
171 Ok(self)
172 }
173
174 #[cfg(feature = "experimental")]
176 pub fn with_encryption_scheme(
177 mut self,
178 encryption_scheme: Option<impl Into<String>>,
179 ) -> Result<Self> {
180 if let Some(scheme) = encryption_scheme {
181 let scheme = match scheme.into().to_lowercase().as_str() {
182 "c4gh" => Ok(EncryptionScheme::C4GH),
183 scheme => Err(HtsGetError::UnsupportedFormat(format!(
184 "invalid encryption scheme `{scheme}`"
185 ))),
186 }?;
187
188 self.query = self.query.with_encryption_scheme(scheme);
189 }
190
191 Ok(self)
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use htsget_config::types::Format::{Bam, Vcf};
198 use htsget_config::types::NoTags;
199
200 use super::*;
201
202 #[test]
203 fn query_with_id() {
204 let request = Request::new_with_id("ValidId".to_string());
205 assert_eq!(
206 QueryBuilder::new(request, Bam).build().id(),
207 "ValidId".to_string()
208 );
209 }
210
211 #[test]
212 fn query_with_format() {
213 let request = Request::new_with_id("ValidId".to_string());
214 assert_eq!(QueryBuilder::new(request, Vcf).build().format(), Vcf);
215 }
216
217 #[test]
218 fn query_with_class() {
219 let request = Request::new_with_id("ValidId".to_string());
220
221 assert_eq!(
222 QueryBuilder::new(request, Bam)
223 .with_class(Some("header"))
224 .unwrap()
225 .build()
226 .class(),
227 Class::Header
228 );
229 }
230
231 #[test]
232 fn query_with_reference_name() {
233 let request = Request::new_with_id("ValidId".to_string());
234
235 assert_eq!(
236 QueryBuilder::new(request, Bam)
237 .with_reference_name(Some("ValidName"))
238 .build()
239 .reference_name(),
240 Some("ValidName")
241 );
242 }
243
244 #[test]
245 fn query_with_range() {
246 let request = Request::new_with_id("ValidId".to_string());
247
248 let query = QueryBuilder::new(request, Bam)
249 .with_reference_name(Some("ValidName"))
250 .with_range(Some("3"), Some("5"))
251 .unwrap()
252 .build();
253 assert_eq!(
254 (query.interval().start(), query.interval().end()),
255 (Some(3), Some(5))
256 );
257 }
258
259 #[test]
260 fn query_with_range_but_without_reference_name() {
261 let request = Request::new_with_id("ValidId".to_string());
262
263 assert!(matches!(
264 QueryBuilder::new(request, Bam)
265 .with_range(Some("3"), Some("5"))
266 .unwrap_err(),
267 HtsGetError::InvalidInput(_)
268 ));
269 }
270
271 #[test]
272 fn query_with_invalid_start() {
273 let request = Request::new_with_id("ValidId".to_string());
274
275 assert!(matches!(
276 QueryBuilder::new(request, Bam)
277 .with_reference_name(Some("ValidName"))
278 .with_range(Some("a"), Some("5"))
279 .unwrap_err(),
280 HtsGetError::InvalidInput(_)
281 ));
282 }
283
284 #[test]
285 fn query_with_invalid_end() {
286 let request = Request::new_with_id("ValidId".to_string());
287
288 assert!(matches!(
289 QueryBuilder::new(request, Bam)
290 .with_reference_name(Some("ValidName"))
291 .with_range(Some("5"), Some("a"))
292 .unwrap_err(),
293 HtsGetError::InvalidInput(_)
294 ));
295 }
296
297 #[test]
298 fn query_with_invalid_range() {
299 let request = Request::new_with_id("ValidId".to_string());
300
301 assert!(matches!(
302 QueryBuilder::new(request, Bam)
303 .with_reference_name(Some("ValidName"))
304 .with_range(Some("5"), Some("3"))
305 .unwrap_err(),
306 HtsGetError::InvalidRange(_)
307 ));
308 }
309
310 #[test]
311 fn query_with_fields() {
312 let request = Request::new_with_id("ValidId".to_string());
313
314 assert_eq!(
315 QueryBuilder::new(request, Bam)
316 .with_fields(Some("header,part1,part2"))
317 .build()
318 .fields(),
319 &Fields::List(HashSet::from_iter(vec![
320 "header".to_string(),
321 "part1".to_string(),
322 "part2".to_string()
323 ]))
324 );
325 }
326
327 #[test]
328 fn query_with_tags() {
329 let request = Request::new_with_id("ValidId".to_string());
330
331 let query = QueryBuilder::new(request, Bam)
332 .with_tags(Some("header,part1,part2"), Some("part3"))
333 .unwrap()
334 .build();
335 assert_eq!(
336 query.tags(),
337 &Tags::List(HashSet::from_iter(vec![
338 "header".to_string(),
339 "part1".to_string(),
340 "part2".to_string()
341 ]))
342 );
343 assert_eq!(
344 query.no_tags(),
345 &NoTags(Some(HashSet::from_iter(vec!["part3".to_string()])))
346 );
347 }
348
349 #[test]
350 fn query_with_invalid_tags() {
351 let request = Request::new_with_id("ValidId".to_string());
352
353 let query = QueryBuilder::new(request, Bam)
354 .with_tags(Some("header,part1,part2"), Some("part3"))
355 .unwrap()
356 .build();
357 assert_eq!(
358 query.tags(),
359 &Tags::List(HashSet::from_iter(vec![
360 "header".to_string(),
361 "part1".to_string(),
362 "part2".to_string()
363 ]))
364 );
365 assert_eq!(
366 query.no_tags(),
367 &NoTags(Some(HashSet::from_iter(vec!["part3".to_string()])))
368 );
369 }
370}